def test(): """ Function to carry out the testing/validation loop for the Full Network for a single epoch. :return: None """ running_recon_loss = 0.0 running_vp_loss = 0.0 model.eval() for batch_idx, (vp_diff, vid1, vid2) in enumerate(testloader): vp_diff = vp_diff.type(torch.FloatTensor).to(device) vid1, vid2 = vid1.to(device), vid2.to(device) img1, img2 = get_first_frame(vid1), get_first_frame(vid2) img1, img2 = img1.to(device), img2.to(device) with torch.no_grad(): gen_v2, vp_est = model(vp_diff=vp_diff, vid1=vid1, img2=img2) # save videos convert_to_vid(tensor=vid1, output_dir=output_video_dir, batch_num=batch_idx + 1, view=1, item_type='input') convert_to_vid(tensor=vid2, output_dir=output_video_dir, batch_num=batch_idx + 1, view=2, item_type='input') convert_to_vid(tensor=gen_v2, output_dir=output_video_dir, batch_num=batch_idx + 1, view=2, item_type='output') export_vps(vp_gt=vp_diff, vp_est=vp_est, output_dir=output_video_dir, batch_num=batch_idx + 1) # loss recon_loss = criterion(gen_v2, vid2) vp_loss = criterion(vp_est, vp_diff) loss = recon_loss + vp_loss running_recon_loss += recon_loss.item() running_vp_loss += vp_loss.item() if (batch_idx + 1) % 10 == 0: print('\tBatch {}/{} ReconLoss:{} VPLoss:{}'.format( batch_idx + 1, len(testloader), "{0:.5f}".format(recon_loss), "{0:.5f}".format(vp_loss))) print('Testing Complete ReconLoss:{} VPLoss:{}'.format( "{0:.5f}".format((running_recon_loss / len(testloader))), "{0:.5f}".format((running_vp_loss / len(testloader)))))
def training_loop(epoch): """ Function carrying out the training loop for the Full Network for a single epoch. :param epoch: (int) The current epoch in which the generator is training. :return: None """ running_recon_loss = 0.0 running_vp_loss = 0.0 running_perc_loss = 0.0 model.train() for batch_idx, (vp_diff, vid1, vid2) in enumerate(trainloader): vp_diff = vp_diff.type(torch.FloatTensor).to(device) vid1, vid2 = vid1.to(device), vid2.to(device) img1, img2 = get_first_frame(vid1), get_first_frame(vid2) img1, img2 = img1.to(device), img2.to(device) optimizer.zero_grad() gen_v2, vp_est = model(vp_diff=vp_diff, vid1=vid1, img2=img2) # loss recon_loss = criterion(gen_v2, vid2) vp_loss = criterion(vp_est, vp_diff) feat_gen = perceptual_loss( torch.reshape(gen_v2, (BATCH_SIZE * FRAMES, CHANNELS, HEIGHT, WIDTH))) feat_gt = perceptual_loss( torch.reshape(vid2, (BATCH_SIZE * FRAMES, CHANNELS, HEIGHT, WIDTH))) perc_loss = f.cosine_similarity(feat_gen, feat_gt) print(perc_loss.shape()) # del vid1, vid2, img1, img2, vp_diff, gen_v2, vp_est loss = (0.1 * recon_loss) + vp_loss + perc_loss loss.backward() optimizer.step() running_recon_loss += recon_loss.item() running_vp_loss += vp_loss.item() running_perc_loss += perc_loss.item() if (batch_idx + 1) % 10 == 0: print('\tBatch {}/{} ReconLoss:{} VPLoss:{} PLoss:{}'.format( batch_idx + 1, len(trainloader), "{0:.5f}".format(recon_loss), "{0:.5f}".format(vp_loss), "{0:.5f}".format(perc_loss))) print('Training Epoch {}/{} ReconLoss:{} VPLoss:{} PLoss:{}'.format( epoch + 1, NUM_EPOCHS, "{0:.5f}".format( (running_recon_loss / len(trainloader))), "{0:.5f}".format( (running_vp_loss / len(trainloader))), "{0:.5f}".format( (running_perc_loss / len(trainloader)))))
def testing_loop(epoch): """ Function to carry out the testing/validation loop for the Full Network for a single epoch. :param epoch: (int) The current epoch in which the generator is testing/validating. :return: None """ running_recon_loss = 0.0 running_vp_loss = 0.0 model.eval() for batch_idx, (vp_diff, vid1, vid2) in enumerate(testloader): vp_diff = vp_diff.type(torch.FloatTensor).to(device) vid1, vid2 = vid1.to(device), vid2.to(device) img1, img2 = get_first_frame(vid1), get_first_frame(vid2) img1, img2 = img1.to(device), img2.to(device) with torch.no_grad(): gen_v2, vp_est = model(vp_diff=vp_diff, vid1=vid1, img2=img2) # loss recon_loss = criterion(gen_v2, vid2) vp_loss = criterion(vp_est, vp_diff) running_recon_loss += recon_loss.item() running_vp_loss += vp_loss.item() if (batch_idx + 1) % 10 == 0: print('\tBatch {}/{} ReconLoss:{} VPLoss:{}'.format( batch_idx + 1, len(testloader), "{0:.5f}".format(recon_loss), "{0:.5f}".format(vp_loss))) print('Validation Epoch {}/{} ReconLoss:{} VPLoss:{}'.format( epoch + 1, NUM_EPOCHS, "{0:.5f}".format((running_recon_loss / len(testloader))), "{0:.5f}".format((running_vp_loss / len(testloader))))) return running_recon_loss / len(testloader)
def train_model(starting_epoch): min_gloss = 1.0 min_dloss = 1.0 for epoch in range(starting_epoch, NUM_EPOCHS): # opt.n_epochs): running_g_loss = 0.0 running_recon_loss = 0.0 running_vp_loss = 0.0 running_perc_loss = 0.0 running_d_loss = 0.0 for batch_idx, (vp_diff, vid1, vid2) in enumerate(trainloader): vp_diff = vp_diff.type(torch.FloatTensor).to(device) vid1, vid2 = vid1.to(device), vid2.to(device) img1, img2 = get_first_frame(vid1), get_first_frame(vid2) img1, img2 = img1.to(device), img2.to(device) batch_size = vp_diff.shape[0] # Adversarial ground truths valid = Variable(FloatTensor(batch_size).fill_(1.0), requires_grad=False) fake = Variable(FloatTensor(batch_size).fill_(0.0), requires_grad=False) # Configure input real_vids_v1, real_vids_v2 = Variable(vid1.type(FloatTensor)), Variable(vid2.type(FloatTensor)) # labels = to_categorical(labels.numpy(), num_columns=opt.n_classes) # ----------------- # Train Generator # ----------------- optimizer_G.zero_grad() # Generate a batch of images # gen_imgs = generator(z, label_input, code_input) gen_v2, vp_est = generator(vp_diff=vp_diff, vid1=vid1, img2=img2) # Loss measures generator's ability to fool the discriminator validity = discriminator(gen_v2) # print(validity.size()) # print(valid.size()) g_loss = adversarial_loss(validity, valid) recon_loss = criterion(gen_v2, vid2) vp_loss = criterion(vp_est, vp_diff) feat_gen = perceptual_loss(torch.reshape(gen_v2, (BATCH_SIZE * FRAMES, CHANNELS, HEIGHT, WIDTH))) feat_gt = perceptual_loss(torch.reshape(vid2, (BATCH_SIZE * FRAMES, CHANNELS, HEIGHT, WIDTH))) perc_lossses = [] for i in range(4): perc_lossses.append(torch.mean(f.cosine_similarity(feat_gen, feat_gt))) total_g_loss = (0.3 * g_loss) + (0.3 * recon_loss + vp_loss) for i in range(4): total_g_loss += (0.3 * perc_lossses[i]) total_g_loss.backward() optimizer_G.step() # --------------------- # Train Discriminator # --------------------- optimizer_D.zero_grad() # Loss for real images real_pred = discriminator(real_vids_v2) # print(real_pred.size()) # print(valid.size()) d_real_loss = adversarial_loss(real_pred, valid) # Loss for fake images fake_pred = discriminator(gen_v2.detach()) # print(fake_pred.size()) # print(fake.size()) d_fake_loss = adversarial_loss(fake_pred, fake) # Total discriminator loss d_loss = (d_real_loss + d_fake_loss) / 2 d_loss.backward() optimizer_D.step() # -------------- # Log Progress # -------------- running_g_loss += g_loss.item() running_recon_loss += recon_loss.item() running_vp_loss += vp_loss.item() running_perc_loss += perc_lossses[-1].item() running_d_loss += d_loss.item() if (batch_idx + 1) % 10 == 0: print('\tBatch {}/{} GLoss:{} ReconLoss:{} VPLoss:{} PLoss:{} DLoss:{}'.format( batch_idx + 1, len(trainloader), "{0:.5f}".format(g_loss), "{0:.5f}".format(recon_loss), "{0:.5f}".format(vp_loss), "{0:.5f}".format(perc_lossses[-1]), "{0:.5f}".format(d_loss))) print('Training Epoch {}/{} GLoss:{} ReconLoss:{} VPLoss:{} PLoss:{} DLoss:{}'.format( epoch + 1, NUM_EPOCHS, "{0:.5f}".format((running_g_loss / len(trainloader))), "{0:.5f}".format((running_recon_loss / len(trainloader))), "{0:.5f}".format((running_vp_loss / len(trainloader))), "{0:.5f}".format((running_perc_loss / len(trainloader))), "{0:.5f}".format((running_d_loss / len(trainloader))))) avg_gloss = ((running_g_loss + running_recon_loss + running_vp_loss + running_perc_loss) / len(trainloader)) avg_dloss = running_d_loss / len(trainloader) if avg_gloss < min_gloss or epoch == 0: min_gloss = avg_gloss torch.save(generator.state_dict(), gen_weight_file[:-3] + '_{}'.format(epoch) + '.pt') if avg_dloss < min_dloss or epoch == 0: min_dloss = avg_dloss torch.save(discriminator.state_dict(), disc_weight_file[:-3] + '_{}'.format(epoch) + '.pt') print('MinGloss:{} MinDLoss:{}'.format(min_gloss, min_dloss))