def __init__(self, config):
        super(Trainer, self).__init__()
        self.config = config
        self.use_cuda = self.config['cuda']
        self.device_ids = self.config['gpu_ids']

        self.netG = Generator(self.config['netG'], self.use_cuda,
                              self.device_ids)
        self.localD = LocalDis(self.config['netD'], self.use_cuda,
                               self.device_ids)
        self.globalD = GlobalDis(self.config['netD'], self.use_cuda,
                                 self.device_ids)

        self.optimizer_g = torch.optim.Adam(self.netG.parameters(),
                                            lr=self.config['lr'],
                                            betas=(self.config['beta1'],
                                                   self.config['beta2']))
        d_params = list(self.localD.parameters()) + list(
            self.globalD.parameters())
        self.optimizer_d = torch.optim.Adam(d_params,
                                            lr=config['lr'],
                                            betas=(self.config['beta1'],
                                                   self.config['beta2']))
        if self.use_cuda:
            self.netG.to(self.device_ids[0])
            self.localD.to(self.device_ids[0])
            self.globalD.to(self.device_ids[0])
class Trainer(nn.Module):
    def __init__(self, config):
        super(Trainer, self).__init__()
        self.config = config
        self.use_cuda = self.config['cuda']
        self.device_ids = self.config['gpu_ids']

        self.netG = Generator(self.config['netG'], self.use_cuda,
                              self.device_ids)
        self.localD = LocalDis(self.config['netD'], self.use_cuda,
                               self.device_ids)
        self.globalD = GlobalDis(self.config['netD'], self.use_cuda,
                                 self.device_ids)

        self.optimizer_g = torch.optim.Adam(self.netG.parameters(),
                                            lr=self.config['lr'],
                                            betas=(self.config['beta1'],
                                                   self.config['beta2']))
        d_params = list(self.localD.parameters()) + list(
            self.globalD.parameters())
        self.optimizer_d = torch.optim.Adam(d_params,
                                            lr=config['lr'],
                                            betas=(self.config['beta1'],
                                                   self.config['beta2']))
        if self.use_cuda:
            self.netG.to(self.device_ids[0])
            self.localD.to(self.device_ids[0])
            self.globalD.to(self.device_ids[0])

    def forward(self, x, bboxes, masks, ground_truth, compute_loss_g=False):
        self.train()
        l1_loss = nn.L1Loss()
        losses = {}

        x1, x2, offset_flow = self.netG(x, masks)
        local_patch_gt = local_patch(ground_truth, bboxes)
        x1_inpaint = x1 * masks + x * (1. - masks)
        x2_inpaint = x2 * masks + x * (1. - masks)
        local_patch_x1_inpaint = local_patch(x1_inpaint, bboxes)
        local_patch_x2_inpaint = local_patch(x2_inpaint, bboxes)

        # D part
        # wgan d loss
        local_patch_real_pred, local_patch_fake_pred = self.dis_forward(
            self.localD, local_patch_gt, local_patch_x2_inpaint.detach())
        global_real_pred, global_fake_pred = self.dis_forward(
            self.globalD, ground_truth, x2_inpaint.detach())
        losses['wgan_d'] = torch.mean(local_patch_fake_pred - local_patch_real_pred) + \
            torch.mean(global_fake_pred - global_real_pred) * self.config['global_wgan_loss_alpha']
        # gradients penalty loss
        local_penalty = self.calc_gradient_penalty(
            self.localD, local_patch_gt, local_patch_x2_inpaint.detach())
        global_penalty = self.calc_gradient_penalty(self.globalD, ground_truth,
                                                    x2_inpaint.detach())
        losses['wgan_gp'] = local_penalty + global_penalty

        # G part
        if compute_loss_g:
            sd_mask = spatial_discounting_mask(self.config)
            losses['l1'] = l1_loss(local_patch_x1_inpaint * sd_mask, local_patch_gt * sd_mask) * \
                self.config['coarse_l1_alpha'] + \
                l1_loss(local_patch_x2_inpaint * sd_mask, local_patch_gt * sd_mask)
            losses['ae'] = l1_loss(x1 * (1. - masks), ground_truth * (1. - masks)) * \
                self.config['coarse_l1_alpha'] + \
                l1_loss(x2 * (1. - masks), ground_truth * (1. - masks))

            # wgan g loss
            local_patch_real_pred, local_patch_fake_pred = self.dis_forward(
                self.localD, local_patch_gt, local_patch_x2_inpaint)
            global_real_pred, global_fake_pred = self.dis_forward(
                self.globalD, ground_truth, x2_inpaint)
            losses['wgan_g'] = - torch.mean(local_patch_fake_pred) - \
                torch.mean(global_fake_pred) * self.config['global_wgan_loss_alpha']

        return losses, x2_inpaint, offset_flow

    def dis_forward(self, netD, ground_truth, x_inpaint):
        assert ground_truth.size() == x_inpaint.size()
        batch_size = ground_truth.size(0)
        batch_data = torch.cat([ground_truth, x_inpaint], dim=0)
        batch_output = netD(batch_data)
        real_pred, fake_pred = torch.split(batch_output, batch_size, dim=0)

        return real_pred, fake_pred

    # Calculate gradient penalty
    def calc_gradient_penalty(self, netD, real_data, fake_data):
        batch_size = real_data.size(0)
        alpha = torch.rand(batch_size, 1, 1, 1)
        alpha = alpha.expand_as(real_data)
        if self.use_cuda:
            alpha = alpha.cuda()

        interpolates = alpha * real_data + (1 - alpha) * fake_data
        interpolates = interpolates.requires_grad_().clone()

        disc_interpolates = netD(interpolates)
        grad_outputs = torch.ones(disc_interpolates.size())

        if self.use_cuda:
            grad_outputs = grad_outputs.cuda()

        gradients = autograd.grad(outputs=disc_interpolates,
                                  inputs=interpolates,
                                  grad_outputs=grad_outputs,
                                  create_graph=True,
                                  retain_graph=True,
                                  only_inputs=True)[0]

        gradients = gradients.view(batch_size, -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()

        return gradient_penalty

    def inference(self, x, masks):
        self.eval()
        x1, x2, offset_flow = self.netG(x, masks)
        # x1_inpaint = x1 * masks + x * (1. - masks)
        x2_inpaint = x2 * masks + x * (1. - masks)

        return x2_inpaint, offset_flow

    def save_model(self, checkpoint_dir, iteration):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(checkpoint_dir, 'gen_%08d.pt' % iteration)
        dis_name = os.path.join(checkpoint_dir, 'dis_%08d.pt' % iteration)
        opt_name = os.path.join(checkpoint_dir, 'optimizer.pt')
        torch.save(self.netG.state_dict(), gen_name)
        torch.save(
            {
                'localD': self.localD.state_dict(),
                'globalD': self.globalD.state_dict()
            }, dis_name)
        torch.save(
            {
                'gen': self.optimizer_g.state_dict(),
                'dis': self.optimizer_d.state_dict()
            }, opt_name)

    def resume(self, checkpoint_dir, iteration=0, test=False):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir,
                                         "gen",
                                         iteration=iteration)
        self.netG.load_state_dict(torch.load(last_model_name))
        iteration = int(last_model_name[-11:-3])

        if not test:
            # Load discriminators
            last_model_name = get_model_list(checkpoint_dir,
                                             "dis",
                                             iteration=iteration)
            state_dict = torch.load(last_model_name)
            self.localD.load_state_dict(state_dict['localD'])
            self.globalD.load_state_dict(state_dict['globalD'])
            # Load optimizers
            state_dict = torch.load(
                os.path.join(checkpoint_dir, 'optimizer.pt'))
            self.optimizer_d.load_state_dict(state_dict['dis'])
            self.optimizer_g.load_state_dict(state_dict['gen'])

        print("Resume from {} at iteration {}".format(checkpoint_dir,
                                                      iteration))
        logger.info("Resume from {} at iteration {}".format(
            checkpoint_dir, iteration))

        return iteration
def train_distributed(config, logger, writer, checkpoint_path):
    
    dist.init_process_group(                                   
        backend='nccl',
#         backend='gloo',
        init_method='env://'
    )  
    
    
    # Find out what GPU on this compute node.
    #
    local_rank = torch.distributed.get_rank()
    
    
    # this is the total # of GPUs across all nodes
    # if using 2 nodes with 4 GPUs each, world size is 8
    #
    world_size = torch.distributed.get_world_size()
    print("### global rank of curr node: {} of {}".format(local_rank, world_size))
    
    
    # For multiprocessing distributed, DistributedDataParallel constructor
    # should always set the single device scope, otherwise,
    # DistributedDataParallel will use all available devices.
    #
    print("local_rank: ", local_rank)
#     dist.barrier()
    torch.cuda.set_device(local_rank)
    
    
    # Define the trainer
    print("Creating models on device: ", local_rank)
    
    
    input_dim = config['netG']['input_dim']
    cnum = config['netG']['ngf']
    use_cuda = True
    gated = config['netG']['gated']
    
    
    # Models
    #
    netG = Generator(config['netG'], use_cuda=True, device=local_rank).cuda()
    netG = torch.nn.parallel.DistributedDataParallel(
        netG,
        device_ids=[local_rank],
        output_device=local_rank,
        find_unused_parameters=True
    )

    
    localD = LocalDis(config['netD'], use_cuda=True, device_id=local_rank).cuda()
    localD = torch.nn.parallel.DistributedDataParallel(
        localD,
        device_ids=[local_rank],
        output_device=local_rank,
        find_unused_parameters=True
    )
    
    
    globalD = GlobalDis(config['netD'], use_cuda=True, device_id=local_rank).cuda()
    globalD = torch.nn.parallel.DistributedDataParallel(
        globalD,
        device_ids=[local_rank],
        output_device=local_rank,
        find_unused_parameters=True
    )
    
    
    if local_rank == 0:
        logger.info("\n{}".format(netG))
        logger.info("\n{}".format(localD))
        logger.info("\n{}".format(globalD))
        
    
    # Optimizers
    #
    optimizer_g = torch.optim.Adam(
        netG.parameters(),
        lr=config['lr'],
        betas=(config['beta1'], config['beta2'])
    )

    
    d_params = list(localD.parameters()) + list(globalD.parameters())
    optimizer_d = torch.optim.Adam(
        d_params,  
        lr=config['lr'],                                    
        betas=(config['beta1'], config['beta2'])                              
    )
    
    
    # Data
    #
    sampler = None
    train_dataset = Dataset(
        data_path=config['train_data_path'],
        with_subfolder=config['data_with_subfolder'],
        image_shape=config['image_shape'],
        random_crop=config['random_crop']
    )
        
    
    sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
#             num_replicas=torch.cuda.device_count(),
        num_replicas=len(config['gpu_ids']),
#         rank = local_rank
    )
    
    
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=config['batch_size'],
        shuffle=(sampler is None),
        num_workers=config['num_workers'],
        pin_memory=True,
        sampler=sampler,
        drop_last=True
    )
    
    
    # Get the resume iteration to restart training
    #
#     start_iteration = trainer.resume(config['resume']) if config['resume'] else 1
    start_iteration = 1
    print("\n\nStarting epoch: ", start_iteration)

    iterable_train_loader = iter(train_loader)

    if local_rank == 0: 
        time_count = time.time()

    epochs = config['niter'] + 1
    pbar = tqdm(range(start_iteration, epochs), dynamic_ncols=True, smoothing=0.01)
    for iteration in pbar:
        sampler.set_epoch(iteration)
        
        try:
            ground_truth = next(iterable_train_loader)
        except StopIteration:
            iterable_train_loader = iter(train_loader)
            ground_truth = next(iterable_train_loader)

        # Prepare the inputs
        bboxes = random_bbox(config, batch_size=ground_truth.size(0))
        x, mask = mask_image(ground_truth, bboxes, config)

        
        # Move to proper device.
        #
        bboxes = bboxes.cuda(local_rank)
        x = x.cuda(local_rank)
        mask = mask.cuda(local_rank)
        ground_truth = ground_truth.cuda(local_rank)
        

        ###### Forward pass ######
        compute_g_loss = iteration % config['n_critic'] == 0
#         losses, inpainted_result, offset_flow = forward(config, x, bboxes, mask, ground_truth,
#                                                        localD=localD, globalD=globalD,
#                                                        coarse_gen=coarse_generator, fine_gen=fine_generator,
#                                                        local_rank=local_rank, compute_loss_g=compute_g_loss)
        losses, inpainted_result, offset_flow = forward(config, x, bboxes, mask, ground_truth,
                                                       netG=netG, localD=localD, globalD=globalD,
                                                       local_rank=local_rank, compute_loss_g=compute_g_loss)

        
        # Scalars from different devices are gathered into vectors
        #
        for k in losses.keys():
            if not losses[k].dim() == 0:
                losses[k] = torch.mean(losses[k])
                
                
        ###### Backward pass ######
        # Update D
        if not compute_g_loss:
            optimizer_d.zero_grad()
            losses['d'] = losses['wgan_d'] + losses['wgan_gp'] * config['wgan_gp_lambda']
            losses['d'].backward()
            optimizer_d.step() 

        # Update G
        if compute_g_loss:
            optimizer_g.zero_grad()
            losses['g'] = losses['ae'] * config['ae_loss_alpha']
            losses['g'] += losses['l1'] * config['l1_loss_alpha']
            losses['g'] += losses['wgan_g'] * config['gan_loss_alpha']
            losses['g'].backward()
            optimizer_g.step()


        # Set tqdm description
        #
        if local_rank == 0:
            log_losses = ['l1', 'ae', 'wgan_g', 'wgan_d', 'wgan_gp', 'g', 'd']
            message = ' '
            for k in log_losses:
                v = losses.get(k, 0.)
                writer.add_scalar(k, v, iteration)
                message += '%s: %.4f ' % (k, v)

            pbar.set_description(
                (
                    f" {message}"
                )
            )
            
                
        if local_rank == 0:      
            if iteration % (config['viz_iter']) == 0:
                    viz_max_out = config['viz_max_out']
                    if x.size(0) > viz_max_out:
                        viz_images = torch.stack([x[:viz_max_out], inpainted_result[:viz_max_out],
                                                  offset_flow[:viz_max_out]], dim=1)
                    else:
                        viz_images = torch.stack([x, inpainted_result, offset_flow], dim=1)
                    viz_images = viz_images.view(-1, *list(x.size())[1:])
                    vutils.save_image(viz_images,
                                      '%s/niter_%08d.png' % (checkpoint_path, iteration),
                                      nrow=3 * 4,
                                      normalize=True)

            # Save the model
            if iteration % config['snapshot_save_iter'] == 0:
                save_model(
                    netG, globalD, localD, optimizer_g, optimizer_d, checkpoint_path, iteration
                )