def main(config):
    matrix = torch.load("matrix_obj_vs_att.pt")
    cudnn.benchmark = True
    device = torch.device('cuda:1')

    log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir(
        config.exp_name)

    attribute_nums = 106

    data_loader, _ = get_dataloader_vg(batch_size=config.batch_size,
                                       attribute_embedding=attribute_nums,
                                       image_size=config.image_size)

    vocab_num = data_loader.dataset.num_objects

    if config.clstm_layers == 0:
        netG = Generator_nolstm(num_embeddings=vocab_num,
                                embedding_dim=config.embedding_dim,
                                z_dim=config.z_dim).to(device)
    else:
        netG = Generator(num_embeddings=vocab_num,
                         obj_att_dim=config.embedding_dim,
                         z_dim=config.z_dim,
                         clstm_layers=config.clstm_layers,
                         obj_size=config.object_size,
                         attribute_dim=attribute_nums).to(device)

    netD_image = ImageDiscriminator(conv_dim=config.embedding_dim).to(device)
    netD_object = ObjectDiscriminator(n_class=vocab_num).to(device)
    netD_att = AttributeDiscriminator(n_attribute=attribute_nums).to(device)

    netD_image = add_sn(netD_image)
    netD_object = add_sn(netD_object)
    netD_att = add_sn(netD_att)

    netG_optimizer = torch.optim.Adam(netG.parameters(), config.learning_rate,
                                      [0.5, 0.999])
    netD_image_optimizer = torch.optim.Adam(netD_image.parameters(),
                                            config.learning_rate, [0.5, 0.999])
    netD_object_optimizer = torch.optim.Adam(netD_object.parameters(),
                                             config.learning_rate,
                                             [0.5, 0.999])
    netD_att_optimizer = torch.optim.Adam(netD_att.parameters(),
                                          config.learning_rate, [0.5, 0.999])

    start_iter_ = load_model(netD_object,
                             model_dir=model_save_dir,
                             appendix='netD_object',
                             iter=config.resume_iter)

    start_iter_ = load_model(netD_att,
                             model_dir=model_save_dir,
                             appendix='netD_attribute',
                             iter=config.resume_iter)

    start_iter_ = load_model(netD_image,
                             model_dir=model_save_dir,
                             appendix='netD_image',
                             iter=config.resume_iter)

    start_iter = load_model(netG,
                            model_dir=model_save_dir,
                            appendix='netG',
                            iter=config.resume_iter)

    data_iter = iter(data_loader)

    if start_iter < config.niter:

        if config.use_tensorboard: writer = SummaryWriter(log_save_dir)

        for i in range(start_iter, config.niter):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #
            try:
                batch = next(data_iter)
            except:
                data_iter = iter(data_loader)
                batch = next(data_iter)

            imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch
            z = torch.randn(objs.size(0), config.z_dim)

            att_idx = attribute.sum(dim=1).nonzero().squeeze()
            # print("Train D")
            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #
            imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift \
                = imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), obj_to_img, z.to(
                device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device)

            attribute_GT = attribute.clone()

            # estimate attributes
            attribute_est = attribute.clone()
            att_mask = torch.zeros(attribute.shape[0])
            att_mask = att_mask.scatter(0, att_idx, 1).to(device)

            crops_input = crop_bbox_batch(imgs, boxes, obj_to_img,
                                          config.object_size)
            estimated_att = netD_att(crops_input)
            max_idx = estimated_att.argmax(1)
            max_idx = max_idx.float() * (~att_mask.byte()).float().to(device)
            for row in range(attribute.shape[0]):
                if row not in att_idx:
                    attribute_est[row, int(max_idx[row])] = 1

            # change GT attribute:
            num_img_to_change = math.floor(imgs.shape[0] / 3)
            for img_idx in range(num_img_to_change):
                obj_indices = torch.nonzero(obj_to_img == img_idx).view(-1)

                num_objs_to_change = math.floor(len(obj_indices) / 2)
                for changed, obj_idx in enumerate(obj_indices):
                    if changed >= num_objs_to_change:
                        break
                    obj = objs[obj_idx]
                    # change GT attribute
                    old_attributes = torch.nonzero(
                        attribute_GT[obj_idx]).view(-1)
                    new_attribute = random.choices(range(106),
                                                   matrix[obj].scatter(
                                                       0, old_attributes.cpu(),
                                                       0),
                                                   k=random.randrange(1, 3))
                    attribute[obj_idx] = 0  # remove all attributes for obj
                    attribute[obj_idx] = attribute[obj_idx].scatter(
                        0,
                        torch.LongTensor(new_attribute).to(device),
                        1)  # assign new attribute

                    # change estimated attributes
                    attribute_est[obj_idx] = 0  # remove all attributes for obj
                    attribute_est[obj_idx] = attribute[obj_idx].scatter(
                        0,
                        torch.LongTensor(new_attribute).to(device), 1)

            # Generate fake image
            output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute,
                          masks_shift, boxes_shift, attribute_est)
            crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output

            # Compute image adv loss with fake images.
            out_logits = netD_image(img_rec.detach())
            d_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            out_logits = netD_image(img_rand.detach())
            d_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            # shift image adv loss
            out_logits = netD_image(img_shift.detach())
            d_image_adv_loss_fake_shift = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            d_image_adv_loss_fake = 0.4 * d_image_adv_loss_fake_rec + 0.4 * d_image_adv_loss_fake_rand + 0.2 * d_image_adv_loss_fake_shift

            # Compute image src loss with real images rec.
            out_logits = netD_image(imgs)
            d_image_adv_loss_real = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 1))

            # Compute object sn adv loss with fake rec crops
            out_logits, _ = netD_object(crops_input_rec.detach(), objs)
            g_object_adv_loss_rec = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            # Compute object sn adv loss with fake rand crops
            out_logits, _ = netD_object(crops_rand.detach(), objs)

            d_object_adv_loss_fake_rand = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            # shift obj adv loss
            out_logits, _ = netD_object(crops_shift.detach(), objs)
            d_object_adv_loss_fake_shift = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            d_object_adv_loss_fake = 0.4 * g_object_adv_loss_rec + 0.4 * d_object_adv_loss_fake_rand + 0.2 * d_object_adv_loss_fake_shift

            # Compute object sn adv loss with real crops.
            out_logits_src, out_logits_cls = netD_object(
                crops_input.detach(), objs)

            d_object_adv_loss_real = F.binary_cross_entropy_with_logits(
                out_logits_src, torch.full_like(out_logits_src, 1))

            # cls
            d_object_cls_loss_real = F.cross_entropy(out_logits_cls, objs)
            # attribute
            att_cls = netD_att(crops_input.detach())
            att_idx = attribute_GT.sum(dim=1).nonzero().squeeze()
            att_cls_annotated = torch.index_select(att_cls, 0, att_idx)
            attribute_annotated = torch.index_select(attribute_GT, 0, att_idx)
            d_object_att_cls_loss_real = F.binary_cross_entropy_with_logits(
                att_cls_annotated,
                attribute_annotated,
                pos_weight=pos_weight.to(device))

            # Backward and optimize.
            d_loss = 0
            d_loss += config.lambda_img_adv * (d_image_adv_loss_fake +
                                               d_image_adv_loss_real)
            d_loss += config.lambda_obj_adv * (d_object_adv_loss_fake +
                                               d_object_adv_loss_real)
            d_loss += config.lambda_obj_cls * d_object_cls_loss_real
            d_loss += config.lambda_att_cls * d_object_att_cls_loss_real

            netD_image.zero_grad()
            netD_object.zero_grad()
            netD_att.zero_grad()

            d_loss.backward()

            netD_image_optimizer.step()
            netD_object_optimizer.step()
            netD_att_optimizer.step()

            # Logging.
            loss = {}
            loss['D/loss'] = d_loss.item()
            loss['D/image_adv_loss_real'] = d_image_adv_loss_real.item()
            loss['D/image_adv_loss_fake'] = d_image_adv_loss_fake.item()
            loss['D/object_adv_loss_real'] = d_object_adv_loss_real.item()
            loss['D/object_adv_loss_fake'] = d_object_adv_loss_fake.item()
            loss['D/object_cls_loss_real'] = d_object_cls_loss_real.item()
            loss['D/object_att_cls_loss'] = d_object_att_cls_loss_real.item()

            # print("train G")
            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #
            # Generate fake image

            output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute,
                          masks_shift, boxes_shift, attribute_est)
            crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output

            # reconstruction loss of ae and img
            rec_img_mask = torch.ones(imgs.shape[0]).scatter(
                0, torch.LongTensor(range(num_img_to_change)), 0).to(device)
            g_img_rec_loss = rec_img_mask * torch.abs(img_rec - imgs).view(
                imgs.shape[0], -1).mean(1)
            g_img_rec_loss = g_img_rec_loss.sum() / (imgs.shape[0] -
                                                     num_img_to_change)

            g_z_rec_loss_rand = torch.abs(z_rand_rec - z).mean()
            g_z_rec_loss_shift = torch.abs(z_rand_shift - z).mean()
            g_z_rec_loss = 0.5 * g_z_rec_loss_rand + 0.5 * g_z_rec_loss_shift

            # kl loss
            kl_element = mu.pow(2).add_(
                logvar.exp()).mul_(-1).add_(1).add_(logvar)
            g_kl_loss = torch.sum(kl_element).mul_(-0.5)

            # Compute image adv loss with fake images.
            out_logits = netD_image(img_rec)

            g_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 1))

            out_logits = netD_image(img_rand)
            g_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 1))

            # shift image adv loss
            out_logits = netD_image(img_shift)
            g_image_adv_loss_fake_shift = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 1))

            g_image_adv_loss_fake = 0.4 * g_image_adv_loss_fake_rec + 0.4 * g_image_adv_loss_fake_rand + 0.2 * g_image_adv_loss_fake_shift

            # Compute object adv loss with fake images.
            out_logits_src, out_logits_cls = netD_object(crops_input_rec, objs)

            g_object_adv_loss_rec = F.binary_cross_entropy_with_logits(
                out_logits_src, torch.full_like(out_logits_src, 1))
            g_object_cls_loss_rec = F.cross_entropy(out_logits_cls, objs)
            # attribute
            att_cls = netD_att(crops_input_rec)
            att_idx = attribute.sum(dim=1).nonzero().squeeze()
            attribute_annotated = torch.index_select(attribute, 0, att_idx)
            att_cls_annotated = torch.index_select(att_cls, 0, att_idx)
            g_object_att_cls_loss_rec = F.binary_cross_entropy_with_logits(
                att_cls_annotated,
                attribute_annotated,
                pos_weight=pos_weight.to(device))

            out_logits_src, out_logits_cls = netD_object(crops_rand, objs)
            g_object_adv_loss_rand = F.binary_cross_entropy_with_logits(
                out_logits_src, torch.full_like(out_logits_src, 1))
            g_object_cls_loss_rand = F.cross_entropy(out_logits_cls, objs)
            # attribute
            att_cls = netD_att(crops_rand)
            att_cls_annotated = torch.index_select(att_cls, 0, att_idx)
            g_object_att_cls_loss_rand = F.binary_cross_entropy_with_logits(
                att_cls_annotated,
                attribute_annotated,
                pos_weight=pos_weight.to(device))

            # shift adv obj loss
            out_logits_src, out_logits_cls = netD_object(crops_shift, objs)
            g_object_adv_loss_shift = F.binary_cross_entropy_with_logits(
                out_logits_src, torch.full_like(out_logits_src, 1))

            g_object_cls_loss_shift = F.cross_entropy(out_logits_cls, objs)
            # attribute
            att_cls = netD_att(crops_shift)
            att_cls_annotated = torch.index_select(att_cls, 0, att_idx)
            g_object_att_cls_loss_shift = F.binary_cross_entropy_with_logits(
                att_cls_annotated,
                attribute_annotated,
                pos_weight=pos_weight.to(device))

            g_object_att_cls_loss = 0.4 * g_object_att_cls_loss_rec + 0.4 * g_object_att_cls_loss_rand + 0.2 * g_object_att_cls_loss_shift

            g_object_adv_loss = 0.4 * g_object_adv_loss_rec + 0.4 * g_object_adv_loss_rand + 0.2 * g_object_adv_loss_shift
            g_object_cls_loss = 0.4 * g_object_cls_loss_rec + 0.4 * g_object_cls_loss_rand + 0.2 * g_object_cls_loss_shift

            # Backward and optimize.
            g_loss = 0
            g_loss += config.lambda_img_rec * g_img_rec_loss
            g_loss += config.lambda_z_rec * g_z_rec_loss
            g_loss += config.lambda_img_adv * g_image_adv_loss_fake
            g_loss += config.lambda_obj_adv * g_object_adv_loss
            g_loss += config.lambda_obj_cls * g_object_cls_loss
            g_loss += config.lambda_att_cls * g_object_att_cls_loss
            g_loss += config.lambda_kl * g_kl_loss

            netG.zero_grad()

            g_loss.backward()

            netG_optimizer.step()

            loss['G/loss'] = g_loss.item()
            loss['G/image_adv_loss'] = g_image_adv_loss_fake.item()
            loss['G/object_adv_loss'] = g_object_adv_loss.item()
            loss['G/object_cls_loss'] = g_object_cls_loss.item()
            loss['G/rec_img'] = g_img_rec_loss.item()
            loss['G/rec_z'] = g_z_rec_loss.item()
            loss['G/kl'] = g_kl_loss.item()
            loss['G/object_att_cls_loss'] = g_object_att_cls_loss.item()

            # =================================================================================== #
            #                               4. Log                                                #
            # =================================================================================== #
            if (i + 1) % config.log_step == 0:
                log = 'iter [{:06d}/{:06d}]'.format(i + 1, config.niter)
                for tag, roi_value in loss.items():
                    log += ", {}: {:.4f}".format(tag, roi_value)
                print(log)

            if (i + 1
                ) % config.tensorboard_step == 0 and config.use_tensorboard:
                for tag, roi_value in loss.items():

                    writer.add_scalar(tag, roi_value, i + 1)
                writer.add_images(
                    'Result/crop_real',
                    imagenet_deprocess_batch(crops_input).float() / 255, i + 1)
                writer.add_images(
                    'Result/crop_real_rec',
                    imagenet_deprocess_batch(crops_input_rec).float() / 255,
                    i + 1)
                writer.add_images(
                    'Result/crop_rand',
                    imagenet_deprocess_batch(crops_rand).float() / 255, i + 1)
                writer.add_images('Result/img_real',
                                  imagenet_deprocess_batch(imgs).float() / 255,
                                  i + 1)
                writer.add_images(
                    'Result/img_real_rec',
                    imagenet_deprocess_batch(img_rec).float() / 255, i + 1)
                writer.add_images(
                    'Result/img_fake_rand',
                    imagenet_deprocess_batch(img_rand).float() / 255, i + 1)

            if (i + 1) % config.save_step == 0:

                # netG_noDP.load_state_dict(new_state_dict)
                save_model(netG,
                           model_dir=model_save_dir,
                           appendix='netG',
                           iter=i + 1,
                           save_num=2,
                           save_step=config.save_step)
                save_model(netD_image,
                           model_dir=model_save_dir,
                           appendix='netD_image',
                           iter=i + 1,
                           save_num=2,
                           save_step=config.save_step)
                save_model(netD_object,
                           model_dir=model_save_dir,
                           appendix='netD_object',
                           iter=i + 1,
                           save_num=2,
                           save_step=config.save_step)
                save_model(netD_att,
                           model_dir=model_save_dir,
                           appendix='netD_attribute',
                           iter=i + 1,
                           save_num=2,
                           save_step=config.save_step)

        if config.use_tensorboard: writer.close()
Ejemplo n.º 2
0
def main(config):
    cudnn.benchmark = True

    device = torch.device('cuda:1')

    log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir(
        config.exp_name)

    with open("data/vocab.json", 'r') as f:
        vocab = json.load(f)
        att_idx_to_name = np.array(vocab['attribute_idx_to_name'])
        print(att_idx_to_name)
        object_idx_to_name = np.array(vocab['object_idx_to_name'])

    attribute_nums = 106

    train_data_loader, val_data_loader = get_dataloader_vg(
        batch_size=config.batch_size, attribute_embedding=attribute_nums)

    vocab_num = train_data_loader.dataset.num_objects

    netG = Generator(num_embeddings=vocab_num,
                     obj_att_dim=config.embedding_dim,
                     z_dim=config.z_dim,
                     clstm_layers=config.clstm_layers,
                     obj_size=config.object_size,
                     attribute_dim=attribute_nums).to(device)

    netD_att = AttributeDiscriminator(n_attribute=attribute_nums).to(device)
    netD_att = add_sn(netD_att)

    start_iter_ = load_model(netD_att,
                             model_dir="~/models/trained_models",
                             appendix='netD_attribute',
                             iter=config.resume_iter)
    _ = load_model(netG,
                   model_dir=model_save_dir,
                   appendix='netG',
                   iter=config.resume_iter)

    data_loader = val_data_loader
    data_iter = iter(data_loader)

    cur_obj_start_idx = 0
    attributes_pred = np.zeros((253468, attribute_nums))
    attributes_gt = None

    with torch.no_grad():
        netG.eval()
        for i, batch in enumerate(data_iter):
            print('batch {}'.format(i))

            imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch
            z = torch.randn(objs.size(0), config.z_dim)
            att_idx = attribute.sum(dim=1).nonzero().squeeze()

            imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift = \
                imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), \
                obj_to_img, z.to(device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device)

            # estimate attributes
            attribute_est = attribute.clone()
            att_mask = torch.zeros(attribute.shape[0])
            att_mask = att_mask.scatter(0, att_idx, 1).to(device)

            crops_input = crop_bbox_batch(imgs, boxes, obj_to_img,
                                          config.object_size)
            estimated_att = netD_att(crops_input)
            max_idx = estimated_att.argmax(1)
            max_idx = max_idx.float() * (~att_mask.byte()).float().to(device)
            for row in range(attribute.shape[0]):
                if row not in att_idx:
                    attribute_est[row, int(max_idx[row])] = 1

            # Generate fake image
            output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute,
                          masks_shift, boxes_shift, attribute_est)
            crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output

            # predict attribute on generated
            crops_rand_yes = torch.index_select(crops_rand, 0,
                                                att_idx.to(device))
            attribute_yes = torch.index_select(attribute, 0,
                                               att_idx.to(device))

            estimated_att_rand = netD_att(crops_rand_yes)
            att_cls = torch.sigmoid(estimated_att_rand)

            for k in att_idx:
                ll = att_idx.cpu().numpy().tolist().index(k)

                non0_idx_cls = (att_cls[ll] > 0.9).nonzero()

                pred_idx = non0_idx_cls.squeeze().cpu().numpy()

                attributes_pred[cur_obj_start_idx, pred_idx] = 1

                cur_obj_start_idx += 1

            # construct GT array
            attributes_gt = attribute_yes.clone().cpu(
            ) if attributes_gt is None else np.vstack(
                [attributes_gt, attribute_yes.clone().cpu()])

            img_rand = imagenet_deprocess_batch(img_rand)
            img_shift = imagenet_deprocess_batch(img_shift)
            img_rec = imagenet_deprocess_batch(img_rec)

            # attribute modification
            changed_list = []
            src = 2  # 94 blue, 95 black
            tgt = 95  # 8 red, 2 white
            for idx, o in enumerate(objs):
                attribute[idx, [2, 8, 0, 94, 90, 95, 96, 34, 25, 70, 58, 104
                                ]] = 0  # remove other color
                attribute[idx, tgt] = 1

                attribute_est[idx,
                              [2, 8, 0, 94, 90, 95, 96, 34, 25, 70, 58, 104
                               ]] = 0  # remove other color
                attribute_est[idx, tgt] = 1
                changed_list.append(idx)

            # Generate red image
            z = torch.randn(objs.size(0), config.z_dim).to(device)
            output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute,
                          masks_shift, boxes_shift, attribute_est)
            crops_input, crops_input_rec, crops_rand_y, crops_shift_y, img_rec_y, img_rand_y, img_shift_y, mu, logvar, z_rand_rec, z_rand_shift = output

            img_rand_y = imagenet_deprocess_batch(img_rand_y)
            img_shift_y = imagenet_deprocess_batch(img_shift_y)
            img_rec_y = imagenet_deprocess_batch(img_rec_y)
            imgs = imagenet_deprocess_batch(imgs)

            # 2 top k
            estimated_att = netD_att(crops_rand)
            max_idx = estimated_att.topk(5)[1]
            changed_list = [i for i in changed_list if tgt not in max_idx[i]]
            estimated_att_y = netD_att(crops_rand_y)
            max_idx_y = estimated_att_y.topk(3)[1]
            changed_list_success = [
                i for i in changed_list if tgt in max_idx_y[i]
            ]
            for j in range(imgs.shape[0]):

                img_np = img_shift[j].numpy().transpose(1, 2, 0)
                img_path = os.path.join(
                    result_save_dir,
                    'img{:06d}_shift.png'.format(i * config.batch_size + j))
                imwrite(img_path, img_np)

                img_np = img_rand[j].numpy().transpose(1, 2, 0)
                img_path = os.path.join(
                    result_save_dir,
                    'img{:06d}_rand.png'.format(i * config.batch_size + j))
                imwrite(img_path, img_np)

                img_rec_np = img_rec[j].numpy().transpose(1, 2, 0)
                img_path = os.path.join(
                    result_save_dir,
                    'img{:06d}_rec.png'.format(i * config.batch_size + j))
                imwrite(img_path, img_rec_np)

                img_real_np = imgs[j].numpy().transpose(1, 2, 0)
                img_path = os.path.join(
                    result_save_dir,
                    'img{:06d}_real.png'.format(i * config.batch_size + j))
                imwrite(img_path, img_real_np)

                cur_obj_success = [
                    int(objs[c]) for c in changed_list_success
                    if obj_to_img[c] == j
                ]

                # save successfully modified images
                if len(cur_obj_success) > 0:
                    img_shift_y = img_shift_y[j].numpy().transpose(1, 2, 0)
                    img_path_red = os.path.join(
                        result_save_dir,
                        'img{:06d}_shift_{}_modified.png'.format(
                            i * config.batch_size + j,
                            object_idx_to_name[cur_obj_success]))
                    imwrite(img_path_red, img_shift_y)

                    img_rec_np = img_rec_y[j].numpy().transpose(1, 2, 0)
                    img_path = os.path.join(
                        result_save_dir, 'img{:06d}_rec_modified.png'.format(
                            i * config.batch_size + j,
                            object_idx_to_name[cur_obj_success]))
                    imwrite(img_path, img_rec_np)

                    img_rand_np = img_rand_y[j].numpy().transpose(1, 2, 0)
                    img_path = os.path.join(
                        result_save_dir, 'img{:06d}_rand_modified.png'.format(
                            i * config.batch_size + j,
                            object_idx_to_name[cur_obj_success]))
                    imwrite(img_path, img_rand_np)

        # calculate recall precision
        num_data = attributes_gt.shape[0]
        count = np.zeros((num_data, 4))
        recall, precision = np.zeros((num_data)), np.zeros((num_data))

        for i in range(num_data):
            # tn, fp, fn, tp = confusion_matrix(attributes_pred, attributes_gt).ravel()
            count[i] = confusion_matrix(attributes_gt[i],
                                        attributes_pred[i]).ravel()
            recall[i] = count[i][3] / (count[i][3] + count[i][2])
            precision[i] = safe_division(count[i][3],
                                         count[i][3] + count[i][1])

        print("average precision = {}".format(precision.mean()))
        print("average recall = {}".format(recall.mean()))

        print("average pred # per obj")
        print((count[:, 1].sum() + count[:, 3].sum()) / count.shape[0])

        print("average GT # per obj")
        print((count[:, 2].sum() + count[:, 3].sum()) / count.shape[0])

        print("% of data that predict something")
        print(((count[:, 1] + count[:, 3]) > 0).sum() / count.shape[0])

        print("% of data at least predicted correct once")
        print((count[:, 3] > 0).sum() / count.shape[0])
Ejemplo n.º 3
0
def main(config):
    cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir(
        config.exp_name)

    if config.dataset == 'vg':
        data_loader, _ = get_dataloader_vg(batch_size=config.batch_size,
                                           VG_DIR=config.vg_dir)
    elif config.dataset == 'coco':
        data_loader, _ = get_dataloader_coco(batch_size=config.batch_size,
                                             COCO_DIR=config.coco_dir)
    vocab_num = data_loader.dataset.num_objects

    assert config.clstm_layers > 0

    netG = Generator(num_embeddings=vocab_num,
                     embedding_dim=config.embedding_dim,
                     z_dim=config.z_dim,
                     clstm_layers=config.clstm_layers).to(device)
    netD_image = ImageDiscriminator(conv_dim=config.embedding_dim).to(device)
    netD_object = ObjectDiscriminator(n_class=vocab_num).to(device)

    netD_image = add_sn(netD_image)
    netD_object = add_sn(netD_object)

    netG_optimizer = torch.optim.Adam(netG.parameters(), config.learning_rate,
                                      [0.5, 0.999])
    netD_image_optimizer = torch.optim.Adam(netD_image.parameters(),
                                            config.learning_rate, [0.5, 0.999])
    netD_object_optimizer = torch.optim.Adam(netD_object.parameters(),
                                             config.learning_rate,
                                             [0.5, 0.999])

    start_iter = load_model(netG,
                            model_dir=model_save_dir,
                            appendix='netG',
                            iter=config.resume_iter)
    _ = load_model(netD_image,
                   model_dir=model_save_dir,
                   appendix='netD_image',
                   iter=config.resume_iter)
    _ = load_model(netD_object,
                   model_dir=model_save_dir,
                   appendix='netD_object',
                   iter=config.resume_iter)

    data_iter = iter(data_loader)

    if start_iter < config.niter:

        if config.use_tensorboard:
            writer = SummaryWriter(log_save_dir)

        for i in range(start_iter, config.niter):
            try:
                batch = next(data_iter)
            except:
                data_iter = iter(data_loader)
                batch = next(data_iter)

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #
            imgs, objs, boxes, masks, obj_to_img = batch
            z = torch.randn(objs.size(0), config.z_dim)
            imgs, objs, boxes, masks, obj_to_img, z = imgs.to(device), objs.to(device), boxes.to(device), \
                masks.to(device), obj_to_img, z.to(device)

            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #

            # Generate fake image
            output = netG(imgs, objs, boxes, masks, obj_to_img, z)
            crops_input, crops_input_rec, crops_rand, img_rec, img_rand, mu, logvar, z_rand_rec = output

            # Compute image adv loss with fake images.
            out_logits = netD_image(img_rec.detach())
            d_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            out_logits = netD_image(img_rand.detach())
            d_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            d_image_adv_loss_fake = 0.5 * d_image_adv_loss_fake_rec + 0.5 * d_image_adv_loss_fake_rand

            # Compute image src loss with real images rec.
            out_logits = netD_image(imgs)
            d_image_adv_loss_real = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 1))

            # Compute object sn adv loss with fake rec crops
            out_logits, _ = netD_object(crops_input_rec.detach(), objs)
            g_object_adv_loss_rec = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            # Compute object sn adv loss with fake rand crops
            out_logits, _ = netD_object(crops_rand.detach(), objs)
            d_object_adv_loss_fake_rand = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            d_object_adv_loss_fake = 0.5 * g_object_adv_loss_rec + 0.5 * d_object_adv_loss_fake_rand

            # Compute object sn adv loss with real crops.
            out_logits_src, out_logits_cls = netD_object(
                crops_input.detach(), objs)
            d_object_adv_loss_real = F.binary_cross_entropy_with_logits(
                out_logits_src, torch.full_like(out_logits_src, 1))
            d_object_cls_loss_real = F.cross_entropy(out_logits_cls, objs)

            # Backward and optimizloe.
            d_loss = 0
            d_loss += config.lambda_img_adv * (d_image_adv_loss_fake +
                                               d_image_adv_loss_real)
            d_loss += config.lambda_obj_adv * (d_object_adv_loss_fake +
                                               d_object_adv_loss_real)
            d_loss += config.lambda_obj_cls * d_object_cls_loss_real

            netD_image.zero_grad()
            netD_object.zero_grad()

            d_loss.backward()

            netD_image_optimizer.step()
            netD_object_optimizer.step()

            # Logging.
            loss = {}
            loss['D/loss'] = d_loss.item()
            loss['D/image_adv_loss_real'] = d_image_adv_loss_real.item()
            loss['D/image_adv_loss_fake'] = d_image_adv_loss_fake.item()
            loss['D/object_adv_loss_real'] = d_object_adv_loss_real.item()
            loss['D/object_adv_loss_fake'] = d_object_adv_loss_fake.item()
            loss['D/object_cls_loss_real'] = d_object_cls_loss_real.item()

            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #
            # Generate fake image
            output = netG(imgs, objs, boxes, masks, obj_to_img, z)
            crops_input, crops_input_rec, crops_rand, img_rec, img_rand, mu, logvar, z_rand_rec = output

            # reconstruction loss of ae and img
            g_img_rec_loss = torch.abs(img_rec - imgs).mean()
            g_z_rec_loss = torch.abs(z_rand_rec - z).mean()

            # kl loss
            kl_element = mu.pow(2).add_(
                logvar.exp()).mul_(-1).add_(1).add_(logvar)
            g_kl_loss = torch.sum(kl_element).mul_(-0.5)

            # Compute image adv loss with fake images.
            out_logits = netD_image(img_rec)
            g_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 1))

            out_logits = netD_image(img_rand)
            g_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 1))

            g_image_adv_loss_fake = 0.5 * g_image_adv_loss_fake_rec + 0.5 * g_image_adv_loss_fake_rand

            # Compute object adv loss with fake images.
            out_logits_src, out_logits_cls = netD_object(crops_input_rec, objs)
            g_object_adv_loss_rec = F.binary_cross_entropy_with_logits(
                out_logits_src, torch.full_like(out_logits_src, 1))
            g_object_cls_loss_rec = F.cross_entropy(out_logits_cls, objs)

            out_logits_src, out_logits_cls = netD_object(crops_rand, objs)
            g_object_adv_loss_rand = F.binary_cross_entropy_with_logits(
                out_logits_src, torch.full_like(out_logits_src, 1))
            g_object_cls_loss_rand = F.cross_entropy(out_logits_cls, objs)

            g_object_adv_loss = 0.5 * g_object_adv_loss_rec + 0.5 * g_object_adv_loss_rand
            g_object_cls_loss = 0.5 * g_object_cls_loss_rec + 0.5 * g_object_cls_loss_rand

            # Backward and optimize.
            g_loss = 0
            g_loss += config.lambda_img_rec * g_img_rec_loss
            g_loss += config.lambda_z_rec * g_z_rec_loss
            g_loss += config.lambda_img_adv * g_image_adv_loss_fake
            g_loss += config.lambda_obj_adv * g_object_adv_loss
            g_loss += config.lambda_obj_cls * g_object_cls_loss
            g_loss += config.lambda_kl * g_kl_loss

            netG.zero_grad()
            g_loss.backward()
            netG_optimizer.step()

            loss['G/loss'] = g_loss.item()
            loss['G/image_adv_loss'] = g_image_adv_loss_fake.item()
            loss['G/object_adv_loss'] = g_object_adv_loss.item()
            loss['G/object_cls_loss'] = g_object_cls_loss.item()
            loss['G/rec_img'] = g_img_rec_loss.item()
            loss['G/rec_z'] = g_z_rec_loss.item()
            loss['G/kl'] = g_kl_loss.item()

            # =================================================================================== #
            #                               4. Log                                                #
            # =================================================================================== #
            if (i + 1) % config.log_step == 0:
                log = 'iter [{:06d}/{:06d}]'.format(i + 1, config.niter)
                for tag, roi_value in loss.items():
                    log += ", {}: {:.4f}".format(tag, roi_value)
                print(log)

            if (i + 1
                ) % config.tensorboard_step == 0 and config.use_tensorboard:
                for tag, roi_value in loss.items():
                    writer.add_scalar(tag, roi_value, i + 1)
                writer.add_image(
                    'Result/crop_real',
                    imagenet_deprocess_batch(crops_input).float() / 255, i + 1)
                writer.add_image(
                    'Result/crop_real_rec',
                    imagenet_deprocess_batch(crops_input_rec).float() / 255,
                    i + 1)
                writer.add_image(
                    'Result/crop_rand',
                    imagenet_deprocess_batch(crops_rand).float() / 255, i + 1)
                writer.add_image('Result/img_real',
                                 imagenet_deprocess_batch(imgs).float() / 255,
                                 i + 1)
                writer.add_image(
                    'Result/img_real_rec',
                    imagenet_deprocess_batch(img_rec).float() / 255, i + 1)
                writer.add_image(
                    'Result/img_fake_rand',
                    imagenet_deprocess_batch(img_rand).float() / 255, i + 1)

            if (i + 1) % config.save_step == 0:
                save_model(netG,
                           model_dir=model_save_dir,
                           appendix='netG',
                           iter=i + 1,
                           save_num=5,
                           save_step=config.save_step)
                save_model(netD_image,
                           model_dir=model_save_dir,
                           appendix='netD_image',
                           iter=i + 1,
                           save_num=5,
                           save_step=config.save_step)
                save_model(netD_object,
                           model_dir=model_save_dir,
                           appendix='netD_object',
                           iter=i + 1,
                           save_num=5,
                           save_step=config.save_step)

        if config.use_tensorboard:
            writer.close()
        Path(result_save_dir).mkdir(parents=True)
    return log_save_dir, model_save_dir, sample_save_dir, result_save_dir


cudnn.benchmark = True
device = torch.device('cuda:0')
exp_name = 'training_att_cls_128'
log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir(
    exp_name)

niter = 400000
attribute_nums = 106
batch_size = 12
save_step = 2000
netD_att = AttributeDiscriminator(n_attribute=attribute_nums).to(device)
netD_att = add_sn(netD_att)
start_iter = load_model(netD_att,
                        model_dir=model_save_dir,
                        appendix='netD_attribute',
                        iter='l')
netD_att_optimizer = torch.optim.Adam(netD_att.parameters(), 2e-4,
                                      [0.5, 0.999])

dataset = 'vg'
if dataset == 'vg':
    train_data_loader, val_data_loader = get_dataloader_vg(
        batch_size=batch_size, attribute_embedding=attribute_nums)
elif dataset == 'coco':
    train_data_loader, val_data_loader = get_dataloader_coco(
        batch_size=batch_size)