예제 #1
0
def test():
    # Test
    model.load_state_dict(
        torch.load(os.path.join(args.model_path, 'w-%d.pkl' % (best_epoch))))
    model.eval()
    with torch.no_grad():
        u = torch.from_numpy(np.array(range(num_users))).to(device)
        v = torch.from_numpy(np.array(range(num_items))).to(device)
        output, m_hat, _ = model(u, v, rating_test)
        loss_ce, loss_rmse = compute_loss(rating_test, u, v, output, m_hat)

    print('[test loss] : ' + str(loss_ce.item()) + ' [test rmse] : ' +
          str(loss_rmse.item()))
예제 #2
0
def train(config):
    print("Thank you for choosing the Sentence VAE today!")
    print(config)
    print()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset, data_loader = load_dataset(config,
                                        type_='train',
                                        dropout_rate=config.word_dropout)
    dataset_test_eval, data_loader_test_eval = load_dataset(
        config, type_='test', sorted_words=dataset.sorted_words)
    dataset_validation_eval, data_loader_validation_eval = load_dataset(
        config, type_='validation', sorted_words=dataset.sorted_words)
    dataset_train_eval, data_loader_train_eval = load_dataset(
        config, type_='train_eval', sorted_words=dataset.sorted_words)

    print("Size of train dataset: %d" % len(dataset))
    print("Size of test dataset: %d" % len(dataset_test_eval))

    model = SentVAE(dataset.vocab_size, config.embedding_size, config.num_hidden, config.latent_size, \
            config.num_layers, dataset.word_2_idx(dataset.PAD), dataset.word_2_idx(dataset.SOS), config.word_dropout, device)
    model.to(device)

    if config.generate:
        model.load_state_dict(torch.load('trained_models/vae-model-39819.pt'))
        markdown_str = ''

        model.eval()
        for step, (batch_inputs, batch_targets, masks,
                   lengths) in enumerate(data_loader):
            batch_inputs = batch_inputs.t().to(device)
            batch_targets = batch_targets.t().to(device)
            masks = masks.t().to(device)
            lengths = lengths.to(device)

            input_sample = data_loader.print_batch(batch_inputs.t(),
                                                   stop_after_EOS=True)
            print(input_sample)

            predictions, mu, sigma = model.forward(batch_inputs,
                                                   lengths,
                                                   greedy=False,
                                                   sample=True)
            predicted_targets = predictions.argmax(dim=-1)
            non_greedy_sample = data_loader.print_batch(predicted_targets.t(),
                                                        stop_after_EOS=True)
            print(non_greedy_sample)

            predictions, mu, sigma = model.forward(batch_inputs,
                                                   lengths,
                                                   greedy=True,
                                                   sample=True)
            predicted_targets = predictions.argmax(dim=-1)
            greedy_sample = data_loader.print_batch(predicted_targets.t(),
                                                    stop_after_EOS=True)
            print(greedy_sample)
            break
        for i, ng, g in zip(input_sample, non_greedy_sample, greedy_sample):
            print('input: ', i)
            print('non-greedy: ', ng)
            print('greedy: ', g)
            print()

        exit()

    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=config.learning_rate_step,
        gamma=config.learning_rate_decay)

    loss_sum, loss_kl_sum, loss_ce_sum, accuracy_sum = 0, 0, 0, 0

    current_epoch = -1
    for step, (batch_inputs, batch_targets, masks,
               lengths) in enumerate(data_loader):
        optimizer.zero_grad()
        batch_inputs = batch_inputs.t().to(device)
        batch_targets = batch_targets.t().to(device)
        masks = masks.t().to(device)
        lengths = lengths.to(device)

        predictions, mu, sigma = model.forward(batch_inputs, lengths)

        predicted_targets = predictions.argmax(dim=-1)

        accuracy = metrics.ACC(predicted_targets, batch_targets, masks,
                               lengths)

        ce_loss = metrics.compute_loss(
            predictions.transpose(1, 0).contiguous(),
            batch_targets.t().contiguous(), masks.t())
        kl_loss = metrics.KL(mu, sigma)
        accuracy = metrics.ACC(predicted_targets, batch_targets, masks,
                               lengths)

        ce_loss = metrics.compute_loss(
            predictions.transpose(1, 0).contiguous(),
            batch_targets.t().contiguous(), masks.t())
        kl_loss = metrics.KL(mu, sigma).mean()

        # KL annealing
        annealing_steps = config.annealing_end - config.annealing_start
        if annealing_steps > 0:
            annealing_frac = max(0.0, data_loader.epoch -
                                 config.annealing_start) / annealing_steps
        else:
            annealing_frac = 1.0

        kl_scale = torch.FloatTensor(1).fill_(min(
            1.0, annealing_frac**2)).to(device)
        kl_loss = kl_scale * kl_loss

        # free bits
        if config.free_bits:
            kl_loss = torch.max(
                torch.FloatTensor(1).fill_(config.free_bits).to(device),
                kl_loss)

        # ELBO
        loss = ce_loss + kl_loss

        loss.backward()
        optimizer.step()
        #scheduler.step()

        loss_sum += loss.item()
        loss_kl_sum += kl_loss.item()
        loss_ce_sum += ce_loss.item()
        accuracy_sum += accuracy.item()

        if step % config.print_every == 0:
            print("Epoch: %2d      STEP %4d     Accuracy: %.3f   Total-loss: %.3f    CE-loss: %.3f   KL-loss: %.3f" %\
                (data_loader.epoch, step, accuracy_sum/config.print_every, loss_sum/config.print_every, loss_ce_sum/config.print_every, loss_kl_sum/config.print_every))

            loss_sum, loss_kl_sum, loss_ce_sum, accuracy_sum = 0, 0, 0, 0

        if step % config.sample_every == 0:

            targets = data_loader.print_batch(batch_targets.t())
            predictions = data_loader.print_batch(predicted_targets.t())
            for i in range(len(targets)):
                print("----------------------------")
                print(targets[i])
                print()
                print(predictions[i])
                print()

            print("%s\nSAMPLES:" % ("-" * 60))
            sample = model.sample()
            sample = data_loader.print_batch(sample.t(), stop_after_EOS=True)
            for s in sample:
                print(s)
                print()

            print("%s\nINTERPOLATION:" % ("-" * 60))
            result = model.interpolation(n_steps=10)
            result = data_loader.print_batch(result.t(), stop_after_EOS=True)
            for s in result:
                print(s)
                print()

        # if step % 5000 == 0:
        if data_loader.epoch != current_epoch:
            current_epoch = data_loader.epoch

            eval_acc, eval_ppl, eval_ll = evaluate(model,
                                                   data_loader_test_eval,
                                                   dataset_test_eval, device)
            val_acc, val_ppl, val_ll = evaluate(model,
                                                data_loader_validation_eval,
                                                dataset_validation_eval,
                                                device)
            train_acc, train_ppl, train_ll = evaluate(model,
                                                      data_loader_train_eval,
                                                      dataset_train_eval,
                                                      device)

            print("Train accuracy-perplexity_likelihood: %.3f %.3f %.3f" %
                  (eval_acc, eval_ppl, eval_ll))
            print("Test accuracy-perplexity-likelihood: %.3f %.3f %.3f" %
                  (train_acc, train_ppl, train_ll))
            print("Validation accuracy-perplexity-likelihood: %.3f %.3f %.3f" %
                  (val_acc, val_ppl, val_ll))

            writer.add_scalar('SVAE/KL Loss', kl_loss.item(), current_epoch)
            writer.add_scalar('SVAE/ELBO', loss.item(), current_epoch)

            writer.add_scalar('SVAE/Train accuracy', train_acc, current_epoch)
            writer.add_scalar('SVAE/Train perplexity', train_ppl,
                              current_epoch)
            writer.add_scalar('SVAE/Train likelihood', train_ll, current_epoch)

            writer.add_scalar('SVAE/Test accuracy', eval_acc, current_epoch)
            writer.add_scalar('SVAE/Test perplexity', eval_ppl, current_epoch)
            writer.add_scalar('SVAE/Test likelihood', eval_ll, current_epoch)

            writer.add_scalar('SVAE/Valid accuracy', val_acc, current_epoch)
            writer.add_scalar('SVAE/Valid perplexity', val_ppl, current_epoch)
            writer.add_scalar('SVAE/Valid likelihood', val_ll, current_epoch)

            markdown_str = ''
            sample = model.sample()
            sample = data_loader.print_batch(sample.t(), stop_after_EOS=True)
            for s in sample:
                markdown_str += '{}  \n'.format(s)
            writer.add_text('SVAE/Samples', markdown_str, current_epoch)

            markdown_str = ''
            result = model.interpolation(n_steps=10)
            result = data_loader.print_batch(result.t(), stop_after_EOS=True)
            for s in result:
                markdown_str += '{}  \n'.format(s)
            writer.add_text('SVAE/Interpolation', markdown_str, current_epoch)

            torch.save(model.state_dict(), 'models/vae-model-%d.pt' % step)

        if data_loader.epoch == config.epochs:
            break
예제 #3
0
def train(config):
    print("Thank you for choosing the RNNLM today!")
    print(config)
    print()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset, data_loader = load_dataset(config, type_='train')
    dataset_test_eval, data_loader_test_eval = load_dataset(
        config, type_='test', sorted_words=dataset.sorted_words)
    dataset_validation_eval, data_loader_validation_eval = load_dataset(
        config, type_='validation', sorted_words=dataset.sorted_words)
    dataset_train_eval, data_loader_train_eval = load_dataset(
        config, type_='train_eval', sorted_words=dataset.sorted_words)

    print("Size of train dataset: %d" % len(dataset))
    print("Size of test dataset: %d" % len(dataset_test_eval))

    model = RNNLM(dataset.vocab_size, config.embedding_size, config.num_hidden,
                  config.num_layers, dataset.word_2_idx(dataset.PAD), device)
    model.to(device)

    if config.generate:
        model.load_state_dict(torch.load('trained_models/rnn-model-20.pt'))
        markdown_str = ''
        sample = model.sample(dataset.word_2_idx(dataset.SOS),
                              30,
                              10,
                              sample=True)
        sample = data_loader.print_batch(sample, stop_after_EOS=True)
        print()
        for i in range(len(sample)):
            markdown_str += '{}  \n'.format(sample[i])
            print(sample[i])

        writer.add_text('RNNLM/Multinomial Samples', markdown_str, 20)
        exit()

    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=config.learning_rate_step,
        gamma=config.learning_rate_decay)

    loss_ce_sum, accuracy_sum = 0, 0

    current_epoch = -1

    for step, (batch_inputs, batch_targets, masks,
               lengths) in enumerate(data_loader):
        optimizer.zero_grad()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)

        batch_inputs = batch_inputs.t().to(device)
        batch_targets = batch_targets.t().to(device)

        masks = masks.t().to(device)
        lengths = lengths.to(device)

        predictions = model.forward(batch_inputs, lengths)

        predicted_targets = predictions.argmax(dim=-1)

        accuracy = metrics.ACC(predicted_targets, batch_targets, masks,
                               lengths)

        loss = metrics.compute_loss(
            predictions.transpose(1, 0).contiguous(),
            batch_targets.t().contiguous(), masks.t())

        loss.backward()
        optimizer.step()
        #scheduler.step()

        loss_ce_sum += loss.item()
        accuracy_sum += accuracy

        if step % config.print_every == 0:
            print("Epoch: %2d   STEP %4d    Accuracy: %.3f   CE-loss: %.3f" %\
                (data_loader.epoch, step, accuracy_sum/config.print_every, loss_ce_sum/config.print_every))

            loss_ce_sum, accuracy_sum = 0, 0

        if step % config.sample_every == 0:
            predictions = data_loader.print_batch(predicted_targets.t())
            targets = data_loader.print_batch(batch_targets.t())
            for i in range(len(targets)):
                print("-----------------------")
                print(targets[i])
                print()
                print(predictions[i])

            sample = model.sample(dataset.word_2_idx(dataset.SOS), 30)
            sample = data_loader.print_batch(sample, stop_after_EOS=True)
            print()
            for i in range(len(sample)):
                print(sample[i])

        # if step % 5000 == 0:
        if data_loader.epoch != current_epoch:
            current_epoch = data_loader.epoch
            eval_acc, eval_ppl, eval_ll = evaluate(model,
                                                   data_loader_test_eval,
                                                   dataset_test_eval, device)
            val_acc, val_ppl, val_ll = evaluate(model,
                                                data_loader_validation_eval,
                                                dataset_validation_eval,
                                                device)
            train_acc, train_ppl, train_ll = evaluate(model,
                                                      data_loader_train_eval,
                                                      dataset_train_eval,
                                                      device)

            print("Train accuracy-perplexity_likelihood: %.3f %.3f %.3f" %
                  (eval_acc, eval_ppl, eval_ll))
            print("Test accuracy-perplexity-likelihood: %.3f %.3f %.3f" %
                  (train_acc, train_ppl, train_ll))
            print("Validation accuracy-perplexity-likelihood: %.3f %.3f %.3f" %
                  (val_acc, val_ppl, val_ll))

            writer.add_scalar('RNNLM/Train accuracy', train_acc, current_epoch)
            writer.add_scalar('RNNLM/Train perplexity', train_ppl,
                              current_epoch)
            writer.add_scalar('RNNLM/Train likelihood', train_ll,
                              current_epoch)

            writer.add_scalar('RNNLM/Test accuracy', eval_acc, current_epoch)
            writer.add_scalar('RNNLM/Test perplexity', eval_ppl, current_epoch)
            writer.add_scalar('RNNLM/Test likelihood', eval_ll, current_epoch)

            writer.add_scalar('RNNLM/Valid accuracy', val_acc, current_epoch)
            writer.add_scalar('RNNLM/Valid perplexity', val_ppl, current_epoch)
            writer.add_scalar('RNNLM/Valid likelihood', val_ll, current_epoch)

            sample = model.sample(dataset.word_2_idx(dataset.SOS), 30)
            sample = data_loader.print_batch(sample, stop_after_EOS=True)

            markdown_str = ''
            for i in range(len(sample)):
                markdown_str += '{}  \n'.format(sample[i])
            writer.add_text('RNNLM/Samples', markdown_str, current_epoch)

            torch.save(model.state_dict(),
                       'models/rnn-model-%d.pt' % current_epoch)

        if data_loader.epoch == config.epochs:
            break
예제 #4
0
def train():
    global best_loss, best_epoch
    if args.start_epoch:
        model.load_state_dict(
            torch.load(
                os.path.join(args.model_path,
                             'w-%d.pkl' % (args.start_epoch))).state_dict())

    # Training
    for epoch in range(args.start_epoch, args.num_epochs):
        model.train()

        train_loss = 0.
        train_rmse = 0.
        shuffled_users = sample(range(num_users), k=num_users)
        shuffled_items = sample(range(num_items), num_items)
        for s, u in enumerate(
                BatchSampler(SequentialSampler(shuffled_users),
                             batch_size=num_users,
                             drop_last=False)):
            u = torch.from_numpy(np.array(u)).to(device)
            for t, v in enumerate(
                    BatchSampler(SequentialSampler(shuffled_items),
                                 batch_size=num_items,
                                 drop_last=False)):
                v = torch.from_numpy(np.array(v)).to(device)
                if len(
                        torch.nonzero(
                            torch.index_select(
                                torch.index_select(rating_train, 1, u), 2,
                                v))) == 0:
                    continue

                output, m_hat, _ = model(u, v, rating_train)

                optimizer.zero_grad()
                loss_ce, loss_rmse = compute_loss(rating_train, u, v, output,
                                                  m_hat)
                loss_ce.backward()
                optimizer.step()

                train_loss += loss_ce.item()
                train_rmse += loss_rmse.item()

        log = 'epoch: ' + str(epoch + 1) + ' loss_ce: ' + str(train_loss / (s + 1) / (t + 1)) \
              + ' loss_rmse: ' + str(train_rmse / (s + 1) / (t + 1))
        print(log)

        if (epoch + 1) % args.val_step == 0:
            # Validation
            model.eval()
            with torch.no_grad():
                u = torch.from_numpy(np.array(range(num_users))).to(device)
                v = torch.from_numpy(np.array(range(num_items))).to(device)
                output, m_hat, _ = model(u, v, rating_val)
                loss_ce, loss_rmse = compute_loss(rating_val, u, v, output,
                                                  m_hat)

            print('[val loss] : ' + str(loss_ce.item()) + ' [val rmse] : ' +
                  str(loss_rmse.item()))
            if best_loss > loss_rmse.item():
                best_loss = loss_rmse.item()
                best_epoch = epoch + 1
                torch.save(
                    model.state_dict(),
                    os.path.join(args.model_path, 'w-%d.pkl' % (best_epoch)))
    for epoch in range(args.epochs):
        for x, y in train_dataloader:
            start_time = time.time()
            if iteration < args.warmup:
                warmup(iteration, optimizer, args.learning_rate, args.warmup)
            x, y = x.to(device), y.to(device)
            y_raw_prediction, _ = model(x)
            loss = loss_function(y_raw_prediction, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            wandb.log({'training loss': loss}, step=iteration * bs)
            wandb.log({'learning rate': get_lr(optimizer)},
                      step=iteration * bs)

            if iteration % 10 == 0:
                test_loss = compute_loss(model, test_dataloader, loss_function,
                                         device)
                wandb.log({'test loss': loss}, step=iteration * bs)
            wandb.log({'iteration': iteration}, step=iteration * bs)
            wandb.log({'iteration time': time.time() - start_time},
                      step=iteration * bs)
            iteration += 1

        lr_scheduler.step()
        training_accuracy = compute_accuracy(model, train_dataloader, device)
        test_accuracy = compute_accuracy(model, test_dataloader, device)
        wandb.log({'training accuracy': training_accuracy},
                  step=iteration * bs)
        wandb.log({'test_accuracy': test_accuracy}, step=iteration * bs)