Example #1
0
class Trainer(object):
    """Trainer"""
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.num_preload = args.num_preload
        self.logger = ut.get_logger(self.config['log_file'])

        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')

        self.normalize_loss = self.config['normalize_loss']
        self.patience = self.config['patience']
        self.lr = self.config['lr']
        self.lr_decay = self.config['lr_decay']
        self.max_epochs = self.config['max_epochs']
        self.warmup_steps = self.config['warmup_steps']

        self.train_smooth_perps = []
        self.train_true_perps = []

        self.data_manager = DataManager(self.config)
        self.validator = Validator(self.config, self.data_manager)

        self.val_per_epoch = self.config['val_per_epoch']
        self.validate_freq = int(self.config['validate_freq'])
        self.logger.info('Evaluate every {} {}'.format(
            self.validate_freq, 'epochs' if self.val_per_epoch else 'batches'))

        # For logging
        self.log_freq = 100  # log train stat every this-many batches
        self.log_train_loss = 0.  # total train loss every log_freq batches
        self.log_nll_loss = 0.
        self.log_train_weights = 0.
        self.num_batches_done = 0  # number of batches done for the whole training
        self.epoch_batches_done = 0  # number of batches done for this epoch
        self.epoch_loss = 0.  # total train loss for whole epoch
        self.epoch_nll_loss = 0.  # total train loss for whole epoch
        self.epoch_weights = 0.  # total train weights (# target words) for whole epoch
        self.epoch_time = 0.  # total exec time for whole epoch, sounds like that tabloid

        # get model
        self.model = Model(self.config).to(self.device)

        param_count = sum(
            [numpy.prod(p.size()) for p in self.model.parameters()])
        self.logger.info('Model has {:,} parameters'.format(param_count))

        # get optimizer
        beta1 = self.config['beta1']
        beta2 = self.config['beta2']
        epsilon = self.config['epsilon']
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.lr,
                                          betas=(beta1, beta2),
                                          eps=epsilon)

    def report_epoch(self, e):
        self.logger.info('Finish epoch {}'.format(e))
        self.logger.info('    It takes {}'.format(
            ut.format_seconds(self.epoch_time)))
        self.logger.info('    Avergage # words/second    {}'.format(
            self.epoch_weights / self.epoch_time))
        self.logger.info('    Average seconds/batch    {}'.format(
            self.epoch_time / self.epoch_batches_done))

        train_smooth_perp = self.epoch_loss / self.epoch_weights
        train_true_perp = self.epoch_nll_loss / self.epoch_weights

        self.epoch_batches_done = 0
        self.epoch_time = 0.
        self.epoch_nll_loss = 0.
        self.epoch_loss = 0.
        self.epoch_weights = 0.

        train_smooth_perp = numpy.exp(
            train_smooth_perp) if train_smooth_perp < 300 else float('inf')
        self.train_smooth_perps.append(train_smooth_perp)
        train_true_perp = numpy.exp(
            train_true_perp) if train_true_perp < 300 else float('inf')
        self.train_true_perps.append(train_true_perp)

        self.logger.info(
            '    smoothed train perplexity: {}'.format(train_smooth_perp))
        self.logger.info(
            '    true train perplexity: {}'.format(train_true_perp))

    def run_log(self, b, e, batch_data):
        start = time.time()
        src_toks, trg_toks, targets = batch_data
        src_toks_cuda = src_toks.to(self.device)
        trg_toks_cuda = trg_toks.to(self.device)
        targets_cuda = targets.to(self.device)

        # zero grad
        self.optimizer.zero_grad()

        # get loss
        ret = self.model(src_toks_cuda, trg_toks_cuda, targets_cuda)
        loss = ret['loss']
        nll_loss = ret['nll_loss']
        if self.normalize_loss == ac.LOSS_TOK:
            opt_loss = loss / (targets_cuda != ac.PAD_ID).type(
                loss.type()).sum()
        elif self.normalize_loss == ac.LOSS_BATCH:
            opt_loss = loss / targets_cuda.size()[0].type(loss.type())
        else:
            opt_loss = loss

        opt_loss.backward()
        # clip gradient
        global_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                     self.config['grad_clip'])

        # update
        self.adjust_lr()
        self.optimizer.step()

        # update training stats
        num_words = (targets != ac.PAD_ID).detach().numpy().sum()

        loss = loss.cpu().detach().numpy()
        nll_loss = nll_loss.cpu().detach().numpy()
        self.num_batches_done += 1
        self.log_train_loss += loss
        self.log_nll_loss += nll_loss
        self.log_train_weights += num_words

        self.epoch_batches_done += 1
        self.epoch_loss += loss
        self.epoch_nll_loss += nll_loss
        self.epoch_weights += num_words
        self.epoch_time += time.time() - start

        if self.num_batches_done % self.log_freq == 0:
            acc_speed_word = self.epoch_weights / self.epoch_time
            acc_speed_time = self.epoch_time / self.epoch_batches_done

            avg_smooth_perp = self.log_train_loss / self.log_train_weights
            avg_smooth_perp = numpy.exp(
                avg_smooth_perp) if avg_smooth_perp < 300 else float('inf')
            avg_true_perp = self.log_nll_loss / self.log_train_weights
            avg_true_perp = numpy.exp(
                avg_true_perp) if avg_true_perp < 300 else float('inf')

            self.log_train_loss = 0.
            self.log_nll_loss = 0.
            self.log_train_weights = 0.

            self.logger.info('Batch {}, epoch {}/{}:'.format(
                b, e + 1, self.max_epochs))
            self.logger.info(
                '   avg smooth perp:   {0:.2f}'.format(avg_smooth_perp))
            self.logger.info(
                '   avg true perp:   {0:.2f}'.format(avg_true_perp))
            self.logger.info('   acc trg words/s: {}'.format(
                int(acc_speed_word)))
            self.logger.info(
                '   acc sec/batch:   {0:.2f}'.format(acc_speed_time))
            self.logger.info('   global norm:     {0:.2f}'.format(global_norm))

    def adjust_lr(self):
        if self.config['warmup_style'] == ac.ORG_WARMUP:
            step = self.num_batches_done + 1.0
            if step < self.config['warmup_steps']:
                lr = self.config['embed_dim']**(
                    -0.5) * step * self.config['warmup_steps']**(-1.5)
            else:
                lr = max(self.config['embed_dim']**(-0.5) * step**(-0.5),
                         self.config['min_lr'])

            for p in self.optimizer.param_groups:
                p['lr'] = lr

    def train(self):
        self.model.train()
        train_ids_file = self.data_manager.data_files['ids']
        for e in range(self.max_epochs):
            b = 0
            for batch_data in self.data_manager.get_batch(
                    ids_file=train_ids_file,
                    shuffle=True,
                    num_preload=self.num_preload):
                b += 1
                self.run_log(b, e, batch_data)
                if not self.val_per_epoch:
                    self.maybe_validate()

            self.report_epoch(e + 1)
            if self.val_per_epoch and (e + 1) % self.validate_freq == 0:
                self.maybe_validate(just_validate=True)

        # validate 1 last time
        if not self.config['val_per_epoch']:
            self.maybe_validate(just_validate=True)

        self.logger.info('It is finally done, mate!')
        self.logger.info('Train smoothed perps:')
        self.logger.info(', '.join(map(str, self.train_smooth_perps)))
        self.logger.info('Train true perps:')
        self.logger.info(', '.join(map(str, self.train_true_perps)))
        numpy.save(join(self.config['save_to'], 'train_smooth_perps.npy'),
                   self.train_smooth_perps)
        numpy.save(join(self.config['save_to'], 'train_true_perps.npy'),
                   self.train_true_perps)

        self.logger.info('Save final checkpoint')
        self.save_checkpoint()

        # Evaluate on test
        for checkpoint in self.data_manager.checkpoints:
            self.logger.info('Translate for {}'.format(checkpoint))
            dev_file = self.data_manager.dev_files[checkpoint][
                self.data_manager.src_lang]
            test_file = self.data_manager.test_files[checkpoint][
                self.data_manager.src_lang]
            if exists(test_file):
                self.logger.info('  Evaluate on test')
                self.restart_to_best_checkpoint(checkpoint)
                self.validator.translate(self.model, test_file)
                self.logger.info('  Also translate dev')
                self.validator.translate(self.model, dev_file)

    def save_checkpoint(self):
        cpkt_path = join(self.config['save_to'],
                         '{}.pth'.format(self.config['model_name']))
        torch.save(self.model.state_dict(), cpkt_path)

    def restart_to_best_checkpoint(self, checkpoint):
        best_perp = numpy.min(self.validator.best_perps[checkpoint])
        best_cpkt_path = self.validator.get_cpkt_path(checkpoint, best_perp)

        self.logger.info('Restore best cpkt from {}'.format(best_cpkt_path))
        self.model.load_state_dict(torch.load(best_cpkt_path))

    def maybe_validate(self, just_validate=False):
        if self.num_batches_done % self.validate_freq == 0 or just_validate:
            self.save_checkpoint()
            self.validator.validate_and_save(self.model)

            # if doing annealing
            if self.config[
                    'warmup_style'] == ac.NO_WARMUP and self.lr_decay > 0:
                cond = len(
                    self.validator.perp_curve
                ) > self.patience and self.validator.perp_curve[-1] > max(
                    self.validator.perp_curve[-1 - self.patience:-1])
                if cond:
                    metric = 'perp'
                    scores = self.validator.perp_curve[-1 - self.patience:]
                    scores = map(str, list(scores))
                    scores = ', '.join(scores)

                    self.logger.info('Past {} are {}'.format(metric, scores))
                    # when don't use warmup, decay lr if dev not improve
                    if self.lr * self.lr_decay >= self.config['min_lr']:
                        self.logger.info(
                            'Anneal the learning rate from {} to {}'.format(
                                self.lr, self.lr * self.lr_decay))
                        self.lr = self.lr * self.lr_decay
                        for p in self.optimizer.param_groups:
                            p['lr'] = self.lr
class Trainer(object):
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.config['fixed_var_list'] = args.fixed_var_list
        self.num_preload = args.num_preload
        self.logger = ut.get_logger(self.config['log_file'])

        self.lr = self.config['lr']
        self.max_epochs = self.config['max_epochs']
        self.save_freq = self.config['save_freq']
        self.cpkt_path = None
        self.validate_freq = None
        self.train_perps = []

        self.saver = None
        self.train_m = None
        self.dev_m = None

        self.data_manager = DataManager(self.config)
        self.validator = Validator(self.config, self.data_manager)
        self.validate_freq = ut.get_validation_frequency(
            self.data_manager.length_files[ac.TRAINING],
            self.config['validate_freq'], self.config['batch_size'])
        self.logger.info('Evaluate every {} batches'.format(
            self.validate_freq))

        _, self.src_ivocab = self.data_manager.init_vocab(
            self.data_manager.src_lang)
        _, self.trg_ivocab = self.data_manager.init_vocab(
            self.data_manager.trg_lang)

        # For logging
        self.log_freq = 100  # log train stat every this-many batches
        self.log_train_loss = 0.  # total train loss every log_freq batches
        self.log_train_weights = 0.
        self.num_batches_done = 0  # number of batches done for the whole training
        self.epoch_batches_done = 0  # number of batches done for this epoch
        self.epoch_loss = 0.  # total train loss for whole epoch
        self.epoch_weights = 0.  # total train weights (# target words) for whole epoch
        self.epoch_time = 0.  # total exec time for whole epoch, sounds like that tabloid

    def get_model(self, mode):
        reuse = mode != ac.TRAINING
        d = self.config['init_range']
        initializer = tf.random_uniform_initializer(-d, d, seed=ac.SEED)
        with tf.variable_scope(self.config['model_name'],
                               reuse=reuse,
                               initializer=initializer):
            return Model(self.config, mode)

    def reload_and_get_cpkt_saver(self, config, sess):
        cpkt_path = join(config['save_to'],
                         '{}.cpkt'.format(config['model_name']))
        saver = tf.train.Saver(var_list=tf.trainable_variables(),
                               max_to_keep=config['n_best'] + 1)

        # TF some version >= 0.11, no longer .cpkt so check for meta file instead
        if exists(cpkt_path + '.meta') and config['reload']:
            self.logger.info('Reload model from {}'.format(cpkt_path))
            saver.restore(sess, cpkt_path)

        self.cpkt_path = cpkt_path
        self.saver = saver

    def train(self):
        tf.reset_default_graph()
        with tf.Graph().as_default():
            tf.set_random_seed(ac.SEED)
            self.train_m = self.get_model(ac.TRAINING)
            self.dev_m = self.get_model(ac.VALIDATING)

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                self.reload_and_get_cpkt_saver(self.config, sess)
                self.logger.info('Set learning rate to {}'.format(self.lr))
                sess.run(tf.assign(self.train_m.lr, self.lr))

                for e in xrange(self.max_epochs):
                    for b, batch_data in self.data_manager.get_batch(
                            mode=ac.TRAINING, num_batches=self.num_preload):
                        self.run_log_save(sess, b, e, batch_data)
                        self.maybe_validate(sess)

                    self.report_epoch(e)

                self.logger.info('It is finally done, mate!')
                self.logger.info('Train perplexities:')
                self.logger.info(', '.join(map(str, self.train_perps)))
                numpy.save(join(self.config['save_to'], 'train_perps.npy'),
                           self.train_perps)

                self.logger.info('Save final checkpoint')
                self.saver.save(sess, self.cpkt_path)

                # Evaluate on test
                if exists(self.data_manager.ids_files[ac.TESTING]):
                    self.logger.info('Evaluate on test')
                    best_bleu = numpy.max(self.validator.best_bleus)
                    best_cpkt_path = self.validator.get_cpkt_path(best_bleu)
                    self.logger.info(
                        'Restore best cpkt from {}'.format(best_cpkt_path))
                    self.saver.restore(sess, best_cpkt_path)
                    self.validator.evaluate(sess, self.dev_m, mode=ac.TESTING)

    def report_epoch(self, e):
        self.logger.info('Finish epoch {}'.format(e + 1))
        self.logger.info('    It takes {}'.format(
            ut.format_seconds(self.epoch_time)))
        self.logger.info('    Avergage # words/second    {}'.format(
            self.epoch_weights / self.epoch_time))
        self.logger.info('    Average seconds/batch    {}'.format(
            self.epoch_time / self.epoch_batches_done))

        train_perp = self.epoch_loss / self.epoch_weights

        self.epoch_batches_done = 0
        self.epoch_time = 0.
        self.epoch_loss = 0.
        self.epoch_weights = 0.

        train_perp = numpy.exp(train_perp) if train_perp < 300 else float(
            'inf')
        self.train_perps.append(train_perp)
        self.logger.info('    train perplexity: {}'.format(train_perp))

    def sample_input(self, batch_data):
        # TODO: more on sampling
        src_sent = batch_data[0][0]
        src_length = batch_data[1][0]
        trg_sent = batch_data[2][0]
        target = batch_data[3][0]
        weight = batch_data[4][0]

        src_sent = map(self.src_ivocab.get, src_sent)
        src_sent = u' '.join(src_sent)
        trg_sent = map(self.trg_ivocab.get, trg_sent)
        trg_sent = u' '.join(trg_sent)
        target = map(self.trg_ivocab.get, target)
        target = u' '.join(target)
        weight = ' '.join(map(str, weight))

        self.logger.info('Sample input data:')
        self.logger.info(u'Src: {}'.format(src_sent))
        self.logger.info(u'Src len: {}'.format(src_length))
        self.logger.info(u'Trg: {}'.format(trg_sent))
        self.logger.info(u'Tar: {}'.format(target))
        self.logger.info(u'W: {}'.format(weight))

    def run_log_save(self, sess, b, e, batch_data):
        start = time.time()
        src_inputs, src_seq_lengths, trg_inputs, trg_targets, target_weights = batch_data
        feed = {
            self.train_m.src_inputs: src_inputs,
            self.train_m.src_seq_lengths: src_seq_lengths,
            self.train_m.trg_inputs: trg_inputs,
            self.train_m.trg_targets: trg_targets,
            self.train_m.target_weights: target_weights
        }
        loss, _ = sess.run([self.train_m.loss, self.train_m.train_op], feed)

        num_trg_words = numpy.sum(target_weights)
        self.num_batches_done += 1
        self.epoch_batches_done += 1
        self.epoch_loss += loss
        self.epoch_weights += num_trg_words
        self.log_train_loss += loss
        self.log_train_weights += num_trg_words
        self.epoch_time += time.time() - start

        if self.num_batches_done % (10 * self.log_freq) == 0:
            self.sample_input(batch_data)

        if self.num_batches_done % self.log_freq == 0:
            acc_speed_word = self.epoch_weights / self.epoch_time
            acc_speed_time = self.epoch_time / self.epoch_batches_done

            avg_word_perp = self.log_train_loss / self.log_train_weights
            avg_word_perp = numpy.exp(
                avg_word_perp) if avg_word_perp < 300 else float('inf')
            self.log_train_loss = 0.
            self.log_train_weights = 0.

            self.logger.info('Batch {}, epoch {}/{}:'.format(
                b, e + 1, self.max_epochs))
            self.logger.info(
                '   avg word perp:   {0:.2f}'.format(avg_word_perp))
            self.logger.info('   acc trg words/s: {}'.format(
                int(acc_speed_word)))
            self.logger.info(
                '   acc sec/batch:   {0:.2f}'.format(acc_speed_time))

        if self.num_batches_done % self.save_freq == 0:
            start = time.time()
            self.saver.save(sess, self.cpkt_path)
            self.logger.info('Save model to {}, takes {}'.format(
                self.cpkt_path, ut.format_seconds(time.time() - start)))

    def maybe_validate(self, sess):
        if self.num_batches_done % self.validate_freq == 0:
            self.validator.validate_and_save(sess, self.dev_m, self.saver)

    def _call_me_maybe(self):
        pass  # NO