Example #1
0
def monitor(f_log_prob, FLAGS, valid_set, train_set=None, states=None):
    print("Start monitoring phase")
    returns = OrderedDict()
    if train_set is None:
        tr_nats = 0.0
        tr_bits = 0.0
    else:
        if states is not None:
            reset_state(states)
        _cost = 0
        _len = 0
        for x in train_set:
            x, x_mask = gen_mask(x, max_seq_len=FLAGS.max_seq_len)
            _tr_cost, _tr_cost_len = f_log_prob(x, x_mask)
            _cost += _tr_cost.sum()
            _len += _tr_cost_len.sum()
        tr_nats = _cost / _len
        tr_bits = nats2bits(tr_nats)
    returns['tr_nats'] = tr_nats
    returns['tr_bits'] = tr_bits
    if states is not None:
        reset_state(states)
    _cost = 0
    _len = 0
    for x in valid_set:
        x, x_mask = gen_mask(x, max_seq_len=FLAGS.max_seq_len)
        _val_cost, _val_cost_len = f_log_prob(x, x_mask)
        _cost += _val_cost.sum()
        _len += _val_cost_len.sum()
    val_nats = _cost / _len
    val_bits = nats2bits(val_nats)
    returns['val_nats'] = val_nats
    returns['val_bits'] = val_bits
    return returns
Example #2
0
 def reset(self):
     reset_state(self.states)
Example #3
0
 def __call__(self):
     # Fix random seeds
     _seed = self.FLAGS.base_seed + self.FLAGS.add_seed
     np.random.seed(seed=_seed)
     # Prefixed names for save files
     prefix_name = os.path.join(self.FLAGS.log_dir, self._file_name)
     file_name = '%s.npz' % prefix_name
     best_file_name = '%s.best.npz' % prefix_name
     opt_file_name = '%s.grads.npz' % prefix_name
     best_opt_file_name = '%s.best.grads.npz' % prefix_name
     if self.FLAGS.start_from_ckpt and os.path.exists(file_name):
         self._ckpt_file_name = file_name
     # Declare summary
     summary = OrderedDict()
     # Initialize the variables
     f_prop, f_update, f_log_prob, f_debug, tparams, opt_tparams, \
         states, st_slope = self._build_graph(self.FLAGS)
     # Restore from checkpoint if FLAGS.start_from_ckpt is on
     if self.FLAGS.start_from_ckpt and os.path.exists(file_name):
         tparams = init_tparams_with_restored_value(tparams, file_name)
         model = np.load(file_name)
         for k, v in model.items():
             if 'summary' in k:
                 summary[k] = list(v)
             if 'time' in k:
                 summary[k] = list(v)
         global_step = model['global_step']
         epoch_step = model['epoch_step']
         batch_step = model['batch_step']
         print("Restore from the last checkpoint. "
               "Restarting from %d step." % global_step)
     else:
         global_step = 0
         epoch_step = 0
         batch_step = 0
     # Construct dataset objects
     train_set = TextIterator(which_set='train',
                              max_seq_len=self.FLAGS.max_seq_len,
                              batch_size=self.FLAGS.batch_size,
                              shuffle_every_epoch=1)
     if self.FLAGS.eval_train:
         train_infer_set = TextIterator(which_set='train',
                                        max_seq_len=self.FLAGS.max_seq_len,
                                        batch_size=self.FLAGS.batch_size,
                                        shuffle_every_epoch=0)
     else:
         train_infer_set = None
     valid_set = TextIterator(which_set='valid',
                              max_seq_len=self.FLAGS.max_seq_len,
                              batch_size=self.FLAGS.batch_size,
                              shuffle_every_epoch=0)
     if self.FLAGS.start_from_ckpt:
         _summary = self._monitor(f_log_prob, self.FLAGS, valid_set, None,
                                  states)
         _val_bits = _summary['val_bits']
         if _val_bits != summary['val_bits'][-1]:
             raise ValueError(
                 "Sanity check failed, check values do not match.")
         try:
             for cc in xrange(batch_step + 1):
                 train_set.next()
         except:
             batch_step = 0
     best_params = None
     tr_costs = []
     _best_score = np.iinfo(np.int32).max
     # Keep training until max iteration
     print("Starting the optimization")
     for _epoch in xrange(self.FLAGS.n_epoch):
         reset_state(states)
         _n_exp = 0
         _time = time.time()
         __time = time.time()
         if self.FLAGS.start_from_ckpt and batch_step is not 0:
             pass
         else:
             batch_step = 0
         if self.FLAGS.use_slope_anneal:
             if _epoch <= self.FLAGS.n_anneal_epoch:
                 new_slope = float(1. + (self.FLAGS.n_slope - 1) /
                                   float(self.FLAGS.n_anneal_epoch) *
                                   _epoch)
                 st_slope.set_value(new_slope)
                 print("Changed the ST slope to : %f" %
                       st_slope.get_value())
         for x in train_set:
             x, x_mask = gen_mask(x, max_seq_len=self.FLAGS.max_seq_len)
             _n_exp += self.FLAGS.batch_size
             # Run f-prop and optimization functions (backprop)
             cost = f_prop(x, x_mask)
             f_update(self.FLAGS.learning_rate)
             tr_costs.append(cost)
             if np.mod(global_step, self.FLAGS.display_freq) == 0:
                 _time_spent = time.time() - _time
                 tr_cost = np.array(tr_costs).mean()
                 print("Epoch " + str(_epoch) + \
                       ", Iter " + str(global_step) + \
                       ", Average batch loss= " + "{:.6f}".format(tr_cost) + \
                       ", Elapsed time= " + "{:.5f}".format(_time_spent))
                 _time = time.time()
                 tr_costs = []
             batch_step += 1
             global_step += 1
         # Monitor training/validation nats and bits
         _summary = self._monitor(f_log_prob, self.FLAGS, valid_set,
                                  train_infer_set, states)
         feed_dict(summary, _summary)
         print("Train average nats= " + "{:.6f}".format(_summary['tr_nats']) + \
               ", Train average bits= " + "{:.6f}".format(_summary['tr_bits']) + \
               ", Valid average nats= " + "{:.6f}".format(_summary['val_nats']) + \
               ", Valid average bits= " + "{:.6f}".format(_summary['val_bits']) + \
               ", Elapsed time= " + "{:.5f}".format(time.time() - __time)) + \
               ", Observed examples= " + "{:d}".format(_n_exp)
         insert_item2dict(summary, 'time', _time_spent)
         # Save model
         _val_bits = summary['val_bits'][-1]
         if _val_bits < _best_score:
             _best_score = _val_bits
             # Save the best model
             best_params = unzip(tparams)
             if self.FLAGS.use_slope_anneal:
                 best_params['st_slope'] = st_slope.get_value()
             save_npz(best_file_name, global_step, epoch_step, batch_step,
                      best_params, summary)
             # Save the gradients of best model
             best_opt_params = unzip(opt_tparams)
             save_npz2(best_opt_file_name, best_opt_params)
             print("Best checkpoint stored in: %s" % best_file_name)
         # Save the latest model
         params = unzip(tparams)
         if self.FLAGS.use_slope_anneal:
             params['st_slope'] = st_slope.get_value()
         save_npz(file_name, global_step, epoch_step, batch_step, params,
                  summary)
         # Save the gradients of latest model
         opt_params = unzip(opt_tparams)
         save_npz2(opt_file_name, opt_params)
         print("Checkpointed in: %s" % file_name)
     print("Optimization Finished.")