Esempio n. 1
0
    def run_default_encoding(self, input, encoding_type, num_iterations,
                             display_intermediate_results):
        # load model
        print(f'Loading {encoding_type} model...')
        ckpt = self.checkpoints[encoding_type]
        opts = ckpt['opts']
        opts['checkpoint_path'] = self.model_paths[encoding_type]
        opts = Namespace(**opts)
        net = e4e(opts) if encoding_type == "horses" else pSp(opts)
        net.eval()
        net.cuda()
        print('Done!')

        # define some arguments
        opts.n_iters_per_batch = num_iterations
        opts.resize_outputs = False

        # define transforms
        image_transforms = self.cars_transforms if encoding_type == "cars" else self.default_transforms

        # if working on faces load and align the image
        if encoding_type == "faces":
            print('Aligning image...')
            input_image = self.run_alignment(str(input))
            print('Done!')
        # otherwise simply load the image
        else:
            input_image = Image.open(str(input)).convert("RGB")

        # preprocess image
        transformed_image = image_transforms(input_image)

        # run inference
        print("Running inference...")
        with torch.no_grad():
            start = time.time()
            avg_image = self.get_avg_image(net, encoding_type)
            result_batch, result_latents = run_on_batch(
                transformed_image.unsqueeze(0).cuda(), net, opts, avg_image)
            total_time = time.time() - start
        print(f"Finished inference in {total_time} seconds!")

        # post-processing
        print("Preparing result...")
        resize_amount = (512, 384) if encoding_type == "cars_encode" else (
            opts.output_size, opts.output_size)
        res = self.get_final_output(result_batch, resize_amount,
                                    display_intermediate_results, opts)

        # display output
        out_path = Path(tempfile.mkdtemp()) / "output.png"
        imageio.imwrite(str(out_path), res)
        return out_path
Esempio n. 2
0
    def __init__(self, opts, prev_train_checkpoint=None):
        self.opts = opts

        self.global_step = 0

        self.device = 'cuda:0'
        self.opts.device = self.device

        # Initialize network
        self.net = e4e(self.opts).to(self.device)

        # Estimate latent_avg via dense sampling if latent_avg is not available
        if self.net.latent_avg is None:
            self.net.latent_avg = self.net.decoder.mean_latent(
                int(1e5))[0].detach()

        # get the image corresponding to the latent average
        self.avg_image = self.net(self.net.latent_avg.unsqueeze(0),
                                  input_code=True,
                                  randomize_noise=False,
                                  return_latents=False,
                                  average_code=True)[0]
        self.avg_image = self.avg_image.to(self.device).float().detach()
        if self.opts.dataset_type == "cars_encode":
            self.avg_image = self.avg_image[:, 32:224, :]
        common.tensor2im(self.avg_image).save(
            os.path.join(self.opts.exp_dir, 'avg_image.jpg'))

        # Initialize loss
        if self.opts.id_lambda > 0 and self.opts.moco_lambda > 0:
            raise ValueError(
                'Both ID and MoCo loss have lambdas > 0! Please select only one to have non-zero lambda!'
            )
        self.mse_loss = nn.MSELoss().to(self.device).eval()
        if self.opts.lpips_lambda > 0:
            self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()
        if self.opts.id_lambda > 0:
            self.id_loss = id_loss.IDLoss().to(self.device).eval()
        if self.opts.moco_lambda > 0:
            self.moco_loss = moco_loss.MocoLoss()

        # Initialize optimizer
        self.optimizer = self.configure_optimizers()

        # Initialize discriminator
        if self.opts.w_discriminator_lambda > 0:
            self.discriminator = LatentCodesDiscriminator(512,
                                                          4).to(self.device)
            self.discriminator_optimizer = torch.optim.Adam(
                list(self.discriminator.parameters()),
                lr=opts.w_discriminator_lr)
            self.real_w_pool = LatentCodesPool(self.opts.w_pool_size)
            self.fake_w_pool = LatentCodesPool(self.opts.w_pool_size)

        # Initialize dataset
        self.train_dataset, self.test_dataset = self.configure_datasets()
        self.train_dataloader = DataLoader(self.train_dataset,
                                           batch_size=self.opts.batch_size,
                                           shuffle=True,
                                           num_workers=int(self.opts.workers),
                                           drop_last=True)
        self.test_dataloader = DataLoader(self.test_dataset,
                                          batch_size=self.opts.test_batch_size,
                                          shuffle=False,
                                          num_workers=int(
                                              self.opts.test_workers),
                                          drop_last=True)

        # Initialize logger
        log_dir = os.path.join(opts.exp_dir, 'logs')
        os.makedirs(log_dir, exist_ok=True)
        self.logger = SummaryWriter(log_dir=log_dir)

        # Initialize checkpoint dir
        self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        self.best_val_loss = None
        if self.opts.save_interval is None:
            self.opts.save_interval = self.opts.max_steps

        if prev_train_checkpoint is not None:
            self.load_from_train_checkpoint(prev_train_checkpoint)
            prev_train_checkpoint = None
def run():
    test_opts = TestOptions().parse()

    out_path_results = os.path.join(test_opts.exp_dir, 'inference_results')
    os.makedirs(out_path_results, exist_ok=True)

    # load model used for initializing encoder bootstrapping
    ckpt = torch.load(test_opts.model_1_checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts['checkpoint_path'] = test_opts.model_1_checkpoint_path
    opts = Namespace(**opts)
    if opts.encoder_type in ENCODER_TYPES['pSp']:
        net1 = pSp(opts)
    else:
        net1 = e4e(opts)
    net1.eval()
    net1.cuda()

    # load model used for translating input image after initialization
    ckpt = torch.load(test_opts.model_2_checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts['checkpoint_path'] = test_opts.model_2_checkpoint_path
    opts = Namespace(**opts)
    if opts.encoder_type in ENCODER_TYPES['pSp']:
        net2 = pSp(opts)
    else:
        net2 = e4e(opts)
    net2.eval()
    net2.cuda()

    print('Loading dataset for {}'.format(opts.dataset_type))
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    dataset = InferenceDataset(
        root=opts.data_path,
        transform=transforms_dict['transform_inference'],
        opts=opts)
    dataloader = DataLoader(dataset,
                            batch_size=opts.test_batch_size,
                            shuffle=False,
                            num_workers=int(opts.test_workers),
                            drop_last=False)

    if opts.n_images is None:
        opts.n_images = len(dataset)

    # get the image corresponding to the latent average
    avg_image = get_average_image(net1, opts)

    resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size,
                                                            opts.output_size)

    global_i = 0
    global_time = []
    for input_batch in tqdm(dataloader):
        if global_i >= opts.n_images:
            break
        with torch.no_grad():
            input_cuda = input_batch.cuda().float()
            tic = time.time()
            result_batch = run_on_batch(input_cuda, net1, net2, opts,
                                        avg_image)
            toc = time.time()
            global_time.append(toc - tic)

        for i in range(input_batch.shape[0]):
            results = [
                tensor2im(result_batch[i][iter_idx])
                for iter_idx in range(opts.n_iters_per_batch + 1)
            ]
            im_path = dataset.paths[global_i]

            input_im = tensor2im(input_batch[i])

            # save step-by-step results side-by-side
            res = np.array(results[0].resize(resize_amount))
            for idx, result in enumerate(results[1:]):
                res = np.concatenate(
                    [res, np.array(result.resize(resize_amount))], axis=1)
            res = np.concatenate([res, input_im.resize(resize_amount)], axis=1)
            Image.fromarray(res).save(
                os.path.join(out_path_results, os.path.basename(im_path)))

            global_i += 1

    stats_path = os.path.join(opts.exp_dir, 'stats.txt')
    result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time),
                                                 np.std(global_time))
    print(result_str)

    with open(stats_path, 'w') as f:
        f.write(result_str)
Esempio n. 4
0
def run():
    test_opts = TestOptions().parse()

    out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled')
    os.makedirs(out_path_coupled, exist_ok=True)

    # update test options with options used during training
    ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts = Namespace(**opts)

    if opts.encoder_type in ENCODER_TYPES['pSp']:
        net = pSp(opts)
    else:
        net = e4e(opts)

    net.eval()
    net.cuda()

    print('Loading dataset for {}'.format(opts.dataset_type))
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    dataset = InferenceDataset(root=opts.data_path,
                               transform=transforms_dict['transform_inference'],
                               opts=opts)
    dataloader = DataLoader(dataset,
                            batch_size=opts.test_batch_size,
                            shuffle=False,
                            num_workers=int(opts.test_workers),
                            drop_last=False)

    if opts.n_images is None:
        opts.n_images = len(dataset)

    # get the image corresponding to the latent average
    avg_image = net(net.latent_avg.unsqueeze(0),
                    input_code=True,
                    randomize_noise=False,
                    return_latents=False,
                    average_code=True)[0]
    avg_image = avg_image.to('cuda').float().detach()
    if opts.dataset_type == "cars_encode":
        avg_image = avg_image[:, 32:224, :]
    tensor2im(avg_image).save(os.path.join(opts.exp_dir, 'avg_image.jpg'))

    if opts.dataset_type == "cars_encode":
        resize_amount = (256, 192) if opts.resize_outputs else (512, 384)
    else:
        resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)

    global_i = 0
    global_time = []
    for input_batch in tqdm(dataloader):
        if global_i >= opts.n_images:
            break

        with torch.no_grad():
            input_cuda = input_batch.cuda().float()
            tic = time.time()
            result_batch, result_latents = run_on_batch(input_cuda, net, opts, avg_image)
            toc = time.time()
            global_time.append(toc - tic)

        for i in range(input_batch.shape[0]):
            results = [tensor2im(result_batch[i][iter_idx]) for iter_idx in range(opts.n_iters_per_batch)]
            im_path = dataset.paths[global_i]

            # save step-by-step results side-by-side
            input_im = tensor2im(input_batch[i])
            res = np.array(results[0].resize(resize_amount))
            for idx, result in enumerate(results[1:]):
                res = np.concatenate([res, np.array(result.resize(resize_amount))], axis=1)
            res = np.concatenate([res, input_im.resize(resize_amount)], axis=1)

            Image.fromarray(res).save(os.path.join(out_path_coupled, os.path.basename(im_path)))

            global_i += 1

    stats_path = os.path.join(opts.exp_dir, 'stats.txt')
    result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time))
    print(result_str)

    with open(stats_path, 'w') as f:
        f.write(result_str)
Esempio n. 5
0
def project(image_path: str, output_path: str, network: str,
            NUM_OUTPUT_IMAGES: int):
    experiment_type = 'ffhq_encode'  #['ffhq_encode', 'cars_encode', 'church_encode', 'horse_encode', 'afhq_wild_encode', 'toonify']

    CODE_DIR = 'restyle-encoder'

    def get_download_model_command(file_id, file_name):
        """ Get wget download command for downloading the desired model and save to directory ../pretrained_models. """
        current_directory = os.getcwd()
        save_path = os.path.join(os.path.dirname(current_directory), CODE_DIR,
                                 "pretrained_models")
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(
            FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)
        return url

    MODEL_PATHS = {
        "ffhq_encode": {
            "id": "1sw6I2lRIB0MpuJkpc8F5BJiSZrc0hjfE",
            "name": "restyle_psp_ffhq_encode.pt"
        },
        # "church_encode": {"id": "1bcxx7mw-1z7dzbJI_z7oGpWG1oQAvMaD", "name": "restyle_psp_church_encode.pt"},
        # "horse_encode": {"id": "19_sUpTYtJmhSAolKLm3VgI-ptYqd-hgY", "name": "restyle_e4e_horse_encode.pt"},
        # "afhq_wild_encode": {"id": "1GyFXVTNDUw3IIGHmGS71ChhJ1Rmslhk7", "name": "restyle_psp_afhq_wild_encode.pt"},
        # "toonify": {"id": "1GtudVDig59d4HJ_8bGEniz5huaTSGO_0", "name": "restyle_psp_toonify.pt"}
    }

    path = MODEL_PATHS[experiment_type]
    download_command = get_download_model_command(file_id=path["id"],
                                                  file_name=path["name"])

    EXPERIMENT_DATA_ARGS = {
        "ffhq_encode": {
            "model_path":
            network,  #"pretrained_models/restyle_psp_ffhq_encode.pt",
            # "image_path": "notebooks/images/face_img.jpg",
            "transform":
            transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
        },
        # "church_encode": {
        #     "model_path": "pretrained_models/restyle_psp_church_encode.pt",
        #     "image_path": "notebooks/images/church_img.jpg",
        #     "transform": transforms.Compose([
        #         transforms.Resize((256, 256)),
        #         transforms.ToTensor(),
        #         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
        # },
        # "horse_encode": {
        #     "model_path": "pretrained_models/restyle_e4e_horse_encode.pt",
        #     "image_path": "notebooks/images/horse_img.jpg",
        #     "transform": transforms.Compose([
        #         transforms.Resize((256, 256)),
        #         transforms.ToTensor(),
        #         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
        # },
        # "afhq_wild_encode": {
        #     "model_path": "pretrained_models/restyle_psp_afhq_wild_encode.pt",
        #     "image_path": "notebooks/images/afhq_wild_img.jpg",
        #     "transform": transforms.Compose([
        #         transforms.Resize((256, 256)),
        #         transforms.ToTensor(),
        #         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
        # },
        # "toonify": {
        #     "model_path": "pretrained_models/restyle_psp_toonify.pt",
        #     "image_path": "notebooks/images/toonify_img.jpg",
        #     "transform": transforms.Compose([
        #         transforms.Resize((256, 256)),
        #         transforms.ToTensor(),
        #         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
        # },
    }

    EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]

    if not os.path.exists(EXPERIMENT_ARGS['model_path']) or os.path.getsize(
            EXPERIMENT_ARGS['model_path']) < 1000000:
        print(f'Downloading ReStyle model for {experiment_type}...')
        os.system(f"wget {download_command}")
        # if google drive receives too many requests, we'll reach the quota limit and be unable to download the model
        if os.path.getsize(EXPERIMENT_ARGS['model_path']) < 1000000:
            raise ValueError(
                "Pretrained model was unable to be downloaded correctly!")
        else:
            print('Done.')
    else:
        print(f'ReStyle model for {experiment_type} already exists!')

    # MOVE INFERENCE FUNCTIONS AND PRE-DECLARATIONS OUTSIDE OF INFINITE LOOP
    def get_avg_image(net):
        avg_image = net(net.latent_avg.unsqueeze(0),
                        input_code=True,
                        randomize_noise=False,
                        return_latents=False,
                        average_code=True)[0]
        avg_image = avg_image.to('cuda').float().detach()
        if experiment_type == "cars_encode":
            avg_image = avg_image[:, 32:224, :]
        return avg_image

    from utils.inference_utils import run_on_batch

    # MOVE FUNCTION TO DISPLAY RESULTS OUTSIDE OF LOOP
    def get_coupled_results(result_batch, transformed_image):
        """
      Visualize output images from left to right (the input image is on the right)
      """
        result_tensors = result_batch[0]  # there's one image in our batch
        result_images = [
            tensor2im(result_tensors[iter_idx])
            for iter_idx in range(opts.n_iters_per_batch)
        ]
        input_im = tensor2im(transformed_image)
        res = np.array(result_images[0].resize(resize_amount))
        for idx, result in enumerate(result_images[1:]):
            res = np.concatenate(
                [res, np.array(result.resize(resize_amount))], axis=1)
        res = np.concatenate([res, input_im.resize(resize_amount)], axis=1)
        res = Image.fromarray(res)
        return res, result_images


##### ADD INFINITE LOOP ################################################################

    time_before = time.time()

    # LOAD PRETRAINED MODEL
    model_path = EXPERIMENT_ARGS['model_path']
    ckpt = torch.load(model_path, map_location='cpu')

    opts = ckpt['opts']

    # resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)

    # number of output images
    opts['n_iters_per_batch'] = NUM_OUTPUT_IMAGES
    opts['resize_outputs'] = False  # generate outputs at full resolution

    # update the training options
    opts['checkpoint_path'] = model_path

    opts = Namespace(**opts)
    if experiment_type == 'horse_encode':
        net = e4e(opts)
    else:
        net = pSp(opts)

    net.eval()
    net.cuda()
    print('Model successfully loaded!')

    time_after_loading = time.time()
    print('Time to load model took {:.4f} seconds.'.format(time_after_loading -
                                                           time_before))

    # VISUALIZE INPUT
    # image_path = EXPERIMENT_DATA_ARGS[experiment_type]["image_path"]
    original_image = Image.open(image_path).convert("RGB")

    # if experiment_type == 'cars_encode':
    #     original_image = original_image.resize((192, 256))
    # else:
    #     original_image = original_image.resize((256, 256))

    # ALIGN IMAGE
    # def run_alignment(image_path):
    #     import dlib
    #     from scripts.align_faces_parallel import align_face
    #     if not os.path.exists("shape_predictor_68_face_landmarks.dat"):
    #         print('Downloading files for aligning face image...')
    #         os.system('wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2')
    #         os.system('bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2')
    #         print('Done.')
    #     predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
    #     aligned_image = align_face(filepath=image_path, predictor=predictor)
    #     print("Aligned image has shape: {}".format(aligned_image.size))
    #     return aligned_image

    # if experiment_type in ['ffhq_encode', 'toonify']:
    #     input_image = run_alignment(image_path)
    # else:
    #     input_image = original_image
    input_image = original_image

    # PERFORM INFERENCE
    img_transforms = EXPERIMENT_ARGS['transform']
    transformed_image = img_transforms(input_image)

    with torch.no_grad():
        avg_image = get_avg_image(net)
        tic = time.time()
        result_batch, result_latents = run_on_batch(
            transformed_image.unsqueeze(0).cuda(), net, opts, avg_image)
        toc = time.time()
        print('Inference took {:.4f} seconds.'.format(toc - tic))

    # VISUALIZE RESULT

    # get results & save
    res, result_images = get_coupled_results(result_batch, transformed_image)

    time_after = time.time()
    print(
        'Time to load model, perform projection, and save out the results took {:.4f} seconds.'
        .format(time_after - time_before))

    return res, result_images
def run():
    """
    This script can be used to perform inversion and editing. Please note that this script supports editing using
    only the ReStyle-e4e model and currently supports editing using three edit directions found using InterFaceGAN
    (age, smile, and pose) on the faces domain.
    For performing the edits please provide the arguments `--edit_directions` and `--factor_ranges`. For example,
    setting these values to be `--edit_directions=age,smile,pose` and `--factor_ranges=5,5,5` will use a lambda range
    between -5 and 5 for each of the attributes. These should be comma-separated lists of the same length. You may
    get better results by playing around with the factor ranges for each edit.
    """
    test_opts = TestOptions().parse()

    out_path_results = os.path.join(test_opts.exp_dir, 'editing_results')
    out_path_coupled = os.path.join(test_opts.exp_dir, 'editing_coupled')

    os.makedirs(out_path_results, exist_ok=True)
    os.makedirs(out_path_coupled, exist_ok=True)

    # update test options with options used during training
    ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts = Namespace(**opts)
    net = e4e(opts)
    net.eval()
    net.cuda()

    print('Loading dataset for {}'.format(opts.dataset_type))
    if opts.dataset_type != "ffhq_encode":
        raise ValueError(
            "Editing script only supports edits on the faces domain!")
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    dataset = InferenceDataset(
        root=opts.data_path,
        transform=transforms_dict['transform_inference'],
        opts=opts)
    dataloader = DataLoader(dataset,
                            batch_size=opts.test_batch_size,
                            shuffle=False,
                            num_workers=int(opts.test_workers),
                            drop_last=False)

    if opts.n_images is None:
        opts.n_images = len(dataset)

    latent_editor = LatentEditor(net.decoder)
    opts.edit_directions = opts.edit_directions.split(',')
    opts.factor_ranges = [
        int(factor) for factor in opts.factor_ranges.split(',')
    ]
    if len(opts.edit_directions) != len(opts.factor_ranges):
        raise ValueError(
            "Invalid edit directions and factor ranges. Please provide a single factor range for each"
            f"edit direction. Given: {opts.edit_directions} and {opts.factor_ranges}"
        )

    avg_image = get_average_image(net, opts)

    global_i = 0
    global_time = []
    for input_batch in tqdm(dataloader):
        if global_i >= opts.n_images:
            break
        with torch.no_grad():
            input_cuda = input_batch.cuda().float()
            tic = time.time()
            result_batch = edit_batch(input_cuda, net, avg_image,
                                      latent_editor, opts)
            toc = time.time()
            global_time.append(toc - tic)

        resize_amount = (256,
                         256) if opts.resize_outputs else (opts.output_size,
                                                           opts.output_size)
        for i in range(input_batch.shape[0]):

            im_path = dataset.paths[global_i]
            results = result_batch[i]

            inversion = results.pop('inversion')
            input_im = tensor2im(input_batch[i])

            all_edit_results = []
            for edit_name, edit_res in results.items():
                res = np.array(
                    input_im.resize(resize_amount))  # set the input image
                res = np.concatenate(
                    [res, np.array(inversion.resize(resize_amount))],
                    axis=1)  # set the inversion
                for result in edit_res:
                    res = np.concatenate(
                        [res, np.array(result.resize(resize_amount))], axis=1)
                res_im = Image.fromarray(res)
                all_edit_results.append(res_im)

                edit_save_dir = os.path.join(out_path_results, edit_name)
                os.makedirs(edit_save_dir, exist_ok=True)
                res_im.save(
                    os.path.join(edit_save_dir, os.path.basename(im_path)))

            # save final concatenated result if all factor ranges are equal
            if opts.factor_ranges.count(opts.factor_ranges[0]) == len(
                    opts.factor_ranges):
                coupled_res = np.concatenate(all_edit_results, axis=0)
                im_save_path = os.path.join(out_path_coupled,
                                            os.path.basename(im_path))
                Image.fromarray(coupled_res).save(im_save_path)

            global_i += 1

    stats_path = os.path.join(opts.exp_dir, 'stats.txt')
    result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time),
                                                 np.std(global_time))
    print(result_str)

    with open(stats_path, 'w') as f:
        f.write(result_str)