Ejemplo n.º 1
0
def main(args):
    num_frames = 15
    ms_per_frame = 40

    network = EncoderDecoder(args).cuda()
    optimizer = torch.optim.Adam(network.parameters(), lr=args.lr, betas=(0.9, 0.99))
    criterion = nn.MSELoss()

    train_loader, dev_loader, test_loader = fetch_kth_data(args)

    # test_tens = next(iter(train_loader))['instance'][0, :, :, :, :].transpose(0, 1)
    # print(test_tens.shape)
    # save_image(test_tens, './img/test_tens.png')
    # print(next(iter(train_loader))['instance'][0, :, 0, :, :].shape)
    train_loss = []
    dev_loss = []
    for epoch in range(args.epochs):
        epoch_loss = 0
        batch_num = 0
        for item in train_loader:
            #label = item['label']
            item = item['instance'].cuda()

            frames_processed = 0
            batch_loss = 0

            # fit a whole batch for all the different milliseconds
            for i in range(num_frames-1):
                for j in range(i+1, num_frames):
                    network.zero_grad()
                    frame_diff = j - i
                    time_delta = torch.tensor(frame_diff * ms_per_frame).float().repeat(args.batch_size).cuda()
                    time_delta.requires_grad = True

                    seq = item[:, :, i, :, :]
                    #print(seq.shape)

                    # downsample
                    #seq = F.interpolate(seq, size=(64, 64))
                    #print(seq.shape)

                    seq.requires_grad = True

                    seq_targ = item[:, :, j, :, :]

                    # downsample
                    #seq_targ = F.interpolate(seq_targ, size=(64, 64))

                    seq_targ.requires_grad = False

                    assert seq.requires_grad and time_delta.requires_grad, 'No Gradients'

                    outputs = network(seq, time_delta)
                    error = criterion(outputs, seq_targ)
                    error.backward()
                    optimizer.step()

                    batch_loss += error.cpu().item()
                    frames_processed += 1

                    if i == 0:
                        save_image(outputs, '/scratch/eecs-share/dinkinst/kth/img/train_output_{}_epoch_{}.png'.format(j, epoch))

            batch_num += 1
            epoch_loss += batch_loss
            print('Epoch {} Batch #{} Total Error {}'.format(epoch, batch_num, batch_loss))
        print('\nEpoch {} Total Loss {} Scaled Loss {}\n'.format(epoch, epoch_loss, epoch_loss/frames_processed))
        train_loss.append(epoch_loss)
        if epoch % 10 == 0:
            torch.save(network.state_dict(), KTH_PATH+str('/model_new_{}.pth'.format(epoch)))
            torch.save(optimizer.state_dict(), KTH_PATH+str('/optim_new_{}.pth'.format(epoch)))

        dev_loss.append(eval_model(network, dev_loader, epoch))
        network.train()

    plt.plot(range(args.epochs), train_loss)
    plt.grid()
    plt.savefig('/scratch/eecs-share/dinkinst/kth/img/loss_train.png', dpi=64)
    plt.close('all')
    plt.plot(range(args.epochs), dev_loss)
    plt.grid()
    plt.savefig('/scratch/eecs-share/dinkinst/kth/img/loss_dev.png', dpi=64)
    plt.close('all')
Ejemplo n.º 2
0
def main():
    torch.manual_seed(10)  # fix seed for reproducibility
    torch.cuda.manual_seed(10)

    train_data, train_source_text, train_target_text = create_data(
        os.path.join(train_data_dir, train_dataset), lang)
    #dev_data, dev_source_text, dev_target_text = create_data(os.path.join(eval_data_dir, 'newstest2012_2013'), lang)

    eval_data, eval_source_text, eval_target_text = create_data(
        os.path.join(dev_data_dir, eval_dataset), lang)

    en_emb_lookup_matrix = train_source_text.vocab.vectors.to(device)
    target_emb_lookup_matrix = train_target_text.vocab.vectors.to(device)

    global en_vocab_size
    global target_vocab_size

    en_vocab_size = train_source_text.vocab.vectors.size(0)
    target_vocab_size = train_target_text.vocab.vectors.size(0)

    if verbose:
        print('English vocab size: ', en_vocab_size)
        print(lang, 'vocab size: ', target_vocab_size)
        print_runtime_metric('Vocabs loaded')

    model = EncoderDecoder(en_emb_lookup_matrix, target_emb_lookup_matrix,
                           hidden_size, bidirectional, attention,
                           attention_type, decoder_cell_type).to(device)

    model.encoder.device = device

    criterion = nn.CrossEntropyLoss(
        ignore_index=1
    )  # ignore_index=1 comes from the target_data generation from the data iterator

    #optimiser = torch.optim.Adadelta(model.parameters(), lr=1.0, rho=0.9, eps=1e-06, weight_decay=0) # This is the exact optimiser in the paper; rho=0.95
    optimiser = torch.optim.Adam(model.parameters(), lr=lr)

    best_loss = 10e+10  # dummy variable
    best_bleu = 0
    epoch = 1  # initial epoch id

    if resume:
        print('\n ---------> Resuming training <----------')
        checkpoint_path = os.path.join(save_dir, 'checkpoint.pth')
        checkpoint = torch.load(checkpoint_path)
        epoch = checkpoint['epoch']
        subepoch, num_subepochs = checkpoint['subepoch_num']
        model.load_state_dict(checkpoint['state_dict'])
        best_loss = checkpoint['best_loss']
        optimiser.load_state_dict(checkpoint['optimiser'])
        is_best = checkpoint['is_best']
        metric_store.load(os.path.join(save_dir, 'checkpoint_metrics.pickle'))

        if subepoch == num_subepochs:
            epoch += 1
            subepoch = 1
        else:
            subepoch += 1

    if verbose:
        print_runtime_metric('Model initialised')

    while epoch <= num_epochs:
        is_best = False  # best loss or not

        # Initialise the iterators
        train_iter = BatchIterator(train_data,
                                   batch_size,
                                   do_train=True,
                                   seed=epoch**2)

        num_subepochs = train_iter.num_batches // subepoch_size

        # train sub-epochs from start_batch
        # This allows subepoch training resumption
        if not resume:
            subepoch = 1
        while subepoch <= num_subepochs:

            if verbose:
                print(' Running code on: ', device)

                print('------> Training epoch {}, sub-epoch {}/{} <------'.
                      format(epoch, subepoch, num_subepochs))

            mean_train_loss = train(model, criterion, optimiser, train_iter,
                                    train_source_text, train_target_text,
                                    subepoch, num_subepochs)

            if verbose:
                print_runtime_metric('Training sub-epoch complete')
                print(
                    '------> Evaluating sub-epoch {} <------'.format(subepoch))

            eval_iter = BatchIterator(eval_data,
                                      batch_size,
                                      do_train=False,
                                      seed=325632)

            mean_eval_loss, mean_eval_bleu, _, mean_eval_sent_bleu, _, _ = evaluate(
                model, criterion, eval_iter, eval_source_text.vocab,
                eval_target_text.vocab, train_source_text.vocab,
                train_target_text.vocab)  # here should be the eval data

            if verbose:
                print_runtime_metric('Evaluating sub-epoch complete')

            if mean_eval_loss < best_loss:
                best_loss = mean_eval_loss
                is_best = True

            if mean_eval_bleu > best_bleu:
                best_bleu = mean_eval_bleu
                is_best = True

            config_dict = {
                'train_dataset': train_dataset,
                'b_size': batch_size,
                'h_size': hidden_size,
                'bidirectional': bidirectional,
                'attention': attention,
                'attention_type': attention_type,
                'decoder_cell_type': decoder_cell_type
            }

            # Save the model and the optimiser state for resumption (after each epoch)
            checkpoint = {
                'epoch': epoch,
                'subepoch_num': (subepoch, num_subepochs),
                'state_dict': model.state_dict(),
                'config': config_dict,
                'best_loss': best_loss,
                'best_BLEU': best_bleu,
                'optimiser': optimiser.state_dict(),
                'is_best': is_best
            }
            torch.save(checkpoint, os.path.join(save_dir, 'checkpoint.pth'))
            metric_store.log(mean_train_loss, mean_eval_loss)
            metric_store.save(
                os.path.join(save_dir, 'checkpoint_metrics.pickle'))

            if verbose:
                print('Checkpoint.')

            # Save the best model so far
            if is_best:
                save_dict = {
                    'state_dict': model.state_dict(),
                    'config': config_dict,
                    'epoch': epoch
                }
                torch.save(save_dict, os.path.join(save_dir, 'best_model.pth'))
                metric_store.save(
                    os.path.join(save_dir, 'best_model_metrics.pickle'))

            if verbose:
                if is_best:
                    print('Best model saved!')
                print(
                    'Ep {} Sub-ep {}/{} Tr loss {} Eval loss {} Eval BLEU {} Eval sent BLEU {}'
                    .format(epoch, subepoch, num_subepochs,
                            round(mean_train_loss, 3),
                            round(mean_eval_loss, 3), round(mean_eval_bleu, 4),
                            round(mean_eval_sent_bleu, 4)))

            subepoch += 1
        epoch += 1