class DummyTrainer(Trainer): def _finish_init(self): self.net = DistFastNet(self.learning_rate, self.image_shape, self.n_out, self.init_model) def get_from_master(self, data): data = comm.bcast(data, root=0) comm.barrier() return data def train(self): util.log('rank %d starting training...' % rank) while self.should_continue_training(): train_data = self.get_from_master(None) self.curr_epoch = self.train_data.epoch input, label = train_data.data, train_data.labels self.net.train_batch(input, label) self.curr_batch += 1 if self.check_test_data(): self.get_test_error() if self.factor != 1.0 and self.check_adjust_lr(): self.adjust_lr() def get_test_error(self): test_data = self.get_from_master(None) input, label = test_data.data, test_data.labels self.net.train_batch(input, label, TEST) def save_checkpoint(self): self.net.get_dumped_layers()
class DummyTrainer(Trainer): def _finish_init(self): self.net = DistFastNet(self.learning_rate, self.image_shape, self.n_out, self.init_model) def get_from_master(self, data): data = comm.bcast(data, root = 0) comm.barrier() return data def train(self): util.log('rank %d starting training...' % rank) while self.should_continue_training(): train_data = self.get_from_master(None) self.curr_epoch = self.train_data.epoch input, label = train_data.data, train_data.labels self.net.train_batch(input, label) self.curr_batch += 1 if self.check_test_data(): self.get_test_error() if self.factor != 1.0 and self.check_adjust_lr(): self.adjust_lr() def get_test_error(self): test_data = self.get_from_master(None) input, label = test_data.data, test_data.labels self.net.train_batch(input, label, TEST) def save_checkpoint(self): self.net.get_dumped_layers()
class ServerTrainer(Trainer): def _finish_init(self): self.net = DistFastNet(self.learning_rate, self.image_shape, self.init_model) def reshape_data(self, data): batch_data = data.data.get() batch_size = self.batch_size num_case = batch_data.size / (self.image_color * self.image_size * self.image_size) batch_data = batch_data.reshape((self.image_color, self.image_size, self.image_size, num_case)) data.data = batch_data def scatter_to_worker(self, data): comm.bcast(data, root = 0) comm.barrier() def train(self): self.print_net_summary() util.log('Starting training...') while self.should_continue_training(): train_data = self.train_dp.get_next_batch() # self.train_dp.wait() self.reshape_data(train_data) self.scatter_to_worker(train_data) self.curr_epoch = train_data.epoch input, label = train_data.data, train_data.labels self.net.train_batch(input, label) self.curr_batch += 1 cost , correct, numCase = self.net.get_batch_information() self.train_outputs += [({'logprob': [cost, 1 - correct]}, numCase, time.time() - start)] print >> sys.stderr, '%d.%d: error: %f logreg: %f time: %f' % (self.curr_epoch, self.curr_batch, 1 - correct, cost, time.time() - start) self.num_batch += 1 if self.check_test_data(): self.get_test_error() if self.factor != 1.0 and self.check_adjust_lr(): self.adjust_lr() if self.check_save_checkpoint(): self.save_checkpoint() self.get_test_error() self.save_checkpoint() def get_test_error(self): start = time.time() test_data = self.test_dp.get_next_batch() self.reshape_data(test_data) self.scatter_to_worker(test_data) input, label = test_data.data, test_data.labels self.net.train_batch(input, label, TEST) cost , correct, numCase, = self.net.get_batch_information() self.test_outputs += [({'logprob': [cost, 1 - correct]}, numCase, time.time() - start)] print >> sys.stderr, 'error: %f logreg: %f time: %f' % (1 - correct, cost, time.time() - start)
def _finish_init(self): self.net = DistFastNet(self.learning_rate, self.image_shape, self.init_model)
class ServerTrainer(Trainer): def _finish_init(self): self.net = DistFastNet(self.learning_rate, self.image_shape, self.init_model) def reshape_data(self, data): batch_data = data.data.get() batch_size = self.batch_size num_case = batch_data.size / (self.image_color * self.image_size * self.image_size) batch_data = batch_data.reshape( (self.image_color, self.image_size, self.image_size, num_case)) data.data = batch_data def scatter_to_worker(self, data): comm.bcast(data, root=0) comm.barrier() def train(self): self.print_net_summary() util.log('Starting training...') while self.should_continue_training(): train_data = self.train_dp.get_next_batch() # self.train_dp.wait() self.reshape_data(train_data) self.scatter_to_worker(train_data) self.curr_epoch = train_data.epoch input, label = train_data.data, train_data.labels self.net.train_batch(input, label) self.curr_batch += 1 cost, correct, numCase = self.net.get_batch_information() self.train_outputs += [({ 'logprob': [cost, 1 - correct] }, numCase, time.time() - start)] print >> sys.stderr, '%d.%d: error: %f logreg: %f time: %f' % ( self.curr_epoch, self.curr_batch, 1 - correct, cost, time.time() - start) self.num_batch += 1 if self.check_test_data(): self.get_test_error() if self.factor != 1.0 and self.check_adjust_lr(): self.adjust_lr() if self.check_save_checkpoint(): self.save_checkpoint() self.get_test_error() self.save_checkpoint() def get_test_error(self): start = time.time() test_data = self.test_dp.get_next_batch() self.reshape_data(test_data) self.scatter_to_worker(test_data) input, label = test_data.data, test_data.labels self.net.train_batch(input, label, TEST) cost, correct, numCase, = self.net.get_batch_information() self.test_outputs += [({ 'logprob': [cost, 1 - correct] }, numCase, time.time() - start)] print >> sys.stderr, 'error: %f logreg: %f time: %f' % ( 1 - correct, cost, time.time() - start)