Ejemplo n.º 1
0
def train(args: Dict):
    train_data_src = read_corpus(args['--train-src'], source='src')
    train_data_tgt = read_corpus(args['--train-tgt'], source='tgt')

    dev_data_src = read_corpus(args['--dev-src'], source='src')
    dev_data_tgt = read_corpus(args['--dev-tgt'], source='tgt')

    # [(src_0, tgt_0), (src_1, tgt_1), ..., ]
    train_data = list(zip(train_data_src, train_data_tgt))
    dev_data = list(zip(dev_data_src, dev_data_tgt))

    train_batch_size = int(args['--batch-size'])
    clip_grad = float(args['--clip-grad'])
    valid_niter = int(args['--valid-niter'])
    log_every = int(args['--log-every'])
    model_save_path = args['--save-to']

    # vocab = Vocab.load(args['--vocab'])
    vocab = Vocab.build(train_data_src, train_data_tgt,
                        int(args['--vocab-size']), 1)

    model = NMT(embed_size=int(args['--embed-size']),
                hidden_size=int(args['--hidden-size']),
                dropout_rate=float(args['--dropout']),
                vocab=vocab)
    model.train()
    print(model)

    uniform_init = float(args['--uniform-init'])
    if np.abs(uniform_init) > 0.:
        print('uniformly initialize parameters [-%f, +%f]' %
              (uniform_init, uniform_init),
              file=sys.stderr)
        for p in model.parameters():
            p.data.uniform_(-uniform_init, uniform_init)

    # vocab_mask = torch.ones(len(vocab.tgt))
    # vocab_mask[vocab.tgt['<pad>']] = 0

    device = torch.device("cuda:0" if args['--cuda'] else "cpu")
    print('use device: %s' % device, file=sys.stderr)

    model = model.to(device)
    model.save(model_save_path)

    optimizer = torch.optim.Adam(model.parameters(), lr=float(args['--lr']))

    num_trial = 0
    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    cum_examples = report_examples = epoch = valid_num = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()

    print('begin Maximum Likelihood training')

    while True:
        epoch += 1

        for src_sents, tgt_sents in batch_iter(train_data,
                                               batch_size=train_batch_size,
                                               shuffle=True):
            train_iter += 1

            optimizer.zero_grad()

            batch_size = len(src_sents)

            #################### forward pass and compute loss #########################
            # example_losses = -model(src_sents, tgt_sents) # (batch_size,)
            example_losses = model(src_sents, tgt_sents)  # [batch_size,]
            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size

            #################### backward pass to compute gradients ####################
            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       clip_grad)

            #################### update model parameters ###############################
            optimizer.step()

            #################### do some statistics ####################################
            batch_losses_val = batch_loss.item()
            report_loss += batch_losses_val
            cum_loss += batch_losses_val

            tgt_words_num_to_predict = sum(
                len(s[1:]) for s in tgt_sents)  # omitting leading `<s>`
            report_tgt_words += tgt_words_num_to_predict
            cum_tgt_words += tgt_words_num_to_predict
            report_examples += batch_size
            cum_examples += batch_size

            #################### print log #############################################
            if train_iter % log_every == 0:
                print(
                    'epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f '
                    'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec'
                    % (
                        epoch,
                        train_iter,
                        report_loss / report_examples,
                        #  math.exp(report_loss / report_tgt_words),
                        (report_loss / report_tgt_words),
                        cum_examples,
                        report_tgt_words / (time.time() - train_time),
                        time.time() - begin_time),
                    file=sys.stderr)

                train_time = time.time()
                report_loss = report_tgt_words = report_examples = 0.

            ##################### perform validation ##################################
            if train_iter % valid_niter == 0:
                print(
                    'epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f cum. examples %d'
                    % (epoch, train_iter, cum_loss / cum_examples,
                       np.exp(cum_loss / cum_tgt_words), cum_examples),
                    file=sys.stderr)

                cum_loss = cum_examples = cum_tgt_words = 0.
                valid_num += 1

                print('begin validation ...', file=sys.stderr)

                # compute dev. ppl and bleu
                dev_ppl = evaluate_ppl(
                    model, dev_data,
                    batch_size=128)  # dev batch size can be a bit larger
                valid_metric = -dev_ppl

                print('validation: iter %d, dev. ppl %f' %
                      (train_iter, dev_ppl),
                      file=sys.stderr)

                is_better = len(hist_valid_scores
                                ) == 0 or valid_metric > max(hist_valid_scores)
                hist_valid_scores.append(valid_metric)

                # hypotheses = beam_search(model, dev_data_src,
                #                          beam_size=4,
                #                          max_decoding_time_step=10)

                if is_better:
                    patience = 0
                    print('save currently the best model to [%s]' %
                          model_save_path,
                          file=sys.stderr)
                    model.save(model_save_path)

                    # also save the optimizers' state
                    torch.save(optimizer.state_dict(),
                               model_save_path + '.optim')
                elif patience < int(args['--patience']):
                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)

                    if patience == int(args['--patience']):
                        num_trial += 1
                        print('hit #%d trial' % num_trial, file=sys.stderr)
                        if num_trial == int(args['--max-num-trial']):
                            print('early stop!', file=sys.stderr)
                            exit(0)

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]['lr'] * float(
                            args['--lr-decay'])
                        print(
                            'load previously best model and decay learning rate to %f'
                            % lr,
                            file=sys.stderr)

                        # load model
                        params = torch.load(
                            model_save_path,
                            map_location=lambda storage, loc: storage)
                        model.load_state_dict(params['state_dict'])
                        model = model.to(device)

                        print('restore parameters of the optimizers',
                              file=sys.stderr)
                        optimizer.load_state_dict(
                            torch.load(model_save_path + '.optim'))

                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr

                        # reset patience
                        patience = 0

                if epoch == int(args['--max-epoch']):
                    print('reached maximum number of epochs!', file=sys.stderr)
                    exit(0)
Ejemplo n.º 2
0
class Trainer:
    """
    训练类,使用训练集训练模型

    Args:
        _hparams (NameSpace): 人为设定的超参数,默认值见config.py,也可以在命令行指定。
    """

    def __init__(self, _hparams):
        self.hparams = _hparams
        set_seed(_hparams.fixed_seed)
        self.train_loader = get_dataloader(_hparams.train_src_path, _hparams.train_dst_path,
                                           _hparams.batch_size, _hparams.num_workers)
        self.src_vocab, self.dst_vocab = load_vocab(_hparams.train_src_pkl, _hparams.train_dst_pkl)
        self.device = torch.device(_hparams.device)
        self.model = NMT(_hparams.embed_size, _hparams.hidden_size,
                         self.src_vocab, self.dst_vocab, self.device,
                         _hparams.dropout_rate).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=_hparams.lr)

    def train(self):
        print('*' * 20, 'train', '*' * 20)
        hist_valid_scores = []
        patience = 0
        num_trial = 0

        for epoch in range(int(self.hparams.max_epochs)):
            self.model.train()

            epoch_loss_val = 0
            epoch_steps = len(self.train_loader)
            for step, data_pairs in tqdm(enumerate(self.train_loader)):
                sents = [(dp.src, dp.dst) for dp in data_pairs]
                src_sents, tgt_sents = zip(*sents)

                self.optimizer.zero_grad()

                batch_size = len(src_sents)
                example_losses = -self.model(src_sents, tgt_sents)
                batch_loss = example_losses.sum()
                train_loss = batch_loss / batch_size
                epoch_loss_val += train_loss.item()
                train_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.hparams.clip_gradient)
                self.optimizer.step()

            epoch_loss_val /= epoch_steps
            print('epoch: {}, epoch_loss_val: {}'.format(epoch, epoch_loss_val))

            # perform validation
            if epoch % self.hparams.valid_niter == 0:
                print('*' * 20, 'validate', '*' * 20)
                dev_ppl = evaluate_ppl(self.model, self.hparams.val_src_path, self.hparams.val_dst_path,
                                       self.hparams.batch_val_size, self.hparams.num_workers)
                valid_metric = -dev_ppl

                is_better = len(hist_valid_scores) == 0 or valid_metric > max(hist_valid_scores)
                hist_valid_scores.append(valid_metric)

                if is_better:
                    patience = 0
                    print('save currently the best model to {}'.format(self.hparams.model_save_path))
                    self.model.save(self.hparams.model_save_path)
                    torch.save(self.optimizer.state_dict(), self.hparams.optimizer_save_path)
                elif patience < self.hparams.patience:
                    patience += 1
                    print('hit patience %d' % patience)

                    if patience == self.hparams.patience:
                        num_trial += 1
                        print('hit #{} trial'.format(num_trial))
                        if num_trial == self.hparams.max_num_trial:
                            print('early stop!')
                            exit(0)

                        # 兼容设计,考虑Adam不需要人工调整lr,而其他优化器需要
                        if hasattr(self.optimizer, 'param_group'):
                            # decay lr, and restore from previously best checkpoint
                            lr = self.optimizer.param_groups[0]['lr'] * self.hparams.lr_decay
                            print('load previously best model and decay learning rate to %f' % lr)

                            params = torch.load(self.hparams.model_save_path, map_location=lambda storage, loc: storage)
                            self.model.load_state_dict(params['state_dict'])
                            self.model = self.model.to(self.device)

                            print('restore parameters of the optimizers')
                            self.optimizer.load_state_dict(torch.load(self.hparams.optimizer_save_path))

                            # set new lr
                            for param_group in self.optimizer.param_groups:
                                param_group['lr'] = lr

                        # reset patience
                        patience = 0
                print('*' * 20, 'end validate', '*' * 20)
        print('*' * 20, 'end train', '*' * 20)
Ejemplo n.º 3
0
def train(args: Dict):
    """ Train the NMT Model.
    @param args (Dict): args from cmd line
    """
    do_bleu = '--ignore-test-bleu' not in args or not args['--ignore-test-bleu']
    train_data_src = read_corpus(args['--train-src'],
                                 source='src',
                                 dev_mode=dev_mode)
    train_data_tgt = read_corpus(args['--train-tgt'],
                                 source='tgt',
                                 dev_mode=dev_mode)

    dev_data_src = read_corpus(args['--dev-src'],
                               source='src',
                               dev_mode=dev_mode)
    dev_data_tgt = read_corpus(args['--dev-tgt'],
                               source='tgt',
                               dev_mode=dev_mode)

    if do_bleu:
        test_data_src = read_corpus(args['--test-src'],
                                    source='src',
                                    dev_mode=dev_mode)
        test_data_tgt = read_corpus(args['--test-tgt'],
                                    source='tgt',
                                    dev_mode=dev_mode)

    train_data = list(zip(train_data_src, train_data_tgt))
    dev_data = list(zip(dev_data_src, dev_data_tgt))

    max_tokens_in_sentence = int(args['--max-decoding-time-step'])
    train_data = clean_data(train_data, max_tokens_in_sentence)
    dev_data = clean_data(dev_data, max_tokens_in_sentence)

    train_batch_size = int(args['--batch-size'])
    dev_batch_size = 128
    clip_grad = float(args['--clip-grad'])
    valid_niter = int(args['--valid-niter'])
    bleu_niter = int(args['--bleu-niter'])
    log_every = int(args['--log-every'])
    model_save_path = args['--save-to']

    vocab = Vocab.load(args['--vocab'], args['--word_freq'])

    model = NMT(embed_size=int(args['--embed-size']),
                hidden_size=int(args['--hidden-size']),
                dropout_rate=float(args['--dropout']),
                vocab=vocab)
    writer = SummaryWriter()

    # model = TransformerNMT(vocab, num_hidden_layers=3)

    model.train()

    uniform_init = float(args['--uniform-init'])
    if np.abs(uniform_init) > 0.:
        print('uniformly initialize parameters [-%f, +%f]' %
              (uniform_init, uniform_init),
              file=sys.stderr)
        for p in model.parameters():
            if p.dim() > 1:
                torch.nn.init.xavier_uniform_(p)
            else:
                p.data.uniform_(-uniform_init, uniform_init)

    vocab_mask = torch.ones(len(vocab.tgt))
    vocab_mask[vocab.tgt['<pad>']] = 0

    device = torch.device("cuda:0" if args['--cuda'] else "cpu")
    print('use device: %s' % device, file=sys.stderr)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=float(args['--lr']))

    num_trial = 0
    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    cum_examples = report_examples = epoch = valid_num = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()

    print("Sorting dataset based on difficulty...")
    dataset = (train_data, dev_data)
    ordered_dataset = load_order(args['--order-name'], dataset, vocab)
    # TODO: order = balance_order(order, dataset)
    (train_data, dev_data) = ordered_dataset

    visualize_scoring_examples = False
    if visualize_scoring_examples:
        visualize_scoring(ordered_dataset, vocab)

    n_iters = math.ceil(len(train_data) / train_batch_size)
    print("n_iters per epoch is {}: ({} / {})".format(n_iters, len(train_data),
                                                      train_batch_size))
    max_epoch = int(args['--max-epoch'])
    max_iters = max_epoch * n_iters

    print('begin Maximum Likelihood training')
    print('Using order function: {}'.format(args['--order-name']))
    print('Using pacing function: {}'.format(args['--pacing-name']))
    while True:
        epoch += 1
        for _ in range(n_iters):
            # Get pacing data according to train_iter
            current_train_data, current_dev_data = pacing_data(
                train_data,
                dev_data,
                time=train_iter,
                warmup_iters=int(args["--warmup-iters"]),
                method=args['--pacing-name'],
                tb=writer)

            # Uniformly sample batches from the paced dataset
            src_sents, tgt_sents = get_pacing_batch(
                current_train_data, batch_size=train_batch_size, shuffle=True)

            train_iter += 1

            # ERROR START
            optimizer.zero_grad()

            batch_size = len(src_sents)

            example_losses = -model(src_sents, tgt_sents)  # (batch_size,)
            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size

            loss.backward()
            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       clip_grad)

            optimizer.step()

            batch_losses_val: int = batch_loss.item()
            report_loss += batch_losses_val
            cum_loss += batch_losses_val

            tgt_words_num_to_predict = sum(
                len(s[1:]) for s in tgt_sents)  # omitting leading `<s>`
            report_tgt_words += tgt_words_num_to_predict
            cum_tgt_words += tgt_words_num_to_predict
            report_examples += batch_size
            cum_examples += batch_size

            if train_iter % log_every == 0:
                print(
                    'epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f '
                    'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec'
                    % (epoch, train_iter, report_loss / report_examples,
                       math.exp(report_loss / report_tgt_words), cum_examples,
                       report_tgt_words /
                       (time.time() - train_time), time.time() - begin_time),
                    file=sys.stderr)
                writer.add_scalar('Loss/train', report_loss / report_examples,
                                  train_iter)
                writer.add_scalar('ppl/train',
                                  math.exp(report_loss / report_tgt_words),
                                  train_iter)
                train_time = time.time()
                report_loss = report_tgt_words = report_examples = 0.

            # evaluate BLEU
            if train_iter % bleu_niter == 0 and do_bleu:
                bleu = decode_with_params(
                    model, test_data_src, test_data_tgt,
                    int(args['--beam-size']),
                    int(args['--max-decoding-time-step']))
                writer.add_scalar('bleu/test', bleu, train_iter)

            # perform validation
            if train_iter % valid_niter == 0:
                print(
                    'epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f cum. examples %d'
                    % (epoch, train_iter, cum_loss / cum_examples,
                       np.exp(cum_loss / cum_tgt_words), cum_examples),
                    file=sys.stderr)

                cum_loss = cum_examples = cum_tgt_words = 0.
                valid_num += 1

                print('begin validation ...', file=sys.stderr)

                # compute dev. ppl and bleu
                # dev batch size can be a bit larger
                dev_ppl = evaluate_ppl(model,
                                       current_dev_data,
                                       batch_size=dev_batch_size)
                valid_metric = -dev_ppl
                writer.add_scalar('ppl/valid', dev_ppl, train_iter)
                cum_loss = cum_examples = cum_tgt_words = 0.
                valid_num += 1

                print('validation: iter %d, dev. ppl %f' %
                      (train_iter, dev_ppl),
                      file=sys.stderr)

                is_better = len(hist_valid_scores
                                ) == 0 or valid_metric > max(hist_valid_scores)
                hist_valid_scores.append(valid_metric)

                if is_better:
                    patience = 0
                    print('save currently the best model to [%s]' %
                          model_save_path,
                          file=sys.stderr)
                    model.save(model_save_path)

                    # also save the optimizers' state
                    torch.save(optimizer.state_dict(),
                               model_save_path + '.optim')
                elif patience < int(args['--patience']):
                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)

                    if patience == int(args['--patience']):
                        num_trial += 1
                        print('hit #%d trial' % num_trial, file=sys.stderr)
                        if num_trial == int(args['--max-num-trial']):
                            print('early stop!', file=sys.stderr)
                            exit(0)

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]['lr'] * \
                            float(args['--lr-decay'])
                        print(
                            'load previously best model and decay learning rate to %f'
                            % lr,
                            file=sys.stderr)

                        # load model
                        params = torch.load(
                            model_save_path,
                            map_location=lambda storage, loc: storage)
                        model.load_state_dict(params['state_dict'])
                        model = model.to(device)

                        print('restore parameters of the optimizers',
                              file=sys.stderr)
                        optimizer.load_state_dict(
                            torch.load(model_save_path + '.optim'))

                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr

                        # reset patience
                        patience = 0

                if epoch >= int(args['--max-epoch']):
                    print('reached maximum number of epochs!', file=sys.stderr)
                    exit(0)