Пример #1
0
    def forward(self, imgs, objs, boxes, masks, obj_to_img, z_rand, attribute, masks_shift, boxes_shift, attribute_est):
        crops_input = crop_bbox_batch(imgs, boxes, obj_to_img, self.obj_size)
        z_rec, mu, logvar = self.crop_encoder(crops_input, objs)

        objs_att = self.attribute_encoder(objs, attribute)
        objs_att_est = self.attribute_encoder(objs, attribute_est)
        # (n, clstm_dim*2, 8, 8)

        h_rec = self.layout_encoder(objs_att_est, masks, obj_to_img, z_rec, objs)
        h_rand = self.layout_encoder(objs_att, masks, obj_to_img, z_rand, objs)
        h_shift = self.layout_encoder(objs_att, masks_shift, obj_to_img, z_rand, objs)

        # global context encoder
        h_rec_global = self.global_encoder(h_rec)
        h_rand_global = self.global_encoder(h_rand)
        h_shift_global = self.global_encoder(h_shift)

        img_rec = self.decoder(h_rec, h_rec_global)
        img_rand = self.decoder(h_rand, h_rand_global)
        img_shift = self.decoder(h_shift, h_shift_global)

        crops_rand = crop_bbox_batch(img_rand, boxes, obj_to_img, self.obj_size)
        _, z_rand_rec, _ = self.crop_encoder(crops_rand, objs)

        crops_input_rec = crop_bbox_batch(img_rec, boxes, obj_to_img, self.obj_size)

        crops_shift = crop_bbox_batch(img_shift, boxes_shift, obj_to_img, self.obj_size)
        _, z_rand_shift, _ = self.crop_encoder(crops_shift, objs)

        return crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift
Пример #2
0
    def forward(self, imgs, objs, boxes, masks, obj_to_img, z_rand):
        crops_input = crop_bbox_batch(imgs, boxes, obj_to_img, self.obj_size)
        z_rec, mu, logvar = self.crop_encoder(crops_input, objs)

        # (n, clstm_dim*2, 8, 8)
        h_rec = self.layout_encoder(objs, masks, obj_to_img, z_rec)
        h_rand = self.layout_encoder(objs, masks, obj_to_img, z_rand)

        img_rec = self.decoder(h_rec)
        img_rand = self.decoder(h_rand)

        crops_rand = crop_bbox_batch(img_rand, boxes, obj_to_img,
                                     self.obj_size)
        _, z_rand_rec, _ = self.crop_encoder(crops_rand, objs)

        crops_input_rec = crop_bbox_batch(img_rec, boxes, obj_to_img,
                                          self.obj_size)

        return crops_input, crops_input_rec, crops_rand, img_rec, img_rand, mu, logvar, z_rand_rec
Пример #3
0
def test_model(model, dataloaders, criterion, input_size=224):
    since = time.time()

    val_acc_history = []
    phase = 'val'
    model.eval()

    running_loss = 0.0
    running_corrects = 0
    running_count = 0

    # Iterate over data.
    for i, batch in enumerate(dataloaders[phase]):
        imgs, objs, boxes, masks, obj_to_img, attributes = batch
        inputs = crop_bbox_batch(imgs, boxes, obj_to_img, input_size)
        inputs = inputs.to(device)
        labels = objs.to(device)

        # forward
        # track history if only in train
        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)


        # statistics
        running_loss += loss.item() * labels.size(0)
        running_corrects += torch.sum(preds == labels.data)
        running_count += labels.size(0)

        if (i + 1) % 20 == 0:
            print('loss: {:.4f} accu: {:.4f}'.format(loss.item(), torch.mean((preds == labels.data).float())))

    epoch_loss = running_loss / running_count
    epoch_acc = running_corrects.double() / running_count

    print('================================================================')
    print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
def main(config):
    matrix = torch.load("matrix_obj_vs_att.pt")
    cudnn.benchmark = True
    device = torch.device('cuda:1')

    log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir(
        config.exp_name)

    attribute_nums = 106

    data_loader, _ = get_dataloader_vg(batch_size=config.batch_size,
                                       attribute_embedding=attribute_nums,
                                       image_size=config.image_size)

    vocab_num = data_loader.dataset.num_objects

    if config.clstm_layers == 0:
        netG = Generator_nolstm(num_embeddings=vocab_num,
                                embedding_dim=config.embedding_dim,
                                z_dim=config.z_dim).to(device)
    else:
        netG = Generator(num_embeddings=vocab_num,
                         obj_att_dim=config.embedding_dim,
                         z_dim=config.z_dim,
                         clstm_layers=config.clstm_layers,
                         obj_size=config.object_size,
                         attribute_dim=attribute_nums).to(device)

    netD_image = ImageDiscriminator(conv_dim=config.embedding_dim).to(device)
    netD_object = ObjectDiscriminator(n_class=vocab_num).to(device)
    netD_att = AttributeDiscriminator(n_attribute=attribute_nums).to(device)

    netD_image = add_sn(netD_image)
    netD_object = add_sn(netD_object)
    netD_att = add_sn(netD_att)

    netG_optimizer = torch.optim.Adam(netG.parameters(), config.learning_rate,
                                      [0.5, 0.999])
    netD_image_optimizer = torch.optim.Adam(netD_image.parameters(),
                                            config.learning_rate, [0.5, 0.999])
    netD_object_optimizer = torch.optim.Adam(netD_object.parameters(),
                                             config.learning_rate,
                                             [0.5, 0.999])
    netD_att_optimizer = torch.optim.Adam(netD_att.parameters(),
                                          config.learning_rate, [0.5, 0.999])

    start_iter_ = load_model(netD_object,
                             model_dir=model_save_dir,
                             appendix='netD_object',
                             iter=config.resume_iter)

    start_iter_ = load_model(netD_att,
                             model_dir=model_save_dir,
                             appendix='netD_attribute',
                             iter=config.resume_iter)

    start_iter_ = load_model(netD_image,
                             model_dir=model_save_dir,
                             appendix='netD_image',
                             iter=config.resume_iter)

    start_iter = load_model(netG,
                            model_dir=model_save_dir,
                            appendix='netG',
                            iter=config.resume_iter)

    data_iter = iter(data_loader)

    if start_iter < config.niter:

        if config.use_tensorboard: writer = SummaryWriter(log_save_dir)

        for i in range(start_iter, config.niter):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #
            try:
                batch = next(data_iter)
            except:
                data_iter = iter(data_loader)
                batch = next(data_iter)

            imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch
            z = torch.randn(objs.size(0), config.z_dim)

            att_idx = attribute.sum(dim=1).nonzero().squeeze()
            # print("Train D")
            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #
            imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift \
                = imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), obj_to_img, z.to(
                device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device)

            attribute_GT = attribute.clone()

            # estimate attributes
            attribute_est = attribute.clone()
            att_mask = torch.zeros(attribute.shape[0])
            att_mask = att_mask.scatter(0, att_idx, 1).to(device)

            crops_input = crop_bbox_batch(imgs, boxes, obj_to_img,
                                          config.object_size)
            estimated_att = netD_att(crops_input)
            max_idx = estimated_att.argmax(1)
            max_idx = max_idx.float() * (~att_mask.byte()).float().to(device)
            for row in range(attribute.shape[0]):
                if row not in att_idx:
                    attribute_est[row, int(max_idx[row])] = 1

            # change GT attribute:
            num_img_to_change = math.floor(imgs.shape[0] / 3)
            for img_idx in range(num_img_to_change):
                obj_indices = torch.nonzero(obj_to_img == img_idx).view(-1)

                num_objs_to_change = math.floor(len(obj_indices) / 2)
                for changed, obj_idx in enumerate(obj_indices):
                    if changed >= num_objs_to_change:
                        break
                    obj = objs[obj_idx]
                    # change GT attribute
                    old_attributes = torch.nonzero(
                        attribute_GT[obj_idx]).view(-1)
                    new_attribute = random.choices(range(106),
                                                   matrix[obj].scatter(
                                                       0, old_attributes.cpu(),
                                                       0),
                                                   k=random.randrange(1, 3))
                    attribute[obj_idx] = 0  # remove all attributes for obj
                    attribute[obj_idx] = attribute[obj_idx].scatter(
                        0,
                        torch.LongTensor(new_attribute).to(device),
                        1)  # assign new attribute

                    # change estimated attributes
                    attribute_est[obj_idx] = 0  # remove all attributes for obj
                    attribute_est[obj_idx] = attribute[obj_idx].scatter(
                        0,
                        torch.LongTensor(new_attribute).to(device), 1)

            # Generate fake image
            output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute,
                          masks_shift, boxes_shift, attribute_est)
            crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output

            # Compute image adv loss with fake images.
            out_logits = netD_image(img_rec.detach())
            d_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            out_logits = netD_image(img_rand.detach())
            d_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            # shift image adv loss
            out_logits = netD_image(img_shift.detach())
            d_image_adv_loss_fake_shift = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            d_image_adv_loss_fake = 0.4 * d_image_adv_loss_fake_rec + 0.4 * d_image_adv_loss_fake_rand + 0.2 * d_image_adv_loss_fake_shift

            # Compute image src loss with real images rec.
            out_logits = netD_image(imgs)
            d_image_adv_loss_real = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 1))

            # Compute object sn adv loss with fake rec crops
            out_logits, _ = netD_object(crops_input_rec.detach(), objs)
            g_object_adv_loss_rec = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            # Compute object sn adv loss with fake rand crops
            out_logits, _ = netD_object(crops_rand.detach(), objs)

            d_object_adv_loss_fake_rand = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            # shift obj adv loss
            out_logits, _ = netD_object(crops_shift.detach(), objs)
            d_object_adv_loss_fake_shift = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            d_object_adv_loss_fake = 0.4 * g_object_adv_loss_rec + 0.4 * d_object_adv_loss_fake_rand + 0.2 * d_object_adv_loss_fake_shift

            # Compute object sn adv loss with real crops.
            out_logits_src, out_logits_cls = netD_object(
                crops_input.detach(), objs)

            d_object_adv_loss_real = F.binary_cross_entropy_with_logits(
                out_logits_src, torch.full_like(out_logits_src, 1))

            # cls
            d_object_cls_loss_real = F.cross_entropy(out_logits_cls, objs)
            # attribute
            att_cls = netD_att(crops_input.detach())
            att_idx = attribute_GT.sum(dim=1).nonzero().squeeze()
            att_cls_annotated = torch.index_select(att_cls, 0, att_idx)
            attribute_annotated = torch.index_select(attribute_GT, 0, att_idx)
            d_object_att_cls_loss_real = F.binary_cross_entropy_with_logits(
                att_cls_annotated,
                attribute_annotated,
                pos_weight=pos_weight.to(device))

            # Backward and optimize.
            d_loss = 0
            d_loss += config.lambda_img_adv * (d_image_adv_loss_fake +
                                               d_image_adv_loss_real)
            d_loss += config.lambda_obj_adv * (d_object_adv_loss_fake +
                                               d_object_adv_loss_real)
            d_loss += config.lambda_obj_cls * d_object_cls_loss_real
            d_loss += config.lambda_att_cls * d_object_att_cls_loss_real

            netD_image.zero_grad()
            netD_object.zero_grad()
            netD_att.zero_grad()

            d_loss.backward()

            netD_image_optimizer.step()
            netD_object_optimizer.step()
            netD_att_optimizer.step()

            # Logging.
            loss = {}
            loss['D/loss'] = d_loss.item()
            loss['D/image_adv_loss_real'] = d_image_adv_loss_real.item()
            loss['D/image_adv_loss_fake'] = d_image_adv_loss_fake.item()
            loss['D/object_adv_loss_real'] = d_object_adv_loss_real.item()
            loss['D/object_adv_loss_fake'] = d_object_adv_loss_fake.item()
            loss['D/object_cls_loss_real'] = d_object_cls_loss_real.item()
            loss['D/object_att_cls_loss'] = d_object_att_cls_loss_real.item()

            # print("train G")
            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #
            # Generate fake image

            output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute,
                          masks_shift, boxes_shift, attribute_est)
            crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output

            # reconstruction loss of ae and img
            rec_img_mask = torch.ones(imgs.shape[0]).scatter(
                0, torch.LongTensor(range(num_img_to_change)), 0).to(device)
            g_img_rec_loss = rec_img_mask * torch.abs(img_rec - imgs).view(
                imgs.shape[0], -1).mean(1)
            g_img_rec_loss = g_img_rec_loss.sum() / (imgs.shape[0] -
                                                     num_img_to_change)

            g_z_rec_loss_rand = torch.abs(z_rand_rec - z).mean()
            g_z_rec_loss_shift = torch.abs(z_rand_shift - z).mean()
            g_z_rec_loss = 0.5 * g_z_rec_loss_rand + 0.5 * g_z_rec_loss_shift

            # kl loss
            kl_element = mu.pow(2).add_(
                logvar.exp()).mul_(-1).add_(1).add_(logvar)
            g_kl_loss = torch.sum(kl_element).mul_(-0.5)

            # Compute image adv loss with fake images.
            out_logits = netD_image(img_rec)

            g_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 1))

            out_logits = netD_image(img_rand)
            g_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 1))

            # shift image adv loss
            out_logits = netD_image(img_shift)
            g_image_adv_loss_fake_shift = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 1))

            g_image_adv_loss_fake = 0.4 * g_image_adv_loss_fake_rec + 0.4 * g_image_adv_loss_fake_rand + 0.2 * g_image_adv_loss_fake_shift

            # Compute object adv loss with fake images.
            out_logits_src, out_logits_cls = netD_object(crops_input_rec, objs)

            g_object_adv_loss_rec = F.binary_cross_entropy_with_logits(
                out_logits_src, torch.full_like(out_logits_src, 1))
            g_object_cls_loss_rec = F.cross_entropy(out_logits_cls, objs)
            # attribute
            att_cls = netD_att(crops_input_rec)
            att_idx = attribute.sum(dim=1).nonzero().squeeze()
            attribute_annotated = torch.index_select(attribute, 0, att_idx)
            att_cls_annotated = torch.index_select(att_cls, 0, att_idx)
            g_object_att_cls_loss_rec = F.binary_cross_entropy_with_logits(
                att_cls_annotated,
                attribute_annotated,
                pos_weight=pos_weight.to(device))

            out_logits_src, out_logits_cls = netD_object(crops_rand, objs)
            g_object_adv_loss_rand = F.binary_cross_entropy_with_logits(
                out_logits_src, torch.full_like(out_logits_src, 1))
            g_object_cls_loss_rand = F.cross_entropy(out_logits_cls, objs)
            # attribute
            att_cls = netD_att(crops_rand)
            att_cls_annotated = torch.index_select(att_cls, 0, att_idx)
            g_object_att_cls_loss_rand = F.binary_cross_entropy_with_logits(
                att_cls_annotated,
                attribute_annotated,
                pos_weight=pos_weight.to(device))

            # shift adv obj loss
            out_logits_src, out_logits_cls = netD_object(crops_shift, objs)
            g_object_adv_loss_shift = F.binary_cross_entropy_with_logits(
                out_logits_src, torch.full_like(out_logits_src, 1))

            g_object_cls_loss_shift = F.cross_entropy(out_logits_cls, objs)
            # attribute
            att_cls = netD_att(crops_shift)
            att_cls_annotated = torch.index_select(att_cls, 0, att_idx)
            g_object_att_cls_loss_shift = F.binary_cross_entropy_with_logits(
                att_cls_annotated,
                attribute_annotated,
                pos_weight=pos_weight.to(device))

            g_object_att_cls_loss = 0.4 * g_object_att_cls_loss_rec + 0.4 * g_object_att_cls_loss_rand + 0.2 * g_object_att_cls_loss_shift

            g_object_adv_loss = 0.4 * g_object_adv_loss_rec + 0.4 * g_object_adv_loss_rand + 0.2 * g_object_adv_loss_shift
            g_object_cls_loss = 0.4 * g_object_cls_loss_rec + 0.4 * g_object_cls_loss_rand + 0.2 * g_object_cls_loss_shift

            # Backward and optimize.
            g_loss = 0
            g_loss += config.lambda_img_rec * g_img_rec_loss
            g_loss += config.lambda_z_rec * g_z_rec_loss
            g_loss += config.lambda_img_adv * g_image_adv_loss_fake
            g_loss += config.lambda_obj_adv * g_object_adv_loss
            g_loss += config.lambda_obj_cls * g_object_cls_loss
            g_loss += config.lambda_att_cls * g_object_att_cls_loss
            g_loss += config.lambda_kl * g_kl_loss

            netG.zero_grad()

            g_loss.backward()

            netG_optimizer.step()

            loss['G/loss'] = g_loss.item()
            loss['G/image_adv_loss'] = g_image_adv_loss_fake.item()
            loss['G/object_adv_loss'] = g_object_adv_loss.item()
            loss['G/object_cls_loss'] = g_object_cls_loss.item()
            loss['G/rec_img'] = g_img_rec_loss.item()
            loss['G/rec_z'] = g_z_rec_loss.item()
            loss['G/kl'] = g_kl_loss.item()
            loss['G/object_att_cls_loss'] = g_object_att_cls_loss.item()

            # =================================================================================== #
            #                               4. Log                                                #
            # =================================================================================== #
            if (i + 1) % config.log_step == 0:
                log = 'iter [{:06d}/{:06d}]'.format(i + 1, config.niter)
                for tag, roi_value in loss.items():
                    log += ", {}: {:.4f}".format(tag, roi_value)
                print(log)

            if (i + 1
                ) % config.tensorboard_step == 0 and config.use_tensorboard:
                for tag, roi_value in loss.items():

                    writer.add_scalar(tag, roi_value, i + 1)
                writer.add_images(
                    'Result/crop_real',
                    imagenet_deprocess_batch(crops_input).float() / 255, i + 1)
                writer.add_images(
                    'Result/crop_real_rec',
                    imagenet_deprocess_batch(crops_input_rec).float() / 255,
                    i + 1)
                writer.add_images(
                    'Result/crop_rand',
                    imagenet_deprocess_batch(crops_rand).float() / 255, i + 1)
                writer.add_images('Result/img_real',
                                  imagenet_deprocess_batch(imgs).float() / 255,
                                  i + 1)
                writer.add_images(
                    'Result/img_real_rec',
                    imagenet_deprocess_batch(img_rec).float() / 255, i + 1)
                writer.add_images(
                    'Result/img_fake_rand',
                    imagenet_deprocess_batch(img_rand).float() / 255, i + 1)

            if (i + 1) % config.save_step == 0:

                # netG_noDP.load_state_dict(new_state_dict)
                save_model(netG,
                           model_dir=model_save_dir,
                           appendix='netG',
                           iter=i + 1,
                           save_num=2,
                           save_step=config.save_step)
                save_model(netD_image,
                           model_dir=model_save_dir,
                           appendix='netD_image',
                           iter=i + 1,
                           save_num=2,
                           save_step=config.save_step)
                save_model(netD_object,
                           model_dir=model_save_dir,
                           appendix='netD_object',
                           iter=i + 1,
                           save_num=2,
                           save_step=config.save_step)
                save_model(netD_att,
                           model_dir=model_save_dir,
                           appendix='netD_attribute',
                           iter=i + 1,
                           save_num=2,
                           save_step=config.save_step)

        if config.use_tensorboard: writer.close()
def test_model(model, pickle_files, criterion, input_size=224):
    since = time.time()

    val_acc_history = []
    phase = 'val'
    model.eval()

    running_loss_real = 0.0
    running_corrects_real = 0
    running_count_real = 0

    running_loss_fake = 0.0
    running_corrects_fake = 0
    running_count_fake = 0

    running_loss_fake_shift = 0.0
    running_corrects_fake_shift = 0
    running_count_fake_shift = 0

    # Iterate over data.
    for i, pickle_file in enumerate(pickle_files):

        batch = pickle.load(open(pickle_file, 'rb'))
        imgs, imgs_rand, imgs_shift, objs, boxes, boxes_shift, obj_to_img = batch[
            'imgs'], batch['imgs_rand'], batch['imgs_shift'], batch[
                'objs'], batch['boxes'], batch['boxes_shift'], batch[
                    'obj_to_img']

        inputs_real = crop_bbox_batch(imgs, boxes, obj_to_img, input_size)
        inputs_real = inputs_real.to(device)

        inputs_fake = crop_bbox_batch(imgs_rand, boxes, obj_to_img, input_size)
        inputs_fake = inputs_fake.to(device)

        inputs_fake_shift = crop_bbox_batch(imgs_shift, boxes_shift,
                                            obj_to_img, input_size)
        inputs_fake_shift = inputs_fake_shift.to(device)

        labels = objs.to(device)

        # forward
        # track history if only in train
        with torch.no_grad():
            outputs_real = model(inputs_real)
            loss_real = criterion(outputs_real, labels)
            _, preds_real = torch.max(outputs_real, 1)

            outputs_fake = model(inputs_fake)
            loss_fake = criterion(outputs_fake, labels)
            _, preds_fake = torch.max(outputs_fake, 1)

            outputs_fake_shift = model(inputs_fake_shift)
            loss_fake_shift = criterion(outputs_fake_shift, labels)
            _, preds_fake_shift = torch.max(outputs_fake_shift, 1)

        # statistics
        running_loss_real += loss_real.item() * labels.size(0)
        running_corrects_real += torch.sum(preds_real == labels.data)
        running_count_real += labels.size(0)

        running_loss_fake += loss_fake.item() * labels.size(0)
        running_corrects_fake += torch.sum(preds_fake == labels.data)
        running_count_fake += labels.size(0)

        running_loss_fake_shift += loss_fake_shift.item() * labels.size(0)
        running_corrects_fake_shift += torch.sum(
            preds_fake_shift == labels.data)
        running_count_fake_shift += labels.size(0)

        if (i + 1) % 20 == 0:
            print(
                'real loss: {:.4f} accu: {:.4f} fake loss: {:.4f} accu: {:.4f} shift loss: {:.4f} accu: {:.4f}'
                .format(loss_real.item(),
                        torch.mean((preds_real == labels.data).float()),
                        loss_fake.item(),
                        torch.mean((preds_fake == labels.data).float()),
                        loss_fake_shift.item(),
                        torch.mean((preds_fake_shift == labels.data).float())))

    epoch_loss_real = running_loss_real / running_count_real
    epoch_acc_real = running_corrects_real.double() / running_count_real
    epoch_loss_fake = running_loss_fake / running_count_fake
    epoch_acc_fake = running_corrects_fake.double() / running_count_fake
    epoch_loss_fake_shift = running_loss_fake_shift / running_count_fake_shift
    epoch_acc_fake_shift = running_corrects_fake_shift.double(
    ) / running_count_fake_shift

    print('================================================================')
    print(
        '{} Real Loss: {:.4f} Acc: {:.4f} Fake Loss: {:.4f} Acc: {:.4f} shift loss: {:.4f} accu: {:.4f}'
        .format(phase, epoch_loss_real, epoch_acc_real, epoch_loss_fake,
                epoch_acc_fake, epoch_loss_fake_shift, epoch_acc_fake_shift))
Пример #6
0
def main(config):
    cudnn.benchmark = True

    device = torch.device('cuda:1')

    log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir(
        config.exp_name)

    with open("data/vocab.json", 'r') as f:
        vocab = json.load(f)
        att_idx_to_name = np.array(vocab['attribute_idx_to_name'])
        print(att_idx_to_name)
        object_idx_to_name = np.array(vocab['object_idx_to_name'])

    attribute_nums = 106

    train_data_loader, val_data_loader = get_dataloader_vg(
        batch_size=config.batch_size, attribute_embedding=attribute_nums)

    vocab_num = train_data_loader.dataset.num_objects

    netG = Generator(num_embeddings=vocab_num,
                     obj_att_dim=config.embedding_dim,
                     z_dim=config.z_dim,
                     clstm_layers=config.clstm_layers,
                     obj_size=config.object_size,
                     attribute_dim=attribute_nums).to(device)

    netD_att = AttributeDiscriminator(n_attribute=attribute_nums).to(device)
    netD_att = add_sn(netD_att)

    start_iter_ = load_model(netD_att,
                             model_dir="~/models/trained_models",
                             appendix='netD_attribute',
                             iter=config.resume_iter)
    _ = load_model(netG,
                   model_dir=model_save_dir,
                   appendix='netG',
                   iter=config.resume_iter)

    data_loader = val_data_loader
    data_iter = iter(data_loader)

    cur_obj_start_idx = 0
    attributes_pred = np.zeros((253468, attribute_nums))
    attributes_gt = None

    with torch.no_grad():
        netG.eval()
        for i, batch in enumerate(data_iter):
            print('batch {}'.format(i))

            imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch
            z = torch.randn(objs.size(0), config.z_dim)
            att_idx = attribute.sum(dim=1).nonzero().squeeze()

            imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift = \
                imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), \
                obj_to_img, z.to(device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device)

            # estimate attributes
            attribute_est = attribute.clone()
            att_mask = torch.zeros(attribute.shape[0])
            att_mask = att_mask.scatter(0, att_idx, 1).to(device)

            crops_input = crop_bbox_batch(imgs, boxes, obj_to_img,
                                          config.object_size)
            estimated_att = netD_att(crops_input)
            max_idx = estimated_att.argmax(1)
            max_idx = max_idx.float() * (~att_mask.byte()).float().to(device)
            for row in range(attribute.shape[0]):
                if row not in att_idx:
                    attribute_est[row, int(max_idx[row])] = 1

            # Generate fake image
            output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute,
                          masks_shift, boxes_shift, attribute_est)
            crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output

            # predict attribute on generated
            crops_rand_yes = torch.index_select(crops_rand, 0,
                                                att_idx.to(device))
            attribute_yes = torch.index_select(attribute, 0,
                                               att_idx.to(device))

            estimated_att_rand = netD_att(crops_rand_yes)
            att_cls = torch.sigmoid(estimated_att_rand)

            for k in att_idx:
                ll = att_idx.cpu().numpy().tolist().index(k)

                non0_idx_cls = (att_cls[ll] > 0.9).nonzero()

                pred_idx = non0_idx_cls.squeeze().cpu().numpy()

                attributes_pred[cur_obj_start_idx, pred_idx] = 1

                cur_obj_start_idx += 1

            # construct GT array
            attributes_gt = attribute_yes.clone().cpu(
            ) if attributes_gt is None else np.vstack(
                [attributes_gt, attribute_yes.clone().cpu()])

            img_rand = imagenet_deprocess_batch(img_rand)
            img_shift = imagenet_deprocess_batch(img_shift)
            img_rec = imagenet_deprocess_batch(img_rec)

            # attribute modification
            changed_list = []
            src = 2  # 94 blue, 95 black
            tgt = 95  # 8 red, 2 white
            for idx, o in enumerate(objs):
                attribute[idx, [2, 8, 0, 94, 90, 95, 96, 34, 25, 70, 58, 104
                                ]] = 0  # remove other color
                attribute[idx, tgt] = 1

                attribute_est[idx,
                              [2, 8, 0, 94, 90, 95, 96, 34, 25, 70, 58, 104
                               ]] = 0  # remove other color
                attribute_est[idx, tgt] = 1
                changed_list.append(idx)

            # Generate red image
            z = torch.randn(objs.size(0), config.z_dim).to(device)
            output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute,
                          masks_shift, boxes_shift, attribute_est)
            crops_input, crops_input_rec, crops_rand_y, crops_shift_y, img_rec_y, img_rand_y, img_shift_y, mu, logvar, z_rand_rec, z_rand_shift = output

            img_rand_y = imagenet_deprocess_batch(img_rand_y)
            img_shift_y = imagenet_deprocess_batch(img_shift_y)
            img_rec_y = imagenet_deprocess_batch(img_rec_y)
            imgs = imagenet_deprocess_batch(imgs)

            # 2 top k
            estimated_att = netD_att(crops_rand)
            max_idx = estimated_att.topk(5)[1]
            changed_list = [i for i in changed_list if tgt not in max_idx[i]]
            estimated_att_y = netD_att(crops_rand_y)
            max_idx_y = estimated_att_y.topk(3)[1]
            changed_list_success = [
                i for i in changed_list if tgt in max_idx_y[i]
            ]
            for j in range(imgs.shape[0]):

                img_np = img_shift[j].numpy().transpose(1, 2, 0)
                img_path = os.path.join(
                    result_save_dir,
                    'img{:06d}_shift.png'.format(i * config.batch_size + j))
                imwrite(img_path, img_np)

                img_np = img_rand[j].numpy().transpose(1, 2, 0)
                img_path = os.path.join(
                    result_save_dir,
                    'img{:06d}_rand.png'.format(i * config.batch_size + j))
                imwrite(img_path, img_np)

                img_rec_np = img_rec[j].numpy().transpose(1, 2, 0)
                img_path = os.path.join(
                    result_save_dir,
                    'img{:06d}_rec.png'.format(i * config.batch_size + j))
                imwrite(img_path, img_rec_np)

                img_real_np = imgs[j].numpy().transpose(1, 2, 0)
                img_path = os.path.join(
                    result_save_dir,
                    'img{:06d}_real.png'.format(i * config.batch_size + j))
                imwrite(img_path, img_real_np)

                cur_obj_success = [
                    int(objs[c]) for c in changed_list_success
                    if obj_to_img[c] == j
                ]

                # save successfully modified images
                if len(cur_obj_success) > 0:
                    img_shift_y = img_shift_y[j].numpy().transpose(1, 2, 0)
                    img_path_red = os.path.join(
                        result_save_dir,
                        'img{:06d}_shift_{}_modified.png'.format(
                            i * config.batch_size + j,
                            object_idx_to_name[cur_obj_success]))
                    imwrite(img_path_red, img_shift_y)

                    img_rec_np = img_rec_y[j].numpy().transpose(1, 2, 0)
                    img_path = os.path.join(
                        result_save_dir, 'img{:06d}_rec_modified.png'.format(
                            i * config.batch_size + j,
                            object_idx_to_name[cur_obj_success]))
                    imwrite(img_path, img_rec_np)

                    img_rand_np = img_rand_y[j].numpy().transpose(1, 2, 0)
                    img_path = os.path.join(
                        result_save_dir, 'img{:06d}_rand_modified.png'.format(
                            i * config.batch_size + j,
                            object_idx_to_name[cur_obj_success]))
                    imwrite(img_path, img_rand_np)

        # calculate recall precision
        num_data = attributes_gt.shape[0]
        count = np.zeros((num_data, 4))
        recall, precision = np.zeros((num_data)), np.zeros((num_data))

        for i in range(num_data):
            # tn, fp, fn, tp = confusion_matrix(attributes_pred, attributes_gt).ravel()
            count[i] = confusion_matrix(attributes_gt[i],
                                        attributes_pred[i]).ravel()
            recall[i] = count[i][3] / (count[i][3] + count[i][2])
            precision[i] = safe_division(count[i][3],
                                         count[i][3] + count[i][1])

        print("average precision = {}".format(precision.mean()))
        print("average recall = {}".format(recall.mean()))

        print("average pred # per obj")
        print((count[:, 1].sum() + count[:, 3].sum()) / count.shape[0])

        print("average GT # per obj")
        print((count[:, 2].sum() + count[:, 3].sum()) / count.shape[0])

        print("% of data that predict something")
        print(((count[:, 1] + count[:, 3]) > 0).sum() / count.shape[0])

        print("% of data at least predicted correct once")
        print((count[:, 3] > 0).sum() / count.shape[0])
def main(config):
    cudnn.benchmark = True
    device = torch.device('cuda:0')

    log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir(
        config.exp_name)

    resinet_dir = "~/pickle/128_vg_pkl_resinet50_247k"

    if not os.path.exists(resinet_dir): os.mkdir(resinet_dir)

    attribute_nums = 106
    if config.dataset == 'vg':
        train_data_loader, val_data_loader = get_dataloader_vg(
            batch_size=config.batch_size, attribute_embedding=attribute_nums)
    elif config.dataset == 'coco':
        train_data_loader, val_data_loader = get_dataloader_coco(
            batch_size=config.batch_size)
    vocab_num = train_data_loader.dataset.num_objects

    if config.clstm_layers == 0:
        netG = Generator_nolstm(num_embeddings=vocab_num,
                                embedding_dim=config.embedding_dim,
                                z_dim=config.z_dim).to(device)
    else:
        netG = Generator(num_embeddings=vocab_num,
                         obj_att_dim=config.embedding_dim,
                         z_dim=config.z_dim,
                         clstm_layers=config.clstm_layers,
                         obj_size=config.object_size,
                         attribute_dim=attribute_nums).to(device)

    _ = load_model(netG,
                   model_dir=model_save_dir,
                   appendix='netG',
                   iter=config.resume_iter)

    netD_att = train_att_change.AttributeDiscriminator(
        n_attribute=attribute_nums).to(device)
    netD_att = train_att_change.add_sn(netD_att)

    start_iter_ = load_model(netD_att,
                             model_dir="~/models/trained_models",
                             appendix='netD_attribute',
                             iter=config.resume_iter)
    _ = load_model(netG,
                   model_dir=model_save_dir,
                   appendix='netG',
                   iter=config.resume_iter)

    data_loader = val_data_loader
    data_iter = iter(data_loader)

    with torch.no_grad():
        netG.eval()
        for i, batch in enumerate(data_iter):
            print('batch {}'.format(i))
            imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch
            att_idx = attribute.sum(dim=1).nonzero().squeeze()

            z = torch.randn(objs.size(0), config.z_dim)
            imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift = \
                imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), \
                obj_to_img, z.to(device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device)

            # estimate attributes
            attribute_est = attribute.clone()
            att_mask = torch.zeros(attribute.shape[0])
            att_mask = att_mask.scatter(0, att_idx, 1).to(device)

            crops_input = crop_bbox_batch(imgs, boxes, obj_to_img,
                                          config.object_size)
            estimated_att = netD_att(crops_input)
            max_idx = estimated_att.argmax(1)
            max_idx = max_idx.float() * (~att_mask.byte()).float().to(device)
            for row in range(attribute.shape[0]):
                if row not in att_idx:
                    attribute_est[row, int(max_idx[row])] = 1

            # Generate fake image
            output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute,
                          masks_shift, boxes_shift, attribute_est)
            crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output

            dict_to_save = {
                'imgs': imgs,
                'imgs_rand': img_rand,
                'imgs_shift': img_shift,
                'objs': objs,
                'boxes': boxes,
                'boxes_shift': boxes_shift,
                'obj_to_img': obj_to_img
            }

            out_name = os.path.join(resinet_dir, 'batch_{}.pkl'.format(i))
            pickle.dump(dict_to_save, open(out_name, 'wb'))

            img_rand = imagenet_deprocess_batch(img_rand)
            img_shift = imagenet_deprocess_batch(img_shift)

            # Save the generated images

            for j in range(img_rand.shape[0]):
                # layout = draw_layout(boxes[obj_to_img==j], objs[obj_to_img==j], True, dset_mode='vg')
                # img_path = os.path.join(result_save_dir, 'layout/img{:06d}_layout.png'.format(i*config.batch_size+j))
                # imwrite(img_path, layout)

                img_np = img_rand[j].numpy().transpose(1, 2, 0)
                img_path = os.path.join(
                    result_save_dir,
                    'img{:06d}.png'.format(i * config.batch_size + j))
                imwrite(img_path, img_np)

                # layout = draw_layout(boxes_shift[obj_to_img==j], objs[obj_to_img==j], True, dset_mode='vg')
                # img_path = os.path.join(result_save_dir, 'layout/img{:06d}_layouts.png'.format(i*config.batch_size+j))
                # imwrite(img_path, layout)

                img_np = img_shift[j].numpy().transpose(1, 2, 0)
                img_path = os.path.join(
                    result_save_dir,
                    'img{:06d}s.png'.format(i * config.batch_size + j))
                imwrite(img_path, img_np)
    writer = SummaryWriter(log_save_dir)

    for i in range(start_iter, niter):

        try:
            batch = next(data_iter)
        except:
            data_iter = iter(data_loader)
            batch = next(data_iter)

        imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch
        imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = \
            imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), \
            obj_to_img, attribute.to(device), masks_shift.to(device), boxes_shift.to(device)

        crops_input = crop_bbox_batch(imgs, boxes, obj_to_img, 64)

        att_cls = netD_att(crops_input.detach())
        att_idx = attribute.sum(dim=1).nonzero().squeeze()
        att_cls_annotated = torch.index_select(att_cls, 0, att_idx)
        attribute_annotated = torch.index_select(attribute, 0, att_idx)
        d_object_att_cls_loss_real = F.binary_cross_entropy_with_logits(
            att_cls_annotated,
            attribute_annotated,
            pos_weight=pos_weight.to(device))

        d_loss = 0
        d_loss += d_object_att_cls_loss_real

        netD_att.zero_grad()
def main(config):
    cudnn.benchmark = True
    device = torch.device('cuda:0')

    log_save_dir, model_save_dir, sample_save_dir, result_save_dir1, result_save_dir2 = prepare_dir(
        config.exp_name)

    resinet_dir = "~/pickle/vg_pkl_resinet50"

    if not os.path.exists(resinet_dir): os.mkdir(resinet_dir)

    attribute_nums = 106
    if config.dataset == 'vg':
        train_data_loader, val_data_loader = get_dataloader_vg(
            batch_size=config.batch_size, attribute_embedding=attribute_nums)
    elif config.dataset == 'coco':
        train_data_loader, val_data_loader = get_dataloader_coco(
            batch_size=config.batch_size)
    vocab_num = train_data_loader.dataset.num_objects

    if config.clstm_layers == 0:
        netG = Generator_nolstm(num_embeddings=vocab_num,
                                embedding_dim=config.embedding_dim,
                                z_dim=config.z_dim).to(device)
    else:
        netG = Generator(num_embeddings=vocab_num,
                         obj_att_dim=config.embedding_dim,
                         z_dim=config.z_dim,
                         clstm_layers=config.clstm_layers,
                         obj_size=config.object_size,
                         attribute_dim=attribute_nums).to(device)

    _ = load_model(netG,
                   model_dir=model_save_dir,
                   appendix='netG',
                   iter=config.resume_iter)

    netD_att = train_att_change.AttributeDiscriminator(
        n_attribute=attribute_nums).to(device)
    netD_att = train_att_change.add_sn(netD_att)

    start_iter_ = load_model(netD_att,
                             model_dir="~/models/trained_models",
                             appendix='netD_attribute',
                             iter=config.resume_iter)
    _ = load_model(netG,
                   model_dir=model_save_dir,
                   appendix='netG',
                   iter=config.resume_iter)

    data_loader = val_data_loader
    data_iter = iter(data_loader)

    L1_dist_b = 0
    L1_dist_f = 0
    L1_rand_dist = 0
    with torch.no_grad():
        netG.eval()
        for i, batch in enumerate(data_iter):
            print('batch {}'.format(i))
            imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch
            att_idx = attribute.sum(dim=1).nonzero().squeeze()

            z = torch.randn(objs.size(0), config.z_dim)
            imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift = \
                imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), \
                obj_to_img, z.to(device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device)

            # estimate attributes
            attribute_est = attribute.clone()
            att_mask = torch.zeros(attribute.shape[0])
            att_mask = att_mask.scatter(0, att_idx, 1).to(device)

            crops_input = crop_bbox_batch(imgs, boxes, obj_to_img,
                                          config.object_size)
            estimated_att = netD_att(crops_input)
            max_idx = estimated_att.argmax(1)
            max_idx = max_idx.float() * (~att_mask.byte()).float().to(device)
            for row in range(attribute.shape[0]):
                if row not in att_idx:
                    attribute_est[row, int(max_idx[row])] = 1

            # Generate fake image
            output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute,
                          masks_shift, boxes_shift, attribute_est)
            crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output

            foreground = list(
                map(lambda x: (x.data[2] - x.data[0]) < 0.5, boxes))

            dict_to_save = {
                'imgs': imgs,
                'imgs_rand': img_rand,
                'imgs_shift': img_shift,
                'objs': objs,
                'boxes': boxes,
                'boxes_shift': boxes_shift,
                'obj_to_img': obj_to_img,
                'foreground': foreground
            }

            out_name = os.path.join(resinet_dir, 'batch_{}.pkl'.format(i))
            pickle.dump(dict_to_save, open(out_name, 'wb'))

            img_rand = imagenet_deprocess_batch(img_rand).type(torch.int32)
            imgs = imagenet_deprocess_batch(imgs).type(torch.int32)
            img_shift = imagenet_deprocess_batch(img_shift).type(torch.int32)

            # Save the generated images
            fore_count = 0
            background_dist = 0
            foreground_dist = 0
            rand_diff = 0
            for j in range(img_rand.shape[0]):

                obj_indices = torch.nonzero(obj_to_img == j).view(-1)
                foreground_obj_indices = [
                    int(i) for i in obj_indices if foreground[i] == 1
                ]
                if foreground_obj_indices != []:
                    foreground_mask = torch.max(
                        torch.max(masks[foreground_obj_indices], 0)[0],
                        torch.max(masks_shift[foreground_obj_indices],
                                  0)[0]).byte()
                else:
                    print("no foreground")
                    foreground_mask = torch.zeros(1, 64, 64).byte().to(device)
                background_dist += safe_division(
                    torch.abs((img_rand[j] - img_shift[j]) *
                              (~foreground_mask).cpu().int()).sum(),
                    (3 * (~foreground_mask).cpu().float().sum()))

                img_np = (img_rand[j] *
                          (~foreground_mask).int().cpu()).numpy().transpose(
                              1, 2, 0)
                img_path = os.path.join(
                    result_save_dir1,
                    'img{:06d}.png'.format(i * config.batch_size + j))
                imwrite(img_path, img_np)

                img_np = (img_shift[j] *
                          (~foreground_mask).int().cpu()).numpy().transpose(
                              1, 2, 0)
                img_path = os.path.join(
                    result_save_dir2,
                    'img{:06d}.png'.format(i * config.batch_size + j))
                imwrite(img_path, img_np)

                for foreground_i in foreground_obj_indices:
                    try:
                        foreground_dist += torch.abs(
                            torch.masked_select(
                                img_rand[j], masks[foreground_i].cpu().byte())
                            - torch.masked_select(
                                img_shift[j], masks_shift[foreground_i].cpu().
                                byte())).sum() / (3 *
                                                  (masks[foreground_i]).sum())
                        fore_count += 1
                    except:
                        continue
            rand_diff = torch.abs(img_rand[j] -
                                  img_shift[j - 1]).sum() / (3 * 64 * 64)

            L1_dist_b += background_dist / config.batch_size
            L1_dist_f += safe_division(foreground_dist, fore_count)
            L1_rand_dist += rand_diff

            try:
                print("runing b_dist = %d, f_dist = %d, rand_diff = %d" %
                      (background_dist / config.batch_size,
                       safe_division(foreground_dist, fore_count), rand_diff))
            except:
                print(math.isnan(background_dist / config.batch_size),
                      math.isnan(safe_division(foreground_dist, fore_count)),
                      math.isnan(rand_diff))

        print("Background L1 dist = %f" % (L1_dist_b / i))
        print("Foreground L1 dist = %f" % (L1_dist_f / i))
        print("Rand L1 dist = %f" % (L1_rand_dist / i))
Пример #10
0
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False, input_size=224, start_epoch=0):
    since = time.time()

    val_acc_history = []

    writer = {"train": SummaryWriter(log_save_dir + exp_name + "_train/"),
              "val": SummaryWriter(log_save_dir + exp_name + "_val/")}

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(start_epoch, num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            running_count = 0

            # Iterate over data.
            for i, batch in enumerate(dataloaders[phase]):
                imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch
                inputs = crop_bbox_batch(imgs, boxes, obj_to_img, input_size)
                inputs = inputs.to(device)
                labels = objs.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if is_inception and phase == 'train':
                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4 * loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * labels.size(0)
                running_corrects += torch.sum(preds == labels.data)
                running_count += labels.size(0)

                if (i+1) % 20 == 0:
                    print('epoch: {:04d} iter: {:08d} loss: {:.4f} accu: {:.4f}'.format(epoch+1, i+1, loss.item(), torch.mean((preds == labels.data).float())))

                # Logging.
                loss_logging = {}
                loss_logging['avg loss'] = loss.item()
                loss_logging['accu'] = torch.mean((preds == labels.data).float())

                if (i + 1) % 100 == 0:
                    for tag, roi_value in loss_logging.items():
                        writer[phase].add_scalar(tag, roi_value, epoch * iters_per_epoch[phase] + i + 1)

            epoch_loss = running_loss / running_count
            epoch_acc = running_corrects.double() / running_count

            print('================================================================')
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val': # and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                save_model(model, model_save_dir, '128_resinet50_vg_best', epoch + 1, save_step=1)
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    writer['train'].close()
    writer['test'].close()

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history