Exemplo n.º 1
0
def question_1i_sanity_check():
    """ Sanity check for nmt_model.py
        basic shape check
    """
    print("-" * 80)
    print("Running Sanity Check for Question 1i: NMT")
    print("-" * 80)
    src_vocab_entry = VocabEntry()
    tgt_vocab_entry = VocabEntry()
    dummy_vocab = Vocab(src_vocab_entry, tgt_vocab_entry)
    word_embed_size = 5
    hidden_size = 10

    nmt = NMT(word_embed_size, hidden_size, dummy_vocab)
    source = [["Hello my friend"], ["How are you"]]
    target = [["Bonjour mon ami"], ["Comment vas tu"]]
    output = nmt.forward(source, target)

    print(output)
    #output_expected_size = [sentence_length, BATCH_SIZE, EMBED_SIZE]
    #assert(list(output.size()) == output_expected_size), "output shape is incorrect: it should be:\n {} but is:\n{}".format(output_expected_size, list(output.size()))
    print("Sanity Check Passed for Question 1i: NMT!")
    print("-" * 80)
Exemplo n.º 2
0
Arquivo: run.py Projeto: aaniin/cs224n
def train(args: Dict):
    """ Train the NMT Model.
    @param args (Dict): args from cmd line
    """
    train_data_src = read_corpus(args['--train-src'], source='src')
    train_data_tgt = read_corpus(args['--train-tgt'], source='tgt')

    dev_data_src = read_corpus(args['--dev-src'], source='src')
    dev_data_tgt = read_corpus(args['--dev-tgt'], source='tgt')

    train_data = list(zip(train_data_src, train_data_tgt))
    dev_data = list(zip(dev_data_src, dev_data_tgt))

    train_batch_size = int(args['--batch-size'])
    clip_grad = float(args['--clip-grad'])
    valid_niter = int(args['--valid-niter'])
    log_every = int(args['--log-every'])
    model_save_path = args['--save-to']

    vocab = Vocab.load(args['--vocab'])

    model = NMT(embed_size=int(args['--embed-size']),
                hidden_size=int(args['--hidden-size']),
                dropout_rate=float(args['--dropout']),
                vocab=vocab)
    model.train()

    uniform_init = float(args['--uniform-init'])
    if np.abs(uniform_init) > 0.:
        print('uniformly initialize parameters [-%f, +%f]' %
              (uniform_init, uniform_init),
              file=sys.stderr)
        for p in model.parameters():
            p.data.uniform_(-uniform_init, uniform_init)

    vocab_mask = torch.ones(len(vocab.tgt))
    vocab_mask[vocab.tgt['<pad>']] = 0

    device = torch.device("cuda:0" if args['--cuda'] else "cpu")
    print('use device: %s' % device, file=sys.stderr)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=float(args['--lr']))

    # Set counters
    num_trial = 0
    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    cum_examples = report_examples = epoch = valid_num = 0
    hist_valid_scores = []
    fwd_time = train_time = begin_time = time.time()

    # Begin training
    print('begin Maximum Likelihood training')

    while True:
        epoch += 1

        #  Loop over all data in selection batches
        for src_sents, tgt_sents in batch_iter(train_data,
                                               batch_size=train_batch_size,
                                               shuffle=True):

            # Sentences must be sorted in length (that is number of words)
            src_sents = sorted(src_sents, key=lambda e: len(e), reverse=True)
            tgt_sents = sorted(tgt_sents, key=lambda e: len(e), reverse=True)

            train_iter += 1
            # Zero out gradients, pytorch accumulates them
            optimizer.zero_grad()

            # Get loss
            train_batch_losses = (-model.forward(src_sents, tgt_sents))
            batch_loss = train_batch_losses.sum()
            loss = batch_loss / train_batch_size

            # Get gradients
            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       clip_grad)

            # step
            optimizer.step()

            # Report progress
            batch_losses_val = batch_loss.item()
            report_loss += batch_losses_val
            cum_loss += batch_losses_val

            # Get some report metric
            tgt_words_num_to_predict = sum(
                len(s[1:]) for s in tgt_sents)  # omitting leading `<s>`
            report_tgt_words += tgt_words_num_to_predict
            cum_tgt_words += tgt_words_num_to_predict
            report_examples += train_batch_size
            cum_examples += train_batch_size

            if train_iter % log_every == 0:
                print('epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f ' \
                        'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter,
                                                                                            report_loss / report_examples,
                                                                                            math.exp(report_loss / report_tgt_words),
                                                                                            cum_examples,
                                                                                            report_tgt_words / (time.time() - train_time),
                                                                                            time.time() - begin_time), file=sys.stderr)

                train_time = time.time()
                report_loss = report_tgt_words = report_examples = 0.

                # Test saving and loading the model
                # test_save_load_model(model=model,optimizer=optimizer)

            # perform validation
            if train_iter % valid_niter == 0:
                print(
                    'epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f, cum. examples %d'
                    % (epoch, train_iter, cum_loss / cum_examples,
                       np.exp(cum_loss / cum_tgt_words), cum_examples),
                    file=sys.stderr)

                cum_loss = cum_examples = cum_tgt_words = 0.
                valid_num += 1

                print('begin validation ...', file=sys.stderr)

                # compute dev. ppl and bleu
                #dev_ppl = evaluate_ppl(model, dev_data, batch_size=128)   # dev batch size can be a bit larger
                dev_ppl = evaluate_ppl(model,
                                       dev_data,
                                       batch_size=train_batch_size *
                                       2)  # dev batch size can be a bit larger
                valid_metric = -dev_ppl

                print('validation: iter %d, dev. ppl %f' %
                      (train_iter, dev_ppl),
                      file=sys.stderr)

                is_better = len(hist_valid_scores
                                ) == 0 or valid_metric > max(hist_valid_scores)
                hist_valid_scores.append(valid_metric)

                if is_better:
                    patience = 0
                    print('save currently the best model to [%s]' %
                          model_save_path,
                          file=sys.stderr)
                    model.save(model_save_path)

                    # also save the optimizers' state
                    torch.save(optimizer.state_dict(),
                               model_save_path + '.optim')
                elif patience < int(args['--patience']):
                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)

                    if patience == int(args['--patience']):
                        num_trial += 1
                        print('hit #%d trial' % num_trial, file=sys.stderr)
                        if num_trial == int(args['--max-num-trial']):
                            print('early stop!', file=sys.stderr)
                            exit(0)

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]['lr'] * float(
                            args['--lr-decay'])
                        print(
                            'load previously best model and decay learning rate to %f'
                            % lr,
                            file=sys.stderr)

                        # load model
                        #params = torch.load(model_save_path, map_location=lambda storage, loc: storage)
                        # See https://github.com/pytorch/pytorch/issues/7415 and
                        # https://discuss.pytorch.org/t/on-a-cpu-device-how-to-load-checkpoint-saved-on-gpu-device/349 and
                        # https://github.com/pytorch/pytorch/issues/9139
                        params = torch.load(model_save_path,
                                            map_location='cpu')
                        model.load_state_dict(params['state_dict'])
                        model = model.to(device)

                        print('restore parameters of the optimizers',
                              file=sys.stderr)
                        # optimizer.load_state_dict(torch.load(model_save_path + '.optim')
                        optimizer.load_state_dict(
                            torch.load(model_save_path + '.optim',
                                       map_location='cpu'))
                        optimizer_to(optimizer, device)

                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr

                        # reset patience
                        patience = 0

                if epoch == int(args['--max-epoch']):
                    print('reached maximum number of epochs!', file=sys.stderr)
                    exit(0)