def main(args): ############################### # TRAIN PREP ############################### print("Loading data") train_loader, valid_loader, data_var, input_size = \ data.get_data(args.data_folder,args.batch_size) args.input_size = input_size args.downsample = args.input_size[-1] // args.enc_height args.data_variance = data_var print(f"Training set size {len(train_loader.dataset)}") print(f"Validation set size {len(valid_loader.dataset)}") print("Loading model") if args.model == 'diffvqvae': model = DiffVQVAE(args).to(device) elif args.model == 'vqvae': model = VQVAE(args).to(device) print( f'The model has {utils.count_parameters(model):,} trainable parameters' ) optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, amsgrad=False) print(f"Start training for {args.num_epochs} epochs") num_batches = math.ceil( len(train_loader.dataset) / train_loader.batch_size) pbar = Progress(num_batches, bar_length=10, custom_increment=True) # Needed for bpd args.KL = args.enc_height * args.enc_height * args.num_codebooks * \ np.log(args.num_embeddings) args.num_pixels = np.prod(args.input_size) ############################### # MAIN TRAIN LOOP ############################### best_valid_loss = float('inf') train_bpd = [] train_recon_error = [] train_perplexity = [] args.global_it = 0 for epoch in range(args.num_epochs): pbar.epoch_start() train_epoch(args, vq_vae_loss, pbar, train_loader, model, optimizer, train_bpd, train_recon_error, train_perplexity) # loss, _ = test(valid_loader, model, args) # pbar.print_eval(loss) valid_loss = evaluate(args, vq_vae_loss, pbar, valid_loader, model) if valid_loss < best_valid_loss: best_valid_loss = valid_loss best_valid_epoch = epoch torch.save(model.state_dict(), args.save_path) pbar.print_end_epoch() print("Plotting training results") utils.plot_results(train_recon_error, train_perplexity, "results/train.png") print("Evaluate and plot validation set") generate_samples(model, valid_loader)