Beispiel #1
0
    def init_model(self, path):
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            print("Using gpu.")
        else:
            self.device = torch.device("cpu")
            print("Using cpu.")

        self.model = DFNet().to(self.device)
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint)
        self.model.eval()

        print("Model %s loaded." % path)
    def init_model(self, path, dev):
        #        if torch.cuda.is_available():
        #            self.device = torch.device('cuda')
        #            print('Using gpu.')
        #        else:
        #            self.device = torch.device('cpu')
        #            print('Using cpu.')
        if dev == 'cpu':
            self.device = torch.device('cpu')
            print('Using cpu.')
        else:
            self.device = torch.device('cuda')
            print('Using gpu.')

        self.model = DFNet().to(self.device)
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint)
        self.model.eval()

        print('Model %s loaded.' % path)
Beispiel #3
0
def build_model():
    # Building  model
    print("Building and training DF model")

    model = DFNet.build(input_shape=parameters.INPUT_SHAPE,
                        classes=parameters.NB_CLASSES)

    model.compile(loss="categorical_crossentropy",
                  optimizer=parameters.OPTIMIZER,
                  metrics=["accuracy"])
    print("Model compiled")
    return model
Beispiel #4
0
Datei: run.py Projekt: ywu40/BMBC
parser.add_argument('--second', type=str, required=True)
parser.add_argument('--output', type=str, required=True)
parser.add_argument('--time_step', type=float, default=0.5)

args = parser.parse_args()
args.dict = dict()

torch.backends.cudnn.benchmark = True

args.dict['context_layer'] = nn.Conv2d(3,
                                       64, (7, 7),
                                       stride=(1, 1),
                                       padding=(3, 3),
                                       bias=False)
args.dict['BMNet'] = BMNet()
args.dict['DF_Net'] = DFNet(32, 4, 16, 6)
args.dict['filtering'] = DynFilter()

args.dict['context_layer'].load_state_dict(
    torch.load('Weights/context_layer.pth'))
args.dict['BMNet'].load_state_dict(torch.load(args.ckpt_bm))
args.dict['DF_Net'].load_state_dict(torch.load(args.ckpt_df))
ReLU = torch.nn.ReLU()

for param in args.dict['context_layer'].parameters():
    param.requires_grad = False
for param in args.dict['BMNet'].parameters():
    param.requires_grad = False
for param in args.dict['DF_Net'].parameters():
    param.requires_grad = False
Beispiel #5
0
class Tester:
    def __init__(self, model_path, input_size, batch_size):
        self.model_path = model_path
        self._input_size = input_size
        self.batch_size = batch_size
        self.init_model(model_path)

    @property
    def input_size(self):
        if self._input_size > 0:
            return (self._input_size, self._input_size)
        elif "celeba" in self.model_path:
            return (256, 256)
        else:
            return (512, 512)

    def init_model(self, path):
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            print("Using gpu.")
        else:
            self.device = torch.device("cpu")
            print("Using cpu.")

        self.model = DFNet().to(self.device)
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint)
        self.model.eval()

        print("Model %s loaded." % path)

    def get_name(self, path):
        return ".".join(Path(path).name.split(".")[:-1])

    def results_path(self, output, img_path, mask_path, prefix="result"):
        img_name = self.get_name(img_path)
        mask_name = self.get_name(mask_path)
        return {
            "result_path":
            self.sub_dir("result").joinpath("result-{}-{}.png".format(
                img_name, mask_name)),
            "raw_path":
            self.sub_dir("raw").joinpath("raw-{}-{}.png".format(
                img_name, mask_name)),
            "alpha_path":
            self.sub_dir("alpha").joinpath("alpha-{}-{}.png".format(
                img_name, mask_name)),
        }

    def inpaint_instance(self, img, mask):
        """Assume color image with 3 dimension. CWH"""
        img = img.view(1, *img.shape)
        mask = mask.view(1, 1, *mask.shape)
        return self.inpaint_batch(img, mask).squeeze()

    def inpaint_batch(self, imgs, masks):
        """Assume color channel is BGR and input is NWHC np.uint8."""
        imgs = np.transpose(imgs, [0, 3, 1, 2])
        masks = np.transpose(masks, [0, 3, 1, 2])

        imgs = torch.from_numpy(imgs).to(self.device)
        masks = torch.from_numpy(masks).to(self.device)
        imgs = imgs.float().div(255)
        masks = masks.float().div(255)
        imgs_miss = imgs * masks
        results = self.model(imgs_miss, masks)
        if type(results) is list:
            results = results[0]
        results = results.mul(255).byte().data.cpu().numpy()
        results = np.transpose(results, [0, 2, 3, 1])
        return results

    def _process_file(self, output, img_path, mask_path):
        item = {"img_path": img_path, "mask_path": mask_path}
        item.update(self.results_path(output, img_path, mask_path))
        self.path_pair.append(item)

    def process_single_file(self, output, img_path, mask_path):
        self.path_pair = []
        self._process_file(output, img_path, mask_path)

    def process_dir(self, output, img_dir, mask_dir):
        img_dir = Path(img_dir)
        mask_dir = Path(mask_dir)
        imgs_path = sorted(
            list(img_dir.glob("*.jpg")) + list(img_dir.glob("*.png")))
        masks_path = sorted(
            list(mask_dir.glob("*.jpg")) + list(mask_dir.glob("*.png")))

        n_img = len(imgs_path)
        n_mask = len(masks_path)
        n_pair = min(n_img, n_mask)

        self.path_pair = []
        for i in range(n_pair):
            img_path = imgs_path[i % n_img]
            mask_path = masks_path[i % n_mask]
            self._process_file(output, img_path, mask_path)

    def get_process(self, input_size):
        def process(pair):
            img = cv2.imread(str(pair["img_path"]), cv2.IMREAD_COLOR)
            mask = cv2.imread(str(pair["mask_path"]), cv2.IMREAD_GRAYSCALE)
            if input_size:
                img = cv2.resize(img, input_size)
                mask = cv2.resize(mask, input_size)
            img = np.ascontiguousarray(img.transpose(2, 0, 1)).astype(np.uint8)
            mask = np.ascontiguousarray(np.expand_dims(mask,
                                                       0)).astype(np.uint8)

            pair["img"] = img
            pair["mask"] = mask
            return pair

        return process

    def _file_batch(self):
        pool = Pool()

        n_pair = len(self.path_pair)
        n_batch = (n_pair - 1) // self.batch_size + 1

        for i in tqdm.trange(n_batch, leave=False):
            _buffer = defaultdict(list)
            start = i * self.batch_size
            stop = start + self.batch_size
            process = self.get_process(self.input_size)
            batch = pool.imap_unordered(process,
                                        islice(self.path_pair, start, stop))
            for instance in batch:
                for k, v in instance.items():
                    _buffer[k].append(v)
            yield _buffer

    def batch_generator(self):
        generator = self._file_batch

        for _buffer in generator():
            for key in _buffer:
                if key in ["img", "mask"]:
                    _buffer[key] = list2nparray(_buffer[key])
            yield _buffer

    def to_numpy(self, tensor):
        tensor = tensor.mul(255).byte().data.cpu().numpy()
        tensor = np.transpose(tensor, [0, 2, 3, 1])
        return tensor

    def process_batch(self, batch, output):
        imgs = torch.from_numpy(batch["img"]).to(self.device)
        masks = torch.from_numpy(batch["mask"]).to(self.device)
        imgs = imgs.float().div(255)
        masks = masks.float().div(255)
        imgs_miss = imgs * masks

        result, alpha, raw = self.model(imgs_miss, masks)
        result, alpha, raw = result[0], alpha[0], raw[0]
        result = imgs * masks + result * (1 - masks)

        result = self.to_numpy(result)
        alpha = self.to_numpy(alpha)
        raw = self.to_numpy(raw)

        for i in range(result.shape[0]):
            cv2.imwrite(str(batch["result_path"][i]), result[i])
            cv2.imwrite(str(batch["raw_path"][i]), raw[i])
            cv2.imwrite(str(batch["alpha_path"][i]), alpha[i])

    @property
    def root(self):
        return Path(self.output)

    def sub_dir(self, sub):
        return self.root.joinpath(sub)

    def prepare_folders(self, folders):
        for folder in folders:
            Path(folder).mkdir(parents=True, exist_ok=True)

    def inpaint(self, output, img, mask, merge_result=False):

        self.output = output
        self.prepare_folders([
            self.sub_dir("result"),
            self.sub_dir("alpha"),
            self.sub_dir("raw")
        ])

        if os.path.isfile(img) and os.path.isfile(mask):
            if img.endswith((".png", ".jpg", ".jpeg")):
                self.process_single_file(output, img, mask)
                _type = "file"
            else:
                raise NotImplementedError()
        elif os.path.isdir(img) and os.path.isdir(mask):
            self.process_dir(output, img, mask)
            _type = "dir"
        else:
            print("Img: ", img)
            print("Mask: ", mask)
            raise NotImplementedError(
                "img and mask should be both file or directory.")

        print("# Inpainting...")
        print("Input size:", self.input_size)
        for batch in self.batch_generator():
            self.process_batch(batch, output)
        print("Inpainting finished.")

        if merge_result:
            miss = self.sub_dir("miss")
            merge = self.sub_dir("merge")

            print("# Preparing input images...")
            gen_miss(img, mask, miss)
            print("# Merging...")
            merge_imgs(
                [
                    miss,
                    self.sub_dir("raw"),
                    self.sub_dir("alpha"),
                    self.sub_dir("result"),
                    img,
                ],
                merge,
                res=self.input_size[0],
            )
            print("Merging finished.")
Beispiel #6
0
size = (args.image_size, args.image_size)
img_tf = transforms.Compose([
    transforms.Resize(size=size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

dataset = DS(args.root, img_tf)

iterator_train = iter(
    data.DataLoader(dataset,
                    batch_size=args.batch_size,
                    sampler=InfiniteSampler(len(dataset)),
                    num_workers=args.n_threads))
print(len(dataset))
model = DFNet().to(device)

lr = args.lr

start_iter = 0
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
criterion = InpaintingLoss().to(device)

if args.resume:
    checkpoint = torch.load(args.resume, map_location=device)
    model.load_state_dict(checkpoint)

for i in tqdm(range(start_iter, args.max_iter)):
    model.train()
Beispiel #7
0
class Tester:

    def __init__(self, model_path, input_size, batch_size):
        self.model_path = model_path
        self._input_size = input_size
        self.batch_size = batch_size
        self.init_model(model_path)
        self.results = np.zeros((0, input_size, input_size, 3), dtype = np.uint8)
        self.results_ids = []

    @property
    def input_size(self):
        if self._input_size > 0:
            return (self._input_size, self._input_size)
        elif 'celeba' in self.model_path:
            return (256, 256)
        else:
            return (512, 512)

    def init_model(self, path):
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
            print('Using gpu.')
        else:
            self.device = torch.device('cpu')
            print('Using cpu.')

        self.model = DFNet().to(self.device)
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint)
        self.model.eval()

        print('Model %s loaded.' % path)

    def get_name(self, path):
        print(path)
        return '.'.join(path.split('.')[:-1])

    def inpaint_instance(self, img, mask):
        """Assume color image with 3 dimension. CWH"""
        img = img.view(1, *img.shape)
        mask = mask.view(1, 1, *mask.shape)
        return self.inpaint_batch(img, mask).squeeze()

    def inpaint_batch(self, imgs, masks):
        """Assume color channel is BGR and input is NWHC np.uint8."""
        imgs = np.transpose(imgs, [0, 3, 1, 2])
        masks = np.transpose(masks, [0, 3, 1, 2])

        imgs = torch.from_numpy(imgs).to(self.device)
        masks = torch.from_numpy(masks).to(self.device)
        imgs = imgs.float().div(255)
        masks = masks.float().div(255)
        imgs_miss = imgs * masks
        results = self.model(imgs_miss, masks)
        if type(results) is list:
            results = results[0]
        results = results.mul(255).byte().data.cpu().numpy()
        results = np.transpose(results, [0, 2, 3, 1])
        return results

    def _process_file(self, img_path, mask_path):
        item = {
            'img_path': img_path,
            'mask_path': mask_path,
        }
        self.path_pair.append(item)

    def process_single_file(self, img_path, mask_path):
        self.path_pair = []
        self._process_file(img_path, mask_path)

    def process_dir(self, img_dir, mask_dir):
        img_dir = Path(img_dir)
        mask_dir = Path(mask_dir)
        imgs_path = sorted(
            list(img_dir.glob('*.jpg')) + list(img_dir.glob('*.png')))
        masks_path = sorted(
            list(mask_dir.glob('*.jpg')) + list(mask_dir.glob('*.png')))

        n_img = len(imgs_path)
        n_mask = len(masks_path)
        n_pair = min(n_img, n_mask)

        self.path_pair = []
        for i in range(n_pair):
            img_path = imgs_path[i % n_img]
            mask_path = masks_path[i % n_mask]
            self._process_file(img_path, mask_path)

    def get_process(self, input_size):
        def process(pair):
            img = cv2.imread(str(pair['img_path']), cv2.IMREAD_COLOR)
            mask = cv2.imread(str(pair['mask_path']), cv2.IMREAD_GRAYSCALE)
            if input_size:
                img = cv2.resize(img, input_size)
                mask = cv2.resize(mask, input_size)
            img = np.ascontiguousarray(img.transpose(2, 0, 1)).astype(np.uint8)
            mask = np.ascontiguousarray(
                np.expand_dims(mask, 0)).astype(np.uint8)

            pair['img'] = img
            pair['mask'] = mask
            return pair
        return process

    def _file_batch(self):
        pool = Pool()

        n_pair = len(self.path_pair)
        n_batch = (n_pair-1) // self.batch_size + 1

        for i in tqdm.trange(n_batch, leave=False):
            _buffer = defaultdict(list)
            start = i * self.batch_size
            stop = start + self.batch_size
            process = self.get_process(self.input_size)
            batch = pool.imap_unordered(
                process, islice(self.path_pair, start, stop))
            for instance in batch:
                for k, v in instance.items():
                    _buffer[k].append(v)
            yield _buffer

    def batch_generator(self):
        generator = self._file_batch

        for _buffer in generator():
            for key in _buffer:
                if key in ['img', 'mask']:
                    _buffer[key] = list2nparray(_buffer[key])
            yield _buffer

    def to_numpy(self, tensor):
        tensor = tensor.mul(255).byte().data.cpu().numpy()
        tensor = np.transpose(tensor, [0, 2, 3, 1])
        return tensor

    def process_batch(self, batch):
        imgs = torch.from_numpy(batch['img']).to(self.device)
        masks = torch.from_numpy(batch['mask']).to(self.device)
        imgs = imgs.float().div(255)
        masks = masks.float().div(255)
        imgs_miss = imgs * masks

        result, alpha, raw = self.model(imgs_miss, masks)
        result, alpha, raw = result[0], alpha[0], raw[0]
        result = imgs * masks + result * (1 - masks)

        result = self.to_numpy(result)
        alpha = self.to_numpy(alpha)
        raw = self.to_numpy(raw)

        for i in range(result.shape[0]):
            return result[i]

    def inpaint(self, img, mask, merge_result=False):

        if os.path.isfile(img) and os.path.isfile(mask):
            if img.endswith(('.png', '.jpg', '.jpeg')):
                self.process_single_file(img, mask)
                _type = 'file'
            else:
                raise NotImplementedError()
        elif os.path.isdir(img) and os.path.isdir(mask):
            self.process_dir(img, mask)
            _type = 'dir'
        else:
            print('Img: ', img)
            print('Mask: ', mask)
            raise NotImplementedError(
                'img and mask should be both file or directory.')

        print('# Inpainting...')
        print('Input size:', self.input_size)
        for batch in self.batch_generator():
            return self.process_batch(batch)
        print('Inpainting finished.')

        if merge_result and _type == 'dir':
            miss = self.sub_dir('miss')
            merge = self.sub_dir('merge')

            print('# Preparing input images...')
            gen_miss(img, mask, miss)
            print('# Merging...')
            merge_imgs([
                miss, self.sub_dir('raw'), self.sub_dir('alpha'),
                self.sub_dir('result'), img], merge, res=self.input_size[0])
            print('Merging finished.')
Beispiel #8
0
    elif config.data.data_flist[config.data.dataset][0].split(".")[1] == "mat":
        imgs.load_mat(
            expanduser(config.data.data_flist[config.data.dataset][0]),
            expanduser(config.data.data_flist[config.data.dataset][1]))
logger_.info("Data loading is finished.")

#
# Create model
# ----------------------------------------------------------------------------------------------------------------------
optimizer = getattr(tf.keras.optimizers,
                    config.optimizer.name)(config.optimizer.args.lr)

# train the model
if config.model.use_sngan:
    model = SNGAN(gen=DFNet(en_ksize=config.model.en_ksize,
                            de_ksize=config.model.de_ksize,
                            fuse_index=config.model.blend_layers),
                  dis=SN_Discriminator(c_num=3))
else:
    model = DFNet(en_ksize=config.model.en_ksize,
                  de_ksize=config.model.de_ksize,
                  fuse_index=config.model.blend_layers)
logger_.info("DFNet declaration is finished.")


#
# Model Training
# ----------------------------------------------------------------------------------------------------------------------
def show_batch(image_batch):
    plt.figure(figsize=(10, 10))
    for n in range(config.batch_size_infer):
    transforms.ToTensor()
])
    
print("IMAGE TRANSFORMS MADE")

dataset = DS(args.root, img_tf)

print("Loading iterator train")
iterator_train = iter(data.DataLoader(
    dataset, batch_size=args.batch_size,
    sampler=InfiniteSampler(len(dataset)),
    num_workers=args.n_threads
))
print(len(dataset))

model = DFNet().to(device)

lr = args.lr

start_iter = 0
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
criterion = InpaintingLoss().to(device)

if args.resume:
    checkpoint = torch.load(args.resume, map_location=device)
    model.load_state_dict(checkpoint)

for i in tqdm(range(start_iter, args.max_iter)):
    model.train()
Beispiel #10
0
    sss = StratifiedShuffleSplit(n_splits=10, test_size=0.1, random_state=0)
    tps, wps, fps, ps, ns = 0, 0, 0, 0, 0
    start_time = time.time()
    folder_num = 1
    for train_index, test_index in sss.split(X, y):
        # logger.info('Testing fold %d'%folder_num)
        folder_num += 1
        #       if folder_num > 2:
        #           break
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]

        # initialize the optimizer and model
        # print (time.sleep(2))
        model = DFNet.build(input_shape=INPUT_SHAPE, classes=NB_CLASSES)

        model.compile(loss="categorical_crossentropy",
                      optimizer=OPTIMIZER,
                      metrics=["accuracy"])
        # print ("Model compiled")

        # Start training
        history = model.fit(X_train,
                            y_train,
                            batch_size=BATCH_SIZE,
                            epochs=NB_EPOCH,
                            verbose=VERBOSE,
                            validation_split=0.1)

        y_pred = model.predict(X_test)
class Tester:
    def __init__(self, model_path, input_size, batch_size, dev='gpu'):
        self.model_path = model_path
        self._input_size = input_size
        self.batch_size = batch_size
        self.init_model(model_path, dev)

    @property
    def input_size(self):
        if self._input_size > 0:
            return (self._input_size, self._input_size)
        elif 'celeba' in self.model_path:
            return (256, 256)
        else:
            return (512, 512)

    def init_model(self, path, dev):
        #        if torch.cuda.is_available():
        #            self.device = torch.device('cuda')
        #            print('Using gpu.')
        #        else:
        #            self.device = torch.device('cpu')
        #            print('Using cpu.')
        if dev == 'cpu':
            self.device = torch.device('cpu')
            print('Using cpu.')
        else:
            self.device = torch.device('cuda')
            print('Using gpu.')

        self.model = DFNet().to(self.device)
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint)
        self.model.eval()

        print('Model %s loaded.' % path)

    def get_name(self, path):
        print(path)
        return '.'.join(path.split('.')[:-1])

    def inpaint_instance(self, img, mask):
        """Assume color image with 3 dimension. CWH"""
        img = img.view(1, *img.shape)
        mask = mask.view(1, 1, *mask.shape)
        return self.inpaint_batch(img, mask).squeeze()

    def inpaint_batch(self, imgs, masks):
        """Assume color channel is BGR and input is NWHC np.uint8."""
        imgs = np.transpose(imgs, [0, 3, 1, 2])
        masks = np.transpose(masks, [0, 3, 1, 2])

        imgs = torch.from_numpy(imgs).to(self.device)
        masks = torch.from_numpy(masks).to(self.device)
        imgs = imgs.float().div(255)
        masks = masks.float().div(255)
        imgs_miss = imgs * masks
        results = self.model(imgs_miss, masks)
        if type(results) is list:
            results = results[0]
        results = results.mul(255).byte().data.cpu().numpy()
        results = np.transpose(results, [0, 2, 3, 1])
        return results

    def _process_file(self, img_path, mask_path):
        item = {
            'img_path': img_path,
            'mask_path': mask_path,
        }
        self.path_pair.append(item)

    def process_single_file(self, img_path, mask_path):
        self.path_pair = []
        self._process_file(img_path, mask_path)

    def process_dir(self, img_dir, mask_dir):
        img_dir = Path(img_dir)
        mask_dir = Path(mask_dir)
        imgs_path = sorted(
            list(img_dir.glob('*.jpg')) + list(img_dir.glob('*.png')))
        masks_path = sorted(
            list(mask_dir.glob('*.jpg')) + list(mask_dir.glob('*.png')))

        n_img = len(imgs_path)
        n_mask = len(masks_path)
        n_pair = min(n_img, n_mask)

        self.path_pair = []
        for i in range(n_pair):
            img_path = imgs_path[i % n_img]
            mask_path = masks_path[i % n_mask]
            self._process_file(img_path, mask_path)

    def get_process(self, input_size):
        def process(pair):
            if input_size:
                img = cv2.resize(self.IMAGE, input_size)
                mask = cv2.resize(self.MASK, input_size)
            img = np.ascontiguousarray(img.transpose(2, 0, 1)).astype(np.uint8)
            mask = np.ascontiguousarray(np.expand_dims(mask,
                                                       0)).astype(np.uint8)

            pair['img'] = img
            pair['mask'] = mask
            return pair

        return process

    def _file_batch(self):
        pool = Pool()

        n_pair = len(self.path_pair)
        n_batch = (n_pair - 1) // self.batch_size + 1

        for i in range(n_batch):
            _buffer = defaultdict(list)
            start = i * self.batch_size
            stop = start + self.batch_size
            process = self.get_process(self.input_size)
            batch = pool.imap_unordered(process,
                                        islice(self.path_pair, start, stop))
            for instance in batch:
                for k, v in instance.items():
                    _buffer[k].append(v)
            yield _buffer

    def batch_generator(self):
        generator = self._file_batch

        for _buffer in generator():
            for key in _buffer:
                if key in ['img', 'mask']:
                    _buffer[key] = list2nparray(_buffer[key])
            yield _buffer

    def to_numpy(self, tensor):
        tensor = tensor.mul(255).byte().data.cpu().numpy()
        tensor = np.transpose(tensor, [0, 2, 3, 1])
        return tensor

    def process_batch(self, batch):
        imgs = torch.from_numpy(batch['img']).to(self.device)
        masks = torch.from_numpy(batch['mask']).to(self.device)
        imgs = imgs.float().div(255)
        masks = masks.float().div(255)
        imgs_miss = imgs * masks

        result, alpha, raw = self.model(imgs_miss, masks)
        result, alpha, raw = result[0], alpha[0], raw[0]
        result = imgs * masks + result * (1 - masks)

        result = self.to_numpy(result)
        alpha = self.to_numpy(alpha)
        raw = self.to_numpy(raw)

        for i in range(result.shape[0]):
            return result[i]

    def inpaint(self, img, mask, merge_result=False):

        self.IMAGE = img
        self.MASK = mask

        self.process_single_file("img.png", "mask.png")

        #        print('# Inpainting...')
        #        print('Input size:', self.input_size)
        for batch in self.batch_generator():
            return self.process_batch(batch)


#        print('Inpainting finished.')