Esempio n. 1
0
def syn_new_segmentations(args, params, model):
    raise NotImplementedError('Needs code refactoring')
    # only one trial for now
    z_shapes = calc_z_shapes(params['channels'], params['img_size'],
                             params['n_block'])
    path = params['samples_path']['segment'][args.cond_mode][args.model] + \
           f'/syn_segs/optim_step={args.last_optim_step}'
    # helper.make_dir_if_not_exists(path)

    n_samples = 2
    for trial in range(args.trials):
        trial_path = f'{path}/i={trial}'
        helper.make_dir_if_not_exists(trial_path)

        for temp in [1.0, 0.7, 0.5, 0.3, 0.1, 0.0]:
            z_a_samples = sample_z(z_shapes, n_samples, temp, device)
            syn_segmentations = model.reverse(z_a_samples=z_a_samples)

            z_b_samples = sample_z(z_shapes, n_samples, temp, device)
            syn_reals = model.reverse(x_a=syn_segmentations,
                                      z_b_samples=z_b_samples)
            all_imgs = torch.cat([syn_segmentations, syn_reals], dim=0)
            utils.save_image(all_imgs,
                             f'{trial_path}/temp={temp}.png',
                             nrow=n_samples)

            print(f'Temp={temp}: done')
        print(f'Trial={trial}: done')
Esempio n. 2
0
def diverse_samples(args, params):
    """
    command:
    python3 main.py --local --run diverse --image_name strasbourg_000001_061472 --temp 0.9 --model improved_so_large --last_optim_step 276000 \
                    --img_size 512 1024 --dataset cityscapes --direction label2photo \
                    --n_block 4 --n_flow 10 10 10 10 --do_lu --reg_factor 0.0001 --grad_checkpoint
    """
    optim_step = args.last_optim_step
    temperature = args.temp
    n_samples = 10
    img_size = [512, 1024]
    n_blocks = args.n_block
    image_name = args.image_name

    image_cond_path = f'/local_storage/datasets/moein/cityscapes/gtFine_trainvaltest/gtFine/train/{image_name.split("_")[0]}/{image_name}_gtFine_color.png'
    image_gt_path = f'/local_storage/datasets/moein/cityscapes/leftImg8bit_trainvaltest/leftImg8bit/train/{image_name.split("_")[0]}/{image_name}_leftImg8bit.png'

    model_paths = helper.compute_paths(args, params)
    output_folder = os.path.join(model_paths['diverse_path'], image_name, f'temp={temperature}')
    print(f'out folder: {output_folder}')

    helper.make_dir_if_not_exists(output_folder)

    shutil.copyfile(image_cond_path, os.path.join(output_folder, 'segment.png'))
    shutil.copyfile(image_gt_path, os.path.join(output_folder, 'real.png'))
    print('Copy segment and real images: done')

    model = models.init_model(args, params)
    model = helper.load_checkpoint(model_paths['checkpoints_path'], optim_step, model, optimizer=None, resume_train=False)[0]

    step = 2
    for i in range(0, 10, step):
        paths_list = [os.path.join(output_folder, f'sample_{(i + 1) + j}.png') for j in range(step)]
        experiments.take_multiple_samples(model, n_blocks, temperature, step, img_size, image_cond_path, paths_list)
def download_data(params):
    # URLs for the zip files
    links = [
        'https://nihcc.box.com/shared/static/vfk49d74nhbxq3nqjg0900w5nvkorp5c.gz',
        'https://nihcc.box.com/shared/static/i28rlmbvmfjbl8p2n3ril0pptcmcu9d1.gz',
        'https://nihcc.box.com/shared/static/f1t00wrtdk94satdfb9olcolqx20z2jp.gz',
        'https://nihcc.box.com/shared/static/0aowwzs5lhjrceb3qp67ahp0rd1l1etg.gz',
        'https://nihcc.box.com/shared/static/v5e3goj22zr6h8tzualxfsqlqaygfbsn.gz',
        'https://nihcc.box.com/shared/static/asi7ikud9jwnkrnkj99jnpfkjdes7l6l.gz',
        'https://nihcc.box.com/shared/static/jn1b4mw4n6lnh74ovmcjb8y48h8xj07n.gz',
        'https://nihcc.box.com/shared/static/tvpxmn7qyrgl0w8wfh9kqfjskv6nmm1j.gz',
        'https://nihcc.box.com/shared/static/upyy3ml7qdumlgk2rfcvlb9k6gvqq2pj.gz',
        'https://nihcc.box.com/shared/static/l6nilvfa9cg3s28tqv1qc1olm3gnz54p.gz',
        'https://nihcc.box.com/shared/static/hhq8fkdgvcari67vfhs7ppg2w6ni4jze.gz',
        'https://nihcc.box.com/shared/static/ioqwiy20ihqwyr8pf4c24eazhh281pbu.gz'
    ]

    download_path = params['download_path']
    helper.make_dir_if_not_exists(download_path)

    for idx, link in enumerate(links):
        file_name = download_path + '/' + 'images_%02d.tar.gz' % idx

        print(f'In [download_data]: downloading at "{file_name}"...')
        urllib.request.urlretrieve(link, file_name)  # download the zip file
    print("In [download_data]: download complete. \n")
Esempio n. 4
0
def evaluate_model(args, params):
    loader_params = {
        'batch_size': params['batch_size'],
        'shuffle': False,
        'num_workers': params['num_workers']
    }
    batch_size = params['batch_size']

    _, _, test_loader = data_handler.init_data_loaders(params, loader_params)

    total_predicted = np.zeros((batch_size, 14))
    total_labels = np.zeros((batch_size, 14))

    path_to_load = helper.compute_paths(
        args, params)['save_path']  # compute path from args and params
    net = networks.init_unified_net(args.model, params)
    net, _ = helper.load_model(path_to_load,
                               args.epoch,
                               net,
                               optimizer=None,
                               resume_train=False)

    print(
        f'In [evaluate_model]: loading the model done from: "{path_to_load}"')
    print(
        f'In [evaluate_model]: starting evaluation with {len(test_loader)} batches'
    )

    with torch.no_grad():
        for i_batch, batch in enumerate(test_loader):
            img_batch = batch['image'].to(device).float()
            label_batch = batch['label'].to(device).float()
            pred = net(img_batch)

            if i_batch > 0:
                total_predicted = np.append(total_predicted,
                                            pred.cpu().detach().numpy(),
                                            axis=0)
                total_labels = np.append(total_labels,
                                         label_batch.cpu().detach().numpy(),
                                         axis=0)
            else:
                total_predicted = pred.cpu().detach().numpy()
                total_labels = label_batch.cpu().detach().numpy()

            if i_batch % 50 == 0:
                print(
                    f'In [evaluate_model]: prediction done for batch {i_batch}'
                )

    results_path = helper.compute_paths(args, params)['results_path']
    helper.make_dir_if_not_exists(results_path)

    # plot roc
    print(f'In [evaluate_model]: starting plotting ROC...')
    plotting.plot_roc(total_predicted, total_labels, pathology_names,
                      results_path)
Esempio n. 5
0
def retrieve_data(sess, hps, args, params):
    train_tfrecords_file = params['tfrecords_file']['train']
    val_tfrecords_file = params['tfrecords_file']['val']

    train_iter = data_io.read_tfrecords(train_tfrecords_file,
                                        args.dataset,
                                        args.direction,
                                        hps.batch_size,
                                        is_training=True)
    valid_iter = data_io.read_tfrecords(val_tfrecords_file,
                                        args.dataset,
                                        args.direction,
                                        hps.batch_size,
                                        is_training=False)

    first_batch = make_first_batch(sess, train_iter)

    dataset_name, direction = args.dataset, args.direction
    assert dataset_name == 'cityscapes'  # get from helper

    run_mode = 'infer' if args.exp else 'train'
    if run_mode == 'infer':
        if args.direction == 'label2photo':
            conditions_paths = helper.get_all_data_folder_images(
                path=params['data_folder']['segment'],
                partition='val',
                image_type='segment')
        else:
            conditions_paths = helper.get_all_data_folder_images(
                path=params['data_folder']['real'],
                partition='val',
                image_type='real')
        cond_list = create_conditions(conditions_paths, also_visual_grid=False)
        print(
            f'In [retrieve_data]: conditions_paths for inference is of len: {len(conditions_paths)}'
        )

    else:
        conditions_paths = globals.seg_conds_abs_paths if direction == 'label2photo' else globals.real_conds_abs_path
        cond_list, visual_grid = create_conditions(conditions_paths)

        # save condition if not exists
        samples_path = helper.compute_paths(args, params)['samples_path']
        helper.make_dir_if_not_exists(samples_path)
        path = os.path.join(samples_path, 'conditions.png')

        if not os.path.isfile(path):
            Image.fromarray(visual_grid).save(path)
            print(f'In [retrieve_data]: saved conditions to: "{path}"')

    # make a list of dicts containing both image path and image array
    conditions = [{
        'image_path': conditions_paths[i],
        'image_array': cond_list[i]
    } for i in range(len(cond_list))]
    return train_iter, valid_iter, first_batch, conditions
Esempio n. 6
0
def train(args, params, model, optimizer, tracker=None):
    loader_params = {
        'batch_size': params['batch_size'],
        'shuffle': params['shuffle'],
        'num_workers': params['num_workers']
    }

    train_loader, val_loader, _ = data_handler.init_data_loaders(
        params, loader_params)

    epoch = 0
    max_epochs = params['max_epochs']

    while epoch < max_epochs:
        print(f'{"=" * 40} In epoch: {epoch} {"=" * 40}')
        print(f'Training on {len(train_loader)} batches...')

        # each for loop for one epoch
        for i_batch, batch in enumerate(train_loader):
            # converting the labels batch  to from Long tensor to Float tensor (otherwise won't work on GPU)
            img_batch = batch['image'].to(device).float()
            label_batch = batch['label'].to(device).float()

            # making gradients zero in each optimization step
            optimizer.zero_grad()

            # getting the network prediction and computing the loss
            pred = model(img_batch)
            train_loss = loss.compute_wcel(fx=pred, labels=label_batch)
            # if i_batch % 50 == 0:
            print(
                f'Batch: {i_batch} - train loss: {round(train_loss.item(), 3)}'
            )

            # tracking the metrics using comet in each iteration
            track('train_loss', round(train_loss.item(), 3), args, tracker)

            # backward and optimization step
            train_loss.backward()
            optimizer.step()

        # save checkpoint after epoch
        save_path = helper.compute_paths(args, params)['save_path']
        helper.make_dir_if_not_exists(save_path)
        helper.save_model(save_path, epoch, model, optimizer,
                          train_loss)  # save the model every epoch

        val_loss = loss.compute_val_loss(model, val_loader)  # track val loss
        print(f'In [train]: epoch={epoch} - val_loss = {val_loss}')
        track('val_loss', val_loss, args, tracker)
        epoch += 1
Esempio n. 7
0
def create_rev_cond(args, params, fixed_conds, also_save=True):
    trans = loader.init_transformers(params['img_size'])
    # only supports reading from train data now
    images = read_imgs(params['data_folder'],
                       split='train',
                       fixed_conds=fixed_conds)

    conds_as_list = []
    real_as_list = []

    for image in images:
        h, w = params['img_size']
        # get half of the image corresponding to either the photo or the map
        photo = trans(
            Image.open(image)
        )[:, :, :
          w]  # if args.direction == 'photo2map' else trans(Image.open(image))[:, :, w:]
        the_map = trans(Image.open(image))[:, :, w:]

        if args.direction == 'map2photo':
            cond = the_map
            real = photo
        else:
            cond = photo
            real = the_map

        conds_as_list.append(cond)
        real_as_list.append(real)

    conds_as_tensor = torch.stack(conds_as_list, dim=0).to(device)
    real_as_tensor = torch.stack(real_as_list, dim=0).to(device)

    if also_save:
        save_path = helper.compute_paths(args, params)['samples_path']
        helper.make_dir_if_not_exists(save_path)

        if 'conditions.png' not in os.listdir(save_path):
            utils.save_image(conds_as_tensor,
                             os.path.join(save_path, 'conditions.png'))
            print(f'In [create_rev_cond]: saved reverse conditions')

        if 'real.png' not in os.listdir(save_path):
            utils.save_image(real_as_tensor,
                             os.path.join(save_path, 'real.png'))
            print(f'In [create_rev_cond]: saved real images')

    print(
        f'In [create_rev_cond]: returning cond_as_tensor of shape: {conds_as_tensor.shape}'
    )
    return conds_as_tensor
def extract_data(params):
    archive_path, extract_path = params['download_path'], params[
        'extracted_path']
    helper.make_dir_if_not_exists(extract_path)

    for fname in os.listdir(archive_path):
        file_path = f'{archive_path}/{fname}'

        tar = tarfile.open(file_path, 'r:gz')
        tar.extractall(path=extract_path)

        print(
            f'In [extract_data]: extracted "{archive_path}/{fname}" to "{extract_path}".'
        )
        tar.close()
    print('In [extract_data]: all done.')
Esempio n. 9
0
def prepare_experiment(params, args, device, exp_name):
    checkpoint_pth = params['checkpoints_path'][
        'conditional']  # always conditional
    optim_step = args.last_optim_step
    save_path = params['samples_path']['conditional'] + f'/{exp_name}'
    make_dir_if_not_exists(save_path)

    # init model and load checkpoint
    model = init_glow(params)
    model, _, _ = load_checkpoint(checkpoint_pth,
                                  optim_step,
                                  model,
                                  None,
                                  resume_train=False)

    return model, save_path
Esempio n. 10
0
def sample_c_flow_conditional(args, params, model):
    raise NotImplementedError('Needs code refactoring')
    trials_pth = params['samples_path']['real'][args.cond_mode][args.model] + \
                 f'/trials/optim_step={args.last_optim_step}'
    helper.make_dir_if_not_exists(trials_pth)

    segmentations, _, real_imgs = \
        _create_cond(params['n_samples'],
                     params['data_folder'],
                     params['img_size'],
                     device,
                     save_path=trials_pth)

    z_shapes = calc_z_shapes(params['channels'], params['img_size'],
                             params['n_block'])
    # split into tensors of 5 img: better for visualization
    seg_splits = torch.split(segmentations, split_size_or_sections=5, dim=0)
    real_splits = torch.split(real_imgs, split_size_or_sections=5, dim=0)

    for i in range(len(seg_splits)):
        print(
            f'====== Doing for the {i}th tensor in seg_splits and real_splits')
        n_samples = seg_splits[i].shape[0]
        # ============ different temperatures
        for temp in [1.0, 0.7, 0.5, 0.3, 0.1, 0.0]:
            all_imgs = torch.cat(
                [seg_splits[i].cpu().data, real_splits[i].cpu().data], dim=0)

            # ============ different trials with different z samples
            for trial in range(args.trials):  # sample for trials times
                z_samples = sample_z(z_shapes, n_samples, temp, device)
                with torch.no_grad():
                    sampled_images = model.reverse(
                        x_a=seg_splits[i], z_b_samples=z_samples).cpu().data

                    # all_imgs.append(sampled_images)
                    all_imgs = torch.cat([all_imgs, sampled_images], dim=0)
                # utils.save_image(sampled_images, f'{trials_pth}/trial={trial}.png', nrow=10)
                print(f'Temp={temp} - Trial={trial}: done')

            # save the images for the given temperature
            path = f'{trials_pth}/i={i}'
            helper.make_dir_if_not_exists(path)
            utils.save_image(all_imgs,
                             f'{path}/temp={temp}.png',
                             nrow=n_samples)
Esempio n. 11
0
def new_condition(img_list, params, args, device):
    checkpoint_pth = params['checkpoints_path'][
        'conditional']  # always conditional
    optim_step = args.last_optim_step
    save_path = params['samples_path']['conditional'] + f'/new_condition'
    make_dir_if_not_exists(save_path)

    # init model and load checkpoint
    model = init_glow(params)
    model, _, _ = load_checkpoint(checkpoint_pth,
                                  optim_step,
                                  model,
                                  None,
                                  resume_train=False)

    for img_num in img_list:
        all_sampled = []
        img, label = get_image(img_num,
                               params['data_folder'],
                               args.img_size,
                               ret_type='batch')

        # get the latent vectors of the image
        forward_cond = (args.dataset, label)
        _, _, z_list = model(
            img, forward_cond
        )  # get the latent vectors corresponding to the style of the chosen image

        for digit in range(10):
            new_cond = ('mnist', digit, 1)
            # pass the new cond along with the extracted latent vectors
            # apply it to a new random image with another condition (another digit)
            sampled_img = model.reverse(z_list,
                                        reconstruct=True,
                                        coupling_conds=new_cond)
            all_sampled.append(
                sampled_img.squeeze(dim=0)
            )  # removing the batch dimension (=1) for the sampled image
            print(f'In [new_condition]: sample with digit={digit} done.')

        utils.save_image(all_sampled,
                         f'{save_path}/img={img_num}.png',
                         nrow=10)
        print(f'In [new_condition]: done for img_num {img_num}')
Esempio n. 12
0
def create_rev_cond(args, params, also_save=True):
    trans = util.init_transformer(params['img_size'])

    # if args.direction == 'daylight2night':
    names = util.fixed_conds[args.direction]
    conds = [trans(Image.open(f"{params['paths']['data_folder']}/{name}")) for name in names]

    # else:
    #    raise NotImplementedError

    conditions = torch.stack(conds, dim=0).to(device)  # (B, C, H, W)

    if also_save:
        save_path = helper.compute_paths(args, params)['samples_path']
        helper.make_dir_if_not_exists(save_path)

        if 'conditions.png' not in os.listdir(save_path):
            utils.save_image(conditions, f'{save_path}/conditions.png')
            print(f'In [create_rev_cond]: saved reverse conditions')
    return conditions
Esempio n. 13
0
def create_tf_records(args, params):
    if args.dataset == 'cityscapes':
        # for train data
        tfrecords_file = params['tfrecords_file']['train']
        gt_fine_path = os.path.join(params['data_folder']['segment'], 'train')
        helper.make_dir_if_not_exists(os.path.split(tfrecords_file)[0])

        dual_glow.write_data_for_tf(tfrecords_file, gt_fine_path)
        print(
            f'In [create_tf_records]: creating tf_records for train data: done. Saved to: "{tfrecords_file}"'
        )

        # for val data
        tfrecords_file = params['tfrecords_file']['val']
        gt_fine_path = os.path.join(params['data_folder']['segment'], 'val')

        dual_glow.write_data_for_tf(tfrecords_file, gt_fine_path)
        print(
            f'In [create_tf_records]: creating tf_records for val data: done. Saved to: "{tfrecords_file}"'
        )

    else:
        raise NotImplementedError
Esempio n. 14
0
def transfer(args, params):
    content_basepath = '../data/cityscapes_complete_downsampled/all_reals'
    cond_basepath = '../data/cityscapes_complete_downsampled/all_segments'

    # pure_content = 'jena_000011_000019'
    # pure_new_cond = 'jena_000066_000019'

    # pure_content = 'aachen_000028_000019'
    # pure_new_cond = 'jena_000011_000019'

    # pure_content = 'jena_000011_000019'
    # pure_new_cond = 'aachen_000010_000019'

    pure_content = 'aachen_000034_000019'
    pure_new_cond = 'bochum_000000_016260'

    content = f'{content_basepath}/{pure_content}.png'  # content image
    condition = f'{cond_basepath}/{pure_content}.png'   # corresponding condition needed to extract z
    new_cond = f'{cond_basepath}/{pure_new_cond}.png'   # new condition

    save_basepath = '../samples/content_transfer_local'
    helper.make_dir_if_not_exists(save_basepath)
    file_path = f'{save_basepath}/content={pure_content}_condition={pure_new_cond}.png'
    experiments.transfer_content(args, params, content, condition, new_cond, file_path)
Esempio n. 15
0
def sample_with_new_condition(args, params):
    """
    In this function, temperature has no effect here as we have no random sampling.
    :param args:
    :param params:
    :return:
    """
    model = models.init_and_load(args, params, run_mode='infer')
    orig_real_name = new_cond_reals[
        'orig_img']  # image paths of the original image (with the desired style)
    orig_pure_name = orig_real_name.split('/')[-1][:-len('_leftImg8bit.png')]
    print(f'In [sample_with_new_cond]: orig cond is: "{orig_pure_name}" \n')

    # ========= get the original segmentation and real image
    orig_seg_batch, _, orig_real_batch, orig_bmap_batch = \
        data_handler._create_cond(params, fixed_conds=[orig_real_name], save_path=None)  # (1, C, H, W)

    # make b_maps None is not needed
    if not args.use_bmaps:
        orig_bmap_batch = None

    # ========= real_img_name: the real image whose segmentation which will be used for conditioning.
    for new_cond_name in new_cond_reals['cond_imgs']:
        new_cond_pure_name = new_cond_name.split(
            '/')[-1][:-len('_leftImg8bit.png')]
        additional_info = {
            'orig_pure_name': orig_pure_name,  # original condition city name
            'new_cond_pure_name':
            new_cond_pure_name,  # new condition city name
            'exp_type': 'new_cond'
        }
        print(
            f'In [sample_with_new_cond]: doing for images: "{new_cond_pure_name}" {"=" * 50} \n'
        )

        # ========= getting new segment cond and real image batch
        exp_path, new_rev_cond, new_real_batch = prep_for_sampling(
            args, params, new_cond_name, additional_info)

        # ========= new_cond segment and bmap batches
        new_seg_batch = new_rev_cond['segment']  # (1, C, H, W)
        new_bmap_batch = new_rev_cond['boundary']

        if not args.use_bmaps:
            new_bmap_batch = None

        # ========= save new segmentation and the corresponding real imgs
        helper.make_dir_if_not_exists(exp_path)

        utils.save_image(new_seg_batch.clone(),
                         f"{exp_path}/new_seg.png",
                         nrow=1,
                         padding=0)
        utils.save_image(new_real_batch.clone(),
                         f"{exp_path}/new_real.png",
                         nrow=1,
                         padding=0)

        # ========= save the original segmentation and real image
        utils.save_image(orig_seg_batch.clone(),
                         f"{exp_path}/orig_seg.png",
                         nrow=1,
                         padding=0)
        utils.save_image(orig_real_batch.clone(),
                         f"{exp_path}/orig_real.png",
                         nrow=1,
                         padding=0)
        print(
            f'In [sample_with_new_cond]: saved original and new segmentation images'
        )

        # =========== getting z_real corresponding to the desired style (x_b) from the orig real images
        left_glow_outs, right_glow_outs = model(x_a=orig_seg_batch,
                                                x_b=orig_real_batch,
                                                b_map=orig_bmap_batch)
        z_real = right_glow_outs['z_outs']  # desired style

        # =========== apply the new condition to the desired style
        new_real_syn = model.reverse(
            x_a=new_seg_batch,  # new segmentation (condition)
            extra_cond=new_bmap_batch,
            z_b_samples=z_real,  # desired style
            mode='new_condition')  # (1, C, H, W)
        utils.save_image(new_real_syn.clone(),
                         f"{exp_path}/new_real_syn.png",
                         nrow=1,
                         padding=0)
        print(f'In [sample_with_new_cond]: save synthesized real image')

        # =========== save all images of combined in a grid
        all_together = torch.cat([
            orig_seg_batch, new_seg_batch, orig_real_batch, new_real_batch,
            new_real_syn
        ],
                                 dim=0)
        utils.save_image(all_together.clone(),
                         f"{exp_path}/all.png",
                         nrow=2,
                         padding=10)

        exp_pure_path = exp_path.split('/')[-1]
        all_together_path = os.path.split(
            exp_path)[0] + '/all'  # 'all' dir in the previous dir
        helper.make_dir_if_not_exists(all_together_path)

        utils.save_image(all_together.clone(),
                         f"{all_together_path}/{exp_pure_path}.png",
                         nrow=2,
                         padding=10)
        print(
            f'In [sample_with_new_cond]: for images: "{new_cond_pure_name}": done {"=" * 50} \n'
        )
Esempio n. 16
0
def train(args,
          params,
          train_configs,
          model,
          optimizer,
          current_lr,
          comet_tracker=None,
          resume=False,
          last_optim_step=0,
          reverse_cond=None):
    # getting data loaders
    train_loader, val_loader = data_handler.init_data_loaders(args, params)

    # adjusting optim step
    optim_step = last_optim_step + 1 if resume else 1
    max_optim_steps = params['iter']
    paths = helper.compute_paths(args, params)

    if resume:
        print(
            f'In [train]: resuming training from optim_step={optim_step} - max_step: {max_optim_steps}'
        )

    # optimization loop
    while optim_step < max_optim_steps:
        # after each epoch, adjust learning rate accordingly
        current_lr = adjust_lr(
            current_lr,
            initial_lr=params['lr'],
            step=optim_step,
            epoch_steps=len(
                train_loader))  # now only supports with batch size 1
        for param_group in optimizer.param_groups:
            param_group['lr'] = current_lr
        print(
            f'In [train]: optimizer learning rate adjusted to: {current_lr}\n')

        for i_batch, batch in enumerate(train_loader):
            if optim_step > max_optim_steps:
                print(
                    f'In [train]: reaching max_step or lr is zero. Terminating...'
                )
                return  # ============ terminate training if max steps reached

            begin_time = time.time()
            # forward pass
            left_batch, right_batch, extra_cond_batch = data_handler.extract_batches(
                batch, args)
            forward_output = forward_and_loss(args, params, model, left_batch,
                                              right_batch, extra_cond_batch)

            # regularize left loss
            if train_configs['reg_factor'] is not None:
                loss = train_configs['reg_factor'] * forward_output[
                    'loss_left'] + forward_output['loss_right']  # regularized
            else:
                loss = forward_output['loss']

            metrics = {'loss': loss}
            # also add left and right loss if available
            if 'loss_left' in forward_output.keys():
                metrics.update({
                    'loss_right': forward_output['loss_right'],
                    'loss_left': forward_output['loss_left']
                })

            # backward pass and optimizer step
            model.zero_grad()
            loss.backward()
            optimizer.step()
            print(f'In [train]: Step: {optim_step} => loss: {loss.item():.3f}')

            # validation loss
            if (params['monitor_val'] and optim_step % params['val_freq']
                    == 0) or current_lr == 0:
                val_loss_mean, _ = calc_val_loss(args, params, model,
                                                 val_loader)
                metrics['val_loss'] = val_loss_mean
                print(
                    f'====== In [train]: val_loss mean: {round(val_loss_mean, 3)}'
                )

            # tracking metrics
            if args.use_comet:
                for key, value in metrics.items():  # track all metric values
                    comet_tracker.track_metric(key, round(value.item(), 3),
                                               optim_step)

            # saving samples
            if (optim_step % params['sample_freq'] == 0) or current_lr == 0:
                samples_path = paths['samples_path']
                helper.make_dir_if_not_exists(samples_path)
                sampled_images = models.take_samples(args, params, model,
                                                     reverse_cond)
                utils.save_image(
                    sampled_images,
                    f'{samples_path}/{str(optim_step).zfill(6)}.png',
                    nrow=10)
                print(
                    f'\nIn [train]: Sample saved at iteration {optim_step} to: \n"{samples_path}"\n'
                )

            # saving checkpoint
            if (optim_step > 0 and optim_step % params['checkpoint_freq']
                    == 0) or current_lr == 0:
                checkpoints_path = paths['checkpoints_path']
                helper.make_dir_if_not_exists(checkpoints_path)
                helper.save_checkpoint(checkpoints_path, optim_step, model,
                                       optimizer, loss, current_lr)
                print("In [train]: Checkpoint saved at iteration", optim_step,
                      '\n')

            optim_step += 1
            end_time = time.time()
            print(f'Iteration took: {round(end_time - begin_time, 2)}')
            helper.show_memory_usage()
            print('\n')

            if current_lr == 0:
                print(
                    'In [train]: current_lr = 0, terminating the training...')
                sys.exit(0)
Esempio n. 17
0
def interpolate(cond_config,
                interp_config,
                params,
                args,
                device,
                mode='conditional'):
    # image and label are of type 'batch'
    img_index1, img_index2, rev_cond = cond_config['img_index1'], cond_config[
        'img_index2'], cond_config['reverse_cond']
    img1, label1 = get_image(img_index1,
                             params['data_folder'],
                             args.img_size,
                             ret_type='batch')
    img2, label2 = get_image(img_index2,
                             params['data_folder'],
                             args.img_size,
                             ret_type='batch')

    checkpoint_pth = params['checkpoints_path'][mode]
    optim_step = args.last_optim_step
    save_path = params['samples_path'][mode] + f'/interp'
    make_dir_if_not_exists(save_path)

    # init model and load checkpoint
    model = init_glow(params)
    model, _, _ = load_checkpoint(checkpoint_pth,
                                  optim_step,
                                  model,
                                  None,
                                  resume_train=False)

    # assumption: the two images are of the same condition (label), so I am only using label1
    forward_cond = (args.dataset, label1)

    _, _, z_list1 = model(img1, forward_cond)
    _, _, z_list2 = model(img2, forward_cond)

    z_diff = [z_list2[i] - z_list1[i] for i in range(len(z_list1))]

    coeff = 0
    steps = interp_config['steps']
    all_sampled = []

    for step in range(steps + 1):
        if interp_config['type'] == 'limited':
            coeff = step / steps  # this is the increment factor: e.g. 1/5, 2/5, ..., 5/5
        else:
            coeff = step * interp_config['increment']

        if interp_config['axis'] == 'all':  # full interpolation in all axes
            z_list_inter = [
                z_list1[i] + coeff * z_diff[i] for i in range(len(z_diff))
            ]

        else:  # interpolation in only the fist axis and keeping others untouched
            axis = 0 if interp_config[
                'axis'] == 'z1' else 1 if interp_config['axis'] == 'z2' else 2
            # print(f'{interp_config["axis"]} shape: {z_list1[axis].shape}')
            # input()
            z_list_inter = [z_list1[i] for i in range(len(z_list1))
                            ]  # deepcopy not available for these tensors
            z_list_inter[axis] = z_list1[axis] + coeff * z_diff[axis]

        sampled_img = model.reverse(z_list_inter,
                                    reconstruct=True,
                                    coupling_conds=rev_cond).cpu().data
        all_sampled.append(sampled_img.squeeze(dim=0))
        # make naming consistent and easy to sort
        coeff_name = '%.2f' % coeff if interp_config[
            'type'] == 'limited' else round(coeff, 2)
        print(f'In [interpolate]: done for coeff {coeff_name}')

    utils.save_image(
        all_sampled,
        f'{save_path}/{img_index1}-to-{img_index2}_[{interp_config["axis"]}].png',
        nrow=10)
Esempio n. 18
0
def infer_on_set(args, params, split_set='val'):
    """
    The model name and paths should be equivalent to the name used in the resize_for_fcn function
    in evaluation.third_party.prepare.py module.
    :param split_set:
    :param args:
    :param params:
    :return:
    """
    with torch.no_grad():
        paths = helper.compute_paths(args, params)
        checkpt_path, val_path, train_vis_path = paths[
            'checkpoints_path'], paths['val_path'], paths['train_vis']

        save_path = val_path if split_set == 'val' else train_vis_path
        save_path = helper.extend_path(save_path, args.sampling_round)

        print(
            f'In [infer_on_validation_set]:\n====== checkpt_path: "{checkpt_path}" \n====== save_path: "{save_path}" \n'
        )
        helper.make_dir_if_not_exists(save_path)

        # init and load model
        model = models.init_and_load(args, params, run_mode='infer')
        print(
            f'In [infer_on_validation_set]: init model and load checkpoint: done'
        )

        # validation loader
        batch_size = params['batch_size']
        loader_params = {
            'batch_size': batch_size,
            'shuffle': False,
            'num_workers': 0
        }
        train_loader, val_loader = data_handler.init_city_loader(
            data_folder=params['data_folder'],
            image_size=(params['img_size']),
            loader_params=loader_params)
        loader = val_loader if split_set == 'val' else train_loader
        print('In [infer_on_validation_set]: using loader of len:',
              len(loader))

        # inference on validation set
        print(
            f'In [infer_on_validation_set]: starting inference on {split_set} set'
        )
        for i_batch, batch in enumerate(loader):
            img_batch = batch['real'].to(device)
            segment_batch = batch['segment'].to(device)
            boundary_batch = batch['boundary'].to(
                device) if args.use_bmaps else None
            real_paths = batch[
                'real_path']  # list: used to save samples with the same name as original images
            seg_paths = batch['segment_path']

            if args.direction == 'label2photo':
                reverse_cond = {
                    'segment': segment_batch,
                    'boundary': boundary_batch
                }
            elif args.direction == 'photo2label':  # 'photo2label'
                reverse_cond = {'real': img_batch}

            # sampling from model
            # using batch.shape[0] is safer than batch_size since the last batch may be different in size
            samples = models.take_samples(args,
                                          params,
                                          model,
                                          reverse_cond,
                                          n_samples=segment_batch.shape[0])
            # save inferred images separately
            paths_list = real_paths if args.direction == 'label2photo' else seg_paths
            helper.save_one_by_one(samples, paths_list, save_path)

            print(
                f'In [infer_on_validation_set]: done for the {i_batch}th batch out of {len(loader)} batches (batch size: {segment_batch.shape[0]})'
            )
        print(
            f'In [infer_on_validation_set]: all done. Inferred images could be found at: {save_path} \n'
        )
Esempio n. 19
0
def run_model(mode, args, params, hps, sess, model, conditions, tracker):
    sess.graph.finalize()

    # inference on validation set
    if mode == 'infer':
        # val_path = helper.compute_paths(args, params)['val_path']
        paths = helper.compute_paths(args, params)
        val_path = helper.extend_path(paths['val_path'], args.sampling_round)
        helper.make_dir_if_not_exists(val_path)

        take_sample(model,
                    conditions,
                    val_path,
                    mode,
                    direction=args.direction,
                    iteration=None)
        print(f'In [run_model]: validation samples saved to: "{val_path}"')

    # training
    else:
        paths = helper.compute_paths(args, params)
        checkpoint_path = paths['checkpoints_path']
        samples_path = paths['samples_path']
        helper.make_dir_if_not_exists(checkpoint_path)
        helper.make_dir_if_not_exists(samples_path)

        # compute_conditional_bpd(model, args.last_optim_step, hps)

        iteration = 0 if not args.resume_train else args.last_optim_step + 1
        while iteration <= hps.train_its:
            lr = hps.lr
            train_results = model.train(
                lr)  # returns [local_loss, bits_x_u, bits_x_o, bits_y]
            train_loss = round(train_results[0], 3)
            print(f'Step {iteration} - train loss: {train_loss}')

            # take sample
            if iteration % hps.sample_freq == 0:
                take_sample(model, conditions, samples_path, mode, iteration)

            # track train loss
            if tracker is not None:
                tracker.track_metric('train_loss', round(train_loss, 3),
                                     iteration)

            # compute val loss
            if iteration % hps.val_freq == 0:
                test_results = []
                print('Computing validation loss...')

                for _ in range(hps.val_its
                               ):  # one loop over all all validation examples
                    test_results += [model.test()]

                test_results = np.mean(np.asarray(test_results),
                                       axis=0)  # get the mean of val loss
                val_loss = round(test_results[0], 3)
                print(f'Step {iteration} - validation loss: {val_loss}')

                # save checkpoint
                path = os.path.join(checkpoint_path, f"step={iteration}.ckpt")
                model.save(path)
                print(f'Checkpoint saved to: "{path}"')

                # track val loss
                if tracker is not None:
                    tracker.track_metric('val_loss', round(val_loss, 3),
                                         iteration)

            iteration += 1
Esempio n. 20
0
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    helper.make_dir_if_not_exists(
        os.path.split(filename)[0])  # make checkpoint dir if not available
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')
Esempio n. 21
0
def _create_cond(args,
                 params,
                 fixed_conds=None,
                 save_path=None,
                 direction='label2photo'):
    """
    :param args:
    :param direction:
    :param params:
    :param fixed_conds:
    :param save_path:
    :return:

    Notes:
        - The use of .clones() for saving in this function is very important. In utils.save_image() the float values
          will be normalized to [0-255] integer and the values of the tensor change in-place, so the tensor does not
          have valid float values after this operation, unless we use .clone() to make a new copy of it for saving.
    """
    n_samples = params['n_samples'] if fixed_conds is None else len(
        fixed_conds)
    data_folder = params['data_folder']
    img_size = params['img_size']

    # this will not be used if fixed_conds is given
    train_df = {
        'segment': data_folder['segment'] + '/train',
        'real': data_folder['real'] + '/train'
    }

    city_dataset = CityDataset(train_df, img_size, fixed_cond=fixed_conds)
    print(
        f'In [create_cond]: created dataset of len {len(city_dataset)} for conditions'
    )

    segmentations = torch.stack(
        [city_dataset[i]['segment'] for i in range(n_samples)], dim=0)
    real_imgs = torch.stack(
        [city_dataset[i]['real'] for i in range(n_samples)], dim=0)
    boundaries = torch.stack(
        [city_dataset[i]['boundary'] for i in range(n_samples)], dim=0)
    seg_paths = [city_dataset[i]['segment_path'] for i in range(n_samples)]
    real_paths = [city_dataset[i]['real_path'] for i in range(n_samples)]
    id_repeats_batch = torch.zeros(
        (n_samples, 34, img_size[0],
         img_size[1]))  # 34 different IDs - no longer used

    # save conditions
    if save_path:
        helper.make_dir_if_not_exists(save_path)
        utils.save_image(segmentations.clone(),
                         f'{save_path}/segmentation.png',
                         nrow=10)  # .clone(): very important
        utils.save_image(real_imgs.clone(),
                         f'{save_path}/real_img.png',
                         nrow=10)

        if direction == 'label2photo' or direction == 'bmap2label':
            if args.do_ceil or direction == 'bmap2label':
                boundaries = torch.ceil(boundaries).to(device)
            utils.save_image(boundaries.clone(),
                             f'{save_path}/boundary.png',
                             nrow=10)
        print(
            f'In [create_cond]: saved the condition and real images to: "{save_path}"'
        )

        # write paths
        with open(f'{save_path}/img_paths.txt', 'a') as f:
            f.write("==== SEGMENTATIONS PATHS \n")
            for item in seg_paths:
                f.write("%s\n" % item)

            f.write("==== REAL IMAGES PATHS \n")
            for item in real_paths:
                f.write("%s\n" % item)
        print('In [create_cond]: saved the image paths \n')

    return segmentations.to(device), id_repeats_batch.to(device), real_imgs.to(
        device), boundaries.to(device)