コード例 #1
0
    def refinement_step(self, imgs, imgs_pred, foreground, ca_masks):
        # Generator (Refinement Network Loss)
        refinement_losses = LossManager()
        # unpack
        f_boxes, f_obj_to_img = foreground
        fake_local_patches = crop_bbox_batch(imgs_pred, f_boxes, f_obj_to_img, self.patch_size)
        real_local_patches = crop_bbox_batch(imgs, f_boxes, f_obj_to_img, self.patch_size)

        # Get Spatial Discounted L1 Loss
        spatial_loss = spatial_l1(fake_local_patches, real_local_patches, f_obj_to_img, self.patch_size, gamma = self.spatial_gamma)
        refinement_losses.add_loss(self.spatial_loss_weight * spatial_loss, 'r_spatial_loss')

        # Get L1 Pix Loss
        l1_pixel_weight = self.l1_pixel_loss_weight
        l1_pixel_loss = F.l1_loss(imgs_pred, imgs)
        refinement_losses.add_loss(self.l1_pixel_loss_weight * l1_pixel_loss, 'r_pix_loss')

        # Get WGAN Loss
        fake_local_pred, fake_global_pred = self.critic(fake_local_patches, imgs_pred, f_obj_to_img, get_ave = True)
        
        # Local Patch Loss
        local_loss = self.critic_g_loss(fake_local_pred)

        # Global Loss
        global_loss = self.critic_g_loss(fake_global_pred)

        critic_loss = self.critic_global_weight * global_loss + local_loss
        refinement_losses.add_loss(self.critic_g_weight * critic_loss, 'r_wgan_loss')

        # backward
        self.reset_grad()
        refinement_losses.total_loss.backward()
        self.refinement_optimizer.step()
        return refinement_losses
コード例 #2
0
def extract_patches(imgs, boxes, obj_to_img, patch_size=32, batch_size=None):
    # Box Format: (x0 , y0,  x1, y1)
    img_patches = []

    # Utilize crop function of scene graph
    img_patches = crop_bbox_batch(imgs, boxes, obj_to_img, patch_size)

    # Check if some image does not have patches
    # print("{} VS. {}".format(len(set(obj_to_img.data.cpu().numpy())), batch_size))

    # comp_obj_to_img = None
    # if(len(set(obj_to_img.data.cpu().numpy())) != batch_size):
    #   print("I AM NOT COMPLETE")
    #   _, comp_obj_to_img = complete_patches(imgs, img_patches, obj_to_img, patch_size, batch_size)

    return img_patches
コード例 #3
0
 def forward(self, imgs, objs, boxes, obj_to_img):
     crops = crop_bbox_batch(imgs, boxes, obj_to_img, self.object_size)
     real_scores, ac_loss = self.discriminator(crops, objs)
     return real_scores, ac_loss
コード例 #4
0
    def critic_step(self, imgs, imgs_pred, foreground, ca_masks):
        # detatch img_pred and ca_masks so refinement wouldn't be updated
        critic_losses = LossManager()
        imgs_fake = imgs_pred.detach()
        f_boxes, f_obj_to_img = foreground

        patch_masks = crop_bbox_batch(
                    ca_masks, f_boxes, f_obj_to_img, self.patch_size)

        # build input for critics
        fake_local_patches = crop_bbox_batch(
            imgs_fake, f_boxes, f_obj_to_img, self.patch_size)
        real_local_patches = crop_bbox_batch(
            imgs, f_boxes, f_obj_to_img, self.patch_size)
        local_vectors = torch.cat(
            [fake_local_patches, real_local_patches], dim=0)
        
        global_vectors = torch.cat([imgs_fake, imgs], dim=0)

        # Feed to the critic then split output to (fake, real)
        local_critic_out, global_critic_out = self.critic(
            local_vectors, global_vectors, f_obj_to_img)
        
        fake_local_pred, real_local_pred = torch.split(
            local_critic_out, imgs.shape[0], dim=0)
        fake_global_pred, real_global_pred = torch.split(
            global_critic_out, imgs.shape[0], dim=0)
        
        # Local Loss
        local_loss = self.critic_d_loss(real_local_pred, fake_local_pred)

        # Global Loss
        global_loss = self.critic_d_loss(real_global_pred, fake_global_pred)

        critic_losses.add_loss(global_loss + local_loss, 'c_wgan_loss')

        # Gradient Penalty Loss
        # real_local_patches, fake_local_patches = gp_vectors
        # global_masks, local_masks = gp_masks

        # Interpolate Images
        local_interpolate = random_interpolate(
            real_local_patches, fake_local_patches)
        global_interpolate = random_interpolate(imgs, imgs_fake)

        local_interpolate = to_var(local_interpolate, requires_grad = True)
        global_interpolate = to_var(global_interpolate, requires_grad = True)
        
        # GP Loss
        # Forward interpolated images to the critic
        local_gp_out, global_gp_out = self.critic(
            local_interpolate, global_interpolate, f_obj_to_img)

        local_gp = self.critic_gp_loss(
            local_interpolate, local_gp_out, mask=patch_masks, f_obj_to_img=f_obj_to_img)
        global_gp = self.critic_gp_loss(
            global_interpolate, global_gp_out, mask=ca_masks)

        critic_losses.add_loss(
            self.critic_gp_weight * (local_gp + global_gp), 'c_gp_loss')

        # backpropagate and optimize critic
        self.reset_grad()
        critic_losses.total_loss.backward()
        self.critic_optimizer.step()
        return critic_losses
コード例 #5
0
    def train(self):
        iters_per_epoch = len(self.train_loader)
        print("Iterations per epoch: " + str(iters_per_epoch))

        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])
        else:
            start = 0
        fixed_batch = self.build_sample_images(self.test_loader)

        # Start training
        iter_ctr = self.init_iterations
        e = iter_ctr // iters_per_epoch
        start_time = time.time()
        while True:
            # Stop training if iter_ctr reached num_iterations
            if iter_ctr >= self.num_iterations:
                break
            e += 1

            for i, batch in enumerate(tqdm(self.train_loader)):
                if iter_ctr == self.eval_mode_after:
                    self.generator.eval()
                    self.gen_optimizer = torch.optim.Adam(
                        self.generator.parameters(), lr=self.learning_rate)
                elif (iter_ctr < self.eval_mode_after):
                    self.generator.train()

                masks = None
                if len(batch) == 6:
                    imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
                elif len(batch) == 7:
                    imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch

                start = time.time()

                # Prepare Data
                imgs = to_var(imgs)
                objs = to_var(objs)
                boxes = to_var(boxes)
                triples = to_var(triples)
                obj_to_img = to_var(obj_to_img)
                triple_to_img = to_var(triple_to_img)
                if (masks):
                    masks = to_var(masks)
                predicates = triples[:, 1]  # get p from triples(s, p ,o)

                # variables needed for generator and discriminator steps
                step_vars = (imgs, objs, boxes, obj_to_img, predicates, masks)

                # Foreground objects
                f_inds = [
                    i for i, obj in enumerate(objs)
                    if obj in self.foreground_objs
                ]
                f_boxes = boxes[f_inds]
                f_obj_to_img = obj_to_img[f_inds]

                # Build Masks for Contextual Attention Module
                ca_masks = build_masks(imgs, f_boxes, f_obj_to_img)
                ca_masks = to_var(ca_masks)

                patch_masks = crop_bbox_batch(ca_masks, f_boxes, f_obj_to_img,
                                              self.patch_size)

                # Forward to Model
                model_boxes = boxes
                model_masks = masks
                model_out = self.generator(objs,
                                           triples,
                                           obj_to_img,
                                           ca_masks,
                                           boxes_gt=model_boxes,
                                           masks_gt=model_masks)

                imgs_pred, _, _, _ = model_out

                # build input for critics
                fake_local_patches = crop_bbox_batch(imgs_pred, f_boxes,
                                                     f_obj_to_img,
                                                     self.patch_size)
                real_local_patches = crop_bbox_batch(imgs, f_boxes,
                                                     f_obj_to_img,
                                                     self.patch_size)

                # Forward to Critic
                critic_pred = None
                if self.critic is not None:
                    local_vectors = torch.cat(
                        [fake_local_patches, real_local_patches], dim=0)
                    global_vectors = torch.cat([imgs_pred, imgs], dim=0)

                    # Feed input to the critic to get scores
                    local_critic_out, global_critic_out = self.critic(
                        local_vectors, global_vectors, f_obj_to_img)

                    # Split scores to (fake, real)
                    fake_local_pred, real_local_pred = torch.split(
                        local_critic_out, imgs.shape[0], dim=0)
                    fake_global_pred, real_global_pred = torch.split(
                        global_critic_out, imgs.shape[0], dim=0)

                    critic_pred = (fake_local_pred, real_local_pred,
                                   fake_global_pred, real_global_pred)

                    # for gradient penalty of critic
                    gp_vectors = (real_local_patches, fake_local_patches)
                    gp_masks = (ca_masks, patch_masks)

                # Generator Step
                total_loss, losses = self.generator_step(
                    step_vars, model_out, critic_pred)

                # Logging
                loss = {}

                loss['G/total_loss'] = total_loss.data.item()
                loss = self.log_losses(loss, 'G', losses)

                # Discriminator Step
                total_loss, dis_obj_losses, dis_img_losses, critic_losses = self.discriminator_step(
                    imgs_pred, step_vars, f_obj_to_img, critic_pred,
                    gp_vectors, gp_masks)

                loss['D/total_loss'] = total_loss.data.item()
                loss = self.log_losses(loss, 'D', dis_obj_losses)
                loss = self.log_losses(loss, 'D', dis_img_losses)
                loss = self.log_losses(loss, 'C', critic_losses)

                # Print out log info
                if (i + 1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    total_time = (self.num_iterations - iter_ctr +
                                  1) * elapsed / (iter_ctr + 1)
                    epoch_time = (iters_per_epoch - i) * elapsed / \
                        (iter_ctr + 1)

                    epoch_time = str(datetime.timedelta(seconds=epoch_time))
                    total_time = str(datetime.timedelta(seconds=total_time))
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed {}/{} -- {} , Iteration [{}/{}], Epoch [{}]".format(
                        elapsed, epoch_time, total_time, iter_ctr + 1,
                        self.num_iterations, e)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)
                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(tag, value,
                                                       iter_ctr + 1)

                # Save model checkpoints
                if (iter_ctr + 1) % self.model_save_step == 0:
                    torch.save(
                        self.generator.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_G.pth'.format(iter_ctr + 1)))
                    torch.save(
                        self.obj_discriminator.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_D_OBJ.pth'.format(iter_ctr + 1)))
                    torch.save(
                        self.img_discriminator.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_D_IMG.pth'.format(iter_ctr + 1)))
                    torch.save(
                        self.critic.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_C_IMG.pth'.format(iter_ctr + 1)))

                if (iter_ctr + 1) % self.sample_step == 0:
                    self.sample_images(fixed_batch, iter_ctr)

                iter_ctr += 1