Exemple #1
0
class Trainer:
  CHECKPOINT_REGEX = None
  def __init__(self, test_id, data_dir, data_provider, checkpoint_dir, train_range, test_range, test_freq, save_freq, batch_size, num_epoch, image_size,
               image_color, learning_rate, auto_init=False, init_model=None, adjust_freq=1, factor=1.0):
    self.test_id = test_id
    self.data_dir = data_dir
    self.data_provider = data_provider
    self.checkpoint_dir = checkpoint_dir
    self.train_range = train_range
    self.test_range = test_range
    self.test_freq = test_freq
    self.save_freq = save_freq
    self.batch_size = batch_size
    self.num_epoch = num_epoch
    self.image_size = image_size
    self.image_color = image_color
    self.learning_rate = learning_rate
    # doesn't matter anymore
    self.n_out = 10
    self.factor = factor
    self.adjust_freq = adjust_freq
    self.regex = re.compile('^test%d-(\d+)\.(\d+)$' % self.test_id)

    self.init_data_provider()
    self.image_shape = (self.batch_size, self.image_color, self.image_size, self.image_size)

    if init_model is not None and 'model_state' in init_model:
      self.train_outputs = init_model['model_state']['train_outputs']
      self.test_outputs = init_model['model_state']['test_outputs']
    else:
      self.train_outputs = []
      self.test_outputs = []

    self.curr_minibatch = self.num_batch = self.curr_epoch = self.curr_batch = 0
    self.net = FastNet(self.learning_rate, self.image_shape, self.n_out, init_model=init_model)

    self.train_data = None
    self.test_data = None

    self.num_train_minibatch = 0
    self.num_test_minibatch = 0
    self.checkpoint_file = ''
    
    self.train_dumper = None #DataDumper('/scratch1/imagenet-pickle/train-data.pickle')
    self.test_dumper = None #DataDumper('/scratch1/imagenet-pickle/test-data.pickle')
    self.input = None


  def init_data_provider(self):
    dp = DataProvider.get_by_name(self.data_provider)
    self.train_dp = dp(self.data_dir, self.train_range)
    self.test_dp = dp(self.data_dir, self.test_range)


  def get_next_minibatch(self, i, train=TRAIN):
    if train == TRAIN:
      data = self.train_data
    else:
      data = self.test_data

    batch_data = data.data
    batch_label = data.labels
    batch_size = self.batch_size

    mini_data = batch_data[:, i * batch_size: (i + 1) * batch_size]
    locked_data = driver.pagelocked_empty(mini_data.shape, mini_data.dtype, order='C',
                                          mem_flags=driver.host_alloc_flags.PORTABLE)
    locked_data[:] = mini_data

    if self.input is not None and locked_data.shape == self.input.shape:
      self.input.set(locked_data)
    else:
      self.input = gpuarray.to_gpu(locked_data)
    
    label = batch_label[i * batch_size : (i + 1) * batch_size]
    #label = gpuarray.to_gpu(label)

    #label = gpuarray.to_gpu(np.require(batch_label[i * batch_size : (i + 1) * batch_size],  dtype =
    #  np.float, requirements = 'C'))

    return self.input, label


  def save_checkpoint(self):
    model = {}
    model['batchnum'] = self.train_dp.get_batch_num()
    model['epoch'] = self.num_epoch + 1
    model['layers'] = self.net.get_dumped_layers()

    model['train_outputs'] = self.train_outputs
    model['test_outputs'] = self.test_outputs

    dic = {'model_state': model, 'op':None}
    self.print_net_summary()
    
    if not os.path.exists(self.checkpoint_dir):
      os.system('mkdir -p \'%s\'' % self.checkpoint_dir)
    
    saved_filename = [f for f in os.listdir(self.checkpoint_dir) if self.regex.match(f)]
    for f in saved_filename:
      os.remove(os.path.join(self.checkpoint_dir, f))
    checkpoint_filename = "test%d-%d.%d" % (self.test_id, self.curr_epoch, self.curr_batch)
    checkpoint_file_path = os.path.join(self.checkpoint_dir, checkpoint_filename)
    self.checkpoint_file = checkpoint_file_path
    print >> sys.stderr,  checkpoint_file_path
    with open(checkpoint_file_path, 'w') as f:
      cPickle.dump(dic, f, protocol=-1)
    util.log('save file finished')

  def get_test_error(self):
    start = time.time()
    self.test_data = self.test_dp.get_next_batch()

    self.num_test_minibatch = divup(self.test_data.data.shape[1], self.batch_size)
    for i in range(self.num_test_minibatch):
      input, label = self.get_next_minibatch(i, TEST)
      self.net.train_batch(input, label, TEST)
      self._capture_test_data()
    
    cost , correct, numCase, = self.net.get_batch_information()
    self.test_outputs += [({'logprob': [cost, 1 - correct]}, numCase, time.time() - start)]
    print >> sys.stderr,  '[%d] error: %f logreg: %f time: %f' % (self.test_data.batchnum, 1 - correct, cost, time.time() - start)

  def print_net_summary(self):
    print >> sys.stderr,  '--------------------------------------------------------------'
    for s in self.net.get_summary():
      name = s[0]
      values = s[1]
      print >> sys.stderr,  "Layer '%s' weight: %e [%e]" % (name, values[0], values[1])
      print >> sys.stderr,  "Layer '%s' bias: %e [%e]" % (name, values[2], values[3])


  def should_continue_training(self):
    return self.curr_epoch <= self.num_epoch

  def check_test_data(self):
    return self.num_batch % self.test_freq == 0

  def check_save_checkpoint(self):
    return self.num_batch % self.save_freq == 0

  def check_adjust_lr(self):
    return self.num_batch % self.adjust_freq == 0
  
  def _finished_training(self):
    if self.train_dumper is not None:
      self.train_dumper.flush()
    
    if self.test_dumper is not None:
      self.test_dumper.flush()
      
  def _capture_training_data(self):
    if not self.train_dumper:
      return

    self.train_dumper.add({'labels' : self.net.label.get(),
                           'fc' : self.net.outputs[-3].get().transpose() })
    
  def _capture_test_data(self):
    if not self.test_dumper:
      return
    self.test_dumper.add({'labels' : self.net.label.get(),
                           'fc' : self.net.outputs[-3].get().transpose() })

  def train(self):
    self.print_net_summary()
    util.log('Starting training...')
    while self.should_continue_training():
      self.train_data = self.train_dp.get_next_batch()  # self.train_dp.wait()
      self.curr_epoch = self.train_data.epoch
      self.curr_batch = self.train_data.batchnum

      start = time.time()
      self.num_train_minibatch = divup(self.train_data.data.shape[1], self.batch_size)
      t = 0
      
      for i in range(self.num_train_minibatch):
        input, label = self.get_next_minibatch(i)
        stime = time.time()
        self.net.train_batch(input, label)
        self._capture_training_data()
        t += time.time() - stime
        self.curr_minibatch += 1

      cost , correct, numCase = self.net.get_batch_information()
      self.train_outputs += [({'logprob': [cost, 1 - correct]}, numCase, time.time() - start)]
      print >> sys.stderr,  '%d.%d: error: %f logreg: %f time: %f' % (self.curr_epoch, self.curr_batch, 1 - correct, cost, time.time() - start)

      self.num_batch += 1
      if self.check_test_data():
        print >> sys.stderr,  '---- test ----'
        self.get_test_error()
        print >> sys.stderr,  '------------'

      if self.factor != 1.0 and self.check_adjust_lr():
        print >> sys.stderr,  '---- adjust learning rate ----'
        self.net.adjust_learning_rate(self.factor)
        print >> sys.stderr,  '--------'

      if self.check_save_checkpoint():
        print >> sys.stderr,  '---- save checkpoint ----'
        self.save_checkpoint()
        print >> sys.stderr,  '------------'

      wait_time = time.time()

      #print 'waitting', time.time() - wait_time, 'secs to load'
      #print 'time to train a batch file is', time.time() - start)


    self.get_test_error()
    self.save_checkpoint()
    self.report()
    self._finished_training()

  def predict(self, save_layers = None, filename = None):
    self.net.save_layerouput(save_layers)
    self.print_net_summary()
    util.log('Starting predict...')
    save_output = []
    while self.curr_epoch < 2:
      start = time.time()
      self.test_data = self.test_dp.get_next_batch()
      self.curr_epoch = self.test_data.epoch
      self.curr_batch = self.test_data.batchnum

      self.num_test_minibatch = divup(self.test_data.data.shape[1], self.batch_size)
      for i in range(self.num_test_minibatch):
        input, label = self.get_next_minibatch(i, TEST)
        self.net.train_batch(input, label, TEST)
      cost , correct, numCase = self.net.get_batch_information()
      print >> sys.stderr,  '%d.%d: error: %f logreg: %f time: %f' % (self.curr_epoch, self.curr_batch, 1 - correct, cost, time.time() - start)
      if save_layers is not None:
        save_output.extend(self.net.get_save_output())

    if save_layers is not None:
      if filename is not None:
        with open(filename, 'w') as f:
          cPickle.dump(save_output, f, protocol = -1)
        util.log('save layer output finished')


  def report(self):
    rep = self.net.get_report()
    if rep is not None:
      print rep