class LSGAN(object): def __init__(self, batch_size, adopt_gas=False): self.batch_size = batch_size self.generator = Generator(batch_size=self.batch_size, base_filter=32) self.discriminator = Discriminator(batch_size=self.batch_size, base_filter=32, adopt_gas=adopt_gas) self.generator.cuda() self.discriminator.cuda() self.gen_optimizer = RMSprop(self.generator.parameters()) self.dis_optimizer = RMSprop(self.discriminator.parameters()) def train(self, epoch, loader): self.generator.train() self.discriminator.train() self.gen_loss_sum = 0.0 self.dis_loss_sum = 0.0 for i, (batch_img, batch_tag) in enumerate(loader): # Get logits batch_img = Variable(batch_img.cuda()) batch_z = Variable(torch.randn(self.batch_size, 100).cuda()) self.gen_image = self.generator(batch_z) true_logits = self.discriminator(batch_img) fake_logits = self.discriminator(self.gen_image) # Get loss self.dis_loss = torch.sum((true_logits - 1)**2 + (fake_logits)) / 2 self.gen_loss = torch.sum((fake_logits - 1)**2) / 2 # Update self.dis_optimizer.zero_grad() self.dis_loss.backward(retain_graph=True) self.dis_loss_sum += self.dis_loss.data.cpu().numpy()[0] self.dis_optimizer.step() if i % 5 == 0: self.gen_optimizer.zero_grad() self.gen_loss.backward() self.gen_loss_sum += self.gen_loss.data.cpu().numpy()[0] self.gen_optimizer.step() if i > 300: break def eval(self): self.generator.eval() batch_z = Variable(torch.randn(32, 100).cuda()) return self.generator(batch_z)
def main(args): os.makedirs('models', exist_ok=True) os.makedirs('outputs', exist_ok=True) # -------------- dataset ---------------------------- g_train_loader = IAMDataLoader(args.batch_size, args.T, args.data_scale, chars=args.chars, points_per_char=args.points_per_char) print('number of batches:', g_train_loader.num_batches) args.c_dimension = len(g_train_loader.chars) + 1 args.U = g_train_loader.max_U # -------------- pretrain generator ---------------------------- generator = Generator(num_gaussians=args.M, mode=args.mode, c_dimension=args.c_dimension, K=args.K, U=args.U, batch_size=args.batch_size, T=args.T, bias=args.b, sample_random=args.sample_random, learning_rate=args.g_learning_rate).to(device) generator = generator.train() if args.g_path and os.path.exists(args.g_path): print('Start loading generator: %s' % (args.g_path)) generator.load_state_dict(torch.load(args.g_path)) else: print('Start pre-training generator:') pre_g(generator, g_train_loader, num_epochs=40, mode=args.mode) # -------------- pretrain discriminator ---------------------------- if args.batch_size > 16: # Do not set batch_size too large generator.batch_size = 16 args.batch_size = 16 discriminator = Discriminator(learning_rate=args.d_learning_rate, weight_decay=args.d_weight_decay).to(device) discriminator = discriminator.train() if args.d_path and os.path.exists(args.d_path): print('Start loading discriminator: %s' % (args.d_path)) discriminator.load_state_dict(torch.load(args.d_path)) else: print('Start pre-training discriminator:') pre_d(discriminator, generator, g_train_loader, num_steps=200) generator.set_learning_rate(args.ad_g_learning_rate) print('Start training discriminator:') ad_train(args, generator, discriminator, g_train_loader, num_steps=100)
def main(args): with open(vocab_path, 'rb') as f: vocab = pickle.load(f) vocab_size = len(vocab) print('vocab_size:', vocab_size) dataloader = get_loader(image_dir, caption_path, vocab, args.batch_size, crop_size, shuffle=True, num_workers=num_workers) generator = Generator(attention_dim, embedding_size, lstm_size, vocab_size, load_path=args.g_path, noise=args.noise) generator = generator.to(device) generator = generator.train() discriminator = Discriminator(vocab_size, embedding_size, lstm_size, attention_dim, load_path=args.d_path) discriminator = discriminator.to(device) discriminator = discriminator.train() if args.train_mode == 'gd': for _ in range(5): for i in range(4): generator.pre_train(dataloader, vocab) for i in range(1): discriminator.fit(generator, dataloader, vocab) elif args.train_mode == 'dg': discriminator.fit(generator, dataloader, vocab) generator.pre_train(dataloader, vocab) elif args.train_mode == 'd': discriminator.fit(generator, dataloader, vocab) elif args.train_mode == 'g': generator.pre_train(dataloader, vocab) elif args.train_mode == 'ad': for i in range(5): generator.ad_train(dataloader, discriminator, vocab, gamma=args.gamma, update_every=args.update_every, alpha_c=1.0, num_rollouts=args.num_rollouts)
class TRPO: def __init__(self, generator_env, expert_env, ent_coeff=0, g_step=4, d_step=1, vf_step=3, gamma=0.995, lam=0.97, max_kl=0.01, cg_damping=0.1): self.generator_env = generator_env self.expert_env = expert_env self.ent_coeff = ent_coeff self.g_step = g_step self.d_step = d_step self.vf_step = vf_step self.gamma = gamma self.lam = lam self.max_kl = max_kl self.cg_damping = cg_damping self.build_net(generator_env, ent_coeff, cg_damping) def build_net(self, env, ent_coeff, cg_damping): # Build two policies. Optimize the performance of pi w.r.t oldpi in each step. self.ob = tf.placeholder(dtype=tf.float32, shape=(None, ) + env.observation_space.shape, name="ob") self.pi = MultiLayerPolicy("pi", self.ob, env.action_space.shape) self.oldpi = MultiLayerPolicy("oldpi", self.ob, env.action_space.shape) self.assignNewToOld = [ tf.assign(oldv, newv) for oldv, newv in zip( self.oldpi.get_variables(), self.pi.get_variables()) ] # Build discriminator. self.d = Discriminator(name="discriminator", ob_shape=(4, ), st_shape=(6, ), ob_slice=range(4)) # KL divergence and entropy self.meanKl = tf.reduce_mean(self.oldpi.pd.kl( self.pi.pd)) # D_KL using Monte Carlo on a batch meanEnt = tf.reduce_mean( self.pi.pd.entropy()) # entropy using Monte Carlo on a batch entBonus = ent_coeff * meanEnt # surrogate gain, L(pi)=J(pi)-J(oldpi)=sum(p_new/p_old*adv) self.ac = tf.placeholder(dtype=tf.float32, shape=(None, ) + env.action_space.shape, name="ac") self.atarg = tf.placeholder( dtype=tf.float32, shape=(None, ), name="advantage") # advantage function for each action ratio = tf.exp(self.pi.pd.logp(self.ac) - self.oldpi.pd.logp(self.ac)) # p_new/p_old surrgain = tf.reduce_mean(ratio * self.atarg) # J(pi)-J(oldpi) self.optimgain = surrgain + entBonus # fisher vector product all_var_list = self.pi.get_trainable_variables() policyVars = [ v for v in all_var_list if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd") ] self.vector = tf.placeholder(dtype=tf.float32, shape=(None, ), name="vector") self.fvp = self.build_fisher_vector_product(self.meanKl, self.vector, policyVars, cg_damping) # loss and gradient self.optimgrad = tf.gradients(self.optimgain, policyVars) self.optimgrad = tf.concat( [tf.reshape(g, [int(np.prod(g.shape))]) for g in self.optimgrad], axis=0) # utils self.init = tf.global_variables_initializer() self.get_theta = tf.concat( [tf.reshape(var, [int(np.prod(var.shape))]) for var in policyVars], axis=0) self.set_theta = self.setFromTheta(policyVars) def setFromTheta(self, var_list): # count the number of elements in var_list shape_list = [var.shape.as_list() for var in var_list] size_list = [int(np.prod(shape)) for shape in shape_list] total_size = np.sum(size_list) self.theta = tf.placeholder(dtype=tf.float32, shape=(total_size, ), name="theta") # assign var_list from the given theta assign_list = [] start = 0 for (shape, var, size) in zip(shape_list, var_list, size_list): assign_list.append( tf.assign(var, tf.reshape(self.theta[start:start + size], shape))) start += size return tf.group(*assign_list) @staticmethod def build_fisher_vector_product(kl, vector, var_list, cg_damping): kl_grad_list = tf.gradients(kl, var_list) # transform vector into the same shape of matrix in var_list shape_list = [var.shape.as_list() for var in var_list] start = 0 vector_list = [] n_list = [] for shape in shape_list: n = int(np.prod(shape)) vector_list.append(tf.reshape(vector[start:start + n], shape)) n_list.append(n) start += n # element-wise product of kl_grad and vector, then add all the elements together gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zip(kl_grad_list, vector_list) ]) fvp_list = tf.gradients(gvp, var_list) fvp = tf.concat([ tf.reshape(fvp_part, (n, )) for (fvp_part, n) in zip(fvp_list, n_list) ], axis=0) return fvp + cg_damping * vector def train(self, max_episode): with tf.Session() as sess: # Preparation fvp = lambda p: sess.run(self.fvp, { self.ob: obs, self.ac: acs, self.atarg: advs, self.vector: p }) set_theta = lambda theta: sess.run(self.set_theta, {self.theta: theta}) get_theta = lambda: sess.run(self.get_theta) get_kl = lambda ob, ac, adv: sess.run(self.meanKl, { self.ob: ob, self.ac: ac, self.atarg: adv }) saver = tf.train.Saver() save_var = lambda path="./log/gail.ckpt": saver.save(sess, path) load_var = lambda path="./log/gail.ckpt": saver.restore(sess, path) assign = lambda: sess.run(self.assignNewToOld) def ob_proc(ob): target = ob[:2] r = sqrt(target[0]**2 + target[1]**2) l1 = 0.1 l2 = 0.11 q_target = np.array([ arctan2(target[1], target[0]) - arccos( (r**2 + l1**2 - l2**2) / 2 / r / l1), PI - arccos( (l1**2 + l2**2 - r**2) / 2 / l1 / l2) ]) q = arctan2(ob[4:6], ob[6:8]) return np.mod(q_target - q + PI, 2 * PI) - PI # Build generator expert = Generator(PIDPolicy(shape=(2, ), ob_proc=ob_proc), self.expert_env, self.d, 1000) generator = Generator(self.pi, self.generator_env, self.d, 1000) # Start training if glob.glob("./log/gail.ckpt.*"): with logger("load last trained data"): load_var() else: with logger("initialize variable"): sess.run(self.init) with logger("training"): for episode in range(max_episode): with logger("episode %d" % episode): if episode % 20 == 0: with logger("save data"): save_var() with logger("train generator"): for g_iter in range(self.g_step): # sample trajectory with logger("sample trajectory"): traj = generator.sample_trajectory() traj = generator.process_trajectory( traj, self.gamma, self.lam) obs, acs, advs, vtarg, vpred = traj[ "ob"], traj["ac"], traj["adv"], traj[ "tdlamret"], traj["vpred"] # normalization advs = (advs - advs.mean()) / advs.std( ) # advantage is normalized on a batch self.pi.ob_rms.update( obs ) # observation is normalized on all history data assign() # loss and gradients on this batch loss, g = sess.run( [self.optimgain, self.optimgrad], { self.ob: obs, self.ac: acs, self.atarg: advs }) if not np.allclose(g, 0): with logger("update policy"): # use conjunct gradient method to solve Hs=g, where H = nabla^2 D_KL stepdir = cg(fvp, g) sHs = 0.5 * stepdir.dot(fvp(stepdir)) lm = np.sqrt(sHs / self.max_kl) fullstep = stepdir / lm # get step from direction expertedimprove = g.dot(fullstep) surrogate_gain_before = loss stepsize = 1.0 theta_before = get_theta() for _ in range(10): theta_new = theta_before + fullstep * stepsize set_theta(theta_new) surr, kl = sess.run( [self.optimgain, self.meanKl], { self.ob: obs, self.ac: acs, self.atarg: advs }) if kl > 1.5 * self.max_kl: pass # violate kl constraint elif surr - surrogate_gain_before < 0: pass # surrogate gain not improve else: break # stepsize OK stepsize *= 5 else: set_theta(theta_before ) # find no good step # update value function with logger("update value function"): for _ in range(self.vf_step): traj = generator.sample_trajectory() traj = generator.process_trajectory( traj, self.gamma, self.lam) obs, vtarg = traj["ob"], traj[ "tdlamret"] self.pi.train_value_function( obs, vtarg) with logger("train discriminator"): for _ in range(self.d_step): traj_g = generator.sample_trajectory() traj_e = expert.sample_trajectory() self.d.train(traj_g["ob"], traj_g["st"], traj_e["ob"], traj_e["st"]) def test(self): with tf.Session() as sess: saver = tf.train.Saver() saver.restore(sess, "./log/gail.ckpt") writer = tf.summary.FileWriter("./log/graph", tf.get_default_graph()) generator = Generator(self.pi, self.generator_env, self.d, 1000, "./record/test.mp4") generator.sample_trajectory(display=True, record=True) writer.close()
gl_data_sampler = MyDatasetSampler(args.data_dir, device=device, size=args.size) disc_data_sampler = MyDatasetSampler(args.data_dir, device=device, length=3, size=args.size) compute_perceptual = PerceptualLoss().to(device) gen_optim = ranger(generator.parameters()) disc_optim = ranger(discriminator.parameters()) losses = [] for e in range(epochs): print('EPOCH {}'.format(e)) for first, second, third in tqdm(DataLoader(gl_data_sampler, batch_size=1)): generator.train(False) discriminator.train(True) for d_first, d_second, _ in DataLoader(disc_data_sampler, batch_size=1): disc_optim.zero_grad() gen_in = torch.cat([d_first[0], d_second[1]], 1) gen_out = generator(gen_in) true_out = discriminator(d_first[0]).view((-1)) fake_out = discriminator(gen_out.detach()).view((-1)) pred = torch.cat((fake_out, true_out), 0) label = torch.Tensor([0, 1]).to(device) bce_loss = F.binary_cross_entropy(pred, label) bce_loss.backward() disc_optim.step()
samples = [] losses = [] print_every = 400 sample_size = 16 if torch.cuda.is_available(): cuda = True else: cuda = False fixed_z = generate_z_vector(sample_size, z_size, cuda) D.train() G.train() if cuda: D.cuda() G.cuda() def train_discriminator(real_images, optimizer, batch_size, z_size): optimizer.zero_grad() if cuda: real_images = real_images.cuda() # Loss for real image d_real_loss = real_loss(D(real_images), cuda, smooth=True)
class GAIL: def __init__(self, exp_dir, exp_thresh, state_dim, action_dim, learn_rate, betas, _device, _gamma, load_weights=False): """ exp_dir : directory containing the expert episodes exp_thresh : parameter to control number of episodes to load as expert based on returns (lower means more episodes) state_dim : dimesnion of state action_dim : dimesnion of action learn_rate : learning rate for optimizer _device : GPU or cpu _gamma : discount factor _load_weights : load weights from directory """ # storing runtime device self.device = _device # discount factor self.gamma = _gamma # Expert trajectory self.expert = ExpertTrajectories(exp_dir, exp_thresh, gamma=self.gamma) # Defining the actor and its optimizer self.actor = ActorNetwork(state_dim).to(self.device) self.optim_actor = torch.optim.Adam(self.actor.parameters(), lr=learn_rate, betas=betas) # Defining the discriminator and its optimizer self.disc = Discriminator(state_dim, action_dim).to(self.device) self.optim_disc = torch.optim.Adam(self.disc.parameters(), lr=learn_rate, betas=betas) if not load_weights: self.actor.apply(init_weights) self.disc.apply(init_weights) else: self.load() # Loss function crtiterion self.criterion = torch.nn.BCELoss() def get_action(self, state): """ obtain action for a given state using actor network """ state = torch.tensor(state, dtype=torch.float, device=self.device).view(1, -1) return self.actor(state).cpu().data.numpy().flatten() def update(self, n_iter, batch_size=100): """ train discriminator and actor for mini-batch """ # memory to store disc_losses = np.zeros(n_iter, dtype=np.float) act_losses = np.zeros(n_iter, dtype=np.float) for i in range(n_iter): # Get expert state and actions batch exp_states, exp_actions = self.expert.sample(batch_size) exp_states = torch.FloatTensor(exp_states).to(self.device) exp_actions = torch.FloatTensor(exp_actions).to(self.device) # Get state, and actions using actor states, _ = self.expert.sample(batch_size) states = torch.FloatTensor(states).to(self.device) actions = self.actor(states) ''' train the discriminator ''' self.optim_disc.zero_grad() # label tensors exp_labels = torch.full((batch_size, 1), 1, device=self.device) policy_labels = torch.full((batch_size, 1), 0, device=self.device) # with expert transitions prob_exp = self.disc(exp_states, exp_actions) exp_loss = self.criterion(prob_exp, exp_labels) # with policy actor transitions prob_policy = self.disc(states, actions.detach()) policy_loss = self.criterion(prob_policy, policy_labels) # use backprop disc_loss = exp_loss + policy_loss disc_losses[i] = disc_loss.mean().item() disc_loss.backward() self.optim_disc.step() ''' train the actor ''' self.optim_actor.zero_grad() loss_actor = -self.disc(states, actions) act_losses[i] = loss_actor.mean().detach().item() loss_actor.mean().backward() self.optim_actor.step() print("Finished training minibatch") return act_losses, disc_losses def save( self, directory='/home/aman/Programming/RL-Project/Deterministic-GAIL/weights', name='GAIL'): torch.save(self.actor.state_dict(), '{}/{}_actor.pth'.format(directory, name)) torch.save(self.disc.state_dict(), '{}/{}_discriminator.pth'.format(directory, name)) def load( self, directory='/home/aman/Programming/RL-Project/Deterministic-GAIL/weights', name='GAIL'): print(os.getcwd()) self.actor.load_state_dict( torch.load('{}/{}_actor.pth'.format(directory, name))) self.disc.load_state_dict( torch.load('{}/{}_discriminator.pth'.format(directory, name))) def set_mode(self, mode="train"): if mode == "train": self.actor.train() self.disc.train() else: self.actor.eval() self.disc.eval()
digit (int): The digit shown in the image. """ pixels = X.reshape((28, 28)) plt.title(str(digit)) plt.imshow(pixels, cmap='gray') plt.show() X = get_normal_shaped_arrays(60000, (1, 784)) X_train, y_train, X_test, y_test = discriminator_train_test_set( X, X_train, params.DISCRIMINATOR_TRAIN_TEST_SPLIT) discriminator = Discriminator(params.DISCRIMINATOR_BATCH_SIZE, params.DISCRIMINATOR_EPOCHS) discriminator.train(X_train, y_train) print(discriminator.eval(X_test, y_test)) generator = Generator() gan = Gan(generator, discriminator) gan.set_discriminator_trainability(False) gan.show_trainable() X = get_normal_shaped_arrays(100000, (1, 16)) y = [] for _ in range(100000): y.append([0, 1]) y = np.array(y)
def train_D_With_G(): aD = Discriminator() aD.cuda() aG = Generator() aG.cuda() optimizer_g = torch.optim.Adam(aG.parameters(), lr=0.0001, betas=(0, 0.9)) optimizer_d = torch.optim.Adam(aD.parameters(), lr=0.0001, betas=(0, 0.9)) criterion = nn.CrossEntropyLoss() n_z = 100 n_classes = 10 np.random.seed(352) label = np.asarray(list(range(10)) * 10) noise = np.random.normal(0, 1, (100, n_z)) label_onehot = np.zeros((100, n_classes)) label_onehot[np.arange(100), label] = 1 noise[np.arange(100), :n_classes] = label_onehot[np.arange(100)] noise = noise.astype(np.float32) save_noise = torch.from_numpy(noise) save_noise = Variable(save_noise).cuda() start_time = time.time() # Train the model num_epochs = 500 loss1 = [] loss2 = [] loss3 = [] loss4 = [] loss5 = [] acc1 = [] for epoch in range(0, num_epochs): aG.train() aD.train() avoidOverflow(optimizer_d) avoidOverflow(optimizer_g) for batch_idx, (X_train_batch, Y_train_batch) in enumerate(trainloader): if (Y_train_batch.shape[0] < batch_size): continue # train G if batch_idx % gen_train == 0: for p in aD.parameters(): p.requires_grad_(False) aG.zero_grad() label = np.random.randint(0, n_classes, batch_size) noise = np.random.normal(0, 1, (batch_size, n_z)) label_onehot = np.zeros((batch_size, n_classes)) label_onehot[np.arange(batch_size), label] = 1 noise[np.arange(batch_size), :n_classes] = label_onehot[ np.arange(batch_size)] noise = noise.astype(np.float32) noise = torch.from_numpy(noise) noise = Variable(noise).cuda() fake_label = Variable(torch.from_numpy(label)).cuda() fake_data = aG(noise) gen_source, gen_class = aD(fake_data) gen_source = gen_source.mean() gen_class = criterion(gen_class, fake_label) gen_cost = -gen_source + gen_class gen_cost.backward() optimizer_g.step() # train D for p in aD.parameters(): p.requires_grad_(True) aD.zero_grad() # train discriminator with input from generator label = np.random.randint(0, n_classes, batch_size) noise = np.random.normal(0, 1, (batch_size, n_z)) label_onehot = np.zeros((batch_size, n_classes)) label_onehot[np.arange(batch_size), label] = 1 noise[np.arange(batch_size), :n_classes] = label_onehot[np.arange( batch_size)] noise = noise.astype(np.float32) noise = torch.from_numpy(noise) noise = Variable(noise).cuda() fake_label = Variable(torch.from_numpy(label)).cuda() with torch.no_grad(): fake_data = aG(noise) disc_fake_source, disc_fake_class = aD(fake_data) disc_fake_source = disc_fake_source.mean() disc_fake_class = criterion(disc_fake_class, fake_label) # train discriminator with input from the discriminator real_data = Variable(X_train_batch).cuda() real_label = Variable(Y_train_batch).cuda() disc_real_source, disc_real_class = aD(real_data) prediction = disc_real_class.data.max(1)[1] accuracy = (float(prediction.eq(real_label.data).sum()) / float(batch_size)) * 100.0 disc_real_source = disc_real_source.mean() disc_real_class = criterion(disc_real_class, real_label) gradient_penalty = calc_gradient_penalty(aD, real_data, fake_data) disc_cost = disc_fake_source - disc_real_source + disc_real_class + disc_fake_class + gradient_penalty disc_cost.backward() optimizer_d.step() loss1.append(gradient_penalty.item()) loss2.append(disc_fake_source.item()) loss3.append(disc_real_source.item()) loss4.append(disc_real_class.item()) loss5.append(disc_fake_class.item()) acc1.append(accuracy) if batch_idx % 50 == 0: print(epoch, batch_idx, "%.2f" % np.mean(loss1), "%.2f" % np.mean(loss2), "%.2f" % np.mean(loss3), "%.2f" % np.mean(loss4), "%.2f" % np.mean(loss5), "%.2f" % np.mean(acc1)) # Test the model aD.eval() with torch.no_grad(): test_accu = [] for batch_idx, (X_test_batch, Y_test_batch) in enumerate(testloader): X_test_batch, Y_test_batch = Variable( X_test_batch).cuda(), Variable(Y_test_batch).cuda() with torch.no_grad(): _, output = aD(X_test_batch) prediction = output.data.max(1)[ 1] # first column has actual prob. accuracy = (float(prediction.eq(Y_test_batch.data).sum()) / float(batch_size)) * 100.0 test_accu.append(accuracy) accuracy_test = np.mean(test_accu) print('Testing', accuracy_test, time.time() - start_time) # save output with torch.no_grad(): aG.eval() samples = aG(save_noise) samples = samples.data.cpu().numpy() samples += 1.0 samples /= 2.0 samples = samples.transpose(0, 2, 3, 1) aG.train() fig = plot(samples) plt.savefig('output/%s.png' % str(epoch).zfill(3), bbox_inches='tight') plt.close(fig) if (epoch + 1) % 1 == 0: torch.save(aG, 'tempG.model') torch.save(aD, 'tempD.model') torch.save(aG, 'generator.model') torch.save(aD, 'discriminator.model')
class GAN_CLS(object): def __init__(self, args, data_loader, SUPERVISED=True): """ args : Arguments data_loader = An instance of class DataLoader for loading our dataset in batches """ self.data_loader = data_loader self.num_epochs = args.num_epochs self.batch_size = args.batch_size self.log_step = args.log_step self.sample_step = args.sample_step self.log_dir = args.log_dir self.checkpoint_dir = args.checkpoint_dir self.sample_dir = args.sample_dir self.final_model = args.final_model self.model_save_step = args.model_save_step #self.dataset = args.dataset #self.model_name = args.model_name self.img_size = args.img_size self.z_dim = args.z_dim self.text_embed_dim = args.text_embed_dim self.text_reduced_dim = args.text_reduced_dim self.learning_rate = args.learning_rate self.beta1 = args.beta1 self.beta2 = args.beta2 self.l1_coeff = args.l1_coeff self.resume_epoch = args.resume_epoch self.resume_idx = args.resume_idx self.SUPERVISED = SUPERVISED # Logger setting log_name = datetime.datetime.now().strftime('%Y-%m-%d') + '.log' self.logger = logging.getLogger('__name__') self.logger.setLevel(logging.INFO) self.formatter = logging.Formatter( '%(asctime)s:%(levelname)s:%(message)s') self.file_handler = logging.FileHandler( os.path.join(self.log_dir, log_name)) self.file_handler.setFormatter(self.formatter) self.logger.addHandler(self.file_handler) self.build_model() def smooth_label(self, tensor, offset): return tensor + offset def dump_imgs(images_Array, name): with open('{}.pickle'.format(name), 'wb') as file: dump(images_Array, file) def build_model(self): """ A function of defining following instances : ----- Generator ----- Discriminator ----- Optimizer for Generator ----- Optimizer for Discriminator ----- Defining Loss functions """ # ---------------------------------------------------------------------# # 1. Network Initialization # # ---------------------------------------------------------------------# self.gen = Generator(batch_size=self.batch_size, img_size=self.img_size, z_dim=self.z_dim, text_embed_dim=self.text_embed_dim, text_reduced_dim=self.text_reduced_dim) self.disc = Discriminator(batch_size=self.batch_size, img_size=self.img_size, text_embed_dim=self.text_embed_dim, text_reduced_dim=self.text_reduced_dim) self.gen_optim = optim.Adam(self.gen.parameters(), lr=self.learning_rate, betas=(self.beta1, self.beta2)) self.disc_optim = optim.Adam(self.disc.parameters(), lr=self.learning_rate, betas=(self.beta1, self.beta2)) self.cls_gan_optim = optim.Adam(itertools.chain( self.gen.parameters(), self.disc.parameters()), lr=self.learning_rate, betas=(self.beta1, self.beta2)) print('------------- Generator Model Info ---------------') self.print_network(self.gen, 'G') print('------------------------------------------------') print('------------- Discriminator Model Info ---------------') self.print_network(self.disc, 'D') print('------------------------------------------------') self.criterion = nn.BCELoss().cuda() # self.CE_loss = nn.CrossEntropyLoss().cuda() # self.MSE_loss = nn.MSELoss().cuda() self.gen.train() self.disc.train() def print_network(self, model, name): """ A function for printing total number of model parameters """ num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("Total number of parameters: {}".format(num_params)) def load_checkpoints(self, resume_epoch, idx): """Restore the trained generator and discriminator.""" print('Loading the trained models from epoch {} and iteration {}...'. format(resume_epoch, idx)) G_path = os.path.join(self.checkpoint_dir, '{}-{}-G.ckpt'.format(resume_epoch, idx)) D_path = os.path.join(self.checkpoint_dir, '{}-{}-D.ckpt'.format(resume_epoch, idx)) self.gen.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) self.disc.load_state_dict( torch.load(D_path, map_location=lambda storage, loc: storage)) def train_model(self): data_loader = self.data_loader start_epoch = 0 if self.resume_epoch >= 0: start_epoch = self.resume_epoch self.load_checkpoints(self.resume_epoch, self.resume_idx) print('--------------- Model Training Started ---------------') start_time = time.time() for epoch in range(start_epoch, self.num_epochs): print("Epoch: {}".format(epoch + 1)) for idx, batch in enumerate(data_loader): print("Index: {}".format(idx + 1), end="\t") true_imgs = batch['true_imgs'] true_embed = batch['true_embds'] false_imgs = batch['false_imgs'] real_labels = torch.ones(true_imgs.size(0)) fake_labels = torch.zeros(true_imgs.size(0)) smooth_real_labels = torch.FloatTensor( self.smooth_label(real_labels.numpy(), -0.1)) true_imgs = Variable(true_imgs.float()).cuda() true_embed = Variable(true_embed.float()).cuda() false_imgs = Variable(false_imgs.float()).cuda() real_labels = Variable(real_labels).cuda() smooth_real_labels = Variable(smooth_real_labels).cuda() fake_labels = Variable(fake_labels).cuda() # ---------------------------------------------------------------# # 2. Training the generator # # ---------------------------------------------------------------# self.gen.zero_grad() z = Variable(torch.randn(true_imgs.size(0), self.z_dim)).cuda() fake_imgs = self.gen.forward(true_embed, z) fake_out, fake_logit = self.disc.forward(fake_imgs, true_embed) fake_out = Variable(fake_out.data, requires_grad=True).cuda() true_out, true_logit = self.disc.forward(true_imgs, true_embed) true_out = Variable(true_out.data, requires_grad=True).cuda() g_sf = self.criterion(fake_out, real_labels) #g_img = self.l1_coeff * nn.L1Loss()(fake_imgs, true_imgs) gen_loss = g_sf gen_loss.backward() self.gen_optim.step() # ---------------------------------------------------------------# # 3. Training the discriminator # # ---------------------------------------------------------------# self.disc.zero_grad() false_out, false_logit = self.disc.forward( false_imgs, true_embed) false_out = Variable(false_out.data, requires_grad=True) sr = self.criterion(true_out, smooth_real_labels) sw = self.criterion(true_out, fake_labels) sf = self.criterion(false_out, smooth_real_labels) disc_loss = torch.log(sr) + (torch.log(1 - sw) + torch.log(1 - sf)) / 2 disc_loss.backward() self.disc_optim.step() self.cls_gan_optim.step() # Logging loss = {} loss['G_loss'] = gen_loss.item() loss['D_loss'] = disc_loss.item() # ---------------------------------------------------------------# # 4. Logging INFO into log_dir # # ---------------------------------------------------------------# log = "" if (idx + 1) % self.log_step == 0: end_time = time.time() - start_time end_time = datetime.timedelta(seconds=end_time) log = "Elapsed [{}], Epoch [{}/{}], Idx [{}]".format( end_time, epoch + 1, self.num_epochs, idx) for net, loss_value in loss.items(): log += "{}: {:.4f}".format(net, loss_value) self.logger.info(log) print(log) """ # ---------------------------------------------------------------# # 5. Saving generated images # # ---------------------------------------------------------------# if (idx + 1) % self.sample_step == 0: concat_imgs = torch.cat((true_imgs, fake_imgs), 0) # ?????????? concat_imgs = (concat_imgs + 1) / 2 # out.clamp_(0, 1) save_path = os.path.join(self.sample_dir, '{}-{}-images.jpg'.format(epoch, idx + 1)) # concat_imgs.cpu().detach().numpy() self.dump_imgs(concat_imgs.cpu().numpy(), save_path) #save_image(concat_imgs.data.cpu(), self.sample_dir, nrow=1, padding=0) print ('Saved real and fake images into {}...'.format(self.sample_dir)) """ # ---------------------------------------------------------------# # 6. Saving the checkpoints & final model # # ---------------------------------------------------------------# if (idx + 1) % self.model_save_step == 0: G_path = os.path.join( self.checkpoint_dir, '{}-{}-G.ckpt'.format(epoch, idx + 1)) D_path = os.path.join( self.checkpoint_dir, '{}-{}-D.ckpt'.format(epoch, idx + 1)) torch.save(self.gen.state_dict(), G_path) torch.save(self.disc.state_dict(), D_path) print('Saved model checkpoints into {}...\n'.format( self.checkpoint_dir)) print('--------------- Model Training Completed ---------------') # Saving final model into final_model directory G_path = os.path.join(self.final_model, '{}-G.pth'.format('final')) D_path = os.path.join(self.final_model, '{}-D.pth'.format('final')) torch.save(self.gen.state_dict(), G_path) torch.save(self.disc.state_dict(), D_path) print('Saved final model into {}...'.format(self.final_model))
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) else: dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 2 # 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 2 # 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Generator loaded successfully!") discriminator = Discriminator(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Discriminator loaded successfully!") g_model_path = 'checkpoints/zhenwarm/generator.pt' assert os.path.exists(g_model_path) # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() model = torch.load(g_model_path) pretrained_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict generator.load_state_dict(model_dict) print("pre-trained Generator loaded successfully!") # # Load discriminator model d_model_path = 'checkpoints/zhenwarm/discri.pt' assert os.path.exists(d_model_path) # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) d_model_dict = discriminator.state_dict() d_model = torch.load(d_model_path) d_pretrained_dict = d_model.state_dict() # 1. filter out unnecessary keys d_pretrained_dict = { k: v for k, v in d_pretrained_dict.items() if k in d_model_dict } # 2. overwrite entries in the existing state dict d_model_dict.update(d_pretrained_dict) # 3. load the new state dict discriminator.load_state_dict(d_model_dict) print("pre-trained Discriminator loaded successfully!") if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/myzhencli5'): os.makedirs('checkpoints/myzhencli5') checkpoints_path = 'checkpoints/myzhencli5/' # define loss function g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum') d_criterion = torch.nn.BCELoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator word embedding (as Wu et al. do) for p in discriminator.embed_src_tokens.parameters(): p.requires_grad = False for p in discriminator.embed_trg_tokens.parameters(): p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter( lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset trainloader = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(trainloader): # set training mode generator.train() discriminator.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) ## part I: use gradient policy method to train the generator # use policy gradient training when random.random() > 50% if random.random() >= 0.5: print("Policy Gradient Training") sys_out_batch = generator(sample) # 64 X 50 X 6632 out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 * 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64*50 = 3200 prediction = torch.reshape( prediction, sample['net_input']['src_tokens'].shape) # 64 X 50 with torch.no_grad(): reward = discriminator(sample['net_input']['src_tokens'], prediction) # 64 X 1 train_trg_batch = sample['target'] # 64 x 50 pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward, use_cuda) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] # 64 logging_loss = pg_loss / math.log(2) g_logging_meters['train_loss'].update(logging_loss.item(), sample_size) logging.debug( f"G policy gradient loss at batch {i}: {pg_loss.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() pg_loss.backward() torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() else: # MLE training print("MLE Training") sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 train_trg_batch = sample['target'].view(-1) # 64*50 = 3200 loss = g_criterion(out_batch, train_trg_batch) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # part II: train the discriminator if num_update % 5 == 0: bsz = sample['target'].size(0) # batch_size = 64 src_sentence = sample['net_input'][ 'src_tokens'] # 64 x max-len i.e 64 X 50 # now train with machine translation output i.e generator output true_sentence = sample['target'].view(-1) # 64*50 = 3200 true_labels = Variable( torch.ones( sample['target'].size(0)).float()) # 64 length vector with torch.no_grad(): sys_out_batch = generator(sample) # 64 X 50 X 6632 out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64 * 50 = 6632 fake_labels = Variable( torch.zeros( sample['target'].size(0)).float()) # 64 length vector fake_sentence = torch.reshape(prediction, src_sentence.shape) # 64 X 50 true_sentence = torch.reshape(true_sentence, src_sentence.shape) if use_cuda: fake_labels = fake_labels.cuda() true_labels = true_labels.cuda() # fake_disc_out = discriminator(src_sentence, fake_sentence) # 64 X 1 # true_disc_out = discriminator(src_sentence, true_sentence) # # fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) # true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) # # fake_acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) # true_acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) # acc = (fake_acc + true_acc) / 2 # # d_loss = fake_d_loss + true_d_loss if random.random() > 0.5: fake_disc_out = discriminator(src_sentence, fake_sentence) fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) fake_acc = torch.sum( torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) d_loss = fake_d_loss acc = fake_acc else: true_disc_out = discriminator(src_sentence, true_sentence) true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) true_acc = torch.sum( torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) d_loss = true_d_loss acc = true_acc d_logging_meters['train_acc'].update(acc) d_logging_meters['train_loss'].update(d_loss) logging.debug( f"D training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}" ) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() if num_update % 10000 == 0: # validation # set validation mode generator.eval() discriminator.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) valloader = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending= True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(valloader): with torch.no_grad(): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) # generator validation sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 dev_trg_batch = sample['target'].view( -1) # 64*50 = 3200 loss = g_criterion(out_batch, dev_trg_batch) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss / sample_size / math.log(2) g_logging_meters['valid_loss'].update( loss, sample_size) logging.debug( f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}" ) # discriminator validation bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64 * 50 = 6632 fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) fake_sentence = torch.reshape( prediction, src_sentence.shape) # 64 X 50 true_sentence = torch.reshape(true_sentence, src_sentence.shape) if use_cuda: fake_labels = fake_labels.cuda() true_labels = true_labels.cuda() fake_disc_out = discriminator(src_sentence, fake_sentence) # 64 X 1 true_disc_out = discriminator(src_sentence, true_sentence) fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) d_loss = fake_d_loss + true_d_loss fake_acc = torch.sum( torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) true_acc = torch.sum( torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) acc = (fake_acc + true_acc) / 2 d_logging_meters['valid_acc'].update(acc) d_logging_meters['valid_loss'].update(d_loss) logging.debug( f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}" ) # torch.save(discriminator, # open(checkpoints_path + f"numupdate_{num_update/10000}k.discri_{d_logging_meters['valid_loss'].avg:.3f}.pt",'wb'), pickle_module=dill) # if d_logging_meters['valid_loss'].avg < best_dev_loss: # best_dev_loss = d_logging_meters['valid_loss'].avg # torch.save(discriminator, open(checkpoints_path + "best_dmodel.pt", 'wb'), pickle_module=dill) torch.save( generator, open( checkpoints_path + f"numupdate_{num_update/10000}k.joint_{g_logging_meters['valid_loss'].avg:.3f}.pt", 'wb'), pickle_module=dill)
def main(args): # log hyperparameter print(args) # select device args.cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda: 0" if args.cuda else "cpu") # set random seed np.random.seed(args.seed) torch.manual_seed(args.seed) # data loader transform = transforms.Compose([ utils.Normalize(), utils.ToTensor() ]) train_dataset = TVDataset( root=args.root, sub_size=args.block_size, volume_list=args.volume_train_list, max_k=args.training_step, train=True, transform=transform ) test_dataset = TVDataset( root=args.root, sub_size=args.block_size, volume_list=args.volume_test_list, max_k=args.training_step, train=False, transform=transform ) kwargs = {"num_workers": 4, "pin_memory": True} if args.cuda else {} train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs) # model def generator_weights_init(m): if isinstance(m, nn.Conv3d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.zeros_(m.bias) def discriminator_weights_init(m): if isinstance(m, nn.Conv3d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') if m.bias is not None: nn.init.zeros_(m.bias) g_model = Generator(args.upsample_mode, args.forward, args.backward, args.gen_sn, args.residual) g_model.apply(generator_weights_init) if args.data_parallel and torch.cuda.device_count() > 1: g_model = nn.DataParallel(g_model) g_model.to(device) if args.gan_loss != "none": d_model = Discriminator(args.dis_sn) d_model.apply(discriminator_weights_init) # if args.dis_sn: # d_model = add_sn(d_model) if args.data_parallel and torch.cuda.device_count() > 1: d_model = nn.DataParallel(d_model) d_model.to(device) mse_loss = nn.MSELoss() adversarial_loss = nn.MSELoss() train_losses, test_losses = [], [] d_losses, g_losses = [], [] # optimizer g_optimizer = optim.Adam(g_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) if args.gan_loss != "none": d_optimizer = optim.Adam(d_model.parameters(), lr=args.d_lr, betas=(args.beta1, args.beta2)) Tensor = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor # load checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint {}".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint["epoch"] g_model.load_state_dict(checkpoint["g_model_state_dict"]) # g_optimizer.load_state_dict(checkpoint["g_optimizer_state_dict"]) if args.gan_loss != "none": d_model.load_state_dict(checkpoint["d_model_state_dict"]) # d_optimizer.load_state_dict(checkpoint["d_optimizer_state_dict"]) d_losses = checkpoint["d_losses"] g_losses = checkpoint["g_losses"] train_losses = checkpoint["train_losses"] test_losses = checkpoint["test_losses"] print("=> load chekcpoint {} (epoch {})" .format(args.resume, checkpoint["epoch"])) # main loop for epoch in tqdm(range(args.start_epoch, args.epochs)): # training.. g_model.train() if args.gan_loss != "none": d_model.train() train_loss = 0. volume_loss_part = np.zeros(args.training_step) for i, sample in enumerate(train_loader): params = list(g_model.named_parameters()) # pdb.set_trace() # params[0][1].register_hook(lambda g: print("{}.grad: {}".format(params[0][0], g))) # adversarial ground truths real_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(1.0), requires_grad=False) fake_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(0.0), requires_grad=False) v_f = sample["v_f"].to(device) v_b = sample["v_b"].to(device) v_i = sample["v_i"].to(device) g_optimizer.zero_grad() fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm) # adversarial loss # update discriminator if args.gan_loss != "none": avg_d_loss = 0. avg_d_loss_real = 0. avg_d_loss_fake = 0. for k in range(args.n_d): d_optimizer.zero_grad() decisions = d_model(v_i) d_loss_real = adversarial_loss(decisions, real_label) fake_decisions = d_model(fake_volumes.detach()) d_loss_fake = adversarial_loss(fake_decisions, fake_label) d_loss = d_loss_real + d_loss_fake d_loss.backward() avg_d_loss += d_loss.item() / args.n_d avg_d_loss_real += d_loss_real / args.n_d avg_d_loss_fake += d_loss_fake / args.n_d d_optimizer.step() # update generator if args.gan_loss != "none": avg_g_loss = 0. avg_loss = 0. for k in range(args.n_g): loss = 0. g_optimizer.zero_grad() # adversarial loss if args.gan_loss != "none": fake_decisions = d_model(fake_volumes) g_loss = args.gan_loss_weight * adversarial_loss(fake_decisions, real_label) loss += g_loss avg_g_loss += g_loss.item() / args.n_g # volume loss if args.volume_loss: volume_loss = args.volume_loss_weight * mse_loss(v_i, fake_volumes) for j in range(v_i.shape[1]): volume_loss_part[j] += mse_loss(v_i[:, j, :], fake_volumes[:, j, :]) / args.n_g / args.log_every loss += volume_loss # feature loss if args.feature_loss: feat_real = d_model.extract_features(v_i) feat_fake = d_model.extract_features(fake_volumes) for m in range(len(feat_real)): loss += args.feature_loss_weight / len(feat_real) * mse_loss(feat_real[m], feat_fake[m]) avg_loss += loss / args.n_g loss.backward() g_optimizer.step() train_loss += avg_loss # log training status subEpoch = (i + 1) // args.log_every if (i+1) % args.log_every == 0: print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( epoch, (i+1) * args.batch_size, len(train_loader.dataset), 100. * (i+1) / len(train_loader), avg_loss )) print("Volume Loss: ") for j in range(volume_loss_part.shape[0]): print("\tintermediate {}: {:.6f}".format( j+1, volume_loss_part[j] )) if args.gan_loss != "none": print("DLossReal: {:.6f} DLossFake: {:.6f} DLoss: {:.6f}, GLoss: {:.6f}".format( avg_d_loss_real, avg_d_loss_fake, avg_d_loss, avg_g_loss )) d_losses.append(avg_d_loss) g_losses.append(avg_g_loss) # train_losses.append(avg_loss) train_losses.append(train_loss.item() / args.log_every) print("====> SubEpoch: {} Average loss: {:.6f} Time {}".format( subEpoch, train_loss.item() / args.log_every, time.asctime(time.localtime(time.time())) )) train_loss = 0. volume_loss_part = np.zeros(args.training_step) # testing... if (i + 1) % args.test_every == 0: g_model.eval() if args.gan_loss != "none": d_model.eval() test_loss = 0. with torch.no_grad(): for i, sample in enumerate(test_loader): v_f = sample["v_f"].to(device) v_b = sample["v_b"].to(device) v_i = sample["v_i"].to(device) fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm) test_loss += args.volume_loss_weight * mse_loss(v_i, fake_volumes).item() test_losses.append(test_loss * args.batch_size / len(test_loader.dataset)) print("====> SubEpoch: {} Test set loss {:4f} Time {}".format( subEpoch, test_losses[-1], time.asctime(time.localtime(time.time())) )) # saving... if (i+1) % args.check_every == 0: print("=> saving checkpoint at epoch {}".format(epoch)) if args.gan_loss != "none": torch.save({"epoch": epoch + 1, "g_model_state_dict": g_model.state_dict(), "g_optimizer_state_dict": g_optimizer.state_dict(), "d_model_state_dict": d_model.state_dict(), "d_optimizer_state_dict": d_optimizer.state_dict(), "d_losses": d_losses, "g_losses": g_losses, "train_losses": train_losses, "test_losses": test_losses}, os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar") ) else: torch.save({"epoch": epoch + 1, "g_model_state_dict": g_model.state_dict(), "g_optimizer_state_dict": g_optimizer.state_dict(), "train_losses": train_losses, "test_losses": test_losses}, os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar") ) torch.save(g_model.state_dict(), os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + ".pth")) num_subEpoch = len(train_loader) // args.log_every print("====> Epoch: {} Average loss: {:.6f} Time {}".format( epoch, np.array(train_losses[-num_subEpoch:]).mean(), time.asctime(time.localtime(time.time())) ))
def run_experiment(fraction_train, load_model=False, old_runname=None, start_epoch=None): runname = f'splitted_data_{str(fraction_train)}' device = Device.GPU1 epochs = 50 features = 64 batch_size = 4 all_image_size = 96 in_chan = 15 context = cpu() if device.value == -1 else gpu(device.value) # ---------------------------------------------------- if load_model: summaryWriter = SummaryWriter('logs/' + old_runname, flush_secs=5) else: summaryWriter = SummaryWriter('logs/' + runname, flush_secs=5) train_iter = modules.make_iterator_preprocessed( 'training', 'V1', 'V2', 'V3', batch_size=batch_size, shuffle=True, fraction_train=fraction_train) test_iter = modules.make_iterator_preprocessed('testing', 'V1', 'V2', 'V3', batch_size=batch_size, shuffle=True) RFlocs_V1_overlapped_avg = modules.get_RFs('V1', context) RFlocs_V2_overlapped_avg = modules.get_RFs('V2', context) RFlocs_V3_overlapped_avg = modules.get_RFs('V3', context) with Context(context): discriminator = Discriminator(in_chan) generator = Generator(in_chan, context) if load_model: generator.network.load_parameters( f'saved_models/{old_runname}/netG_{start_epoch}.model', ctx=context) discriminator.network.load_parameters( f'saved_models/{old_runname}/netD_{start_epoch}.model') gen_lossfun = gen.Lossfun(1, 100, 1, context) d = discriminator.network dis_lossfun = dis.Lossfun(1) g = generator.network print('train_dataset_length:', len(train_iter._dataset)) for epoch in range(epochs): loss_discriminator_train = [] loss_generator_train = [] # ==================== # T R AI N I N G # ==================== for RFsignalsV1, RFsignalsV2, RFsignalsV3, targets in tqdm( train_iter, total=len(train_iter)): # ------- # Inputs # ------- inputs1 = modules.get_inputsROI(RFsignalsV1, RFlocs_V1_overlapped_avg, context) inputs2 = modules.get_inputsROI(RFsignalsV2, RFlocs_V2_overlapped_avg, context) inputs3 = modules.get_inputsROI(RFsignalsV3, RFlocs_V3_overlapped_avg, context) inputs = concat(inputs1, inputs2, inputs3, dim=1) # ------------------------------------ # T R A I N D i s c r i m i n a t o r # ------------------------------------ targets = targets.as_in_context(context).transpose( (0, 1, 3, 2)) loss_discriminator_train.append( discriminator.train(g, inputs, targets)) # ---------------------------- # T R A I N G e n e r a t o r # ---------------------------- loss_generator_train.append(generator.train( d, inputs, targets)) if load_model: os.makedirs('saved_models/' + old_runname, exist_ok=True) generator.network.save_parameters( f'saved_models/{old_runname}/netG_{epoch+start_epoch+1}.model' ) discriminator.network.save_parameters( f'saved_models/{old_runname}/netD_{epoch+start_epoch+1}.model' ) else: os.makedirs('saved_models/' + runname, exist_ok=True) generator.network.save_parameters( f'saved_models/{runname}/netG_{epoch}.model') discriminator.network.save_parameters( f'saved_models/{runname}/netD_{epoch}.model') # ==================== # T E S T I N G # ==================== loss_discriminator_test = [] loss_generator_test = [] for RFsignalsV1, RFsignalsV2, RFsignalsV3, targets in test_iter: # ------- # Inputs # ------- inputs1 = modules.get_inputsROI(RFsignalsV1, RFlocs_V1_overlapped_avg, context) inputs2 = modules.get_inputsROI(RFsignalsV2, RFlocs_V2_overlapped_avg, context) inputs3 = modules.get_inputsROI(RFsignalsV3, RFlocs_V3_overlapped_avg, context) inputs = concat(inputs1, inputs2, inputs3, dim=1) # ----- # Targets # ----- targets = targets.as_in_context(context).transpose( (0, 1, 3, 2)) # ---- # sample randomly from history buffer (capacity 50) # ---- z = concat(inputs, g(inputs), dim=1) dis_loss_test = 0.5 * (dis_lossfun(0, d(z)) + dis_lossfun( 1, d(concat(inputs, targets, dim=1)))) loss_discriminator_test.append(float(dis_loss_test.asscalar())) gen_loss_test = (lambda y_hat: gen_lossfun( 1, d(concat(inputs, y_hat, dim=1)), targets, y_hat))( generator.network(inputs)) loss_generator_test.append(float(gen_loss_test.asscalar())) summaryWriter.add_image( "input", modules.leclip(inputs.expand_dims(2).sum(1)), epoch) summaryWriter.add_image("target", modules.leclip(targets), epoch) summaryWriter.add_image("pred", modules.leclip(g(inputs)), epoch) summaryWriter.add_scalar( "dis/loss_discriminator_train", sum(loss_discriminator_train) / len(loss_discriminator_train), epoch) summaryWriter.add_scalar( "gen/loss_generator_train", sum(loss_generator_train) / len(loss_generator_train), epoch) summaryWriter.add_scalar( "dis/loss_discriminator_test", sum(loss_discriminator_test) / len(loss_discriminator_test), epoch) summaryWriter.add_scalar( "gen/loss_generator_test", sum(loss_generator_test) / len(loss_generator_test), epoch) # ------------------------------------------------------------------ # T R A I N I N G Losses # ------------------------------------------------------------------ np.save(f'saved_models/{runname}/Gloss_train', np.array(loss_generator_train)) np.save(f'saved_models/{runname}/Dloss_train', np.array(loss_discriminator_train)) # ------------------------------------------------------------------ # T E S T I N G Losses # ------------------------------------------------------------------ np.save(f'saved_models/{runname}/Gloss_test', np.array(loss_generator_test)) np.save(f'saved_models/{runname}/Dloss_test', np.array(loss_discriminator_test))
class MGAIL(object): def __init__(self, environment, use_irl=False): self.use_irl = use_irl self.env = environment # Create placeholders for all the inputs self.states_ = tf.compat.v1.placeholder("float", shape=(None, self.env.state_size), name='states_') # Batch x State self.states = tf.compat.v1.placeholder("float", shape=(None, self.env.state_size), name='states') # Batch x State self.actions = tf.compat.v1.placeholder("float", shape=(None, self.env.action_size), name='action') # Batch x Action self.label = tf.compat.v1.placeholder("float", shape=(None, 1), name='label') self.gamma = tf.compat.v1.placeholder("float", shape=(), name='gamma') self.temp = tf.compat.v1.placeholder("float", shape=(), name='temperature') self.noise = tf.compat.v1.placeholder("float", shape=(), name='noise_flag') self.do_keep_prob = tf.compat.v1.placeholder("float", shape=(), name='do_keep_prob') self.lprobs = tf.compat.v1.placeholder('float', shape=(None, 1), name='log_probs') # Create MGAIL blocks self.forward_model = ForwardModel(state_size=self.env.state_size, action_size=self.env.action_size, encoding_size=self.env.fm_size, lr=self.env.fm_lr) # MODIFYING THE NEW DISCRIMINATOR: if self.use_irl: self.discriminator = DiscriminatorIRL(in_dim=self.env.state_size + self.env.action_size, out_dim=1, size=self.env.d_size, lr=self.env.d_lr, do_keep_prob=self.do_keep_prob, weight_decay=self.env.weight_decay, state_only=True, gamma=self.gamma, state_size = self.env.state_size, action_size = self.env.action_size) # END MODIFYING THE NEW DISCRIMINATOR else: self.discriminator = Discriminator(in_dim=self.env.state_size + self.env.action_size, out_dim=2, size=self.env.d_size, lr=self.env.d_lr, do_keep_prob=self.do_keep_prob, weight_decay=self.env.weight_decay) self.policy = Policy(in_dim=self.env.state_size, out_dim=self.env.action_size, size=self.env.p_size, lr=self.env.p_lr, do_keep_prob=self.do_keep_prob, n_accum_steps=self.env.policy_accum_steps, weight_decay=self.env.weight_decay) # Create experience buffers self.er_agent = ER(memory_size=self.env.er_agent_size, state_dim=self.env.state_size, action_dim=self.env.action_size, batch_size=self.env.batch_size, history_length=1) self.er_expert = common.load_d4rl_er(h5path=os.path.join(self.env.run_dir, self.env.expert_data), batch_size=self.env.batch_size, history_length=1, traj_length=2) self.env.sigma = self.er_expert.actions_std / self.env.noise_intensity # Normalize the inputs states_ = common.normalize(self.states_, self.er_expert.states_mean, self.er_expert.states_std) states = common.normalize(self.states, self.er_expert.states_mean, self.er_expert.states_std) if self.env.continuous_actions: actions = common.normalize(self.actions, self.er_expert.actions_mean, self.er_expert.actions_std) else: actions = self.actions # 1. Forward Model initial_gru_state = np.ones((1, self.forward_model.encoding_size)) forward_model_prediction, _ = self.forward_model.forward([states_, actions, initial_gru_state]) forward_model_loss = tf.reduce_mean(tf.square(states-forward_model_prediction)) self.forward_model.train(objective=forward_model_loss) # 2. Discriminator labels = tf.concat([1 - self.label, self.label], 1) lprobs = self.lprobs # MODIFIED DISCRIMINATOR SECTION if self.use_irl: self.discrim_output, log_p_tau, log_q_tau, log_pq = self.discriminator.forward(states_, actions, states, lprobs) correct_predictions = tf.equal(tf.cast(tf.round(self.discrim_output), tf.int64), tf.argmax(labels, 1)) self.discriminator.acc = tf.reduce_mean(tf.cast(correct_predictions, "float")) d_cross_entropy = self.label*(log_p_tau-log_pq) + (1-self.label)*(log_q_tau-log_pq) d_loss_weighted = self.env.cost_sensitive_weight * tf.multiply(tf.compat.v1.to_float(tf.equal(tf.squeeze(self.label), 1.)), d_cross_entropy) +\ tf.multiply(tf.compat.v1.to_float(tf.equal(tf.squeeze(self.label), 0.)), d_cross_entropy) discriminator_loss = -tf.reduce_mean(d_loss_weighted) self.discriminator.train(objective=discriminator_loss) # END MODIFIED DISCRIMINATOR SECTION else: d = self.discriminator.forward(states, actions) # 2.1 0-1 accuracy correct_predictions = tf.equal(tf.argmax(d, 1), tf.argmax(labels, 1)) self.discriminator.acc = tf.reduce_mean(tf.cast(correct_predictions, "float")) # 2.2 prediction d_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=d, labels=labels) # cost sensitive weighting (weight true=expert, predict=agent mistakes) d_loss_weighted = self.env.cost_sensitive_weight * tf.multiply(tf.compat.v1.to_float(tf.equal(tf.squeeze(self.label), 1.)), d_cross_entropy) +\ tf.multiply(tf.compat.v1.to_float(tf.equal(tf.squeeze(self.label), 0.)), d_cross_entropy) discriminator_loss = tf.reduce_mean(d_loss_weighted) self.discriminator.train(objective=discriminator_loss) # 3. Collect experience mu = self.policy.forward(states) if self.env.continuous_actions: a = common.denormalize(mu, self.er_expert.actions_mean, self.er_expert.actions_std) eta = tf.random.normal(shape=tf.shape(a), stddev=self.env.sigma) self.action_test = a + self.noise * eta # self.action_means = mu N = tf.shape(self.action_test)[0] expanded_sigma= tf.repeat(tf.expand_dims(tf.cast(self.env.sigma, dtype=tf.float32), 0), N, axis=0) self.action_probs_test = common.compute_action_probs_tf(self.action_test, mu, expanded_sigma) else: a = common.gumbel_softmax(logits=mu, temperature=self.temp) self.action_test = tf.compat.v1.argmax(a, dimension=1) self.action_means = tf.squeeze(mu) # 4.3 AL def policy_loop(state_, t, total_cost, total_trans_err, _): mu = self.policy.forward(state_, reuse=True) if self.env.continuous_actions: eta = self.env.sigma * tf.random.normal(shape=tf.shape(mu)) action = mu + eta N = tf.shape(action)[0] expanded_sigma= tf.repeat(tf.expand_dims(tf.cast(self.env.sigma, dtype=tf.float32), 0), N, axis=0) a_prob = common.compute_action_probs_tf(action, mu, expanded_sigma) else: action = common.gumbel_softmax_sample(logits=mu, temperature=self.temp) a_prob = 0.5 # get action if self.env.continuous_actions: a_sim = common.denormalize(action, self.er_expert.actions_mean, self.er_expert.actions_std) else: a_sim = tf.compat.v1.argmax(action, dimension=1) # get next state state_env, _, env_term_sig, = self.env.step(a_sim, mode='tensorflow')[:3] state_e = common.normalize(state_env, self.er_expert.states_mean, self.er_expert.states_std) state_e = tf.stop_gradient(state_e) state_a, _ = self.forward_model.forward([state_, action, initial_gru_state], reuse=True) state, nu = common.re_parametrization(state_e=state_e, state_a=state_a) total_trans_err += tf.reduce_mean(abs(nu)) t += 1 # minimize the gap between agent logit (d[:,0]) and expert logit (d[:,1]) # MODIFIED DISCRIMINATOR SECTION: if self.use_irl: self.discrim_output, log_p_tau, log_q_tau, log_pq = self.discriminator.forward(state_, action, state, a_prob, reuse=True) cost = self.al_loss(log_p=log_p_tau, log_q=log_q_tau, log_pq=log_pq) else: d = self.discriminator.forward(state_, action, reuse=True) cost = self.al_loss(d=d) # END MODIFIED DISCRIMINATOR SECTION # add step cost total_cost += tf.multiply(tf.pow(self.gamma, t), cost) return state, t, total_cost, total_trans_err, env_term_sig def policy_stop_condition(state_, t, cost, trans_err, env_term_sig): cond = tf.logical_not(env_term_sig) cond = tf.logical_and(cond, t < self.env.n_steps_train) cond = tf.logical_and(cond, trans_err < self.env.total_trans_err_allowed) return cond state_0 = tf.slice(states, [0, 0], [1, -1]) loop_outputs = tf.while_loop(policy_stop_condition, policy_loop, [state_0, 0., 0., 0., False]) self.policy.train(objective=loop_outputs[2]) def al_loss(self, d=None, log_p=None, log_q=None, log_pq=None): if not self.use_irl: logit_agent, logit_expert = tf.split(axis=1, num_or_size_splits=2, value=d) labels = tf.concat([tf.zeros_like(logit_agent), tf.ones_like(logit_expert)], 1) d_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=d, labels=labels) else: # USING IRL d_cross_entropy = - (log_p - log_pq) + (log_q - log_pq) loss = tf.reduce_mean(d_cross_entropy) return loss*self.env.policy_al_w
class Trainer: def __init__(self, params, *, n_samples=10000000): self.model = [ fastText.load_model( os.path.join(params.dataDir, params.model_path_0)), fastText.load_model( os.path.join(params.dataDir, params.model_path_1)) ] self.dic = [ list(zip(*self.model[id].get_words(include_freq=True))) for id in [0, 1] ] x = [ np.empty((params.vocab_size, params.emb_dim), dtype=np.float64) for _ in [0, 1] ] for id in [0, 1]: for i in range(params.vocab_size): x[id][i, :] = self.model[id].get_word_vector( self.dic[id][i][0]) x[id] = normalize_embeddings_np(x[id], params.normalize_pre) u0, s0, _ = scipy.linalg.svd(x[0], full_matrices=False) u1, s1, _ = scipy.linalg.svd(x[1], full_matrices=False) if params.spectral_align_pre: s = (s0 + s1) * 0.5 x[0] = u0 @ np.diag(s) x[1] = u1 @ np.diag(s) else: x[0] = u0 @ np.diag(s0) x[1] = u1 @ np.diag(s1) self.embedding = [ nn.Embedding.from_pretrained(torch.from_numpy(x[id]).to( torch.float).to(GPU), freeze=True, sparse=True) for id in [0, 1] ] self.discriminator = Discriminator( params.emb_dim, n_layers=params.d_n_layers, n_units=params.d_n_units, drop_prob=params.d_drop_prob, drop_prob_input=params.d_drop_prob_input, leaky=params.d_leaky, batch_norm=params.d_bn).to(GPU) self.mapping = Mapping(params.emb_dim).to(GPU) if params.d_optimizer == "SGD": self.d_optimizer, self.d_scheduler = optimizers.get_sgd_adapt( self.discriminator.parameters(), lr=params.d_lr, mode="max", wd=params.d_wd) elif params.d_optimizer == "RMSProp": self.d_optimizer, self.d_scheduler = optimizers.get_rmsprop_linear( self.discriminator.parameters(), params.n_steps, lr=params.d_lr, wd=params.d_wd) else: raise Exception(f"Optimizer {params.d_optimizer} not found.") if params.m_optimizer == "SGD": self.m_optimizer, self.m_scheduler = optimizers.get_sgd_adapt( self.mapping.parameters(), lr=params.m_lr, mode="max", wd=params.m_wd, factor=params.m_lr_decay, patience=params.m_lr_patience) elif params.m_optimizer == "RMSProp": self.m_optimizer, self.m_scheduler = optimizers.get_rmsprop_linear( self.mapping.parameters(), params.n_steps, lr=params.m_lr, wd=params.m_wd) else: raise Exception(f"Optimizer {params.m_optimizer} not found") self.m_beta = params.m_beta self.smooth = params.smooth self.wgan = params.wgan self.d_clip_mode = params.d_clip_mode if params.wgan: self.loss_fn = _wasserstein_distance else: self.loss_fn = nn.BCEWithLogitsLoss(reduction="elementwise_mean") self.sampler = [ WordSampler(self.dic[id], n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top) for id in [0, 1] ] self.d_bs = params.d_bs self.d_gp = params.d_gp def get_adv_batch(self, *, reverse, gp=False): batch = [ torch.LongTensor( [self.sampler[id].sample() for _ in range(self.d_bs)]).view(self.d_bs, 1).to(GPU) for id in [0, 1] ] with torch.no_grad(): x = [ self.embedding[id](batch[id]).view(self.d_bs, -1) for id in [0, 1] ] y = torch.FloatTensor(self.d_bs * 2).to(GPU).uniform_(0.0, self.smooth) if reverse: y[:self.d_bs] = 1 - y[:self.d_bs] else: y[self.d_bs:] = 1 - y[self.d_bs:] x[0] = self.mapping(x[0]) if gp: t = torch.FloatTensor(self.d_bs, 1).to(GPU).uniform_(0.0, 1.0).expand_as(x[0]) z = x[0] * t + x[1] * (1.0 - t) x = torch.cat(x, 0) return x, y, z else: x = torch.cat(x, 0) return x, y def adversarial_step(self): self.m_optimizer.zero_grad() self.discriminator.eval() x, y = self.get_adv_batch(reverse=True) y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) loss.backward() self.m_optimizer.step() self.mapping.clip_weights() return loss.item() def discriminator_step(self): self.d_optimizer.zero_grad() self.discriminator.train() with torch.no_grad(): if self.d_gp > 0: x, y, z = self.get_adv_batch(reverse=False, gp=True) else: x, y = self.get_adv_batch(reverse=False) z = None y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) if self.d_gp > 0: z.requires_grad_() z_out = self.discriminator(z) g = autograd.grad(z_out, z, grad_outputs=torch.ones_like(z_out, device=GPU), retain_graph=True, create_graph=True, only_inputs=True)[0] gp = torch.mean((g.norm(p=2, dim=1) - 1.0)**2) loss += self.d_gp * gp loss.backward() self.d_optimizer.step() if self.wgan: self.discriminator.clip_weights(self.d_clip_mode) return loss.item() def scheduler_step(self, metric): self.m_scheduler.step(metric)
for RFsignal, RFlocs_overlapped_avg in zip( RFsignals, [ RFlocs_V1_overlapped_avg, RFlocs_V2_overlapped_avg, RFlocs_V3_overlapped_avg ]) ] brain_inputs = concat(*brain_inputs, dim=1) # ------------------------------------ # T R A I N D i s c r i m i n a t o r # ------------------------------------ targets = targets.as_in_context(context) loss_discriminator_train.append( discriminator.train( g, generator.rf_mapper(RFsignals) + brain_inputs, targets)) # ---------------------------- # T R A I N G e n e r a t o r # ---------------------------- loss_generator_train.append( generator.train(d, RFsignals, brain_inputs, targets)) os.makedirs(MODEL_DIR + '/saved_models/' + runname, exist_ok=True) generator.network.save_parameters( f'{MODEL_DIR}/saved_models/{runname}/netG_{epoch}.model') generator.rf_mapper.save_parameters( f'{MODEL_DIR}/saved_models/{runname}/RFlayers_{epoch}.model') discriminator.network.save_parameters( f'{MODEL_DIR}/saved_models/{runname}/netD_{epoch}.model')
def run_gail(agent, index_gail, env): DG_flag = 1 #env.seed(0) ob_space = env.observation_space Policy = Policy_net('policy_' + str(index_gail), env) Old_Policy = Policy_net('old_policy' + str(index_gail), env) gamma = 0.95 PPO = PPOTrain(Policy, Old_Policy, gamma) D = Discriminator(env, index_gail) if DG_flag: # with open(Config.DEMO_DATA_PATH, 'rb') as f: # demo_transitions = pickle.load(f) # demo_transitions = deque(itertools.islice(demo_transitions, 0, Config.demo_buffer_size)) # assert len(demo_transitions) == Config.demo_buffer_size expert_data = agent.replay_memory if agent.replay_memory.full( ) else agent.demo_memory _, demo_transitions, _ = expert_data.sample(agent.config.BATCH_SIZE) expert_observations = [data[0] for data in demo_transitions] expert_actions = [data[1] for data in demo_transitions] else: expert_observations = np.genfromtxt('trajectory/observations.csv') expert_actions = np.genfromtxt('trajectory/actions.csv', dtype=np.int32) with tf.Session() as sess: # writer = tf.summary.FileWriter(args.logdir, sess.graph) #load_path=saver.restore(sess,"trained_models/model.ckpt") #sess.run(tf.global_variables_initializer()) #if index_gail>1: # saver.restore(sess, 'trained_models/model' + str(index_gail-1) + '.ckpt') obs = env.reset() state_for_memory = obs #为了处理两套程序中使用的数据格式不同 success_num = 0 iteration = int(2000) #0319 for iteration in range(iteration): #print("running policy ") observations = [] #states_for_memory=[] actions = [] # do NOT use rewards to update policy , # 0319 why ? rewards = [] v_preds = [] run_policy_steps = 0 score = 0 if DG_flag: t_q = deque(maxlen=Config.trajectory_n) done, score, n_step_reward, state_for_memory = False, 0, None, env.reset( ) while True: run_policy_steps += 1 obs = np.stack([obs]).astype( dtype=np.float32) # prepare to feed placeholder Policy.obs act, v_pred = Policy.act(obs=obs, stochastic=True) act = np.asscalar(act) v_pred = np.asscalar(v_pred) next_obs, reward, done, info = env.step(act) next_state_for_memory = next_obs score += reward if DG_flag: reward_to_sub = 0. if len(t_q) < t_q.maxlen else t_q[0][ 2] # record the earliest reward for the sub t_q.append([ state_for_memory, act, reward, next_state_for_memory, done, 0.0 ]) if len(t_q) == t_q.maxlen: if n_step_reward is None: # only compute once when t_q first filled n_step_reward = sum([ t[2] * Config.GAMMA**i for i, t in enumerate(t_q) ]) else: n_step_reward = (n_step_reward - reward_to_sub) / Config.GAMMA n_step_reward += reward * Config.GAMMA**( Config.trajectory_n - 1) t_q[0].extend([ n_step_reward, next_state_for_memory, done, t_q.maxlen ]) # actual_n is max_len here #agent.perceive(t_q[0]) # perceive when a transition is completed env.render() # 0313 observations.append(obs) actions.append(act) rewards.append(reward) v_preds.append(v_pred) if done: if DG_flag: t_q.popleft( ) # first transition's n-step is already set transitions = set_n_step(t_q, Config.trajectory_n) next_obs = np.stack([next_obs]).astype( dtype=np.float32 ) # prepare to feed placeholder Policy.obs _, v_pred = Policy.act(obs=next_obs, stochastic=True) v_preds_next = v_preds[1:] + [np.asscalar(v_pred)] obs = env.reset() print("iteration", iteration, "score", score) break else: obs = next_obs state_for_memory = next_state_for_memory #print("state_for memory",state_for_memory) #writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_length', simple_value=run_policy_steps)]), iteration) #writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_reward', simple_value=sum(rewards))]), iteration) # if sum(rewards) >= 100: success_num += 1 # todo # 在 能够得到较好的回报 的时候 存储这次的demo if DG_flag: for t in transitions: agent.perceive(t) agent.replay_memory.memory_len() if success_num >= 3: #saver.save(sess, 'trained_models/model.ckpt') #saver.save(sess, 'trained_models/model' + str(index_gail) + '.ckpt') print(success_num) print('Clear!! Model saved.') env.close() break else: success_num = 0 # convert list to numpy array for feeding tf.placeholder observations = np.reshape(observations, newshape=[-1] + list(ob_space.shape)) actions = np.array(actions).astype(dtype=np.int32) # train discriminator for i in range(2): #print("training D") D.train(expert_s=expert_observations, expert_a=expert_actions, agent_s=observations, agent_a=actions) # output of this discriminator is reward d_rewards = D.get_rewards(agent_s=observations, agent_a=actions) d_rewards = np.reshape(d_rewards, newshape=[-1]).astype(dtype=np.float32) gaes = PPO.get_gaes(rewards=d_rewards, v_preds=v_preds, v_preds_next=v_preds_next) gaes = np.array(gaes).astype(dtype=np.float32) # gaes = (gaes - gaes.mean()) / gaes.std() v_preds_next = np.array(v_preds_next).astype(dtype=np.float32) # train policy inp = [observations, actions, gaes, d_rewards, v_preds_next] PPO.assign_policy_parameters() for epoch in range(6): #print("updating PPO ") sample_indices = np.random.randint( low=0, high=observations.shape[0], size=32) # indices are in [low, high) sampled_inp = [ np.take(a=a, indices=sample_indices, axis=0) for a in inp ] # sample training data PPO.train(obs=sampled_inp[0], actions=sampled_inp[1], gaes=sampled_inp[2], rewards=sampled_inp[3], v_preds_next=sampled_inp[4])
def train(): FLAGS(sys.argv) with sc2_env.SC2Env( map_name='MoveToBeacon', agent_interface_format=sc2_env.parse_agent_interface_format( feature_screen=64, feature_minimap=64, rgb_screen=None, rgb_minimap=None, action_space=None, use_feature_units=False), step_mul=step_mul, game_steps_per_episode=None, disable_fog=False, visualize=False) as env: r = tf.placeholder(tf.float32) ######## rr = tf.summary.scalar('reward', r) merged = tf.summary.merge_all() ######## expert_observations = np.genfromtxt('trajectory/observations.csv') expert_actions = np.genfromtxt('trajectory/actions.csv', dtype=np.int32) with tf.Session() as sess: Policy = Policy_net('policy', 2, 4) Old_Policy = Policy_net('old_policy', 2, 4) PPO = PPOTrain(Policy, Old_Policy) D = Discriminator() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() writer = tf.summary.FileWriter('./board/gail', sess.graph) ######## c = 0 for episodes in range(100000): done = False obs = env.reset() while not 331 in obs[0].observation.available_actions: actions = actAgent2Pysc2(100, obs) obs = env.step(actions=[actions]) state = obs2state(obs) observations = [] actions_list = [] rewards = [] v_preds = [] reward = 0 global_step = 0 while not done: global_step += 1 state = np.stack([state]).astype(dtype=np.float32) act, v_pred = Policy.act(obs=state, stochastic=True) act, v_pred = np.asscalar(act), np.asscalar(v_pred) observations.append(state) actions_list.append(act) rewards.append(reward) v_preds.append(v_pred) actions = actAgent2Pysc2(act, obs) obs = env.step(actions=[actions]) next_state = obs2state(obs) distance = obs2distance(obs) if distance < 0.03 or global_step == 100: done = True if done: v_preds_next = v_preds[1:] + [0] break state = next_state observations = np.reshape(observations, newshape=[-1, 2]) actions_list = np.array(actions_list).astype(dtype=np.int32) for i in range(2): sample_indices = (np.random.randint( expert_observations.shape[0], size=observations.shape[0])) inp = [expert_observations, expert_actions] sampled_inp = [ np.take(a=a, indices=sample_indices, axis=0) for a in inp ] # sample training data D.train(expert_s=sampled_inp[0], expert_a=sampled_inp[1], agent_s=observations, agent_a=actions_list) d_rewards = D.get_rewards(agent_s=observations, agent_a=actions_list) d_rewards = np.reshape(d_rewards, newshape=[-1]).astype(dtype=np.float32) gaes = PPO.get_gaes(rewards=d_rewards, v_preds=v_preds, v_preds_next=v_preds_next) gaes = np.array(gaes).astype(dtype=np.float32) v_preds_next = np.array(v_preds_next).astype(dtype=np.float32) inp = [ observations, actions_list, gaes, d_rewards, v_preds_next ] PPO.assign_policy_parameters() for epoch in range(15): sample_indices = np.random.randint( low=0, high=observations.shape[0], size=32) # indices are in [low, high) sampled_inp = [ np.take(a=a, indices=sample_indices, axis=0) for a in inp ] # sample training data PPO.train(obs=sampled_inp[0], actions=sampled_inp[1], gaes=sampled_inp[2], rewards=sampled_inp[3], v_preds_next=sampled_inp[4]) summary = sess.run(merged, feed_dict={r: global_step}) writer.add_summary(summary, episodes) if global_step < 50: c += 1 else: c = 0 if c > 10: saver.save(sess, './model/gail.cpkt') print('save model') break print(episodes, global_step, c)
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) print("======printing args========") print(args) print("=================================") # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): print("Loading bin dataset") dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) #args.data, splits, args.src_lang, args.trg_lang) else: print(f"Loading raw text dataset {args.data}") dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) #args.data, splits, args.src_lang, args.trg_lang) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False # try to load generator model g_model_path = 'checkpoints/generator/best_gmodel.pt' if not os.path.exists(g_model_path): print("Start training generator!") train_g(args, dataset) assert os.path.exists(g_model_path) generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() pretrained_dict = torch.load(g_model_path) #print(f"First dict: {pretrained_dict}") # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } #print(f"Second dict: {pretrained_dict}") # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) #print(f"model dict: {model_dict}") # 3. load the new state dict generator.load_state_dict(model_dict) print("Generator has successfully loaded!") # try to load discriminator model d_model_path = 'checkpoints/discriminator/best_dmodel.pt' if not os.path.exists(d_model_path): print("Start training discriminator!") train_d(args, dataset) assert os.path.exists(d_model_path) discriminator = Discriminator(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = discriminator.state_dict() pretrained_dict = torch.load(d_model_path) # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict discriminator.load_state_dict(model_dict) print("Discriminator has successfully loaded!") #return print("starting main training loop") torch.autograd.set_detect_anomaly(True) if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/joint'): os.makedirs('checkpoints/joint') checkpoints_path = 'checkpoints/joint/' # define loss function g_criterion = torch.nn.NLLLoss(size_average=False, ignore_index=dataset.dst_dict.pad(), reduce=True) d_criterion = torch.nn.BCEWithLogitsLoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator word embedding (as Wu et al. do) for p in discriminator.embed_src_tokens.parameters(): p.requires_grad = False for p in discriminator.embed_trg_tokens.parameters(): p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter( lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) # seed = args.seed + epoch_i # torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset itr = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() # set training mode generator.train() discriminator.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) for i, sample in enumerate(itr): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) ## part I: use gradient policy method to train the generator # use policy gradient training when rand > 50% rand = random.random() if rand >= 0.5: # policy gradient training generator.decoder.is_testing = True sys_out_batch, prediction, _ = generator(sample) generator.decoder.is_testing = False with torch.no_grad(): n_i = sample['net_input']['src_tokens'] #print(f"net input:\n{n_i}, pred: \n{prediction}") reward = discriminator( sample['net_input']['src_tokens'], prediction) # dataset.dst_dict.pad()) train_trg_batch = sample['target'] #print(f"sys_out_batch: {sys_out_batch.shape}:\n{sys_out_batch}") pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward, use_cuda) # logging.debug("G policy gradient loss at batch {0}: {1:.3f}, lr={2}".format(i, pg_loss.item(), g_optimizer.param_groups[0]['lr'])) g_optimizer.zero_grad() pg_loss.backward() torch.nn.utils.clip_grad_norm(generator.parameters(), args.clip_norm) g_optimizer.step() # oracle valid _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "G MLE loss at batch {0}: {1:.3f}, lr={2}".format( i, g_logging_meters['train_loss'].avg, g_optimizer.param_groups[0]['lr'])) else: # MLE training #print(f"printing sample: \n{sample}") _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "G MLE loss at batch {0}: {1:.3f}, lr={2}".format( i, g_logging_meters['train_loss'].avg, g_optimizer.param_groups[0]['lr'])) g_optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # part II: train the discriminator bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): generator.decoder.is_testing = True _, prediction, _ = generator(sample) generator.decoder.is_testing = False fake_sentence = prediction fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0) labels = torch.cat([true_labels, fake_labels], dim=0) indices = np.random.permutation(2 * bsz) trg_sentence = trg_sentence[indices][:bsz] labels = labels[indices][:bsz] if use_cuda: labels = labels.cuda() disc_out = discriminator(src_sentence, trg_sentence) #, dataset.dst_dict.pad()) #print(f"disc out: {disc_out.shape}, labels: {labels.shape}") #print(f"labels: {labels}") d_loss = d_criterion(disc_out, labels.long()) acc = torch.sum(torch.Sigmoid() (disc_out).round() == labels).float() / len(labels) d_logging_meters['train_acc'].update(acc) d_logging_meters['train_loss'].update(d_loss) # logging.debug("D training loss {0:.3f}, acc {1:.3f} at batch {2}: ".format(d_logging_meters['train_loss'].avg, # d_logging_meters['train_acc'].avg, # i)) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() # validation # set validation mode generator.eval() discriminator.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) itr = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(itr): with torch.no_grad(): if use_cuda: sample['id'] = sample['id'].cuda() sample['net_input']['src_tokens'] = sample['net_input'][ 'src_tokens'].cuda() sample['net_input']['src_lengths'] = sample['net_input'][ 'src_lengths'].cuda() sample['net_input']['prev_output_tokens'] = sample[ 'net_input']['prev_output_tokens'].cuda() sample['target'] = sample['target'].cuda() # generator validation _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss / sample_size / math.log(2) g_logging_meters['valid_loss'].update(loss, sample_size) logging.debug("G dev loss at batch {0}: {1:.3f}".format( i, g_logging_meters['valid_loss'].avg)) # discriminator validation bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): generator.decoder.is_testing = True _, prediction, _ = generator(sample) generator.decoder.is_testing = False fake_sentence = prediction fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0) labels = torch.cat([true_labels, fake_labels], dim=0) indices = np.random.permutation(2 * bsz) trg_sentence = trg_sentence[indices][:bsz] labels = labels[indices][:bsz] if use_cuda: labels = labels.cuda() disc_out = discriminator(src_sentence, trg_sentence, dataset.dst_dict.pad()) d_loss = d_criterion(disc_out, labels) acc = torch.sum(torch.Sigmoid()(disc_out).round() == labels).float() / len(labels) d_logging_meters['valid_acc'].update(acc) d_logging_meters['valid_loss'].update(d_loss) # logging.debug("D dev loss {0:.3f}, acc {1:.3f} at batch {2}".format(d_logging_meters['valid_loss'].avg, # d_logging_meters['valid_acc'].avg, i)) torch.save(generator, open( checkpoints_path + "joint_{0:.3f}.epoch_{1}.pt".format( g_logging_meters['valid_loss'].avg, epoch_i), 'wb'), pickle_module=dill) if g_logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = g_logging_meters['valid_loss'].avg torch.save(generator, open(checkpoints_path + "best_gmodel.pt", 'wb'), pickle_module=dill)
if len(sys.argv) > 1: already_trained_g = True netG.load_state_dict(torch.load(sys.argv[1])) if (device.type == 'cuda') and (ngpu > 1): # Handle multi-gpu if desired netG = nn.DataParallel(netG, list(range(ngpu))) #print(netG) # Print the model netG.train() # DISCRIMINATOR netD = Discriminator(in_nc_d, out_nc_d, ndf).to(device) if (device.type == 'cuda') and (ngpu > 1): # Handle multi-gpu if desired netD = nn.DataParallel(netD, list(range(ngpu))) #print(netD) # Print the model netD.train() # VGG19 vgg = VGG19(init_weights=vggroot, feature_mode=True).to(device) # Initialize BCELoss, L1Loss function bce_loss = nn.BCELoss().to(device) l1_loss = nn.L1Loss().to(device) # Setup Adam optimizers for both G and D optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999)) G_scheduler = optim.lr_scheduler.MultiStepLR( optimizer=optimizerG, milestones=[num_epochs // 2, num_epochs // 4 * 3],
def semi_main(options): print('\nSemi-Supervised Learning!\n') # 1. Make sure the options are valid argparse CLI options indeed assert isinstance(options, argparse.Namespace) # 2. Set up the logger logging.basicConfig(level=str(options.loglevel).upper()) # 3. Make sure the output dir `outf` exists _check_out_dir(options) # 4. Set the random state _set_random_state(options) # 5. Configure CUDA and Cudnn, set the global `device` for PyTorch device = _configure_cuda(options) # 6. Prepare the datasets and split it for semi-supervised learning if options.dataset != 'cifar10': raise NotImplementedError( 'Semi-supervised learning only support CIFAR10 dataset at the moment!' ) test_data_loader, semi_data_loader, train_data_loader = _prepare_semi_dataset( options) # 7. Set the parameters ngpu = int(options.ngpu) # num of GPUs nz = int( options.nz) # size of latent vector, also the number of the generators ngf = int(options.ngf) # depth of feature maps through G ndf = int(options.ndf) # depth of feature maps through D nc = int(options.nc ) # num of channels of the input images, 3 indicates color images M = int(options.mcmc) # num of SGHMC chains run concurrently nd = int(options.nd) # num of discriminators nsetz = int(options.nsetz) # num of noise batches # 8. Special preparations for Bayesian GAN for Generators # In order to inject the SGHMAC into the training process, instead of pause the gradient descent at # each training step, which can be easily defined with static computation graph(Tensorflow), in PyTorch, # we have to move the Generator Sampling to the very beginning of the whole training process, and use # a trick that initializing all of the generators explicitly for later usages. Generator_chains = [] for _ in range(nsetz): for __ in range(M): netG = Generator(ngpu, nz, ngf, nc).to(device) netG.apply(weights_init) Generator_chains.append(netG) logging.info( f'Showing the first generator of the Generator chain: \n {Generator_chains[0]}\n' ) # 9. Special preparations for Bayesian GAN for Discriminators assert options.dataset == 'cifar10', 'Semi-supervised learning only support CIFAR10 dataset at the moment!' num_class = 10 + 1 # To simplify the implementation we only consider the situation of 1 discriminator # if nd <= 1: # netD = Discriminator(ngpu, ndf, nc, num_class=num_class).to(device) # netD.apply(weights_init) # else: # Discriminator_chains = [] # for _ in range(nd): # for __ in range(M): # netD = Discriminator(ngpu, ndf, nc, num_class=num_class).to(device) # netD.apply(weights_init) # Discriminator_chains.append(netD) netD = Discriminator(ngpu, ndf, nc, num_class=num_class).to(device) netD.apply(weights_init) logging.info(f'Showing the Discriminator model: \n {netD}\n') # 10. Loss function criterion = nn.CrossEntropyLoss() all_criterion = ComplementCrossEntropyLoss(except_index=0, device=device) # 11. Set up optimizers optimizerG_chains = [ optim.Adam(netG.parameters(), lr=options.lr, betas=(options.beta1, 0.999)) for netG in Generator_chains ] # optimizerD_chains = [ # optim.Adam(netD.parameters(), lr=options.lr, betas=(options.beta1, 0.999)) for netD in Discriminator_chains # ] optimizerD = optim.Adam(netD.parameters(), lr=options.lr, betas=(options.beta1, 0.999)) import math # 12. Set up the losses for priors and noises gprior = PriorLoss(prior_std=1., total=500.) gnoise = NoiseLoss(params=Generator_chains[0].parameters(), device=device, scale=math.sqrt(2 * options.alpha / options.lr), total=500.) dprior = PriorLoss(prior_std=1., total=50000.) dnoise = NoiseLoss(params=netD.parameters(), device=device, scale=math.sqrt(2 * options.alpha * options.lr), total=50000.) gprior.to(device=device) gnoise.to(device=device) dprior.to(device=device) dnoise.to(device=device) # In order to let G condition on a specific noise, we attach the noise to a fixed Tensor fixed_noise = torch.FloatTensor(options.batchSize, options.nz, 1, 1).normal_(0, 1).to(device=device) inputT = torch.FloatTensor(options.batchSize, 3, options.imageSize, options.imageSize).to(device=device) noiseT = torch.FloatTensor(options.batchSize, options.nz, 1, 1).to(device=device) labelT = torch.FloatTensor(options.batchSize).to(device=device) real_label = 1 fake_label = 0 # 13. Transfer all the tensors and modules to GPU if applicable # for netD in Discriminator_chains: # netD.to(device=device) netD.to(device=device) for netG in Generator_chains: netG.to(device=device) criterion.to(device=device) all_criterion.to(device=device) # ======================== # === Training Process === # ======================== # Lists to keep track of progress img_list = [] G_losses = [] D_losses = [] stats = [] iters = 0 try: print("\nStarting Training Loop...\n") for epoch in range(options.niter): top1 = Metrics() for i, data in enumerate(train_data_loader, 0): # ################## # Train with real # ################## netD.zero_grad() real_cpu = data[0].to(device) batch_size = real_cpu.size(0) # label = torch.full((batch_size,), real_label, device=device) inputT.resize_as_(real_cpu).copy_(real_cpu) labelT.resize_(batch_size).fill_(real_label) inputv = torch.autograd.Variable(inputT) labelv = torch.autograd.Variable(labelT) output = netD(inputv) errD_real = all_criterion(output) errD_real.backward() D_x = 1 - torch.nn.functional.softmax( output).data[:, 0].mean().item() # ################## # Train with fake # ################## fake_images = [] for i_z in range(nsetz): noiseT.resize_(batch_size, nz, 1, 1).normal_( 0, 1) # prior, sample from N(0, 1) distribution noisev = torch.autograd.Variable(noiseT) for m in range(M): idx = i_z * M + m netG = Generator_chains[idx] _fake = netG(noisev) fake_images.append(_fake) # output = torch.stack(fake_images) fake = torch.cat(fake_images) output = netD(fake.detach()) labelv = torch.autograd.Variable( torch.LongTensor(fake.data.shape[0]).to( device=device).fill_(fake_label)) errD_fake = criterion(output, labelv) errD_fake.backward() D_G_z1 = 1 - torch.nn.functional.softmax( output).data[:, 0].mean().item() # ################## # Semi-supervised learning # ################## for ii, (input_sup, target_sup) in enumerate(semi_data_loader): input_sup, target_sup = input_sup.to( device=device), target_sup.to(device=device) break input_sup_v = input_sup.to(device=device) target_sup_v = (target_sup + 1).to(device=device) output_sup = netD(input_sup_v) err_sup = criterion(output_sup, target_sup_v) err_sup.backward() pred1 = accuracy(output_sup.data, target_sup + 1, topk=(1, ))[0] top1.update(value=pred1.item(), N=input_sup.size(0)) errD_prior = dprior(netD.parameters()) errD_prior.backward() errD_noise = dnoise(netD.parameters()) errD_noise.backward() errD = errD_real + errD_fake + err_sup + errD_prior + errD_noise optimizerD.step() # ################## # Sample and construct generator(s) # ################## for netG in Generator_chains: netG.zero_grad() labelv = torch.autograd.Variable( torch.FloatTensor(fake.data.shape[0]).to( device=device).fill_(real_label)) output = netD(fake) errG = all_criterion(output) for netG in Generator_chains: errG = errG + gprior(netG.parameters()) errG = errG + gnoise(netG.parameters()) errG.backward() D_G_z2 = 1 - torch.nn.functional.softmax( output).data[:, 0].mean().item() for optimizerG in optimizerG_chains: optimizerG.step() # ################## # Evaluate testing accuracy # ################## # Pause and compute the test accuracy after every 10 times of the notefreq if iters % 10 * int(options.notefreq) == 0: # get test accuracy on train and test netD.eval() compute_test_accuracy(discriminator=netD, testing_data_loader=test_data_loader, device=device) netD.train() # ################## # Note down # ################## # Report status for the current iteration training_status = f"[{epoch}/{options.niter}][{i}/{len(train_data_loader)}] Loss_D: {errD.item():.4f} " \ f"Loss_G: " \ f"{errG.item():.4f} D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}" \ f" | Acc {top1.value:.1f} / {top1.mean:.1f}" print(training_status) # Save samples to disk if i % int(options.notefreq) == 0: vutils.save_image( real_cpu, f"{options.outf}/real_samples_epoch_{epoch:{0}{3}}_{i}.png", normalize=True) for _iz in range(nsetz): for _m in range(M): gidx = _iz * M + _m netG = Generator_chains[gidx] fake = netG(fixed_noise) vutils.save_image( fake.detach(), f"{options.outf}/fake_samples_epoch_{epoch:{0}{3}}_{i}_z{_iz}_m{_m}.png", normalize=True) # Save Losses statistics for post-mortem G_losses.append(errG.item()) D_losses.append(errD.item()) stats.append(training_status) # # Check how the generator is doing by saving G's output on fixed_noise # if (iters % 500 == 0) or ((epoch == options.niter - 1) and (i == len(data_loader) - 1)): # with torch.no_grad(): # fake = netG(fixed_noise).detach().cpu() # img_list.append(vutils.make_grid(fake, padding=2, normalize=True)) iters += 1 # TODO: find an elegant way to support saving checkpoints in Bayesian GAN context except Exception as e: print(e) # save training stats no matter what kind of errors occur in the processes _save_stats(statistic=G_losses, save_name='G_losses', options=options) _save_stats(statistic=D_losses, save_name='D_losses', options=options) _save_stats(statistic=stats, save_name='Training_stats', options=options)
model_g.cuda() input, label = input.cuda(), label.cuda() noise, fixed_noise = noise.cuda(), fixed_noise.cuda() one_hot_labels = one_hot_labels.cuda() fixed_labels = fixed_labels.cuda() optim_d = optim.SGD(model_d.parameters(), lr=args.lr) optim_g = optim.SGD(model_g.parameters(), lr=args.lr) fixed_noise = Variable(fixed_noise) fixed_labels = Variable(fixed_labels) real_label = 1 fake_label = 0 for epoch_idx in range(args.epochs): model_d.train() model_g.train() d_loss = 0.0 g_loss = 0.0 for batch_idx, (train_x, train_y) in enumerate(train_loader): batch_size = train_x.size(0) train_x = train_x.view(-1, INPUT_SIZE) if args.cuda: train_x = train_x.cuda() train_y = train_y.cuda() input.resize_as_(train_x).copy_(train_x) label.resize_(batch_size).fill_(real_label) one_hot_labels.resize_(batch_size, NUM_LABELS).zero_()
class MGAIL(object): def __init__(self, environment): self.env = environment # Create placeholders for all the inputs self.states_ = tf.placeholder( "float", shape=(None, ) + self.env.state_size, name='states_') # Batch x State, previous state self.states = tf.placeholder( "float", shape=(None, ) + self.env.state_size, name='states') # Batch x State, current_state self.actions = tf.placeholder("float", shape=(None, self.env.action_size), name='action') # Batch x Action self.label = tf.placeholder("float", shape=(None, 1), name='label') self.gamma = tf.placeholder("float", shape=(), name='gamma') self.temp = tf.placeholder("float", shape=(), name='temperature') self.noise = tf.placeholder("float", shape=(), name='noise_flag') self.do_keep_prob = tf.placeholder("float", shape=(), name='do_keep_prob') if self.env.use_airl: self.done_ph = tf.placeholder(name="dones", shape=(None, ), dtype=tf.float32) # Create MGAIL blocks self.forward_model = ForwardModel( state_size=self.env.state_size[0] if self.env.obs_mode == 'state' else self.env.encoder_feat_size, action_size=self.env.action_size, encoding_size=self.env.fm_size, lr=self.env.fm_lr, forward_model_type=self.env.forward_model_type, obs_mode=self.env.obs_mode, use_scale_dot_product=self.env.use_scale_dot_product, use_skip_connection=self.env.use_skip_connection, use_dropout=self.env.use_dropout) if self.env.obs_mode == 'pixel': if self.env.state_only: feat_in_dim = 1024 # self.env.encoder_feat_size[0] policy_input_feat = 1024 else: feat_in_dim = 1024 + self.env.action_size # self.env.encoder_feat_size[0] policy_input_feat = 1024 else: if self.env.state_only: feat_in_dim = self.env.state_size[0] policy_input_feat = self.env.state_size[0] else: feat_in_dim = self.env.state_size[0] + self.env.action_size policy_input_feat = self.env.state_size[0] self.discriminator = Discriminator( in_dim=feat_in_dim, out_dim=self.env.disc_out_dim, size=self.env.d_size, lr=self.env.d_lr, do_keep_prob=self.do_keep_prob, weight_decay=self.env.weight_decay, use_airl=self.env.use_airl, phi_hidden_size=self.env.phi_size, state_only=self.env.state_only, ) self.policy = Policy(in_dim=policy_input_feat, out_dim=self.env.action_size, size=self.env.p_size, lr=self.env.p_lr, do_keep_prob=self.do_keep_prob, n_accum_steps=self.env.policy_accum_steps, weight_decay=self.env.weight_decay) # Create experience buffers self.er_agent = ER( memory_size=self.env.er_agent_size, state_dim=self.env.state_size, action_dim=self.env.action_size, reward_dim=1, # stub connection qpos_dim=self.env.qpos_size, qvel_dim=self.env.qvel_size, batch_size=self.env.batch_size, history_length=1) self.er_expert = common.load_er(fname=os.path.join( self.env.run_dir, self.env.expert_data), batch_size=self.env.batch_size, history_length=1, traj_length=2) self.env.sigma = self.er_expert.actions_std / self.env.noise_intensity if self.env.obs_mode == 'pixel': current_states = ops.preprocess(self.states, bits=8) current_states_feat = ops.encoder(current_states, reuse=tf.AUTO_REUSE) prev_states = ops.preprocess(self.states_, bits=8) prev_states_feat = ops.encoder(prev_states, reuse=tf.AUTO_REUSE) else: # Normalize the inputs prev_states = common.normalize(self.states_, self.er_expert.states_mean, self.er_expert.states_std) current_states = common.normalize(self.states, self.er_expert.states_mean, self.er_expert.states_std) prev_states_feat = prev_states current_states_feat = current_states if self.env.continuous_actions: actions = common.normalize(self.actions, self.er_expert.actions_mean, self.er_expert.actions_std) else: actions = self.actions # 1. Forward Model initial_gru_state = np.ones((1, self.forward_model.encoding_size)) forward_model_prediction, _, divergence_loss = self.forward_model.forward( [prev_states_feat, actions, initial_gru_state]) if self.env.obs_mode == 'pixel': forward_model_prediction = ops.decoder( forward_model_prediction, data_shape=self.env.state_size, reuse=tf.AUTO_REUSE) self.forward_model_prediction = ops.postprocess( forward_model_prediction, bits=8, dtype=tf.uint8) else: self.forward_model_prediction = forward_model_prediction forward_model_loss = tf.reduce_mean( tf.square(current_states - forward_model_prediction) ) + self.env.forward_model_lambda * tf.reduce_mean(divergence_loss) self.forward_model.train(objective=forward_model_loss) if self.env.use_airl: # 1.1 action log prob logits = self.policy.forward(current_states_feat) if self.env.continuous_actions: mean, logstd = logits, tf.log(tf.ones_like(logits)) std = tf.exp(logstd) n_elts = tf.cast(tf.reduce_prod(mean.shape[1:]), tf.float32) # first dimension is batch size log_normalizer = n_elts / 2. * (np.log(2 * np.pi).astype( np.float32)) + 1 / 2 * tf.reduce_sum(logstd, axis=1) # Diagonal Gaussian action probability, for every action action_logprob = -tf.reduce_sum(tf.square(actions - mean) / (2 * std), axis=1) - log_normalizer else: # Override since the implementation of tfp.RelaxedOneHotCategorical # yields positive values. if actions.shape[1:] != logits.shape[1:]: actions = tf.cast(actions, tf.int8) values = tf.one_hot(actions, logits.shape.as_list()[-1], dtype=tf.float32) assert values.shape == logits.shape, (values.shape, logits.shape) else: values = actions # [0]'s implementation (see line below) seems to be an approximation # to the actual Gumbel Softmax density. # TODO: to confirm 'action' or 'value' action_logprob = -tf.reduce_sum( -values * tf.nn.log_softmax(logits, axis=-1), axis=-1) # prob = logit[np.arange(self.action_test.shape[0]), self.action_test] # action_logprob = tf.log(prob) # 2. Discriminator self.discriminator.airl_entropy_weight = self.env.airl_entropy_weight # labels = tf.concat([1 - self.label, self.label], 1) # labels = 1 - self.label # 0 for expert, 1 for policy labels = self.label # 1 for expert, 0 for policy d, self.disc_shaped_reward_output, self.disc_reward = self.discriminator.forward( state=current_states_feat, action=actions, prev_state=prev_states_feat, done_inp=self.done_ph, log_policy_act_prob=action_logprob, ) # 2.1 0-1 accuracy correct_predictions = tf.equal(tf.argmax(d, 1), tf.argmax(labels, 1)) self.discriminator.acc = tf.reduce_mean( tf.cast(correct_predictions, "float")) # 2.2 prediction d_cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=d, name="disc_loss", ) # Construct generator reward: # \[\hat{r}(s,a) = \log(D_{\theta}(s,a)) - \log(1 - D_{\theta}(s,a)).\] # This simplifies to: # \[\hat{r}(s,a) = f_{\theta}(s,a) - \log \pi(a \mid s).\] # This is just an entropy-regularized objective # ent_bonus = -self.env.airl_entropy_weight * self.discriminator.log_policy_act_prob_ph # policy_train_reward = self.discriminator.reward_net.reward_output_train + ent_bonus else: # 2. Discriminator labels = tf.concat([1 - self.label, self.label], 1) d, _, _ = self.discriminator.forward(state=current_states_feat, action=actions) # 2.1 0-1 accuracy correct_predictions = tf.equal(tf.argmax(d, 1), tf.argmax(labels, 1)) self.discriminator.acc = tf.reduce_mean( tf.cast(correct_predictions, "float")) # 2.2 prediction d_cross_entropy = tf.nn.softmax_cross_entropy_with_logits( logits=d, labels=labels) # cost sensitive weighting (weight true=expert, predict=agent mistakes) d_loss_weighted = self.env.cost_sensitive_weight * tf.multiply(tf.to_float(tf.equal(tf.squeeze(self.label), 1.)), d_cross_entropy) +\ tf.multiply(tf.to_float(tf.equal(tf.squeeze(self.label), 0.)), d_cross_entropy) discriminator_loss = tf.reduce_mean(d_loss_weighted) self.discriminator.train(objective=discriminator_loss) # 3. Collect experience mu = self.policy.forward(current_states_feat) if self.env.continuous_actions: a = common.denormalize(mu, self.er_expert.actions_mean, self.er_expert.actions_std) eta = tf.random_normal(shape=tf.shape(a), stddev=self.env.sigma) self.action_test = tf.squeeze(a + self.noise * eta) else: a = common.gumbel_softmax(logits=mu, temperature=self.temp) self.action_test = tf.argmax(a, dimension=1) # 4.3 AL def policy_loop(current_state_policy_update, t, total_cost, total_trans_err, env_term_sig, prev_state): if self.env.obs_mode == 'pixel': current_state_feat_policy_update = ops.encoder( current_state_policy_update, reuse=True) prev_state_feat_policy_update = ops.encoder(prev_state, reuse=True) else: current_state_feat_policy_update = current_state_policy_update prev_state_feat_policy_update = prev_state mu = self.policy.forward(current_state_feat_policy_update, reuse=True) if self.env.continuous_actions: eta = self.env.sigma * tf.random_normal(shape=tf.shape(mu)) action = mu + eta if self.env.use_airl: mean, logstd = mu, tf.log( tf.ones_like(mu) * self.env.sigma) std = tf.exp(logstd) n_elts = tf.cast( tf.reduce_prod(mean.shape[1:]), tf.float32) # first dimension is batch size log_normalizer = n_elts / 2. * (np.log(2 * np.pi).astype( np.float32)) + 1 / 2 * tf.reduce_sum(logstd, axis=1) # Diagonal Gaussian action probability, for every action action_logprob = -tf.reduce_sum(tf.square(action - mean) / (2 * std), axis=1) - log_normalizer else: action = common.gumbel_softmax_sample(logits=mu, temperature=self.temp) if self.env.use_airl: # Override since the implementation of tfp.RelaxedOneHotCategorical # yields positive values. if action.shape[1:] != logits.shape[1:]: actions = tf.cast(action, tf.int8) values = tf.one_hot(actions, logits.shape.as_list()[-1], dtype=tf.float32) assert values.shape == logits.shape, (values.shape, logits.shape) else: values = action # [0]'s implementation (see line below) seems to be an approximation # to the actual Gumbel Softmax density. # TODO: to confirm 'action' or 'value' action_logprob = -tf.reduce_sum( -values * tf.nn.log_softmax(logits, axis=-1), axis=-1) # minimize the gap between agent logit (d[:,0]) and expert logit (d[:,1]) if self.env.use_airl: d, shaped_reward_output, reward = self.discriminator.forward( state=current_state_feat_policy_update, action=action, prev_state=prev_state_feat_policy_update, done_inp=tf.cast(env_term_sig, tf.float32), log_policy_act_prob=action_logprob, reuse=True) if self.env.alg in ['mairlTransfer', 'mairlImit4Transfer']: reward_for_updating_policy = reward else: # 'mairlImit' reward_for_updating_policy = shaped_reward_output if self.env.train_mode and not self.env.alg in [ 'mairlTransfer', 'mairlImit4Transfer' ]: ent_bonus = -self.env.airl_entropy_weight * tf.stop_gradient( action_logprob) policy_reward = reward_for_updating_policy + ent_bonus else: policy_reward = reward_for_updating_policy cost = tf.reduce_mean(-policy_reward) * self.env.policy_al_w else: d, _, _ = self.discriminator.forward( state=current_state_feat_policy_update, action=action, reuse=True) cost = self.al_loss(d) # add step cost total_cost += tf.multiply(tf.pow(self.gamma, t), cost) # get action if self.env.continuous_actions: a_sim = common.denormalize(action, self.er_expert.actions_mean, self.er_expert.actions_std) else: a_sim = tf.argmax(action, dimension=1) # get next state state_env, _, env_term_sig, = self.env.step(a_sim, mode='tensorflow')[:3] state_e = common.normalize(state_env, self.er_expert.states_mean, self.er_expert.states_std) state_e = tf.stop_gradient(state_e) state_a, _, divergence_loss_a = self.forward_model.forward( [current_state_feat_policy_update, action, initial_gru_state], reuse=True) if self.env.obs_mode == 'pixel': state_a = ops.decoder(state_a, data_shape=self.env.state_size, reuse=True) if True: # self.env.alg in ['mgail']: state, nu = common.re_parametrization(state_e=state_e, state_a=state_a) else: _, nu = common.re_parametrization(state_e=state_e, state_a=state_a) state = state_a total_trans_err += tf.reduce_mean(abs(nu)) t += 1 if self.env.obs_mode == 'pixel': state = tf.slice(state, [0, 0, 0, 0], [1, -1, -1, -1]) return state, t, total_cost, total_trans_err, env_term_sig, current_state_policy_update def policy_stop_condition(current_state_policy_update, t, cost, trans_err, env_term_sig, prev_state): cond = tf.logical_not( env_term_sig) # not done: env_term_sig = False cond = tf.logical_and(cond, t < self.env.n_steps_train) cond = tf.logical_and(cond, trans_err < self.env.total_trans_err_allowed) return cond if self.env.obs_mode == 'pixel': state_0 = tf.slice(current_states, [0, 0, 0, 0], [1, -1, -1, -1]) else: state_0 = tf.slice(current_states, [0, 0], [1, -1]) # prev_state_0 = tf.slice(states_, [0, 0], [1, -1]) loop_outputs = tf.while_loop(policy_stop_condition, policy_loop, [state_0, 0., 0., 0., False, state_0]) self.policy.train(objective=loop_outputs[2]) def al_loss(self, d): logit_agent, logit_expert = tf.split(axis=1, num_or_size_splits=2, value=d) # Cross entropy loss labels = tf.concat( [tf.zeros_like(logit_agent), tf.ones_like(logit_expert)], 1) d_cross_entropy = tf.nn.softmax_cross_entropy_with_logits( logits=d, labels=labels) loss = tf.reduce_mean(d_cross_entropy) return loss * self.env.policy_al_w
class trainer(object): def __init__(self, cfg): self.cfg = cfg self.OldLabel_generator = U_Net(in_ch=cfg.DATASET.N_CLASS, out_ch=cfg.DATASET.N_CLASS, side='out') self.Image_generator = U_Net(in_ch=3, out_ch=cfg.DATASET.N_CLASS, side='in') self.discriminator = Discriminator(cfg.DATASET.N_CLASS + 3, cfg.DATASET.IMGSIZE, patch=True) self.criterion_G = GeneratorLoss(cfg.LOSS.LOSS_WEIGHT[0], cfg.LOSS.LOSS_WEIGHT[1], cfg.LOSS.LOSS_WEIGHT[2], ignore_index=cfg.LOSS.IGNORE_INDEX) self.criterion_D = DiscriminatorLoss() train_dataset = BaseDataset(cfg, split='train') valid_dataset = BaseDataset(cfg, split='val') self.train_dataloader = data.DataLoader( train_dataset, batch_size=cfg.DATASET.BATCHSIZE, num_workers=8, shuffle=True, drop_last=True) self.valid_dataloader = data.DataLoader( valid_dataset, batch_size=cfg.DATASET.BATCHSIZE, num_workers=8, shuffle=True, drop_last=True) self.ckpt_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints') if not os.path.isdir(self.ckpt_outdir): os.mkdir(self.ckpt_outdir) self.val_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'val') if not os.path.isdir(self.val_outdir): os.mkdir(self.val_outdir) self.start_epoch = cfg.TRAIN.RESUME self.n_epoch = cfg.TRAIN.N_EPOCH self.optimizer_G = torch.optim.Adam( [{ 'params': self.OldLabel_generator.parameters() }, { 'params': self.Image_generator.parameters() }], lr=cfg.OPTIMIZER.G_LR, betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY) self.optimizer_D = torch.optim.Adam( [{ 'params': self.discriminator.parameters(), 'initial_lr': cfg.OPTIMIZER.D_LR }], lr=cfg.OPTIMIZER.D_LR, betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY) iter_per_epoch = len(train_dataset) // cfg.DATASET.BATCHSIZE lambda_poly = lambda iters: pow( (1.0 - iters / (cfg.TRAIN.N_EPOCH * iter_per_epoch)), 0.9) self.scheduler_G = torch.optim.lr_scheduler.LambdaLR( self.optimizer_G, lr_lambda=lambda_poly, ) # last_epoch=(self.start_epoch+1)*iter_per_epoch) self.scheduler_D = torch.optim.lr_scheduler.LambdaLR( self.optimizer_D, lr_lambda=lambda_poly, ) # last_epoch=(self.start_epoch+1)*iter_per_epoch) self.logger = logger(cfg.TRAIN.OUTDIR, name='train') self.running_metrics = runningScore(n_classes=cfg.DATASET.N_CLASS) if self.start_epoch >= 0: self.OldLabel_generator.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['model_G_N']) self.Image_generator.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['model_G_I']) self.discriminator.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['model_D']) self.optimizer_G.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['optimizer_G']) self.optimizer_D.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['optimizer_D']) log = "Using the {}th checkpoint".format(self.start_epoch) self.logger.info(log) self.Image_generator = self.Image_generator.cuda() self.OldLabel_generator = self.OldLabel_generator.cuda() self.discriminator = self.discriminator.cuda() self.criterion_G = self.criterion_G.cuda() self.criterion_D = self.criterion_D.cuda() def train(self): all_train_iter_total_loss = [] all_train_iter_corr_loss = [] all_train_iter_recover_loss = [] all_train_iter_change_loss = [] all_train_iter_gan_loss_gen = [] all_train_iter_gan_loss_dis = [] all_val_epo_iou = [] all_val_epo_acc = [] iter_num = [0] epoch_num = [] num_batches = len(self.train_dataloader) for epoch_i in range(self.start_epoch + 1, self.n_epoch): iter_total_loss = AverageTracker() iter_corr_loss = AverageTracker() iter_recover_loss = AverageTracker() iter_change_loss = AverageTracker() iter_gan_loss_gen = AverageTracker() iter_gan_loss_dis = AverageTracker() batch_time = AverageTracker() tic = time.time() # train self.OldLabel_generator.train() self.Image_generator.train() self.discriminator.train() for i, meta in enumerate(self.train_dataloader): image, old_label, new_label = meta[0].cuda(), meta[1].cuda( ), meta[2].cuda() recover_pred, feats = self.OldLabel_generator( label2onehot(old_label, self.cfg.DATASET.N_CLASS)) corr_pred = self.Image_generator(image, feats) # ------------------- # Train Discriminator # ------------------- self.discriminator.set_requires_grad(True) self.optimizer_D.zero_grad() fake_sample = torch.cat((image, corr_pred), 1).detach() real_sample = torch.cat( (image, label2onehot(new_label, cfg.DATASET.N_CLASS)), 1) score_fake_d = self.discriminator(fake_sample) score_real = self.discriminator(real_sample) gan_loss_dis = self.criterion_D(pred_score=score_fake_d, real_score=score_real) gan_loss_dis.backward() self.optimizer_D.step() self.scheduler_D.step() # --------------- # Train Generator # --------------- self.discriminator.set_requires_grad(False) self.optimizer_G.zero_grad() score_fake = self.discriminator( torch.cat((image, corr_pred), 1)) total_loss, corr_loss, recover_loss, change_loss, gan_loss_gen = self.criterion_G( corr_pred, recover_pred, score_fake, old_label, new_label) total_loss.backward() self.optimizer_G.step() self.scheduler_G.step() iter_total_loss.update(total_loss.item()) iter_corr_loss.update(corr_loss.item()) iter_recover_loss.update(recover_loss.item()) iter_change_loss.update(change_loss.item()) iter_gan_loss_gen.update(gan_loss_gen.item()) iter_gan_loss_dis.update(gan_loss_dis.item()) batch_time.update(time.time() - tic) tic = time.time() log = '{}: Epoch: [{}][{}/{}], Time: {:.2f}, ' \ 'Total Loss: {:.6f}, Corr Loss: {:.6f}, Recover Loss: {:.6f}, Change Loss: {:.6f}, GAN_G Loss: {:.6f}, GAN_D Loss: {:.6f}'.format( datetime.now(), epoch_i, i, num_batches, batch_time.avg, total_loss.item(), corr_loss.item(), recover_loss.item(), change_loss.item(), gan_loss_gen.item(), gan_loss_dis.item()) print(log) if (i + 1) % 10 == 0: all_train_iter_total_loss.append(iter_total_loss.avg) all_train_iter_corr_loss.append(iter_corr_loss.avg) all_train_iter_recover_loss.append(iter_recover_loss.avg) all_train_iter_change_loss.append(iter_change_loss.avg) all_train_iter_gan_loss_gen.append(iter_gan_loss_gen.avg) all_train_iter_gan_loss_dis.append(iter_gan_loss_dis.avg) iter_total_loss.reset() iter_corr_loss.reset() iter_recover_loss.reset() iter_change_loss.reset() iter_gan_loss_gen.reset() iter_gan_loss_dis.reset() vis.line(X=np.column_stack( np.repeat(np.expand_dims(iter_num, 0), 6, axis=0)), Y=np.column_stack((all_train_iter_total_loss, all_train_iter_corr_loss, all_train_iter_recover_loss, all_train_iter_change_loss, all_train_iter_gan_loss_gen, all_train_iter_gan_loss_dis)), opts={ 'legend': [ 'total_loss', 'corr_loss', 'recover_loss', 'change_loss', 'gan_loss_gen', 'gan_loss_dis' ], 'linecolor': np.array([[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [0, 255, 255], [255, 0, 255]]), 'title': 'Train loss of generator and discriminator' }, win='Train loss of generator and discriminator') iter_num.append(iter_num[-1] + 1) # eval self.OldLabel_generator.eval() self.Image_generator.eval() self.discriminator.eval() with torch.no_grad(): for j, meta in enumerate(self.valid_dataloader): image, old_label, new_label = meta[0].cuda(), meta[1].cuda( ), meta[2].cuda() recover_pred, feats = self.OldLabel_generator( label2onehot(old_label, self.cfg.DATASET.N_CLASS)) corr_pred = self.Image_generator(image, feats) preds = np.argmax(corr_pred.cpu().detach().numpy().copy(), axis=1) target = new_label.cpu().detach().numpy().copy() self.running_metrics.update(target, preds) if j == 0: color_map1 = gen_color_map(preds[0, :]).astype( np.uint8) color_map2 = gen_color_map(preds[1, :]).astype( np.uint8) color_map = cv2.hconcat([color_map1, color_map2]) cv2.imwrite( os.path.join( self.val_outdir, '{}epoch*{}*{}.png'.format( epoch_i, meta[3][0], meta[3][1])), color_map) score = self.running_metrics.get_scores() oa = score['Overall Acc: \t'] precision = score['Precision: \t'][1] recall = score['Recall: \t'][1] iou = score['Class IoU: \t'][1] miou = score['Mean IoU: \t'] self.running_metrics.reset() epoch_num.append(epoch_i) all_val_epo_acc.append(oa) all_val_epo_iou.append(miou) vis.line(X=np.column_stack( np.repeat(np.expand_dims(epoch_num, 0), 2, axis=0)), Y=np.column_stack((all_val_epo_acc, all_val_epo_iou)), opts={ 'legend': ['val epoch Overall Acc', 'val epoch Mean IoU'], 'linecolor': np.array([[255, 0, 0], [0, 255, 0]]), 'title': 'Validate Accuracy and IoU' }, win='validate Accuracy and IoU') log = '{}: Epoch Val: [{}], ACC: {:.2f}, Recall: {:.2f}, mIoU: {:.4f}' \ .format(datetime.now(), epoch_i, oa, recall, miou) self.logger.info(log) state = { 'epoch': epoch_i, "acc": oa, "recall": recall, "iou": miou, 'model_G_N': self.OldLabel_generator.state_dict(), 'model_G_I': self.Image_generator.state_dict(), 'model_D': self.discriminator.state_dict(), 'optimizer_G': self.optimizer_G.state_dict(), 'optimizer_D': self.optimizer_D.state_dict() } save_path = os.path.join(self.cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format(epoch_i)) torch.save(state, save_path)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device) disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device) initialize_weights(gen) initialize_weights(disc) opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999)) opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999)) criterion = nn.BCELoss() fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device) writer_real = SummaryWriter(f"logs/real") writer_fake = SummaryWriter(f"logs/fake") step = 0 gen.train() disc.train() for epoch in range(NUM_EPOCHS): for batch_idx, real in enumerate(loader): real = real.to(device) noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device) fake = gen(noise) disc_real = disc(real).reshape(-1) loss_disc_real = criterion(disc_real, torch.ones_like(disc_real)) disc_fake = disc(fake.detach()).reshape(-1) loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) loss_disc = (loss_disc_real + loss_disc_fake) / 2 disc.zero_grad() loss_disc.backward(retain_graph=True) opt_disc.step()
class CycleGAN(AlignmentModel): """This class implements the alignment model for GAN networks with two generators and two discriminators (cycle GAN). For description of the implemented functions, refer to the alignment model.""" def __init__(self, device, config, generator_a=None, generator_b=None, discriminator_a=None, discriminator_b=None): """Initialize two new generators and two discriminators from the config or use pre-trained ones and create Adam optimizers for all models.""" super().__init__(device, config) self.epoch_losses = [0., 0., 0., 0.] if generator_a is None: generator_a_conf = dict( dim_1=config['dim_b'], dim_2=config['dim_a'], layer_number=config['generator_layers'], layer_expansion=config['generator_expansion'], initialize_generator=config['initialize_generator'], norm=config['gen_norm'], batch_norm=config['gen_batch_norm'], activation=config['gen_activation'], dropout=config['gen_dropout']) self.generator_a = Generator(generator_a_conf, device) self.generator_a.to(device) else: self.generator_a = generator_a if 'optimizer' in config: self.optimizer_g_a = OPTIMIZERS[config['optimizer']]( self.generator_a.parameters(), config['learning_rate']) elif 'optimizer_default' in config: if config['optimizer_default'] == 'sgd': self.optimizer_g_a = OPTIMIZERS[config['optimizer_default']]( self.generator_a.parameters(), config['learning_rate']) else: self.optimizer_g_a = OPTIMIZERS[config['optimizer_default']]( self.generator_a.parameters()) else: self.optimizer_g_a = torch.optim.Adam( self.generator_a.parameters(), config['learning_rate']) if generator_b is None: generator_b_conf = dict( dim_1=config['dim_a'], dim_2=config['dim_b'], layer_number=config['generator_layers'], layer_expansion=config['generator_expansion'], initialize_generator=config['initialize_generator'], norm=config['gen_norm'], batch_norm=config['gen_batch_norm'], activation=config['gen_activation'], dropout=config['gen_dropout']) self.generator_b = Generator(generator_b_conf, device) self.generator_b.to(device) else: self.generator_b = generator_b if 'optimizer' in config: self.optimizer_g_b = OPTIMIZERS[config['optimizer']]( self.generator_b.parameters(), config['learning_rate']) elif 'optimizer_default' in config: if config['optimizer_default'] == 'sgd': self.optimizer_g_b = OPTIMIZERS[config['optimizer_default']]( self.generator_b.parameters(), config['learning_rate']) else: self.optimizer_g_b = OPTIMIZERS[config['optimizer_default']]( self.generator_b.parameters()) else: self.optimizer_g_b = torch.optim.Adam( self.generator_b.parameters(), config['learning_rate']) if discriminator_a is None: discriminator_a_conf = dict( dim=config['dim_a'], layer_number=config['discriminator_layers'], layer_expansion=config['discriminator_expansion'], batch_norm=config['disc_batch_norm'], activation=config['disc_activation'], dropout=config['disc_dropout']) self.discriminator_a = Discriminator(discriminator_a_conf, device) self.discriminator_a.to(device) else: self.discriminator_a = discriminator_a if 'optimizer' in config: self.optimizer_d_a = OPTIMIZERS[config['optimizer']]( self.discriminator_a.parameters(), config['learning_rate']) elif 'optimizer_default' in config: if config['optimizer_default'] == 'sgd': self.optimizer_d_a = OPTIMIZERS[config['optimizer_default']]( self.discriminator_a.parameters(), config['learning_rate']) else: self.optimizer_d_a = OPTIMIZERS[config['optimizer_default']]( self.discriminator_a.parameters()) else: self.optimizer_d_a = torch.optim.Adam( self.discriminator_a.parameters(), config['learning_rate']) if discriminator_b is None: discriminator_b_conf = dict( dim=config['dim_b'], layer_number=config['discriminator_layers'], layer_expansion=config['discriminator_expansion'], batch_norm=config['disc_batch_norm'], activation=config['disc_activation'], dropout=config['disc_dropout']) self.discriminator_b = Discriminator(discriminator_b_conf, device) self.discriminator_b.to(device) else: self.discriminator_b = discriminator_b if 'optimizer' in config: self.optimizer_d_b = OPTIMIZERS[config['optimizer']]( self.discriminator_b.parameters(), config['learning_rate']) elif 'optimizer_default' in config: if config['optimizer_default'] == 'sgd': self.optimizer_d_b = OPTIMIZERS[config['optimizer_default']]( self.discriminator_b.parameters(), config['learning_rate']) else: self.optimizer_d_b = OPTIMIZERS[config['optimizer_default']]( self.discriminator_b.parameters()) else: self.optimizer_d_b = torch.optim.Adam( self.discriminator_b.parameters(), config['learning_rate']) def train(self): self.generator_a.train() self.generator_b.train() self.discriminator_a.train() self.discriminator_b.train() def eval(self): self.generator_a.eval() self.generator_b.eval() self.discriminator_a.eval() self.discriminator_b.eval() def zero_grad(self): self.optimizer_g_a.zero_grad() self.optimizer_g_b.zero_grad() self.optimizer_d_a.zero_grad() self.optimizer_d_b.zero_grad() def optimize_all(self): self.optimizer_g_a.step() self.optimizer_g_b.step() self.optimizer_d_a.step() self.optimizer_d_b.step() def optimize_generator(self): """Do the optimization step only for generators (e.g. when training generators and discriminators separately or in turns).""" self.optimizer_g_a.step() self.optimizer_g_b.step() def optimize_discriminator(self): """Do the optimization step only for discriminators (e.g. when training generators and discriminators separately or in turns).""" self.optimizer_d_a.step() self.optimizer_d_b.step() def change_lr(self, factor): self.current_lr = self.current_lr * factor for param_group in self.optimizer_g_a.param_groups: param_group['lr'] = self.current_lr for param_group in self.optimizer_g_b.param_groups: param_group['lr'] = self.current_lr def update_losses_batch(self, *losses): loss_g_a, loss_g_b, loss_d_a, loss_d_b = losses self.epoch_losses[0] += loss_g_a self.epoch_losses[1] += loss_g_b self.epoch_losses[2] += loss_d_a self.epoch_losses[3] += loss_d_b def complete_epoch(self, epoch_metrics): self.metrics.append(epoch_metrics + [sum(self.epoch_losses)]) self.losses.append(self.epoch_losses) self.epoch_losses = [0., 0., 0., 0.] def print_epoch_info(self): print( f"{len(self.metrics)} ### {self.losses[-1][0]:.2f} - {self.losses[-1][1]:.2f} " f"- {self.losses[-1][2]:.2f} - {self.losses[-1][3]:.2f} ### {self.metrics[-1]}" ) def copy_model(self): self.model_copy = deepcopy(self.generator_a.state_dict()), deepcopy(self.generator_b.state_dict()),\ deepcopy(self.discriminator_a.state_dict()), deepcopy(self.discriminator_b.state_dict()) def restore_model(self): self.generator_a.load_state_dict(self.model_copy[0]) self.generator_b.load_state_dict(self.model_copy[1]) self.discriminator_a.load_state_dict(self.model_copy[2]) self.discriminator_b.load_state_dict(self.model_copy[3]) def export_model(self, test_results, description=None): if description is None: description = f"CycleGAN_{self.config['evaluation']}_{self.config['subset']}" export_cyclegan_alignment(description, self.config, self.generator_a, self.generator_b, self.discriminator_a, self.discriminator_b, self.metrics) save_alignment_test_results(test_results, description) print(f"Saved model to directory {description}.") @classmethod def load_model(cls, name, device): generator_a, generator_b, discriminator_a, discriminator_b, config = load_cyclegan_alignment( name, device) model = cls(device, config, generator_a, generator_b, discriminator_a, discriminator_b) return model
class Trainer: def __init__(self, corpus_data_0, corpus_data_1, *, params, n_samples=10000000): self.fast_text = [FastText(corpus_data_0.model).to(GPU), FastText(corpus_data_1.model).to(GPU)] self.discriminator = Discriminator(params.emb_dim, n_layers=params.d_n_layers, n_units=params.d_n_units, drop_prob=params.d_drop_prob, drop_prob_input=params.d_drop_prob_input, leaky=params.d_leaky, batch_norm=params.d_bn).to(GPU) self.mapping = nn.Linear(params.emb_dim, params.emb_dim, bias=False) self.mapping.weight.data.copy_(torch.diag(torch.ones(params.emb_dim))) self.mapping = self.mapping.to(GPU) self.ft_optimizer, self.ft_scheduler = [], [] for id in [0, 1]: optimizer, scheduler = optimizers.get_sgd_adapt(self.fast_text[id].parameters(), lr=params.ft_lr, mode="max", factor=params.ft_lr_decay, patience=params.ft_lr_patience) self.ft_optimizer.append(optimizer) self.ft_scheduler.append(scheduler) self.a_optimizer, self.a_scheduler = [], [] for id in [0, 1]: optimizer, scheduler = optimizers.get_sgd_adapt( [{"params": self.fast_text[id].u.parameters()}, {"params": self.fast_text[id].v.parameters()}], lr=params.a_lr, mode="max", factor=params.a_lr_decay, patience=params.a_lr_patience) self.a_optimizer.append(optimizer) self.a_scheduler.append(scheduler) if params.d_optimizer == "SGD": self.d_optimizer, self.d_scheduler = optimizers.get_sgd_adapt(self.discriminator.parameters(), lr=params.d_lr, mode="max", wd=params.d_wd) elif params.d_optimizer == "RMSProp": self.d_optimizer, self.d_scheduler = optimizers.get_rmsprop_linear(self.discriminator.parameters(), params.n_steps, lr=params.d_lr, wd=params.d_wd) else: raise Exception(f"Optimizer {params.d_optimizer} not found.") if params.m_optimizer == "SGD": self.m_optimizer, self.m_scheduler = optimizers.get_sgd_adapt(self.mapping.parameters(), lr=params.m_lr, mode="max", wd=params.m_wd, factor=params.m_lr_decay, patience=params.m_lr_patience) elif params.m_optimizer == "RMSProp": self.m_optimizer, self.m_scheduler = optimizers.get_rmsprop_linear(self.mapping.parameters(), params.n_steps, lr=params.m_lr, wd=params.m_wd) else: raise Exception(f"Optimizer {params.m_optimizer} not found") self.m_beta = params.m_beta self.smooth = params.smooth self.wgan = params.wgan self.d_clip_mode = params.d_clip_mode if params.wgan: self.loss_fn = _wasserstein_distance else: self.loss_fn = nn.BCEWithLogitsLoss(reduction="elementwise_mean") self.corpus_data_queue = [ _data_queue(corpus_data_0, n_threads=(params.n_threads + 1) // 2, n_sentences=params.n_sentences, batch_size=params.ft_bs), _data_queue(corpus_data_1, n_threads=(params.n_threads + 1) // 2, n_sentences=params.n_sentences, batch_size=params.ft_bs) ] self.sampler = [ WordSampler(corpus_data_0.dic, n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top), WordSampler(corpus_data_1.dic, n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top)] self.d_bs = params.d_bs self.dic_0, self.dic_1 = corpus_data_0.dic, corpus_data_1.dic self.d_gp = params.d_gp def fast_text_step(self): losses = [] for id in [0, 1]: self.ft_optimizer[id].zero_grad() u_b, v_b = self.corpus_data_queue[id].__next__() s = self.fast_text[id](u_b, v_b) loss = FastText.loss_fn(s) loss.backward() self.ft_optimizer[id].step() losses.append(loss.item()) return losses[0], losses[1] def get_adv_batch(self, *, reverse, fix_embedding=False, gp=False): batch = [[self.sampler[id].sample() for _ in range(self.d_bs)] for id in [0, 1]] batch = [self.fast_text[id].model.get_bag(batch[id], self.fast_text[id].u.weight.device) for id in [0, 1]] if fix_embedding: with torch.no_grad(): x = [self.fast_text[id].u(batch[id][0], batch[id][1]).view(self.d_bs, -1) for id in [0, 1]] else: x = [self.fast_text[id].u(batch[id][0], batch[id][1]).view(self.d_bs, -1) for id in [0, 1]] y = torch.FloatTensor(self.d_bs * 2).to(GPU).uniform_(0.0, self.smooth) if reverse: y[: self.d_bs] = 1 - y[: self.d_bs] else: y[self.d_bs:] = 1 - y[self.d_bs:] x[0] = self.mapping(x[0]) if gp: t = torch.FloatTensor(self.d_bs, 1).to(GPU).uniform_(0.0, 1.0).expand_as(x[0]) z = x[0] * t + x[1] * (1.0 - t) x = torch.cat(x, 0) return x, y, z else: x = torch.cat(x, 0) return x, y def adversarial_step(self, fix_embedding=False): for id in [0, 1]: self.a_optimizer[id].zero_grad() self.m_optimizer.zero_grad() self.discriminator.eval() x, y = self.get_adv_batch(reverse=True, fix_embedding=fix_embedding) y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) loss.backward() for id in [0, 1]: self.a_optimizer[id].step() self.m_optimizer.step() _orthogonalize(self.mapping, self.m_beta) return loss.item() def discriminator_step(self): self.d_optimizer.zero_grad() self.discriminator.train() with torch.no_grad(): if self.d_gp > 0: x, y, z = self.get_adv_batch(reverse=False, gp=True) else: x, y = self.get_adv_batch(reverse=False) z = None y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) if self.d_gp > 0: z.requires_grad_() z_out = self.discriminator(z) g = autograd.grad(z_out, z, grad_outputs=torch.ones_like(z_out, device=GPU), retain_graph=True, create_graph=True, only_inputs=True)[0] gp = torch.mean((g.norm(p=2, dim=1) - 1.0) ** 2) loss += self.d_gp * gp loss.backward() self.d_optimizer.step() if self.wgan: self.discriminator.clip_weights(self.d_clip_mode) return loss.item() def scheduler_step(self, metric): for id in [0, 1]: self.ft_scheduler[id].step(metric) self.a_scheduler[id].step(metric) # self.d_scheduler.step(metric) self.m_scheduler.step(metric)
class BigGAN(): """Big GAN""" def __init__(self, device, dataloader, num_classes, configs): self.device = device self.dataloader = dataloader self.num_classes = num_classes # model settings & hyperparams # self.total_steps = configs.total_steps self.epochs = configs.epochs self.d_iters = configs.d_iters self.g_iters = configs.g_iters self.batch_size = configs.batch_size self.imsize = configs.imsize self.nz = configs.nz self.ngf = configs.ngf self.ndf = configs.ndf self.g_lr = configs.g_lr self.d_lr = configs.d_lr self.beta1 = configs.beta1 self.beta2 = configs.beta2 # instance noise self.inst_noise_sigma = configs.inst_noise_sigma self.inst_noise_sigma_iters = configs.inst_noise_sigma_iters # model logging and saving self.log_step = configs.log_step self.save_epoch = configs.save_epoch self.model_path = configs.model_path self.sample_path = configs.sample_path # pretrained self.pretrained_model = configs.pretrained_model # building self.build_model() # archive of all losses self.ave_d_losses = [] self.ave_d_losses_real = [] self.ave_d_losses_fake = [] self.ave_g_losses = [] if self.pretrained_model: self.load_pretrained() def build_model(self): """Initiate Generator and Discriminator""" self.G = Generator(self.nz, self.ngf, self.num_classes).to(self.device) self.D = Discriminator(self.ndf, self.num_classes).to(self.device) self.g_optimizer = optim.Adam( filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = optim.Adam( filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2]) print("Generator Parameters: ", parameters(self.G)) print(self.G) print("Discriminator Parameters: ", parameters(self.D)) print(self.D) print("Number of classes: ", self.num_classes) def load_pretrained(self): """Loading pretrained model""" checkpoint = torch.load( os.path.join(self.model_path, "{}_biggan.pth".format(self.pretrained_model))) # load models self.G.load_state_dict(checkpoint["g_state_dict"]) self.D.load_state_dict(checkpoint["d_state_dict"]) # load optimizers self.g_optimizer.load_state_dict(checkpoint["g_optimizer"]) self.d_optimizer.load_state_dict(checkpoint["d_optimizer"]) # load losses self.ave_d_losses = checkpoint["ave_d_losses"] self.ave_d_losses_real = checkpoint["ave_d_losses_real"] self.ave_d_losses_fake = checkpoint["ave_d_losses_fake"] self.ave_g_losses = checkpoint["ave_g_losses"] print("Loading pretrained models (epoch: {})..!".format( self.pretrained_model)) def reset_grad(self): """Reset gradients""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() def train(self): """Train model""" step_per_epoch = len(self.dataloader) epochs = self.epochs total_steps = epochs * step_per_epoch # fixed z and labels for sampling generator images fixed_z = tensor2var(torch.randn(self.batch_size, self.nz), device=self.device) fixed_labels = tensor2var(torch.from_numpy( np.tile(np.arange(self.num_classes), self.batch_size)).long(), device=self.device) print("Initiating Training") print("Epochs: {}, Total Steps: {}, Steps/Epoch: {}".format( epochs, total_steps, step_per_epoch)) if self.pretrained_model: start_epoch = self.pretrained_model else: start_epoch = 0 self.D.train() self.G.train() # Instance noise - make random noise mean (0) and std for injecting inst_noise_mean = torch.full( (self.batch_size, 3, self.imsize, self.imsize), 0).to(self.device) inst_noise_std = torch.full( (self.batch_size, 3, self.imsize, self.imsize), self.inst_noise_sigma).to(self.device) # total time start_time = time.time() for epoch in range(start_epoch, epochs): # local losses d_losses = [] d_losses_real = [] d_losses_fake = [] g_losses = [] data_iter = iter(self.dataloader) for step in range(step_per_epoch): # Instance noise std is linearly annealed from self.inst_noise_sigma to 0 thru self.inst_noise_sigma_iters inst_noise_sigma_curr = 0 if step > self.inst_noise_sigma_iters else ( 1 - step / self.inst_noise_sigma_iters) * self.inst_noise_sigma inst_noise_std.fill_(inst_noise_sigma_curr) # get real images real_images, real_labels = next(data_iter) real_images = real_images.to(self.device) real_labels = real_labels.to(self.device) # ================== TRAIN DISCRIMINATOR ================== # for _ in range(self.d_iters): self.reset_grad() # TRAIN REAL # creating instance noise inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to( self.device) # adding noise to real images d_real = self.D(real_images + inst_noise, real_labels) d_loss_real = loss_hinge_dis_real(d_real) d_loss_real.backward() # delete loss if (step + 1) % self.log_step != 0: del d_real, d_loss_real # TRAIN FAKE # create fake images using latent vector z = tensor2var(torch.randn(real_images.size(0), self.nz), device=self.device) fake_images = self.G(z, real_labels) # creating instance noise inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to( self.device) # adding noise to fake images # detach fake_images tensor from graph d_fake = self.D(fake_images.detach() + inst_noise, real_labels) d_loss_fake = loss_hinge_dis_fake(d_fake) d_loss_fake.backward() # delete loss, output del fake_images if (step + 1) % self.log_step != 0: del d_fake, d_loss_fake # optimize D self.d_optimizer.step() # ================== TRAIN GENERATOR ================== # for _ in range(self.g_iters): self.reset_grad() # create new latent vector z = tensor2var(torch.randn(real_images.size(0), self.nz), device=self.device) # generate fake images inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to( self.device) fake_images = self.G(z, real_labels) g_fake = self.D(fake_images + inst_noise, real_labels) # compute hinge loss for G g_loss = loss_hinge_gen(g_fake) g_loss.backward() del fake_images if (step + 1) % self.log_step != 0: del g_fake, g_loss # optimize G self.g_optimizer.step() # logging step progression if (step + 1) % self.log_step == 0: d_loss = d_loss_real + d_loss_fake # logging losses d_losses.append(d_loss.item()) d_losses_real.append(d_loss_real.item()) d_losses_fake.append(d_loss_fake.item()) g_losses.append(g_loss.item()) # print out elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print( "Elapsed [{}], Epoch: [{}/{}], Step [{}/{}], g_loss: {:.4f}, d_loss: {:.4f}," " d_loss_real: {:.4f}, d_loss_fake: {:.4f}".format( elapsed, (epoch + 1), epochs, (step + 1), step_per_epoch, g_loss, d_loss, d_loss_real, d_loss_fake)) del d_real, d_loss_real, d_fake, d_loss_fake, g_fake, g_loss # logging average losses over epoch self.ave_d_losses.append(mean(d_losses)) self.ave_d_losses_real.append(mean(d_losses_real)) self.ave_d_losses_fake.append(mean(d_losses_fake)) self.ave_g_losses.append(mean(g_losses)) # epoch update print( "Elapsed [{}], Epoch: [{}/{}], ave_g_loss: {:.4f}, ave_d_loss: {:.4f}," " ave_d_loss_real: {:.4f}, ave_d_loss_fake: {:.4f},".format( elapsed, epoch + 1, epochs, self.ave_g_losses[epoch], self.ave_d_losses[epoch], self.ave_d_losses_real[epoch], self.ave_d_losses_fake[epoch])) # sample images every epoch fake_images = self.G(fixed_z, fixed_labels) fake_images = denorm(fake_images.data) save_image( fake_images, os.path.join(self.sample_path, "Epoch {}.png".format(epoch + 1))) # save model if (epoch + 1) % self.save_epoch == 0: torch.save( { "g_state_dict": self.G.state_dict(), "d_state_dict": self.D.state_dict(), "g_optimizer": self.g_optimizer.state_dict(), "d_optimizer": self.d_optimizer.state_dict(), "ave_d_losses": self.ave_d_losses, "ave_d_losses_real": self.ave_d_losses_real, "ave_d_losses_fake": self.ave_d_losses_fake, "ave_g_losses": self.ave_g_losses }, os.path.join(self.model_path, "{}_biggan.pth".format(epoch + 1))) print("Saving models (epoch {})..!".format(epoch + 1)) def plot(self): plt.plot(self.ave_d_losses) plt.plot(self.ave_d_losses_real) plt.plot(self.ave_d_losses_fake) plt.plot(self.ave_g_losses) plt.legend(["d loss", "d real", "d fake", "g loss"], loc="upper left") plt.show()
def main(): random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) gen_data_loader = Gen_Data_loader(BATCH_SIZE) likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing vocab_size = 2000 dis_data_loader = Dis_dataloader(BATCH_SIZE) generator = Generator(vocab_size, EMB_DIM, HIDDEN_DIM, 1, START_TOKEN, SEQ_LENGTH).to(device) target_lstm = Generator(vocab_size, EMB_DIM, HIDDEN_DIM, 1, START_TOKEN, SEQ_LENGTH, oracle=True).to(device) discriminator = Discriminator(vocab_size, dis_embedding_dim, dis_filter_sizes, dis_num_filters, dis_dropout).to(device) generate_samples(target_lstm, BATCH_SIZE, generated_num, positive_file) gen_data_loader.create_batches(positive_file) pre_gen_opt = torch.optim.Adam(generator.parameters(), 1e-2) adv_gen_opt = torch.optim.Adam(generator.parameters(), 1e-2) dis_opt = torch.optim.Adam(discriminator.parameters(), 1e-4) dis_criterion = nn.NLLLoss() log = open('save/experiment-log.txt', 'w') print('Start pre-training...') log.write('pre-training...\n') for epoch in range(PRE_EPOCH_NUM): loss = pre_train_epoch(generator, pre_gen_opt, gen_data_loader) if (epoch + 1) % 5 == 0: generate_samples(generator, BATCH_SIZE, generated_num, eval_file) likelihood_data_loader.create_batches(eval_file) test_loss = target_loss(target_lstm, likelihood_data_loader) print('pre-train epoch ', epoch + 1, '\tnll:\t', test_loss) buffer = 'epoch:\t' + str(epoch + 1) + '\tnll:\t' + str(test_loss) + '\n' log.write(buffer) print('Start pre-training discriminator...') # Train 3 epoch on the generated data and do this for 50 times for e in range(50): generate_samples(generator, BATCH_SIZE, generated_num, negative_file) dis_data_loader.load_train_data(positive_file, negative_file) d_total_loss = [] for _ in range(3): dis_data_loader.reset_pointer() total_loss = [] for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() x_batch = x_batch.to(device) y_batch = y_batch.to(device) dis_output = discriminator(x_batch.detach()) d_loss = dis_criterion(dis_output, y_batch.detach()) dis_opt.zero_grad() d_loss.backward() dis_opt.step() total_loss.append(d_loss.data.cpu().numpy()) d_total_loss.append(np.mean(total_loss)) if (e + 1) % 5 == 0: buffer = 'Epoch [{}], discriminator loss [{:.4f}]\n'.format( e + 1, np.mean(d_total_loss)) print(buffer) log.write(buffer) rollout = Rollout(generator, 0.8) print( '#########################################################################' ) print('Start Adversarial Training...') log.write('adversarial training...\n') gan_loss = GANLoss() for total_batch in range(TOTAL_BATCH): # Train the generator for one step discriminator.eval() for it in range(1): samples, _ = generator.sample(num_samples=BATCH_SIZE) rewards = rollout.get_reward(samples, 16, discriminator) prob = generator(samples.detach()) adv_loss = gan_loss(prob, samples.detach(), rewards.detach()) adv_gen_opt.zero_grad() adv_loss.backward() nn.utils.clip_grad_norm_(generator.parameters(), 5.0) adv_gen_opt.step() # Test if (total_batch + 1) % 5 == 0: generate_samples(generator, BATCH_SIZE, generated_num, eval_file) likelihood_data_loader.create_batches(eval_file) test_loss = target_loss(target_lstm, likelihood_data_loader) self_bleu_score = self_bleu(generator) buffer = 'epoch:\t' + str(total_batch + 1) + '\tnll:\t' + str( test_loss) + '\tSelf Bleu:\t' + str(self_bleu_score) + '\n' print(buffer) log.write(buffer) # Update roll-out parameters rollout.update_params() # Train the discriminator discriminator.train() for _ in range(5): generate_samples(generator, BATCH_SIZE, generated_num, negative_file) dis_data_loader.load_train_data(positive_file, negative_file) d_total_loss = [] for _ in range(3): dis_data_loader.reset_pointer() total_loss = [] for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() x_batch = x_batch.to(device) y_batch = y_batch.to(device) dis_output = discriminator(x_batch.detach()) d_loss = dis_criterion(dis_output, y_batch.detach()) dis_opt.zero_grad() d_loss.backward() dis_opt.step() total_loss.append(d_loss.data.cpu().numpy()) d_total_loss.append(np.mean(total_loss)) if (total_batch + 1) % 5 == 0: buffer = 'Epoch [{}], discriminator loss [{:.4f}]\n'.format( total_batch + 1, np.mean(d_total_loss)) print(buffer) log.write(buffer) log.close()
def main(): # # -------------------- Data -------------------- num_workers = 8 # number of subprocesses to use for data loading batch_size = 64 # how many samples per batch to load transform = transforms.ToTensor() # convert data to torch.FloatTensor train_data = datasets.MNIST(root='../data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers) # # Obtain one batch of training images # dataiter = iter(train_loader) # images, labels = dataiter.next() # images = images.numpy() # # Get one image from the batch for visualization # img = np.squeeze(images[0]) # fig = plt.figure(figsize=(3, 3)) # ax = fig.add_subplot(111) # ax.imshow(img, cmap='gray') # plt.show() # # -------------------- Discriminator and Generator -------------------- # Discriminator hyperparams input_size = 784 # Size of input image to discriminator (28*28) d_output_size = 1 # Size of discriminator output (real or fake) d_hidden_size = 32 # Size of last hidden layer in the discriminator # Generator hyperparams z_size = 100 # Size of latent vector to give to generator g_output_size = 784 # Size of discriminator output (generated image) g_hidden_size = 32 # Size of first hidden layer in the generator # Instantiate discriminator and generator D = Discriminator(input_size, d_hidden_size, d_output_size) G = Generator(z_size, g_hidden_size, g_output_size) # # -------------------- Optimizers and Criterion -------------------- # Training hyperparams num_epochs = 100 print_every = 400 lr = 0.002 # Create optimizers for the discriminator and generator, respectively d_optimizer = optim.Adam(D.parameters(), lr) g_optimizer = optim.Adam(G.parameters(), lr) losses = [] # keep track of generated "fake" samples criterion = nn.BCEWithLogitsLoss() # -------------------- Training -------------------- D.train() G.train() # Get some fixed data for sampling. These are images that are held # constant throughout training, and allow us to inspect the model's performance sample_size = 16 fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size)) fixed_z = torch.from_numpy(fixed_z).float() samples = [] # keep track of loss for epoch in range(num_epochs): for batch_i, (real_images, _) in enumerate(train_loader): batch_size = real_images.size(0) # Important rescaling step real_images = real_images * 2 - 1 # rescale input images from [0,1) to [-1, 1) # Generate fake images, used for both discriminator and generator z = np.random.uniform(-1, 1, size=(batch_size, z_size)) z = torch.from_numpy(z).float() fake_images = G(z) real_labels = torch.ones(batch_size) fake_labels = torch.zeros(batch_size) # ============================================ # TRAIN THE DISCRIMINATOR # ============================================ d_optimizer.zero_grad() # 1. Train with real images # Compute the discriminator losses on real images D_real = D(real_images) d_real_loss = real_loss(criterion, D_real, real_labels, smooth=True) # 2. Train with fake images # Compute the discriminator losses on fake images # ------------------------------------------------------- # ATTENTION: # *.detach(), thus, generator is fixed when we optimize # the discriminator # ------------------------------------------------------- D_fake = D(fake_images.detach()) d_fake_loss = fake_loss(criterion, D_fake, fake_labels) # 3. Add up loss and perform backprop d_loss = (d_real_loss + d_fake_loss) * 0.5 d_loss.backward() d_optimizer.step() # ========================================= # TRAIN THE GENERATOR # ========================================= g_optimizer.zero_grad() # Make the discriminator fixed when optimizing the generator set_model_gradient(D, False) # 1. Train with fake images and flipped labels # Compute the discriminator losses on fake images using flipped labels! G_D_fake = D(fake_images) g_loss = real_loss(criterion, G_D_fake, real_labels) # use real loss to flip labels # 2. Perform backprop g_loss.backward() g_optimizer.step() # Make the discriminator require_grad=True after optimizing the generator set_model_gradient(D, True) # ========================================= # Print some loss stats # ========================================= if batch_i % print_every == 0: print( 'Epoch [{:5d}/{:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'. format(epoch + 1, num_epochs, d_loss.item(), g_loss.item())) # AFTER EACH EPOCH losses.append((d_loss.item(), g_loss.item())) # generate and save sample, fake images G.eval() # eval mode for generating samples samples_z = G(fixed_z) samples.append(samples_z) view_samples(-1, samples, "last_sample.png") G.train() # back to train mode # Save models and training generator samples torch.save(G.state_dict(), "G.pth") torch.save(D.state_dict(), "D.pth") with open('train_samples.pkl', 'wb') as f: pkl.dump(samples, f) # Plot the loss curve fig, ax = plt.subplots() losses = np.array(losses) plt.plot(losses.T[0], label='Discriminator') plt.plot(losses.T[1], label='Generator') plt.title("Training Losses") plt.legend() plt.savefig("loss.png") plt.show()