1 #ifndef MSHADOW_TENSOR_EXPR_ENGINE_INL_HPP
2 #define MSHADOW_TENSOR_EXPR_ENGINE_INL_HPP
19 template<
typename SubType,
typename SrcExp,
int dim>
25 return *
static_cast<const SubType*
>(
this);
32 template<
typename ExpType>
42 template <
typename Device,
int dim>
48 return dptr_[ y * stride_ + x ];
55 template <
typename Device>
78 template<
typename OP,
typename TA,
typename TB,
int etype>
82 :lhs_(lhs), rhs_(rhs){}
84 return OP::Map( lhs_.Eval( y, x ), rhs_.Eval( y, x ) );
91 template<
typename OP,
typename TA,
int etype>
96 return OP::Map( src_.Eval( y, x ) );
103 template<
typename SubType,
typename SrcExp,
int dim>
108 return src_.Eval( y, x );
115 template<
typename OP,
typename TA,
typename TB,
int etype>
124 inline Plan<T> MakePlan(
const ContainerExp<T> &e ){
125 return Plan<T>( e.self() );
128 template<
typename T,
typename SrcExp,
int dim>
129 inline Plan< T > MakePlan(
const MakeTensorExp<T,SrcExp,dim> &e ){
130 return Plan< T >( e.real_self() );
133 template<
typename OP,
typename TA,
int etype>
134 inline Plan< UnaryMapExp<OP,TA,etype> > MakePlan(
const UnaryMapExp<OP,TA,etype> &e ){
135 return Plan< UnaryMapExp<OP,TA,etype> >( MakePlan(e.src_) );
138 template<
typename OP,
typename TA,
typename TB,
int etype>
139 inline Plan< BinaryMapExp<OP,TA,TB,etype> > MakePlan(
const BinaryMapExp<OP,TA,TB,etype> &e ){
140 return Plan< BinaryMapExp<OP,TA,TB,etype> >( MakePlan(e.lhs_), MakePlan(e.rhs_) );
154 const static int kDim = -1;
155 const static int kDevMask = 0;
159 const static int kDim = 0;
160 const static int kDevMask = 0xffff;
162 template<
typename Device,
int dim>
164 const static int kDim = dim;
165 const static int kDevMask = Device::kDevMask;
167 template<
typename T,
typename SrcExp,
int dim>
170 const static int kDim = kDimSrc >= 0 ? dim : -1;
173 template<
typename OP,
typename TA,
int etype>
178 template<
typename OP,
typename TA,
typename TB,
int etype>
182 const static int kDim = (kDimLhs>=0 && kDimRhs >= 0) ? \
183 ( kDimLhs==0 ? kDimRhs : ( (kDimRhs==0||kDimLhs==kDimRhs) ? kDimLhs : -1 ) ):-1;
188 template<
typename Device,
int dim,
typename E>
206 inline static void Error_All_Tensor_in_Exp_Must_Have_Same_Type(
void ){}
207 inline static void Error_TypeCheck_Not_Pass_For_Reduce_Exp(
void ){}
208 inline static void Error_Expression_Does_Not_Meet_Dimension_Req(
void ){}
214 template<
int dim,
typename E>
227 template<
int dim,
typename Device>
233 template<
int dim,
typename SrcExp,
typename T>
239 template<
int dim,
typename OP,
typename TA,
int etype>
246 template<
int dim,
typename OP,
typename TA,
typename TB,
int etype>
251 if( shape1[0] == 0 )
return shape2;
252 if( shape2[0] == 0 )
return shape1;
253 utils::Assert( shape1 == shape2,
"BinaryMapExp: Shapes of two tensors in BinaryMapExp expression is not the same");
261 template<
typename SV,
typename Device,
int ddim,
int ldim,
int rdim,
bool ltrans,
bool rtrans>
267 template<
typename Device>
270 #if (MSHADOW_USE_CBLAS||MSHADOW_USE_MKL)
273 inline static CBLAS_TRANSPOSE GetT(
bool t ){
274 return t ? CblasTrans : CblasNoTrans;
276 inline static void gemm(
bool transa,
bool transb,
int m,
int n,
int k,
float alpha, \
277 const float *A,
int lda,
const float *B,
int ldb,
float beta,
float *C,
int ldc ){
278 cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb), m,n,k,alpha,A,lda,B,ldb,beta,C,ldc);
280 inline static void gemm(
bool transa,
bool transb,
int m,
int n,
int k,
double alpha, \
281 const double *A,
int lda,
const double *B,
int ldb,
double beta,
double *C,
int ldc ){
282 cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb), m,n,k,alpha,A,lda,B,ldb,beta,C,ldc);
284 inline static void gemv(
bool trans,
int m,
int n,
float alpha,
const float *A,
int lda, \
285 const float *X,
int incX,
float beta,
float *Y,
int incY ){
286 cblas_sgemv(CblasColMajor, GetT(trans), m,n,alpha,A,lda,X,incX,beta,Y,incY);
288 inline static void gemv(
bool trans,
int m,
int n,
double alpha,
const double *A,
int lda, \
289 const double *X,
int incX,
double beta,
double *Y,
int incY ){
290 cblas_dgemv(CblasColMajor, GetT(trans), m,n,alpha,A,lda,X,incX,beta,Y,incY);
292 inline static void ger(
int m,
int n,
float alpha,
const float *X,
int incX,
const float *Y,
int incY,
float *A,
int lda ){
293 cblas_sger(CblasColMajor,m,n,alpha,X,incX,Y,incY,A,lda);
295 inline static void ger(
int m,
int n,
double alpha,
const double *X,
int incX,
const double *Y,
int incY,
double *A,
int lda ){
296 cblas_dger(CblasColMajor,m,n,alpha,X,incX,Y,incY,A,lda);
299 #endif // MSHADOW_USE_CBLAS || MSHADOW_USE_MKL
305 inline static char GetT(
bool t ){
306 return t ?
'T' :
'N';
308 inline static void gemm(
bool transa,
bool transb,
int m,
int n,
int k,
float alpha,
309 const float *A,
int lda,
const float *B,
int ldb,
float beta,
float *C,
int ldc ){
310 cublasSgemm(GetT(transa),GetT(transb),m,n,k,alpha,A,lda,B,ldb,beta,C,ldc);
312 inline static void gemm(
bool transa,
bool transb,
int m,
int n,
int k,
double alpha,
313 const double *A,
int lda,
const double *B,
int ldb,
double beta,
double *C,
int ldc ){
314 cublasDgemm(GetT(transa),GetT(transb),m,n,k,alpha,A,lda,B,ldb,beta,C,ldc);
316 inline static void gemv(
bool trans,
int m,
int n,
float alpha,
const float *A,
int lda, \
317 const float *X,
int incX,
float beta,
float *Y,
int incY ){
318 cublasSgemv(GetT(trans), m,n,alpha,A,lda,X,incX,beta,Y,incY);
320 inline static void gemv(
bool trans,
int m,
int n,
double alpha,
const double *A,
int lda, \
321 const double *X,
int incX,
double beta,
double *Y,
int incY ){
322 cublasDgemv(GetT(trans), m,n,alpha,A,lda,X,incX,beta,Y,incY);
324 inline static void ger(
int m,
int n,
float alpha,
const float *X,
int incX,
const float *Y,
int incY,
float *A,
int lda ){
325 cublasSger(m,n,alpha,X,incX,Y,incY,A,lda);
327 inline static void ger(
int m,
int n,
double alpha,
const double *X,
int incX,
const double *Y,
int incY,
double *A,
int lda ){
328 cublasDger(m,n,alpha,X,incX,Y,incY,A,lda);
335 return transpose ?
Shape2(shape[0],shape[1]) : shape;
338 template<
typename SV,
typename xpu,
bool transpose_left,
bool transpose_right>
339 struct DotEngine<SV,xpu,2,2,2,transpose_left,transpose_right>{
344 && sleft[0] == sright[1] ,
"dot-gemm: matrix shape mismatch" );
347 ( transpose_right , transpose_left,
348 transpose_right ? rhs.
shape[1] : rhs.
shape[0],
350 transpose_right ? rhs.
shape[0] : rhs.
shape[1],
351 scale * SV::kAlphaBLAS,
358 template<
typename SV,
typename xpu,
bool transpose_right>
365 rhs.
shape[0], rhs.
shape[1], scale * SV::kAlphaBLAS,
367 lhs.
dptr, 1, SV::kBetaBLAS,
371 template<
typename SV,
typename xpu>
375 if( SV::kBetaBLAS < 1e-6f ){
377 ( rhs.
shape[0], lhs.
shape[0], scale * SV::kAlphaBLAS,
389 template<
typename SV,
typename Device,
int dim,
typename E>
393 template<
typename SV,
typename Device,
int dim>
397 MapExp<SV,dim,E>( dst, exp );
401 MapExp<SV,dim,E>( dst, exp );
408 template<
typename SV,
typename Device,
int dim,
int ldim,
int rdim,
bool ltrans,
bool rtrans>
static const bool kMapPass
whether the expression can be mapped to expression of dim
Definition: tensor_expr_engine-inl.hpp:195
unsigned index_t
type that will be used for index
Definition: tensor_base.h:123
Shape< dim > shape_
the shape of this expression
Definition: tensor_expr_engine-inl.hpp:22
Definition: tensor_expr_engine-inl.hpp:67
This part of code gives plan that can be used to carry out execution.
Definition: tensor_expr_engine-inl.hpp:33
Definition: tensor_expr_engine-inl.hpp:201
template to do type check
Definition: tensor_expr_engine-inl.hpp:189
binary map expression lhs [op] rhs
Definition: tensor_expr.h:225
void Assert(bool exp)
assert a expression is true
Definition: tensor_base.h:285
static const bool kRedPass
whether the expression can be reduced to expression of dim
Definition: tensor_expr_engine-inl.hpp:197
const SubType & real_self(void) const
true self of subtype
Definition: tensor_expr_engine-inl.hpp:24
static const int kExpDim
dimension of expression
Definition: tensor_expr_engine-inl.hpp:191
Definition: tensor_expr_engine-inl.hpp:262
float real_t
type that will be used for content
Definition: tensor_base.h:118
const SubType & self(void) const
Definition: tensor_expr.h:52
header file of tensor data structure and functions covention: this lib requires explicit memory alloc...
device name CPU
Definition: tensor.h:185
device name CPU
Definition: tensor.h:192
const TA & src_
source expression
Definition: tensor_expr.h:342
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: tensor_expr_engine-inl.hpp:153
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s1, index_t s0)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:152
MSHADOW_XINLINE real_t Eval(index_t y, index_t x) const
evaluate the expression at index [y][x] to be implemented by SubType
Definition: tensor_expr_engine-inl.hpp:215
static const bool kDevPass
whether the expression device type matches
Definition: tensor_expr_engine-inl.hpp:193
some engine that evaluate complex expression
Definition: tensor_expr_engine-inl.hpp:390
real_t * dptr
pointer to the data
Definition: tensor.h:215
MSHADOW_XINLINE Tensor< Device, 2 > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:229
MSHADOW_XINLINE real_t Eval(index_t y, index_t x) const
evaluate at [y][x]
Definition: tensor_expr_engine-inl.hpp:71
unary map expression op(src)
Definition: tensor_expr.h:340
matrix multiplication expression dot( lhs[.T], rhs[.T] )
Definition: tensor_expr.h:172
Shape< dimension > shape
shape of the tensor
Definition: tensor.h:217
scalar expression
Definition: tensor_expr.h:62
base class for expression
Definition: tensor_expr.h:49
const TA & lhs_
left operand
Definition: tensor_expr.h:227
a general class that allows extension that makes tensors of some shape
Definition: tensor_expr_engine-inl.hpp:20
definitions of abstract expressions and expressions template
expression engine that actually interprets these expressions this is a function template that needed ...
Definition: tensor_expr.h:34
real_t scalar_
scalar value
Definition: tensor_expr.h:64
Definition: tensor_expr_engine-inl.hpp:268
general tensor
Definition: tensor.h:206
const TB & rhs_
right operand
Definition: tensor_expr.h:229