def sigmoid_rampup(global_step, start_iter, end_iter): if global_step < start_iter: return 0.0 rampup_length = end_iter - start_iter cur_ramp = global_step - start_iter cur_ramp = np.clip(cur_ramp, 0, rampup_length) phase = 1.0 - cur_ramp / rampup_length return np.exp(-5.0 * phase * phase) itr = inputs.get_data_iter(batch_size=FLAGS.bs_c, subset=FLAGS.n_labels) # itr_u = inputs.get_data_iter(batch_size=FLAGS.bs_c) netG, optim_G = inputs.get_generator_optimizer() netD, optim_D = inputs.get_discriminator_optimizer() netC, optim_c = inputs.get_classifier_optimizer() netG, netD, netC = netG.to(device), netD.to(device), netC.to(device) netG = nn.DataParallel(netG) netD = nn.DataParallel(netD) netC = nn.DataParallel(netC) netC_T, _ = inputs.get_classifier_optimizer() netC_T = netC_T.to(device) netC_T = nn.DataParallel(netC_T) netC.train() netC_T.train() Torture.update_average(netC_T, netC, 0) for p in netC_T.parameters(): p.requires_grad_(False) if FLAGS.c_step == "ramp_swa": netC_swa, _ = inputs.get_classifier_optimizer()
torch.cuda.manual_seed(1235) np.random.seed(1236) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") FLAGS.device = device nlabels = FLAGS.y_dist.dim batch_size = FLAGS.training.batch_size checkpoint_io = CheckpointIO(checkpoint_dir=MODELS_FOLDER) logger = Logger(log_dir=SUMMARIES_FOLDER) itr = inputs.get_data_iter(batch_size) GNet, GOptim = inputs.get_generator_optimizer() DNet, DOptim = inputs.get_discriminator_optimizer() GNet_test = copy.deepcopy(GNet) update_average(GNet_test, GNet, 0.0) ydist = get_ydist(**vars(FLAGS.y_dist)) zdist = get_zdist(**vars(FLAGS.z_dist)) checkpoint_io.register_modules(GNet=GNet, GOptim=GOptim, DNet=DNet, DOptim=DOptim) checkpoint_io.register_modules(GNet_test=GNet_test) trainer_dict = {"baseline": trainer_baseline, "fd": trainer_fd} trainer_used = trainer_dict[FLAGS.trainer.name] trainer = trainer_used.Trainer(GNet, DNet, GOptim, DOptim, **vars(FLAGS.trainer.kwargs))