Ejemplo n.º 1
0
            model, test_dataloader, tense_list)
        writer.add_scalar('Loss/reconstruction loss', rc_loss / trainset_size,
                          epoch + 1)
        writer.add_scalar('Loss/KL loss', kl_loss / trainset_size, epoch + 1)
        writer.add_scalar('BLEU-4 score', average_bleu_score, epoch + 1)
        writer.add_scalar('Setting parameter/KL weight', kl_weight, epoch + 1)
        writer.add_scalar('Setting parameter/teacher forcing ratio',
                          teacher_forcing_ratio, epoch + 1)
        writer.add_scalars(
            'Comparison', {
                'rc_loss': rc_loss / trainset_size,
                'kl_loss': kl_loss / trainset_size,
                'BLEU-4 socre': average_bleu_score,
                'kl_weight': kl_weight,
                'teacher_ratio': teacher_forcing_ratio,
                'gaussian_score': gaussian_score
            }, epoch + 1)
        print('Epoch ', epoch + 1)
        print('Average Reconstruction loss: %f\n Average KL loss: %f' %
              (rc_loss / trainset_size, kl_loss / trainset_size))
        print('Average BLEU-4 score: %f\n' % average_bleu_score)

        if average_bleu_score > best_bleu_score:
            record_score(average_bleu_score, gaussian_score, predict_list,
                         generate_words, test_dataloader, transformer)
            best_bleu_score = average_bleu_score

        torch.save(model.state_dict(),
                   'model/checkpoint' + str(epoch) + '.pkl')
    end = time.time()
    print('Total training time: ' + str((end - start) // 60) + ' minutess')
Ejemplo n.º 2
0
Archivo: main.py Proyecto: mori97/MVAE
def main():
    parser = argparse.ArgumentParser(
        description='Train MVAE with VCC2018 dataset',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--train-dataset',
                        help='Path of training dataset.',
                        type=str, required=True)
    parser.add_argument('--val-dataset',
                        help='Path of validation dataset.',
                        type=str, required=True)
    parser.add_argument('--batch-size', '-b',
                        help='Batch size.',
                        type=int, default=32)
    parser.add_argument('--epochs', '-e',
                        help='Number of epochs.',
                        type=int, default=800)
    parser.add_argument('--eval-interval',
                        help='Evaluate and save model every N epochs.',
                        type=int, default=200, metavar='N')
    parser.add_argument('--gpu', '-g',
                        help='GPU id. (Negative number indicates CPU)',
                        type=int, default=-1)
    parser.add_argument('--learning-rate', '-l',
                        help='Learning Rate.',
                        type=float, default=1e-3)
    parser.add_argument('--output',
                        help='Save model to PATH',
                        type=str, default='./models')
    args = parser.parse_args()

    if not os.path.isdir(args.output):
        os.mkdir(args.output)

    if_use_cuda = torch.cuda.is_available() and args.gpu >= 0
    if if_use_cuda:
        device = torch.device(f'cuda:{args.gpu}')
        cp.cuda.Device(args.gpu).use()
    else:
        device = torch.device('cpu')

    train_dataset = torch.load(args.train_dataset)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, args.batch_size, shuffle=True)
    val_dataset = make_eval_set(args.val_dataset)

    baseline = baseline_ilrma(val_dataset, device)

    model = CVAE(n_speakers=train_dataset[0][1].size(0)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)

    # TensorBoard
    writer = SummaryWriter()

    for epoch in range(1, args.epochs + 1):
        train(model, train_dataloader, optimizer, device, epoch, writer)
        if epoch % args.eval_interval == 0:
            validate(model, val_dataset, baseline, device, epoch, writer)
            # Save model
            model.cpu()
            path = os.path.join(args.output, f'model-{epoch}.pth')
            torch.save(model.state_dict(), path)
            model.to(device)

    writer.close()