Beispiel #1
0
def generate(img, img_mask_path, model_path):
    with torch.no_grad():   # enter no grad context
        if img_mask_path and is_image_file(img_mask_path):
            # Test a single masked image with a given mask
            x = Image.fromarray(img)
            mask = default_loader(img_mask_path)
            x = transforms.Resize(config['image_shape'][:-1])(x)
            x = transforms.CenterCrop(config['image_shape'][:-1])(x)
            mask = transforms.Resize(config['image_shape'][:-1])(mask)
            mask = transforms.CenterCrop(config['image_shape'][:-1])(mask)
            x = transforms.ToTensor()(x)
            mask = transforms.ToTensor()(mask)[0].unsqueeze(dim=0)
            x = normalize(x)
            x = x * (1. - mask)
            x = x.unsqueeze(dim=0)
            mask = mask.unsqueeze(dim=0)
        elif img_mask_path:
            raise TypeError("{} is not an image file.".format(img_mask_path))
        else:
            # Test a single ground-truth image with a random mask
            #ground_truth = default_loader(img_path)
            ground_truth = img
            ground_truth = transforms.Resize(config['image_shape'][:-1])(ground_truth)
            ground_truth = transforms.CenterCrop(config['image_shape'][:-1])(ground_truth)
            ground_truth = transforms.ToTensor()(ground_truth)
            ground_truth = normalize(ground_truth)
            ground_truth = ground_truth.unsqueeze(dim=0)
            bboxes = random_bbox(config, batch_size=ground_truth.size(0))
            x, mask = mask_image(ground_truth, bboxes, config)

        # Set checkpoint path
        if not model_path:
            checkpoint_path = os.path.join('checkpoints',
                                           config['dataset_name'],
                                           config['mask_type'] + '_' + config['expname'])
        else:
            checkpoint_path = model_path

        # Define the trainer
        netG = Generator(config['netG'], cuda, device_ids)
        # Resume weight
        last_model_name = get_model_list(checkpoint_path, "gen", iteration=0)
        
        if cuda:
            netG.load_state_dict(torch.load(last_model_name))
        else:
            netG.load_state_dict(torch.load(last_model_name, map_location='cpu'))
                                 
        model_iteration = int(last_model_name[-11:-3])
        print("Resume from {} at iteration {}".format(checkpoint_path, model_iteration))

        if cuda:
            netG = nn.parallel.DataParallel(netG, device_ids=device_ids)
            x = x.cuda()
            mask = mask.cuda()

        # Inference
        x1, x2, offset_flow = netG(x, mask)
        inpainted_result = x2 * mask + x * (1. - mask)
        inpainted_result =  from_torch_img_to_numpy(inpainted_result, 'output.png', padding=0, normalize=True)

        return inpainted_result
Beispiel #2
0
def main():
    args = parser.parse_args()
    config = get_config(args.config)

    # CUDA configuration
    cuda = config['cuda']
    device_ids = config['gpu_ids']
    if cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
            str(i) for i in device_ids)
        device_ids = list(range(len(device_ids)))
        config['gpu_ids'] = device_ids
        cudnn.benchmark = True

    # Configure checkpoint path
    checkpoint_path = os.path.join(
        'checkpoints', config['dataset_name'],
        config['mask_type'] + '_' + config['expname'])
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    shutil.copy(args.config,
                os.path.join(checkpoint_path, os.path.basename(args.config)))
    # embed()
    writer = SummaryWriter(logdir=checkpoint_path)
    logger = get_logger(
        checkpoint_path)  # get logger and configure it at the first call

    logger.info(f"Arguments: {args}")
    # Set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    logger.info(f"Random seed: {args.seed}")
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed_all(args.seed)

    # Log the configuration
    logger.info(f"Configuration: {config}")
    # embed()

    try:  # for unexpected error logging
        # Load the dataset
        logger.info(f"Training on dataset: {config['dataset_name']}")
        train_dataset = Dataset(data_path=config['train_data_path'],
                                with_subfolder=config['data_with_subfolder'],
                                image_shape=config['image_shape'],
                                random_crop=config['random_crop'])

        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=config['num_workers'])

        # Define the trainer
        trainer = Trainer(config)
        # logger.info(f"\n{trainer.netG}")
        # logger.info(f"\n{trainer.localD}")
        # logger.info(f"\n{trainer.globalD}")

        # if cuda:
        #     trainer = nn.parallel.DataParallel(trainer, device_ids=device_ids)
        #     trainer_module = trainer.module
        # else:
        trainer_module = trainer

        # Get the resume iteration to restart training
        start_iteration = trainer_module.resume(
            config['resume']) if config['resume'] else 1
        iterable_train_loader = iter(train_loader)
        time_count = time.time()
        epoch = 1
        for iteration in range(start_iteration, config['niter'] + 1):
            try:
                ground_truth = iterable_train_loader.next()
            except StopIteration:
                logger.info(f"Epoch: {epoch}")
                epoch += 1
                iterable_train_loader = iter(train_loader)
                ground_truth = iterable_train_loader.next()

            # Prepare the inputs
            bboxes = random_bbox(config, batch_size=ground_truth.size(0))
            x, mask = mask_image(ground_truth, bboxes, config)
            if cuda:
                x = x.cuda()
                mask = mask.cuda()
                ground_truth = ground_truth.cuda()

            ###### Forward pass ######
            losses, inpainted_result, offset_flow = trainer(
                x, bboxes, mask, ground_truth)
            # Scalars from different devices are gathered into vectors
            for k in losses.keys():
                if not losses[k].dim() == 0:
                    losses[k] = torch.mean(losses[k])

            ###### Backward pass ######
            # Update D
            trainer_module.optimizer_d.zero_grad()
            losses['d'] = losses[
                'wgan_d'] + losses['wgan_gp'] * config['wgan_gp_lambda']
            losses['d'].backward()
            trainer_module.optimizer_d.step()

            # Update G
            trainer_module.optimizer_g.zero_grad()
            losses['g'] = losses['l1'] * config['l1_loss_alpha'] \
                          + losses['ae'] * config['ae_loss_alpha'] \
                          + losses['wgan_g'] * config['gan_loss_alpha']
            losses['g'].backward()
            trainer_module.optimizer_g.step()

            # Log and visualization
            log_losses = ['l1', 'ae', 'wgan_g', 'wgan_d', 'wgan_gp', 'g', 'd']
            message = 'Iter: [%d/%d] ' % (iteration, config['niter'])
            for k in log_losses:
                v = losses[k]
                message += '%s: %.6f ' % (k, v)
            print(f"\r{message}", end="")
            if iteration % config['print_iter'] == 0:
                print("")
                time_count = time.time() - time_count
                speed = config['print_iter'] / time_count
                speed_msg = 'speed: %.2f batches/s ' % speed
                time_count = time.time()

                message += speed_msg
                logger.info(message)

            if iteration % (config['viz_iter']) == 0:
                viz_max_out = config['viz_max_out']
                if x.size(0) > viz_max_out:
                    viz_images = torch.stack([
                        x[:viz_max_out], inpainted_result[:viz_max_out],
                        offset_flow[:viz_max_out]
                    ],
                                             dim=1)
                else:
                    viz_images = torch.stack(
                        [x, inpainted_result, offset_flow], dim=1)
                viz_images = viz_images.view(-1, *list(x.size())[1:])
                vutils.save_image(viz_images,
                                  '%s/niter_%03d.png' %
                                  (checkpoint_path, iteration),
                                  nrow=3 * 4,
                                  normalize=True)

            # Save the model
            if iteration % config['snapshot_save_iter'] == 0:
                trainer_module.save_model(checkpoint_path, iteration)

    except Exception as e:  # for unexpected error logging
        logger.error(f"{e}")
        raise e
def main():
    args = parser.parse_args()
    config = get_config(args.config)

    # CUDA configuration
    cuda = config['cuda']
    device_ids = config['gpu_ids']
    if cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
            str(i) for i in device_ids)
        device_ids = list(range(len(device_ids)))
        config['gpu_ids'] = device_ids
        cudnn.benchmark = True

    print("Arguments: {}".format(args))

    # Set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    print("Random seed: {}".format(args.seed))
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed_all(args.seed)

    print("Configuration: {}".format(config))

    try:  # for unexpected error logging
        with torch.no_grad():   # enter no grad context
            if is_image_file(args.image):
                if args.mask and is_image_file(args.mask):
                    # Test a single masked image with a given mask
                    x = default_loader(args.image)
                    mask = default_loader(args.mask)
                    x = transforms.Resize(config['image_shape'][:-1])(x)
                    x = transforms.CenterCrop(config['image_shape'][:-1])(x)
                    mask = transforms.Resize(config['image_shape'][:-1])(mask)
                    mask = transforms.CenterCrop(
                        config['image_shape'][:-1])(mask)
                    x = transforms.ToTensor()(x)
                    mask = transforms.ToTensor()(mask)[0].unsqueeze(dim=0)
                    x = normalize(x)
                    x = x * (1. - mask)
                    x = x.unsqueeze(dim=0)
                    mask = mask.unsqueeze(dim=0)
                elif args.mask:
                    raise TypeError(
                        "{} is not an image file.".format(args.mask))
                else:
                    # Test a single ground-truth image with a random mask
                    ground_truth = default_loader(args.image)
                    ground_truth = transforms.Resize(
                        config['image_shape'][:-1])(ground_truth)
                    ground_truth = transforms.CenterCrop(
                        config['image_shape'][:-1])(ground_truth)
                    ground_truth = transforms.ToTensor()(ground_truth)
                    ground_truth = normalize(ground_truth)
                    ground_truth = ground_truth.unsqueeze(dim=0)
                    bboxes = random_bbox(
                        config, batch_size=ground_truth.size(0))
                    x, mask = mask_image(ground_truth, bboxes, config)

                # Set checkpoint path
                if not args.checkpoint_path:
                    checkpoint_path = os.path.join('checkpoints',
                                                   config['dataset_name'],
                                                   config['mask_type'] + '_' + config['expname'])
                else:
                    checkpoint_path = args.checkpoint_path

                # Define the trainer
                netG = Generator(config['netG'], cuda, device_ids)
                # Resume weight
                last_model_name = get_model_list(
                    checkpoint_path, "gen", iteration=args.iter)
                netG.load_state_dict(torch.load(last_model_name))
                model_iteration = int(last_model_name[-11:-3])
                print("Resume from {} at iteration {}".format(
                    checkpoint_path, model_iteration))

                if cuda:
                    netG = nn.parallel.DataParallel(
                        netG, device_ids=device_ids)
                    x = x.cuda()
                    mask = mask.cuda()

                # Inference
                x1, x2, offset_flow = netG(x, mask)
                inpainted_result = x2 * mask + x * (1. - mask)

                vutils.save_image(inpainted_result, args.output,
                                  padding=0, normalize=True)
                print("Saved the inpainted result to {}".format(args.output))
                if args.flow:
                    vutils.save_image(offset_flow, args.flow,
                                      padding=0, normalize=True)
                    print("Saved offset flow to {}".format(args.flow))
            else:
                raise TypeError("{} is not an image file.".format)
        # exit no grad context
    except Exception as e:  # for unexpected error logging
        print("Error: {}".format(e))
        raise e
def main():
    args = parser.parse_args()
    config = get_config(args.config)

    # CUDA configuration
    cuda = config['cuda']
    device_ids = config['gpu_ids']
    if cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
            str(i) for i in device_ids)
        device_ids = list(range(len(device_ids)))
        config['gpu_ids'] = device_ids
        cudnn.benchmark = True

    print("Arguments: {}".format(args))

    # Set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    print("Random seed: {}".format(args.seed))
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed_all(args.seed)

    print("Configuration: {}".format(config))

    try:  # for unexpected error logging
        with torch.no_grad():  # enter no grad context
            file = dataset_files(args.test_root, "*.jpg")
            mask_file = dataset_files(args.mask_dir, "*.png")
            for j in range(len(mask_file)):
                for i in range(len(file)):
                    if is_image_file(file[i]):
                        if mask_file and is_image_file(mask_file[j]):
                            # Test a single masked image with a given mask
                            x = default_loader(file[i])
                            mask = default_loader(mask_file[j])
                            # x = cv2.cvtColor(cv2.imread(file[i]), cv2.COLOR_BGR2RGB)
                            # mask = cv2.cvtColor(cv2.imread(mask_file[j]), cv2.COLOR_BGR2RGB)
                            # x = cv2.resize(x, (config['image_shape'][0], config['image_shape'][1]))
                            # mask = cv2.resize(mask, (config['image_shape'][0], config['image_shape'][1]))
                            x = transforms.Resize(
                                config['image_shape'][:-1])(x)
                            x = transforms.CenterCrop(
                                config['image_shape'][:-1])(x)
                            # mask = transforms.Resize(config['image_shape'][:-1])(mask)
                            # mask = transforms.CenterCrop(config['image_shape'][:-1])(mask)
                            x = transforms.ToTensor()(x)
                            mask = transforms.ToTensor()(mask)[0].unsqueeze(
                                dim=0)
                            x = normalize(x)
                            x = x * (1. - mask)
                            x = x.unsqueeze(dim=0)
                            # x_raw = x
                            mask = mask.unsqueeze(dim=0)
                        elif mask_file[j]:
                            raise TypeError("{} is not an image file.".format(
                                mask_file[j]))
                        else:
                            # Test a single ground-truth image with a random mask
                            ground_truth = default_loader(file[i])
                            ground_truth = transforms.Resize(
                                config['image_shape'][:-1])(ground_truth)
                            ground_truth = transforms.CenterCrop(
                                config['image_shape'][:-1])(ground_truth)
                            ground_truth = transforms.ToTensor()(ground_truth)
                            ground_truth = normalize(ground_truth)
                            ground_truth = ground_truth.unsqueeze(dim=0)
                            bboxes = test_bbox(config,
                                               batch_size=ground_truth.size(0),
                                               t=50,
                                               l=50)
                            x, mask = mask_image(ground_truth, bboxes, config)

                        # Set checkpoint path
                        if not args.checkpoint_path:
                            checkpoint_path = os.path.join(
                                'checkpoints', config['dataset_name'],
                                config['mask_type'] + '_' + config['expname'])
                        else:
                            checkpoint_path = args.checkpoint_path

                        # Define the trainer
                        netG = Generator(config['netG'], cuda, device_ids)
                        # Resume weight
                        g_checkpoint = torch.load(f'{checkpoint_path}/gen.pt')
                        netG.load_state_dict(g_checkpoint)
                        # model_iteration = int(last_model_name[-11:-3])
                        print("Model Resumed".format(checkpoint_path))

                        if cuda:
                            netG = nn.parallel.DataParallel(
                                netG, device_ids=device_ids)
                            x = x.cuda()
                            mask = mask.cuda()

                        # Inference
                        x1, x2 = netG(x, mask)
                        inpainted_result = x2 * mask + x * (1. - mask)
                        inpainted_result_cpu = torch.Tensor.cpu(
                            inpainted_result).detach().permute(0, 2, 3, 1)
                        inpainted_result_cpu = np.asarray(
                            inpainted_result_cpu[0])
                        inpainted_result_cpu = cv2.normalize(
                            inpainted_result_cpu, inpainted_result_cpu, 0, 255,
                            cv2.NORM_MINMAX)

                        # cat_result = torch.cat([x, inpainted_result, ground_truth], dim=3).cuda()

                        vutils.save_image(inpainted_result,
                                          args.output_dir +
                                          'output_{}/'.format(j + 1) +
                                          'output_{}.png'.format(i),
                                          padding=0,
                                          normalize=True)
                        # cv2.imwrite(args.output_dir+ 'output_{}/'.format(j+1) + 'output_{}.png'.format(i), inpainted_result_cpu)
                        #             cv2.cvtColor(inpainted_result_cpu, cv2.COLOR_BGR2RGB))
                        print("{}th image saved".format(i))
                    else:
                        raise TypeError("{} is not an image file.".format)
            # exit no grad context
    except Exception as e:  # for unexpected error logging
        print("Error: {}".format(e))
        raise e
def train(config, logger, checkpoint_path):
    try:  # for unexpected error logging
        # Load the dataset
        logger.info("Training on dataset: {}".format(config['dataset_name']))
        train_dataset = Dataset(data_path=config['train_data_path'],
                                with_subfolder=config['data_with_subfolder'],
                                image_shape=config['image_shape'],
                                random_crop=config['random_crop'])
        # val_dataset = Dataset(data_path=config['val_data_path'],
        #                       with_subfolder=config['data_with_subfolder'],
        #                       image_size=config['image_size'],
        #                       random_crop=config['random_crop'])
        train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                   batch_size=config['batch_size'],
                                                   shuffle=True,
                                                   num_workers=config['num_workers'])
        # val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
        #                                           batch_size=config['batch_size'],
        #                                           shuffle=False,
        #                                           num_workers=config['num_workers'])

        # Define the trainer
        trainer = Trainer(config)
        logger.info("\n{}".format(trainer.netG))
        logger.info("\n{}".format(trainer.localD))
        logger.info("\n{}".format(trainer.globalD))

        
#         if cuda:
#             trainer = nn.parallel.DataParallel(trainer, device_ids=device_ids)
#             trainer_module = trainer.module
#         else:
#             trainer_module = trainer
        trainer_module = trainer

            
        # Get the resume iteration to restart training
        #
        start_iteration = trainer_module.resume(config['resume']) if config['resume'] else 1
        print("\n\nStarting epoch: ", start_iteration)

        iterable_train_loader = iter(train_loader)

    
        time_count = time.time()


        epochs = config['niter'] + 1
        pbar = tqdm(range(start_iteration, epochs), dynamic_ncols=True, smoothing=0.01)
#         for iteration in range(start_iteration, epochs):
        for iteration in pbar:
            try:
                ground_truth = next(iterable_train_loader)
            except StopIteration:
                iterable_train_loader = iter(train_loader)
                ground_truth = next(iterable_train_loader)

            # Prepare the inputs
            bboxes = random_bbox(config, batch_size=ground_truth.size(0))
            x, mask = mask_image(ground_truth, bboxes, config)
            x = x.cuda()
            mask = mask.cuda()
            ground_truth = ground_truth.cuda()

            ###### Forward pass ######
            compute_g_loss = iteration % config['n_critic'] == 0
            losses, inpainted_result, offset_flow = trainer(x, bboxes, mask, ground_truth, compute_g_loss)
            # Scalars from different devices are gathered into vectors
            for k in losses.keys():
                if not losses[k].dim() == 0:
                    losses[k] = torch.mean(losses[k])

            ###### Backward pass ######
            # Update D
            if not compute_g_loss:
                trainer_module.optimizer_d.zero_grad()
                losses['d'] = losses['wgan_d'] + losses['wgan_gp'] * config['wgan_gp_lambda']
                losses['d'].backward()
                trainer_module.optimizer_d.step()

            # Update G
            if compute_g_loss:
                trainer_module.optimizer_g.zero_grad()
                losses['g'] = losses['l1'] * config['l1_loss_alpha'] \
                              + losses['ae'] * config['ae_loss_alpha'] \
                              + losses['wgan_g'] * config['gan_loss_alpha']
                losses['g'].backward()
                trainer_module.optimizer_g.step()

                
            ### TODO:
            ### - Why does this need to be moved from above to here?
            ###
#             losses['d'].backward()
#             trainer_module.optimizer_d.step()    
            
            
            # Set tqdm description
            #
            log_losses = ['l1', 'ae', 'wgan_g', 'wgan_d', 'wgan_gp', 'g', 'd']  
#             message = 'Iter: [%d/%d] ' % (iteration, config['niter'])
            message = ' '
            for k in log_losses:
                v = losses.get(k, 0.)
                writer.add_scalar(k, v, iteration)
                message += '%s: %.4f ' % (k, v)
                
            pbar.set_description(
                (
                    f" {message}"
                )
            )
                
                
            # Log and visualization
            log_losses = ['l1', 'ae', 'wgan_g', 'wgan_d', 'wgan_gp', 'g', 'd']
            if iteration % config['print_iter'] == 0:
                time_count = time.time() - time_count
                speed = config['print_iter'] / time_count
                speed_msg = 'speed: %.2f batches/s ' % speed
                time_count = time.time()

                message = 'Iter: [%d/%d] ' % (iteration, config['niter'])
                for k in log_losses:
                    v = losses.get(k, 0.)
                    writer.add_scalar(k, v, iteration)
                    message += '%s: %.6f ' % (k, v)
                message += speed_msg
#                 logger.info(message)
                

            if iteration % (config['viz_iter']) == 0:
                viz_max_out = config['viz_max_out']
                if x.size(0) > viz_max_out:
                    viz_images = torch.stack([x[:viz_max_out], inpainted_result[:viz_max_out],
                                              offset_flow[:viz_max_out]], dim=1)
                else:
                    viz_images = torch.stack([x, inpainted_result, offset_flow], dim=1)
                viz_images = viz_images.view(-1, *list(x.size())[1:])
                vutils.save_image(viz_images,
                                  '%s/niter_%03d.png' % (checkpoint_path, iteration),
                                  nrow=3 * 4,
                                  normalize=True)

                
            # Save the model
            if iteration % config['snapshot_save_iter'] == 0:
                trainer_module.save_model(checkpoint_path, iteration)
    
    
    except Exception as e:  # for unexpected error logging
        logger.error("{}".format(e))
        raise e
def train_distributed(config, logger, writer, checkpoint_path):
    
    dist.init_process_group(                                   
        backend='nccl',
#         backend='gloo',
        init_method='env://'
    )  
    
    
    # Find out what GPU on this compute node.
    #
    local_rank = torch.distributed.get_rank()
    
    
    # this is the total # of GPUs across all nodes
    # if using 2 nodes with 4 GPUs each, world size is 8
    #
    world_size = torch.distributed.get_world_size()
    print("### global rank of curr node: {} of {}".format(local_rank, world_size))
    
    
    # For multiprocessing distributed, DistributedDataParallel constructor
    # should always set the single device scope, otherwise,
    # DistributedDataParallel will use all available devices.
    #
    print("local_rank: ", local_rank)
#     dist.barrier()
    torch.cuda.set_device(local_rank)
    
    
    # Define the trainer
    print("Creating models on device: ", local_rank)
    
    
    input_dim = config['netG']['input_dim']
    cnum = config['netG']['ngf']
    use_cuda = True
    gated = config['netG']['gated']
    
    
    # Models
    #
    netG = Generator(config['netG'], use_cuda=True, device=local_rank).cuda()
    netG = torch.nn.parallel.DistributedDataParallel(
        netG,
        device_ids=[local_rank],
        output_device=local_rank,
        find_unused_parameters=True
    )

    
    localD = LocalDis(config['netD'], use_cuda=True, device_id=local_rank).cuda()
    localD = torch.nn.parallel.DistributedDataParallel(
        localD,
        device_ids=[local_rank],
        output_device=local_rank,
        find_unused_parameters=True
    )
    
    
    globalD = GlobalDis(config['netD'], use_cuda=True, device_id=local_rank).cuda()
    globalD = torch.nn.parallel.DistributedDataParallel(
        globalD,
        device_ids=[local_rank],
        output_device=local_rank,
        find_unused_parameters=True
    )
    
    
    if local_rank == 0:
        logger.info("\n{}".format(netG))
        logger.info("\n{}".format(localD))
        logger.info("\n{}".format(globalD))
        
    
    # Optimizers
    #
    optimizer_g = torch.optim.Adam(
        netG.parameters(),
        lr=config['lr'],
        betas=(config['beta1'], config['beta2'])
    )

    
    d_params = list(localD.parameters()) + list(globalD.parameters())
    optimizer_d = torch.optim.Adam(
        d_params,  
        lr=config['lr'],                                    
        betas=(config['beta1'], config['beta2'])                              
    )
    
    
    # Data
    #
    sampler = None
    train_dataset = Dataset(
        data_path=config['train_data_path'],
        with_subfolder=config['data_with_subfolder'],
        image_shape=config['image_shape'],
        random_crop=config['random_crop']
    )
        
    
    sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
#             num_replicas=torch.cuda.device_count(),
        num_replicas=len(config['gpu_ids']),
#         rank = local_rank
    )
    
    
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=config['batch_size'],
        shuffle=(sampler is None),
        num_workers=config['num_workers'],
        pin_memory=True,
        sampler=sampler,
        drop_last=True
    )
    
    
    # Get the resume iteration to restart training
    #
#     start_iteration = trainer.resume(config['resume']) if config['resume'] else 1
    start_iteration = 1
    print("\n\nStarting epoch: ", start_iteration)

    iterable_train_loader = iter(train_loader)

    if local_rank == 0: 
        time_count = time.time()

    epochs = config['niter'] + 1
    pbar = tqdm(range(start_iteration, epochs), dynamic_ncols=True, smoothing=0.01)
    for iteration in pbar:
        sampler.set_epoch(iteration)
        
        try:
            ground_truth = next(iterable_train_loader)
        except StopIteration:
            iterable_train_loader = iter(train_loader)
            ground_truth = next(iterable_train_loader)

        # Prepare the inputs
        bboxes = random_bbox(config, batch_size=ground_truth.size(0))
        x, mask = mask_image(ground_truth, bboxes, config)

        
        # Move to proper device.
        #
        bboxes = bboxes.cuda(local_rank)
        x = x.cuda(local_rank)
        mask = mask.cuda(local_rank)
        ground_truth = ground_truth.cuda(local_rank)
        

        ###### Forward pass ######
        compute_g_loss = iteration % config['n_critic'] == 0
#         losses, inpainted_result, offset_flow = forward(config, x, bboxes, mask, ground_truth,
#                                                        localD=localD, globalD=globalD,
#                                                        coarse_gen=coarse_generator, fine_gen=fine_generator,
#                                                        local_rank=local_rank, compute_loss_g=compute_g_loss)
        losses, inpainted_result, offset_flow = forward(config, x, bboxes, mask, ground_truth,
                                                       netG=netG, localD=localD, globalD=globalD,
                                                       local_rank=local_rank, compute_loss_g=compute_g_loss)

        
        # Scalars from different devices are gathered into vectors
        #
        for k in losses.keys():
            if not losses[k].dim() == 0:
                losses[k] = torch.mean(losses[k])
                
                
        ###### Backward pass ######
        # Update D
        if not compute_g_loss:
            optimizer_d.zero_grad()
            losses['d'] = losses['wgan_d'] + losses['wgan_gp'] * config['wgan_gp_lambda']
            losses['d'].backward()
            optimizer_d.step() 

        # Update G
        if compute_g_loss:
            optimizer_g.zero_grad()
            losses['g'] = losses['ae'] * config['ae_loss_alpha']
            losses['g'] += losses['l1'] * config['l1_loss_alpha']
            losses['g'] += losses['wgan_g'] * config['gan_loss_alpha']
            losses['g'].backward()
            optimizer_g.step()


        # Set tqdm description
        #
        if local_rank == 0:
            log_losses = ['l1', 'ae', 'wgan_g', 'wgan_d', 'wgan_gp', 'g', 'd']
            message = ' '
            for k in log_losses:
                v = losses.get(k, 0.)
                writer.add_scalar(k, v, iteration)
                message += '%s: %.4f ' % (k, v)

            pbar.set_description(
                (
                    f" {message}"
                )
            )
            
                
        if local_rank == 0:      
            if iteration % (config['viz_iter']) == 0:
                    viz_max_out = config['viz_max_out']
                    if x.size(0) > viz_max_out:
                        viz_images = torch.stack([x[:viz_max_out], inpainted_result[:viz_max_out],
                                                  offset_flow[:viz_max_out]], dim=1)
                    else:
                        viz_images = torch.stack([x, inpainted_result, offset_flow], dim=1)
                    viz_images = viz_images.view(-1, *list(x.size())[1:])
                    vutils.save_image(viz_images,
                                      '%s/niter_%08d.png' % (checkpoint_path, iteration),
                                      nrow=3 * 4,
                                      normalize=True)

            # Save the model
            if iteration % config['snapshot_save_iter'] == 0:
                save_model(
                    netG, globalD, localD, optimizer_g, optimizer_d, checkpoint_path, iteration
                )
def main():
    args = parser.parse_args()
    config = get_config(args.config)  # way to use config

    # CUDA configuration
    cuda = config['cuda']  # specify cuda in config.yaml
    if torch.cuda.device_count() > 0:
        cuda = True
        config['cuda'] = True

    device_ids = config['gpu_ids']
    if cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
            str(i) for i in device_ids)
        device_ids = list(range(len(device_ids)))
        config['gpu_ids'] = device_ids
        cudnn.benchmark = True

    # Configure checkpoint path
    # say you use imagenet pretrained model, since the dataset name is altered to be "dtd"
    # the trained checkpoint of this model will be stored in checkpoints/dtd/hole_benchmark
    checkpoint_path = os.path.join(
        'checkpoints', config['dataset_name'],
        config['mask_type'] + '_' + config['expname'])
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    shutil.copy(args.config,
                os.path.join(checkpoint_path, os.path.basename(args.config)))
    writer = SummaryWriter(logdir=checkpoint_path)
    logger = get_logger(
        checkpoint_path)  # get logger and configure it at the first call

    logger.info("Arguments: {}".format(args))
    # Set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    logger.info("Random seed: {}".format(args.seed))
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed_all(args.seed)

    # Log the configuration
    logger.info("Configuration: {}".format(config))

    try:  # for unexpected error logging
        # Load the dataset
        logger.info("Training on dataset: {}".format(config['dataset_name']))
        train_dataset = Dataset(data_path=config['train_data_path'],
                                with_subfolder=config['data_with_subfolder'],
                                image_shape=config['image_shape'],
                                random_crop=config['random_crop'])
        # val_dataset = Dataset(data_path=config['val_data_path'],
        #                       with_subfolder=config['data_with_subfolder'],
        #                       image_size=config['image_size'],
        #                       random_crop=config['random_crop'])
        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=config['num_workers'])
        # val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
        #                                           batch_size=config['batch_size'],
        #                                           shuffle=False,
        #                                           num_workers=config['num_workers'])

        # Define the trainer
        trainer = Trainer(config)
        logger.info("\n{}".format(trainer.netG))
        logger.info("\n{}".format(trainer.localD))
        logger.info("\n{}".format(trainer.globalD))

        if cuda:
            trainer = nn.parallel.DataParallel(trainer, device_ids=device_ids)
            trainer_module = trainer.module
        else:
            trainer_module = trainer

        # Get the resume iteration to restart training
        # config['resume'] being the directory to checkpoints such as checkpoints/imagenet/hole_benchmark/
        start_iteration = trainer_module.resume(
            config['resume']) if config['resume'] else 1
        print("start at iteration {}".format(start_iteration))
        iterable_train_loader = iter(train_loader)

        time_count = time.time()

        for iteration in range(start_iteration, config['niter'] + 1):
            try:
                ground_truth = next(iterable_train_loader)
            except StopIteration:
                iterable_train_loader = iter(train_loader)
                ground_truth = next(iterable_train_loader)

            # Prepare the inputs
            bboxes = random_bbox(config, batch_size=ground_truth.size(0))
            x, mask = mask_image(ground_truth, bboxes, config)
            if cuda:
                x = x.cuda()
                mask = mask.cuda()
                ground_truth = ground_truth.cuda()

            ###### Forward pass ######
            compute_g_loss = iteration % config['n_critic'] == 0
            losses, inpainted_result, offset_flow = trainer(
                x, bboxes, mask, ground_truth, compute_g_loss)
            # Scalars from different devices are gathered into vectors
            for k in losses.keys():
                if not losses[k].dim() == 0:
                    losses[k] = torch.mean(losses[k])

            ###### Backward pass ######
            # Update D
            trainer_module.optimizer_d.zero_grad()
            losses['d'] = losses[
                'wgan_d'] + losses['wgan_gp'] * config['wgan_gp_lambda']
            losses['d'].backward()
            trainer_module.optimizer_d.step()

            # Update G
            if compute_g_loss:
                trainer_module.optimizer_g.zero_grad()
                losses['g'] = losses['l1'] * config['l1_loss_alpha'] \
                              + losses['ae'] * config['ae_loss_alpha'] \
                              + losses['wgan_g'] * config['gan_loss_alpha']
                losses['g'].backward()
                trainer_module.optimizer_g.step()

            # Log and visualization
            log_losses = ['l1', 'ae', 'wgan_g', 'wgan_d', 'wgan_gp', 'g', 'd']
            if iteration % config['print_iter'] == 0:
                time_count = time.time() - time_count
                speed = config['print_iter'] / time_count
                speed_msg = 'speed: %.2f batches/s ' % speed
                time_count = time.time()

                message = 'Iter: [%d/%d] ' % (iteration, config['niter'])
                for k in log_losses:
                    v = losses.get(k, 0.)
                    writer.add_scalar(k, v, iteration)
                    message += '%s: %.6f ' % (k, v)
                message += speed_msg
                logger.info(message)

            if iteration % (config['viz_iter']) == 0:
                viz_max_out = config['viz_max_out']
                if x.size(0) > viz_max_out:
                    viz_images = torch.stack([
                        x[:viz_max_out], inpainted_result[:viz_max_out],
                        offset_flow[:viz_max_out]
                    ],
                                             dim=1)
                else:
                    viz_images = torch.stack(
                        [x, inpainted_result, offset_flow], dim=1)
                viz_images = viz_images.view(-1, *list(x.size())[1:])
                vutils.save_image(viz_images,
                                  '%s/niter_%03d.png' %
                                  (checkpoint_path, iteration),
                                  nrow=3 * 4,
                                  normalize=True)

            # Save the model
            if iteration % config['snapshot_save_iter'] == 0:
                trainer_module.save_model(checkpoint_path, iteration)
                # checkpoint_path is checkpoints/dtd/hole_benchmark
                # then you should use this check point to test raven score then store them in a file
            # if iteration % (config['snapshot_save_iter']//10) == 0:
            #     os.system("python test_raven.py --checkpoint_path {} >> tmp.out".format(checkpoint_path))

            # this could cause trouble if it's not python but python3 etc.

    except Exception as e:  # for unexpected error logging
        logger.error("{}".format(e))
        raise e
                raise TypeError("{} is not an image file.".format(args.mask))
            else:
                # Test a single ground-truth image with a random mask
                ground_truth = tif_loader(args.image)
                #ground_truth = transforms.Resize(config['image_shape'][:-1])(ground_truth)
                #ground_truth = transforms.CenterCrop(config['image_shape'][:-1])(ground_truth)
                #ground_truth = transfer2tensor(ground_truth)
                #ground_truth = transforms.ToTensor()(ground_truth)
                x = ground_truth[110:366, 110:366, :]
                x = np.transpose(x, [2, 0, 1])
                #print('Output min, max',x.min(),x.max())
                ground_truth = torch.from_numpy(x)
                ground_truth = normalize(ground_truth)
                ground_truth = ground_truth.unsqueeze(dim=0)
                bboxes = random_bbox(config, batch_size=ground_truth.size(0))
                x, mask = mask_image(ground_truth, bboxes, config)

            # Set checkpoint path
            if not args.checkpoint_path:
                checkpoint_path = os.path.join(
                    'checkpoints', config['dataset_name'],
                    config['mask_type'] + '_' + config['expname'])
            else:
                checkpoint_path = args.checkpoint_path

            # Define the trainer
            netG = Generator(config['netG'], cuda, device_ids)
            # Resume weight
            last_model_name = get_model_list(checkpoint_path,
                                             "gen",
                                             iteration=args.iter)