示例#1
0
class DummyTrainer(Trainer):
    def _finish_init(self):
        self.net = DistFastNet(self.learning_rate, self.image_shape,
                               self.n_out, self.init_model)

    def get_from_master(self, data):
        data = comm.bcast(data, root=0)
        comm.barrier()
        return data

    def train(self):
        util.log('rank %d starting training...' % rank)
        while self.should_continue_training():
            train_data = self.get_from_master(None)
            self.curr_epoch = self.train_data.epoch

            input, label = train_data.data, train_data.labels
            self.net.train_batch(input, label)
            self.curr_batch += 1

            if self.check_test_data():
                self.get_test_error()

            if self.factor != 1.0 and self.check_adjust_lr():
                self.adjust_lr()

    def get_test_error(self):
        test_data = self.get_from_master(None)
        input, label = test_data.data, test_data.labels
        self.net.train_batch(input, label, TEST)

    def save_checkpoint(self):
        self.net.get_dumped_layers()
示例#2
0
class DummyTrainer(Trainer):
  def _finish_init(self):
    self.net = DistFastNet(self.learning_rate, self.image_shape, self.n_out, self.init_model)

  def get_from_master(self, data):
    data = comm.bcast(data, root = 0)
    comm.barrier()
    return data

  def train(self):
    util.log('rank %d starting training...' % rank)
    while self.should_continue_training():
      train_data = self.get_from_master(None)
      self.curr_epoch = self.train_data.epoch

      input, label = train_data.data, train_data.labels
      self.net.train_batch(input, label)
      self.curr_batch += 1

      if self.check_test_data():
        self.get_test_error()

      if self.factor != 1.0 and self.check_adjust_lr():
        self.adjust_lr()

  def get_test_error(self):
      test_data = self.get_from_master(None)
      input, label = test_data.data, test_data.labels
      self.net.train_batch(input, label, TEST)

  def save_checkpoint(self):
    self.net.get_dumped_layers()
示例#3
0
class ServerTrainer(Trainer):
  def _finish_init(self):
    self.net = DistFastNet(self.learning_rate, self.image_shape, self.init_model)

  def reshape_data(self, data):
    batch_data = data.data.get()
    batch_size = self.batch_size

    num_case = batch_data.size / (self.image_color * self.image_size * self.image_size)
    batch_data = batch_data.reshape((self.image_color, self.image_size, self.image_size, num_case))
    data.data = batch_data

  def scatter_to_worker(self, data):
    comm.bcast(data, root = 0)
    comm.barrier()

  def train(self):
    self.print_net_summary()
    util.log('Starting training...')
    while self.should_continue_training():
      train_data = self.train_dp.get_next_batch()  # self.train_dp.wait()
      self.reshape_data(train_data)
      self.scatter_to_worker(train_data)
      self.curr_epoch = train_data.epoch

      input, label = train_data.data, train_data.labels
      self.net.train_batch(input, label)

      self.curr_batch += 1
      cost , correct, numCase = self.net.get_batch_information()
      self.train_outputs += [({'logprob': [cost, 1 - correct]}, numCase, time.time() - start)]
      print >> sys.stderr,  '%d.%d: error: %f logreg: %f time: %f' % (self.curr_epoch, self.curr_batch, 1 - correct, cost, time.time() - start)

      self.num_batch += 1
      if self.check_test_data():
        self.get_test_error()

      if self.factor != 1.0 and self.check_adjust_lr():
        self.adjust_lr()

      if self.check_save_checkpoint():
        self.save_checkpoint()

    self.get_test_error()
    self.save_checkpoint()

  def get_test_error(self):
    start = time.time()
    test_data = self.test_dp.get_next_batch()
    self.reshape_data(test_data)
    self.scatter_to_worker(test_data)

    input, label = test_data.data, test_data.labels
    self.net.train_batch(input, label, TEST)

    cost , correct, numCase, = self.net.get_batch_information()
    self.test_outputs += [({'logprob': [cost, 1 - correct]}, numCase, time.time() - start)]
    print >> sys.stderr,  'error: %f logreg: %f time: %f' % (1 - correct, cost, time.time() - start)
示例#4
0
 def _finish_init(self):
     self.net = DistFastNet(self.learning_rate, self.image_shape,
                            self.init_model)
示例#5
0
class ServerTrainer(Trainer):
    def _finish_init(self):
        self.net = DistFastNet(self.learning_rate, self.image_shape,
                               self.init_model)

    def reshape_data(self, data):
        batch_data = data.data.get()
        batch_size = self.batch_size

        num_case = batch_data.size / (self.image_color * self.image_size *
                                      self.image_size)
        batch_data = batch_data.reshape(
            (self.image_color, self.image_size, self.image_size, num_case))
        data.data = batch_data

    def scatter_to_worker(self, data):
        comm.bcast(data, root=0)
        comm.barrier()

    def train(self):
        self.print_net_summary()
        util.log('Starting training...')
        while self.should_continue_training():
            train_data = self.train_dp.get_next_batch()  # self.train_dp.wait()
            self.reshape_data(train_data)
            self.scatter_to_worker(train_data)
            self.curr_epoch = train_data.epoch

            input, label = train_data.data, train_data.labels
            self.net.train_batch(input, label)

            self.curr_batch += 1
            cost, correct, numCase = self.net.get_batch_information()
            self.train_outputs += [({
                'logprob': [cost, 1 - correct]
            }, numCase, time.time() - start)]
            print >> sys.stderr, '%d.%d: error: %f logreg: %f time: %f' % (
                self.curr_epoch, self.curr_batch, 1 - correct, cost,
                time.time() - start)

            self.num_batch += 1
            if self.check_test_data():
                self.get_test_error()

            if self.factor != 1.0 and self.check_adjust_lr():
                self.adjust_lr()

            if self.check_save_checkpoint():
                self.save_checkpoint()

        self.get_test_error()
        self.save_checkpoint()

    def get_test_error(self):
        start = time.time()
        test_data = self.test_dp.get_next_batch()
        self.reshape_data(test_data)
        self.scatter_to_worker(test_data)

        input, label = test_data.data, test_data.labels
        self.net.train_batch(input, label, TEST)

        cost, correct, numCase, = self.net.get_batch_information()
        self.test_outputs += [({
            'logprob': [cost, 1 - correct]
        }, numCase, time.time() - start)]
        print >> sys.stderr, 'error: %f logreg: %f time: %f' % (
            1 - correct, cost, time.time() - start)
示例#6
0
 def _finish_init(self):
   self.net = DistFastNet(self.learning_rate, self.image_shape, self.init_model)