torch.cuda.manual_seed(opt.seed) # Networks if opt.upsample == 'ori': netG_A2B = Generator_ori(opt.input_nc, opt.output_nc) netG_B2A = Generator_ori(opt.output_nc, opt.input_nc) else: netG_A2B = Generator(opt.input_nc, opt.output_nc) netG_B2A = Generator(opt.output_nc, opt.input_nc) netD_A = Discriminator(opt.input_nc) netD_B = Discriminator(opt.output_nc) netG_A2B.cuda() netG_B2A.cuda() netD_A.cuda() netD_B.cuda() netG_A2B.apply(weights_init_normal) netG_B2A.apply(weights_init_normal) netD_A.apply(weights_init_normal) netD_B.apply(weights_init_normal) torch.save(netG_A2B.state_dict(), "initial_weights/netG_A2B_seed_{}.pth.tar".format(opt.seed)) torch.save(netG_B2A.state_dict(), "initial_weights/netG_B2A_seed_{}.pth.tar".format(opt.seed)) torch.save(netD_A.state_dict(), "initial_weights/netD_A_seed_{}.pth.tar".format(opt.seed)) torch.save(netD_B.state_dict(), "initial_weights/netD_B_seed_{}.pth.tar".format(opt.seed))
optimizer_D_B.step() ################################### # print loss: if i % 100 == 0: print('epoch %d-%d: loss_G %.4f, loss_D %.4f' % (epoch, i, loss_G.data, (loss_D_A + loss_D_B).data)) loss_G_value += loss_G.data loss_D_value += (loss_D_A + loss_D_B).data loss_G_GAN_value += (loss_GAN_A2B + loss_GAN_B2A).data loss_G_cycle_value += (loss_cycle_ABA + loss_cycle_BAB).data loss_G_identity_value += (loss_identity_A + loss_identity_B).data if epoch % 5 == 0 or epoch == opt.n_epochs - 1: torch.save( netG_A2B.state_dict(), os.path.join(pth_dir, 'netG_A2B_epoch_{}.pth'.format(epoch))) torch.save( netG_B2A.state_dict(), os.path.join(pth_dir, 'netG_B2A_epoch_{}.pth'.format(epoch))) torch.save(netD_A.state_dict(), os.path.join(pth_dir, 'netD_A_epoch_{}.pth'.format(epoch))) torch.save(netD_B.state_dict(), os.path.join(pth_dir, 'netD_B_epoch_{}.pth'.format(epoch))) ## at the end of each epoch # plot loss: losses = { 'loss_G': (loss_G_lst, loss_G_value), 'loss_D': (loss_D_lst, loss_D_value), 'loss_G_GAN': (loss_G_GAN_lst, loss_G_GAN_value), 'loss_G_cycle': (loss_G_cycle_lst, loss_G_cycle_value),
netD_A.load_state_dict( torch.load(os.path.join(args.pretrain, 'netD_A_epoch_199.pth'))) netD_B.load_state_dict( torch.load(os.path.join(args.pretrain, 'netD_B_epoch_199.pth'))) #one shot pruning pruning_generate(netG_A2B, (1 - args.percent)) pruning_generate(netG_B2A, (1 - args.percent)) see_remain_rate(netG_A2B) see_remain_rate(netG_B2A) #rewind to random weight a2b_init = torch.load(os.path.join(args.rand, 'netG_A2B_seed_1.pth.tar')) b2a_init = torch.load(os.path.join(args.rand, 'netG_B2A_seed_1.pth.tar')) a2b_orig_weight = rewind_weight(a2b_init, netG_A2B.state_dict().keys()) b2a_orig_weight = rewind_weight(b2a_init, netG_B2A.state_dict().keys()) a2b_weight = netG_A2B.state_dict() b2a_weight = netG_B2A.state_dict() a2b_weight.update(a2b_orig_weight) b2a_weight.update(b2a_orig_weight) netG_A2B.load_state_dict(a2b_weight) netG_B2A.load_state_dict(b2a_weight) # Lossess criterion_GAN = torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() # optimizers & LR schedulers optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
pruning_generate(netG_B2A, opt.rate) pruning_generate(netD_A, opt.rate) pruning_generate(netD_B, opt.rate) see_remain_rate(netG_A2B) see_remain_rate(netG_B2A) see_remain_rate(netD_A) see_remain_rate(netD_B) # Rewind to random G weight a2b_init = torch.load(os.path.join(opt.rand, 'netG_A2B_seed_1.pth.tar')) b2a_init = torch.load(os.path.join(opt.rand, 'netG_B2A_seed_1.pth.tar')) a_init = torch.load(os.path.join(opt.rand, 'netD_A_seed_1.pth.tar')) b_init = torch.load(os.path.join(opt.rand, 'netD_B_seed_1.pth.tar')) a2b_orig_weight = rewind_weight(a2b_init, netG_A2B.state_dict().keys()) b2a_orig_weight = rewind_weight(b2a_init, netG_B2A.state_dict().keys()) a_orig_weight = rewind_weight(a_init, netD_A.state_dict().keys()) b_orig_weight = rewind_weight(b_init, netD_B.state_dict().keys()) a2b_weight = netG_A2B.state_dict() b2a_weight = netG_B2A.state_dict() a_weight = netD_A.state_dict() b_weight = netD_B.state_dict() a2b_weight.update(a2b_orig_weight) b2a_weight.update(b2a_orig_weight) a_weight.update(a_orig_weight) b_weight.update(b_orig_weight) netG_A2B.load_state_dict(a2b_weight) netG_B2A.load_state_dict(b2a_weight) # Rewind to random D weight netD_A.load_state_dict(a_weight)