26 #include "stafield_cuda.h"
30 #ifndef STA_FIELD_CLASS_H
31 #define STA_FIELD_CLASS_H
36 #include "sta_error.h"
38 #include "stensor_kernels.h"
40 #ifdef _SUPPORT_MATLAB_
41 #include "sta_mex_helpfunc.h"
45 #include "stensorcuda.h"
57 static int classcount=0;
58 static bool do_classcount=
false;
68 int object_is_dead_soon;
87 static void set_death(
const _stafield * cfield)
90 field->object_is_dead_soon++;
97 std::size_t getNumVoxel()
const {
98 return this->shape[0]*this->shape[1]*this->shape[2];
101 void switchFourierFlag()
117 static bool equalShape(
const _stafield & a,
const _stafield & b)
119 if ((a.shape[0]!=b.shape[0]) ||
120 (a.shape[1]!=b.shape[1]) ||
121 (a.shape[2]!=b.shape[2]) )
137 if ((this->shape[0]!=field.
shape[0]) ||
138 (this->shape[1]!=field.
shape[1]) ||
139 (this->shape[2]!=field.
shape[2]) ||
152 return (!((*
this)==field));
166 this->element_size[0]=T(1);
167 this->element_size[1]=T(1);
168 this->element_size[2]=T(1);
175 this->own_memory=
true;
179 this->classcount_id=classcount;
180 this->object_is_dead_soon=0;
195 return this->own_memory;
198 bool oneBlockMem()
const {
199 return (this->stride==0);
202 std::size_t getStride()
const {
204 return hanalysis::order2numComponents(this->field_storage,this->field_type,this->L);
208 void setElementSize(
const T element_size[])
210 this->element_size[0]=element_size[0];
211 this->element_size[1]=element_size[1];
212 this->element_size[2]=element_size[2];
215 const T * getElementSize()
const
217 return this->element_size;
244 template <
typename T>
255 std::complex<T> * data;
261 if ((this->L==-1)&&(f.
getRank()==-1))
263 if (!f.oneBlockMem())
264 throw (
hanalysis::STAError(
"error copying host memory to device memory, the host memory must be alignd in one single block!"));
269 throw STAError(
"warning: operator= (stride!=0) shared memory block but alocating new (own) memory would be nrequired \n");
271 if (!this->own_memory)
272 throw STAError(
"warning: operator= (!own_memory): shared memory block but alocating new (own) memory would be nrequired \n");
274 if (this->own_memory && (this->data!=NULL))
275 cudaFree(this->data);
277 this->field_storage=f.getStorage();
278 this->field_type=f.getType();
283 this->setElementSize(f.getElementSize());
287 this->own_memory=
true;
289 int numcomponents=hanalysis::order2numComponents(this->field_storage,this->field_type,this->L);
290 this->data=
new std::complex<T>[numcomponents*this->getNumVoxel()];
297 int numcomponents=hanalysis::order2numComponents(this->field_storage,this->field_type,this->L);
298 this->setElementSize(f.getElementSize());
300 T * dest=( T* )(this->
getData());
301 hanalysis_cuda::gpu_memcpy_d2h(src,dest,this->getNumVoxel()*numcomponents*
sizeof(std::complex<T>));
303 throw (
hanalysis::STAError(
"error copying host memory to device memory, the (existing) device memory must be alignd in one single block!"));
313 this->set_death(&result);
336 void operator=(S value)
339 throw (
hanalysis::STAError(
"operator= : field type must be STA_OFIELD_SINGLE",STA_RESULT_OFIELD_TYPE_MISMATCH));
341 int thestride=this->getStride();
342 std::size_t jumpz=this->shape[1]*this->shape[2];
343 int strid_all=jumpz*thestride;
344 int numcomponents=hanalysis::order2numComponents(this->field_storage,this->field_type,this->L);
345 int remaining=thestride-numcomponents;
347 #pragma omp parallel for num_threads(get_numCPUs())
348 for (std::size_t z=0;z<this->shape[0];z++)
350 std::complex<T> * p=this->data+z*strid_all;
351 for (std::size_t a=0;a<jumpz;a++)
353 for (
int a=0;a<numcomponents;a++)
370 if ((this->L==-1)&&(f.
L==-1))
388 if ((f.object_is_dead_soon==1)
397 if (f.object_is_dead_soon>1)
398 throw STAError(
"error: something went wrong with the memory managemant \n");
400 if (this->own_memory && (this->data!=NULL) && (this->object_is_dead_soon<2))
401 delete [] this->data;
409 this->shape[0]=f.
shape[0];
410 this->shape[1]=f.
shape[1];
411 this->shape[2]=f.
shape[2];
413 this->setElementSize(f.getElementSize());
417 this->stride=f.stride;
418 this->own_memory=f.own_memory;
428 throw STAError(
"warning: operator= (stride!=0) shared memory block but allocation of new (own) memory would be required \n");
430 if (!this->own_memory)
431 throw STAError(
"warning: operator= (!own_memory): shared memory block but allocation of new (own) memory would be required \n");
435 if (this->own_memory && (this->data!=NULL))
436 delete [] this->data;
440 this->shape[0]=f.
shape[0];
441 this->shape[1]=f.
shape[1];
442 this->shape[2]=f.
shape[2];
444 this->setElementSize(f.getElementSize());
448 this->own_memory=
true;
450 int numcomponents=hanalysis::order2numComponents(this->field_storage,this->field_type,this->L);
451 int numVoxel=this->shape[0]*this->shape[1]*this->shape[2];
452 this->data=
new std::complex<T>[numcomponents*numVoxel];
456 if ((f.stride==0)&&(this->stride==0))
458 int numcomponents=hanalysis::order2numComponents(this->field_storage,this->field_type,this->L);
459 this->setElementSize(f.getElementSize());
464 memcpy(this->data,f.data,this->shape[0]*this->shape[1]*this->shape[2]*numcomponents*
sizeof(std::complex<T>));
469 this->setElementSize(f.getElementSize());
475 throw STAError(
"operator= this cannot happen ! the input field must be STA_OFIELD_SINGLE",STA_RESULT_OFIELD_TYPE_MISMATCH);
477 int numcomponents_new=this->L+1;
479 numcomponents_new=2*this->L+1;
491 numcomponents_new*=2;
494 int strid_in=2*(f.getStride())-numcomponents_new;
495 int strid_out=2*this->getStride()-numcomponents_new;
496 std::size_t jumpz=this->shape[1]*this->shape[2];
497 int strid_in_all=jumpz*f.getStride();
498 int strid_out_all=jumpz*this->getStride();
500 #pragma omp parallel for num_threads(get_numCPUs())
501 for (std::size_t z=0;z<this->shape[0];z++)
503 T * in=(T*)(f.data+z*strid_in_all);
504 T * out=(T*)(this->data+z*strid_out_all);
505 for (std::size_t a=0;a<jumpz;a++)
507 for (
int b=0;b<numcomponents_new;b++)
563 Mult(*
this,*
this,alpha,
false,
true);
579 Mult(*
this,*
this,T(1)/alpha,
false,
true);
598 stafield result(this->shape,this->L,this->field_storage,this->field_type);
599 Mult(*
this,result,T(1),
false,
true);
600 Mult(a,result,T(1),
false,
false);
620 stafield result(this->shape,this->L,this->field_storage,this->field_type);
621 Mult(*
this,result,T(-1),
false,
true);
622 Mult(a,result,T(-1),
false,
false);
644 stafield result(this->shape,this->L,this->field_storage,this->field_type);
645 Mult(*
this,result,alpha,
false,
true);
661 stafield result(this->shape,this->L,this->field_storage,this->field_type);
662 Mult(*
this,result,T(1)/alpha,
false,
true);
735 stafield(
const stafield & field) : _stafield<T>(), data(NULL)
768 const T element_size[]=NULL)
775 if (element_size!=NULL)
776 this->setElementSize(element_size);
784 this->shape[0]=shape[0];
785 this->shape[1]=shape[1];
786 this->shape[2]=shape[2];
790 int numcomponents=hanalysis::order2numComponents(field_storage,field_type,L);
791 int numVoxel=shape[0]*shape[1]*shape[2];
792 if (hanalysis::verbose>0)
793 printf(
"L: %d , (%i,%i,%i) , // %i\n",L,shape[0],shape[1],shape[2],numcomponents);
794 if (hanalysis::verbose>0)
795 printf(
"allocating %d bytes\n",numcomponents*numVoxel*
sizeof(std::complex<T>));
796 this->data=
new std::complex<T>[numcomponents*numVoxel];
819 const std::size_t shape[],
820 std::vector<T> param_v,
824 const T element_size[]=NULL)
831 if (element_size!=NULL)
832 this->setElementSize(element_size);
838 this->shape[0]=shape[0];
839 this->shape[1]=shape[1];
840 this->shape[2]=shape[2];
842 this->own_memory=
true;
844 int numcomponents=hanalysis::order2numComponents(field_storage,this->field_type,L);
845 int numVoxel=shape[0]*shape[1]*shape[2];
846 if (hanalysis::verbose>0)
847 printf(
"L: %d , (%i,%i,%i) , // %i\n",L,shape[0],shape[1],shape[2],numcomponents);
848 if (hanalysis::verbose>0)
849 printf(
"allocating %d bytes\n",numcomponents*numVoxel*
sizeof(std::complex<T>));
850 this->data=
new std::complex<T>[numcomponents*numVoxel];
852 makeKernel(kernelname,param_v,*
this,centered);
869 std::complex<T> * data,
870 std::size_t stride=0,
871 const T element_size[]=NULL) :
_stafield<T>()
877 if (element_size!=NULL)
878 this->setElementSize(element_size);
886 this->shape[0]=shape[0];
887 this->shape[1]=shape[1];
888 this->shape[2]=shape[2];
890 this->own_memory=
false;
902 #ifdef _SUPPORT_MATLAB_
906 const T element_size[]=NULL)
909 stafieldFromMxArray(field,field_storage,field_type,element_size);
912 stafield(mxArray * field) : _stafield<T>()
924 if (mex_isStaFieldStruct<T>(field))
927 mxArray *storage = mxGetField(field,0,(
char *)
"storage");
928 mxArray *type = mxGetField(field,0,(
char *)
"type");
929 mxArray *data = mxGetField(field,0,(
char *)
"data");
930 mxArray *element_size = mxGetField(field,0,(
char *)
"element_size");
934 element_size_p[0]=element_size_p[1]=element_size_p[2]=1;
936 switch (mxGetClassID(element_size))
941 float * esize_p= (
float*) mxGetData(element_size);
942 element_size_p[0]=esize_p[2];
943 element_size_p[1]=esize_p[1];
944 element_size_p[2]=esize_p[0];
949 double * esize_p= (
double*) mxGetData(element_size);
950 element_size_p[0]=esize_p[2];
951 element_size_p[1]=esize_p[1];
952 element_size_p[2]=esize_p[0];
956 mexErrMsgTxt(
"unsupported data type for element size!\n");
960 stafieldFromMxArray(data,
961 enumfromstring_storage(mex_mex2string(storage)),
962 enumfromstring_type(mex_mex2string(type)),
967 throw (
hanalysis::STAError(
"mxArray contains no valid stafield class nor a valid starfield struct!"));
971 void stafieldFromMxArray(mxArray * field,
974 const T element_size[]=NULL)
976 if (mxGetClassID(field)!=mex_getClassId<T>())
983 if (element_size!=NULL)
984 this->setElementSize(element_size);
989 this->own_memory=
false;
990 const mwSize *dimsFIELD = mxGetDimensions(field);
991 const int numdimFIELD = mxGetNumberOfDimensions(field);
995 throw (
hanalysis::STAError(
"wrong number of components in first dimension (must be 2 -> real/imag)"));
996 int ncomponents=dimsFIELD[1];
998 this->L=numComponents2order(field_storage,field_type,ncomponents);
1000 this->data = (std::complex<T> *) (mxGetData(field));
1001 this->shape[0]=dimsFIELD[2];
1002 this->shape[1]=dimsFIELD[3];
1003 this->shape[2]=dimsFIELD[4];
1004 std::swap(this->shape[0],this->shape[2]);
1005 if (hanalysis::verbose>0)
1006 printf(
"L: %d , (%i,%i,%i) , // %i\n",this->L,this->shape[0],this->shape[1],this->shape[2],ncomponents);
1060 static stafield createFieldAndmxStruct( mxArray * & newMxArray,
1061 const std::size_t shape[],
1065 const T element_size[]=NULL)
1067 int numcomponents=hanalysis::order2numComponents(field_storage,field_type,L);
1071 ndims[1] = numcomponents;
1076 const char *field_names[] = {
"storage",
"type",
"L",
"data",
"shape",
"element_size"};
1079 newMxArray=mxCreateStructArray(1,&dim,6,field_names);
1081 mxArray *theL = mxCreateNumericArray(1,&dim,mxDOUBLE_CLASS,mxREAL);
1082 double *s = (
double*) mxGetData(theL);
1085 mwSize dims[2]={1,3};
1087 mxArray *theSHAPE = mxCreateNumericArray(dim,dims,mxDOUBLE_CLASS,mxREAL);
1088 s = (
double*) mxGetData(theSHAPE);
1094 mxArray *theELEMENT_SIZE = mxCreateNumericArray(dim,dims,mxDOUBLE_CLASS,mxREAL);
1095 s = (
double*) mxGetData(theELEMENT_SIZE);
1096 s[0]=s[1]=s[2]=T(1);
1098 if (element_size!=NULL)
1100 s[0] = element_size[2];
1101 s[1] = element_size[1];
1102 s[2] = element_size[0];
1108 mxSetField(newMxArray,0,
"storage", mxCreateString(enumtostring_storage(field_storage).c_str()));
1109 mxSetField(newMxArray,0,
"type", mxCreateString(enumtostring_type(field_type).c_str()));
1110 mxSetField(newMxArray,0,
"L", theL);
1111 mxSetField(newMxArray,0,
"data", mxCreateNumericArray(5,ndims,mex_getClassId<T>(),mxREAL));
1112 mxSetField(newMxArray,0,
"shape", theSHAPE);
1113 mxSetField(newMxArray,0,
"element_size", theELEMENT_SIZE);
1116 std::complex<T> * datap=(std::complex<T> *) (mxGetData(mxGetField(newMxArray,0,(
char *)
"data")));
1122 stafield stOut(shape,L,field_storage,field_type,datap,0,element_size);
1123 _stafield<T>::set_death(&stOut);
1128 static stafield createFieldAndmxArray( mxArray * & newMxArray,
1129 const std::size_t shape[],
1133 const T element_size[]=NULL)
1135 int numcomponents=hanalysis::order2numComponents(field_storage,field_type,L);
1139 ndims[1] = numcomponents;
1143 newMxArray = mxCreateNumericArray(5,ndims,mex_getClassId<T>(),mxREAL);
1145 stafield stOut(shape,L,field_storage,field_type, (std::complex<T> *) (mxGetData(newMxArray)),element_size);
1146 _stafield<T>::set_death(&stOut);
1158 if (this->own_memory && (this->data!=NULL))
1161 printf(
"destroying stafield %i / remaining: %i [own]",this->classcount_id,--classcount);
1163 if (this->object_is_dead_soon<2)
1166 printf(
" (deleting data)\n");
1168 if (hanalysis::verbose>0)
1169 printf(
"field destrucor -> deleting data\n");
1170 delete [] this->data;
1174 printf(
" (not deleting data, still having references)\n");
1180 printf(
"destroying stafield %i / remaining: %i [empty]\n",this->classcount_id,--classcount);
1182 printf(
"destroying stafield %i / remaining: %i [view]\n",this->classcount_id,--classcount);
1184 if (hanalysis::verbose>0)
1185 printf(
"field destrucor -> --\n");
1199 std::complex<T> sum()
const {
1200 std::complex<T> Sum=0;
1201 int numcomponents=hanalysis::order2numComponents(this->field_storage,this->field_type,this->L);
1202 std::size_t numel=this->getNumVoxel()*numcomponents;
1204 for (std::size_t i=0;i<numel;i++)
1218 if (this->own_memory)
return false;
1223 std::complex<T> * own_mem=NULL;
1226 if (this->stride==0)
1228 int numcomponents=hanalysis::order2numComponents(this->field_storage,this->field_type,this->L);
1229 if (hanalysis::verbose>0)
1230 printf(
"copying field with %i components\n",numcomponents);
1231 if (hanalysis::verbose>0)
1232 printf(
"copying %i bytes\n",this->shape[0]*this->shape[1]*this->shape[2]*numcomponents*
sizeof(std::complex<T>));
1233 own_mem=
new std::complex<T>[this->shape[0]*this->shape[1]*this->shape[2]*numcomponents];
1234 memcpy(own_mem,this->data,this->shape[0]*this->shape[1]*this->shape[2]*numcomponents*
sizeof(std::complex<T>));
1238 int numcomponents=hanalysis::order2numComponents(this->field_storage,this->field_type,this->L);
1239 own_mem=
new std::complex<T>[this->shape[0]*this->shape[1]*this->shape[2]*numcomponents];
1244 int strid_in=2*(this->getStride())-numcomponents;
1245 std::size_t jumpz=this->shape[1]*this->shape[2];
1246 int strid_in_all=jumpz*this->getStride();
1247 int strid_out_all=jumpz*numcomponents/2;
1249 #pragma omp parallel for num_threads(get_numCPUs())
1250 for (std::size_t z=0;z<this->shape[0];z++)
1252 T * in=(T*)(this->data+z*strid_in_all);
1253 T * out=(T*)(own_mem+z*strid_out_all);
1254 for (std::size_t a=0;a<jumpz;a++)
1256 for (
int b=0;b<numcomponents;b++)
1264 std::swap(this->data,own_mem);
1265 this->own_memory=
true;
1266 this->object_is_dead_soon=0;
1296 throw (
hanalysis::STAError(
"error retrieving (sub) field l, l must be >= 0",STA_RESULT_INVALID_TENSOR_RANK));
1310 std::complex<T> * component_data;
1311 int offset=hanalysis::getComponentOffset(this->field_storage,this->field_type,l);
1312 int numcomponents=hanalysis::order2numComponents(this->field_storage,this->field_type,this->L);
1316 component_data=this->data+offset;
1320 this->field_storage,
1323 this->set_death(&view);
1324 view.stride=numcomponents;
1326 view.setElementSize(this->getElementSize());
1334 throw (
hanalysis::STAError(
"error retrieving (sub) field l, l must be >= 0",STA_RESULT_INVALID_TENSOR_RANK));
1350 throw (
hanalysis::STAError(
"error retrieving (sub) field l,m, storage m>0 but not STA_FIELD_STORAGE_C"));
1353 std::complex<T> * component_data;
1354 int offset=hanalysis::getComponentOffset(this->field_storage,this->field_type,l);
1356 int numcomponents=hanalysis::order2numComponents(this->field_storage,this->field_type,this->L);
1358 component_data=this->data+offset;
1362 this->field_storage,
1365 this->set_death(&view);
1366 view.stride=numcomponents;
1368 view.setElementSize(this->getElementSize());
1372 std::complex<T>
get(
int l,
int m,
int Z,
int Y,
int X)
const
1375 stafield tmp=this->
get(l,m);
1376 std::size_t stride=tmp.getStride();
1377 const std::complex<T> * datap=tmp.getDataConst();
1378 std::size_t voxel=X+(this->shape[2]*(Y+this->shape[1]*Z));
1379 if (voxel>this->getNumVoxel())
1382 return (*(datap+voxel*stride));
1390 std::complex<T> set(
int l,
int m,
int Z,
int Y,
int X, std::complex<T> value)
1393 stafield tmp=this->
get(l,m);
1394 std::size_t stride=tmp.getStride();
1395 std::complex<T> * datap=tmp.getData();
1396 std::size_t voxel=X+(this->shape[2]*(Y+this->shape[1]*Z));
1397 if (voxel>this->getNumVoxel())
1400 (*(datap+voxel*stride))=value;
1440 bool conjugate=
false,
1441 std::complex<T> alpha= T( 1 ),
1442 #ifdef _STA_LINK_FFTW
1443 int flag=FFTW_ESTIMATE )
1461 std::size_t ncomponents_in=hanalysis::order2numComponents(stIn.getStorage(),stIn.getType(),stIn.
L);
1462 std::size_t ncomponents_out=hanalysis::order2numComponents(stOut.getStorage(),stOut.getType(),stOut.
L);
1463 if (((stIn.stride!=0)&&(ncomponents_in!=stIn.stride))||((stOut.stride!=0)&&(ncomponents_out!=stOut.stride)))
1465 if ((stIn.data==stOut.data))
1478 throw (
hanalysis::STAError(
"FFT: storage type must be the same",STA_RESULT_STORAGE_MISMATCH));
1483 int stride_in = stIn.getStride();
1484 int stride_out = stOut.getStride();
1485 int ncomponents=hanalysis::order2numComponents(stIn.getStorage(),stIn.getType(),stIn.
L);
1487 if (hanalysis::verbose>0)
1488 printf(
"FFT: stride_in: %i stride_out %i , ncomp: %i\n",stride_in,stride_out,ncomponents);
1499 if (err!=STA_RESULT_SUCCESS)
1502 stOut.setElementSize(stIn.getElementSize());
1518 std::complex<T> alpha= T( 1 ),
1519 bool conjugate=
false,
1520 bool clear_result =
false)
1524 if (stIn.getStorage()!=stOut.getStorage())
1525 throw (
hanalysis::STAError(
"Mult: storage type must be the same",STA_RESULT_STORAGE_MISMATCH));
1526 int stride_in = stIn.getStride();
1527 int stride_out = stOut.getStride();
1528 int ncomponents=hanalysis::order2numComponents(stIn.getStorage(),stIn.getType(),stIn.
L);
1530 if (hanalysis::verbose>0)
1531 printf(
"Mult: stride_in: %i stride_out %i , ncomp: %i\n",stride_in,stride_out,ncomponents);
1533 STA_RESULT err=hanalysis::sta_mult<T,std::complex<T> > (
1544 if (err!=STA_RESULT_SUCCESS)
1547 stOut.setElementSize(stIn.getElementSize());
1551 static void Fspecial(
const stafield & stIn,
1554 bool clear_result =
false)
1558 if (stIn.getStorage()!=stOut.getStorage())
1559 throw (
hanalysis::STAError(
"Fspecial: storage type must be the same",STA_RESULT_STORAGE_MISMATCH));
1560 int stride_in = stIn.getStride();
1561 int stride_out = stOut.getStride();
1562 int ncomponents=hanalysis::order2numComponents(stIn.getStorage(),stIn.getType(),stIn.
L);
1564 if (hanalysis::verbose>0)
1565 printf(
"Fspecial: stride_in: %i stride_out %i , ncomp: %i\n",stride_in,stride_out,ncomponents);
1577 if (err!=STA_RESULT_SUCCESS)
1580 stOut.setElementSize(stIn.getElementSize());
1590 bool clear_result =
false)
1592 if (!stafield::equalShape(stIn,stOut))
1595 throw (
hanalysis::STAError(
"Norm: stOut must be 0-order tensor field!!",STA_RESULT_INVALID_TENSOR_RANK));
1597 throw (
hanalysis::STAError(
"Deriv: first input field type must be STA_OFIELD_SINGLE",STA_RESULT_OFIELD_TYPE_MISMATCH));
1598 if (stOut.getType()!=stIn.getType())
1599 throw (
hanalysis::STAError(
"Deriv: stOut field type must be STA_OFIELD_SINGLE",STA_RESULT_OFIELD_TYPE_MISMATCH));
1601 if (stIn.getStorage()!=stOut.getStorage())
1602 throw (
hanalysis::STAError(
"Norm: storage type must be the same",STA_RESULT_STORAGE_MISMATCH));
1603 int stride_in = stIn.getStride();
1604 int stride_out = stOut.getStride();
1606 if (hanalysis::verbose>0)
1607 printf(
"Norm: stride_in: %i stride_out %i\n",stride_in,stride_out);
1619 if (err!=STA_RESULT_SUCCESS)
1623 stOut.setElementSize(stIn.getElementSize());
1657 bool conjugate=
false,
1658 std::complex<T> alpha= T( 1 ),
1659 bool clear_result =
false,
1663 if (!stafield::equalShape(stIn,stOut))
1665 if (stIn.getStorage()!=stOut.getStorage())
1666 throw (
hanalysis::STAError(
"Deriv: storage type must be the same",STA_RESULT_STORAGE_MISMATCH));
1668 throw (
hanalysis::STAError(
"Deriv: first input field type must be STA_OFIELD_SINGLE",STA_RESULT_OFIELD_TYPE_MISMATCH));
1669 if (stOut.getType()!=stIn.getType())
1670 throw (
hanalysis::STAError(
"Deriv: stOut field type must be STA_OFIELD_SINGLE",STA_RESULT_OFIELD_TYPE_MISMATCH));
1672 throw (
hanalysis::STAError(
"Deriv: stOut rank must be input rank+Jupdown",STA_RESULT_INVALID_TENSOR_RANK));
1674 int stride_in = stIn.getStride();
1675 int stride_out = stOut.getStride();
1677 if (hanalysis::verbose>0)
1678 printf(
"Deriv: stride_in: %i stride_out %i\n",stride_in,stride_out);
1689 stIn.getElementSize(),
1699 if (err!=STA_RESULT_SUCCESS)
1702 stOut.setElementSize(stIn.getElementSize());
1737 bool conjugate=
false,
1738 std::complex<T> alpha= T( 1 ),
1739 bool clear_result =
false)
1741 if (!stafield::equalShape(stIn,stOut))
1743 if (stIn.getStorage()!=stOut.getStorage())
1747 if (stOut.getType()!=stIn.getType())
1752 int stride_in = stIn.getStride();
1753 int stride_out = stOut.getStride();
1755 if (hanalysis::verbose>0)
1756 printf(
"Deriv2: stride_in: %i stride_out %i\n",stride_in,stride_out);
1767 stIn.getElementSize(),
1772 if (err!=STA_RESULT_SUCCESS)
1775 stOut.setElementSize(stIn.getElementSize());
1779 static bool Is_nan(
const stafield & stIn)
1782 int stride_in = stIn.getStride();
1783 int ncomponents=hanalysis::order2numComponents(stIn.getStorage(),stIn.getType(),stIn.
L);
1785 return hanalysis::sta_isnan (
1791 static bool Is_inf(
const stafield & stIn)
1794 int stride_in = stIn.getStride();
1795 int ncomponents=hanalysis::order2numComponents(stIn.getStorage(),stIn.getType(),stIn.L);
1797 return hanalysis::sta_isinf (
1798 stIn.getDataConst(),
1818 std::complex<T> alpha= T( 1 ),
1819 bool clear_result =
false,
1822 if (stIn.getStorage()!=stOut.getStorage())
1823 throw (
hanalysis::STAError(
"Lap: storage type must be the same",STA_RESULT_STORAGE_MISMATCH));
1826 int stride_in = stIn.getStride();
1827 int stride_out = stOut.getStride();
1828 int ncomponents=hanalysis::order2numComponents(stIn.getStorage(),stIn.getType(),stIn.
L);
1831 if (hanalysis::verbose>0)
1832 printf(
"Lap: stride_in: %i stride_out %i , ncomp: %i\n",stride_in,stride_out,ncomponents);
1842 stIn.getElementSize(),
1847 if (err!=STA_RESULT_SUCCESS)
1850 stOut.setElementSize(stIn.getElementSize());
1880 bool normalize=
false,
1881 std::complex<T> alpha= T( 1 ),
1882 bool clear_result =
false)
1892 throw (
hanalysis::STAError(
"Prod: ensure that |l1-l2|<=J && |l1+l2|<=J",STA_RESULT_INVALID_PRODUCT));
1893 if ( ( ( stIn1.
getRank()+stIn2.
getRank()+J ) %2!=0 ) && ( normalize ) )
1897 throw (
hanalysis::STAError(
"Prod: stOut has wrong Rank!",STA_RESULT_INVALID_TENSOR_RANK));
1900 if ((!stafield::equalShape(stIn1,stOut))||(!stafield::equalShape(stIn2,stOut)))
1902 if ((stIn1.getStorage()!=stOut.getStorage())||(stIn2.getStorage()!=stOut.getStorage()))
1903 throw (
hanalysis::STAError(
"Prod: storage type must be the same",STA_RESULT_STORAGE_MISMATCH));
1905 throw (
hanalysis::STAError(
"Prod: first input field type must be STA_OFIELD_SINGLE",STA_RESULT_OFIELD_TYPE_MISMATCH));
1906 if (stIn1.getType()!=stIn2.getType())
1907 throw (
hanalysis::STAError(
"Prod: first input field type must be STA_OFIELD_SINGLE",STA_RESULT_OFIELD_TYPE_MISMATCH));
1908 if (stOut.getType()!=stIn1.getType())
1909 throw (
hanalysis::STAError(
"Prod: stOut field type must be STA_OFIELD_SINGLE",STA_RESULT_OFIELD_TYPE_MISMATCH));
1911 int stride_in1 = stIn1.getStride();
1912 int stride_in2 = stIn2.getStride();
1913 int stride_out = stOut.getStride();
1915 if (hanalysis::verbose>0)
1916 printf(
"Prod: stride_in1: %i,stride_in2: %i stride_out %i\n",stride_in1,stride_in2,stride_out);
1937 if (err!=STA_RESULT_SUCCESS)
1943 stOut.setElementSize(stIn1.getElementSize());
1947 static void Prod3(
const stafield & stIn1,
1953 bool normalize=
false,
1954 std::complex<T> alpha= T( 1 ),
1955 bool clear_result =
false)
1959 throw (
hanalysis::STAError(
"Prod: ensure that |l1-l2|<=J && |l1+l2|<=J",STA_RESULT_INVALID_PRODUCT));
1960 if ( ( ( stIn1.
getRank()+stIn2.
getRank()+Jprod1 ) %2!=0 ) && ( normalize ) )
1963 if ( ( std::abs ( stIn3.
getRank()-Jprod1 ) >Jprod2 ) ||
1964 ( Jprod1>std::abs ( Jprod1+stIn3.
getRank() ) ) )
1965 throw (
hanalysis::STAError(
"Prod: ensure that |l1-l2|<=J && |l1+l2|<=J",STA_RESULT_INVALID_PRODUCT));
1966 if ( ( ( stIn3.
getRank()+Jprod2+Jprod1 ) %2!=0 ) && ( normalize ) )
1970 throw (
hanalysis::STAError(
"Prod: stOut has wrong Rank!",STA_RESULT_INVALID_TENSOR_RANK));
1973 if ((!stafield::equalShape(stIn1,stOut))||(!stafield::equalShape(stIn2,stOut))||(!stafield::equalShape(stIn3,stOut)))
1975 if ((stIn1.getStorage()!=stOut.getStorage())||(stIn2.getStorage()!=stOut.getStorage())||(stIn3.getStorage()!=stOut.getStorage()))
1976 throw (
hanalysis::STAError(
"Prod: storage type must be the same",STA_RESULT_STORAGE_MISMATCH));
1978 throw (
hanalysis::STAError(
"Prod: first input field type must be STA_OFIELD_SINGLE",STA_RESULT_OFIELD_TYPE_MISMATCH));
1979 if (stIn1.getType()!=stIn2.getType())
1980 throw (
hanalysis::STAError(
"Prod: first input field type must be STA_OFIELD_SINGLE",STA_RESULT_OFIELD_TYPE_MISMATCH));
1981 if (stOut.getType()!=stIn1.getType())
1982 throw (
hanalysis::STAError(
"Prod: stOut field type must be STA_OFIELD_SINGLE",STA_RESULT_OFIELD_TYPE_MISMATCH));
1983 if (stOut.getType()!=stIn3.getType())
1984 throw (
hanalysis::STAError(
"Prod: stOut field type must be STA_OFIELD_SINGLE",STA_RESULT_OFIELD_TYPE_MISMATCH));
1986 int stride_in1 = stIn1.getStride();
1987 int stride_in2 = stIn2.getStride();
1988 int stride_in3 = stIn3.getStride();
1989 int stride_out = stOut.getStride();
1991 if (hanalysis::verbose>0)
1992 printf(
"Prod: stride_in1: %i,stride_in2: %i stride_out %i\n",stride_in1,stride_in2,stride_out);
2017 if (err!=STA_RESULT_SUCCESS)
2023 stOut.setElementSize(stIn1.getElementSize());
2028 static std::vector<T> kernel_param(std::string params)
2030 for (std::size_t a=0;a<params.length();a++)
2033 std::vector<T> param_v;
2034 std::stringstream param_s;
2037 while (!param_s.eof() && param_s.good())
2041 param_v.push_back(tmp);
2044 throw STAError(
"ahhh, something went completely wrong while parsing parameters! ");
2050 static void makeKernel(std::string kernelname,
2051 std::vector<T> param_v,
2053 bool centered=
false)
2056 throw (
hanalysis::STAError(
"makeKernel: first input field type must be STA_OFIELD_SINGLE"));
2059 hanalysis::STA_CONVOLUTION_KERNELS kernel=hanalysis::STA_CONV_KERNEL_UNKNOWN;
2061 if (kernelname==
"gauss")
2062 kernel=hanalysis::STA_CONV_KERNEL_GAUSS;
2063 if (kernelname==
"gaussLaguerre")
2064 kernel=hanalysis::STA_CONV_KERNEL_GAUSS_LAGUERRE;
2065 if (kernelname==
"gaussBessel")
2066 kernel=hanalysis::STA_CONV_KERNEL_GAUSS_BESSEL;
2067 if (kernelname==
"sh")
2068 kernel=hanalysis::STA_CONV_KERNEL_SH;
2069 if (kernelname==
"fourier")
2070 kernel=hanalysis::STA_CONV_KERNEL_FOURIER;
2074 case hanalysis::STA_CONV_KERNEL_FOURIER:
2076 if (!(param_v.size()==2))
2077 throw STAError(
"error: wrong numger of parameters \n");
2083 case hanalysis::STA_CONV_KERNEL_SH:
2085 if (!(param_v.size()==2))
2086 throw STAError(
"error: wrong numger of parameters \n");
2092 case hanalysis::STA_CONV_KERNEL_GAUSS:
2094 if (!(param_v.size()==1))
2095 throw STAError(
"error: wrong numger of parameters");
2100 case hanalysis::STA_CONV_KERNEL_GAUSS_LAGUERRE:
2102 if (!(param_v.size()==2))
2103 throw STAError(
"error: wrong numger of parameters");
2109 case hanalysis::STA_CONV_KERNEL_GAUSS_BESSEL:
2111 if (!(param_v.size()==3))
2112 throw STAError(
"error: wrong numger of parameters");
2120 throw STAError(
"error: unsupported kernel \n");
2125 range=stIn.getRank();
2129 for (
int m=-stIn.getRank();m<=range;m++)
2131 hanalysis::renderKernel(
2139 stIn.getElementSize(),
2145 delete currentKernel;
2157 bool conjugate=
false,
2158 std::complex<T> alpha= T( 1 ),
2159 #ifdef _STA_LINK_FFTW
2160 int flag=FFTW_ESTIMATE )
2167 stafield result(this->shape,this->L,this->field_storage,this->field_type);
2168 this->set_death(&result);
2169 FFT(*
this,result,forward,conjugate,alpha,flag);
2181 #ifdef _STA_LINK_FFTW
2182 int flag=FFTW_ESTIMATE )
2189 return (*this).
fft(
true,
false,T(1),flag).
prod(b.
fft(
true,
true,T(1),flag),J,
true).
fft(
false,
false,T(1)/T(b.getNumVoxel()),flag);;
2201 bool conjugate=
false)
const
2206 stafield result(this->shape,this->L,this->field_storage,this->field_type);
2207 this->set_death(&result);
2208 Mult(*
this,result,alpha,conjugate,
true);
2224 stafield result(this->shape,this->L,this->field_storage,this->field_type);
2225 this->set_death(&result);
2226 Fspecial(*
this,result,fsp,
true);
2242 stafield result(this->shape,0,this->field_storage,this->field_type);
2244 this->set_death(&result);
2245 Norm(*
this,result,
true);
2257 bool conjugate=
false,
2258 std::complex<T> alpha= T( 1 ),
2266 stafield result(this->shape,this->L+J,this->field_storage,this->field_type);
2268 this->set_death(&result);
2270 Deriv(*
this,result,J,conjugate,alpha,
true,accuracy);
2285 bool conjugate=
false,
2286 std::complex<T> alpha= T( 1 ),
2293 stafield result(this->shape,this->L+J,this->field_storage,this->field_type);
2294 this->set_death(&result);
2295 Deriv2(*
this,result,J,conjugate,alpha,
true);
2311 stafield result(this->shape,this->L,this->field_storage,this->field_type);
2312 this->set_death(&result);
2313 Lap(*
this,result,alpha,
true,type);
2328 bool normalize=
false,
2329 std::complex<T> alpha= T( 1 ))
const
2333 stafield result(this->shape,J,this->field_storage,this->field_type);
2334 this->set_death(&result);
2337 Prod(*
this,b,result,J,normalize,alpha,
true);
2364 std::complex<T> fspecial(
const std::complex<T> & value)
const
2366 return std::exp(v*value);
2373 std::complex<T> fspecial(
const std::complex<T> & value)
const
2375 return std::sqrt(value);
2391 std::complex<T> fspecial(
const std::complex<T> & value)
const
2393 return std::pow(value,v);
2404 v=std::numeric_limits<T>::epsilon();
2409 std::complex<T> fspecial(
const std::complex<T> & value)
const
2411 return T(1)/(value+v);
2423 stafield result(this->shape,this->L,this->field_storage,this->field_type);
2424 this->set_death(&result);
2425 Fspecial(*
this,result,Exp,
true);
2442 stafield result(this->shape,this->L,this->field_storage,this->field_type);
2443 this->set_death(&result);
2444 Fspecial(*
this,result,Sqrt,
true);
2459 stafield result(this->shape,this->L,this->field_storage,this->field_type);
2460 this->set_death(&result);
2461 Fspecial(*
this,result,Pow,
true);
2476 stafield result(this->shape,this->L,this->field_storage,this->field_type);
2477 this->set_death(&result);
2478 Fspecial(*
this,result,Inv,
true);
2498 #define STA_FIELD hanalysis::stafieldGPU
2500 #define STA_FIELD hanalysis::stafield
STA_FIELD_STORAGE
tensor field data storage
Definition: stensor.h:5163
tensor field has all components of even ranks :
Definition: stensor.h:5184
stafield exp(std::complex< T > value=T(1)) const
exp, component by component
Definition: stafield.h:2418
stafield deriv(int J, bool conjugate=false, std::complex< T > alpha=T(1), int accuracy=0) const
see Deriv
Definition: stafield.h:2256
stafield(const std::size_t shape[], int L, hanalysis::STA_FIELD_STORAGE field_storage, hanalysis::STA_FIELD_TYPE field_type, std::complex< T > *data, std::size_t stride=0, const T element_size[]=NULL)
Definition: stafield.h:865
STA_RESULT sta_derivatives(const std::complex< T > *stIn, std::complex< T > *stOut, const std::size_t shape[], int J, int Jupdown, bool conjugate=false, std::complex< T > alpha=(T) 1.0, STA_FIELD_STORAGE field_storage=STA_FIELD_STORAGE_C, const T v_size[]=NULL, int stride_in=-1, int stride_out=-1, bool clear_field=false, int accuracy=0)
spherical tensor derivative:
Definition: stensor.h:5864
Definition: stafield.h:2353
bool operator!=(const _stafield &field) const
Definition: stafield.h:150
Definition: stensor_kernels.h:70
represents spherical tensor fields (CPU version)
Definition: stafield.h:251
stafield lap(std::complex< T > alpha=T(1), int type=1) const
see Lap
Definition: stafield.h:2305
hanalysis::STA_FIELD_STORAGE field_storage
must be either STA_FIELD_STORAGE_C, STA_FIELD_STORAGE_R or STA_FIELD_STORAGE_RF
Definition: stafield.h:78
const stafield operator-(const stafield &a) const
Definition: stafield.h:613
static void Deriv(const stafield &stIn, stafield &stOut, int Jupdown, bool conjugate=false, std::complex< T > alpha=T(1), bool clear_result=false, int accuracy=0)
spherical tensor derivative:
Definition: stafield.h:1654
static void Prod(const stafield &stIn1, const stafield &stIn2, stafield &stOut, int J, bool normalize=false, std::complex< T > alpha=T(1), bool clear_result=false)
spherical tensor product: and , respectively
Definition: stafield.h:1876
the STA error class
Definition: sta_error.h:68
stafield pow(T v=T(1)) const
pow, component by component
Definition: stafield.h:2454
Definition: stensor_kernels.h:69
static void Lap(const stafield &stIn, stafield &stOut, std::complex< T > alpha=T(1), bool clear_result=false, int type=1)
Laplacian: .
Definition: stafield.h:1816
const stafield operator*(std::complex< T > alpha) const
Definition: stafield.h:640
stafield(std::string kernelname, const std::size_t shape[], std::vector< T > param_v, bool centered=false, int L=0, hanalysis::STA_FIELD_STORAGE field_storage=hanalysis::STA_FIELD_STORAGE_R, const T element_size[]=NULL)
Definition: stafield.h:818
represents spherical tensor fields (GPU version)
Definition: stafield_cuda.h:44
stafield operator[](int l) const
Definition: stafield.h:695
const std::size_t * getShape() const
Definition: stafield.h:187
tensor field has one single component of rank :
Definition: stensor.h:5180
STA_FIELD_TYPE
tensor field data interpretations according to certain symmetries
Definition: stensor.h:5177
int L
tensor rank
Definition: stafield.h:82
STA_RESULT
function return value
Definition: stensor.h:124
const std::complex< T > * getDataConst() const
Definition: stafield.h:1281
std::size_t shape[3]
image shape
Definition: stafield.h:74
Definition: stensor.h:6190
stafield mult(std::complex< T > alpha=T(1), bool conjugate=false) const
see Mult
Definition: stafield.h:2200
stafield & operator+=(const stafield &a)
Definition: stafield.h:523
static void Deriv2(const stafield &stIn, stafield &stOut, int Jupdown, bool conjugate=false, std::complex< T > alpha=T(1), bool clear_result=false)
spherical tensor double-derivative:
Definition: stafield.h:1734
stafield & operator*=(std::complex< T > alpha)
Definition: stafield.h:559
Definition: stafield.h:2370
stafield fft(bool forward, bool conjugate=false, std::complex< T > alpha=T(1), int flag=0) const
see FFT
Definition: stafield.h:2156
static void Mult(const stafield &stIn, stafield &stOut, std::complex< T > alpha=T(1), bool conjugate=false, bool clear_result=false)
computes
Definition: stafield.h:1516
Definition: stensor_kernels.h:67
Definition: stensor.h:5173
Definition: stensor_kernels.h:68
hanalysis::STA_FIELD_TYPE field_type
must be either STA_OFIELD_SINGLE, STA_OFIELD_FULL, STA_OFIELD_EVEN or STA_OFIELD_ODD ...
Definition: stafield.h:80
Definition: stafield.h:2380
Definition: stafield.h:2398
STA_RESULT sta_product(const std::complex< T > *stIn1, const std::complex< T > *stIn2, std::complex< T > *stOut, const std::size_t shape[], int J1, int J2, int J, std::complex< T > alpha=T(1), bool normalize=false, STA_FIELD_STORAGE field_storage=STA_FIELD_STORAGE_C, int stride_in1=-1, int stride_in2=-1, int stride_out=-1, bool clear_field=false)
spherical tensor product: and , respectively
Definition: stensor.h:5528
bool ownMemory() const
Definition: stafield.h:194
Definition: stensor.h:5167
STA_RESULT sta_fft(const std::complex< T > *stIn, std::complex< T > *stOut, const std::size_t shape[], int components, bool forward, bool conjugate=false, S alpha=(S) 1, int flag=0)
tensor fft component by component
Definition: stensor.h:6140
stafield deriv2(int J, bool conjugate=false, std::complex< T > alpha=T(1), int accuracy=0) const
see Deriv2
Definition: stafield.h:2284
stafield(const std::size_t shape[], int L, hanalysis::STA_FIELD_STORAGE field_storage=hanalysis::STA_FIELD_STORAGE_R, hanalysis::STA_FIELD_TYPE field_type=hanalysis::STA_OFIELD_SINGLE, const T element_size[]=NULL)
Definition: stafield.h:764
stafield & operator/=(std::complex< T > alpha)
Definition: stafield.h:575
const T * getDataConst() const
Definition: stafield_cuda.h:62
Definition: stensor_kernels.h:71
int getRank() const
Definition: stafield.h:236
The STA-ImageAnalysisToolkit namespace.
Definition: stafield.h:55
stafield sqrt() const
sqrt, component by component
Definition: stafield.h:2437
tensor field has all components of odd ranks :
Definition: stensor.h:5186
stafield norm() const
see Norm
Definition: stafield.h:2237
Definition: stensor_kernels.h:49
stafield & operator=(const stafield &f)
Definition: stafield.h:363
STA_RESULT sta_derivatives2(const std::complex< T > *stIn, std::complex< T > *stOut, const std::size_t shape[], int J, int Jupdown, bool conjugate=false, std::complex< T > alpha=(T) 1.0, STA_FIELD_STORAGE field_storage=STA_FIELD_STORAGE_C, const T v_size[]=NULL, int stride_in=-1, int stride_out=-1, bool clear_field=false)
spherical tensor double-derivative:
Definition: stensor.h:5991
bool createMemCopy()
Definition: stafield.h:1216
stafield & operator-=(const stafield &a)
Definition: stafield.h:541
std::complex< T > * getData()
Definition: stafield.h:1274
stafield prod(const stafield &b, int J, bool normalize=false, std::complex< T > alpha=T(1)) const
see Prod
Definition: stafield.h:2326
static void FFT(const stafield &stIn, stafield &stOut, bool forward, bool conjugate=false, std::complex< T > alpha=T(1), int flag=0)
tensor fft component by component
Definition: stafield.h:1437
static void Norm(const stafield &stIn, stafield &stOut, bool clear_result=false)
returns lengths of vectors component by compnent
Definition: stafield.h:1588
const stafield operator/(std::complex< T > alpha) const
Definition: stafield.h:657
stafield convolve(stafield &b, int J=0, int flag=0)
see FFT
Definition: stafield.h:2179
STA_RESULT sta_laplace(const std::complex< T > *stIn, std::complex< T > *stOut, const std::size_t shape[], int components=1, int type=1, std::complex< T > alpha=1, STA_FIELD_STORAGE field_storage=STA_FIELD_STORAGE_C, const T v_size[]=NULL, int stride_in=-1, int stride_out=-1, bool clear_field=false)
Laplacian: .
Definition: stensor.h:6078
const stafield operator+(const stafield &a) const
Definition: stafield.h:591
represents spherical tensor fields
Definition: stafield.h:62
bool operator==(const _stafield &field) const
Definition: stafield.h:135
stafield invert(T v=std::numeric_limits< T >::epsilon()) const
invert, component by component
Definition: stafield.h:2471
Definition: stensor.h:5170