Apache SINGA
A distributed deep learning platform .
 All Classes Namespaces Files Functions Variables Typedefs Enumerator Macros
worker.h
1 #ifndef INCLUDE_TRAINER_WORKER_H_
2 #define INCLUDE_TRAINER_WORKER_H_
3 #include <map>
4 #include <exception>
5 #include "neuralnet/neuralnet.h"
6 #include "proto/model.pb.h"
7 #include "utils/cluster.h"
8 #include "communication/socket.h"
9 #include "communication/msg.h"
10 
11 namespace singa {
12 const int kCollectSleepTime=5;//milliseconds;
18 class Worker {
19  public:
20  Worker(int thread_id, int group_id, int worker_id);
21  ~Worker(){}
22  void Setup(const ModelProto& model, shared_ptr<NeuralNet> train_net);
23  void set_test_net(shared_ptr<NeuralNet> test_net){
24  test_net_=test_net;
25  }
26  void set_validation_net(shared_ptr<NeuralNet> val_net){
27  validation_net_=val_net;
28  }
29 
30 
31  void Stop();
32  int Put(shared_ptr<Param> param, int step);
33  int Get(shared_ptr<Param> param, int step);
34  int Update(shared_ptr<Param> param, int step);
35  int Collect(shared_ptr<Param> param, int step);
36  int CollectAll(shared_ptr<NeuralNet> net, int step);
43  void RunOneBatch(int step, Metric* perf=nullptr);
48  virtual void TrainOneBatch(int step)=0;
52  virtual void TestOneBatch(shared_ptr<NeuralNet> net, int step, Phase phase)=0;
59  void Test(shared_ptr<NeuralNet> net, int nsteps, const string &prefix);
60 
66  virtual void Run();
67 
68 
77  const bool DisplayNow(const int step) const {
78  return (modelproto_.display_frequency() > 0
79  && step >= modelproto_.display_after_steps()
80  && ((step - modelproto_.display_after_steps())
81  % modelproto_.display_frequency() == 0));
82  }
83 
84  const bool DisplayDebugInfo(const int step) const {
85  return DisplayNow(step)&&modelproto_.debug()&&group_id_==0;
86  }
87  const void DisplayPerformance(const Metric & perf, const string& prefix);
88 
93  const bool StopNow(const int step) const{
94  return (step >= modelproto_.train_steps());
95  }
100  const bool CheckpointNow(const int step) const{
101  return (group_id_==0
102  && modelproto_.checkpoint_frequency() > 0
103  && step >= modelproto_.checkpoint_after_steps()
104  && ((step - modelproto_.checkpoint_after_steps())
105  % modelproto_.checkpoint_frequency() == 0));
106  }
111  const bool TestNow(const int step) const{
112  return (group_id_==0
113  && modelproto_.test_frequency() > 0
114  && modelproto_.test_steps() > 0
115  && step >= modelproto_.test_after_steps()
116  && ((step - modelproto_.test_after_steps())
117  % modelproto_.test_frequency() == 0));
118  }
123  const bool ValidateNow(const int step) {
124  return (group_id_==0
125  && modelproto_.validation_frequency() > 0
126  && modelproto_.validation_steps() > 0
127  && step >= modelproto_.validation_after_steps()
128  && ((step - modelproto_.validation_after_steps())
129  % modelproto_.validation_frequency() == 0));
130  }
131 
132 
142  void ReceiveBlobs(shared_ptr<NeuralNet> net);
143  void SendBlob();
144  protected:
145  int thread_id_, group_id_, worker_id_;
146  int step_;
147  ModelProto modelproto_;
148  shared_ptr<NeuralNet> train_net_, test_net_, validation_net_;
149  shared_ptr<Dealer> layer_dealer_, param_dealer_;
150  Poller layer_poller_, param_poller_;
151 };
152 
153 class BPWorker: public Worker{
154  public:
155  ~BPWorker(){}
156  BPWorker(int thread_id, int group_id, int worker_id):Worker(thread_id, group_id, worker_id){}
157  virtual void TrainOneBatch(int step);
158  virtual void TestOneBatch(shared_ptr<NeuralNet> net, int step, Phase phase);
159  void Forward(shared_ptr<NeuralNet> net, int step, bool training);
160  void Backward(shared_ptr<NeuralNet> net, int step);
188 };
189 } // namespace singa
190 
191 #endif // INCLUDE_TRAINER_WORKER_H_
virtual void Run()
Main function of Worker.
The Worker class which runs the training algorithm.
Definition: worker.h:18
Definition: model.pb.h:316
virtual void TrainOneBatch(int step)=0
Train one mini-batch.
Definition: worker.h:153
const bool StopNow(const int step) const
return true if the stop condition is satisfied, e.g., the maximum number of steps have been reached...
Definition: worker.h:93
void Test(shared_ptr< NeuralNet > net, int nsteps, const string &prefix)
Test the perforance of the learned model on validation or test dataset.
virtual void TestOneBatch(shared_ptr< NeuralNet > net, int step, Phase phase)
Test/validate one mini-batch.
const bool ValidateNow(const int step)
Check is it time to do validation.
Definition: worker.h:123
const bool CheckpointNow(const int step) const
Check is it time to do checkpoint.
Definition: worker.h:100
void ReceiveBlobs(shared_ptr< NeuralNet > net)
start training from scratch.
virtual void TrainOneBatch(int step)
Train one mini-batch.
Definition: socket.h:56
void RunOneBatch(int step, Metric *perf=nullptr)
check validation/test firstly, then TrainOneBatch Performance collects performance for the whole neur...
Definition: common.h:49
const bool DisplayNow(const int step) const
Pull data from layers resident on other nodes due to Model Partition.
Definition: worker.h:77
virtual void TestOneBatch(shared_ptr< NeuralNet > net, int step, Phase phase)=0
Test/validate one mini-batch.
const bool TestNow(const int step) const
Check is it time to do test.
Definition: worker.h:111