コード例 #1
0
    def G_step(result):
        imgs, imgs_pred, objs, \
        g_scores_fake_crop, g_obj_scores_fake_crop, g_scores_fake_img, \
        = result.imgs, result.imgs_pred, result.objs, \
          result.g_scores_fake_crop, result.g_obj_scores_fake_crop, result.g_scores_fake_img
        mask_noise_indexes = result.mask_noise_indexes
        g_rec_feature_fake_crop = result.g_rec_feature_fake_crop
        obj_fmaps = result.obj_fmaps
        g_scores_fake_bg = result.g_scores_fake_bg

        with timeit('loss', args.timing):
            total_loss, losses = calculate_model_losses(
                args, imgs, imgs_pred, mask_noise_indexes)

            if criterionVGG is not None:
                if mask_noise_indexes is not None and args.perceptual_not_on_noise:
                    perceptual_loss = criterionVGG(
                        imgs_pred[mask_noise_indexes],
                        imgs[mask_noise_indexes])
                else:
                    perceptual_loss = criterionVGG(imgs_pred, imgs)
                total_loss = add_loss(total_loss, perceptual_loss, losses,
                                      'perceptual_loss',
                                      args.perceptual_loss_weight)

            if all_in_one_model.obj_discriminator is not None:
                total_loss = add_loss(
                    total_loss, F.cross_entropy(g_obj_scores_fake_crop, objs),
                    losses, 'ac_loss', args.ac_loss_weight)
                weight = args.discriminator_loss_weight * args.d_obj_weight
                total_loss = add_loss(total_loss,
                                      gan_g_loss(g_scores_fake_crop), losses,
                                      'g_gan_obj_loss', weight)
                if args.d_obj_rec_feat_weight > 0:
                    total_loss = add_loss(
                        total_loss,
                        F.l1_loss(g_rec_feature_fake_crop, obj_fmaps), losses,
                        'g_obj_fea_rec_loss', args.d_obj_rec_feat_weight)

            if all_in_one_model.img_discriminator is not None:
                weight = args.discriminator_loss_weight * args.d_img_weight
                total_loss = add_loss(total_loss,
                                      gan_g_loss(g_scores_fake_img), losses,
                                      'g_gan_img_loss', weight)

            if all_in_one_model.bg_discriminator is not None:
                weight = args.discriminator_loss_weight * args.d_bg_weight
                total_loss = add_loss(total_loss, gan_g_loss(g_scores_fake_bg),
                                      losses, 'g_gan_bg_loss', weight)

        losses['total_loss'] = total_loss.item()

        if math.isfinite(losses['total_loss']):
            with timeit('backward', args.timing):
                all_in_one_model.optimizer.zero_grad()
                total_loss.backward()
                all_in_one_model.optimizer.step()

        return losses
コード例 #2
0
def main(args):

    if args.device == 'cpu':
        device = torch.device('cpu')
    elif args.device == 'gpu':
        device = torch.device('cuda:0')
        if not torch.cuda.is_available():
            print('WARNING: CUDA not available; falling back to CPU')
            device = torch.device('cpu')

    # Load the model, with a bit of care in case there are no GPUs
    map_location = 'cpu' if device == torch.device('cpu') else None
    checkpoint = torch.load(args.checkpoint, map_location=map_location)
    model = Sg2ImModel(**checkpoint['model_kwargs'])
    model.load_state_dict(checkpoint['model_state'], strict=False)
    model.eval()
    model.to(device)

    vocab, train_loader, val_loader = build_loaders(args)

    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)
        print('Created %s' % args.output_dir)

    ## add code for validation visualization
    #logger = Logger(args.output_dir)
    logger = None
    t = 1
    with timeit('forward', args.timing):
        print('checking on val')

        import json
        with open('vg_captions_500.json', 'w') as f:
            f.write('[')

        # check_model_predicate_debug(args, t, val_loader, model, logger=logger, log_tag='Validation', write_images=True)
        val_results = check_model(args,
                                  t,
                                  val_loader,
                                  model,
                                  device,
                                  logger=logger,
                                  log_tag='Validation',
                                  write_images=True)

        # rel_score, avg_iou = get_rel_score(args, t, val_loader, model)
        # print ('relation score: ', rel_score)
        # print ('average iou: ', avg_iou)
        # val_losses, val_avg_iou = val_results
        # print('val iou: ', val_avg_iou)

        with open('vg_captions_500.json', 'a') as f:
            f.write(']')
コード例 #3
0
def main(args):

    if args.device == 'cpu':
        device = torch.device('cpu')
    elif args.device == 'gpu':
        device = torch.device('cuda:0')
        if not torch.cuda.is_available():
            print('WARNING: CUDA not available; falling back to CPU')
            device = torch.device('cpu')

    # Load the model, with a bit of care in case there are no GPUs
    map_location = 'cpu' if device == torch.device('cpu') else None
    checkpoint = torch.load(args.checkpoint, map_location=map_location)
    # for flags added after model trained.
    checkpoint['model_kwargs']['triplet_box_net'] = args.triplet_box_net
    checkpoint['model_kwargs']['triplet_mask_size'] = args.triplet_mask_size
    checkpoint['model_kwargs'][
        'triplet_embedding_size'] = args.triplet_embedding_size
    model = Sg2ImModel(**checkpoint['model_kwargs'])
    model.load_state_dict(checkpoint['model_state'], strict=False)
    model.eval()
    model.to(device)

    vocab, train_loader, val_loader = build_loaders(args)

    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)
        print('Created %s' % args.output_dir)

    ## add code for validation visualization
    #logger = Logger(args.output_dir)
    logger = None

    t = 1
    with timeit('forward', args.timing):
        #print('Extracting embeddings from train set:')
        #train_results = check_model(args, t, train_loader, model, log_tag='Train', write_images=False)
        print('Extracting embeddings from val test set:')
        val_results = check_model(args,
                                  t,
                                  val_loader,
                                  model,
                                  log_tag='Validation',
                                  write_images=True)
コード例 #4
0
def main(args):

    if args.device == 'cpu':
        device = torch.device('cpu')
    elif args.device == 'gpu':
        device = torch.device('cuda:0')
        if not torch.cuda.is_available():
            print('WARNING: CUDA not available; falling back to CPU')
            device = torch.device('cpu')

    # Load the model, with a bit of care in case there are no GPUs
    map_location = 'cpu' if device == torch.device('cpu') else None
    checkpoint = torch.load(args.checkpoint, map_location=map_location)
    model = Sg2ImModel(**checkpoint['model_kwargs'])
    model.load_state_dict(checkpoint['model_state'], strict=False)
    model.eval()
    model.to(device)

    vocab, train_loader, val_loader = build_loaders(args)

    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)
        print('Created %s' % args.output_dir)

    ## add code for validation visualization
    #logger = Logger(args.output_dir)
    logger = None
    t = 1
    with timeit('forward', args.timing):
        print('checking on val')

        check_model_predicate_debug(args,
                                    t,
                                    val_loader,
                                    model,
                                    logger=logger,
                                    log_tag='Validation',
                                    write_images=True)
コード例 #5
0
ファイル: model.py プロジェクト: LUGUANSONG/i2g2i
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        # forward detector
        with timeit('detector forward', self.args.timing):
            result = self.detector(x,
                                   im_sizes,
                                   image_offset,
                                   gt_boxes,
                                   gt_classes,
                                   gt_rels,
                                   proposals,
                                   train_anchor_inds,
                                   return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        # forward generator
        imgs = F.interpolate(x, size=self.args.image_size)
        objs = result.obj_preds
        boxes = result.rm_box_priors / BOX_SCALE
        obj_to_img = result.im_inds - image_offset
        obj_fmap = result.obj_fmap

        # check if all image have detection
        cnt = torch.zeros(len(imgs)).byte()
        cnt[obj_to_img] += 1
        if (cnt > 0).sum() != len(imgs):
            print("some imgs have no detection")
            print(cnt)
            imgs = imgs[cnt]
            obj_to_img_new = obj_to_img.clone()
            for i in range(len(cnt)):
                if cnt[i] == 0:
                    obj_to_img_new -= (obj_to_img > i).long()
            obj_to_img = obj_to_img_new

        with timeit('generator forward', self.args.timing):
            imgs_pred = self.model(obj_to_img, boxes, obj_fmap)

        # forward discriminators to train generator
        if self.obj_discriminator is not None:
            with timeit('d_obj forward for g', self.args.timing):
                g_scores_fake_crop, g_obj_scores_fake_crop = self.obj_discriminator(
                    imgs_pred, objs, boxes, obj_to_img)

        if self.img_discriminator is not None:
            with timeit('d_img forward for g', self.args.timing):
                g_scores_fake_img = self.img_discriminator(imgs_pred)

        # forward discriminators to train discriminators
        if self.obj_discriminator is not None:
            imgs_fake = imgs_pred.detach()
            with timeit('d_obj forward for d', self.args.timing):
                d_scores_fake_crop, d_obj_scores_fake_crop = self.obj_discriminator(
                    imgs_fake, objs, boxes, obj_to_img)
                d_scores_real_crop, d_obj_scores_real_crop = self.obj_discriminator(
                    imgs, objs, boxes, obj_to_img)

        if self.img_discriminator is not None:
            imgs_fake = imgs_pred.detach()
            with timeit('d_img forward for d', self.args.timing):
                d_scores_fake_img = self.img_discriminator(imgs_fake)
                d_scores_real_img = self.img_discriminator(imgs)

        return Result(imgs=imgs,
                      imgs_pred=imgs_pred,
                      objs=objs,
                      g_scores_fake_crop=g_scores_fake_crop,
                      g_obj_scores_fake_crop=g_obj_scores_fake_crop,
                      g_scores_fake_img=g_scores_fake_img,
                      d_scores_fake_crop=d_scores_fake_crop,
                      d_obj_scores_fake_crop=d_obj_scores_fake_crop,
                      d_scores_real_crop=d_scores_real_crop,
                      d_obj_scores_real_crop=d_obj_scores_real_crop,
                      d_scores_fake_img=d_scores_fake_img,
                      d_scores_real_img=d_scores_real_img)
コード例 #6
0
    def forward(self, imgs, img_offset, gt_boxes, gt_classes, gt_fmaps):
        obj_to_img = gt_classes[:, 0] - img_offset
        # print("obj_to_img.min(), obj_to_img.max(), len(imgs) {} {} {}".format(obj_to_img.min(), obj_to_img.max(), len(imgs)))
        assert obj_to_img.min() >= 0 and obj_to_img.max() < len(imgs), \
            "obj_to_img.min() >= 0 and obj_to_img.max() < len(imgs) is not satidfied: {} {} {}".format(obj_to_img.min(), obj_to_img.max(), len(imgs))
        boxes = gt_boxes
        obj_fmaps = gt_fmaps
        objs = gt_classes[:, 1]

        if self.args is not None:
            if self.args.exchange_feat_cls:
                print("exchange feature vectors and classes among bboxes")
                for img_ind in range(imgs.shape[0]):
                    ind = (obj_to_img == img_ind).nonzero()[:, 0]
                    # permute = ind[torch.randperm(len(ind))]
                    # obj_fmaps[ind] = obj_fmaps[permute]
                    permute_ind = ind[torch.randperm(len(ind))[:2]]
                    permute = permute_ind[[1, 0]]
                    obj_fmaps[permute_ind] = obj_fmaps[permute]
                    objs[permute_ind] = objs[permute]

            if self.args.change_bbox:
                print("change the position of bboxes")
                for img_ind in range(imgs.shape[0]):
                    ind = (obj_to_img == img_ind).nonzero()[:, 0]
                    ind = ind[torch.randperm(len(ind))[0]]
                    if boxes[ind][3] < 0.8:
                        print("move to bottom")
                        boxes[ind][1] += (1 - boxes[ind][3])
                        boxes[ind][3] = 1
                    elif boxes[ind][1] > 0.2:
                        print("move to top")
                        boxes[ind][3] -= boxes[ind][1]
                        boxes[ind][1] = 0
                    elif boxes[ind][0] > 0.2:
                        print("move to left")
                        boxes[ind][2] -= boxes[ind][0]
                        boxes[ind][0] = 0
                    elif boxes[ind][2] < 0.8:
                        print("move to right")
                        boxes[ind][0] += (1 - boxes[ind][2])
                        boxes[ind][2] = 1
                    else:
                        print("move to bottom right")
                        boxes[ind][1] += (1 - boxes[ind][3])
                        boxes[ind][3] = 1
                        boxes[ind][0] += (1 - boxes[ind][2])
                        boxes[ind][2] = 1

        # obj_to_img, boxes, obj_fmaps, mask_noise_indexes
        half_size = imgs.shape[0] // 2

        obj_index_encoded = []
        obj_index_random = []
        for ind in range(half_size):
            obj_index_encoded.append((obj_to_img == ind).nonzero()[:, 0])
        obj_index_encoded = torch.cat(obj_index_encoded)
        for ind in range(half_size, imgs.shape[0]):
            obj_index_random.append((obj_to_img == ind).nonzero()[:, 0])
        obj_index_random = torch.cat(obj_index_random)

        imgs_encoded = imgs[:half_size]
        obj_to_img_encoded = obj_to_img[obj_index_encoded]
        boxes_encoded = boxes[obj_index_encoded]
        obj_fmaps_encoded = obj_fmaps[obj_index_encoded]
        mask_noise_indexes_encoded = torch.randperm(
            half_size)[:int(self.args.noise_mask_ratio * half_size)].to(
                imgs.device)
        if len(mask_noise_indexes_encoded) == 0:
            mask_noise_indexes_encoded = None
        crops_encoded = crop_bbox_batch(imgs_encoded, boxes_encoded,
                                        obj_to_img_encoded,
                                        self.args.crop_size)

        imgs_random = imgs[half_size:]
        obj_to_img_random = obj_to_img[obj_index_random] - half_size
        boxes_random = boxes[obj_index_random]
        obj_fmaps_random = obj_fmaps[obj_index_random]
        mask_noise_indexes_random = torch.randperm(imgs.shape[0] - half_size)\
            [:int(self.args.noise_mask_ratio * (imgs.shape[0] - half_size))].to(imgs.device)
        if len(mask_noise_indexes_random) == 0:
            mask_noise_indexes_random = None
        # crops_random = crop_bbox_batch(imgs_random, boxes_random, obj_to_img_random, self.args.crop_size)

        mask_noise_indexes = None
        if mask_noise_indexes_encoded is not None:
            mask_noise_indexes = mask_noise_indexes_encoded
        if mask_noise_indexes_random is not None:
            if mask_noise_indexes is not None:
                mask_noise_indexes = torch.cat([
                    mask_noise_indexes, mask_noise_indexes_random + half_size
                ])
            else:
                mask_noise_indexes = mask_noise_indexes_random + half_size

        if self.forward_G:
            with timeit('generator forward', self.args.timing):
                if self.training:
                    mu_encoded, logvar_encoded = self.obj_encoder(
                        crops_encoded)
                    std = logvar_encoded.mul(0.5).exp_()
                    eps = torch.randn((std.size(0), std.size(1)),
                                      dtype=std.dtype,
                                      device=std.device)
                    z_encoded = eps.mul(std).add_(mu_encoded)
                    z_random = torch.randn((obj_fmaps_random.shape[0],
                                            self.args.object_noise_dim),
                                           dtype=obj_fmaps_random.dtype,
                                           device=obj_fmaps_random.device)

                    imgs_pred_encoded, layout_encoded = self.model(
                        obj_to_img_encoded,
                        boxes_encoded,
                        obj_fmaps_encoded,
                        mask_noise_indexes=mask_noise_indexes_encoded,
                        object_noise=z_encoded)
                    imgs_pred_random, layout_random = self.model(
                        obj_to_img_random,
                        boxes_random,
                        obj_fmaps_random,
                        mask_noise_indexes=mask_noise_indexes_random,
                        object_noise=z_random)

                    crops_pred_encoded = crop_bbox_batch(
                        imgs_pred_encoded, boxes_encoded, obj_to_img_encoded,
                        self.args.crop_size)

                    crops_pred_random = crop_bbox_batch(
                        imgs_pred_random, boxes_random, obj_to_img_random,
                        self.args.crop_size)
                    mu_rec, logvar_rec = self.obj_encoder(crops_pred_random)
                    z_random_rec = mu_rec

                    imgs_pred = torch.cat(
                        [imgs_pred_encoded, imgs_pred_random], dim=0)

                    layout = torch.cat([layout_encoded, layout_random],
                                       dim=0).detach()
                else:
                    z_random = torch.randn(
                        (obj_fmaps.shape[0], self.args.object_noise_dim),
                        dtype=obj_fmaps.dtype,
                        device=obj_fmaps.device)
                    imgs_pred, layout = self.model(
                        obj_to_img,
                        boxes,
                        obj_fmaps,
                        mask_noise_indexes=mask_noise_indexes,
                        object_noise=z_random)
                    layout = layout.detach()
                    crops_encoded = None
                    crops_pred_encoded = None
                    z_random = None
                    z_random_rec = None
                    mu_encoded = None
                    logvar_encoded = None

        H, W = self.args.image_size
        bg_layout = boxes_to_layout(
            torch.ones(boxes.shape[0], 3).to(imgs.device), boxes, obj_to_img,
            H, W)
        bg_layout = (bg_layout <= 0).type(imgs.dtype)

        if self.args.condition_d_img_on_class_label_map:
            layout = boxes_to_layout(
                (objs + 1).view(-1, 1).repeat(1, 3).type(imgs.dtype), boxes,
                obj_to_img, H, W)

        g_scores_fake_crop, g_obj_scores_fake_crop, g_rec_feature_fake_crop = None, None, None
        g_scores_fake_img = None
        g_scores_fake_bg = None
        if self.calc_G_D_loss:
            # forward discriminators to train generator
            if self.obj_discriminator is not None:
                with timeit('d_obj forward for g', self.args.timing):
                    g_scores_fake_crop, g_obj_scores_fake_crop, _, g_rec_feature_fake_crop = \
                        self.obj_discriminator(imgs_pred, objs, boxes, obj_to_img)

            if self.img_discriminator is not None:
                with timeit('d_img forward for g', self.args.timing):
                    if self.args.condition_d_img:
                        g_scores_fake_img = self.img_discriminator(
                            imgs_pred, layout)
                    else:
                        g_scores_fake_img = self.img_discriminator(imgs_pred)

            if self.bg_discriminator is not None:
                with timeit('d_bg forward for g', self.args.timing):
                    if self.args.condition_d_bg:
                        g_scores_fake_bg = self.bg_discriminator(
                            imgs_pred, bg_layout)
                    else:
                        g_scores_fake_bg = self.bg_discriminator(imgs_pred *
                                                                 bg_layout)

        d_scores_fake_crop, d_obj_scores_fake_crop, fake_crops, d_rec_feature_fake_crop = None, None, None, None
        d_scores_real_crop, d_obj_scores_real_crop, real_crops, d_rec_feature_real_crop = None, None, None, None
        d_obj_gp = None
        d_scores_fake_img = None
        d_scores_real_img = None
        d_img_gp = None
        d_scores_fake_bg = None
        d_scores_real_bg = None
        d_bg_gp = None
        if self.forward_D:
            # forward discriminators to train discriminators
            if self.obj_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_obj forward for d', self.args.timing):
                    d_scores_fake_crop, d_obj_scores_fake_crop, fake_crops, d_rec_feature_fake_crop = \
                        self.obj_discriminator(imgs_fake, objs, boxes, obj_to_img)
                    d_scores_real_crop, d_obj_scores_real_crop, real_crops, d_rec_feature_real_crop = \
                        self.obj_discriminator(imgs, objs, boxes, obj_to_img)
                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        d_obj_gp = gradient_penalty(
                            real_crops.detach(), fake_crops.detach(),
                            self.obj_discriminator.discriminator)

            if self.img_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_img forward for d', self.args.timing):
                    if self.args.condition_d_img:
                        d_scores_fake_img = self.img_discriminator(
                            imgs_fake, layout)
                        d_scores_real_img = self.img_discriminator(
                            imgs, layout)
                    else:
                        d_scores_fake_img = self.img_discriminator(imgs_fake)
                        d_scores_real_img = self.img_discriminator(imgs)

                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        if self.args.condition_d_img:
                            d_img_gp = gradient_penalty(
                                torch.cat([imgs, layout], dim=1),
                                torch.cat([imgs_fake, layout], dim=1),
                                self.img_discriminator)
                        else:
                            d_img_gp = gradient_penalty(
                                imgs, imgs_fake, self.img_discriminator)

            if self.bg_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_bg forward for d', self.args.timing):
                    if self.args.condition_d_bg:
                        d_scores_fake_bg = self.bg_discriminator(
                            imgs_fake, bg_layout)
                        d_scores_real_bg = self.bg_discriminator(
                            imgs, bg_layout)
                    else:
                        d_scores_fake_bg = self.bg_discriminator(imgs_fake *
                                                                 bg_layout)
                        d_scores_real_bg = self.bg_discriminator(imgs *
                                                                 bg_layout)

                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        if self.args.condition_d_bg:
                            d_bg_gp = gradient_penalty(
                                torch.cat([imgs, bg_layout], dim=1),
                                torch.cat([imgs_fake, bg_layout], dim=1),
                                self.bg_discriminator)
                        else:
                            d_bg_gp = gradient_penalty(imgs * bg_layout,
                                                       imgs_fake * bg_layout,
                                                       self.bg_discriminator)
        return Result(imgs=imgs,
                      imgs_pred=imgs_pred,
                      objs=objs,
                      obj_fmaps=obj_fmaps,
                      boxes=boxes,
                      obj_to_img=obj_to_img + img_offset,
                      g_scores_fake_crop=g_scores_fake_crop,
                      g_obj_scores_fake_crop=g_obj_scores_fake_crop,
                      g_scores_fake_img=g_scores_fake_img,
                      d_scores_fake_crop=d_scores_fake_crop,
                      d_obj_scores_fake_crop=d_obj_scores_fake_crop,
                      d_scores_real_crop=d_scores_real_crop,
                      d_obj_scores_real_crop=d_obj_scores_real_crop,
                      d_scores_fake_img=d_scores_fake_img,
                      d_scores_real_img=d_scores_real_img,
                      d_obj_gp=d_obj_gp,
                      d_img_gp=d_img_gp,
                      fake_crops=fake_crops,
                      real_crops=real_crops,
                      mask_noise_indexes=(mask_noise_indexes + img_offset)
                      if mask_noise_indexes is not None else None,
                      g_rec_feature_fake_crop=g_rec_feature_fake_crop,
                      d_rec_feature_fake_crop=d_rec_feature_fake_crop,
                      d_rec_feature_real_crop=d_rec_feature_real_crop,
                      g_scores_fake_bg=g_scores_fake_bg,
                      d_scores_fake_bg=d_scores_fake_bg,
                      d_scores_real_bg=d_scores_real_bg,
                      d_bg_gp=d_bg_gp,
                      bg_layout=bg_layout,
                      crops_encoded=crops_encoded,
                      crops_pred_encoded=crops_pred_encoded,
                      z_random=z_random,
                      z_random_rec=z_random_rec,
                      mu_encoded=mu_encoded,
                      logvar_encoded=logvar_encoded)
    def G_step(result):
        imgs, imgs_pred, objs, \
        g_scores_fake_crop, g_obj_scores_fake_crop, g_scores_fake_img, \
        = result.imgs, result.imgs_pred, result.objs, \
          result.g_scores_fake_crop, result.g_obj_scores_fake_crop, result.g_scores_fake_img
        mask_noise_indexes = result.mask_noise_indexes
        g_rec_feature_fake_crop = result.g_rec_feature_fake_crop
        obj_fmaps = result.obj_fmaps
        g_scores_fake_bg = result.g_scores_fake_bg

        bg_layout = result.bg_layout
        crops_encoded = result.crops_encoded
        crops_pred_encoded = result.crops_pred_encoded
        z_random = result.z_random
        z_random_rec = result.z_random_rec
        mu_encoded = result.mu_encoded
        logvar_encoded = result.logvar_encoded

        with timeit('loss', args.timing):
            total_loss, losses = calculate_model_losses(
                args, imgs, imgs_pred, mask_noise_indexes, bg_layout)

            crops_encoded_rec_loss = F.l1_loss(crops_pred_encoded,
                                               crops_encoded)
            total_loss = add_loss(total_loss, crops_encoded_rec_loss, losses,
                                  'crops_encoded_rec_loss',
                                  args.crops_encoded_rec_loss_weight)

            kl_loss = torch.sum(1 + logvar_encoded - mu_encoded.pow(2) -
                                logvar_encoded.exp()) * (-0.5)
            total_loss = add_loss(total_loss, kl_loss, losses, 'kl_loss',
                                  args.kl_loss_weight)

            if criterionVGG is not None:
                if args.perceptual_on_bg:
                    perceptual_imgs = imgs * bg_layout
                    preceptual_imgs_pred = imgs_pred * bg_layout
                if mask_noise_indexes is not None and args.perceptual_not_on_noise:
                    perceptual_loss = criterionVGG(
                        preceptual_imgs_pred[mask_noise_indexes],
                        perceptual_imgs[mask_noise_indexes])
                else:
                    perceptual_loss = criterionVGG(preceptual_imgs_pred,
                                                   perceptual_imgs)
                total_loss = add_loss(total_loss, perceptual_loss, losses,
                                      'perceptual_loss',
                                      args.perceptual_loss_weight)

            if all_in_one_model.obj_discriminator is not None:
                total_loss = add_loss(
                    total_loss, F.cross_entropy(g_obj_scores_fake_crop, objs),
                    losses, 'ac_loss', args.ac_loss_weight)
                weight = args.discriminator_loss_weight * args.d_obj_weight
                total_loss = add_loss(total_loss,
                                      gan_g_loss(g_scores_fake_crop), losses,
                                      'g_gan_obj_loss', weight)
                if args.d_obj_rec_feat_weight > 0:
                    total_loss = add_loss(
                        total_loss,
                        F.l1_loss(g_rec_feature_fake_crop, obj_fmaps), losses,
                        'g_obj_fea_rec_loss', args.d_obj_rec_feat_weight)

            if all_in_one_model.img_discriminator is not None:
                weight = args.discriminator_loss_weight * args.d_img_weight
                total_loss = add_loss(total_loss,
                                      gan_g_loss(g_scores_fake_img), losses,
                                      'g_gan_img_loss', weight)

            if all_in_one_model.bg_discriminator is not None:
                weight = args.discriminator_loss_weight * args.d_bg_weight
                total_loss = add_loss(total_loss, gan_g_loss(g_scores_fake_bg),
                                      losses, 'g_gan_bg_loss', weight)

        losses['total_loss'] = total_loss.item()

        if math.isfinite(losses['total_loss']):
            with timeit('backward', args.timing):
                all_in_one_model.optimizer_e_obj.zero_grad()
                all_in_one_model.optimizer.zero_grad()
                total_loss.backward(retain_graph=True)
                all_in_one_model.optimizer.step()
                all_in_one_model.optimizer_e_obj.step()

        z_random_rec_loss = torch.mean(
            torch.abs(z_random_rec - z_random)) * args.z_random_rec_loss_weight
        all_in_one_model.optimizer.zero_grad()
        all_in_one_model.optimizer_e_obj.zero_grad()
        z_random_rec_loss.backward()
        all_in_one_model.optimizer.step()

        total_loss = add_loss(total_loss, z_random_rec_loss, losses,
                              'z_random_rec_loss', 1.)
        losses['total_loss'] = total_loss.item()

        return losses
コード例 #8
0
ファイル: train.py プロジェクト: LUGUANSONG/i2g2i
def main(args):
    print(args)
    check_args(args)
    float_dtype = torch.cuda.FloatTensor
    long_dtype = torch.cuda.LongTensor
    detector_gather_device = args.num_gpus - 1
    sg2im_device = torch.device(args.num_gpus - 1)
    args.detector_gather_device = detector_gather_device
    args.sg2im_device = sg2im_device
    if not exists(args.output_dir):
        os.makedirs(args.output_dir)
    summary_writer = SummaryWriter(args.output_dir)

    # vocab, train_loader, val_loader = build_loaders(args)
    # self.ind_to_classes, self.ind_to_predicates
    vocab = {
        'object_idx_to_name': load_detector.train.ind_to_classes,
    }
    model, model_kwargs = build_model(
        args
    )  #, vocab)print(type(batch.imgs), len(batch.imgs), type(batch.imgs[0]))

    # model.type(float_dtype)
    model = model.to(sg2im_device)
    # model = DataParallel(model, list(range(args.num_gpus)))
    print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    obj_discriminator, d_obj_kwargs = build_obj_discriminator(args, vocab)
    img_discriminator, d_img_kwargs = build_img_discriminator(args)  #, vocab)
    gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)

    if obj_discriminator is not None:
        # obj_discriminator.type(float_dtype)
        obj_discriminator = obj_discriminator.to(sg2im_device)
        # obj_discriminator = DataParallel(obj_discriminator, list(range(args.num_gpus)))
        obj_discriminator.train()
        print(obj_discriminator)
        optimizer_d_obj = torch.optim.Adam(obj_discriminator.parameters(),
                                           lr=args.learning_rate)

    if img_discriminator is not None:
        # img_discriminator.type(float_dtype)
        img_discriminator = img_discriminator.to(sg2im_device)
        # img_discriminator = DataParallel(img_discriminator, list(range(args.num_gpus)))
        img_discriminator.train()
        print(img_discriminator)
        optimizer_d_img = torch.optim.Adam(img_discriminator.parameters(),
                                           lr=args.learning_rate)

    restore_path = None
    if args.restore_from_checkpoint:
        restore_path = '%s_with_model.pt' % args.checkpoint_name
        restore_path = os.path.join(args.output_dir, restore_path)
    if restore_path is not None and os.path.isfile(restore_path):
        print('Restoring from checkpoint:')
        print(restore_path)
        checkpoint = torch.load(restore_path)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optim_state'])

        if obj_discriminator is not None:
            obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
            optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])

        if img_discriminator is not None:
            img_discriminator.load_state_dict(checkpoint['d_img_state'])
            optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])

        t = checkpoint['counters']['t']
        if 0 <= args.eval_mode_after <= t:
            model.eval()
        else:
            model.train()
        epoch = checkpoint['counters']['epoch']
    else:
        t, epoch = 0, 0
        checkpoint = {
            # 'args': args.__dict__,
            'vocab': vocab,
            'model_kwargs': model_kwargs,
            'd_obj_kwargs': d_obj_kwargs,
            'd_img_kwargs': d_img_kwargs,
            'losses_ts': [],
            'losses': defaultdict(list),
            'd_losses': defaultdict(list),
            'checkpoint_ts': [],
            'train_batch_data': [],
            'train_samples': [],
            'train_iou': [],
            'val_batch_data': [],
            'val_samples': [],
            'val_losses': defaultdict(list),
            'val_iou': [],
            'norm_d': [],
            'norm_g': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'model_state': None,
            'model_best_state': None,
            'optim_state': None,
            'd_obj_state': None,
            'd_obj_best_state': None,
            'd_obj_optim_state': None,
            'd_img_state': None,
            'd_img_best_state': None,
            'd_img_optim_state': None,
            'best_t': [],
        }

    while True:
        if t >= args.num_iterations:
            break
        epoch += 1
        print('Starting epoch %d' % epoch)

        # for batch in train_loader:
        # for batch in train_detector.train_loader:
        for step, batch in enumerate(
                tqdm(load_detector.train_loader,
                     desc='Training Epoch %d' % epoch,
                     total=len(load_detector.train_loader))):
            if t == args.eval_mode_after:
                print('switching to eval mode')
                model.eval()
                optimizer = optim.Adam(model.parameters(),
                                       lr=args.learning_rate)
            t += 1
            # batch = [tensor.cuda() for tensor in batch]
            # masks = None
            # if len(batch) == 6:
            #   imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
            # elif len(batch) == 7:
            #   imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
            # else:
            #   assert False
            # predicates = triples[:, 1]

            with timeit('forward', args.timing):
                # model_boxes = boxes
                # model_masks = masks
                # model_out = model(objs, triples, obj_to_img,
                #                   boxes_gt=model_boxes, masks_gt=model_masks)
                # imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
                imgs = F.interpolate(batch.imgs,
                                     size=args.image_size).to(sg2im_device)
                with torch.no_grad():
                    if args.num_gpus > 2:
                        result = load_detector.detector.__getitem__(
                            batch, target_device=detector_gather_device)
                    else:
                        result = load_detector.detector[batch]
                objs = result.obj_preds
                boxes = result.rm_box_priors
                obj_to_img = result.im_inds
                obj_fmap = result.obj_fmap
                if args.num_gpus == 2:
                    objs = objs.to(sg2im_device)
                    boxes = boxes.to(sg2im_device)
                    obj_to_img = obj_to_img.to(sg2im_device)
                    obj_fmap = obj_fmap.to(sg2im_device)

                boxes /= load_detector.IM_SCALE
                # check if all image have detection
                cnt = torch.zeros(len(imgs)).byte()
                cnt[obj_to_img] += 1
                if (cnt > 0).sum() != len(imgs):
                    print("some imgs have no detection")
                    # print(obj_to_img)
                    print(cnt)
                    imgs = imgs[cnt]
                    obj_to_img_new = obj_to_img.clone()
                    for i in range(len(cnt)):
                        if cnt[i] == 0:
                            obj_to_img_new -= (obj_to_img > i).long()
                    obj_to_img = obj_to_img_new

                # assert (cnt > 0).sum() == len(imgs), "some imgs have no detection"
                model_out = model(obj_to_img, boxes, obj_fmap)
                imgs_pred = model_out

            with timeit('loss', args.timing):
                # Skip the pixel loss if using GT boxes
                # skip_pixel_loss = (model_boxes is None)
                skip_pixel_loss = False
                total_loss, losses = calculate_model_losses(
                    args, skip_pixel_loss, model, imgs, imgs_pred)

            if obj_discriminator is not None:
                scores_fake, ac_loss = obj_discriminator(
                    imgs_pred, objs, boxes, obj_to_img)
                total_loss = add_loss(total_loss, ac_loss, losses, 'ac_loss',
                                      args.ac_loss_weight)
                weight = args.discriminator_loss_weight * args.d_obj_weight
                total_loss = add_loss(total_loss, gan_g_loss(scores_fake),
                                      losses, 'g_gan_obj_loss', weight)

            if img_discriminator is not None:
                scores_fake = img_discriminator(imgs_pred)
                weight = args.discriminator_loss_weight * args.d_img_weight
                total_loss = add_loss(total_loss, gan_g_loss(scores_fake),
                                      losses, 'g_gan_img_loss', weight)

            losses['total_loss'] = total_loss.item()
            if not math.isfinite(losses['total_loss']):
                print('WARNING: Got loss = NaN, not backpropping')
                continue

            optimizer.zero_grad()
            with timeit('backward', args.timing):
                total_loss.backward()
            optimizer.step()
            total_loss_d = None
            ac_loss_real = None
            ac_loss_fake = None
            d_losses = {}

            if obj_discriminator is not None:
                d_obj_losses = LossManager()
                imgs_fake = imgs_pred.detach()
                scores_fake, ac_loss_fake = obj_discriminator(
                    imgs_fake, objs, boxes, obj_to_img)
                scores_real, ac_loss_real = obj_discriminator(
                    imgs, objs, boxes, obj_to_img)

                d_obj_gan_loss = gan_d_loss(scores_real, scores_fake)
                d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
                d_obj_losses.add_loss(ac_loss_real, 'd_ac_loss_real')
                d_obj_losses.add_loss(ac_loss_fake, 'd_ac_loss_fake')

                optimizer_d_obj.zero_grad()
                d_obj_losses.total_loss.backward()
                optimizer_d_obj.step()

            if img_discriminator is not None:
                d_img_losses = LossManager()
                imgs_fake = imgs_pred.detach()
                scores_fake = img_discriminator(imgs_fake)
                scores_real = img_discriminator(imgs)

                d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
                d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')

                optimizer_d_img.zero_grad()
                d_img_losses.total_loss.backward()
                optimizer_d_img.step()

            if t % args.print_every == 0:
                print('t = %d / %d' % (t, args.num_iterations))
                G_loss_list = []
                for name, val in losses.items():
                    # print(' G [%s]: %.4f' % (name, val))
                    G_loss_list.append('[%s]: %.4f' % (name, val))
                    checkpoint['losses'][name].append(val)
                    summary_writer.add_scalar("G_%s" % name, val, t)
                print("G: %s" % ", ".join(G_loss_list))
                checkpoint['losses_ts'].append(t)

                if obj_discriminator is not None:
                    D_obj_loss_list = []
                    for name, val in d_obj_losses.items():
                        # print(' D_obj [%s]: %.4f' % (name, val))
                        D_obj_loss_list.append('[%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)
                        summary_writer.add_scalar("D_obj_%s" % name, val, t)
                    print("D_obj: %s" % ", ".join(D_obj_loss_list))

                if img_discriminator is not None:
                    D_img_loss_list = []
                    for name, val in d_img_losses.items():
                        # print(' D_img [%s]: %.4f' % (name, val))
                        D_img_loss_list.append('[%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)
                        summary_writer.add_scalar("D_img_%s" % name, val, t)
                    print("D_img: %s" % ", ".join(D_img_loss_list))

            if t % args.checkpoint_every == 0:
                print('checking on train')
                train_results = check_model(args, t,
                                            load_detector.train_loader, model)
                # t_losses, t_samples, t_batch_data, t_avg_iou = train_results
                t_losses, t_samples, t_batch_data = train_results

                checkpoint['train_batch_data'].append(t_batch_data)
                checkpoint['train_samples'].append(t_samples)
                checkpoint['checkpoint_ts'].append(t)
                # checkpoint['train_iou'].append(t_avg_iou)
                for name, images in t_samples.items():
                    # images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1)
                    # images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1)
                    summary_writer.add_image("train_%s" % name, images, t)

                print('checking on val')
                val_results = check_model(args, t, load_detector.val_loader,
                                          model)
                # val_losses, val_samples, val_batch_data, val_avg_iou = val_results
                val_losses, val_samples, val_batch_data = val_results
                checkpoint['val_samples'].append(val_samples)
                checkpoint['val_batch_data'].append(val_batch_data)
                # checkpoint['val_iou'].append(val_avg_iou)
                for name, images in val_samples.items():
                    # images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1)
                    # images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1)
                    summary_writer.add_image("val_%s" % name, images, t)

                # print('train iou: ', t_avg_iou)
                # print('val iou: ', val_avg_iou)

                for k, v in val_losses.items():
                    checkpoint['val_losses'][k].append(v)
                    summary_writer.add_scalar("val_%s" % k, v, t)
                checkpoint['model_state'] = model.state_dict()

                if obj_discriminator is not None:
                    checkpoint['d_obj_state'] = obj_discriminator.state_dict()
                    checkpoint[
                        'd_obj_optim_state'] = optimizer_d_obj.state_dict()

                if img_discriminator is not None:
                    checkpoint['d_img_state'] = img_discriminator.state_dict()
                    checkpoint[
                        'd_img_optim_state'] = optimizer_d_img.state_dict()

                checkpoint['optim_state'] = optimizer.state_dict()
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name)
                print('Saving checkpoint to ', checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

                # Save another checkpoint without any model or optim state
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                key_blacklist = [
                    'model_state', 'optim_state', 'model_best_state',
                    'd_obj_state', 'd_obj_optim_state', 'd_obj_best_state',
                    'd_img_state', 'd_img_optim_state', 'd_img_best_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
コード例 #9
0
    def forward(self, imgs, img_offset, gt_boxes, gt_classes, gt_fmaps):
        obj_to_img = gt_classes[:, 0] - img_offset
        # print("obj_to_img.min(), obj_to_img.max(), len(imgs) {} {} {}".format(obj_to_img.min(), obj_to_img.max(), len(imgs)))
        assert obj_to_img.min() >= 0 and obj_to_img.max() < len(imgs), \
            "obj_to_img.min() >= 0 and obj_to_img.max() < len(imgs) is not satidfied: {} {} {}".format(obj_to_img.min(), obj_to_img.max(), len(imgs))
        boxes = gt_boxes
        obj_fmaps = gt_fmaps
        objs = gt_classes[:, 1]

        if self.args is not None:
            if self.args.exchange_feat_cls:
                print("exchange feature vectors and classes among bboxes")
                for img_ind in range(imgs.shape[0]):
                    ind = (obj_to_img == img_ind).nonzero()[:, 0]
                    # permute = ind[torch.randperm(len(ind))]
                    # obj_fmaps[ind] = obj_fmaps[permute]
                    permute_ind = ind[torch.randperm(len(ind))[:2]]
                    permute = permute_ind[[1, 0]]
                    obj_fmaps[permute_ind] = obj_fmaps[permute]
                    objs[permute_ind] = objs[permute]

            if self.args.change_bbox:
                print("change the position of bboxes")
                for img_ind in range(imgs.shape[0]):
                    ind = (obj_to_img == img_ind).nonzero()[:, 0]
                    ind = ind[torch.randperm(len(ind))[0]]
                    if boxes[ind][3] < 0.8:
                        print("move to bottom")
                        boxes[ind][1] += (1 - boxes[ind][3])
                        boxes[ind][3] = 1
                    elif boxes[ind][1] > 0.2:
                        print("move to top")
                        boxes[ind][3] -= boxes[ind][1]
                        boxes[ind][1] = 0
                    elif boxes[ind][0] > 0.2:
                        print("move to left")
                        boxes[ind][2] -= boxes[ind][0]
                        boxes[ind][0] = 0
                    elif boxes[ind][2] < 0.8:
                        print("move to right")
                        boxes[ind][0] += (1 - boxes[ind][2])
                        boxes[ind][2] = 1
                    else:
                        print("move to bottom right")
                        boxes[ind][1] += (1 - boxes[ind][3])
                        boxes[ind][3] = 1
                        boxes[ind][0] += (1 - boxes[ind][2])
                        boxes[ind][2] = 1

        mask_noise_indexes = torch.randperm(
            imgs.shape[0])[:int(self.args.noise_mask_ratio *
                                imgs.shape[0])].to(imgs.device)
        if len(mask_noise_indexes) == 0:
            mask_noise_indexes = None

        if self.forward_G:
            with timeit('generator forward', self.args.timing):
                imgs_pred, layout, z_random = self.model(
                    obj_to_img, boxes, obj_fmaps, mask_noise_indexes)

                if self.training:
                    mu_rec, logvar_rec = self.img_encoder(imgs_pred)
                    z_random_rec = mu_rec
                else:
                    z_random_rec = None

        H, W = self.args.image_size
        bg_layout = boxes_to_layout(
            torch.ones(boxes.shape[0], 3).to(imgs.device), boxes, obj_to_img,
            H, W)
        bg_layout = (bg_layout <= 0).type(imgs.dtype)

        if self.args.condition_d_img_on_class_label_map:
            layout = boxes_to_layout(
                (objs + 1).view(-1, 1).repeat(1, 3).type(imgs.dtype), boxes,
                obj_to_img, H, W)

        g_scores_fake_crop, g_obj_scores_fake_crop, g_rec_feature_fake_crop = None, None, None
        g_scores_fake_img = None
        g_scores_fake_bg = None
        if self.calc_G_D_loss:
            # forward discriminators to train generator
            if self.obj_discriminator is not None:
                with timeit('d_obj forward for g', self.args.timing):
                    g_scores_fake_crop, g_obj_scores_fake_crop, _, g_rec_feature_fake_crop = \
                        self.obj_discriminator(imgs_pred, objs, boxes, obj_to_img)

            if self.img_discriminator is not None:
                with timeit('d_img forward for g', self.args.timing):
                    if self.args.condition_d_img:
                        g_scores_fake_img = self.img_discriminator(
                            imgs_pred, layout)
                    else:
                        g_scores_fake_img = self.img_discriminator(imgs_pred)

            if self.bg_discriminator is not None:
                with timeit('d_bg forward for g', self.args.timing):
                    if self.args.condition_d_bg:
                        g_scores_fake_bg = self.bg_discriminator(
                            imgs_pred, bg_layout)
                    else:
                        g_scores_fake_bg = self.bg_discriminator(imgs_pred *
                                                                 bg_layout)

        d_scores_fake_crop, d_obj_scores_fake_crop, fake_crops, d_rec_feature_fake_crop = None, None, None, None
        d_scores_real_crop, d_obj_scores_real_crop, real_crops, d_rec_feature_real_crop = None, None, None, None
        d_obj_gp = None
        d_scores_fake_img = None
        d_scores_real_img = None
        d_img_gp = None
        d_scores_fake_bg = None
        d_scores_real_bg = None
        d_bg_gp = None
        if self.forward_D:
            # forward discriminators to train discriminators
            if self.obj_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_obj forward for d', self.args.timing):
                    d_scores_fake_crop, d_obj_scores_fake_crop, fake_crops, d_rec_feature_fake_crop = \
                        self.obj_discriminator(imgs_fake, objs, boxes, obj_to_img)
                    d_scores_real_crop, d_obj_scores_real_crop, real_crops, d_rec_feature_real_crop = \
                        self.obj_discriminator(imgs, objs, boxes, obj_to_img)
                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        d_obj_gp = gradient_penalty(
                            real_crops.detach(), fake_crops.detach(),
                            self.obj_discriminator.discriminator)

            if self.img_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_img forward for d', self.args.timing):
                    if self.args.condition_d_img:
                        d_scores_fake_img = self.img_discriminator(
                            imgs_fake, layout)
                        d_scores_real_img = self.img_discriminator(
                            imgs, layout)
                    else:
                        d_scores_fake_img = self.img_discriminator(imgs_fake)
                        d_scores_real_img = self.img_discriminator(imgs)

                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        if self.args.condition_d_img:
                            d_img_gp = gradient_penalty(
                                torch.cat([imgs, layout], dim=1),
                                torch.cat([imgs_fake, layout], dim=1),
                                self.img_discriminator)
                        else:
                            d_img_gp = gradient_penalty(
                                imgs, imgs_fake, self.img_discriminator)

            if self.bg_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_bg forward for d', self.args.timing):
                    if self.args.condition_d_bg:
                        d_scores_fake_bg = self.bg_discriminator(
                            imgs_fake, bg_layout)
                        d_scores_real_bg = self.bg_discriminator(
                            imgs, bg_layout)
                    else:
                        d_scores_fake_bg = self.bg_discriminator(imgs_fake *
                                                                 bg_layout)
                        d_scores_real_bg = self.bg_discriminator(imgs *
                                                                 bg_layout)

                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        if self.args.condition_d_bg:
                            d_bg_gp = gradient_penalty(
                                torch.cat([imgs, bg_layout], dim=1),
                                torch.cat([imgs_fake, bg_layout], dim=1),
                                self.bg_discriminator)
                        else:
                            d_bg_gp = gradient_penalty(imgs * bg_layout,
                                                       imgs_fake * bg_layout,
                                                       self.bg_discriminator)
        return Result(imgs=imgs,
                      imgs_pred=imgs_pred,
                      objs=objs,
                      obj_fmaps=obj_fmaps,
                      boxes=boxes,
                      obj_to_img=obj_to_img + img_offset,
                      g_scores_fake_crop=g_scores_fake_crop,
                      g_obj_scores_fake_crop=g_obj_scores_fake_crop,
                      g_scores_fake_img=g_scores_fake_img,
                      d_scores_fake_crop=d_scores_fake_crop,
                      d_obj_scores_fake_crop=d_obj_scores_fake_crop,
                      d_scores_real_crop=d_scores_real_crop,
                      d_obj_scores_real_crop=d_obj_scores_real_crop,
                      d_scores_fake_img=d_scores_fake_img,
                      d_scores_real_img=d_scores_real_img,
                      d_obj_gp=d_obj_gp,
                      d_img_gp=d_img_gp,
                      fake_crops=fake_crops,
                      real_crops=real_crops,
                      mask_noise_indexes=(mask_noise_indexes + img_offset)
                      if mask_noise_indexes is not None else None,
                      g_rec_feature_fake_crop=g_rec_feature_fake_crop,
                      d_rec_feature_fake_crop=d_rec_feature_fake_crop,
                      d_rec_feature_real_crop=d_rec_feature_real_crop,
                      g_scores_fake_bg=g_scores_fake_bg,
                      d_scores_fake_bg=d_scores_fake_bg,
                      d_scores_real_bg=d_scores_real_bg,
                      d_bg_gp=d_bg_gp,
                      bg_layout=bg_layout,
                      z_random=z_random,
                      z_random_rec=z_random_rec)
コード例 #10
0
def main(args):
    print(args)
    check_args(args)
    float_dtype = torch.cuda.FloatTensor
    long_dtype = torch.cuda.LongTensor

    vocab, train_loader, val_loader = build_loaders(args)
    model, model_kwargs = build_model(args, vocab)
    model.type(float_dtype)
    model = model.cuda()
    layoutgen = LayoutGenerator(args.batch_size,
                                args.max_objects_per_image + 1, 184).cuda()
    optimizer_params = list(model.parameters()) + list(layoutgen.parameters())
    optimizer = torch.optim.Adam(params=optimizer_params,
                                 lr=args.learning_rate)

    obj_discriminator, d_obj_kwargs = build_obj_discriminator(args, vocab)
    img_discriminator, d_img_kwargs = build_img_discriminator(args, vocab)
    obj_discriminator = obj_discriminator.cuda()
    img_discriminator = img_discriminator.cuda()
    layout_discriminator = LayoutDiscriminator(args.batch_size,
                                               args.max_objects_per_image + 1,
                                               184, 64, 64).cuda()

    gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)

    obj_discriminator.type(float_dtype)
    obj_discriminator.train()
    optimizer_d_obj = torch.optim.Adam(obj_discriminator.parameters(),
                                       lr=args.learning_rate)

    img_discriminator.type(float_dtype)
    img_discriminator.train()
    optimizer_d_img = torch.optim.Adam(img_discriminator.parameters(),
                                       lr=args.learning_rate)

    optimizer_d_layout = torch.optim.Adam(
        params=layout_discriminator.parameters(), lr=args.learning_rate)

    model_path = 'stats/epoch_2_batch_399_with_model.pt'
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state'])
    layoutgen.load_state_dict(checkpoint['layout_gen'])

    obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
    img_discriminator.load_state_dict(checkpoint['d_img_state'])
    layout_discriminator.load_state_dict(checkpoint['d_layout_state'])

    optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])
    optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])
    optimizer_d_layout.load_state_dict(checkpoint['d_layout_optim_state'])
    optimizer.load_state_dict(checkpoint['optim_state'])

    # checkpoint = torch.load('sg2im-models/coco64.pt')
    # model.load_state_dict(checkpoint['model_state'])
    # 0/0

    # 'model_state':model.state_dict(),
    #   'layout_gen':layoutgen.state_dict(),
    #   'd_obj_state': obj_discriminator.state_dict(),
    #   'd_img_state': img_discriminator.state_dict(),
    #   'd_layout_state':layout_discriminator.state_dict(),

    #   'd_obj_optim_state': optimizer_d_obj.state_dict(),
    #   'd_img_optim_state': optimizer_d_img.state_dict(),
    #   'd_layout_optim_state':optimizer_d_layout.state_dict(),
    #   'optim_state': optimizer.state_dict(),

    # restore_path = None
    # if args.restore_from_checkpoint:
    #   restore_path = 'stats/%s_with_model.pt' % args.checkpoint_name
    #   restore_path = os.path.join(args.output_dir, restore_path)

    # if restore_path is not None and os.path.isfile(restore_path):
    #   print('Restoring from checkpoint:')
    #   print(restore_path)
    #   checkpoint = torch.load(restore_path)
    #   model.load_state_dict(checkpoint['model_state'])
    #   optimizer.load_state_dict(checkpoint['optim_state'])

    #   obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
    #   optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])

    #   img_discriminator.load_state_dict(checkpoint['d_img_state'])
    #   optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])

    #   t = checkpoint['counters']['t']
    #   if 0 <= args.eval_mode_after <= t:
    #     model.eval()
    #   else:
    #     model.train()
    #   epoch = checkpoint['counters']['epoch']
    # else:
    #   t, epoch = 0, 0
    #   checkpoint = {
    #     'args': args.__dict__,
    #     'vocab': vocab,
    #     'model_kwargs': model_kwargs,
    #     'd_obj_kwargs': d_obj_kwargs,
    #     'd_img_kwargs': d_img_kwargs,
    #     'losses_ts': [],
    #     'losses': defaultdict(list),
    #     'd_losses': defaultdict(list),
    #     'checkpoint_ts': [],
    #     'train_batch_data': [],
    #     'train_samples': [],
    #     'train_iou': [],
    #     'val_batch_data': [],
    #     'val_samples': [],
    #     'val_losses': defaultdict(list),
    #     'val_iou': [],
    #     'norm_d': [],
    #     'norm_g': [],
    #     'counters': {
    #       't': None,
    #       'epoch': None,
    #     },
    #     'model_state': None, 'model_best_state': None, 'optim_state': None,
    #     'd_obj_state': None, 'd_obj_best_state': None, 'd_obj_optim_state': None,
    #     'd_img_state': None, 'd_img_best_state': None, 'd_img_optim_state': None,
    #     'best_t': [],
    #   }

    epoch = 2
    while True:
        if (epoch >= 20):
            break

        epoch += 1
        print('Starting epoch %d' % epoch)

        for batchnum, batch in enumerate(tqdm(train_loader)):
            imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, combined, all_num_objs = batch
            imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, combined, all_num_objs = imgs.cuda(
            ), objs.cuda(), boxes.cuda(), masks.cuda(), triples.cuda(
            ), obj_to_img.cuda(), triple_to_img.cuda(), combined.cuda(
            ), all_num_objs.cuda()

            if (imgs.shape[0] < args.batch_size):
                print('current size was', imgs.shape[0])
                continue
            #print('\n\nimages',imgs.shape,objs.shape,boxes.shape,masks.shape,combined.shape)
            #print(all_num_objs)
            zlist = []
            for i in range(args.batch_size):
                geo_z = torch.normal(0,
                                     1,
                                     size=(args.max_objects_per_image + 1, 4))
                z = torch.FloatTensor(geo_z)
                zlist.append(z)

            zlist = torch.stack(zlist).cuda()
            zlist = torch.cat((zlist, combined[:, :, 4:]), dim=2)

            feature_vectors, logit_boxes = layoutgen(zlist.cuda())
            generated_boxes = 1 / (1 + torch.exp(-1 * logit_boxes))

            new_gen_boxes = torch.empty((0, 4)).cuda()
            new_feature_vecs = torch.empty((0, 128)).cuda()

            for kb in range(args.batch_size):
                new_gen_boxes = torch.cat([
                    new_gen_boxes,
                    torch.squeeze(generated_boxes[kb, :all_num_objs[kb], :4])
                ],
                                          dim=0)
                new_feature_vecs = torch.cat([
                    new_feature_vecs,
                    torch.squeeze(feature_vectors[kb, :all_num_objs[kb], :])
                ],
                                             dim=0)

            #print('Shape of new gen boxes:',new_gen_boxes.shape)
            #print('Shape of new feature vec:',new_feature_vecs.shape)
            boxes_pred = new_gen_boxes

            with timeit('forward', args.timing):
                #model_boxes = boxes
                model_boxes = generated_boxes
                #model_masks = masks
                model_masks = None
                triples = None

                imgs_pred = model(new_feature_vecs, new_gen_boxes, triples,
                                  obj_to_img)
                #boxes_gt=model_boxes, masks_gt=model_masks)
                #imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
            with timeit('loss', args.timing):
                # Skip the pixel loss if using GT boxes
                skip_pixel_loss = (model_boxes is None)
                total_loss, losses = calculate_model_losses(
                    args, model, imgs, imgs_pred, boxes, boxes_pred,
                    logit_boxes, generated_boxes, combined)

            scores_fake, ac_loss = obj_discriminator(imgs_pred, objs, boxes,
                                                     obj_to_img)
            total_loss = add_loss(total_loss, ac_loss, losses, 'ac_loss',
                                  args.ac_loss_weight)
            weight = args.discriminator_loss_weight * args.d_obj_weight
            total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
                                  'g_gan_obj_loss', weight)

            scores_fake = img_discriminator(imgs_pred)
            weight = args.discriminator_loss_weight * args.d_img_weight
            total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
                                  'g_gan_img_loss', weight)

            scores_fake = layout_discriminator(logit_boxes)
            weight = args.discriminator_loss_weight * args.d_img_weight
            total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
                                  'g_gan_layout_loss', weight)

            losses['total_loss'] = total_loss.item()
            if not math.isfinite(losses['total_loss']):
                print('WARNING: Got loss = NaN, not backpropping')
                continue

            optimizer.zero_grad()
            with timeit('backward', args.timing):
                #print('Total loss:',total_loss)
                total_loss.backward()

            optimizer.step()
            total_loss_d = None
            ac_loss_real = None
            ac_loss_fake = None
            d_losses = {}

            d_obj_losses = LossManager()
            imgs_fake = imgs_pred.detach().cuda()
            scores_fake, ac_loss_fake = obj_discriminator(
                imgs_fake, objs, boxes, obj_to_img)
            scores_real, ac_loss_real = obj_discriminator(
                imgs, objs, boxes, obj_to_img)

            d_obj_gan_loss = gan_d_loss(scores_real, scores_fake)
            d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
            d_obj_losses.add_loss(ac_loss_real, 'd_ac_loss_real')
            d_obj_losses.add_loss(ac_loss_fake, 'd_ac_loss_fake')

            optimizer_d_obj.zero_grad()
            d_obj_losses.total_loss.backward()
            optimizer_d_obj.step()

            d_img_losses = LossManager()
            imgs_fake = imgs_pred.detach().cuda()
            scores_fake = img_discriminator(imgs_fake)
            scores_real = img_discriminator(imgs)

            d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
            d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')

            optimizer_d_img.zero_grad()
            d_img_losses.total_loss.backward()
            optimizer_d_img.step()

            d_layout_losses = LossManager()
            layout_fake = logit_boxes.detach()
            scores_fake = layout_discriminator(layout_fake)
            scores_real = layout_discriminator(combined)

            d_layout_gan_loss = gan_d_loss(scores_real, scores_fake)
            d_layout_losses.add_loss(d_layout_gan_loss, 'd_layout_gan_loss')

            optimizer_d_layout.zero_grad()
            d_layout_losses.total_loss.backward()
            optimizer_d_layout.step()

            if ((batchnum + 1) % 10 == 0):
                towrite = '\n|Epoch:' + str(epoch) + ' | layout Loss:' + str(
                    d_layout_losses.total_loss) + ' | img disc loss:' + str(
                        d_img_losses.total_loss) + ' | obj disc loss:' + str(
                            d_obj_losses.total_loss
                        ) + ' | total gen loss:' + str(total_loss)
                with open('stats/training_stats.txt', 'a+') as f:
                    f.write(towrite)

            if ((batchnum + 1) % 100 == 0):
                checkpoint = {
                    'model_state': model.state_dict(),
                    'layout_gen': layoutgen.state_dict(),
                    'd_obj_state': obj_discriminator.state_dict(),
                    'd_img_state': img_discriminator.state_dict(),
                    'd_layout_state': layout_discriminator.state_dict(),
                    'd_obj_optim_state': optimizer_d_obj.state_dict(),
                    'd_img_optim_state': optimizer_d_img.state_dict(),
                    'd_layout_optim_state': optimizer_d_layout.state_dict(),
                    'optim_state': optimizer.state_dict(),
                }
                print('Saving checkpoint to ', 'stats/')
                checkpoint_path = os.path.join(
                    'stats/', 'epoch_' + str(epoch) + '_batch_' +
                    str(batchnum) + '_with_model.pt')
                torch.save(checkpoint, checkpoint_path)

        checkpoint = {
            'model_state': model.state_dict(),
            'layout_gen': layoutgen.state_dict(),
            'd_obj_state': obj_discriminator.state_dict(),
            'd_img_state': img_discriminator.state_dict(),
            'd_layout_state': layout_discriminator.state_dict(),
            'd_obj_optim_state': optimizer_d_obj.state_dict(),
            'd_img_optim_state': optimizer_d_img.state_dict(),
            'd_layout_optim_state': optimizer_d_layout.state_dict(),
            'optim_state': optimizer.state_dict(),
        }
        print('Saving checkpoint to ', 'stats/')
        checkpoint_path = os.path.join(
            'stats/', 'epoch_' + str(epoch) + '_with_model.pt')
        torch.save(checkpoint, checkpoint_path)
コード例 #11
0
ファイル: sequence_train.py プロジェクト: louis2889184/sg2im
def main(args):
    print(args)
    check_args(args)

    if USE_GPU:
        float_dtype = torch.cuda.FloatTensor
        long_dtype = torch.cuda.LongTensor
    else:
        float_dtype = torch.FloatTensor
        long_dtype = torch.LongTensor

    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased-itokens")

    # add_tokens(tokenizer)

    vocab, train_loader, val_loader = build_loaders(args, tokenizer)
    model_kwargs = {}

    encoder_decoder_config = EncoderDecoderConfig.from_pretrained(
        "bert-base-uncased-itokens")
    model = EncoderDecoderModel.from_pretrained("bert-base-uncased-itokens",
                                                config=encoder_decoder_config)

    # modify_network(model, tokenizer)
    # model, model_kwargs = build_model(args, vocab)
    # model.type(float_dtype)
    model.cuda()
    print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    obj_discriminator, d_obj_kwargs = build_obj_discriminator(args, vocab)
    img_discriminator, d_img_kwargs = build_img_discriminator(args, vocab)
    gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)

    if obj_discriminator is not None:
        obj_discriminator.type(float_dtype)
        obj_discriminator.train()
        print(obj_discriminator)
        optimizer_d_obj = torch.optim.Adam(obj_discriminator.parameters(),
                                           lr=args.learning_rate)

    if img_discriminator is not None:
        img_discriminator.type(float_dtype)
        img_discriminator.train()
        print(img_discriminator)
        optimizer_d_img = torch.optim.Adam(img_discriminator.parameters(),
                                           lr=args.learning_rate)

    restore_path = None
    if args.restore_from_checkpoint:
        restore_path = '%s_with_model.pt' % args.checkpoint_name
        restore_path = os.path.join(args.output_dir, restore_path)
    if restore_path is not None and os.path.isfile(restore_path):
        print('Restoring from checkpoint:')
        print(restore_path)
        checkpoint = torch.load(restore_path)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optim_state'])

        if obj_discriminator is not None:
            obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
            optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])

        if img_discriminator is not None:
            img_discriminator.load_state_dict(checkpoint['d_img_state'])
            optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])

        t = checkpoint['counters']['t']
        if 0 <= args.eval_mode_after <= t:
            model.eval()
        else:
            model.train()
        epoch = checkpoint['counters']['epoch']
    else:
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'vocab': vocab,
            'model_kwargs': model_kwargs,
            'd_obj_kwargs': d_obj_kwargs,
            'd_img_kwargs': d_img_kwargs,
            'losses_ts': [],
            'losses': defaultdict(list),
            'd_losses': defaultdict(list),
            'checkpoint_ts': [],
            'train_batch_data': [],
            'train_samples': [],
            'train_iou': [],
            'val_batch_data': [],
            'val_samples': [],
            'val_losses': defaultdict(list),
            'val_iou': [],
            'norm_d': [],
            'norm_g': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'model_state': None,
            'model_best_state': None,
            'optim_state': None,
            'd_obj_state': None,
            'd_obj_best_state': None,
            'd_obj_optim_state': None,
            'd_img_state': None,
            'd_img_best_state': None,
            'd_img_optim_state': None,
            'best_t': [],
        }

    while True:
        if t >= args.num_iterations:
            break
        epoch += 1
        print('Starting epoch %d' % epoch)

        for batch in train_loader:
            print(batch)
            exit()
            if t == args.eval_mode_after:
                print('switching to eval mode')
                model.eval()
                optimizer = optim.Adam(model.parameters(),
                                       lr=args.learning_rate)
            t += 1
            if USE_GPU:
                for k in batch.keys():
                    batch[k] = batch[k].cuda().long()
            masks = None

            with timeit('forward', args.timing):
                output = model(**batch)
            # with timeit('loss', args.timing):
            #   # Skip the pixel loss if using GT boxes
            #   skip_pixel_loss = False
            #   total_loss, losses = calculate_model_losses(
            #                           args, skip_pixel_loss, model, imgs, imgs_pred)

            # if img_discriminator is not None:
            #   scores_fake = img_discriminator(imgs_pred)
            #   weight = args.discriminator_loss_weight * args.d_img_weight
            #   total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
            #                         'g_gan_img_loss', weight)

            losses = {}
            total_loss = output["loss"]
            losses['total_loss'] = total_loss.item()
            if not math.isfinite(losses['total_loss']):
                print('WARNING: Got loss = NaN, not backpropping')
                continue

            optimizer.zero_grad()
            with timeit('backward', args.timing):
                total_loss.backward()
            optimizer.step()
            total_loss_d = None
            ac_loss_real = None
            ac_loss_fake = None
            d_losses = {}

            # if img_discriminator is not None:
            #   d_img_losses = LossManager()
            #   imgs_fake = imgs_pred.detach()
            #   scores_fake = img_discriminator(imgs_fake)
            #   scores_real = img_discriminator(imgs)

            #   d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
            #   d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')

            #   optimizer_d_img.zero_grad()
            #   d_img_losses.total_loss.backward()
            #   optimizer_d_img.step()

            if t % args.print_every == 0:
                print('t = %d / %d' % (t, args.num_iterations))
                for name, val in losses.items():
                    print(' G [%s]: %.4f' % (name, val))
                    checkpoint['losses'][name].append(val)
                checkpoint['losses_ts'].append(t)

                # if img_discriminator is not None:
                #   for name, val in d_img_losses.items():
                #     print(' D_img [%s]: %.4f' % (name, val))
                #     checkpoint['d_losses'][name].append(val)

            if t % args.checkpoint_every == 0:
                print('checking on train')
                train_results = check_model(args, t, train_loader, model)
                t_losses = train_results[0]

                print('checking on val')
                val_results = check_model(args, t, val_loader, model)
                val_losses = val_results[0]

                for k, v in val_losses.items():
                    checkpoint['val_losses'][k].append(v)

                checkpoint['model_state'] = model.state_dict()

                if obj_discriminator is not None:
                    checkpoint['d_obj_state'] = obj_discriminator.state_dict()
                    checkpoint[
                        'd_obj_optim_state'] = optimizer_d_obj.state_dict()

                if img_discriminator is not None:
                    checkpoint['d_img_state'] = img_discriminator.state_dict()
                    checkpoint[
                        'd_img_optim_state'] = optimizer_d_img.state_dict()

                checkpoint['optim_state'] = optimizer.state_dict()
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name)
                print('Saving checkpoint to ', checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

                # Save another checkpoint without any model or optim state
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                key_blacklist = [
                    'model_state', 'optim_state', 'model_best_state',
                    'd_obj_state', 'd_obj_optim_state', 'd_obj_best_state',
                    'd_img_state', 'd_img_optim_state', 'd_img_best_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
コード例 #12
0
def main(args):

  device = torch.device('cuda:0')
  if not torch.cuda.is_available():
   print('WARNING: CUDA not available; falling back to CPU')
   torch.cuda.current_device()
 
  print(args)
  check_args(args)
  float_dtype = torch.cuda.FloatTensor
  long_dtype = torch.cuda.LongTensor

  vocab, train_loader, val_loader = build_loaders(args)
  model, model_kwargs = build_model(args, vocab)
  model.type(float_dtype)
  print(model)

  if not os.path.isdir(args.output_dir):
    os.mkdir(args.output_dir)
    print('Created %s' %args.output_dir)

  optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
  vgg_featureExractor = FeatureExtractor(requires_grad=False).cuda()
  ## add code for training visualization
  logger = Logger(args.output_dir)

  obj_discriminator, d_obj_kwargs = build_obj_discriminator(args, vocab)
  img_discriminator, d_img_kwargs = build_img_discriminator(args, vocab)
  gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)
  
  if args.matching_aware_loss and args.sg_context_dim > 0:
    gan_g_matching_aware_loss, gan_d_matching_aware_loss = get_gan_losses('matching_aware_gan')

  ### quick hack
  obj_discriminator = None
  img_discriminator = None


  ############
  if obj_discriminator is not None:
    obj_discriminator.type(float_dtype)
    obj_discriminator.train()
    print(obj_discriminator)
    optimizer_d_obj = torch.optim.Adam(obj_discriminator.parameters(),
                                       lr=args.learning_rate)

  if img_discriminator is not None:
    img_discriminator.type(float_dtype)
    img_discriminator.train()
    print(img_discriminator)
    optimizer_d_img = torch.optim.Adam(img_discriminator.parameters(),
                                       lr=args.learning_rate)

  restore_path = None
  if args.checkpoint_start_from is not None:
    restore_path = args.checkpoint_start_from
  elif args.restore_from_checkpoint:
    restore_path = '%s_with_model.pt' % args.checkpoint_name
    restore_path = os.path.join(args.output_dir, restore_path)
  if restore_path is not None and os.path.isfile(restore_path):
    print('Restoring from checkpoint:')
    print(restore_path)
    checkpoint = torch.load(restore_path)
    model.load_state_dict(checkpoint['model_state'])
    optimizer.load_state_dict(checkpoint['optim_state'])

    if obj_discriminator is not None:
      obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
      optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])

    if img_discriminator is not None:

      img_discriminator.load_state_dict(checkpoint['d_img_state'])
      optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])

    t = checkpoint['counters']['t']
    if 0 <= args.eval_mode_after <= t:
      model.eval()
    else:
      model.train()
    epoch = checkpoint['counters']['epoch']
  else:
    t, epoch = 0, 0
    checkpoint = {
      'args': args.__dict__,
      'vocab': vocab,
      'model_kwargs': model_kwargs,
      'd_obj_kwargs': d_obj_kwargs,
      'd_img_kwargs': d_img_kwargs,
      'losses_ts': [],
      'losses': defaultdict(list),
      'd_losses': defaultdict(list),
      'checkpoint_ts': [],
      'train_batch_data': [], 
      'train_samples': [],
      'train_iou': [],
      'val_batch_data': [], 
      'val_samples': [],
      'val_losses': defaultdict(list),
      'val_iou': [], 
      'norm_d': [], 
      'norm_g': [],
      'counters': {
        't': None,
        'epoch': None,
      },
      'model_state': None, 'model_best_state': None, 'optim_state': None,
      'd_obj_state': None, 'd_obj_best_state': None, 'd_obj_optim_state': None,
      'd_img_state': None, 'd_img_best_state': None, 'd_img_optim_state': None,
      'best_t': [],
    }


  while True:
    if t >= args.num_iterations:
      break
    epoch += 1
    print('Starting epoch %d' % epoch)
   
    for batch in train_loader:
      if t == args.eval_mode_after:
        print('switching to eval mode')
        model.eval()
        optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
      t += 1
      batch = [tensor.cuda() for tensor in batch]
      masks = None
      if len(batch) == 6:
        imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
        triplet_masks = None 
      elif len(batch) == 8:
        imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, triplet_masks = batch
      #elif len(batch) == 7:
      #  imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
      else:
        assert False
      predicates = triples[:, 1]

      with timeit('forward', args.timing):
        model_boxes = boxes
        model_masks = masks
        model_out = model(objs, triples, obj_to_img,
                          boxes_gt=model_boxes, masks_gt=model_masks
                          )
        # imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
        imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores, obj_embeddings, pred_embeddings, triplet_boxes_pred, triplet_boxes, triplet_masks_pred, boxes_pred_info, triplet_superboxes_pred = model_out
        # boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores = model_out
     
      # add additional information for GT boxes (hack to not change coco.py)
      boxes_info = None
      if args.use_bbox_info and boxes_pred_info is not None:
        boxes_info = add_bbox_info(boxes) 
      # GT for triplet superbox
      triplet_superboxes = None
      if args.triplet_superbox_net and triplet_superboxes_pred is not None:
        # triplet_boxes = [ x1_0 y1_0 x1_1 y1_1 x2_0 y2_0 x2_1 y2_1]
        min_pts = triplet_boxes[:,:2]
        max_pts = triplet_boxes[:,6:8]
        triplet_superboxes = torch.cat([min_pts, max_pts], dim=1)

      with timeit('loss', args.timing):
        # Skip the pixel loss if using GT boxes
        #skip_pixel_loss = (model_boxes is None)
        skip_pixel_loss = True 
        # Skip the perceptual loss if using GT boxes
        #skip_perceptual_loss = (model_boxes is None)
        skip_perceptual_loss = True 

        if args.perceptual_loss_weight:
          total_loss, losses =  calculate_model_losses(
                                  args, skip_pixel_loss, model, imgs, mgs_pred,
                                  boxes, boxes_pred, masks, masks_pred,
                                  boxes_info, boxes_pred_info,
                                  predicates, predicate_scores, 
                                  triplet_boxes, triplet_boxes_pred, 
                                  triplet_masks, triplet_masks_pred, 
                                  triplet_superboxes, triplet_superboxes_pred,
                                  skip_perceptual_loss,
                                  perceptual_extractor=vgg_featureExractor)
        else:
          total_loss, losses =  calculate_model_losses(
                                    args, skip_pixel_loss, model, imgs, imgs_pred,
                                    boxes, boxes_pred, masks, masks_pred,
                                    boxes_info, boxes_pred_info,
                                    predicates, predicate_scores, 
                                    triplet_boxes, triplet_boxes_pred, 
                                    triplet_masks, triplet_masks_pred,
                                    triplet_superboxes, triplet_superboxes_pred,
                                    skip_perceptual_loss)  
          #total_loss, losses =  calculate_model_losses(
          #                          args, skip_pixel_loss, model, imgs, imgs_pred,
          #                          boxes, boxes_pred, masks, masks_pred,
          #                          predicates, predicate_scores, 

      if obj_discriminator is not None:
        scores_fake, ac_loss = obj_discriminator(imgs_pred, objs, boxes, obj_to_img)
        total_loss = add_loss(total_loss, ac_loss, losses, 'ac_loss',
                              args.ac_loss_weight)
        weight = args.discriminator_loss_weight * args.d_obj_weight
        total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
                              'g_gan_obj_loss', weight)

      if img_discriminator is not None:
        weight = args.discriminator_loss_weight * args.d_img_weight
      
        # scene_graph context by pooled GCNN features
        if args.sg_context_dim > 0:
          ## concatenate => imgs, (layout_embedding), sg_context_pred
          if args.layout_for_discrim == 1:
            discrim_pred = torch.cat([imgs_pred, layout, sg_context_pred_d], dim=1) 
          else:
            discrim_pred = torch.cat([imgs_pred, sg_context_pred_d], dim=1)  
          
          if args.matching_aware_loss:
            # shuffle sg_context_p to use addional fake examples with real-images
            matching_aware_size = sg_context_pred_d.size()[0]
            s_sg_context_pred_d = sg_context_pred_d[torch.randperm(matching_aware_size)]  
            if args.layout_for_discrim == 1:
              match_aware_discrim_pred = torch.cat([imgs, layout, s_sg_context_pred_d], dim=1) 
            else: 
              match_aware_discrim_pred = torch.cat([imgs, s_sg_context_pred_d], dim=1 )           
            discrim_pred = torch.cat([discrim_pred, match_aware_discrim_pred], dim=0)         
          
          scores_fake = img_discriminator(discrim_pred)         
          if args.matching_aware_loss:
            total_loss = add_loss(total_loss, gan_g_matching_aware_loss(scores_fake), losses,
                            'g_gan_img_loss', weight)
          else:
            total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
                              'g_gan_img_loss', weight)
        else:
          scores_fake = img_discriminator(imgs_pred)
          #weight = args.discriminator_loss_weight * args.d_img_weight
          total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
                                'g_gan_img_loss', weight)

      losses['total_loss'] = total_loss.item()
      if not math.isfinite(losses['total_loss']):
        print('WARNING: Got loss = NaN, not backpropping')
        continue

      optimizer.zero_grad()
      with timeit('backward', args.timing):
        total_loss.backward()
      optimizer.step()
      total_loss_d = None
      ac_loss_real = None
      ac_loss_fake = None
      d_losses = {}
      
      if obj_discriminator is not None:
        d_obj_losses = LossManager()
        imgs_fake = imgs_pred.detach()
        scores_fake, ac_loss_fake = obj_discriminator(imgs_fake, objs, boxes, obj_to_img)
        scores_real, ac_loss_real = obj_discriminator(imgs, objs, boxes, obj_to_img)

        d_obj_gan_loss = gan_d_loss(scores_real, scores_fake)
        d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
        d_obj_losses.add_loss(ac_loss_real, 'd_ac_loss_real')
        d_obj_losses.add_loss(ac_loss_fake, 'd_ac_loss_fake')

        optimizer_d_obj.zero_grad()
        d_obj_losses.total_loss.backward()
        optimizer_d_obj.step()


      if img_discriminator is not None:
        d_img_losses = LossManager()
               
        imgs_fake = imgs_pred.detach()

        if args.sg_context_dim_d > 0:
          sg_context_p = sg_context_pred_d.detach()  
        
        layout_p = layout.detach()
        # layout_gt_p = layout_gt.detach()

        ## concatenate=> imgs_fake, (layout_embedding), sg_context_pred
        if args.sg_context_dim > 0:
          if args.layout_for_discrim:
            discrim_fake = torch.cat([imgs_fake, layout_p, sg_context_p], dim=1 )  
            discrim_real = torch.cat([imgs, layout_p, sg_context_p], dim=1 ) 
            # discrim_real = torch.cat([imgs, layout_gt_p, sg_context_p], dim=1 ) 
          else:   
            discrim_fake = torch.cat([imgs_fake, sg_context_p], dim=1 ) 
            discrim_real = torch.cat([imgs, sg_context_p], dim=1 )   
              
          if args.matching_aware_loss:
            # shuffle sg_context_p to use addional fake examples with real-images
            matching_aware_size = sg_context_p.size()[0]
            s_sg_context_p = sg_context_p[torch.randperm(matching_aware_size)]
            # s_sg_context_p = sg_context_p[torch.randperm(args.batch_size)]
            if args.layout_for_discrim:
              match_aware_discrim_fake = torch.cat([imgs, layout_p, s_sg_context_p], dim=1 ) 
            else:
              match_aware_discrim_fake = torch.cat([imgs, s_sg_context_p], dim=1 ) 
            discrim_fake = torch.cat([discrim_fake, match_aware_discrim_fake], dim=0)   

          scores_fake = img_discriminator(discrim_fake) 
          scores_real = img_discriminator(discrim_real)

          if args.matching_aware_loss:
            d_img_gan_loss = gan_d_matching_aware_loss(scores_real, scores_fake)
          else:
            d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
        else:
          # imgs_fake = imgs_pred.detach()
          scores_fake = img_discriminator(imgs_fake)
          scores_real = img_discriminator(imgs)

        if args.matching_aware_loss:
          d_img_gan_loss = gan_d_matching_aware_loss(scores_real, scores_fake)
        else:
          d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
          
        d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')
        
        optimizer_d_img.zero_grad()
        d_img_losses.total_loss.backward()
        optimizer_d_img.step()

      # report intermediary values to stdout
      if t % args.print_every == 0:
        print('t = %d / %d' % (t, args.num_iterations))
        for name, val in losses.items():
          print(' G [%s]: %.4f' % (name, val))
          checkpoint['losses'][name].append(val)
        checkpoint['losses_ts'].append(t)

        if obj_discriminator is not None:
          for name, val in d_obj_losses.items():
            print(' D_obj [%s]: %.4f' % (name, val))
            checkpoint['d_losses'][name].append(val)

        if img_discriminator is not None:
          for name, val in d_img_losses.items():
            print(' D_img [%s]: %.4f' % (name, val))
            checkpoint['d_losses'][name].append(val)

        # ================================================================== #
        #                        Tensorboard Logging                         #
        # ================================================================== #

        # 1. Log scalar values (scalar summary)
        for name, val in losses.items():
            logger.scalar_summary(name, val, t)
        if obj_discriminator is not None:    
          for name, val in d_obj_losses.items():
              logger.scalar_summary(name, val, t)           
        if img_discriminator is not None:
          for name, val in d_img_losses.items():
              logger.scalar_summary(name, val, t)   
        logger.scalar_summary('score', t, t)

        if t % args.check_val_metrics_every == 0 and t > args.print_every:
          print('Checking val metrics')
          rel_score, avg_iou = get_rel_score(args, t, val_loader, model)
          logger.scalar_summary('relation_score', rel_score, t) 
          logger.scalar_summary('avg_IoU', avg_iou, t) 
          print(t, ': val relation score = ', rel_score)
          print(t, ': IoU = ', avg_iou)
          # checks entire val dataset
          val_results = check_model(args, t, val_loader, model, logger=logger, log_tag='Val', write_images=False)
          v_losses, v_samples, v_batch_data, v_avg_iou = val_results
          logger.scalar_summary('val_total_loss', v_losses['total_loss'], t)
      
      if t % args.checkpoint_every == 0:
        print('checking on train')
        train_results = check_model(args, t, train_loader, model, logger=logger, log_tag='Train', write_images=False)
        t_losses, t_samples, t_batch_data, t_avg_iou = train_results

        checkpoint['train_batch_data'].append(t_batch_data)
        checkpoint['train_samples'].append(t_samples)
        checkpoint['checkpoint_ts'].append(t)
        checkpoint['train_iou'].append(t_avg_iou)

        print('checking on val')
        val_results = check_model(args, t, val_loader, model, logger=logger, log_tag='Validation', write_images=True)

        val_losses, val_samples, val_batch_data, val_avg_iou = val_results
        checkpoint['val_samples'].append(val_samples)
        checkpoint['val_batch_data'].append(val_batch_data)
        checkpoint['val_iou'].append(val_avg_iou)
        
        print('train iou: ', t_avg_iou)
        print('val iou: ', val_avg_iou)

        for k, v in val_losses.items():
          checkpoint['val_losses'][k].append(v)
        checkpoint['model_state'] = model.state_dict()

        if obj_discriminator is not None:
          checkpoint['d_obj_state'] = obj_discriminator.state_dict()
          checkpoint['d_obj_optim_state'] = optimizer_d_obj.state_dict()

        if img_discriminator is not None:
          checkpoint['d_img_state'] = img_discriminator.state_dict()
          checkpoint['d_img_optim_state'] = optimizer_d_img.state_dict()

        checkpoint['optim_state'] = optimizer.state_dict()
        checkpoint['counters']['t'] = t
        checkpoint['counters']['epoch'] = epoch
        checkpoint_path = os.path.join(args.output_dir,
                              '%s_with_model_%d.pt' %(args.checkpoint_name, t) 
                              #'%s_with_model.pt' %args.checkpoint_name
                              )
        print('Saving checkpoint to ', checkpoint_path)
        torch.save(checkpoint, checkpoint_path)

        # Save another checkpoint without any model or optim state
        checkpoint_path = os.path.join(args.output_dir,
                              '%s_no_model.pt' % args.checkpoint_name)
        key_blacklist = ['model_state', 'optim_state', 'model_best_state',
                         'd_obj_state', 'd_obj_optim_state', 'd_obj_best_state',
                         'd_img_state', 'd_img_optim_state', 'd_img_best_state']
        small_checkpoint = {}
        for k, v in checkpoint.items():
          if k not in key_blacklist:
            small_checkpoint[k] = v
        torch.save(small_checkpoint, checkpoint_path)
コード例 #13
0
    def forward(self, imgs, img_offset, gt_boxes, gt_classes, gt_fmaps):
        # forward detector
        # with timeit('detector forward', self.args.timing):
        #     result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
        #                                train_anchor_inds, return_fmap=True)
        # if result.is_none():
        #     return ValueError("heck")

        # forward generator
        # imgs = F.interpolate(x, size=self.args.image_size)
        # objs = result.obj_preds
        # boxes = result.rm_box_priors / BOX_SCALE
        # obj_to_img = result.im_inds - image_offset
        # obj_fmap = result.obj_fmap
        #
        # # check if all image have detection
        # cnt = torch.zeros(len(imgs)).byte()
        # cnt[obj_to_img] += 1
        # if (cnt > 0).sum() != len(imgs):
        #     print("some imgs have no detection")
        #     print(cnt)
        #     imgs = imgs[cnt]
        #     obj_to_img_new = obj_to_img.clone()
        #     for i in range(len(cnt)):
        #         if cnt[i] == 0:
        #             obj_to_img_new -= (obj_to_img > i).long()
        #     obj_to_img = obj_to_img_new

        obj_to_img = gt_classes[:, 0] - img_offset
        # print("obj_to_img.min(), obj_to_img.max(), len(imgs) {} {} {}".format(obj_to_img.min(), obj_to_img.max(), len(imgs)))
        assert obj_to_img.min() >= 0 and obj_to_img.max() < len(imgs), \
            "obj_to_img.min() >= 0 and obj_to_img.max() < len(imgs) is not satidfied: {} {} {}".format(obj_to_img.min(), obj_to_img.max(), len(imgs))
        boxes = gt_boxes
        obj_fmaps = gt_fmaps
        objs = gt_classes[:, 1]

        if self.args is not None:
            if self.args.exchange_feat_cls:
                print("exchange feature vectors and classes among bboxes")
                for img_ind in range(imgs.shape[0]):
                    ind = (obj_to_img == img_ind).nonzero()[:, 0]
                    # permute = ind[torch.randperm(len(ind))]
                    # obj_fmaps[ind] = obj_fmaps[permute]
                    permute_ind = ind[torch.randperm(len(ind))[:2]]
                    permute = permute_ind[[1, 0]]
                    obj_fmaps[permute_ind] = obj_fmaps[permute]
                    objs[permute_ind] = objs[permute]

            if self.args.change_bbox:
                print("change the position of bboxes")
                for img_ind in range(imgs.shape[0]):
                    ind = (obj_to_img == img_ind).nonzero()[:, 0]
                    ind = ind[torch.randperm(len(ind))[0]]
                    if boxes[ind][3] < 0.8:
                        print("move to bottom")
                        boxes[ind][1] += (1 - boxes[ind][3])
                        boxes[ind][3] = 1
                    elif boxes[ind][1] > 0.2:
                        print("move to top")
                        boxes[ind][3] -= boxes[ind][1]
                        boxes[ind][1] = 0
                    elif boxes[ind][0] > 0.2:
                        print("move to left")
                        boxes[ind][2] -= boxes[ind][0]
                        boxes[ind][0] = 0
                    elif boxes[ind][2] < 0.8:
                        print("move to right")
                        boxes[ind][0] += (1 - boxes[ind][2])
                        boxes[ind][2] = 1
                    else:
                        print("move to bottom right")
                        boxes[ind][1] += (1 - boxes[ind][3])
                        boxes[ind][3] = 1
                        boxes[ind][0] += (1 - boxes[ind][2])
                        boxes[ind][2] = 1

        mask_noise_indexes = torch.randperm(
            imgs.shape[0])[:int(self.args.noise_mask_ratio *
                                imgs.shape[0])].to(imgs.device)
        if len(mask_noise_indexes) == 0:
            mask_noise_indexes = None

        H, W = self.args.image_size
        fg_layout = boxes_to_layout(
            torch.ones(boxes.shape[0], 3).to(imgs.device), boxes, obj_to_img,
            H, W)
        bg_layout = (fg_layout <= 0).type(imgs.dtype)

        if self.forward_G:
            with timeit('generator forward', self.args.timing):
                imgs_pred, layout = self.model(obj_to_img,
                                               boxes,
                                               obj_fmaps,
                                               mask_noise_indexes,
                                               bg_layout=bg_layout)

        layout = layout.detach()
        if self.args.condition_d_img_on_class_label_map:
            layout = boxes_to_layout(
                (objs + 1).view(-1, 1).repeat(1, 3).type(imgs.dtype), boxes,
                obj_to_img, H, W)

        g_scores_fake_crop, g_obj_scores_fake_crop, g_rec_feature_fake_crop = None, None, None
        g_scores_fake_img = None
        g_scores_fake_bg = None
        if self.calc_G_D_loss:
            # forward discriminators to train generator
            if self.obj_discriminator is not None:
                with timeit('d_obj forward for g', self.args.timing):
                    g_scores_fake_crop, g_obj_scores_fake_crop, _, g_rec_feature_fake_crop = \
                        self.obj_discriminator(imgs_pred, objs, boxes, obj_to_img)

            if self.img_discriminator is not None:
                with timeit('d_img forward for g', self.args.timing):
                    if self.args.condition_d_img:
                        g_scores_fake_img = self.img_discriminator(
                            imgs_pred, layout)
                    else:
                        g_scores_fake_img = self.img_discriminator(imgs_pred)

            if self.bg_discriminator is not None:
                with timeit('d_bg forward for g', self.args.timing):
                    if self.args.condition_d_bg:
                        g_scores_fake_bg = self.bg_discriminator(
                            imgs_pred, bg_layout)
                    else:
                        g_scores_fake_bg = self.bg_discriminator(imgs_pred *
                                                                 bg_layout)

        d_scores_fake_crop, d_obj_scores_fake_crop, fake_crops, d_rec_feature_fake_crop = None, None, None, None
        d_scores_real_crop, d_obj_scores_real_crop, real_crops, d_rec_feature_real_crop = None, None, None, None
        d_obj_gp = None
        d_scores_fake_img = None
        d_scores_real_img = None
        d_img_gp = None
        d_scores_fake_bg = None
        d_scores_real_bg = None
        d_bg_gp = None
        if self.forward_D:
            # forward discriminators to train discriminators
            if self.obj_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_obj forward for d', self.args.timing):
                    d_scores_fake_crop, d_obj_scores_fake_crop, fake_crops, d_rec_feature_fake_crop = \
                        self.obj_discriminator(imgs_fake, objs, boxes, obj_to_img)
                    d_scores_real_crop, d_obj_scores_real_crop, real_crops, d_rec_feature_real_crop = \
                        self.obj_discriminator(imgs, objs, boxes, obj_to_img)
                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        d_obj_gp = gradient_penalty(
                            real_crops.detach(), fake_crops.detach(),
                            self.obj_discriminator.discriminator)

            if self.img_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_img forward for d', self.args.timing):
                    if self.args.condition_d_img:
                        d_scores_fake_img = self.img_discriminator(
                            imgs_fake, layout)
                        d_scores_real_img = self.img_discriminator(
                            imgs, layout)
                    else:
                        d_scores_fake_img = self.img_discriminator(imgs_fake)
                        d_scores_real_img = self.img_discriminator(imgs)

                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        if self.args.condition_d_img:
                            d_img_gp = gradient_penalty(
                                torch.cat([imgs, layout], dim=1),
                                torch.cat([imgs_fake, layout], dim=1),
                                self.img_discriminator)
                        else:
                            d_img_gp = gradient_penalty(
                                imgs, imgs_fake, self.img_discriminator)

            if self.bg_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_bg forward for d', self.args.timing):
                    if self.args.condition_d_bg:
                        d_scores_fake_bg = self.bg_discriminator(
                            imgs_fake, bg_layout)
                        d_scores_real_bg = self.bg_discriminator(
                            imgs, bg_layout)
                    else:
                        d_scores_fake_bg = self.bg_discriminator(imgs_fake *
                                                                 bg_layout)
                        d_scores_real_bg = self.bg_discriminator(imgs *
                                                                 bg_layout)

                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        if self.args.condition_d_bg:
                            d_bg_gp = gradient_penalty(
                                torch.cat([imgs, bg_layout], dim=1),
                                torch.cat([imgs_fake, bg_layout], dim=1),
                                self.bg_discriminator)
                        else:
                            d_bg_gp = gradient_penalty(imgs * bg_layout,
                                                       imgs_fake * bg_layout,
                                                       self.bg_discriminator)
        return Result(
            imgs=imgs,
            imgs_pred=imgs_pred,
            objs=objs,
            obj_fmaps=obj_fmaps,
            boxes=boxes,
            obj_to_img=obj_to_img + img_offset,
            g_scores_fake_crop=g_scores_fake_crop,
            g_obj_scores_fake_crop=g_obj_scores_fake_crop,
            g_scores_fake_img=g_scores_fake_img,
            d_scores_fake_crop=d_scores_fake_crop,
            d_obj_scores_fake_crop=d_obj_scores_fake_crop,
            d_scores_real_crop=d_scores_real_crop,
            d_obj_scores_real_crop=d_obj_scores_real_crop,
            d_scores_fake_img=d_scores_fake_img,
            d_scores_real_img=d_scores_real_img,
            d_obj_gp=d_obj_gp,
            d_img_gp=d_img_gp,
            fake_crops=fake_crops,
            real_crops=real_crops,
            mask_noise_indexes=(mask_noise_indexes + img_offset)
            if mask_noise_indexes is not None else None,
            g_rec_feature_fake_crop=g_rec_feature_fake_crop,
            d_rec_feature_fake_crop=d_rec_feature_fake_crop,
            d_rec_feature_real_crop=d_rec_feature_real_crop,
            g_scores_fake_bg=g_scores_fake_bg,
            d_scores_fake_bg=d_scores_fake_bg,
            d_scores_real_bg=d_scores_real_bg,
            d_bg_gp=d_bg_gp,
            bg_layout=bg_layout,
        )
コード例 #14
0
def main(args):
    print(args)
    check_args(args)
    float_dtype = torch.cuda.FloatTensor
    long_dtype = torch.cuda.LongTensor

    vocab, train_loader, val_loader = build_loaders(args)
    print("vocab")
    print(vocab)
    obj_dict = vocab['object_name_to_idx']
    print("vocab['object_name_to_idx']")
    print(vocab['object_name_to_idx'])
    model, model_kwargs = build_model(args, vocab)
    model.type(float_dtype)
    print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    obj_discriminator, d_obj_kwargs = build_obj_discriminator(args, vocab)
    img_discriminator, d_img_kwargs = build_img_discriminator(args, vocab)
    gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)

    if obj_discriminator is not None:
        obj_discriminator.type(float_dtype)
        obj_discriminator.train()
        print(obj_discriminator)
        optimizer_d_obj = torch.optim.Adam(obj_discriminator.parameters(),
                                           lr=args.learning_rate)

    if img_discriminator is not None:
        img_discriminator.type(float_dtype)
        img_discriminator.train()
        print(img_discriminator)
        optimizer_d_img = torch.optim.Adam(img_discriminator.parameters(),
                                           lr=args.learning_rate)

    restore_path = './'

    # if args.restore_from_checkpoint:
    #   restore_path = '%s_with_model.pt' % args.checkpoint_name
    #   restore_path = os.path.join(args.output_dir, restore_path)
    restore_path = './sg2im-models/vg128.pt'
    if restore_path is not None and os.path.isfile(restore_path):
        print('Restoring from checkpoint:')
        print(restore_path)
        checkpoint = torch.load(restore_path)
        model.load_state_dict(checkpoint['model_state'])
        #optimizer.load_state_dict(checkpoint['optim_state'])

        # if obj_discriminator is not None:
        #   obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
        #   optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])

        #   if img_discriminator is not None:
        #     img_discriminator.load_state_dict(checkpoint['d_img_state'])
        #     optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])

        #   t = checkpoint['counters']['t']
        #   if 0 <= args.eval_mode_after <= t:
        #     model.eval()
        #   else:
        #     model.train()
        #   epoch = checkpoint['counters']['epoch']
        # else:
        #   t, epoch = 0, 0
        #   checkpoint = {
        #     'args': args.__dict__,
        #     'vocab': vocab,
        #     'model_kwargs': model_kwargs,
        #     'd_obj_kwargs': d_obj_kwargs,
        #     'd_img_kwargs': d_img_kwargs,
        #     'losses_ts': [],
        #     'losses': defaultdict(list),
        #     'd_losses': defaultdict(list),
        #     'checkpoint_ts': [],
        #     'train_batch_data': [],
        #     'train_samples': [],
        #     'train_iou': [],
        #     'val_batch_data': [],
        #     'val_samples': [],
        #     'val_losses': defaultdict(list),
        #     'val_iou': [],
        #     'norm_d': [],
        #     'norm_g': [],
        #     'counters': {
        #       't': None,
        #       'epoch': None,
        #     },
        #     'model_state': None, 'model_best_state': None, 'optim_state': None,
        #     'd_obj_state': None, 'd_obj_best_state': None, 'd_obj_optim_state': None,
        #     'd_img_state': None, 'd_img_best_state': None, 'd_img_optim_state': None,
        #     'best_t': [],
        #   }

        # while True:
        #   if t >= args.num_iterations:
        #     break
        #   epoch += 1
        #   print('Starting epoch %d' % epoch)

        # for batch in train_loader:
        #   if t == args.eval_mode_after:
        #     print('switching to eval mode')
        #     model.eval()
        #     optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
        #   t += 1
        #   batch = [tensor.cuda() for tensor in batch]
        #   masks = None
        #   if len(batch) == 6:
        #     imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
        #   elif len(batch) == 7:
        #     imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
        #   else:
        #     assert False
        #   predicates = triples[:, 1]

        #   with timeit('forward', args.timing):
        #     model_boxes = boxes

        #     model_masks = masks
        #     model_out = model(objs, triples, obj_to_img,
        #                       boxes_gt=model_boxes, masks_gt=model_masks)
        #     imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
        #   with timeit('loss', args.timing):
        #     # Skip the pixel loss if using GT boxes
        #     skip_pixel_loss = (model_boxes is None)
        #     total_loss, losses =  calculate_model_losses(
        #                             args, skip_pixel_loss, model, imgs, imgs_pred,
        #                             boxes, boxes_pred, masks, masks_pred,
        #                             predicates, predicate_scores)

        #   if obj_discriminator is not None:
        #     scores_fake, ac_loss = obj_discriminator(imgs_pred, objs, boxes, obj_to_img)
        #     total_loss = add_loss(total_loss, ac_loss, losses, 'ac_loss',
        #                           args.ac_loss_weight)
        #     weight = args.discriminator_loss_weight * args.d_obj_weight
        #     total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
        #                           'g_gan_obj_loss', weight)

        #   if img_discriminator is not None:
        #     scores_fake = img_discriminator(imgs_pred)
        #     weight = args.discriminator_loss_weight * args.d_img_weight
        #     total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
        #                           'g_gan_img_loss', weight)

        #   losses['total_loss'] = total_loss.item()
        #   if not math.isfinite(losses['total_loss']):
        #     print('WARNING: Got loss = NaN, not backpropping')
        #     continue

        #   optimizer.zero_grad()
        #   with timeit('backward', args.timing):
        #     total_loss.backward()
        #   optimizer.step()
        #   total_loss_d = None
        #   ac_loss_real = None
        #   ac_loss_fake = None
        #   d_losses = {}

        #   if obj_discriminator is not None:
        #     d_obj_losses = LossManager()
        #     imgs_fake = imgs_pred.detach()
        #     scores_fake, ac_loss_fake = obj_discriminator(imgs_fake, objs, boxes, obj_to_img)
        #     scores_real, ac_loss_real = obj_discriminator(imgs, objs, boxes, obj_to_img)

        #     d_obj_gan_loss = gan_d_loss(scores_real, scores_fake)
        #     d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
        #     d_obj_losses.add_loss(ac_loss_real, 'd_ac_loss_real')
        #     d_obj_losses.add_loss(ac_loss_fake, 'd_ac_loss_fake')

        #     optimizer_d_obj.zero_grad()
        #     d_obj_losses.total_loss.backward()
        #     optimizer_d_obj.step()

        #   if img_discriminator is not None:
        #     d_img_losses = LossManager()
        #     imgs_fake = imgs_pred.detach()
        #     scores_fake = img_discriminator(imgs_fake)
        #     scores_real = img_discriminator(imgs)

        #     d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
        #     d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')

        #     optimizer_d_img.zero_grad()
        #     d_img_losses.total_loss.backward()
        #     optimizer_d_img.step()

        #   if t % args.print_every == 0:
        #     print('t = %d / %d' % (t, args.num_iterations))
        #     for name, val in losses.items():
        #       print(' G [%s]: %.4f' % (name, val))
        #       checkpoint['losses'][name].append(val)
        #     checkpoint['losses_ts'].append(t)

        #     if obj_discriminator is not None:
        #       for name, val in d_obj_losses.items():
        #         print(' D_obj [%s]: %.4f' % (name, val))
        #         checkpoint['d_losses'][name].append(val)

        #     if img_discriminator is not None:
        #       for name, val in d_img_losses.items():
        #         print(' D_img [%s]: %.4f' % (name, val))
        #         checkpoint['d_losses'][name].append(val)

        #   if t % args.checkpoint_every == 0:
        #     print('checking on train')
        #     train_results = check_model(args, t, train_loader, model)
        #     t_losses, t_samples, t_batch_data, t_avg_iou = train_results

        #     checkpoint['train_batch_data'].append(t_batch_data)
        #     checkpoint['train_samples'].append(t_samples)
        #     checkpoint['checkpoint_ts'].append(t)
        #     checkpoint['train_iou'].append(t_avg_iou)

        #     print('checking on val')
        #     val_results = check_model(args, t, val_loader, model)
        #     val_losses, val_samples, val_batch_data, val_avg_iou = val_results
        #     checkpoint['val_samples'].append(val_samples)
        #     checkpoint['val_batch_data'].append(val_batch_data)
        #     checkpoint['val_iou'].append(val_avg_iou)

        #     print('train iou: ', t_avg_iou)
        #     print('val iou: ', val_avg_iou)

        #     for k, v in val_losses.items():
        #       checkpoint['val_losses'][k].append(v)
        #     checkpoint['model_state'] = model.state_dict()

        #     if obj_discriminator is not None:
        #       checkpoint['d_obj_state'] = obj_discriminator.state_dict()
        #       checkpoint['d_obj_optim_state'] = optimizer_d_obj.state_dict()

        #     if img_discriminator is not None:
        #       checkpoint['d_img_state'] = img_discriminator.state_dict()
        #       checkpoint['d_img_optim_state'] = optimizer_d_img.state_dict()

        #     checkpoint['optim_state'] = optimizer.state_dict()
        #     checkpoint['counters']['t'] = t
        #     checkpoint['counters']['epoch'] = epoch
        #     checkpoint_path = os.path.join(args.output_dir,
        #                           '%s_with_model.pt' % args.checkpoint_name)
        #     print('Saving checkpoint to ', checkpoint_path)
        #     torch.save(checkpoint, checkpoint_path)

        #     # Save another checkpoint without any model or optim state
        #     checkpoint_path = os.path.join(args.output_dir,
        #                           '%s_no_model.pt' % args.checkpoint_name)
        #     key_blacklist = ['model_state', 'optim_state', 'model_best_state',
        #                      'd_obj_state', 'd_obj_optim_state', 'd_obj_best_state',
        #                      'd_img_state', 'd_img_optim_state', 'd_img_best_state']
        #     small_checkpoint = {}
        #     for k, v in checkpoint.items():
        #       if k not in key_blacklist:
        #         small_checkpoint[k] = v
        #     torch.save(small_checkpoint, checkpoint_path)
        batch_index = 0
        print('switching to eval mode')
        model.eval()
        for val_batch in val_loader:
            batch_index += 1
            print("val_batch_index: ", batch_index)

            val_batch = [tensor.cuda() for tensor in val_batch]
            val_masks = None
            if len(val_batch) == 6:
                val_imgs, val_objs, val_boxes, val_triples, val_obj_to_img, val_triple_to_img = val_batch
            elif len(val_batch) == 7:
                val_imgs, val_objs, val_boxes, val_masks, val_triples, val_obj_to_img, val_triple_to_img = val_batch
            else:
                assert False
            predicates = val_triples[:, 1]

            with timeit('forward', args.timing):
                val_model_boxes = val_boxes

                val_model_out = model(val_objs,
                                      val_triples,
                                      val_obj_to_img,
                                      boxes_gt=val_model_boxes,
                                      masks_gt=val_masks)
                val_imgs_pred, val_boxes_pred, val_masks_pred, val_predicate_scores = val_model_out

                val_imgs = imagenet_deprocess_batch(val_imgs_pred)

                output_img_dir = "./output_batch"

                if not os.path.exists(output_img_dir):
                    os.makedirs(output_img_dir)

                print("label: ")
                print(val_objs.shape[0])
                print(val_objs)

                object_name_list = []
                for label_index in range(val_objs.shape[0]):
                    object_index = val_objs[label_index].cpu().data.numpy()

                    object_name = list(obj_dict.keys())[list(
                        obj_dict.values()).index(object_index)]
                    object_name_list.append(object_name)
                    #print("val_objs[label_index]", val_objs[label_index].cpu().data.numpy())
                    #print("object_name: ", object_name)
                print(object_name_list)
                print("val_obj_to_img")
                print(val_obj_to_img)
                print("gt_boxes: ", val_model_boxes.shape)
                print(val_model_boxes)
                # Save the generated images
                for img_index in range(val_imgs.shape[0]):
                    img_np = val_imgs[img_index].numpy().transpose(1, 2, 0)
                    img_path = os.path.join(
                        output_img_dir,
                        'img_{}_{}.png'.format('%04d' % batch_index,
                                               '%03d' % img_index))
                    cv2.imwrite(img_path, img_np)

                #print("val_imgs_pred.shape: ", val_imgs_pred.shape)
                raise Exception("hahha, gonna save val_imgs_pred")
コード例 #15
0
        imgs, objs = imgs.cuda(), objs.cuda()
        zs[:, 50:] = objs.view(-1, 1).repeat(1, 50)
        boxes = torch.Tensor([0, 0, 1, 1]).view(1,
                                                -1).repeat(args.batch_size,
                                                           1).cuda()
        obj_to_img = torch.arange(args.batch_size).cuda()

        imgs_pred = generator[0](zs).view(args.batch_size, 64, 2, 2)
        for i in range(1, len(generator)):
            imgs_pred = generator[i](imgs_pred)

        # print(zs.shape, imgs.shape, objs.shape, boxes.shape, obj_to_img.shape, imgs_pred.shape)
        if t % (args.n_critic + 1) != 0:
            if obj_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_obj forward for d', args.timing):
                    d_scores_fake_crop, d_obj_scores_fake_crop, fake_crops, d_rec_feature_fake_crop = \
                        obj_discriminator(imgs_fake, objs, boxes, obj_to_img)
                    d_scores_real_crop, d_obj_scores_real_crop, real_crops, d_rec_feature_real_crop = \
                        obj_discriminator(imgs, objs, boxes, obj_to_img)
                    if args.gan_loss_type == "wgan-gp":
                        d_obj_gp = gradient_penalty(
                            real_crops.detach(), fake_crops.detach(),
                            obj_discriminator.discriminator)

                ## train d
                with timeit('d_obj loss', args.timing):
                    d_obj_losses = LossManager()
                    if args.d_obj_weight > 0:
                        d_obj_gan_loss = gan_d_loss(d_obj_scores_real_crop,
                                                    d_obj_scores_fake_crop)
コード例 #16
0
ファイル: train_all_in_one.py プロジェクト: LUGUANSONG/i2g2i
def main(args):
    print(args)
    check_args(args)
    if not exists(args.output_dir):
        os.makedirs(args.output_dir)
    summary_writer = SummaryWriter(args.output_dir)

    if args.coco:
        train, val = CocoDetection.splits()
        val.ids = val.ids[:args.val_size]
        train.ids = train.ids
        train_loader, val_loader = CocoDataLoader.splits(train, val, batch_size=args.batch_size,
                                                         num_workers=args.num_workers,
                                                         num_gpus=args.num_gpus)
    else:
        train, val, _ = VG.splits(num_val_im=args.val_size, filter_non_overlap=False,
                                  filter_empty_rels=False, use_proposals=args.use_proposals)
        train_loader, val_loader = VGDataLoader.splits(train, val, batch_size=args.batch_size,
                                                       num_workers=args.num_workers,
                                                       num_gpus=args.num_gpus)
    print(train.ind_to_classes)
    os._exit(0)

    all_in_one_model = neural_motifs_sg2im_model(args, train.ind_to_classes)
    # Freeze the detector
    for n, param in all_in_one_model.detector.named_parameters():
        param.requires_grad = False
    all_in_one_model.cuda()
    gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)

    t, epoch, checkpoint = all_in_one_model.t, all_in_one_model.epoch, all_in_one_model.checkpoint
    while True:
        if t >= args.num_iterations:
            break
        epoch += 1
        print('Starting epoch %d' % epoch)

        for step, batch in enumerate(tqdm(train_loader, desc='Training Epoch %d' % epoch, total=len(train_loader))):
            if t == args.eval_mode_after:
                print('switching to eval mode')
                all_in_one_model.model.eval()
                all_in_one_model.optimizer = optim.Adam(all_in_one_model.parameters(), lr=args.learning_rate)
            t += 1

            with timeit('forward', args.timing):
                result = all_in_one_model[batch]
                imgs, imgs_pred, objs, g_scores_fake_crop, g_obj_scores_fake_crop, g_scores_fake_img, \
                d_scores_fake_crop, d_obj_scores_fake_crop, d_scores_real_crop, d_obj_scores_real_crop, \
                d_scores_fake_img, d_scores_real_img = result.imgs, result.imgs_pred, result.objs, \
                result.g_scores_fake_crop, result.g_obj_scores_fake_crop, result.g_scores_fake_img, \
                result.d_scores_fake_crop, result.d_obj_scores_fake_crop, result.d_scores_real_crop, \
                result.d_obj_scores_real_crop, result.d_scores_fake_img, result.d_scores_real_img

            with timeit('loss', args.timing):
                total_loss, losses = calculate_model_losses(
                    args, imgs, imgs_pred)

                if all_in_one_model.obj_discriminator is not None:
                    total_loss = add_loss(total_loss, F.cross_entropy(g_obj_scores_fake_crop, objs), losses, 'ac_loss',
                                          args.ac_loss_weight)
                    weight = args.discriminator_loss_weight * args.d_obj_weight
                    total_loss = add_loss(total_loss, gan_g_loss(g_scores_fake_crop), losses,
                                          'g_gan_obj_loss', weight)

                if all_in_one_model.img_discriminator is not None:
                    weight = args.discriminator_loss_weight * args.d_img_weight
                    total_loss = add_loss(total_loss, gan_g_loss(g_scores_fake_img), losses,
                                          'g_gan_img_loss', weight)

            losses['total_loss'] = total_loss.item()
            if not math.isfinite(losses['total_loss']):
                print('WARNING: Got loss = NaN, not backpropping')
                continue

            with timeit('backward', args.timing):
                all_in_one_model.optimizer.zero_grad()
                total_loss.backward()
                all_in_one_model.optimizer.step()


            if all_in_one_model.obj_discriminator is not None:
                with timeit('d_obj loss', args.timing):
                    d_obj_losses = LossManager()
                    d_obj_gan_loss = gan_d_loss(d_scores_real_crop, d_scores_fake_crop)
                    d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
                    d_obj_losses.add_loss(F.cross_entropy(d_obj_scores_real_crop, objs), 'd_ac_loss_real')
                    d_obj_losses.add_loss(F.cross_entropy(d_obj_scores_fake_crop, objs), 'd_ac_loss_fake')

                with timeit('d_obj backward', args.timing):
                    all_in_one_model.optimizer_d_obj.zero_grad()
                    d_obj_losses.total_loss.backward()
                    all_in_one_model.optimizer_d_obj.step()

            if all_in_one_model.img_discriminator is not None:
                with timeit('d_img loss', args.timing):
                    d_img_losses = LossManager()
                    d_img_gan_loss = gan_d_loss(d_scores_real_img, d_scores_fake_img)
                    d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')

                with timeit('d_img backward', args.timing):
                    all_in_one_model.optimizer_d_img.zero_grad()
                    d_img_losses.total_loss.backward()
                    all_in_one_model.optimizer_d_img.step()

            if t % args.print_every == 0:
                print('t = %d / %d' % (t, args.num_iterations))
                G_loss_list = []
                for name, val in losses.items():
                    G_loss_list.append('[%s]: %.4f' % (name, val))
                    checkpoint['losses'][name].append(val)
                    summary_writer.add_scalar("G_%s" % name, val, t)
                print("G: %s" % ", ".join(G_loss_list))
                checkpoint['losses_ts'].append(t)

                if all_in_one_model.obj_discriminator is not None:
                    D_obj_loss_list = []
                    for name, val in d_obj_losses.items():
                        D_obj_loss_list.append('[%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)
                        summary_writer.add_scalar("D_obj_%s" % name, val, t)
                    print("D_obj: %s" % ", ".join(D_obj_loss_list))

                if all_in_one_model.img_discriminator is not None:
                    D_img_loss_list = []
                    for name, val in d_img_losses.items():
                        D_img_loss_list.append('[%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)
                        summary_writer.add_scalar("D_img_%s" % name, val, t)
                    print("D_img: %s" % ", ".join(D_img_loss_list))

            if t % args.checkpoint_every == 0:
                print('checking on train')
                train_results = check_model(args, train_loader, all_in_one_model)
                t_losses, t_samples = train_results

                checkpoint['train_samples'].append(t_samples)
                checkpoint['checkpoint_ts'].append(t)
                for name, images in t_samples.items():
                    summary_writer.add_image("train_%s" % name, images, t)

                print('checking on val')
                val_results = check_model(args, val_loader, all_in_one_model)
                val_losses, val_samples = val_results
                checkpoint['val_samples'].append(val_samples)
                for name, images in val_samples.items():
                    summary_writer.add_image("val_%s" % name, images, t)

                for k, v in val_losses.items():
                    checkpoint['val_losses'][k].append(v)
                    summary_writer.add_scalar("val_%s" % k, v, t)
                checkpoint['model_state'] = all_in_one_model.model.state_dict()

                if all_in_one_model.obj_discriminator is not None:
                    checkpoint['d_obj_state'] = all_in_one_model.obj_discriminator.state_dict()
                    checkpoint['d_obj_optim_state'] = all_in_one_model.optimizer_d_obj.state_dict()

                if all_in_one_model.img_discriminator is not None:
                    checkpoint['d_img_state'] = all_in_one_model.img_discriminator.state_dict()
                    checkpoint['d_img_optim_state'] = all_in_one_model.optimizer_d_img.state_dict()

                checkpoint['optim_state'] = all_in_one_model.optimizer.state_dict()
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint_path = os.path.join(args.output_dir,
                                               '%s_with_model.pt' % args.checkpoint_name)
                print('Saving checkpoint to ', checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

                # Save another checkpoint without any model or optim state
                checkpoint_path = os.path.join(args.output_dir,
                                               '%s_no_model.pt' % args.checkpoint_name)
                key_blacklist = ['model_state', 'optim_state', 'model_best_state',
                                 'd_obj_state', 'd_obj_optim_state', 'd_obj_best_state',
                                 'd_img_state', 'd_img_optim_state', 'd_img_best_state']
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
コード例 #17
0
ファイル: train.py プロジェクト: bschroedr/sg2im
def main(args):
    print(args)
    check_args(args)
    float_dtype = torch.cuda.FloatTensor
    long_dtype = torch.cuda.LongTensor

    vocab, train_loader, val_loader = build_loaders(args)
    model, model_kwargs = build_model(args, vocab)
    model.type(float_dtype)
    print(model)

    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)
        print('Created %s' % args.output_dir)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    vgg_featureExractor = FeatureExtractor(requires_grad=False).cuda()
    ## add code for training visualization
    logger = Logger(args.output_dir)

    obj_discriminator, d_obj_kwargs = build_obj_discriminator(args, vocab)
    img_discriminator, d_img_kwargs = build_img_discriminator(args, vocab)
    gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)

    if args.matching_aware_loss and args.sg_context_dim > 0:
        gan_g_matching_aware_loss, gan_d_matching_aware_loss = get_gan_losses(
            'matching_aware_gan')
    ############
    if obj_discriminator is not None:
        obj_discriminator.type(float_dtype)
        obj_discriminator.train()
        print(obj_discriminator)
        optimizer_d_obj = torch.optim.Adam(obj_discriminator.parameters(),
                                           lr=args.learning_rate)

    if img_discriminator is not None:
        img_discriminator.type(float_dtype)
        img_discriminator.train()
        print(img_discriminator)
        optimizer_d_img = torch.optim.Adam(img_discriminator.parameters(),
                                           lr=args.learning_rate)

    ##########################################################
    if args.kubernetes:

        ###ailab experiments
        from lib.exp import Client
        import kubernetes

        namespace = os.getenv("EXPERIMENT_NAMESPACE")
        job_name = os.getenv("JOB_NAME")
        client = Client(namespace)
        experiment = client.current_experiment()
        job = client.get_job(job_name)
        experiment_result = experiment.result(job)
        try:
            result = client.create_result(experiment_result)
        except kubernetes.client.rest.ApiException as e:
            body = json.loads(e.body)
            if body['reason'] != 'AlreadyExists':
                raise e
            result = client.get_result(experiment_result.name)

    restore_path = None
    if args.restore_from_checkpoint:
        restore_path = '%s_with_model.pt' % args.checkpoint_name
        restore_path = os.path.join(args.output_dir, restore_path)
    if restore_path is not None and os.path.isfile(restore_path):
        print('Restoring from checkpoint:')
        print(restore_path)
        checkpoint = torch.load(restore_path)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optim_state'])

        if obj_discriminator is not None:
            obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
            optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])

        if img_discriminator is not None:

            img_discriminator.load_state_dict(checkpoint['d_img_state'])
            optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])

        t = checkpoint['counters']['t']
        if 0 <= args.eval_mode_after <= t:
            model.eval()
        else:
            model.train()
        epoch = checkpoint['counters']['epoch']
    else:
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'vocab': vocab,
            'model_kwargs': model_kwargs,
            'd_obj_kwargs': d_obj_kwargs,
            'd_img_kwargs': d_img_kwargs,
            'losses_ts': [],
            'losses': defaultdict(list),
            'd_losses': defaultdict(list),
            'checkpoint_ts': [],
            'train_batch_data': [],
            'train_samples': [],
            'train_iou': [],
            'val_batch_data': [],
            'val_samples': [],
            'val_losses': defaultdict(list),
            'val_iou': [],
            'norm_d': [],
            'norm_g': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'model_state': None,
            'model_best_state': None,
            'optim_state': None,
            'd_obj_state': None,
            'd_obj_best_state': None,
            'd_obj_optim_state': None,
            'd_img_state': None,
            'd_img_best_state': None,
            'd_img_optim_state': None,
            'best_t': [],
        }

    while True:
        if t >= args.num_iterations:
            break
        epoch += 1
        print('Starting epoch %d' % epoch)

        for batch in train_loader:
            if t == args.eval_mode_after:
                print('switching to eval mode')
                model.eval()
                optimizer = optim.Adam(model.parameters(),
                                       lr=args.learning_rate)
            t += 1
            batch = [tensor.cuda() for tensor in batch]
            masks = None
            if len(batch) == 6:
                imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
            elif len(batch) == 7:
                imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
            else:
                assert False
            predicates = triples[:, 1]

            with timeit('forward', args.timing):
                model_boxes = boxes
                model_masks = masks
                model_out = model(objs,
                                  triples,
                                  obj_to_img,
                                  boxes_gt=model_boxes,
                                  masks_gt=model_masks)
                # imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
                imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores = model_out

            with timeit('loss', args.timing):
                # Skip the pixel loss if using GT boxes
                skip_pixel_loss = (model_boxes is None)
                # Skip the perceptual loss if using GT boxes
                skip_perceptual_loss = (model_boxes is None)

                if args.perceptual_loss_weight:
                    total_loss, losses = calculate_model_losses(
                        args,
                        skip_pixel_loss,
                        model,
                        imgs,
                        imgs_pred,
                        boxes,
                        boxes_pred,
                        masks,
                        masks_pred,
                        predicates,
                        predicate_scores,
                        skip_perceptual_loss,
                        perceptual_extractor=vgg_featureExractor)
                else:
                    total_loss, losses = calculate_model_losses(
                        args, skip_pixel_loss, model, imgs, imgs_pred, boxes,
                        boxes_pred, masks, masks_pred, predicates,
                        predicate_scores, skip_perceptual_loss)

            if obj_discriminator is not None:
                scores_fake, ac_loss = obj_discriminator(
                    imgs_pred, objs, boxes, obj_to_img)
                total_loss = add_loss(total_loss, ac_loss, losses, 'ac_loss',
                                      args.ac_loss_weight)
                weight = args.discriminator_loss_weight * args.d_obj_weight
                total_loss = add_loss(total_loss, gan_g_loss(scores_fake),
                                      losses, 'g_gan_obj_loss', weight)

            if img_discriminator is not None:
                weight = args.discriminator_loss_weight * args.d_img_weight

                # scene_graph context by pooled GCNN features
                if args.sg_context_dim > 0:
                    ## concatenate => imgs, (layout_embedding), sg_context_pred
                    if args.layout_for_discrim == 1:
                        discrim_pred = torch.cat(
                            [imgs_pred, layout, sg_context_pred_d], dim=1)
                    else:
                        discrim_pred = torch.cat(
                            [imgs_pred, sg_context_pred_d], dim=1)

                    if args.matching_aware_loss:
                        # shuffle sg_context_p to use addional fake examples with real-images
                        matching_aware_size = sg_context_pred_d.size()[0]
                        s_sg_context_pred_d = sg_context_pred_d[torch.randperm(
                            matching_aware_size)]
                        if args.layout_for_discrim == 1:
                            match_aware_discrim_pred = torch.cat(
                                [imgs, layout, s_sg_context_pred_d], dim=1)
                        else:
                            match_aware_discrim_pred = torch.cat(
                                [imgs, s_sg_context_pred_d], dim=1)
                        discrim_pred = torch.cat(
                            [discrim_pred, match_aware_discrim_pred], dim=0)

                    scores_fake = img_discriminator(discrim_pred)
                    if args.matching_aware_loss:
                        total_loss = add_loss(
                            total_loss, gan_g_matching_aware_loss(scores_fake),
                            losses, 'g_gan_img_loss', weight)
                    else:
                        total_loss = add_loss(total_loss,
                                              gan_g_loss(scores_fake), losses,
                                              'g_gan_img_loss', weight)
                else:
                    scores_fake = img_discriminator(imgs_pred)
                    #weight = args.discriminator_loss_weight * args.d_img_weight
                    total_loss = add_loss(total_loss, gan_g_loss(scores_fake),
                                          losses, 'g_gan_img_loss', weight)

            losses['total_loss'] = total_loss.item()
            if not math.isfinite(losses['total_loss']):
                print('WARNING: Got loss = NaN, not backpropping')
                continue

            optimizer.zero_grad()
            with timeit('backward', args.timing):
                total_loss.backward()
            optimizer.step()
            total_loss_d = None
            ac_loss_real = None
            ac_loss_fake = None
            d_losses = {}

            if obj_discriminator is not None:
                d_obj_losses = LossManager()
                imgs_fake = imgs_pred.detach()
                scores_fake, ac_loss_fake = obj_discriminator(
                    imgs_fake, objs, boxes, obj_to_img)
                scores_real, ac_loss_real = obj_discriminator(
                    imgs, objs, boxes, obj_to_img)

                d_obj_gan_loss = gan_d_loss(scores_real, scores_fake)
                d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
                d_obj_losses.add_loss(ac_loss_real, 'd_ac_loss_real')
                d_obj_losses.add_loss(ac_loss_fake, 'd_ac_loss_fake')

                optimizer_d_obj.zero_grad()
                d_obj_losses.total_loss.backward()
                optimizer_d_obj.step()

            if img_discriminator is not None:
                d_img_losses = LossManager()

                imgs_fake = imgs_pred.detach()

                if args.sg_context_dim_d > 0:
                    sg_context_p = sg_context_pred_d.detach()

                layout_p = layout.detach()
                # layout_gt_p = layout_gt.detach()

                ## concatenate=> imgs_fake, (layout_embedding), sg_context_pred
                if args.sg_context_dim > 0:
                    if args.layout_for_discrim:
                        discrim_fake = torch.cat(
                            [imgs_fake, layout_p, sg_context_p], dim=1)
                        discrim_real = torch.cat(
                            [imgs, layout_p, sg_context_p], dim=1)
                        # discrim_real = torch.cat([imgs, layout_gt_p, sg_context_p], dim=1 )
                    else:
                        discrim_fake = torch.cat([imgs_fake, sg_context_p],
                                                 dim=1)
                        discrim_real = torch.cat([imgs, sg_context_p], dim=1)

                    if args.matching_aware_loss:
                        # shuffle sg_context_p to use addional fake examples with real-images
                        matching_aware_size = sg_context_p.size()[0]
                        s_sg_context_p = sg_context_p[torch.randperm(
                            matching_aware_size)]
                        # s_sg_context_p = sg_context_p[torch.randperm(args.batch_size)]
                        if args.layout_for_discrim:
                            match_aware_discrim_fake = torch.cat(
                                [imgs, layout_p, s_sg_context_p], dim=1)
                        else:
                            match_aware_discrim_fake = torch.cat(
                                [imgs, s_sg_context_p], dim=1)
                        discrim_fake = torch.cat(
                            [discrim_fake, match_aware_discrim_fake], dim=0)

                    scores_fake = img_discriminator(discrim_fake)
                    scores_real = img_discriminator(discrim_real)

                    if args.matching_aware_loss:
                        d_img_gan_loss = gan_d_matching_aware_loss(
                            scores_real, scores_fake)
                    else:
                        d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
                else:
                    # imgs_fake = imgs_pred.detach()
                    scores_fake = img_discriminator(imgs_fake)
                    scores_real = img_discriminator(imgs)

                if args.matching_aware_loss:
                    d_img_gan_loss = gan_d_matching_aware_loss(
                        scores_real, scores_fake)
                else:
                    d_img_gan_loss = gan_d_loss(scores_real, scores_fake)

                d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')

                optimizer_d_img.zero_grad()
                d_img_losses.total_loss.backward()
                optimizer_d_img.step()

            if t % args.print_every == 0:
                print('t = %d / %d' % (t, args.num_iterations))
                for name, val in losses.items():
                    print(' G [%s]: %.4f' % (name, val))
                    checkpoint['losses'][name].append(val)
                checkpoint['losses_ts'].append(t)

                if obj_discriminator is not None:
                    for name, val in d_obj_losses.items():
                        print(' D_obj [%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)

                if img_discriminator is not None:
                    for name, val in d_img_losses.items():
                        print(' D_img [%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)

                # ================================================================== #
                #                        Tensorboard Logging                         #
                # ================================================================== #

                # 1. Log scalar values (scalar summary)
                for name, val in losses.items():
                    logger.scalar_summary(name, val, t)
                if obj_discriminator is not None:
                    for name, val in d_obj_losses.items():
                        logger.scalar_summary(name, val, t)
                if img_discriminator is not None:
                    for name, val in d_img_losses.items():
                        logger.scalar_summary(name, val, t)

            if t % args.checkpoint_every == 0:
                print('checking on train')
                train_results = check_model(args,
                                            t,
                                            train_loader,
                                            model,
                                            logger=logger,
                                            log_tag='Train',
                                            write_images=False)
                t_losses, t_samples, t_batch_data, t_avg_iou = train_results

                checkpoint['train_batch_data'].append(t_batch_data)
                checkpoint['train_samples'].append(t_samples)
                checkpoint['checkpoint_ts'].append(t)
                checkpoint['train_iou'].append(t_avg_iou)

                print('checking on val')
                val_results = check_model(args,
                                          t,
                                          val_loader,
                                          model,
                                          logger=logger,
                                          log_tag='Validation',
                                          write_images=True)

                val_losses, val_samples, val_batch_data, val_avg_iou = val_results
                checkpoint['val_samples'].append(val_samples)
                checkpoint['val_batch_data'].append(val_batch_data)
                checkpoint['val_iou'].append(val_avg_iou)

                print('train iou: ', t_avg_iou)
                print('val iou: ', val_avg_iou)

                for k, v in val_losses.items():
                    checkpoint['val_losses'][k].append(v)
                checkpoint['model_state'] = model.state_dict()

                if obj_discriminator is not None:
                    checkpoint['d_obj_state'] = obj_discriminator.state_dict()
                    checkpoint[
                        'd_obj_optim_state'] = optimizer_d_obj.state_dict()

                if img_discriminator is not None:
                    checkpoint['d_img_state'] = img_discriminator.state_dict()
                    checkpoint[
                        'd_img_optim_state'] = optimizer_d_img.state_dict()

                checkpoint['optim_state'] = optimizer.state_dict()
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint_path = os.path.join(
                    args.output_dir,
                    #'%s_with_model_%d.pt' %(args.checkpoint_name, t)
                    '%s_with_model.pt' % args.checkpoint_name)
                print('Saving checkpoint to ', checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

                # Save another checkpoint without any model or optim state
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                key_blacklist = [
                    'model_state', 'optim_state', 'model_best_state',
                    'd_obj_state', 'd_obj_optim_state', 'd_obj_best_state',
                    'd_img_state', 'd_img_optim_state', 'd_img_best_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
コード例 #18
0
def main(args):
    print(args)
    check_args(args)
    if not exists(args.output_dir):
        os.makedirs(args.output_dir)
    summary_writer = SummaryWriter(args.output_dir)

    # if args.coco:
    #     train, val = CocoDetection.splits()
    #     val.ids = val.ids[:args.val_size]
    #     train.ids = train.ids
    #     train_loader, val_loader = CocoDataLoader.splits(train, val, batch_size=args.batch_size,
    #                                                      num_workers=args.num_workers,
    #                                                      num_gpus=args.num_gpus)
    # else:
    train, val, _ = VG.splits(transform=transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ]),
                              args=args)
    train_loader, val_loader = VGDataLoader.splits(
        train,
        val,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        num_gpus=args.num_gpus)
    print(train.ind_to_classes)

    all_in_one_model = neural_motifs_sg2im_model(args, train.ind_to_classes)
    print(all_in_one_model)
    # Freeze the detector
    # for n, param in all_in_one_model.detector.named_parameters():
    #     param.requires_grad = False
    all_in_one_model.cuda()
    gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)
    criterionVGG = VGGLoss() if args.perceptual_loss_weight > 0 else None

    t, epoch, checkpoint = all_in_one_model.t, all_in_one_model.epoch, all_in_one_model.checkpoint

    def D_step(result):
        imgs, imgs_pred, objs, \
        d_scores_fake_crop, d_obj_scores_fake_crop, d_scores_real_crop, \
        d_obj_scores_real_crop, d_scores_fake_img, d_scores_real_img, \
        d_obj_gp, d_img_gp \
        = result.imgs, result.imgs_pred, result.objs, \
          result.d_scores_fake_crop, result.d_obj_scores_fake_crop, result.d_scores_real_crop, \
          result.d_obj_scores_real_crop, result.d_scores_fake_img, result.d_scores_real_img, \
          result.d_obj_gp, result.d_img_gp
        d_rec_feature_fake_crop, d_rec_feature_real_crop = result.d_rec_feature_fake_crop, result.d_rec_feature_real_crop
        obj_fmaps = result.obj_fmaps
        d_scores_fake_bg, d_scores_real_bg, d_bg_gp = result.d_scores_fake_bg, result.d_scores_real_bg, result.d_bg_gp

        d_obj_losses, d_img_losses, d_bg_losses = None, None, None
        if all_in_one_model.obj_discriminator is not None:
            with timeit('d_obj loss', args.timing):
                d_obj_losses = LossManager()
                if args.d_obj_weight > 0:
                    d_obj_gan_loss = gan_d_loss(d_scores_real_crop,
                                                d_scores_fake_crop)
                    d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
                    if args.gan_loss_type == 'wgan-gp':
                        d_obj_losses.add_loss(d_obj_gp.mean(), 'd_obj_gp',
                                              args.d_obj_gp_weight)
                if args.ac_loss_weight > 0:
                    d_obj_losses.add_loss(
                        F.cross_entropy(d_obj_scores_real_crop, objs),
                        'd_ac_loss_real')
                    d_obj_losses.add_loss(
                        F.cross_entropy(d_obj_scores_fake_crop, objs),
                        'd_ac_loss_fake')
                if args.d_obj_rec_feat_weight > 0:
                    d_obj_losses.add_loss(
                        F.l1_loss(d_rec_feature_fake_crop, obj_fmaps),
                        'd_obj_fea_rec_loss_fake')
                    d_obj_losses.add_loss(
                        F.l1_loss(d_rec_feature_real_crop, obj_fmaps),
                        'd_obj_fea_rec_loss_real')

            with timeit('d_obj backward', args.timing):
                all_in_one_model.optimizer_d_obj.zero_grad()
                d_obj_losses.total_loss.backward()
                all_in_one_model.optimizer_d_obj.step()

        if all_in_one_model.img_discriminator is not None:
            with timeit('d_img loss', args.timing):
                d_img_losses = LossManager()
                d_img_gan_loss = gan_d_loss(d_scores_real_img,
                                            d_scores_fake_img)
                d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')
                if args.gan_loss_type == 'wgan-gp':
                    d_img_losses.add_loss(d_img_gp.mean(), 'd_img_gp',
                                          args.d_img_gp_weight)

            with timeit('d_img backward', args.timing):
                all_in_one_model.optimizer_d_img.zero_grad()
                d_img_losses.total_loss.backward()
                all_in_one_model.optimizer_d_img.step()

        if all_in_one_model.bg_discriminator is not None:
            with timeit('d_bg loss', args.timing):
                d_bg_losses = LossManager()
                d_bg_gan_loss = gan_d_loss(d_scores_real_bg, d_scores_fake_bg)
                d_bg_losses.add_loss(d_bg_gan_loss, 'd_bg_gan_loss')
                if args.gan_loss_type == 'wgan-gp':
                    d_bg_losses.add_loss(d_bg_gp.mean(), 'd_bg_gp',
                                         args.d_bg_gp_weight)

            with timeit('d_bg backward', args.timing):
                all_in_one_model.optimizer_d_bg.zero_grad()
                d_bg_losses.total_loss.backward()
                all_in_one_model.optimizer_d_bg.step()

        return d_obj_losses, d_img_losses, d_bg_losses

    def G_step(result):
        imgs, imgs_pred, objs, \
        g_scores_fake_crop, g_obj_scores_fake_crop, g_scores_fake_img, \
        = result.imgs, result.imgs_pred, result.objs, \
          result.g_scores_fake_crop, result.g_obj_scores_fake_crop, result.g_scores_fake_img
        mask_noise_indexes = result.mask_noise_indexes
        g_rec_feature_fake_crop = result.g_rec_feature_fake_crop
        obj_fmaps = result.obj_fmaps
        g_scores_fake_bg = result.g_scores_fake_bg

        bg_layout = result.bg_layout
        crops_encoded = result.crops_encoded
        crops_pred_encoded = result.crops_pred_encoded
        z_random = result.z_random
        z_random_rec = result.z_random_rec
        mu_encoded = result.mu_encoded
        logvar_encoded = result.logvar_encoded

        with timeit('loss', args.timing):
            total_loss, losses = calculate_model_losses(
                args, imgs, imgs_pred, mask_noise_indexes, bg_layout)

            crops_encoded_rec_loss = F.l1_loss(crops_pred_encoded,
                                               crops_encoded)
            total_loss = add_loss(total_loss, crops_encoded_rec_loss, losses,
                                  'crops_encoded_rec_loss',
                                  args.crops_encoded_rec_loss_weight)

            kl_loss = torch.sum(1 + logvar_encoded - mu_encoded.pow(2) -
                                logvar_encoded.exp()) * (-0.5)
            total_loss = add_loss(total_loss, kl_loss, losses, 'kl_loss',
                                  args.kl_loss_weight)

            if criterionVGG is not None:
                if args.perceptual_on_bg:
                    perceptual_imgs = imgs * bg_layout
                    preceptual_imgs_pred = imgs_pred * bg_layout
                if mask_noise_indexes is not None and args.perceptual_not_on_noise:
                    perceptual_loss = criterionVGG(
                        preceptual_imgs_pred[mask_noise_indexes],
                        perceptual_imgs[mask_noise_indexes])
                else:
                    perceptual_loss = criterionVGG(preceptual_imgs_pred,
                                                   perceptual_imgs)
                total_loss = add_loss(total_loss, perceptual_loss, losses,
                                      'perceptual_loss',
                                      args.perceptual_loss_weight)

            if all_in_one_model.obj_discriminator is not None:
                total_loss = add_loss(
                    total_loss, F.cross_entropy(g_obj_scores_fake_crop, objs),
                    losses, 'ac_loss', args.ac_loss_weight)
                weight = args.discriminator_loss_weight * args.d_obj_weight
                total_loss = add_loss(total_loss,
                                      gan_g_loss(g_scores_fake_crop), losses,
                                      'g_gan_obj_loss', weight)
                if args.d_obj_rec_feat_weight > 0:
                    total_loss = add_loss(
                        total_loss,
                        F.l1_loss(g_rec_feature_fake_crop, obj_fmaps), losses,
                        'g_obj_fea_rec_loss', args.d_obj_rec_feat_weight)

            if all_in_one_model.img_discriminator is not None:
                weight = args.discriminator_loss_weight * args.d_img_weight
                total_loss = add_loss(total_loss,
                                      gan_g_loss(g_scores_fake_img), losses,
                                      'g_gan_img_loss', weight)

            if all_in_one_model.bg_discriminator is not None:
                weight = args.discriminator_loss_weight * args.d_bg_weight
                total_loss = add_loss(total_loss, gan_g_loss(g_scores_fake_bg),
                                      losses, 'g_gan_bg_loss', weight)

        losses['total_loss'] = total_loss.item()

        if math.isfinite(losses['total_loss']):
            with timeit('backward', args.timing):
                all_in_one_model.optimizer_e_obj.zero_grad()
                all_in_one_model.optimizer.zero_grad()
                total_loss.backward(retain_graph=True)
                all_in_one_model.optimizer.step()
                all_in_one_model.optimizer_e_obj.step()

        z_random_rec_loss = torch.mean(
            torch.abs(z_random_rec - z_random)) * args.z_random_rec_loss_weight
        all_in_one_model.optimizer.zero_grad()
        all_in_one_model.optimizer_e_obj.zero_grad()
        z_random_rec_loss.backward()
        all_in_one_model.optimizer.step()

        total_loss = add_loss(total_loss, z_random_rec_loss, losses,
                              'z_random_rec_loss', 1.)
        losses['total_loss'] = total_loss.item()

        return losses

    while True:
        if t >= args.num_iterations * (args.n_critic + args.n_gen):
            break
        epoch += 1
        print('Starting epoch %d' % epoch)

        for step, batch in enumerate(
                tqdm(train_loader,
                     desc='Training Epoch %d' % epoch,
                     total=len(train_loader))):
            # if t == args.eval_mode_after:
            #     print('switching to eval mode')
            #     all_in_one_model.model.eval()
            #     all_in_one_model.optimizer = optim.Adam(all_in_one_model.parameters(), lr=args.learning_rate)
            all_in_one_model.train()
            modes = ['l1', 'noise_std', 'd_obj', 'd_img', 'ac_loss']
            attrs = [
                'l1_pixel_loss_weight', 'noise_std', 'd_obj_weight',
                'd_img_weight', 'ac_loss_weight'
            ]
            for mode, attr in zip(modes, attrs):
                old_value = getattr(args, attr)
                if getattr(args,
                           "%s_mode" % mode) == "change" and t in getattr(
                               args, "%s_change_iters" % mode):
                    step_index = getattr(args,
                                         "%s_change_iters" % mode).index(t)
                    new_value = getattr(args,
                                        "%s_change_vals" % mode)[step_index]
                    setattr(args, attr, new_value)
                    print("Change %s from %.10f to %.10f at iteration %d" %
                          (attr, old_value, getattr(args, attr), t))
                elif getattr(args, "%s_mode" % mode) == "change_linear":
                    start_step = getattr(args, "%s_change_iters" % mode)[0]
                    end_step = getattr(args, "%s_change_iters" % mode)[1]
                    if start_step <= t <= end_step:
                        start_val = getattr(args, "%s_change_vals" % mode)[0]
                        end_val = getattr(args, "%s_change_vals" % mode)[1]
                        new_value = start_val + (end_val - start_val) * (
                            t - start_step) / (end_step - start_step)
                        setattr(args, attr, new_value)
                        print("Change %s from %.10f to %.10f at iteration %d" %
                              (attr, old_value, getattr(args, attr), t))
                    elif t > end_step:
                        end_val = getattr(args, "%s_change_vals" % mode)[1]
                        if old_value != end_val:
                            new_value = end_val
                            setattr(args, attr, new_value)
                            print(
                                "probably resume training from previous checkpoint"
                            )
                            print(
                                "Change %s from %.10f to %.10f at iteration %d"
                                % (attr, old_value, getattr(args, attr), t))
            t += 1
            if args.gan_loss_type in ["wgan", "wgan-gp"] or args.n_critic != 0:
                # train discriminator (critic) for n_critic iterations
                if t % (args.n_critic + args.n_gen) in list(
                        range(1, args.n_critic + 1)):
                    all_in_one_model.forward_G = True
                    all_in_one_model.calc_G_D_loss = False
                    all_in_one_model.forward_D = True
                    all_in_one_model.set_requires_grad([
                        all_in_one_model.obj_discriminator,
                        all_in_one_model.img_discriminator
                    ], True)
                    with timeit('forward', args.timing):
                        result = all_in_one_model[batch]
                    d_obj_losses, d_img_losses, d_bg_losses = D_step(result)

                # train generator for 1 iteration after n_critic iterations
                if t % (args.n_critic + args.n_gen) in (list(
                        range(args.n_critic + 1, args.n_critic + args.n_gen)) +
                                                        [0]):
                    all_in_one_model.forward_G = True
                    all_in_one_model.calc_G_D_loss = True
                    all_in_one_model.forward_D = False
                    all_in_one_model.set_requires_grad([
                        all_in_one_model.obj_discriminator,
                        all_in_one_model.img_discriminator
                    ], False)
                    result = all_in_one_model[batch]

                    losses = G_step(result)
                    if not math.isfinite(losses['total_loss']):
                        print('WARNING: Got loss = NaN, not backpropping')
                        continue
            else:  # vanilla gan or lsgan
                all_in_one_model.forward_G = True
                all_in_one_model.calc_G_D_loss = True
                all_in_one_model.forward_D = True
                with timeit('forward', args.timing):
                    result = all_in_one_model[batch]
                losses = G_step(result)
                if not math.isfinite(losses['total_loss']):
                    print('WARNING: Got loss = NaN, not backpropping')
                    continue
                d_obj_losses, d_img_losses, d_bg_losses = D_step(result)

            if t % (args.print_every * (args.n_critic + args.n_gen)) == 0:
                print('t = %d / %d' % (t, args.num_iterations))
                G_loss_list = []
                for name, val in losses.items():
                    G_loss_list.append('[%s]: %.4f' % (name, val))
                    checkpoint['losses'][name].append(val)
                    summary_writer.add_scalar("G_%s" % name, val, t)
                print("G: %s" % ", ".join(G_loss_list))
                checkpoint['losses_ts'].append(t)

                if all_in_one_model.obj_discriminator is not None:
                    D_obj_loss_list = []
                    for name, val in d_obj_losses.items():
                        D_obj_loss_list.append('[%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)
                        summary_writer.add_scalar("D_obj_%s" % name, val, t)
                    print("D_obj: %s" % ", ".join(D_obj_loss_list))

                if all_in_one_model.img_discriminator is not None:
                    D_img_loss_list = []
                    for name, val in d_img_losses.items():
                        D_img_loss_list.append('[%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)
                        summary_writer.add_scalar("D_img_%s" % name, val, t)
                    print("D_img: %s" % ", ".join(D_img_loss_list))

                if all_in_one_model.bg_discriminator is not None:
                    D_bg_loss_list = []
                    for name, val in d_bg_losses.items():
                        D_bg_loss_list.append('[%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)
                        summary_writer.add_scalar("D_bg_%s" % name, val, t)
                    print("D_bg: %s" % ", ".join(D_bg_loss_list))

            if t % (args.checkpoint_every * (args.n_critic + args.n_gen)) == 0:
                print('checking on train')
                train_results = check_model(args, train_loader,
                                            all_in_one_model)
                t_losses, t_samples = train_results

                checkpoint['train_samples'].append(t_samples)
                checkpoint['checkpoint_ts'].append(t)
                for name, images in t_samples.items():
                    summary_writer.add_image("train_%s" % name, images, t)

                print('checking on val')
                val_results = check_model(args, val_loader, all_in_one_model)
                val_losses, val_samples = val_results
                checkpoint['val_samples'].append(val_samples)
                for name, images in val_samples.items():
                    summary_writer.add_image("val_%s" % name, images, t)

                for k, v in val_losses.items():
                    checkpoint['val_losses'][k].append(v)
                    summary_writer.add_scalar("val_%s" % k, v, t)
                checkpoint['model_state'] = all_in_one_model.model.state_dict()
                checkpoint[
                    'optim_state'] = all_in_one_model.optimizer.state_dict()

                checkpoint[
                    'e_obj_state'] = all_in_one_model.obj_encoder.state_dict()
                checkpoint[
                    'e_obj_optim_state'] = all_in_one_model.optimizer_e_obj.state_dict(
                    )

                if all_in_one_model.obj_discriminator is not None:
                    checkpoint[
                        'd_obj_state'] = all_in_one_model.obj_discriminator.state_dict(
                        )
                    checkpoint[
                        'd_obj_optim_state'] = all_in_one_model.optimizer_d_obj.state_dict(
                        )

                if all_in_one_model.img_discriminator is not None:
                    checkpoint[
                        'd_img_state'] = all_in_one_model.img_discriminator.state_dict(
                        )
                    checkpoint[
                        'd_img_optim_state'] = all_in_one_model.optimizer_d_img.state_dict(
                        )

                if all_in_one_model.bg_discriminator is not None:
                    checkpoint[
                        'd_bg_state'] = all_in_one_model.bg_discriminator.state_dict(
                        )
                    checkpoint[
                        'd_bg_optim_state'] = all_in_one_model.optimizer_d_bg.state_dict(
                        )

                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name)
                print('Saving checkpoint to ', checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

                # Save another checkpoint without any model or optim state
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                key_blacklist = [
                    'model_state', 'optim_state', 'model_best_state',
                    'd_obj_state', 'd_obj_optim_state', 'd_obj_best_state',
                    'd_img_state', 'd_img_optim_state', 'd_img_best_state',
                    'd_bg_state', 'd_bg_optim_state', 'd_bg_best_state',
                    'e_obj_state', 'e_obj_optim_state', 'e_obj_best_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
コード例 #19
0
def main(args):
    print(args)
    check_args(args)
    float_dtype = torch.cuda.FloatTensor
    long_dtype = torch.cuda.LongTensor

    vocab, train_loader, val_loader = build_loaders(args)
    model, model_kwargs = build_model(args, vocab)
    model.type(float_dtype)
    print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    obj_discriminator, d_obj_kwargs = build_obj_discriminator(args, vocab)
    img_discriminator, d_img_kwargs = build_img_discriminator(args, vocab)
    gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)

    if obj_discriminator is not None:
        obj_discriminator.type(float_dtype)
        obj_discriminator.train()
        print(obj_discriminator)
        optimizer_d_obj = torch.optim.Adam(obj_discriminator.parameters(),
                                           lr=args.learning_rate)

    if img_discriminator is not None:
        img_discriminator.type(float_dtype)
        img_discriminator.train()
        print(img_discriminator)
        optimizer_d_img = torch.optim.Adam(img_discriminator.parameters(),
                                           lr=args.learning_rate)

    restore_path = None
    if args.restore_from_checkpoint:
        restore_path = '%s_with_model.pt' % args.checkpoint_name
        restore_path = os.path.join(args.output_dir, restore_path)
    if restore_path is not None and os.path.isfile(restore_path):
        print('Restoring from checkpoint:')
        print(restore_path)
        checkpoint = torch.load(restore_path)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optim_state'])

        if obj_discriminator is not None:
            obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
            optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])

        if img_discriminator is not None:
            img_discriminator.load_state_dict(checkpoint['d_img_state'])
            optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])

        t = checkpoint['counters']['t']
        if 0 <= args.eval_mode_after <= t:
            model.eval()
        else:
            model.train()
        epoch = checkpoint['counters']['epoch']
    else:
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'vocab': vocab,
            'model_kwargs': model_kwargs,
            'd_obj_kwargs': d_obj_kwargs,
            'd_img_kwargs': d_img_kwargs,
            'losses_ts': [],
            'losses': defaultdict(list),
            'd_losses': defaultdict(list),
            'checkpoint_ts': [],
            'train_batch_data': [],
            'train_samples': [],
            'train_iou': [],
            'val_batch_data': [],
            'val_samples': [],
            'val_losses': defaultdict(list),
            'val_iou': [],
            'norm_d': [],
            'norm_g': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'model_state': None,
            'model_best_state': None,
            'optim_state': None,
            'd_obj_state': None,
            'd_obj_best_state': None,
            'd_obj_optim_state': None,
            'd_img_state': None,
            'd_img_best_state': None,
            'd_img_optim_state': None,
            'best_t': [],
        }

    while True:
        if t >= args.num_iterations:
            break
        epoch += 1
        print('Starting epoch %d' % epoch)

        for batch in train_loader:
            if t == args.eval_mode_after:
                print('switching to eval mode')
                model.eval()
                optimizer = optim.Adam(model.parameters(),
                                       lr=args.learning_rate)
            t += 1
            batch = [tensor.cuda() for tensor in batch]
            masks = None
            if len(batch) == 6:
                imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
            elif len(batch) == 7:
                imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
            else:
                assert False
            predicates = triples[:, 1]

            with timeit('forward', args.timing):
                model_boxes = boxes
                model_masks = masks
                model_out = model(objs,
                                  triples,
                                  obj_to_img,
                                  boxes_gt=model_boxes,
                                  masks_gt=model_masks)
                imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
            with timeit('loss', args.timing):
                # Skip the pixel loss if using GT boxes
                skip_pixel_loss = (model_boxes is None)
                total_loss, losses = calculate_model_losses(
                    args, skip_pixel_loss, model, imgs, imgs_pred, boxes,
                    boxes_pred, masks, masks_pred, predicates,
                    predicate_scores)

            if obj_discriminator is not None:
                scores_fake, ac_loss = obj_discriminator(
                    imgs_pred, objs, boxes, obj_to_img)
                total_loss = add_loss(total_loss, ac_loss, losses, 'ac_loss',
                                      args.ac_loss_weight)
                weight = args.discriminator_loss_weight * args.d_obj_weight
                total_loss = add_loss(total_loss, gan_g_loss(scores_fake),
                                      losses, 'g_gan_obj_loss', weight)

            if img_discriminator is not None:
                scores_fake = img_discriminator(imgs_pred)
                weight = args.discriminator_loss_weight * args.d_img_weight
                total_loss = add_loss(total_loss, gan_g_loss(scores_fake),
                                      losses, 'g_gan_img_loss', weight)

            losses['total_loss'] = total_loss.item()
            if not math.isfinite(losses['total_loss']):
                print('WARNING: Got loss = NaN, not backpropping')
                continue

            optimizer.zero_grad()
            with timeit('backward', args.timing):
                total_loss.backward()
            optimizer.step()
            total_loss_d = None
            ac_loss_real = None
            ac_loss_fake = None
            d_losses = {}

            if obj_discriminator is not None:
                d_obj_losses = LossManager()
                imgs_fake = imgs_pred.detach()
                scores_fake, ac_loss_fake = obj_discriminator(
                    imgs_fake, objs, boxes, obj_to_img)
                scores_real, ac_loss_real = obj_discriminator(
                    imgs, objs, boxes, obj_to_img)

                d_obj_gan_loss = gan_d_loss(scores_real, scores_fake)
                d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
                d_obj_losses.add_loss(ac_loss_real, 'd_ac_loss_real')
                d_obj_losses.add_loss(ac_loss_fake, 'd_ac_loss_fake')

                optimizer_d_obj.zero_grad()
                d_obj_losses.total_loss.backward()
                optimizer_d_obj.step()

            if img_discriminator is not None:
                d_img_losses = LossManager()
                imgs_fake = imgs_pred.detach()
                scores_fake = img_discriminator(imgs_fake)
                scores_real = img_discriminator(imgs)

                d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
                d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')

                optimizer_d_img.zero_grad()
                d_img_losses.total_loss.backward()
                optimizer_d_img.step()

            if t % args.print_every == 0:
                print('t = %d / %d' % (t, args.num_iterations))
                for name, val in losses.items():
                    print(' G [%s]: %.4f' % (name, val))
                    checkpoint['losses'][name].append(val)
                checkpoint['losses_ts'].append(t)

                if obj_discriminator is not None:
                    for name, val in d_obj_losses.items():
                        print(' D_obj [%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)

                if img_discriminator is not None:
                    for name, val in d_img_losses.items():
                        print(' D_img [%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)

            if t % args.checkpoint_every == 0:
                print('checking on train')
                train_results = check_model(args, t, train_loader, model)
                t_losses, t_samples, t_batch_data, t_avg_iou = train_results

                checkpoint['train_batch_data'].append(t_batch_data)
                checkpoint['train_samples'].append(t_samples)
                checkpoint['checkpoint_ts'].append(t)
                checkpoint['train_iou'].append(t_avg_iou)

                print('checking on val')
                val_results = check_model(args, t, val_loader, model)
                val_losses, val_samples, val_batch_data, val_avg_iou = val_results
                checkpoint['val_samples'].append(val_samples)
                checkpoint['val_batch_data'].append(val_batch_data)
                checkpoint['val_iou'].append(val_avg_iou)

                print('train iou: ', t_avg_iou)
                print('val iou: ', val_avg_iou)

                for k, v in val_losses.items():
                    checkpoint['val_losses'][k].append(v)
                checkpoint['model_state'] = model.state_dict()

                if obj_discriminator is not None:
                    checkpoint['d_obj_state'] = obj_discriminator.state_dict()
                    checkpoint[
                        'd_obj_optim_state'] = optimizer_d_obj.state_dict()

                if img_discriminator is not None:
                    checkpoint['d_img_state'] = img_discriminator.state_dict()
                    checkpoint[
                        'd_img_optim_state'] = optimizer_d_img.state_dict()

                checkpoint['optim_state'] = optimizer.state_dict()
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name)
                print('Saving checkpoint to ', checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

                # Save another checkpoint without any model or optim state
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                key_blacklist = [
                    'model_state', 'optim_state', 'model_best_state',
                    'd_obj_state', 'd_obj_optim_state', 'd_obj_best_state',
                    'd_img_state', 'd_img_optim_state', 'd_img_best_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
コード例 #20
0
    def D_step(result):
        imgs, imgs_pred, objs, \
        d_scores_fake_crop, d_obj_scores_fake_crop, d_scores_real_crop, \
        d_obj_scores_real_crop, d_scores_fake_img, d_scores_real_img, \
        d_obj_gp, d_img_gp \
        = result.imgs, result.imgs_pred, result.objs, \
          result.d_scores_fake_crop, result.d_obj_scores_fake_crop, result.d_scores_real_crop, \
          result.d_obj_scores_real_crop, result.d_scores_fake_img, result.d_scores_real_img, \
          result.d_obj_gp, result.d_img_gp
        d_rec_feature_fake_crop, d_rec_feature_real_crop = result.d_rec_feature_fake_crop, result.d_rec_feature_real_crop
        obj_fmaps = result.obj_fmaps
        d_scores_fake_bg, d_scores_real_bg, d_bg_gp = result.d_scores_fake_bg, result.d_scores_real_bg, result.d_bg_gp

        d_obj_losses, d_img_losses, d_bg_losses = None, None, None
        if all_in_one_model.obj_discriminator is not None:
            with timeit('d_obj loss', args.timing):
                d_obj_losses = LossManager()
                if args.d_obj_weight > 0:
                    d_obj_gan_loss = gan_d_loss(d_scores_real_crop,
                                                d_scores_fake_crop)
                    d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
                    if args.gan_loss_type == 'wgan-gp':
                        d_obj_losses.add_loss(d_obj_gp.mean(), 'd_obj_gp',
                                              args.d_obj_gp_weight)
                if args.ac_loss_weight > 0:
                    d_obj_losses.add_loss(
                        F.cross_entropy(d_obj_scores_real_crop, objs),
                        'd_ac_loss_real')
                    d_obj_losses.add_loss(
                        F.cross_entropy(d_obj_scores_fake_crop, objs),
                        'd_ac_loss_fake')
                if args.d_obj_rec_feat_weight > 0:
                    d_obj_losses.add_loss(
                        F.l1_loss(d_rec_feature_fake_crop, obj_fmaps),
                        'd_obj_fea_rec_loss_fake')
                    d_obj_losses.add_loss(
                        F.l1_loss(d_rec_feature_real_crop, obj_fmaps),
                        'd_obj_fea_rec_loss_real')

            with timeit('d_obj backward', args.timing):
                all_in_one_model.optimizer_d_obj.zero_grad()
                d_obj_losses.total_loss.backward()
                all_in_one_model.optimizer_d_obj.step()

        if all_in_one_model.img_discriminator is not None:
            with timeit('d_img loss', args.timing):
                d_img_losses = LossManager()
                d_img_gan_loss = gan_d_loss(d_scores_real_img,
                                            d_scores_fake_img)
                d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')
                if args.gan_loss_type == 'wgan-gp':
                    d_img_losses.add_loss(d_img_gp.mean(), 'd_img_gp',
                                          args.d_img_gp_weight)

            with timeit('d_img backward', args.timing):
                all_in_one_model.optimizer_d_img.zero_grad()
                d_img_losses.total_loss.backward()
                all_in_one_model.optimizer_d_img.step()

        if all_in_one_model.bg_discriminator is not None:
            with timeit('d_bg loss', args.timing):
                d_bg_losses = LossManager()
                d_bg_gan_loss = gan_d_loss(d_scores_real_bg, d_scores_fake_bg)
                d_bg_losses.add_loss(d_bg_gan_loss, 'd_bg_gan_loss')
                if args.gan_loss_type == 'wgan-gp':
                    d_bg_losses.add_loss(d_bg_gp.mean(), 'd_bg_gp',
                                         args.d_bg_gp_weight)

            with timeit('d_bg backward', args.timing):
                all_in_one_model.optimizer_d_bg.zero_grad()
                d_bg_losses.total_loss.backward()
                all_in_one_model.optimizer_d_bg.step()

        return d_obj_losses, d_img_losses, d_bg_losses