_, dis_output_fake = cgan(batch_y, batch_x) loss_gen = bce_loss(dis_output_fake, y_dis_real) loss_gen.backward() optim_gen.step() count_update_step_cgan += 1 if count_update_step_cgan == 1: viz = Visdom() x_value = np.asarray(count_update_step_cgan).reshape(1, ) x_label = 'Training Step' y_value = np.column_stack( (np.asarray(loss_dis.item()), np.asarray(loss_gen.item()))) y_label = 'Loss' title = 'Discriminator and Generator Losses' legend = ['Loss_Dis', 'Loss_Gen'] win_dis_gen = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) elif count_update_step_cgan % 50 == 0: x_value = np.asarray(count_update_step_cgan).reshape(1, ) y_value = np.column_stack( (np.asarray(loss_dis.item()), np.asarray(loss_gen.item()))) update_vis(viz, win_dis_gen, x_value, y_value) # evaluate the model if count_update_step_cgan % 1000 == 0: print('\nUpdate step: {:d}' '\nmean loss_dis: {:.4f}' '\nmean loss_gen: {:.4f}'.format(count_update_step_cgan, loss_dis.item(), loss_gen.item())) # save the midterm model sates
def training(params, encoder, decoder, optim_encoder, optim_decoder, lr_scheduler_encoder, lr_scheduler_decoder, device, attrs_class, decorr_regul, train_loader, valid_loader, save_dir): encoder.train() decoder.train() count_update_step = 0 for i in range(params.n_epochs): for batch_idx, sample_batched in enumerate(train_loader): data, label = sample_batched['image'], sample_batched['attributes'] batch_x = data.to(device) batch_y = label.to(device) count_update_step += 1 ############################ # (1) update encoder ############################ for p in encoder.parameters(): p.requires_grad_(True) for p in decoder.parameters(): p.requires_grad_(False) optim_encoder.zero_grad() y_attrs, z_latent = encoder(batch_x) x_recons = decoder(batch_y, z_latent) # classification loss loss_class = attrs_class(y_attrs, batch_y) # decorrelation loss y_attrs_sigmoid = torch.sigmoid(y_attrs) loss_decorr = decorr_regul(y_attrs_sigmoid, z_latent) # image reconstruction loss loss_recons_image = torch.mean( 0.5 * (batch_x.view(len(batch_x), -1) - x_recons.view(len(x_recons), -1))**2) loss_encoder = params.lambda_class * loss_class + params.lambda_decorr * loss_decorr \ + params.lambda_recons * loss_recons_image loss_encoder.backward() optim_encoder.step() ############################ # (2) update decoder ############################ for p in decoder.parameters(): p.requires_grad_(True) for p in encoder.parameters(): p.requires_grad_(False) optim_decoder.zero_grad() y_attrs, z_latent = encoder(batch_x) x_recons = decoder(batch_y, z_latent.detach()) # image reconstruction loss loss_recons_image = torch.mean( 0.5 * (batch_x.view(len(batch_x), -1) - x_recons.view(len(x_recons), -1))**2) loss_decoder = params.lambda_recons * loss_recons_image loss_decoder.backward() optim_decoder.step() # visualize losses if count_update_step == 1: viz = Visdom() x_value = np.asarray(count_update_step).reshape(1, ) x_label = 'Training Step' y_value = np.asarray(loss_recons_image.item()).reshape(1, ) y_label = 'MSE' title = 'Image Reconstruction Loss' legend = ['MSE'] win_img_recons = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) y_value = np.asarray(loss_class.item()).reshape(1, ) y_label = 'Loss' title = 'Binary Cross Entropy Loss' legend = ['Loss_BCE'] win_attrs_class = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) y_value = np.asarray(loss_decorr.item()).reshape(1, ) y_label = 'Loss' title = 'Decorrelation Loss' legend = ['Loss_Decorr'] win_decorr = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) elif count_update_step % 50 == 0: x_value = np.asarray(count_update_step).reshape(1, ) y_value = np.asarray(loss_recons_image.item()).reshape(1, ) update_vis(viz, win_img_recons, x_value, y_value) y_value = np.asarray(loss_class.item()).reshape(1, ) update_vis(viz, win_attrs_class, x_value, y_value) y_value = np.asarray(loss_decorr.item()).reshape(1, ) update_vis(viz, win_decorr, x_value, y_value) # evaluate the model if count_update_step % 1000 == 0: print('\nUpdate step: {:d}' '\nmean loss_img_recons: {:.4f}' '\nmean loss_attrs_class: {:.4f}' '\nmean loss_decorr: {:.4f}'.format( count_update_step, loss_recons_image.item(), loss_class.item(), loss_decorr.item())) # evaluation on validation set _, er_total_valid, \ loss_decorr_valid, loss_recons_valid = evaluate_classification(encoder, decoder, valid_loader, params.n_valid, decorr_regul, device) print( 'Attribute Classification Error Rate on Validation Set: {:.2f}%' .format(er_total_valid)) print('Decorrelation Error on Each Mini-Batch: {:.4f}'.format( loss_decorr_valid)) print('Reconstruction Error on Each Sample: {:.4f}'.format( loss_recons_valid)) if count_update_step == 1000: viz = Visdom() x_value = np.asarray(count_update_step).reshape(1, ) x_label = 'Training Step' y_value = np.asarray(loss_recons_valid).reshape(1, ) y_label = 'Reconstruction Error' title = 'Valid Image Reconstruction Error' legend = ['Recons_Error'] win_img_recons_valid = creat_vis_plot( viz, x_value, y_value, x_label, y_label, title, legend) y_value = np.asarray(er_total_valid).reshape(1, ) y_label = 'Classification Error' title = 'Valid Attr. Classification Error' legend = ['Class_Error'] win_attrs_class_valid = creat_vis_plot( viz, x_value, y_value, x_label, y_label, title, legend) y_value = np.asarray(loss_decorr_valid).reshape(1, ) y_label = 'Decorrelation Error' title = 'Valid Decorrelation Error' legend = ['Decorr_Error'] win_decorr_valid = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) else: x_value = np.asarray(count_update_step).reshape(1, ) y_value = np.asarray(loss_recons_valid).reshape(1, ) update_vis(viz, win_img_recons_valid, x_value, y_value) y_value = np.asarray(er_total_valid).reshape(1, ) update_vis(viz, win_attrs_class_valid, x_value, y_value) y_value = np.asarray(loss_decorr_valid).reshape(1, ) update_vis(viz, win_decorr_valid, x_value, y_value) encoder.train() decoder.train() # save the midterm model sates if count_update_step % 5000 == 0: torch.save( encoder.state_dict(), save_dir + '/encoder_step' + str(count_update_step) + '.pt') torch.save( decoder.state_dict(), save_dir + '/decoder_step' + str(count_update_step) + '.pt') if (10000 < count_update_step < 20000) and (count_update_step % 1000 == 0): torch.save( encoder.state_dict(), save_dir + '/encoder_step' + str(count_update_step) + '.pt') torch.save( decoder.state_dict(), save_dir + '/decoder_step' + str(count_update_step) + '.pt') lr_scheduler_encoder.step() lr_scheduler_decoder.step() # save the whole model torch.save(encoder.state_dict(), save_dir + '/encoder_final.pt') torch.save(decoder.state_dict(), save_dir + '/decoder_final.pt') return count_update_step
def training(params, encoder_y, encoder_z, decoder, discriminator, optim_encoder_y, optim_encoder_z, optim_decoder, optim_discriminator, lr_scheduler_encoder_z, lr_scheduler_decoder, lr_scheduler_dis, device, attrs_class, decorr_regul, train_loader, valid_loader, margin, equilibrium, save_dir): ################################################ # training Encoder_Y ################################################ encoder_y.train() count_update_step_class = 0 for epoch in range(params.n_epochs_EncY): for batch_idx, sample_batched in enumerate(train_loader): data, label = sample_batched['image'], sample_batched['attributes'] batch_x = data.to(device) batch_y = label.to(device) optim_encoder_y.zero_grad() y_attrs = encoder_y(batch_x) y_attrs_sigmoid = torch.sigmoid(y_attrs) loss_class = attrs_class(y_attrs_sigmoid, batch_y) loss_class.backward() optim_encoder_y.step() count_update_step_class += 1 if count_update_step_class % 500 == 0: # evaluation on validation set er_each_attr_valid, er_total_valid = evaluate_class( encoder_y, valid_loader, params.n_valid, device) print( '\nUpdate step: {:d} ' '\nAttribute Classification Error Rate on Validation Set: {:.4f}%' .format(count_update_step_class, er_total_valid)) if count_update_step_class == 500: viz = Visdom() x_value = np.asarray(count_update_step_class).reshape(1, ) x_label = 'Training Step' y_value = np.asarray(er_total_valid).reshape(1, ) y_label = 'Classification Error' title = 'Valid Attr. Classification Error' legend = ['Class_Error'] win_attrs_class_valid = creat_vis_plot( viz, x_value, y_value, x_label, y_label, title, legend) else: x_value = np.asarray(count_update_step_class).reshape(1, ) y_value = np.asarray(er_total_valid).reshape(1, ) update_vis(viz, win_attrs_class_valid, x_value, y_value) encoder_y.train() # save the midterm model sate if count_update_step_class % 1000 == 0: torch.save( encoder_y.state_dict(), save_dir + '/encoder_y_step' + str(count_update_step_class) + '.pt') _, er_total_valid = evaluate_class(encoder_y, valid_loader, params.n_valid, device) print( '\nFinal Attribute Classification Error Rate on Validation Set: {:.4f}%' .format(er_total_valid)) torch.save(encoder_y.state_dict(), save_dir + '/encoder_y_final.pt') ################################################ # training Encoder_Z, Decoder, and Discriminator ################################################ encoder_y.eval() encoder_z.train() decoder.train() discriminator.train() count_update_step = 0 for epoch in range(params.n_epochs): for batch_idx, sample_batched in enumerate(train_loader): data, label = sample_batched['image'], sample_batched['attributes'] batch_x = data.to(device) count_update_step += 1 ############################ # (1) update discriminator ############################ for p in discriminator.parameters(): p.requires_grad_(True) for p in decoder.parameters(): p.requires_grad_(False) for p in encoder_z.parameters(): p.requires_grad_(False) optim_discriminator.zero_grad() y_attrs = encoder_y(batch_x) y_attrs_sigmoid = torch.sigmoid(y_attrs) z_latent = encoder_z(batch_x) x_recons = decoder(z_latent.detach(), y_attrs_sigmoid) # using x_recons as fake data dis_output = discriminator(x_recons.detach(), batch_x, mode='GAN') dis_output_sampled = dis_output[:batch_x.size(0)] dis_output_original = dis_output[batch_x.size(0):] # GAN loss dis_original = -torch.log(dis_output_original + 1e-3) dis_sampled = -torch.log(1 - dis_output_sampled + 1e-3) loss_discriminator = torch.mean(dis_original) + torch.mean( dis_sampled) train_dis = True train_dec = True if ((torch.mean(dis_original)).item() > equilibrium + margin) \ or ((torch.mean(dis_sampled)).item() > equilibrium + margin): train_dec = False if ((torch.mean(dis_original)).item() < equilibrium - margin) \ or ((torch.mean(dis_sampled)).item() < equilibrium - margin): train_dis = False if train_dec is False and train_dis is False: train_dis = True train_dec = True if train_dis: loss_discriminator.backward() optim_discriminator.step() ############################ # (2) update decoder ############################ for p in decoder.parameters(): p.requires_grad_(True) for p in discriminator.parameters(): p.requires_grad_(False) optim_decoder.zero_grad() y_attrs = encoder_y(batch_x) y_attrs_sigmoid = torch.sigmoid(y_attrs) z_latent = encoder_z(batch_x) x_recons = decoder(z_latent.detach(), y_attrs_sigmoid) mid_repre = discriminator(x_recons, batch_x, mode='REC') mid_repre_recons = mid_repre[:batch_x.size(0)] mid_repre_original = mid_repre[batch_x.size(0):] # using x_recons as fake data dis_output = discriminator(x_recons, batch_x, mode='GAN') dis_output_sampled = dis_output[:batch_x.size(0)] dis_output_original = dis_output[batch_x.size(0):] # image reconstruction loss loss_recons_image = torch.mean( 0.5 * (batch_x.view(len(batch_x), -1) - x_recons.view(len(x_recons), -1))**2) # feature reconstruction loss loss_recons_feature = torch.mean( 0.5 * (mid_repre_original - mid_repre_recons)**2) # GAN loss dis_original = -torch.log(dis_output_original + 1e-3) dis_sampled = -torch.log(1 - dis_output_sampled + 1e-3) loss_discriminator = torch.mean(dis_original) + torch.mean( dis_sampled) loss_decoder = 1 * loss_recons_image + params.lambda_recons * loss_recons_feature - \ params.lambda_dis * loss_discriminator train_dis = True train_dec = True if ((torch.mean(dis_original)).item() > equilibrium + margin) \ or ((torch.mean(dis_sampled)).item() > equilibrium + margin): train_dec = False if ((torch.mean(dis_original)).item() < equilibrium - margin) \ or ((torch.mean(dis_sampled)).item() < equilibrium - margin): train_dis = False if train_dec is False and train_dis is False: train_dis = True train_dec = True if train_dec: loss_decoder.backward() optim_decoder.step() ############################ # (3) update encoder_z ############################ for p in encoder_z.parameters(): p.requires_grad_(True) for p in decoder.parameters(): p.requires_grad_(False) optim_encoder_z.zero_grad() y_attrs = encoder_y(batch_x) y_attrs_sigmoid = torch.sigmoid(y_attrs) z_latent = encoder_z(batch_x) x_recons = decoder(z_latent, y_attrs_sigmoid) mid_repre = discriminator(x_recons, batch_x, mode='REC') mid_repre_recons = mid_repre[:batch_x.size(0)] mid_repre_original = mid_repre[batch_x.size(0):] # decorrelation loss start_time = time.time() loss_decorr = decorr_regul(y_attrs_sigmoid, z_latent) end_time = time.time() print( 'Time cost of computing decorr_regul: batch_id={:d}, time={:.9f}' .format(batch_idx, end_time - start_time)) # image reconstruction loss loss_recons_image = torch.mean( 0.5 * (batch_x.view(len(batch_x), -1) - x_recons.view(len(x_recons), -1))**2) # feature reconstruction loss loss_recons_feature = torch.mean( 0.5 * (mid_repre_original - mid_repre_recons)**2) loss_encoder_z = 1 * loss_recons_image + params.lambda_recons * loss_recons_feature + \ get_lambda(params.lambda_decorr, params.lambda_schedule, count_update_step) * loss_decorr \ loss_encoder_z.backward() optim_encoder_z.step() # visualize losses if count_update_step == 1: viz = Visdom() x_value = np.asarray(count_update_step).reshape(1, ) x_label = 'Training Step' y_value = np.asarray(loss_recons_image.item()).reshape(1, ) y_label = 'MSE' title = 'Image Reconstruction Loss' legend = ['MSE'] win_img_recons = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) y_value = np.asarray(loss_recons_feature.item()).reshape(1, ) y_label = 'MSE' title = 'Feature Reconstruction Loss' legend = ['MSE'] win_feature_recons = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) y_value = np.column_stack( (np.asarray(loss_discriminator.item()), np.asarray(loss_decoder.item()))) y_label = 'Loss' title = 'Discriminator and Decoder Losses' legend = ['Loss_Dis', 'Loss_Dec'] win_dis_gen = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) y_value = np.asarray(loss_decorr.item()).reshape(1, ) y_label = 'Loss' title = 'Decorrelation Loss' legend = ['Loss_Decorr'] win_decorr = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) elif count_update_step % 50 == 0: x_value = np.asarray(count_update_step).reshape(1, ) y_value = np.asarray(loss_recons_image.item()).reshape(1, ) update_vis(viz, win_img_recons, x_value, y_value) y_value = np.asarray(loss_recons_feature.item()).reshape(1, ) update_vis(viz, win_feature_recons, x_value, y_value) y_value = np.column_stack( (np.asarray(loss_discriminator.item()), np.asarray(loss_decoder.item()))) update_vis(viz, win_dis_gen, x_value, y_value) y_value = np.asarray(loss_decorr.item()).reshape(1, ) update_vis(viz, win_decorr, x_value, y_value) # evaluate the model if count_update_step % 1000 == 0: print('\nUpdate step: {:d}' '\nmean loss_img_recons: {:.4f}' '\nmean loss_feature_recons: {:.4f}' '\nmean loss_GAN_dis: {:.4f}' '\nmean loss_decoder: {:.4f}' '\nmean loss_decorr: {:.4f}'.format( count_update_step, loss_recons_image.item(), loss_recons_feature.item(), loss_discriminator.item(), loss_decoder.item(), loss_decorr.item())) # evaluation on validation set loss_decorr_valid, loss_recons_valid = evaluate_learning( encoder_y, encoder_z, decoder, valid_loader, params.n_valid, decorr_regul, device) print('Decorrelation Error on Each Mini-Batch: {:.4f}'.format( loss_decorr_valid)) print('Reconstruction Error on Each Sample: {:.4f}'.format( loss_recons_valid)) if count_update_step == 1000: viz = Visdom() x_value = np.asarray(count_update_step).reshape(1, ) x_label = 'Training Step' y_value = np.asarray(loss_recons_valid).reshape(1, ) y_label = 'Reconstruction Error' title = 'Valid Image Reconstruction Error' legend = ['Recons_Error'] win_img_recons_valid = creat_vis_plot( viz, x_value, y_value, x_label, y_label, title, legend) y_value = np.asarray(loss_decorr_valid).reshape(1, ) y_label = 'Decorrelation Error' title = 'Valid Decorrelation Error' legend = ['Decorr_Error'] win_decorr_valid = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) else: x_value = np.asarray(count_update_step).reshape(1, ) y_value = np.asarray(loss_recons_valid).reshape(1, ) update_vis(viz, win_img_recons_valid, x_value, y_value) y_value = np.asarray(loss_decorr_valid).reshape(1, ) update_vis(viz, win_decorr_valid, x_value, y_value) encoder_z.train() decoder.train() # save the midterm model sates if count_update_step % 5000 == 0: torch.save( encoder_z.state_dict(), save_dir + '/encoder_z_step' + str(count_update_step) + '.pt') torch.save( decoder.state_dict(), save_dir + '/decoder_step' + str(count_update_step) + '.pt') torch.save( discriminator.state_dict(), save_dir + '/discriminator_step' + str(count_update_step) + '.pt') lr_scheduler_encoder_z.step() lr_scheduler_decoder.step() lr_scheduler_dis.step() torch.save(encoder_z.state_dict(), save_dir + '/encoder_z_final.pt') torch.save(decoder.state_dict(), save_dir + '/decoder_final.pt') torch.save(discriminator.state_dict(), save_dir + '/discriminator_final.pt') return count_update_step
def training(params, encoder, decoder, optim_encoder, optim_decoder, device, digit_class, decorr_regul, train_loader, valid_loader, save_dir): encoder.train() decoder.train() count_update_step = 0 indices = torch.LongTensor(params.batch_size, 1) labels_onehot = torch.FloatTensor(params.batch_size, params.n_class) for i in range(params.n_epochs): for batch_idx, (batch_x, batch_y) in enumerate(train_loader): batch_x = batch_x.to(device) # convert the labels into one-hot form vectors indices.zero_() indices = batch_y.view(-1, 1) labels_onehot.zero_() labels_onehot.scatter_(1, indices, 1) batch_y_onehot = labels_onehot.to(device) batch_y = batch_y.to(device) count_update_step += 1 ############################ # (1) update encoder ############################ for p in encoder.parameters(): p.requires_grad_(True) for p in decoder.parameters(): p.requires_grad_(False) optim_encoder.zero_grad() y_class, z_latent = encoder(batch_x) x_recons = decoder(batch_y_onehot, z_latent) # classification loss loss_class = digit_class(y_class, batch_y) # decorrelation loss y_class_softmax = F.softmax(y_class, 1) loss_decorr = decorr_regul(y_class_softmax, z_latent) # image reconstruction loss loss_recons_image = torch.mean( 0.5 * (batch_x.view(len(batch_x), -1) - x_recons.view(len(x_recons), -1))**2) loss_encoder = params.lambda_class * loss_class + params.lambda_decorr * loss_decorr \ + params.lambda_recons * loss_recons_image loss_encoder.backward() optim_encoder.step() ############################ # (2) update decoder ############################ for p in decoder.parameters(): p.requires_grad_(True) for p in encoder.parameters(): p.requires_grad_(False) optim_decoder.zero_grad() _, z_latent = encoder(batch_x) x_recons = decoder(batch_y_onehot, z_latent) # image reconstruction loss loss_recons_image = torch.mean( 0.5 * (batch_x.view(len(batch_x), -1) - x_recons.view(len(x_recons), -1))**2) loss_decoder = params.lambda_recons * loss_recons_image loss_decoder.backward() optim_decoder.step() # visualize losses if count_update_step == 1: viz = Visdom() x_value = np.asarray(count_update_step).reshape(1, ) x_label = 'Training Step' y_value = np.asarray(loss_recons_image.item()).reshape(1, ) y_label = 'MSE' title = 'Image Reconstruction Loss' legend = ['MSE'] win_img_recons = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) y_value = np.asarray(loss_class.item()).reshape(1, ) y_label = 'Loss' title = 'Cross Entropy Loss' legend = ['Loss_BCE'] win_attrs_class = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) y_value = np.asarray(loss_decorr.item()).reshape(1, ) y_label = 'Loss' title = 'Decorrelation Loss' legend = ['Loss_Decorr'] win_decorr = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) elif count_update_step % 50 == 0: x_value = np.asarray(count_update_step).reshape(1, ) y_value = np.asarray(loss_recons_image.item()).reshape(1, ) update_vis(viz, win_img_recons, x_value, y_value) y_value = np.asarray(loss_class.item()).reshape(1, ) update_vis(viz, win_attrs_class, x_value, y_value) y_value = np.asarray(loss_decorr.item()).reshape(1, ) update_vis(viz, win_decorr, x_value, y_value) # evaluate the model if count_update_step % 1000 == 0: print('\nUpdate step: {:d}' '\nmean loss_img_recons: {:.4f}' '\nmean loss_class: {:.4f}' '\nmean loss_decorr: {:.4f}'.format( count_update_step, loss_recons_image.item(), loss_class.item(), loss_decorr.item())) # evaluation on validation set loss_recons, loss_decorr, class_error_rate = evaluate_classification( params, encoder, decoder, valid_loader, params.n_valid, decorr_regul, device) print('Classification Error Rate on Validation Set: {:.2f}%'. format(class_error_rate)) print('Decorrelation Error on Each Mini-Batch: {:.4f}'.format( loss_decorr)) print('Reconstruction Error on Each Sample: {:.4f}'.format( loss_recons)) if count_update_step == 1000: viz = Visdom() x_value = np.asarray(count_update_step).reshape(1, ) x_label = 'Training Step' y_value = np.asarray(loss_recons).reshape(1, ) y_label = 'Reconstruction Error' title = 'Valid Image Reconstruction Error' legend = ['Recons_Error'] win_img_recons_valid = creat_vis_plot( viz, x_value, y_value, x_label, y_label, title, legend) y_value = np.asarray(class_error_rate).reshape(1, ) y_label = 'Classification Error' title = 'Valid Classification Error' legend = ['Class_Error'] win_class_valid = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) y_value = np.asarray(loss_decorr).reshape(1, ) y_label = 'Decorrelation Error' title = 'Valid Decorrelation Error' legend = ['Decorr_Error'] win_decorr_valid = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) else: x_value = np.asarray(count_update_step).reshape(1, ) y_value = np.asarray(loss_recons).reshape(1, ) update_vis(viz, win_img_recons_valid, x_value, y_value) y_value = np.asarray(class_error_rate).reshape(1, ) update_vis(viz, win_class_valid, x_value, y_value) y_value = np.asarray(loss_decorr).reshape(1, ) update_vis(viz, win_decorr_valid, x_value, y_value) encoder.train() decoder.train() # save the midterm model sates if count_update_step % 5000 == 0: torch.save( encoder.state_dict(), save_dir + '/encoder_step' + str(count_update_step) + '.pt') torch.save( decoder.state_dict(), save_dir + '/decoder_step' + str(count_update_step) + '.pt') # lr_scheduler_encoder.step() # lr_scheduler_decoder.step() # save the whole model torch.save(encoder.state_dict(), save_dir + '/encoder_final.pt') torch.save(decoder.state_dict(), save_dir + '/decoder_final.pt') return count_update_step
if train_dis: loss_discriminator.backward() optim_discriminator.step() # visualize losses count_update_step += 1 if count_update_step == 1: viz = Visdom() x_value = np.asarray(count_update_step).reshape(1, ) x_label = 'Training Step' y_value = np.asarray(torch.mean(nle_value).item()).reshape(1, ) y_label = 'MSE' title = 'Image Reconstruction Loss' legend = ['MSE'] win_img_recons = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) y_value = np.asarray(torch.mean(mse_value).item()).reshape(1, ) y_label = 'MSE' title = 'Feature Reconstruction Loss' legend = ['MSE'] win_feature_recons = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend) y_value = np.column_stack((np.asarray(loss_discriminator.item()), np.asarray(loss_decoder.item()))) y_label = 'Loss' title = 'Discriminator and Decoder Losses' legend = ['Loss_Dis', 'Loss_Dec'] win_dis_gen = creat_vis_plot(viz, x_value, y_value, x_label, y_label, title, legend)