Exemple #1
0
def main(args):
    ###############################
    # TRAIN PREP
    ###############################
    print("Loading data")
    train_loader, valid_loader, data_var, input_size = \
                                data.get_data(args.data_folder,args.batch_size)

    args.input_size = input_size
    args.downsample = args.input_size[-1] // args.enc_height
    args.data_variance = data_var
    print(f"Training set size {len(train_loader.dataset)}")
    print(f"Validation set size {len(valid_loader.dataset)}")

    print("Loading model")
    if args.model == 'diffvqvae':
        model = DiffVQVAE(args).to(device)
    elif args.model == 'vqvae':
        model = VQVAE(args).to(device)
    print(
        f'The model has {utils.count_parameters(model):,} trainable parameters'
    )

    optimizer = optim.Adam(model.parameters(),
                           lr=args.learning_rate,
                           amsgrad=False)

    print(f"Start training for {args.num_epochs} epochs")
    num_batches = math.ceil(
        len(train_loader.dataset) / train_loader.batch_size)
    pbar = Progress(num_batches, bar_length=10, custom_increment=True)

    # Needed for bpd
    args.KL = args.enc_height * args.enc_height * args.num_codebooks * \
                                                    np.log(args.num_embeddings)
    args.num_pixels = np.prod(args.input_size)

    ###############################
    # MAIN TRAIN LOOP
    ###############################
    best_valid_loss = float('inf')
    train_bpd = []
    train_recon_error = []
    train_perplexity = []
    args.global_it = 0
    for epoch in range(args.num_epochs):
        pbar.epoch_start()
        train_epoch(args, vq_vae_loss, pbar, train_loader, model, optimizer,
                    train_bpd, train_recon_error, train_perplexity)
        # loss, _ = test(valid_loader, model, args)
        # pbar.print_eval(loss)
        valid_loss = evaluate(args, vq_vae_loss, pbar, valid_loader, model)
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            best_valid_epoch = epoch
            torch.save(model.state_dict(), args.save_path)
        pbar.print_end_epoch()

    print("Plotting training results")
    utils.plot_results(train_recon_error, train_perplexity,
                       "results/train.png")

    print("Evaluate and plot validation set")
    generate_samples(model, valid_loader)