def train_batch(self, imgs): #print(imgs.min(), imgs.mean(), imgs.max()) imgs = imgs.to(self.device) self.g.train() self.d.train() # train discriminator toggle_grad(self.g, False) toggle_grad(self.d, True) self.optimizer_d.zero_grad() batch_imgs, labels = self.make_adversarial_batch(imgs) real, fake = batch_imgs[:self.args.batch_size], batch_imgs[self.args. batch_size:] # torchvision.utils.save_image(real, 'batch_real.png', normalize=True, range=(-1, 1)) # torchvision.utils.save_image(fake, 'batch_fake.png', normalize=True, range=(-1, 1)) if self.args.grad_penalty: real.requires_grad_() p_labels_real = self.d(real) p_labels_fake = self.d(fake.detach()) p_labels = torch.cat([p_labels_real, p_labels_fake], dim=0) # loss_real, loss_fake = self.d_loss(labels, p_labels) # d_loss = loss_real + loss_fake # d_loss = ( # self.bce_loss(p_labels_real, labels[:len(labels) // 2]) + # self.bce_loss(p_labels_fake, labels[len(labels) // 2:]) # ) d_loss = self.bce_loss(p_labels, labels) d_grad_penalty = 0. if self.args.grad_penalty: d_grad_penalty = self.args.grad_penalty * gradient_penalty( p_labels_real, real) d_loss += d_grad_penalty d_loss.backward() self.optimizer_d.step() # train generator toggle_grad(self.g, True) toggle_grad(self.d, False) self.optimizer_g.zero_grad() batch_imgs, labels = self.make_generator_batch(imgs) #torchvision.utils.save_image(batch_imgs, 'batch.png', normalize=True, range=(-1, 1)) p_labels = self.d(batch_imgs) #g_loss = self.g_loss(p_labels) g_loss = self.bce_loss(p_labels, labels) g_loss.backward() self.optimizer_g.step() self.update_target_generator() return dict(g_loss=float(g_loss), d_loss=float(d_loss), gp=float(d_grad_penalty))
def train_batch(self, input_indexes): input_indexes = input_indexes.long().to(self.device) # train embedding self.embedding.train() toggle_grad(self.embedding, True) self.optimizer_embedding.zero_grad() # extract random windows from the input docs window_size = self.args.context * 2 + 1 offsets = np.random.randint(0, window_size, len(input_indexes)) windows = torch.stack([ input_indexes[i, ..., offset:offset + window_size] for i, offset in enumerate(offsets) ]) # get a list of pivot words and their contexts from the windows words = windows[..., self.args.context] contexts = torch.cat([ windows[..., :self.args.context], windows[..., self.args.context + 1:] ], dim=-1) # get the loss embedding_loss = self.embedding.loss(words, contexts) embedding_loss.backward() self.optimizer_embedding.step() self.pretraining_embedding = max(self.pretraining_embedding - 1, 0) if not self.pretraining_embedding: inputs = self.embedding(input_indexes).permute((0, 2, 1)).detach() self.g.train() self.d.train() # train discriminator toggle_grad(self.embedding, False) toggle_grad(self.g, False) toggle_grad(self.d, True) self.optimizer_d.zero_grad() batch_inputs, labels = self.make_adversarial_batch(inputs) real, fake = batch_inputs[:self.args.batch_size], batch_inputs[ self.args.batch_size:] if self.args.grad_penalty: real.requires_grad_() p_labels_real = self.d(real) p_labels_fake = self.d(fake.detach()) p_labels = torch.cat([p_labels_real, p_labels_fake], dim=0) d_loss = self.bce_loss(p_labels, labels) d_grad_penalty = 0. if self.args.grad_penalty: d_grad_penalty = self.args.grad_penalty * gradient_penalty( p_labels_real, real) d_loss += d_grad_penalty d_loss.backward() self.optimizer_d.step() # train generator toggle_grad(self.g, True) toggle_grad(self.d, False) self.optimizer_g.zero_grad() batch_inputs, labels = self.make_generator_batch(inputs) p_labels = self.d(batch_inputs) g_loss = self.bce_loss(p_labels, labels) g_loss.backward() self.optimizer_g.step() else: g_loss = d_loss = d_grad_penalty = 0. self.update_target_generator() return dict(g_loss=float(g_loss), d_loss=float(d_loss), gp=float(d_grad_penalty), embedding_loss=float(embedding_loss))
def train_batch(self, imgs): #print(imgs.min(), imgs.mean(), imgs.max()) imgs = imgs.to(self.device) self.g.train() self.d.train() # train discriminator toggle_grad(self.g, False) toggle_grad(self.d, True) self.optimizer_d.zero_grad() batch_imgs, labels, z = self.make_adversarial_batch(imgs) real, fake = batch_imgs[:self.args.batch_size], batch_imgs[self.args.batch_size:] if self.args.grad_penalty: real.requires_grad_() p_labels_real, _ = self.d(real) p_labels_fake, p_codes = self.d(fake.detach()) p_labels = torch.cat([p_labels_real, p_labels_fake], dim=0) d_loss = self.bce_loss(p_labels, labels) # infogan loss d_code_loss = 0 if self.args.info_cat_dims: z_cat_code = self.z_categorical_code(z) p_z_cat_code = self.z_categorical_code(p_codes) d_cat_code_loss = self.bce_loss(p_z_cat_code, z_cat_code) d_code_loss += d_cat_code_loss if self.args.info_cont_dims: z_cont_code = self.z_continuous_code(z) p_z_cont_code = self.z_continuous_code(p_codes) d_cont_code_loss = self.mse_loss(p_z_cont_code, z_cont_code) d_code_loss += d_cont_code_loss d_loss += self.args.info_w * d_code_loss d_grad_penalty = 0. if self.args.grad_penalty: d_grad_penalty = self.args.grad_penalty * gradient_penalty(p_labels_real, real) d_loss += d_grad_penalty d_loss.backward() self.optimizer_d.step() # train generator toggle_grad(self.g, True) toggle_grad(self.d, False) self.optimizer_g.zero_grad() batch_imgs, labels, z = self.make_generator_batch(imgs) p_labels, p_codes = self.d(batch_imgs) g_loss = self.bce_loss(p_labels, labels) # infogan loss g_code_loss = 0. if self.args.info_cat_dims: z_cat_code = self.z_categorical_code(z) p_z_cat_code = self.z_categorical_code(p_codes) g_cat_code_loss = self.bce_loss(p_z_cat_code, z_cat_code) g_code_loss += g_cat_code_loss if self.args.info_cont_dims: z_cont_code = self.z_continuous_code(z) p_z_cont_code = self.z_continuous_code(p_codes) g_cont_code_loss = self.mse_loss(p_z_cont_code, z_cont_code) g_code_loss += g_cont_code_loss g_loss += self.args.info_w * g_code_loss g_loss.backward() self.optimizer_g.step() self.update_target_generator() return dict( g_loss=float(g_loss), g_code_loss=float(g_code_loss), d_loss=float(d_loss), d_code_loss=float(d_code_loss), gp=float(d_grad_penalty) )