Esempio n. 1
0
def runTest(corpus, rnn_layers, hidden_size, embed_size, node_size, capsule_size, gcn_layers, gcn_filters, capsule_num,
            saved_aspect_model, saved_review_model, beam_size, max_length, min_length, save_dir):
    vocabs, train_pairs, valid_pairs, test_pairs = loadPrepareData(corpus, save_dir)

    print('Building aspect model ...')
    aspect_model = AspectModel(vocabs, embed_size, node_size, hidden_size, capsule_size,
                               gcn_layers, gcn_filters, rnn_layers, capsule_num).to(device)

    print('Building review model ...')
    review_model = ReviewModel(vocabs, embed_size, node_size, hidden_size, rnn_layers).to(device)
    
    checkpoint = torch.load(saved_aspect_model)
    aspect_model.load_state_dict(checkpoint['aspect_model'])
    
    checkpoint = torch.load(saved_review_model)
    review_model.load_state_dict(checkpoint['review_model'])

    # train mode set to false, effect only on dropout, batchNorm
    aspect_model.train(False)
    review_model.train(False)

    evaluateRandomly(aspect_model, review_model, vocabs, test_pairs, len(test_pairs), beam_size,
                     max_length, min_length, save_dir)
def trainIters(corpus, learning_rate, lr_decay_epoch, lr_decay_ratio, weight_decay, batch_size, rnn_layers,
               hidden_size, embed_size, node_size, epochs, save_dir, load_file=None):

    print('load data...')
    vocabs, train_pairs, valid_pairs, test_pairs = loadPrepareData(corpus, save_dir)
    print('load data finish...')

    data_path = os.path.join(save_dir, "batches")
    if not os.path.exists(data_path):
        os.makedirs(data_path)

    corpus_name = corpus
    try:
        training_batches = torch.load(os.path.join(data_path, '{}_{}.tar'.format('training_batches', batch_size)))
    except FileNotFoundError:
        print('Training pairs not found, generating ...')
        training_batches = batchify(train_pairs, batch_size, vocabs)
        print('Complete building training pairs ...')
        torch.save(training_batches, os.path.join(data_path, '{}_{}.tar'.format('training_batches', batch_size)))

    # validation/test data
    eval_batch_size = 10
    try:
        val_batches = torch.load(os.path.join(data_path, '{}_{}.tar'.format('val_batches', eval_batch_size)))
    except FileNotFoundError:
        print('Validation pairs not found, generating ...')
        val_batches = batchify(valid_pairs, eval_batch_size, vocabs)
        print('Complete building validation pairs ...')
        torch.save(val_batches, os.path.join(data_path, '{}_{}.tar'.format('val_batches', eval_batch_size)))

    print('Building review model ...')
    review_model = ReviewModel(vocabs, embed_size, node_size, hidden_size, rnn_layers).to(device)

    print('Building optimizers ...')
    review_optimizer = optim.Adam(review_model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    print('Initializing ...')
    global_step = 1
    last_epoch = 1
    perplexities = []
    losses = []
    best_val_loss = None

    log_path = os.path.join('ckpt/' + corpus_name)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    writer = SummaryWriter(log_path)

    if load_file:
        checkpoint = torch.load(load_file)
        review_model.load_state_dict(checkpoint['review_model'])
        global_step = checkpoint['global_step']
        last_epoch = checkpoint['epoch'] + 1
        perplexities = checkpoint['perplexity']
        losses = checkpoint['loss']
        for i in range(len(losses)):
            writer.add_scalar("Train/loss", losses[i], i)
            writer.add_scalar("Train/perplexity", perplexities[i], i)

    for epoch in tqdm(range(last_epoch, epochs+1), desc="Epoch: ", leave=True):

        # train epoch
        review_model.train()

        tr_loss = 0
        steps = trange(len(training_batches), desc="Train Loss")
        for step in steps:
            context_input, aspect_input, review_input, review_output, extend_input = training_batches[step]

            loss = train(context_input, aspect_input, review_input, review_output, extend_input,
                         review_model, review_optimizer)

            global_step += 1
            tr_loss += loss

            losses.append(loss)
            perplexities.append(math.exp(loss))

            writer.add_scalar("Train/loss", loss, global_step)
            writer.add_scalar("Train/perplexity", math.exp(loss), global_step)

            steps.set_description("ReviewModel (Loss=%g, PPL=%g)" % (round(loss, 4), round(math.exp(loss), 4)))

        cur_loss = tr_loss / len(training_batches)
        cur_ppl = math.exp(cur_loss)

        print('\nTrain | Epoch: {:3d} | Avg Loss={:4.4f} | Avg PPL={:4.4f}\n'.format(epoch, cur_loss, cur_ppl))

        # evaluate
        review_model.eval()
        with torch.no_grad():
            vl_loss = 0
            for val_batch in val_batches:
                context_input, aspect_input, review_input, review_output, extend_input = val_batch

                loss = evaluate(context_input, aspect_input, review_input, review_output,
                                extend_input, review_model)

                vl_loss += loss
            vl_loss /= len(val_batches)
            vl_ppl = math.exp(vl_loss)

        writer.add_scalar("Valid/loss", vl_loss, global_step)
        writer.add_scalar("Valid/perplexity", vl_ppl, global_step)

        print('\nValid | Epoch: {:3d} | Avg Loss={:4.4f} | Avg PPL={:4.4f}\n'.format(epoch, vl_loss, vl_ppl))

        # Save the model if the validation loss is the best we've seen so far.
        model_path = os.path.join(save_dir, "model")
        if not best_val_loss or vl_loss < best_val_loss:
            directory = os.path.join(model_path, '{}_{}_{}'.format(batch_size, hidden_size, rnn_layers))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'global_step': global_step,
                'epoch': epoch,
                'review_model': review_model.state_dict(),
                'loss': losses,
                'perplexity': perplexities
            }, os.path.join(directory, '{}_{}_{}.tar'.format(epoch, round(vl_loss, 4), 'review_model')))
            best_val_loss = vl_loss

        if vl_loss > best_val_loss:
            print('validation loss is larger than best validation loss. Break!')
            break

        # learning rate adjust
        adjust_learning_rate(review_optimizer, epoch-last_epoch+1, learning_rate, lr_decay_epoch, lr_decay_ratio)