Apache SINGA
A distributed deep learning platform .
 All Classes Namespaces Files Functions Variables Typedefs Enumerator Macros
param.h
1 #ifndef INCLUDE_UTILS_PARAM_H_
2 #define INCLUDE_UTILS_PARAM_H_
3 #include <vector>
4 #include <string>
5 #include <map>
6 #include <functional>
7 #include "proto/model.pb.h"
8 #include "utils/blob.h"
9 #include "communication/msg.h"
10 // Base paramter class.
11 namespace singa {
12 class Param {
13  public:
14  Param():data_(nullptr){}
15  virtual ~Param(){};
16 
17  virtual Msg* GenGetMsg(void* arg=nullptr);
18  virtual Msg* GenPutMsg(void* arg=nullptr);
19  virtual Msg* GenUpdateMsg(void* arg=nullptr);
20  virtual Msg* GenSyncMsg(void* arg=nullptr);
21 
22  virtual Msg* HandleGetMsg(Msg** msg);
23  virtual Msg* HandlePutMsg(Msg** msg);
24  virtual int ParseUpdateMsg(Msg** msg);
25  virtual Msg* GenUpdateResponseMsg(void* arg=nullptr);
26  virtual Msg* HandleSyncMsg(Msg** msg);
27 
28  virtual int ParseGetResponseMsg(Msg** msg);
29  virtual int ParsePutResponseMsg(Msg** msg);
30  virtual int ParseUpdateResponseMsg(Msg** msg);
31  virtual int ParseSyncResponseMsg(Msg** msg);
32 
36  virtual void Setup(const ParamProto& proto, const std::vector<int>& shape, int fan_in);
37  /*
38  * fill the data according to initmethod, i.e., random/gaussian/fixed value
39  */
40  virtual void Init(int v=0);
41  void ShareData(shared_ptr<Param> other){
42  proto_.set_owner(other->owner());
43  if(data_!=nullptr)
44  CHECK(std::equal(data_->shape().begin(), data_->shape().end(),
45  other->data_->shape().begin()));
46  data_=other->data_;
47  }
48  float learning_rate_multiplier() {
49  return proto_.learning_rate_multiplier();
50  }
51  float weight_decay_multiplier() {
52  return proto_.weight_decay_multiplier();
53  }
54  /*
55  const int split_threshold(){
56  return proto_.split_threshold();
57  }
58  */
59  const std::string& name() {
60  return proto_.name();
61  }
66  const int owner() const{
67  return proto_.owner();
68  }
69  int id() const{
70  return proto_.id();
71  }
72  void set_id(int id){
73  proto_.set_id(id);
74  proto_.set_owner(id);
75  }
76 
77  int version() const {
78  return data_->version(); // TODO store version in data blob
79  }
80  void set_version(int v) {
81  data_->set_version(v); // TODO read version from data blob
82  }
86  int size() const {
87  return data_->count();
88  }
92  const Blob<float> &data() {
93  return *data_;
94  }
95  Blob<float> *mutable_data() {
96  return data_.get();
97  }
101  const Blob<float> &grad() {
102  return grad_;
103  }
104  Blob<float> *mutable_grad() {
105  return &grad_;
106  }
107 
108  const Blob<float> &history() {
109  return history_;
110  }
111  Blob<float> *mutable_history() {
112  return &history_;
113  }
114 
115  float* mutable_cpu_data(){
116  return data_->mutable_cpu_data();
117  }
118  float* mutable_cpu_grad(){
119  return grad_.mutable_cpu_data();
120  }
121  float* mutable_cpu_history(){
122  return history_.mutable_cpu_data();
123  }
124  protected:
128  std::string name_;
129  shared_ptr<Blob<float>> data_;
131  Blob<float> grad_, history_;
132  ParamProto proto_;
133  int fan_in_;
134 };
170 } // namespace singa
171 
172 #endif // INCLUDE_UTILS_PARAM_H_
std::string name_
name of the parameter used to share wights between neuralnets
Definition: param.h:128
const int owner() const
if the Param shares data with others, then owner is the id of that param.
Definition: param.h:66
Definition: param.h:12
Definition: msg.h:59
Blob< float > grad_
gradient, history gradient of this parameter
Definition: param.h:131
const Blob< float > & data()
Return const mem address for the content of this parameter.
Definition: param.h:92
virtual void Setup(const ParamProto &proto, const std::vector< int > &shape, int fan_in)
setup param shape
const Blob< float > & grad()
Return gradient of this parameter.
Definition: param.h:101
Definition: model.pb.h:764
int size() const
Definition: param.h:86