Example #1
0
def train_mle(args: Dict):
    train_data_src = read_corpus(args['--train-src'], source='src')
    train_data_tgt = read_corpus(args['--train-tgt'], source='tgt')

    dev_data_src = read_corpus(args['--dev-src'], source='src')
    dev_data_tgt = read_corpus(args['--dev-tgt'], source='tgt')

    train_data = list(zip(train_data_src, train_data_tgt))
    dev_data = list(zip(dev_data_src, dev_data_tgt))

    train_batch_size = int(args['--batch-size'])
    clip_grad = float(args['--clip-grad'])
    valid_niter = int(args['--valid-niter'])
    log_every = int(args['--log-every'])
    notify_slack_every = int(args['--notify-slack-every'])
    model_save_path = args['--save-to']

    vocab = Vocab.load(args['--vocab'])

    model = NMT(embed_size=int(args['--embed-size']),
                hidden_size=int(args['--hidden-size']),
                dropout_rate=float(args['--dropout']),
                input_feed=args['--input-feed'],
                label_smoothing=float(args['--label-smoothing']),
                vocab=vocab)
    model.train()

    uniform_init = float(args['--uniform-init'])
    if np.abs(uniform_init) > 0.:
        print('uniformly initialize parameters [-%f, +%f]' %
              (uniform_init, uniform_init),
              file=sys.stderr)
        for p in model.parameters():
            p.data.uniform_(-uniform_init, uniform_init)

    vocab_mask = torch.ones(len(vocab.tgt))
    vocab_mask[vocab.tgt['<pad>']] = 0

    device = torch.device("cuda:0" if args['--cuda'] else "cpu")
    print('use device: %s' % device, file=sys.stderr)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=float(args['--lr']))

    num_trial = 0
    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    cum_examples = report_examples = epoch = valid_num = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()

    log_data = {'args': args}  # log用, あとで学習の収束とか見るよう

    _info = f"""
        begin Maximum Likelihood training
        ・学習:{len(train_data)}ペア
        ・テスト:{len(dev_data)}ペア, {valid_niter}iter毎
        ・バッチサイズ:{train_batch_size}
        ・1epoch = {len(train_data)}ペア = {int(len(train_data)/train_batch_size)}iter
        ・max epoch:{args['--max-epoch']}
    """
    print(_info)
    print(_info, file=sys.stderr)

    _notify_slack_if_need(f"""
    {_info}
    {args}
    """, args)

    while True:
        epoch += 1

        for src_sents, tgt_sents in batch_iter(train_data,
                                               batch_size=train_batch_size,
                                               shuffle=True):
            train_iter += 1

            optimizer.zero_grad()

            batch_size = len(src_sents)

            # (batch_size)
            example_losses = -model(src_sents, tgt_sents)
            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size

            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(),
                                                      clip_grad)

            optimizer.step()

            batch_losses_val = batch_loss.item()
            report_loss += batch_losses_val
            cum_loss += batch_losses_val

            tgt_words_num_to_predict = sum(
                len(s[1:]) for s in tgt_sents)  # omitting leading `<s>`
            report_tgt_words += tgt_words_num_to_predict
            cum_tgt_words += tgt_words_num_to_predict
            report_examples += batch_size
            cum_examples += batch_size

            if train_iter % log_every == 0 or train_iter % notify_slack_every == 0:
                _report = 'epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f ' \
                          'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter,
                                                                                             report_loss / report_examples,
                                                                                             math.exp(
                                                                                                 report_loss / report_tgt_words),
                                                                                             cum_examples,
                                                                                             report_tgt_words / (
                                                                                                     time.time() - train_time),
                                                                                             time.time() - begin_time)
                print(_report, file=sys.stderr)

                _list_dict_update(
                    log_data, {
                        'epoch': epoch,
                        'train_iter': train_iter,
                        'loss': report_loss / report_examples,
                        'ppl': math.exp(report_loss / report_tgt_words),
                        'examples': cum_examples,
                        'speed': report_tgt_words / (time.time() - train_time),
                        'elapsed': time.time() - begin_time
                    }, 'train')

                train_time = time.time()
                report_loss = report_tgt_words = report_examples = 0.

                if train_iter % notify_slack_every == 0:
                    _notify_slack_if_need(_report, args)

            # perform validation
            if train_iter % valid_niter == 0:
                print(
                    'epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f cum. examples %d'
                    % (epoch, train_iter, cum_loss / cum_examples,
                       np.exp(cum_loss / cum_tgt_words), cum_examples),
                    file=sys.stderr)

                cum_loss = cum_examples = cum_tgt_words = 0.
                valid_num += 1

                print('begin validation ...', file=sys.stderr)

                # compute dev. ppl and bleu
                dev_ppl, dev_loss = evaluate_ppl(
                    model, dev_data,
                    batch_size=128)  # dev batch size can be a bit larger
                valid_metric, eval_info = evaluate_valid_metric(
                    model, dev_data, dev_ppl, args)

                _report = 'validation: iter %d, dev. ppl %f, dev. %s %f , time elapsed %.2f sec' % (
                    train_iter, dev_ppl, args['--valid-metric'], valid_metric,
                    eval_info['elapsed'])
                print(_report, file=sys.stderr)
                _notify_slack_if_need(_report, args)

                if 'dev_data' in log_data:
                    log_data['dev_data'] = dev_data[:int(
                        args['--dev-decode-limit'])]

                _list_dict_update(log_data, {
                    'epoch': epoch,
                    'train_iter': train_iter,
                    'loss': dev_loss,
                    'ppl': dev_ppl,
                    args['--valid-metric']: valid_metric,
                    **eval_info,
                },
                                  'validation',
                                  is_save=True)

                is_better = len(hist_valid_scores
                                ) == 0 or valid_metric > max(hist_valid_scores)
                hist_valid_scores.append(valid_metric)

                if is_better:
                    patience = 0
                    print('save currently the best model to [%s]' %
                          model_save_path,
                          file=sys.stderr)
                    model.save(model_save_path)

                    # also save the optimizers' state
                    torch.save(optimizer.state_dict(),
                               model_save_path + '.optim')
                elif patience < int(args['--patience']):
                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)

                    if patience == int(args['--patience']):
                        num_trial += 1
                        print('hit #%d trial' % num_trial, file=sys.stderr)
                        if num_trial == int(args['--max-num-trial']):
                            _report = 'early stop!'
                            _notify_slack_if_need(_report, args)
                            print(_report, file=sys.stderr)
                            exit(0)

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]['lr'] * float(
                            args['--lr-decay'])
                        print(
                            'load previously best model and decay learning rate to %f'
                            % lr,
                            file=sys.stderr)

                        # load model
                        params = torch.load(
                            model_save_path,
                            map_location=lambda storage, loc: storage)
                        model.load_state_dict(params['state_dict'])
                        model = model.to(device)

                        print('restore parameters of the optimizers',
                              file=sys.stderr)
                        optimizer.load_state_dict(
                            torch.load(model_save_path + '.optim'))

                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr

                        # reset patience
                        patience = 0

                if epoch == int(args['--max-epoch']):
                    _report = 'reached maximum number of epochs!'
                    _notify_slack_if_need(_report, args)
                    print(_report, file=sys.stderr)
                    exit(0)
Example #2
0
def train_raml(args: Dict):
    train_data_src = read_corpus(args['--train-src'], source='src')
    train_data_tgt = read_corpus(args['--train-tgt'], source='tgt')

    dev_data_src = read_corpus(args['--dev-src'], source='src')
    dev_data_tgt = read_corpus(args['--dev-tgt'], source='tgt')

    train_data = list(zip(train_data_src, train_data_tgt))
    dev_data = list(zip(dev_data_src, dev_data_tgt))

    train_batch_size = int(args['--batch-size'])
    clip_grad = float(args['--clip-grad'])
    valid_niter = int(args['--valid-niter'])
    log_every = int(args['--log-every'])
    notify_slack_every = int(args['--notify-slack-every'])
    model_save_path = args['--save-to']

    vocab = Vocab.load(args['--vocab'])

    model = NMT(embed_size=int(args['--embed-size']),
                hidden_size=int(args['--hidden-size']),
                dropout_rate=float(args['--dropout']),
                input_feed=args['--input-feed'],
                label_smoothing=float(args['--label-smoothing']),
                vocab=vocab)
    model.train()

    # NOTE: RAML
    tau = float(args['--raml-temp'])
    raml_sample_mode = args['--raml-sample-mode']
    raml_sample_size = int(args['--raml-sample-size'])

    uniform_init = float(args['--uniform-init'])
    if np.abs(uniform_init) > 0.:
        print('uniformly initialize parameters [-%f, +%f]' %
              (uniform_init, uniform_init),
              file=sys.stderr)
        for p in model.parameters():
            p.data.uniform_(-uniform_init, uniform_init)

    vocab_mask = torch.ones(len(vocab.tgt))
    vocab_mask[vocab.tgt['<pad>']] = 0

    device = torch.device("cuda:0" if args['--cuda'] else "cpu")
    print('use device: %s' % device, file=sys.stderr)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=float(args['--lr']))

    num_trial = 0
    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    cum_examples = report_examples = epoch = valid_num = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()

    # NOTE: RAML
    report_weighted_loss = cum_weighted_loss = 0

    # NOTE: RAML サンプリングの読み込み or 生成
    if raml_sample_mode == 'pre_sample':
        # dict of (src, [tgt: (sent, prob)])
        print('read in raml training data...', file=sys.stderr, end='')
        begin_time = time.time()
        raml_samples = read_raml_train_data(args['--raml-sample-file'],
                                            temp=tau)
        print('done[%d s].' % (time.time() - begin_time))
    else:
        raise Exception(f'sampling:{raml_sample_mode} は、まだ未実装です')

    log_data = {'args': args}  # log用, あとで学習の収束とか見るよう

    _info = f"""
        begin RAML training
        ・学習:{len(train_data)}ペア
        ・テスト:{len(dev_data)}ペア, {valid_niter}iter毎
        ・バッチサイズ:{train_batch_size}
        ・1epoch = {len(train_data)}ペア = {int(len(train_data)/train_batch_size)}iter
        ・max epoch:{args['--max-epoch']}
    """
    print(_info)
    print(_info, file=sys.stderr)

    _notify_slack_if_need(f"""
    {_info}
    {args}
    """, args)

    while True:
        epoch += 1

        for src_sents, tgt_sents in batch_iter(train_data,
                                               batch_size=train_batch_size,
                                               shuffle=True):
            train_iter += 1

            # NOTE: RAML
            # src_sents 内 sent に紐づくサンプリングを取得 → 学習データとする
            raml_src_sents = []
            raml_tgt_sents = []
            raml_tgt_weights = []
            if raml_sample_mode == 'pre_sample':
                for src_sent in src_sents:
                    sent = ' '.join(src_sent)
                    tgt_samples_all = raml_samples[sent]
                    # random choice from candidate samples
                    if raml_sample_size >= len(tgt_samples_all):
                        tgt_samples = tgt_samples_all
                    else:
                        tgt_samples_id = np.random.choice(
                            range(1, len(tgt_samples_all)),
                            size=raml_sample_size - 1,
                            replace=False)
                        # [ground truth y*] + samples
                        tgt_samples = [tgt_samples_all[0]] + [
                            tgt_samples_all[i] for i in tgt_samples_id
                        ]

                    raml_src_sents.extend([src_sent] * len(tgt_samples))
                    raml_tgt_sents.extend([['<s>'] + sent.split(' ') +
                                           ['</s>']
                                           for sent, weight in tgt_samples])
                    raml_tgt_weights.extend(
                        [weight for sent, weight in tgt_samples])
            else:
                raise Exception(f'sampling:{raml_sample_mode} は、まだ未実装です')

            optimizer.zero_grad()

            # NOTE: RAML
            weights_var = torch.tensor(raml_tgt_weights,
                                       dtype=torch.float,
                                       device=device)
            batch_size = len(raml_src_sents)

            # (batch_size)
            unweighted_loss = -model(raml_src_sents, raml_tgt_sents)
            batch_loss = weighted_loss = (unweighted_loss * weights_var).sum()
            loss = batch_loss / batch_size

            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(),
                                                      clip_grad)

            optimizer.step()

            # NOTE: RAML
            weighted_loss_val = weighted_loss.item()
            batch_losses_val = unweighted_loss.sum().item()

            report_weighted_loss += weighted_loss_val
            cum_weighted_loss += weighted_loss_val
            report_loss += batch_losses_val
            cum_loss += batch_losses_val

            tgt_words_num_to_predict = sum(
                len(s[1:]) for s in tgt_sents)  # omitting leading `<s>`
            report_tgt_words += tgt_words_num_to_predict
            cum_tgt_words += tgt_words_num_to_predict
            report_examples += batch_size
            cum_examples += batch_size

            if train_iter % log_every == 0 or train_iter % notify_slack_every == 0:
                _report = 'epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f ' \
                          'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter,
                                                                                             report_weighted_loss / report_examples,
                                                                                             math.exp(
                                                                                                 report_loss / report_tgt_words),
                                                                                             cum_examples,
                                                                                             report_tgt_words / (
                                                                                                     time.time() - train_time),
                                                                                             time.time() - begin_time)
                print(_report, file=sys.stderr)

                _list_dict_update(
                    log_data, {
                        'epoch': epoch,
                        'train_iter': train_iter,
                        'loss': report_loss / report_examples,
                        'ppl': math.exp(report_loss / report_tgt_words),
                        'examples': cum_examples,
                        'speed': report_tgt_words / (time.time() - train_time),
                        'elapsed': time.time() - begin_time
                    }, 'train')

                train_time = time.time()
                report_loss = report_weighted_loss = report_tgt_words = report_examples = 0.

                if train_iter % notify_slack_every == 0:
                    _notify_slack_if_need(_report, args)

            # perform validation
            if train_iter % valid_niter == 0:
                print(
                    'epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f cum. examples %d'
                    % (epoch, train_iter, cum_weighted_loss / cum_examples,
                       np.exp(cum_loss / cum_tgt_words), cum_examples),
                    file=sys.stderr)

                cum_loss = cum_weighted_loss = cum_examples = cum_tgt_words = 0.
                valid_num += 1

                print('begin validation ...', file=sys.stderr)

                # compute dev. ppl and bleu
                dev_ppl, dev_loss = evaluate_ppl(
                    model, dev_data,
                    batch_size=128)  # dev batch size can be a bit larger
                valid_metric, eval_info = evaluate_valid_metric(
                    model, dev_data, dev_ppl, args)

                _report = 'validation: iter %d, dev. ppl %f, dev. %s %f , time elapsed %.2f sec' % (
                    train_iter, dev_ppl, args['--valid-metric'], valid_metric,
                    eval_info['elapsed'])
                print(_report, file=sys.stderr)
                _notify_slack_if_need(_report, args)

                if 'dev_data' in log_data:
                    log_data['dev_data'] = dev_data[:int(
                        args['--dev-decode-limit'])]

                _list_dict_update(log_data, {
                    'epoch': epoch,
                    'train_iter': train_iter,
                    'loss': dev_loss,
                    'ppl': dev_ppl,
                    args['--valid-metric']: valid_metric,
                    **eval_info,
                },
                                  'validation',
                                  is_save=True)

                is_better = len(hist_valid_scores
                                ) == 0 or valid_metric > max(hist_valid_scores)
                hist_valid_scores.append(valid_metric)

                if is_better:
                    patience = 0
                    print('save currently the best model to [%s]' %
                          model_save_path,
                          file=sys.stderr)
                    model.save(model_save_path)

                    # also save the optimizers' state
                    torch.save(optimizer.state_dict(),
                               model_save_path + '.optim')
                elif patience < int(args['--patience']):
                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)

                    if patience == int(args['--patience']):
                        num_trial += 1
                        print('hit #%d trial' % num_trial, file=sys.stderr)
                        if num_trial == int(args['--max-num-trial']):
                            _report = 'early stop!'
                            _notify_slack_if_need(_report, args)
                            print(_report, file=sys.stderr)
                            exit(0)

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]['lr'] * float(
                            args['--lr-decay'])
                        print(
                            'load previously best model and decay learning rate to %f'
                            % lr,
                            file=sys.stderr)

                        # load model
                        params = torch.load(
                            model_save_path,
                            map_location=lambda storage, loc: storage)
                        model.load_state_dict(params['state_dict'])
                        model = model.to(device)

                        print('restore parameters of the optimizers',
                              file=sys.stderr)
                        optimizer.load_state_dict(
                            torch.load(model_save_path + '.optim'))

                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr

                        # reset patience
                        patience = 0

                if epoch == int(args['--max-epoch']):
                    _report = 'reached maximum number of epochs!'
                    _notify_slack_if_need(_report, args)
                    print(_report, file=sys.stderr)
                    exit(0)