def __init__(self, opt): self.device = torch.device('cuda') self.opt = opt self.G = Generator(self.opt['network_G']).to(self.device) util.init_weights(self.G, init_type='kaiming', scale=0.1) if self.opt['path']['pretrain_G']: self.G.load_state_dict(torch.load(self.opt['path']['pretrain_G']), strict=True) self.D = Discriminator(self.opt['network_D']).to(self.device) util.init_weights(self.D, init_type='kaiming', scale=1) self.FE = VGGFeatureExtractor().to(self.device) self.G.train() self.D.train() self.FE.eval() self.log_dict = OrderedDict() self.optim_params = [ v for k, v in self.G.named_parameters() if v.requires_grad ] self.opt_G = torch.optim.Adam(self.optim_params, lr=self.opt['train']['lr_G'], betas=(self.opt['train']['b1_G'], self.opt['train']['b2_G'])) self.opt_D = torch.optim.Adam(self.D.parameters(), lr=self.opt['train']['lr_D'], betas=(self.opt['train']['b1_D'], self.opt['train']['b2_D'])) self.optimizers = [self.opt_G, self.opt_D] self.schedulers = [ lr_scheduler.MultiStepLR(optimizer, self.opt['train']['lr_steps'], self.opt['train']['lr_gamma']) for optimizer in self.optimizers ]
def __init__(self): logger.info('Set Data Loader') self.dataset = FoodDataset(transform=transforms.Compose([ToTensor()])) self.data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) checkpoint, checkpoint_name = self.load_checkpoint(model_dump_path) if checkpoint == None: logger.info( 'Don\'t have pre-trained model. Ignore loading model process.') logger.info('Set Generator and Discriminator') self.G = Generator(tag=tag_size).to(device) self.D = Discriminator(tag=tag_size).to(device) logger.info('Initialize Weights') self.G.apply(initital_network_weights).to(device) self.D.apply(initital_network_weights).to(device) logger.info('Set Optimizers') self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.epoch = 0 else: logger.info('Load Generator and Discriminator') self.G = Generator(tag=tag_size).to(device) self.D = Discriminator(tag=tag_size).to(device) logger.info('Load Pre-Trained Weights From Checkpoint'.format( checkpoint_name)) self.G.load_state_dict(checkpoint['G']) self.D.load_state_dict(checkpoint['D']) logger.info('Load Optimizers') self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.optimizer_G.load_state_dict(checkpoint['optimizer_G']) self.optimizer_D.load_state_dict(checkpoint['optimizer_D']) self.epoch = checkpoint['epoch'] logger.info('Set Criterion') self.a_D = alexnet.alexnet(num_classes=tag_size).to(device) self.optimizer_a_D = torch.optim.Adam(self.a_D.parameters(), lr=learning_rate, betas=(beta_1, .999))
def _build_model(self): device = torch.device('cuda') data_dimension = self.config.data['dimension'] generator_hidden_layers = self.config.model['generator_hidden_layers'] use_dropout = self.config.model['use_dropout'] drop_prob = self.config.model['drop_prob'] use_ac_func = self.config.model['use_ac_func'] activation = self.config.model['activation'] disc_hidden_layers = self.config.model['disc_hidden_layers'] logger.log("Loading {} network ...".format(colored('generator', 'red'))) gen_fc_layers = [ self.latent_dim, *generator_hidden_layers, data_dimension ] generator = Generator(gen_fc_layers, use_dropout, drop_prob, use_ac_func, activation).to(device) logger.log("Loading {} network ...".format( colored('discriminator', 'red'))) disc_fc_layers = [data_dimension, *disc_hidden_layers, 1] discriminator = Discriminator(disc_fc_layers, use_dropout, drop_prob, use_ac_func, activation).to(device) wandb.watch([generator, discriminator]) g_optimizer, d_optimizer = self._setup_optimizers( generator, discriminator) return generator, discriminator, g_optimizer, d_optimizer
def __init__(self, args): super(SegTransferModel, self).__init__() # n_classes for Fundus: 1 # n_classes for OCT: 12 self.args = args assert args.data_modality in [ 'oct', 'fundus' ], 'error in seg_mode, got {}'.format(args.data_modality) # model on gpu if self.args.data_modality == 'fundus': model_G = UNet_4mp(n_channels=1, n_classes=1) else: model_G = UNet_4mp(n_channels=1, n_classes=12) model_D = Discriminator(in_channels=1) model_G = nn.DataParallel(model_G).cuda() model_D = nn.DataParallel(model_D).cuda() l1_loss = nn.L1Loss().cuda() nll_loss = nn.NLLLoss().cuda() adversarial_loss = AdversarialLoss().cuda() self.add_module('model_G', model_G) self.add_module('model_D', model_D) self.add_module('l1_loss', l1_loss) self.add_module('nll_loss', nll_loss) self.add_module('adversarial_loss', adversarial_loss) # optimizer self.optimizer_G = torch.optim.Adam(params=self.model_G.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(args.b1, args.b2)) self.optimizer_D = torch.optim.Adam(params=self.model_D.parameters(), lr=args.lr * args.d2g_lr, weight_decay=args.weight_decay, betas=(args.b1, args.b2)) # Optionally resume from a checkpoint if self.args.resume: ckpt_root = os.path.join(args.output_root, args.project, 'checkpoints') ckpt_path = os.path.join(ckpt_root, args.resume) if os.path.isfile(ckpt_path): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(ckpt_path) args.start_epoch = checkpoint['epoch'] self.model_G.load_state_dict(checkpoint['state_dict_G']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: raise ValueError("=> no checkpoint found at '{}'".format( args.resume))
def __init__(self, args, ablation_mode=4): super(PNetModel, self).__init__() self.args = args """ ablation study mode """ # 0: output_structure (1 feature) # 2: image (1 feature), i.e. auto-encoder # 4: output_structure + image (2 features) # model on gpu if self.args.data_modality == 'fundus': model_G1 = Strcutre_Extraction_Network(n_channels=1, n_classes=1) else: model_G1 = Strcutre_Extraction_Network(n_channels=1, n_classes=12) model_G2 = Image_Reconstruction_Network( in_ch=1, modality=self.args.data_modality, ablation_mode=ablation_mode) model_D = Discriminator(in_channels=1) model_G1 = nn.DataParallel(model_G1).cuda() model_G2 = nn.DataParallel(model_G2).cuda() model_D = nn.DataParallel(model_D).cuda() l1_loss = nn.L1Loss().cuda() adversarial_loss = AdversarialLoss().cuda() self.add_module('model_G1', model_G1) self.add_module('model_G2', model_G2) self.add_module('model_D', model_D) self.add_module('l1_loss', l1_loss) self.add_module('adversarial_loss', adversarial_loss) # optimizer self.optimizer_G = torch.optim.Adam(params=self.model_G2.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(args.b1, args.b2)) self.optimizer_D = torch.optim.Adam(params=self.model_D.parameters(), lr=args.lr * args.d2g_lr, weight_decay=args.weight_decay, betas=(args.b1, args.b2)) # load 1-st ckpts if self.args.server == 'ai': seg_ckpt_root = os.path.join('/root/workspace', args.project, 'save_models') else: seg_ckpt_root = os.path.join('/home/imed/new_disk/workspace', args.project, 'save_models') if self.args.data_modality == 'fundus': if self.args.DA_ablation_mode_isee == 0: _g_zero_point = '0' elif self.args.DA_ablation_mode_isee == 0.001: _g_zero_point = '001' elif self.args.DA_ablation_mode_isee == 0.0001: # this is the default _g_zero_point = '0001' else: raise NotImplementedError('error') seg_ckpt_path = os.path.join( seg_ckpt_root, '1st_fundus_seg_g_{}.pth.tar'.format(_g_zero_point)) ## orginal seg mdel # seg_ckpt_path = os.path.join(seg_ckpt_root, '1st_fundus_seg_vgg.pth.tar') else: seg_ckpt_path = os.path.join(seg_ckpt_root, '1st_oct_seg.pth.tar') if os.path.isfile(seg_ckpt_path): print("=> loading G1 checkpoint") checkpoint = torch.load(seg_ckpt_path) self.model_G1.load_state_dict(checkpoint['state_dict_G']) print("=> loaded G1 checkpoint (epoch {}) \n from {}".format( checkpoint['epoch'], seg_ckpt_path)) else: raise ValueError( "=> no checkpoint found at '{}'".format(seg_ckpt_path)) # Optionally resume from a checkpoint if self.args.resume: ckpt_root = os.path.join(self.args.output_root, args.project, 'checkpoints') ckpt_path = os.path.join(ckpt_root, args.resume) if os.path.isfile(ckpt_path): print("=> loading G2 checkpoint '{}'".format(args.resume)) checkpoint = torch.load(ckpt_path) args.start_epoch = checkpoint['epoch'] self.model_G2.load_state_dict(checkpoint['state_dict_G']) print("=> loaded G2 checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume))
def main(args): env = gym.make('CartPole-v0') env.seed(0) ob_space = env.observation_space Policy = Policy_net('policy', env) Old_Policy = Policy_net('old_policy', env) PPO = PPOTrain(Policy, Old_Policy, gamma=args.gamma) D = Discriminator(env) expert_observations = np.genfromtxt('trajectory/observations.csv') expert_actions = np.genfromtxt('trajectory/actions.csv', dtype=np.int32) saver = tf.train.Saver() with tf.Session() as sess: writer = tf.summary.FileWriter(args.logdir, sess.graph) sess.run(tf.global_variables_initializer()) obs = env.reset() success_num = 0 for iteration in range(args.iteration): observations = [] actions = [] # do NOT use rewards to update policy rewards = [] v_preds = [] run_policy_steps = 0 while True: run_policy_steps += 1 obs = np.stack([obs]).astype(dtype=np.float32) # prepare to feed placeholder Policy.obs act, v_pred = Policy.act(obs=obs, stochastic=True) act = np.asscalar(act) v_pred = np.asscalar(v_pred) next_obs, reward, done, info = env.step(act) observations.append(obs) actions.append(act) rewards.append(reward) v_preds.append(v_pred) if done: next_obs = np.stack([next_obs]).astype(dtype=np.float32) # prepare to feed placeholder Policy.obs _, v_pred = Policy.act(obs=next_obs, stochastic=True) v_preds_next = v_preds[1:] + [np.asscalar(v_pred)] obs = env.reset() break else: obs = next_obs writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_length', simple_value=run_policy_steps)]) , iteration) writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_reward', simple_value=sum(rewards))]) , iteration) if sum(rewards) >= 195: success_num += 1 if success_num >= 100: saver.save(sess, args.savedir + '/model.ckpt') print('Clear!! Model saved.') break else: success_num = 0 # convert list to numpy array for feeding tf.placeholder observations = np.reshape(observations, newshape=[-1] + list(ob_space.shape)) actions = np.array(actions).astype(dtype=np.int32) # train discriminator for i in range(2): D.train(expert_s=expert_observations, expert_a=expert_actions, agent_s=observations, agent_a=actions) # output of this discriminator is reward d_rewards = D.get_rewards(agent_s=observations, agent_a=actions) d_rewards = np.reshape(d_rewards, newshape=[-1]).astype(dtype=np.float32) gaes = PPO.get_gaes(rewards=d_rewards, v_preds=v_preds, v_preds_next=v_preds_next) gaes = np.array(gaes).astype(dtype=np.float32) # gaes = (gaes - gaes.mean()) / gaes.std() v_preds_next = np.array(v_preds_next).astype(dtype=np.float32) # train policy inp = [observations, actions, gaes, d_rewards, v_preds_next] PPO.assign_policy_parameters() for epoch in range(6): sample_indices = np.random.randint(low=0, high=observations.shape[0], size=32) # indices are in [low, high) sampled_inp = [np.take(a=a, indices=sample_indices, axis=0) for a in inp] # sample training data PPO.train(obs=sampled_inp[0], actions=sampled_inp[1], gaes=sampled_inp[2], rewards=sampled_inp[3], v_preds_next=sampled_inp[4]) summary = PPO.get_summary(obs=inp[0], actions=inp[1], gaes=inp[2], rewards=inp[3], v_preds_next=inp[4]) writer.add_summary(summary, iteration) writer.close()
class SRGAN(): def __init__(self): logger.info('Set Data Loader') self.dataset = AnimeFaceDataset( avatar_tag_dat_path=avatar_tag_dat_path, transform=transforms.Compose([ToTensor()])) self.data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) checkpoint, checkpoint_name = self.load_checkpoint(model_dump_path) if checkpoint == None: logger.info( 'Don\'t have pre-trained model. Ignore loading model process.') logger.info('Set Generator and Discriminator') self.G = Generator().to(device) self.D = Discriminator().to(device) logger.info('Initialize Weights') self.G.apply(initital_network_weights).to(device) self.D.apply(initital_network_weights).to(device) logger.info('Set Optimizers') self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.epoch = 0 else: logger.info('Load Generator and Discriminator') self.G = Generator().to(device) self.D = Discriminator().to(device) logger.info('Load Pre-Trained Weights From Checkpoint'.format( checkpoint_name)) self.G.load_state_dict(checkpoint['G']) self.D.load_state_dict(checkpoint['D']) logger.info('Load Optimizers') self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.optimizer_G.load_state_dict(checkpoint['optimizer_G']) self.optimizer_D.load_state_dict(checkpoint['optimizer_D']) self.epoch = checkpoint['epoch'] logger.info('Set Criterion') self.label_criterion = nn.BCEWithLogitsLoss().to(device) self.tag_criterion = nn.MultiLabelSoftMarginLoss().to(device) def load_checkpoint(self, model_dir): models_path = read_newest_model(model_dir) if len(models_path) == 0: return None, None models_path.sort() new_model_path = os.path.join(model_dump_path, models_path[-1]) checkpoint = torch.load(new_model_path) return checkpoint, new_model_path def train(self): iteration = -1 label = Variable(torch.FloatTensor(batch_size, 1.0)).to(device) logging.info('Current epoch: {}. Max epoch: {}.'.format( self.epoch, max_epoch)) while self.epoch <= max_epoch: # dump checkpoint torch.save( { 'epoch': self.epoch, 'D': self.D.state_dict(), 'G': self.G.state_dict(), 'optimizer_D': self.optimizer_D.state_dict(), 'optimizer_G': self.optimizer_G.state_dict(), }, '{}/checkpoint_{}.tar'.format(model_dump_path, str(self.epoch).zfill(4))) logger.info('Checkpoint saved in: {}'.format( '{}/checkpoint_{}.tar'.format(model_dump_path, str(self.epoch).zfill(4)))) msg = {} adjust_learning_rate(self.optimizer_G, iteration) adjust_learning_rate(self.optimizer_D, iteration) for i, (avatar_tag, avatar_img) in enumerate(self.data_loader): iteration += 1 if verbose: if iteration % verbose_T == 0: msg['epoch'] = int(self.epoch) msg['step'] = int(i) msg['iteration'] = iteration avatar_img = Variable(avatar_img).to(device) avatar_tag = Variable(torch.FloatTensor(avatar_tag)).to(device) # D : G = 2 : 1 # 1. Training D # 1.1. use really image for discriminating self.D.zero_grad() label_p, tag_p = self.D(avatar_img) label.data.fill_(1.0) # 1.2. real image's loss real_label_loss = self.label_criterion(label_p, label) real_tag_loss = self.tag_criterion(tag_p, avatar_tag) real_loss_sum = real_label_loss * lambda_adv / 2.0 + real_tag_loss * lambda_adv / 2.0 real_loss_sum.backward() if verbose: if iteration % verbose_T == 0: msg['discriminator real loss'] = float(real_loss_sum) # 1.3. use fake image for discriminating g_noise, fake_tag = fake_generator() fake_feat = torch.cat([g_noise, fake_tag], dim=1) fake_img = self.G(fake_feat).detach() fake_label_p, fake_tag_p = self.D(fake_img) label.data.fill_(.0) # 1.4. fake image's loss fake_label_loss = self.label_criterion(fake_label_p, label) fake_tag_loss = self.tag_criterion(fake_tag_p, fake_tag) fake_loss_sum = fake_label_loss * lambda_adv / 2.0 + fake_tag_loss * lambda_adv / 2.0 fake_loss_sum.backward() if verbose: if iteration % verbose_T == 0: msg['discriminator fake loss'] = float(fake_loss_sum) # 1.5. gradient penalty # https://github.com/jfsantos/dragan-pytorch/blob/master/dragan.py alpha_size = [1] * avatar_img.dim() alpha_size[0] = avatar_img.size(0) alpha = torch.rand(alpha_size).to(device) x_hat = Variable(alpha * avatar_img.data + (1 - alpha) * \ (avatar_img.data + 0.5 * avatar_img.data.std() * Variable(torch.rand(avatar_img.size())).to(device)), requires_grad=True).to(device) pred_hat, pred_tag = self.D(x_hat) gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones( pred_hat.size()).to(device), create_graph=True, retain_graph=True, only_inputs=True)[0].view(x_hat.size(0), -1) gradient_penalty = lambda_gp * ( (gradients.norm(2, dim=1) - 1)**2).mean() gradient_penalty.backward() if verbose: if iteration % verbose_T == 0: msg['discriminator gradient penalty'] = float( gradient_penalty) # 1.6. update optimizer self.optimizer_D.step() # 2. Training G # 2.1. generate fake image self.G.zero_grad() g_noise, fake_tag = fake_generator() fake_feat = torch.cat([g_noise, fake_tag], dim=1) fake_img = self.G(fake_feat) fake_label_p, fake_tag_p = self.D(fake_img) label.data.fill_(1.0) # 2.2. calc loss label_loss_g = self.label_criterion(fake_label_p, label) tag_loss_g = self.tag_criterion(fake_tag_p, fake_tag) loss_g = label_loss_g * lambda_adv / 2.0 + tag_loss_g * lambda_adv / 2.0 loss_g.backward() if verbose: if iteration % verbose_T == 0: msg['generator loss'] = float(loss_g) # 2.2. update optimizer self.optimizer_G.step() if verbose: if iteration % verbose_T == 0: logger.info( '------------------------------------------') for key in msg.keys(): logger.info('{} : {}'.format(key, msg[key])) # save intermediate file if iteration % verbose_T == 0: vutils.save_image( avatar_img.data.view(batch_size, 3, avatar_img.size(2), avatar_img.size(3)), os.path.join( tmp_path, 'real_image_{}.png'.format( str(iteration).zfill(8)))) g_noise, fake_tag = fake_generator() fake_feat = torch.cat([g_noise, fake_tag], dim=1) fake_img = self.G(fake_feat) vutils.save_image( fake_img.data.view(batch_size, 3, avatar_img.size(2), avatar_img.size(3)), os.path.join( tmp_path, 'fake_image_{}.png'.format( str(iteration).zfill(8)))) logger.info('Saved intermediate file in {}'.format( os.path.join( tmp_path, 'fake_image_{}.png'.format( str(iteration).zfill(8))))) self.epoch += 1
def get_discriminator(self, task_id): discriminator = Discriminator(self.args, task_id).to(self.args.device) return discriminator
RANDOM_ALL = True PRECROP = True if DATASET.lower() == 'ntu' else False VP_VALUE_COUNT = 1 if DATASET.lower() == 'ntu' else 3 CLOSE_VIEWS = True if DATASET.lower() == 'panoptic' else False vgg_weights_path, i3d_weights_path, gen_weights_path, disc_weights_path = pretrained_weights_config() # generator generator = FullNetwork(vp_value_count=VP_VALUE_COUNT, stdev=STDEV, output_shape=(BATCH_SIZE, CHANNELS, FRAMES, HEIGHT, WIDTH), pretrained=True, vgg_weights_path=vgg_weights_path, i3d_weights_path=i3d_weights_path) if GEN_PRETRAINED: generator.load_state_dict(torch.load(gen_weights_path)) generator = generator.to(device) # discriminator discriminator = Discriminator(in_channels=3, pretrained=GEN_PRETRAINED, weights_path=disc_weights_path) discriminator = discriminator.to(device) if device == 'cuda': net = torch.nn.DataParallel(generator) cudnn.benchmark = True # Loss functions criterion = nn.MSELoss() adversarial_loss = nn.BCELoss() perceptual_loss = vgg16().to(device) # categorical_loss = torch.nn.CrossEntropyLoss() # continuous_loss = torch.nn.MSELoss() optimizer_G = optim.Adam(generator.parameters(), lr=LR) optimizer_D = optim.Adam(discriminator.parameters(), lr=LR)
else: device = torch.device("cpu") random.seed(args.SEED) torch.manual_seed(args.SEED) if args.USE_CUDA: torch.cuda.manual_seed_all(args.SEED) # Initialize Generator generator = Generator(args.GAN_TYPE, args.ZDIM, args.NUM_CLASSES) generator.apply(weights_init) generator.to(device) print(generator) # Initialize Discriminator discriminator = Discriminator(args.GAN_TYPE, args.NUM_CLASSES) discriminator.apply(weights_init) discriminator.to(device) print(discriminator) # Initialize loss function and optimizer criterionLabel = nn.BCELoss() criterionClass = nn.CrossEntropyLoss() optimizerD = Adam(discriminator.parameters(), lr=args.LR, betas=(0.5, 0.999)) optimizerG = Adam(generator.parameters(), lr=args.LR, betas=(0.5, 0.999)) # Prepare the noise for evaluation during training phase fixedNoise = torch.FloatTensor(args.BATCH_SIZE, args.ZDIM, 1, 1).normal_(0, 1) if args.GAN_TYPE in ["CGAN", "ACGAN"]: fixedClass = F.one_hot(torch.LongTensor([i % args.NUM_CLASSES for i in range(args.BATCH_SIZE)]), num_classes=args.NUM_CLASSES) fixedConstraint = fixedClass.unsqueeze(-1).unsqueeze(-1)
beta_1=0.5, beta_2=0.999) optimizer_keypoint_detector = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.5, beta_2=0.999) optimizer_discriminator = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.5, beta_2=0.999) batch_size = 20 epochs = 150 train_steps = 99 # change keypoint_detector = KeypointDetector() generator = Generator() discriminator = Discriminator() generator_full = FullGenerator(keypoint_detector, generator, discriminator) discriminator_full = FullDiscriminator(discriminator) @tf.function def train_step(source_images, driving_images): with tf.GradientTape(persistent=True) as tape: losses_generator, generated = generator_full(source_images, driving_images, tape) generator_loss = tf.math.reduce_sum(list(losses_generator.values())) generator_gradients = tape.gradient(generator_loss, generator_full.trainable_variables) keypoint_detector_gradients = tape.gradient(
""" device = 'cuda' if torch.cuda.is_available() else 'cpu' RANDOM_ALL = True PRECROP = True if DATASET.lower() == 'ntu' else False VP_VALUE_COUNT = 1 if DATASET.lower() == 'ntu' else 3 CLOSE_VIEWS = True if DATASET.lower() == 'panoptic' else False # vgg_weights_path, i3d_weights_path = pretrained_weights_config() # generator generator = FullNetwork(vp_value_count=VP_VALUE_COUNT, stdev=STDEV, output_shape=(BATCH_SIZE, CHANNELS, FRAMES, HEIGHT, WIDTH)) # pretrained=True, vgg_weights_path=vgg_weights_path, i3d_weights_path=i3d_weights_path) generator = generator.to(device) # discriminator discriminator = Discriminator(in_channels=3) discriminator = discriminator.to(device) if device == 'cuda': net = torch.nn.DataParallel(generator) cudnn.benchmark = True # Loss functions criterion = nn.MSELoss() adversarial_loss = nn.BCELoss() # categorical_loss = torch.nn.CrossEntropyLoss() # continuous_loss = torch.nn.MSELoss() optimizer_G = optim.Adam(generator.parameters(), lr=LR) optimizer_D = optim.Adam(discriminator.parameters(), lr=LR)
test_noise = [[]] img = G(torch.Tensor(test_noise).to(device)) t_img = vutils.make_grid(img.data.view(1, 3, 128, 128)).numpy() t_img = np.transpose(t_img, (1, 2, 0)) t_img[t_img < 0] = 0 # min_max_scaler = preprocessing.MinMaxScaler() # t_img[..., 0] = min_max_scaler.fit_transform(t_img[..., 0]) # t_img[..., 1] = min_max_scaler.fit_transform(t_img[..., 1]) # t_img[..., 2] = min_max_scaler.fit_transform(t_img[..., 2]) plt.imshow(t_img) plt.show() label_p, tag_p = D(img) label = Variable(torch.FloatTensor(1, 1.0)).to(device) lbl_criterion = nn.BCEWithLogitsLoss().to(device) loss = lbl_criterion(label_p, label) * 10000 loss.backward() print(grad_list[0][0].shape) if __name__ == '__main__': checkpoint, _ = load_checkpoint(model_dump_path) G = Generator().to(device) G.load_state_dict(checkpoint['G']) D = Discriminator().to(device) D.load_state_dict(checkpoint['D']) # generate(G, 'test', ['white hair']) image_backward_D(G, D)
class SRGAN(): def __init__(self): logger.info('Set Data Loader') self.dataset = FoodDataset(transform=transforms.Compose([ToTensor()])) self.data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) checkpoint, checkpoint_name = self.load_checkpoint(model_dump_path) if checkpoint == None: logger.info( 'Don\'t have pre-trained model. Ignore loading model process.') logger.info('Set Generator and Discriminator') self.G = Generator(tag=tag_size).to(device) self.D = Discriminator(tag=tag_size).to(device) logger.info('Initialize Weights') self.G.apply(initital_network_weights).to(device) self.D.apply(initital_network_weights).to(device) logger.info('Set Optimizers') self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.epoch = 0 else: logger.info('Load Generator and Discriminator') self.G = Generator(tag=tag_size).to(device) self.D = Discriminator(tag=tag_size).to(device) logger.info('Load Pre-Trained Weights From Checkpoint'.format( checkpoint_name)) self.G.load_state_dict(checkpoint['G']) self.D.load_state_dict(checkpoint['D']) logger.info('Load Optimizers') self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.optimizer_G.load_state_dict(checkpoint['optimizer_G']) self.optimizer_D.load_state_dict(checkpoint['optimizer_D']) self.epoch = checkpoint['epoch'] logger.info('Set Criterion') self.a_D = alexnet.alexnet(num_classes=tag_size).to(device) self.optimizer_a_D = torch.optim.Adam(self.a_D.parameters(), lr=learning_rate, betas=(beta_1, .999)) # self.label_criterion = nn.BCEWithLogitsLoss().to(device) # self.tag_criterion = nn.BCEWithLogitsLoss().to(device) def load_checkpoint(self, model_dir): models_path = utils.read_newest_model(model_dir) if len(models_path) == 0: return None, None models_path.sort() new_model_path = os.path.join(model_dump_path, models_path[-1]) if torch.cuda.is_available(): checkpoint = torch.load(new_model_path) else: checkpoint = torch.load( new_model_path, map_location='cuda' if torch.cuda.is_available() else 'cpu') return checkpoint, new_model_path def train(self): iteration = -1 label = Variable(torch.FloatTensor(batch_size, 1)).to(device) logging.info('Current epoch: {}. Max epoch: {}.'.format( self.epoch, max_epoch)) while self.epoch <= max_epoch: msg = {} adjust_learning_rate(self.optimizer_G, iteration) adjust_learning_rate(self.optimizer_D, iteration) for i, (food_tag, food_img) in enumerate(self.data_loader): iteration += 1 if food_img.shape[0] != batch_size: logging.warn('Batch size not satisfied. Ignoring.') continue if verbose: if iteration % verbose_T == 0: msg['epoch'] = int(self.epoch) msg['step'] = int(i) msg['iteration'] = iteration food_img = Variable(food_img).to(device) # 0. training assistant D self.a_D.zero_grad() a_D_feat = self.a_D(food_img) # 1. Training D # 1.1. use really image for discriminating self.D.zero_grad() label_p = self.D(food_img) label.data.fill_(1.0) # 1.2. real image's loss # real_label_loss = self.label_criterion(label_p, label) real_label_loss = F.binary_cross_entropy(label_p, label) real_loss_sum = real_label_loss real_loss_sum.backward() if verbose: if iteration % verbose_T == 0: msg['discriminator real loss'] = float(real_loss_sum) # 1.3. use fake image for discriminating g_noise, fake_tag = utils.fake_generator( batch_size, noise_size, device) fake_feat = torch.cat([g_noise, fake_tag], dim=1) fake_img = self.G(fake_feat).detach() fake_label_p = self.D(fake_img) label.data.fill_(.0) # 1.4. fake image's loss # fake_label_loss = self.label_criterion(fake_label_p, label) fake_label_loss = F.binary_cross_entropy(fake_label_p, label) # TODO: fake_loss_sum = fake_label_loss fake_loss_sum.backward() if verbose: if iteration % verbose_T == 0: print('predicted fake label: {}'.format(fake_label_p)) msg['discriminator fake loss'] = float(fake_loss_sum) # 1.6. update optimizer self.optimizer_D.step() # 2. Training G # 2.1. generate fake image self.G.zero_grad() g_noise, fake_tag = utils.fake_generator( batch_size, noise_size, device) fake_feat = torch.cat([g_noise, fake_tag], dim=1) fake_img = self.G(fake_feat) fake_label_p = self.D(fake_img) label.data.fill_(1.0) a_D_feat = self.a_D(fake_img) feat_loss = F.binary_cross_entropy(a_D_feat, fake_tag) # 2.2. calc loss # label_loss_g = self.label_criterion(fake_label_p, label) label_loss_g = F.binary_cross_entropy(fake_label_p, label) loss_g = label_loss_g loss_g.backward() if verbose: if iteration % verbose_T == 0: msg['generator loss'] = float(loss_g) # 2.2. update optimizer self.optimizer_G.step() if verbose: if iteration % verbose_T == 0: logger.info( '------------------------------------------') for key in msg.keys(): logger.info('{} : {}'.format(key, msg[key])) # save intermediate file if iteration % 10000 == 0: torch.save( { 'epoch': self.epoch, 'D': self.D.state_dict(), 'G': self.G.state_dict(), 'optimizer_D': self.optimizer_D.state_dict(), 'optimizer_G': self.optimizer_G.state_dict(), }, '{}/checkpoint_{}.tar'.format(model_dump_path, str(iteration).zfill(8))) logger.info('Checkpoint saved in: {}'.format( '{}/checkpoint_{}.tar'.format( model_dump_path, str(iteration).zfill(8)))) if iteration % verbose_T == 0: vutils.save_image( food_img.data.view(batch_size, 3, food_img.size(2), food_img.size(3)), os.path.join( tmp_path, 'real_image_{}.png'.format( str(iteration).zfill(8)))) g_noise, fake_tag = utils.fake_generator( batch_size, noise_size, device) fake_feat = torch.cat([g_noise, fake_tag], dim=1) fake_img = self.G(fake_feat) vutils.save_image( fake_img.data.view(batch_size, 3, food_img.size(2), food_img.size(3)), os.path.join( tmp_path, 'fake_image_{}.png'.format( str(iteration).zfill(8)))) logger.info('Saved intermediate file in {}'.format( os.path.join( tmp_path, 'fake_image_{}.png'.format( str(iteration).zfill(8))))) # dump checkpoint torch.save( { 'epoch': self.epoch, 'D': self.D.state_dict(), 'G': self.G.state_dict(), 'optimizer_D': self.optimizer_D.state_dict(), 'optimizer_G': self.optimizer_G.state_dict(), }, '{}/checkpoint_{}.tar'.format(model_dump_path, str(self.epoch).zfill(4))) logger.info('Checkpoint saved in: {}'.format( '{}/checkpoint_{}.tar'.format(model_dump_path, str(self.epoch).zfill(4)))) self.epoch += 1
class Trainer(): def __init__(self, opt): self.device = torch.device('cuda') self.opt = opt self.G = Generator(self.opt['network_G']).to(self.device) util.init_weights(self.G, init_type='kaiming', scale=0.1) if self.opt['path']['pretrain_G']: self.G.load_state_dict(torch.load(self.opt['path']['pretrain_G']), strict=True) self.D = Discriminator(self.opt['network_D']).to(self.device) util.init_weights(self.D, init_type='kaiming', scale=1) self.FE = VGGFeatureExtractor().to(self.device) self.G.train() self.D.train() self.FE.eval() self.log_dict = OrderedDict() self.optim_params = [ v for k, v in self.G.named_parameters() if v.requires_grad ] self.opt_G = torch.optim.Adam(self.optim_params, lr=self.opt['train']['lr_G'], betas=(self.opt['train']['b1_G'], self.opt['train']['b2_G'])) self.opt_D = torch.optim.Adam(self.D.parameters(), lr=self.opt['train']['lr_D'], betas=(self.opt['train']['b1_D'], self.opt['train']['b2_D'])) self.optimizers = [self.opt_G, self.opt_D] self.schedulers = [ lr_scheduler.MultiStepLR(optimizer, self.opt['train']['lr_steps'], self.opt['train']['lr_gamma']) for optimizer in self.optimizers ] def update_learning_rate(self): for scheduler in self.schedulers: scheduler.step() def get_current_log(self): return self.log_dict def get_current_learning_rate(self): return self.schedulers[0].get_lr()[0] def load_model(self, step, strict=True): self.G.load_state_dict(torch.load( f"{self.opt['path']['checkpoints']['models']}/{step}_G.pth"), strict=strict) self.D.load_state_dict(torch.load( f"{self.opt['path']['checkpoints']['models']}/{step}_D.pth"), strict=strict) def resume_training(self, resume_state): '''Resume the optimizers and schedulers for training''' resume_optimizers = resume_state['optimizers'] resume_schedulers = resume_state['schedulers'] assert len(resume_optimizers) == len( self.optimizers), 'Wrong lengths of optimizers' assert len(resume_schedulers) == len( self.schedulers), 'Wrong lengths of schedulers' for i, o in enumerate(resume_optimizers): self.optimizers[i].load_state_dict(o) for i, s in enumerate(resume_schedulers): self.schedulers[i].load_state_dict(s) def save_network(self, network, network_label, iter_step): util.mkdir(self.opt['path']['checkpoints']['models']) save_filename = '{}_{}.pth'.format(iter_step, network_label) save_path = os.path.join(self.opt['path']['checkpoints']['models'], save_filename) if isinstance(network, nn.DataParallel): network = network.module state_dict = network.state_dict() for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path) def save_model(self, epoch, current_step): self.save_network(self.G, 'G', current_step) self.save_network(self.D, 'D', current_step) self.save_training_state(epoch, current_step) def save_training_state(self, epoch, iter_step): '''Saves training state during training, which will be used for resuming''' state = { 'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': [] } for s in self.schedulers: state['schedulers'].append(s.state_dict()) for o in self.optimizers: state['optimizers'].append(o.state_dict()) save_filename = '{}.state'.format(iter_step) util.mkdir(self.opt['path']['checkpoints']['states']) save_path = os.path.join(self.opt['path']['checkpoints']['states'], save_filename) torch.save(state, save_path) def train(self, train_batch, step): self.lr = train_batch['LR'].to(self.device) self.hr = train_batch['HR'].to(self.device) for p in self.D.parameters(): p.requires_grad = False self.opt_G.zero_grad() self.sr = self.G(self.lr) l_g_total = 0 # pixel loss l_g_pix = self.opt['train']['wt_pix'] * cri_pix(self.sr, self.hr) l_g_total += l_g_pix # feature loss real_fea = self.FE(self.hr).detach() fake_fea = self.FE(self.sr) l_g_fea = self.opt['train']['wt_fea'] * cri_fea(fake_fea, real_fea) l_g_total += l_g_fea # ragan loss pred_g_fake = self.D(self.sr) pred_d_real = self.D(self.hr).detach() l_g_gan = self.opt['train']['wt_gan'] * ( cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan l_g_total.backward() self.opt_G.step() # D for p in self.D.parameters(): p.requires_grad = True self.opt_D.zero_grad() l_d_total = 0 pred_d_real = self.D(self.hr) pred_d_fake = self.D(self.sr.detach()) # detach to avoid BP to G l_d_real = cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_fake = cri_gan(pred_d_fake - torch.mean(pred_d_real), False) l_d_total = (l_d_real + l_d_fake) / 2 l_d_total.backward() self.opt_D.step() # set log # G self.log_dict['l_g_pix'] = l_g_pix.item() self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() # D self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() # D outputs self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) def validate(self, val_batch, current_step): avg_psnr = 0.0 avg_ssim = 0.0 idx = 0 for _, val_data in enumerate(val_batch): idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LR_path'][0]))[0] img_dir = os.path.join( self.opt['path']['checkpoints']['val_image_dir'], img_name) util.mkdir(img_dir) self.val_lr = val_data['LR'].to(self.device) self.val_hr = val_data['HR'].to(self.device) self.G.eval() with torch.no_grad(): self.val_sr = self.G(self.val_lr) self.G.train() val_LR = self.val_lr.detach()[0].float().cpu() val_SR = self.val_sr.detach()[0].float().cpu() val_HR = self.val_hr.detach()[0].float().cpu() sr_img = util.tensor2img(val_SR) # uint8 gt_img = util.tensor2img(val_HR) # uint8 # Save SR images for reference save_img_path = os.path.join( img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) cv2.imwrite(save_img_path, sr_img) # calculate PSNR crop_size = 4 gt_img = gt_img / 255. sr_img = sr_img / 255. cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += PSNR(cropped_sr_img * 255, cropped_gt_img * 255) avg_ssim += SSIM(cropped_sr_img * 255, cropped_gt_img * 255) avg_psnr = avg_psnr / idx avg_ssim = avg_ssim / idx return avg_psnr, avg_ssim
train_dataset = TrainDataset( root=args.train_data_path, scale_factor=args.scale_factor, hr_size=args.hr_size, random_crop_size=args.random_crop_size, ) test_dataset = TestDataset(root=args.test_data_path, scale_factor=args.scale_factor, hr_size=args.hr_size) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64) gen_net = Generator(num_res_blocks=args.gen_res_blocks, upscale_factor=args.scale_factor).to(device) dis_net = Discriminator(hr_size=args.random_crop_size, sigmoid=not args.no_sigmoid).to(device) print(f"Generator number of parameters: {count_parameters(gen_net)}") print(f"Discriminator number of parameters: {count_parameters(dis_net)}") gen_path = args.from_pretrained_gen if gen_path: gen_net.load_state_dict(torch.load(gen_path)) dis_path = args.from_pretrained_dis if dis_path: dis_net.load_state_dict(torch.load(dis_path)) perceptual_loss = PerceptualLoss(device=device) mse_loss = nn.MSELoss() beta1 = 0.9 opt_gen = optim.Adam(gen_net.parameters(),
run_name = 'correlation-GAN_{}'.format(config.version) wandb.init(name=run_name, dir=config.checkpoint_dir, notes=config.description) wandb.config.update(config.__dict__) device = torch.device('cuda') use_dropout = [True, True, False] drop_prob = [0.5, 0.5, 0.5] use_ac_func = [True, True, False] activation = 'relu' latent_dim = 10 gen_fc_layers = [latent_dim, 16, 32, 2] generator = Generator(gen_fc_layers, use_dropout, drop_prob, use_ac_func, activation).to(device) disc_fc_layers = [2, 32, 16, 1] discriminator = Discriminator(disc_fc_layers, use_dropout, drop_prob, use_ac_func, activation).to(device) wandb.watch([generator, discriminator]) g_optimizer = Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.9)) d_optimizer = Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.9)) wgan_gp = WGAN_GP(config, generator, discriminator, g_optimizer, d_optimizer, latent_shape) wgan_gp.train(dataloader, 200)