def refinement_step(self, imgs, imgs_pred, foreground, ca_masks): # Generator (Refinement Network Loss) refinement_losses = LossManager() # unpack f_boxes, f_obj_to_img = foreground fake_local_patches = crop_bbox_batch(imgs_pred, f_boxes, f_obj_to_img, self.patch_size) real_local_patches = crop_bbox_batch(imgs, f_boxes, f_obj_to_img, self.patch_size) # Get Spatial Discounted L1 Loss spatial_loss = spatial_l1(fake_local_patches, real_local_patches, f_obj_to_img, self.patch_size, gamma = self.spatial_gamma) refinement_losses.add_loss(self.spatial_loss_weight * spatial_loss, 'r_spatial_loss') # Get L1 Pix Loss l1_pixel_weight = self.l1_pixel_loss_weight l1_pixel_loss = F.l1_loss(imgs_pred, imgs) refinement_losses.add_loss(self.l1_pixel_loss_weight * l1_pixel_loss, 'r_pix_loss') # Get WGAN Loss fake_local_pred, fake_global_pred = self.critic(fake_local_patches, imgs_pred, f_obj_to_img, get_ave = True) # Local Patch Loss local_loss = self.critic_g_loss(fake_local_pred) # Global Loss global_loss = self.critic_g_loss(fake_global_pred) critic_loss = self.critic_global_weight * global_loss + local_loss refinement_losses.add_loss(self.critic_g_weight * critic_loss, 'r_wgan_loss') # backward self.reset_grad() refinement_losses.total_loss.backward() self.refinement_optimizer.step() return refinement_losses
def extract_patches(imgs, boxes, obj_to_img, patch_size=32, batch_size=None): # Box Format: (x0 , y0, x1, y1) img_patches = [] # Utilize crop function of scene graph img_patches = crop_bbox_batch(imgs, boxes, obj_to_img, patch_size) # Check if some image does not have patches # print("{} VS. {}".format(len(set(obj_to_img.data.cpu().numpy())), batch_size)) # comp_obj_to_img = None # if(len(set(obj_to_img.data.cpu().numpy())) != batch_size): # print("I AM NOT COMPLETE") # _, comp_obj_to_img = complete_patches(imgs, img_patches, obj_to_img, patch_size, batch_size) return img_patches
def forward(self, imgs, objs, boxes, obj_to_img): crops = crop_bbox_batch(imgs, boxes, obj_to_img, self.object_size) real_scores, ac_loss = self.discriminator(crops, objs) return real_scores, ac_loss
def critic_step(self, imgs, imgs_pred, foreground, ca_masks): # detatch img_pred and ca_masks so refinement wouldn't be updated critic_losses = LossManager() imgs_fake = imgs_pred.detach() f_boxes, f_obj_to_img = foreground patch_masks = crop_bbox_batch( ca_masks, f_boxes, f_obj_to_img, self.patch_size) # build input for critics fake_local_patches = crop_bbox_batch( imgs_fake, f_boxes, f_obj_to_img, self.patch_size) real_local_patches = crop_bbox_batch( imgs, f_boxes, f_obj_to_img, self.patch_size) local_vectors = torch.cat( [fake_local_patches, real_local_patches], dim=0) global_vectors = torch.cat([imgs_fake, imgs], dim=0) # Feed to the critic then split output to (fake, real) local_critic_out, global_critic_out = self.critic( local_vectors, global_vectors, f_obj_to_img) fake_local_pred, real_local_pred = torch.split( local_critic_out, imgs.shape[0], dim=0) fake_global_pred, real_global_pred = torch.split( global_critic_out, imgs.shape[0], dim=0) # Local Loss local_loss = self.critic_d_loss(real_local_pred, fake_local_pred) # Global Loss global_loss = self.critic_d_loss(real_global_pred, fake_global_pred) critic_losses.add_loss(global_loss + local_loss, 'c_wgan_loss') # Gradient Penalty Loss # real_local_patches, fake_local_patches = gp_vectors # global_masks, local_masks = gp_masks # Interpolate Images local_interpolate = random_interpolate( real_local_patches, fake_local_patches) global_interpolate = random_interpolate(imgs, imgs_fake) local_interpolate = to_var(local_interpolate, requires_grad = True) global_interpolate = to_var(global_interpolate, requires_grad = True) # GP Loss # Forward interpolated images to the critic local_gp_out, global_gp_out = self.critic( local_interpolate, global_interpolate, f_obj_to_img) local_gp = self.critic_gp_loss( local_interpolate, local_gp_out, mask=patch_masks, f_obj_to_img=f_obj_to_img) global_gp = self.critic_gp_loss( global_interpolate, global_gp_out, mask=ca_masks) critic_losses.add_loss( self.critic_gp_weight * (local_gp + global_gp), 'c_gp_loss') # backpropagate and optimize critic self.reset_grad() critic_losses.total_loss.backward() self.critic_optimizer.step() return critic_losses
def train(self): iters_per_epoch = len(self.train_loader) print("Iterations per epoch: " + str(iters_per_epoch)) # Start with trained model if exists if self.pretrained_model: start = int(self.pretrained_model.split('_')[0]) else: start = 0 fixed_batch = self.build_sample_images(self.test_loader) # Start training iter_ctr = self.init_iterations e = iter_ctr // iters_per_epoch start_time = time.time() while True: # Stop training if iter_ctr reached num_iterations if iter_ctr >= self.num_iterations: break e += 1 for i, batch in enumerate(tqdm(self.train_loader)): if iter_ctr == self.eval_mode_after: self.generator.eval() self.gen_optimizer = torch.optim.Adam( self.generator.parameters(), lr=self.learning_rate) elif (iter_ctr < self.eval_mode_after): self.generator.train() masks = None if len(batch) == 6: imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch elif len(batch) == 7: imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch start = time.time() # Prepare Data imgs = to_var(imgs) objs = to_var(objs) boxes = to_var(boxes) triples = to_var(triples) obj_to_img = to_var(obj_to_img) triple_to_img = to_var(triple_to_img) if (masks): masks = to_var(masks) predicates = triples[:, 1] # get p from triples(s, p ,o) # variables needed for generator and discriminator steps step_vars = (imgs, objs, boxes, obj_to_img, predicates, masks) # Foreground objects f_inds = [ i for i, obj in enumerate(objs) if obj in self.foreground_objs ] f_boxes = boxes[f_inds] f_obj_to_img = obj_to_img[f_inds] # Build Masks for Contextual Attention Module ca_masks = build_masks(imgs, f_boxes, f_obj_to_img) ca_masks = to_var(ca_masks) patch_masks = crop_bbox_batch(ca_masks, f_boxes, f_obj_to_img, self.patch_size) # Forward to Model model_boxes = boxes model_masks = masks model_out = self.generator(objs, triples, obj_to_img, ca_masks, boxes_gt=model_boxes, masks_gt=model_masks) imgs_pred, _, _, _ = model_out # build input for critics fake_local_patches = crop_bbox_batch(imgs_pred, f_boxes, f_obj_to_img, self.patch_size) real_local_patches = crop_bbox_batch(imgs, f_boxes, f_obj_to_img, self.patch_size) # Forward to Critic critic_pred = None if self.critic is not None: local_vectors = torch.cat( [fake_local_patches, real_local_patches], dim=0) global_vectors = torch.cat([imgs_pred, imgs], dim=0) # Feed input to the critic to get scores local_critic_out, global_critic_out = self.critic( local_vectors, global_vectors, f_obj_to_img) # Split scores to (fake, real) fake_local_pred, real_local_pred = torch.split( local_critic_out, imgs.shape[0], dim=0) fake_global_pred, real_global_pred = torch.split( global_critic_out, imgs.shape[0], dim=0) critic_pred = (fake_local_pred, real_local_pred, fake_global_pred, real_global_pred) # for gradient penalty of critic gp_vectors = (real_local_patches, fake_local_patches) gp_masks = (ca_masks, patch_masks) # Generator Step total_loss, losses = self.generator_step( step_vars, model_out, critic_pred) # Logging loss = {} loss['G/total_loss'] = total_loss.data.item() loss = self.log_losses(loss, 'G', losses) # Discriminator Step total_loss, dis_obj_losses, dis_img_losses, critic_losses = self.discriminator_step( imgs_pred, step_vars, f_obj_to_img, critic_pred, gp_vectors, gp_masks) loss['D/total_loss'] = total_loss.data.item() loss = self.log_losses(loss, 'D', dis_obj_losses) loss = self.log_losses(loss, 'D', dis_img_losses) loss = self.log_losses(loss, 'C', critic_losses) # Print out log info if (i + 1) % self.log_step == 0: elapsed = time.time() - start_time total_time = (self.num_iterations - iter_ctr + 1) * elapsed / (iter_ctr + 1) epoch_time = (iters_per_epoch - i) * elapsed / \ (iter_ctr + 1) epoch_time = str(datetime.timedelta(seconds=epoch_time)) total_time = str(datetime.timedelta(seconds=total_time)) elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed {}/{} -- {} , Iteration [{}/{}], Epoch [{}]".format( elapsed, epoch_time, total_time, iter_ctr + 1, self.num_iterations, e) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, iter_ctr + 1) # Save model checkpoints if (iter_ctr + 1) % self.model_save_step == 0: torch.save( self.generator.state_dict(), os.path.join(self.model_save_path, '{}_G.pth'.format(iter_ctr + 1))) torch.save( self.obj_discriminator.state_dict(), os.path.join(self.model_save_path, '{}_D_OBJ.pth'.format(iter_ctr + 1))) torch.save( self.img_discriminator.state_dict(), os.path.join(self.model_save_path, '{}_D_IMG.pth'.format(iter_ctr + 1))) torch.save( self.critic.state_dict(), os.path.join(self.model_save_path, '{}_C_IMG.pth'.format(iter_ctr + 1))) if (iter_ctr + 1) % self.sample_step == 0: self.sample_images(fixed_batch, iter_ctr) iter_ctr += 1