Exemple #1
0
nbmasks = len(allmasks)
print('Number of masks: %d'%nbmasks)
dataset = Inferer(imfile= allimages, mfiles = allmasks,  category_names = category_names,
                  transform=transform, final_img_size=img_size)

train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
#val_loader = DataLoader(dataset_val, batch_size=batch_size, shuffle=True)


#--------------------------Define Networks------------------------------------
G_local = Generator_Baseline_2(z_dim=noise_size, label_channel=len(category_names),num_res_blocks=num_res_blocks)
G_local.cuda()
G_local.load_state_dict(torch.load(opt.maskmodel))
print('Parameters for mask model are loaded!')

G_fg = Generator_FG(z_dim=noise_size, label_channel=len(category_names),num_res_blocks=5)
G_fg.load_state_dict(torch.load(opt.fgmodel))
print('Parameters for FG model are loaded!')
G_fg.cuda()


#---------------------- results save folder-----------------------------------
root = './results_KID'
mask_folder = os.path.join(root, opt.category_name, 'ms')
im_folder = os.path.join(root, opt.category_name, 'ims')

os.makedirs(mask_folder, exist_ok=True)
os.makedirs(im_folder, exist_ok=True)

#Save the file for inspection later
Exemple #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_imgs', type=str, help='dataset path')
    parser.add_argument('--mask_imgs', type=str, help='dataset path')
    parser.add_argument('--log_dir',
                        type=str,
                        default='log',
                        help='Name of the log folder')
    parser.add_argument('--save_models',
                        type=bool,
                        default=True,
                        help='Set True if you want to save trained models')
    parser.add_argument('--pre_trained_model_path',
                        type=str,
                        default=None,
                        help='Pre-trained model path')
    parser.add_argument('--pre_trained_model_epoch',
                        type=str,
                        default=None,
                        help='Pre-trained model epoch e.g 200')
    parser.add_argument('--train_imgs_path',
                        type=str,
                        default='C:/Users/motur/coco/images/train2017',
                        help='Path to training images')
    parser.add_argument(
        '--train_annotation_path',
        type=str,
        default='C:/Users/motur/coco/annotations/instances_train2017.json',
        help='Path to annotation file, .json file')
    parser.add_argument('--category_names',
                        type=str,
                        default='giraffe,elephant,zebra,sheep,cow,bear',
                        help='List of categories in MS-COCO dataset')
    parser.add_argument('--num_test_img',
                        type=int,
                        default=16,
                        help='Number of images saved during training')
    parser.add_argument('--img_size',
                        type=int,
                        default=256,
                        help='Generated image size')
    parser.add_argument(
        '--local_patch_size',
        type=int,
        default=256,
        help='Image size of instance images after interpolation')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='Mini-batch size')
    parser.add_argument('--train_epoch',
                        type=int,
                        default=20,
                        help='Maximum training epoch')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='Initial learning rate')
    parser.add_argument('--optim_step_size',
                        type=int,
                        default=80,
                        help='Learning rate decay step size')
    parser.add_argument('--optim_gamma',
                        type=float,
                        default=0.5,
                        help='Learning rate decay ratio')
    parser.add_argument(
        '--critic_iter',
        type=int,
        default=5,
        help='Number of discriminator update against each generator update')
    parser.add_argument('--noise_size',
                        type=int,
                        default=128,
                        help='Noise vector size')
    parser.add_argument('--lambda_FM',
                        type=float,
                        default=1,
                        help='Trade-off param for feature matching loss')
    parser.add_argument('--lambda_recon',
                        type=float,
                        default=0.00001,
                        help='Trade-off param for reconstruction loss')
    parser.add_argument('--num_res_blocks',
                        type=int,
                        default=5,
                        help='Number of residual block in generator network')
    parser.add_argument(
        '--trade_off_G',
        type=float,
        default=0.1,
        help=
        'Trade-off parameter which controls gradient flow to generator from D_local and D_glob'
    )

    opt = parser.parse_args()
    print(opt)

    #Create log folder
    root = 'result_fg/' + opt.category_names + '/'
    model = 'coco_model_'
    result_folder_name = 'images_' + opt.log_dir
    model_folder_name = 'models_' + opt.log_dir
    if not os.path.isdir(root):
        os.makedirs(root)
    if not os.path.isdir(root + result_folder_name):
        os.makedirs(root + result_folder_name)
    if not os.path.isdir(root + model_folder_name):
        os.makedirs(root + model_folder_name)

    #Save the script
    copyfile(os.path.basename(__file__),
             root + result_folder_name + '/' + os.path.basename(__file__))

    #Define transformation for dataset images - e.g scaling
    transform = transforms.Compose([
        transforms.Scale((opt.img_size, opt.img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    #Load dataset
    category_names = opt.category_names.split(',')
    allmasks = sorted(
        glob.glob(os.path.join(opt.mask_imgs, '**', '*.png'), recursive=True))
    print('Number of masks: %d' % len(allmasks))
    dataset = chairs(imfile=opt.train_imgs,
                     mfiles=allmasks,
                     category_names=category_names,
                     transform=transform,
                     final_img_size=opt.img_size)

    #Discard images contain very small instances
    # dataset.discard_small(min_area=0.03, max_area=1)

    #Define data loader
    train_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True)

    #For evaluation define fixed masks and noises
    data_iter = iter(train_loader)
    sample_batched = data_iter.next()
    x_fixed = sample_batched['image'][0:opt.num_test_img]
    x_fixed = Variable(x_fixed.cuda())
    y_fixed = sample_batched['single_fg_mask'][0:opt.num_test_img]
    y_fixed = Variable(y_fixed.cuda())
    z_fixed = torch.randn((opt.num_test_img, opt.noise_size))
    z_fixed = Variable(z_fixed.cuda())

    #Define networks
    G_fg = Generator_FG(z_dim=opt.noise_size,
                        label_channel=len(category_names),
                        num_res_blocks=opt.num_res_blocks)
    D_glob = Discriminator(channels=3 + len(category_names))
    D_instance = Discriminator(channels=3 + len(category_names),
                               input_size=opt.local_patch_size)
    G_fg.cuda()
    D_glob.cuda()
    D_instance.cuda()

    #Load parameters from pre-trained models
    if opt.pre_trained_model_path != None and opt.pre_trained_model_epoch != None:
        try:
            G_fg.load_state_dict(
                torch.load(opt.pre_trained_model_path + 'G_fg_epoch_' +
                           opt.pre_trained_model_epoch))
            D_glob.load_state_dict(
                torch.load(opt.pre_trained_model_path + 'D_glob_epoch_' +
                           opt.pre_trained_model_epoch))
            D_instance.load_state_dict(
                torch.load(opt.pre_trained_model_path + 'D_local_epoch_' +
                           opt.pre_trained_model_epoch))
            print('Parameters are loaded!')
        except:
            print('Error: Pre-trained parameters are not loaded!')
            pass

    #Define interpolation operation
    up_instance = nn.Upsample(size=(opt.local_patch_size,
                                    opt.local_patch_size),
                              mode='bilinear')

    #Define pooling operation for the case that image size and local patch size are mismatched
    pooling_instance = nn.Sequential()
    if opt.local_patch_size != opt.img_size:
        pooling_instance.add_module(
            '0', nn.AvgPool2d(int(opt.img_size / opt.local_patch_size)))

    #Define training loss function - binary cross entropy
    BCE_loss = nn.BCELoss()

    #Define feature matching loss
    criterionVGG = VGGLoss()
    criterionVGG = criterionVGG.cuda()

    #Define optimizer
    G_local_optimizer = optim.Adam(G_fg.parameters(),
                                   lr=opt.lr,
                                   betas=(0.0, 0.9))
    D_local_optimizer = optim.Adam(
        list(filter(lambda p: p.requires_grad, D_glob.parameters())) +
        list(filter(lambda p: p.requires_grad, D_instance.parameters())),
        lr=opt.lr,
        betas=(0.0, 0.9))
    #Deine learning rate scheduler
    scheduler_G = lr_scheduler.StepLR(G_local_optimizer,
                                      step_size=opt.optim_step_size,
                                      gamma=opt.optim_gamma)
    scheduler_D = lr_scheduler.StepLR(D_local_optimizer,
                                      step_size=opt.optim_step_size,
                                      gamma=opt.optim_gamma)

    #----------------------------TRAIN-----------------------------------------
    print('training start!')
    start_time = time.time()

    for epoch in range(opt.train_epoch):
        epoch_start_time = time.time()

        scheduler_G.step()
        scheduler_D.step()

        D_local_losses = []
        G_local_losses = []

        y_real_ = torch.ones(opt.batch_size)
        y_fake_ = torch.zeros(opt.batch_size)
        y_real_, y_fake_ = Variable(y_real_.cuda()), Variable(y_fake_.cuda())

        data_iter = iter(train_loader)
        num_iter = 0
        while num_iter < len(train_loader):

            j = 0
            while j < opt.critic_iter and num_iter < len(train_loader):
                j += 1
                sample_batched = data_iter.next()
                num_iter += 1
                x_ = sample_batched['image']
                y_ = sample_batched['single_fg_mask']
                fg_mask = sample_batched['seg_mask']

                y_instances = sample_batched['mask_instance']
                bbox = sample_batched['bbox']

                mini_batch = x_.size()[0]
                if mini_batch != opt.batch_size:
                    break

                #Update discriminators - D
                #Real examples
                D_glob.zero_grad()
                D_instance.zero_grad()

                x_, y_ = Variable(x_.cuda()), Variable(y_.cuda())
                fg_mask = Variable(fg_mask.cuda())
                y_reduced = torch.sum(y_,
                                      1).clamp(0,
                                               1).view(y_.size(0), 1,
                                                       opt.img_size,
                                                       opt.img_size)

                x_d = torch.cat([x_, fg_mask], 1)

                x_instances = torch.zeros(
                    (opt.batch_size, 3, opt.local_patch_size,
                     opt.local_patch_size))
                x_instances = Variable(x_instances.cuda())
                y_instances = Variable(y_instances.cuda())
                y_instances = pooling_instance(y_instances)
                G_instances = torch.zeros(
                    (opt.batch_size, 3, opt.local_patch_size,
                     opt.local_patch_size))
                G_instances = Variable(G_instances.cuda())

                #Obtain instances
                for t in range(x_d.size()[0]):
                    x_instance = x_[t, 0:3, bbox[0][t]:bbox[1][t],
                                    bbox[2][t]:bbox[3][t]]
                    x_instance = x_instance.contiguous().view(
                        1,
                        x_instance.size()[0],
                        x_instance.size()[1],
                        x_instance.size()[2])
                    x_instances[t] = up_instance(x_instance)

                D_result_instance = D_instance(
                    torch.cat([x_instances, y_instances], 1)).squeeze()
                D_result = D_glob(x_d).squeeze()
                D_real_loss = BCE_loss(D_result, y_real_) + BCE_loss(
                    D_result_instance, y_real_)
                D_real_loss.backward()

                #Fake examples
                z_ = torch.randn((mini_batch, opt.noise_size))
                z_ = Variable(z_.cuda())

                #Generate fake images
                G_fg_result = G_fg(z_, y_, torch.mul(x_, (1 - y_reduced)))
                G_result_d = torch.cat([G_fg_result, fg_mask], 1)

                #Obtain fake instances
                for t in range(x_d.size()[0]):
                    G_instance = G_result_d[t, 0:3, bbox[0][t]:bbox[1][t],
                                            bbox[2][t]:bbox[3][t]]
                    G_instance = G_instance.contiguous().view(
                        1,
                        G_instance.size()[0],
                        G_instance.size()[1],
                        G_instance.size()[2])
                    G_instances[t] = up_instance(G_instance)

                D_result_instance = D_instance(
                    torch.cat([G_instances, y_instances],
                              1).detach()).squeeze()
                D_result = D_glob(G_result_d.detach()).squeeze()
                D_fake_loss = BCE_loss(D_result, y_fake_) + BCE_loss(
                    D_result_instance, y_fake_)
                D_fake_loss.backward()
                D_local_optimizer.step()

                D_train_loss = D_real_loss + D_fake_loss
                D_local_losses.append(D_train_loss.data)

            if mini_batch != opt.batch_size:
                break

            #Update generator G
            G_fg.zero_grad()
            D_result = D_glob(G_result_d).squeeze()
            D_result_instance = D_instance(
                torch.cat([G_instances, y_instances], 1)).squeeze()
            G_train_loss = (1 - opt.trade_off_G) * BCE_loss(
                D_result, y_real_) + opt.trade_off_G * BCE_loss(
                    D_result_instance, y_real_)

            #Feature matching loss between generated image and corresponding ground truth
            FM_loss = criterionVGG(G_fg_result, x_)

            #Reconstruction loss
            Recon_loss = mse_loss(torch.mul(x_, (1 - y_reduced)),
                                  torch.mul(G_fg_result, (1 - y_reduced)))

            total_loss = G_train_loss + opt.lambda_FM * FM_loss + opt.lambda_recon * Recon_loss
            total_loss.backward()
            G_local_optimizer.step()
            G_local_losses.append(G_train_loss.data)

            print('loss_d: %.3f, loss_g: %.3f' %
                  (D_train_loss.data, G_train_loss.data))
            if (num_iter % 100) == 0:
                print('%d - %d complete!' % ((epoch + 1), num_iter))
                print(result_folder_name)

        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time
        print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' %
              ((epoch + 1), opt.train_epoch, per_epoch_ptime,
               torch.mean(torch.FloatTensor(D_local_losses)),
               torch.mean(torch.FloatTensor(G_local_losses))))

        #Save images
        G_fg.eval()

        if epoch == 0:
            show_result_rgb((epoch + 1),
                            x_fixed,
                            save=True,
                            path=root + result_folder_name + '/' + model +
                            str(epoch + 1) + '_gt.png')
            for t in range(y_fixed.size()[1]):
                show_result_rgb((epoch + 1),
                                y_fixed[:, t:t + 1, :, :],
                                save=True,
                                path=root + result_folder_name + '/' + model +
                                str(epoch + 1) + '_' + str(t) + '_masked.png')

        show_result_rgb(
            (epoch + 1),
            G_fg(
                z_fixed, y_fixed,
                torch.mul(x_fixed, (1 - torch.sum(y_fixed, 1).view(
                    y_fixed.size(0), 1, opt.img_size, opt.img_size)))),
            save=True,
            path=root + result_folder_name + '/' + model + str(epoch + 1) +
            '_fg.png')
        G_fg.train()

        #Save model params
        if opt.save_models and (epoch > 11 and epoch % 10 == 0):
            torch.save(
                G_fg.state_dict(), root + model_folder_name + '/' + model +
                'G_fg_epoch_' + str(epoch) + '.pth')
            torch.save(
                D_glob.state_dict(), root + model_folder_name + '/' + model +
                'D_glob_epoch_' + str(epoch) + '.pth')
            torch.save(
                D_instance.state_dict(), root + model_folder_name + '/' +
                model + 'D_local_epoch_' + str(epoch) + '.pth')

    torch.save(
        G_fg.state_dict(), root + model_folder_name + '/' + model +
        'G_fg_epoch_' + str(epoch) + '.pth')
    torch.save(
        D_glob.state_dict(), root + model_folder_name + '/' + model +
        'D_glob_epoch_' + str(epoch) + '.pth')
    torch.save(
        D_instance.state_dict(), root + model_folder_name + '/' + model +
        'D_local_epoch_' + str(epoch) + '.pth')
    end_time = time.time()
    total_ptime = end_time - start_time
    print("Training finish!... save training results")
    print('Training time: ' + str(total_ptime))
Exemple #3
0
def main(rank):
    
    #Seed - Added for TPU purposes
    torch.manual_seed(1)
       
    #Create log folder
    root = 'result_fg/'
    model = 'coco_model_'
    result_folder_name = 'images_' + FLAGS['log_dir']
    model_folder_name = 'models_' + FLAGS['log_dir']
    if not os.path.isdir(root):
        os.mkdir(root)
    if not os.path.isdir(root + result_folder_name):
        os.mkdir(root + result_folder_name)
    if not os.path.isdir(root + model_folder_name):
        os.mkdir(root + model_folder_name)
    
    #Save the script
    copyfile(os.path.basename(__file__), root + result_folder_name + '/' + os.path.basename(__file__))
    
    #Define transformation for dataset images - e.g scaling
    transform = transforms.Compose(
        [
            transforms.Scale((FLAGS['img_size'],FLAGS['img_size'])),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    ) 
    #Load dataset
    category_names = FLAGS['category_names'].split(',')
    
    #Serial Executor - This is needed to spread inside TPU for memory purposes
    SERIAL_EXEC = xmp.MpSerialExecutor()
    
    #Define Dataset
    dataset = SERIAL_EXEC.run(
        lambda: CocoData(
            root = FLAGS['train_imgs_path'],
            annFile = FLAGS['train_annotation_path'],
            category_names = category_names,
            transform=transform,
            final_img_size=FLAGS['img_size']
        )
    )
    
    #Discard images contain very small instances  
    dataset.discard_small(min_area=0.03, max_area=1)
    
    #Define data sampler - Added for TPU purposes
    train_sampler = DistributedSampler(
        dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )

    #Define data loader
    train_loader = DataLoader( #Modified for TPU purposes
        dataset,
        batch_size=FLAGS['batch_size'],
        sampler=train_sampler,
        num_workers=FLAGS['num_workers']
        # shuffle=True
    )

    #Define device - Added for TPU purposes
    device = xm.xla_device(devkind='TPU')

    #For evaluation define fixed masks and noises
    data_iter = iter(train_loader)
    sample_batched = data_iter.next()  
    x_fixed = sample_batched['image'][0:FLAGS['num_test_img']]
    x_fixed = Variable(x_fixed.to(device))
    y_fixed = sample_batched['single_fg_mask'][0:FLAGS['num_test_img']]
    y_fixed = Variable(y_fixed.to(device))
    z_fixed = torch.randn((FLAGS['num_test_img'],FLAGS['noise_size']))
    z_fixed = Variable(z_fixed.to(device))
    
    #Define networks
    generator = Generator_FG(
        z_dim=FLAGS['noise_size'],
        label_channel=len(category_names),
        num_res_blocks=FLAGS['num_res_blocks']
    )

    discriminator_glob = Discriminator(
        channels=3+len(category_names)
    )

    discriminator_instance = Discriminator(
        channels=3+len(category_names),
        input_size=FLAGS['local_patch_size']
    )

    WRAPPED_GENERATOR = xmp.MpModelWrapper(generator) #Added for TPU purposes
    WRAPPED_DISCRIMINATOR_GLOB = xmp.MpModelWrapper(discriminator) #Added for TPU purposes
    WRAPPED_DISCRIMINATOR_INSTANCE = xmp.MpModelWrapper(discriminator) #Added for TPU purposes

    G_fg = WRAPPED_GENERATOR.to(device) #Modified for TPU purposes
    D_glob = WRAPPED_DISCRIMINATOR.to(device) #Modified for TPU purposes
    D_instance = WRAPPED_DISCRIMINATOR.to(device) #Modified for TPU purposes
    
    #Load parameters from pre-trained models
    if FLAGS['pre_trained_model_path'] != None and FLAGS['pre_trained_model_epoch'] != None:
        try:
            G_fg.load_state_dict(xser.load(FLAGS['pre_trained_model_path'] + 'G_fg_epoch_' + FLAGS['pre_trained_model_epoch']))
            D_glob.load_state_dict(xser.load(FLAGS['pre_trained_model_path'] + 'D_glob_epoch_' + FLAGS['pre_trained_model_epoch']))
            D_instance.load_state_dict(xser.load(FLAGS['pre_trained_model_path'] + 'D_local_epoch_' + FLAGS['pre_trained_model_epoch']))
  
            xm.master_print('Parameters are loaded!')
        except:
            xm.master_print('Error: Pre-trained parameters are not loaded!')
            pass
    
    #Define interpolation operation
    up_instance =  nn.Upsample(
        size=(FLAGS['local_patch_size'],FLAGS['local_patch_size']),
        mode='bilinear'
    )
    
    #Define pooling operation for the case that image size and local patch size are mismatched
    pooling_instance = nn.Sequential()
    if FLAGS['local_patch_size']!=FLAGS['img_size']:
        pooling_instance.add_module(
            '0',
            nn.AvgPool2d(int(FLAGS['img_size']/FLAGS['local_patch_size']))
        )
        
    #Define training loss function - binary cross entropy
    BCE_loss = nn.BCELoss()
    
    #Define feature matching loss
    criterionVGG = VGGLoss()
    criterionVGG = criterionVGG.to(device) #Modified for TPU Purposes
         
    #Define optimizer
    G_local_optimizer = optim.Adam(
        G_fg.parameters(),
        lr=FLAGS['lr'],
        betas=(0.0, 0.9)
    )
    D_local_optimizer = optim.Adam(
        list(filter(lambda p: p.requires_grad, D_glob.parameters())) + list(filter(lambda p: p.requires_grad, D_instance.parameters())),
        lr=FLAGS['lr'],
        betas=(0.0,0.9)
    )

    #Deine learning rate scheduler
    scheduler_G = lr_scheduler.StepLR(
        G_local_optimizer,
        step_size=FLAGS['optim_step_size'],
        gamma=FLAGS['optim_gamma']
    )
    scheduler_D = lr_scheduler.StepLR(
        D_local_optimizer,
        step_size=FLAGS['optim_step_size'],
        gamma=FLAGS['optim_gamma']
    )
    
    #----------------------------TRAIN-----------------------------------------
    xm.master_print('training start!')
    tracker = xm.RateTracker() #Added for TPU reasons
    start_time = time.time()
    
    for epoch in range(FLAGS['train_epoch']):
        epoch_start_time = time.time()
        para_loader = pl.ParallelLoader(train_loader, [device]) #Added for TPU purposes
        loader = para_loader.per_device_loader(device) #Added for TPU purposes
         
        D_local_losses = []
        G_local_losses = []
    
        y_real_ = torch.ones(FLAGS['batch_size'])
        y_fake_ = torch.zeros(FLAGS['batch_size'])
        y_real_ = Variable(y_real_.to(device)) #Modified for TPU purposes
        y_fake_ = Variable(y_fake_.to(device)) #Modified for TPU purposes

        data_iter = iter(loader)
        num_iter = 0

        while num_iter < len(loader): #Modified for TPU purposes 
            j=0
            while j < FLAGS['critic_iter'] and num_iter < len(loader):
                j += 1
                sample_batched = data_iter.next()  
                num_iter += 1

                x_ = sample_batched['image']
                x_ = Variable(x_.to(device)) #Modified for TPU purposes

                y_ = sample_batched['single_fg_mask']
                y_ = Variable(y_.to(device)) #Modified for TPU purposes

                fg_mask = sample_batched['seg_mask']
                fg_mask = Variable(fg_mask.to(device)) #Modified for TPU purposes

                y_instances = sample_batched['mask_instance']
                bbox = sample_batched['bbox']
                
                mini_batch = x_.size()[0]
                if mini_batch != FLAGS['batch_size']:
                    break
                
                #Update discriminators - D 
                #Real examples
                D_glob.zero_grad()
                D_instance.zero_grad()
                    
                y_reduced = torch.sum(y_,1).clamp(0,1).view(y_.size(0),1,FLAGS['img_size'],FLAGS['img_size'])
                
                x_d = torch.cat([x_,fg_mask],1)
                
                x_instances = torch.zeros((FLAGS['batch_size'],3,FLAGS['local_patch_size'],FLAGS['local_patch_size']))
                x_instances = Variable(x_instances.to(device))
                y_instances = Variable(y_instances.to(device))
                y_instances = pooling_instance(y_instances)
                G_instances = torch.zeros((FLAGS['batch_size'],3,FLAGS['local_patch_size'],FLAGS['local_patch_size']))
                G_instances = Variable(G_instances.to(device))
                      
                #Obtain instances
                for t in range(x_d.size()[0]):
                    x_instance = x_[t,0:3,bbox[0][t]:bbox[1][t],bbox[2][t]:bbox[3][t]] 
                    x_instance = x_instance.contiguous().view(1,x_instance.size()[0],x_instance.size()[1],x_instance.size()[2]) 
                    x_instances[t] = up_instance(x_instance)
                    
                D_result_instance = D_instance(torch.cat([x_instances,y_instances],1)).squeeze()       
                D_result = D_glob(x_d).squeeze()
                D_real_loss = BCE_loss(D_result, y_real_) +  BCE_loss(D_result_instance, y_real_)
                D_real_loss.backward()
                
                #Fake examples
                z_ = torch.randn((mini_batch,FLAGS['noise_size']))
                z_ = Variable(z_.to(device))
    
                #Generate fake images
                G_fg_result = G_fg(z_,y_, torch.mul(x_,(1-y_reduced)))
                G_result_d = torch.cat([G_fg_result,fg_mask],1) 
                            
                #Obtain fake instances
                for t in range(x_d.size()[0]):
                    G_instance = G_result_d[t,0:3,bbox[0][t]:bbox[1][t],bbox[2][t]:bbox[3][t]] 
                    G_instance = G_instance.contiguous().view(1,G_instance.size()[0],G_instance.size()[1],G_instance.size()[2]) 
                    G_instances[t] = up_instance(G_instance)
                
                
                D_result_instance = D_instance(torch.cat([G_instances,y_instances],1).detach()).squeeze() 
                D_result = D_glob(G_result_d.detach()).squeeze() 
                D_fake_loss = BCE_loss(D_result, y_fake_) +  BCE_loss(D_result_instance, y_fake_)
                D_fake_loss.backward()

                xm.optimizer_step(D_local_optimizer) #Modified for TPU purposes
                
                D_train_loss = D_real_loss + D_fake_loss
                D_local_losses.append(D_train_loss.data[0])
    
            if mini_batch != FLAGS['batch_size']:
                break  
            
            #Update generator G
            G_fg.zero_grad()   
            D_result = D_glob(G_result_d).squeeze() 
            D_result_instance = D_instance(torch.cat([G_instances,y_instances],1)).squeeze() 
            G_train_loss = (1-FLAGS['trade_off_G'])*BCE_loss(D_result, y_real_) + FLAGS['trade_off_G']*BCE_loss(D_result_instance, y_real_) 
            
            #Feature matching loss between generated image and corresponding ground truth
            FM_loss = criterionVGG(G_fg_result, x_)
            
            #Reconstruction loss
            Recon_loss = mse_loss(torch.mul(x_,(1-y_reduced) ), torch.mul(G_fg_result,(1-y_reduced))  )
    
            total_loss = G_train_loss + FLAGS['lambda_FM']*FM_loss + FLAGS['lambda_recon']*Recon_loss
            total_loss.backward() 

            xm.optimizer_step(G_local_optimizer)

            G_local_losses.append(G_train_loss.data[0])
    
            xm.master_print('loss_d: %.3f, loss_g: %.3f' % (D_train_loss.data[0],G_train_loss.data[0]))
            if (num_iter % 100) == 0:
                xm.master_print('%d - %d complete!' % ((epoch+1), num_iter))
                xm.master_print(result_folder_name)

        #Modified location of the scheduler step to avoid warning
        scheduler_G.step()
        scheduler_D.step()

        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time
        xm.master_print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), FLAGS['train_epoch'], per_epoch_ptime, torch.mean(torch.FloatTensor(D_local_losses)), torch.mean(torch.FloatTensor(G_local_losses))))
    
        #Save images
        G_fg.eval()
        
        if epoch == 0:
            show_result(
                (epoch+1),
                x_fixed,
                save=True,
                path=root + result_folder_name+ '/' + model + str(epoch + 1 ) + '_gt.png'
            )
            for t in range(y_fixed.size()[1]):
                show_result(
                    (epoch+1),
                    y_fixed[:,t:t+1,:,:],
                    save=True,
                    path=root + result_folder_name+ '/' + model + str(epoch + 1 ) +'_'+ str(t) +'_masked.png'
                )
            
        show_result(
            (epoch+1),
            G_fg(
                z_fixed,
                y_fixed,
                torch.mul(
                    x_fixed,
                    (1-torch.sum(y_fixed,1).view(y_fixed.size(0),1,FLAGS['img_size'],FLAGS['img_size']))
                )
            ),
            save=True,
            path=root + result_folder_name+ '/' + model + str(epoch + 1 ) + '_fg.png'
        )
        
        G_fg.train()
        
        #Save model params
        if FLAGS['save_models'] and (epoch>11 and epoch % 10 == 0 ):
            xser.save(
                G_fg.state_dict(),
                root + model_folder_name + '/' + model + 'G_fg_epoch_'+str(epoch)+'.pth'
                master_only=True
            )
            xser.save(
                D_glob.state_dict(),
                root + model_folder_name + '/' + model + 'D_glob_epoch_'+str(epoch)+'.pth'
                master_only=True
            )
            xser.save(
                D_instance.state_dict(),
                root + model_folder_name + '/' + model + 'D_local_epoch_'+str(epoch)+'.pth'
                master_only=True
            )
                         
    end_time = time.time()
    total_ptime = end_time - start_time
    xm.master_print("Training finish!... save training results")
    xm.master_print('Training time: ' + str(total_ptime))