Exemplo n.º 1
0
def eval_ssim(args, params):
    assert args.dataset == 'cityscapes'  # not implemented otherwise
    if args.direction == 'label2photo':
        ref_dir = os.path.join(params['data_folder']['real'], 'val')
    else:
        ref_dir = os.path.join(params['data_folder']['segment'], 'val')
    if args.model == 'pix2pix':
        syn_dir = f'/Midgard/home/sorkhei/glow2/checkpoints/cityscapes/256x256/pix2pix/label2photo/eval_{args.sampling_round}/val_imgs/pix2pix_cityscapes_label2photo/test_46/images'
        ssim_file = f'/Midgard/home/sorkhei/glow2/checkpoints/cityscapes/256x256/pix2pix/label2photo/ssim.txt'
        file_paths = helper.files_with_suffix(syn_dir, '_leftImg8bit.png')

        # print(syn_dir)
        # print('len file paths:', len(file_paths))
        # print('is dir:', os.path.isdir('/Midgard/home/sorkhei/glow2/checkpoints/cityscapes/256x256/pix2pix/label2photo'))
    else:
        syn_dir = helper.extend_path(
            helper.compute_paths(args, params)['val_path'],
            args.sampling_round)
        ssim_file = os.path.join(
            helper.compute_paths(args, params)['eval_path'], 'ssim.txt')
        file_paths = helper.absolute_paths(
            syn_dir)  # the absolute paths of all synthesized images

    # get the score
    ssim_score = compute_ssim(file_paths, ref_dir)

    with open(ssim_file, 'a') as file:
        file.write(f'Round {args.sampling_round}: {ssim_score}\n')

    print(
        f'In [eval_ssim]: writing ssim score of {ssim_score} to: "{ssim_file}"'
    )
    print(f'In [eval_ssim]: all done.')
Exemplo n.º 2
0
def evaluate_model(args, params):
    loader_params = {
        'batch_size': params['batch_size'],
        'shuffle': False,
        'num_workers': params['num_workers']
    }
    batch_size = params['batch_size']

    _, _, test_loader = data_handler.init_data_loaders(params, loader_params)

    total_predicted = np.zeros((batch_size, 14))
    total_labels = np.zeros((batch_size, 14))

    path_to_load = helper.compute_paths(
        args, params)['save_path']  # compute path from args and params
    net = networks.init_unified_net(args.model, params)
    net, _ = helper.load_model(path_to_load,
                               args.epoch,
                               net,
                               optimizer=None,
                               resume_train=False)

    print(
        f'In [evaluate_model]: loading the model done from: "{path_to_load}"')
    print(
        f'In [evaluate_model]: starting evaluation with {len(test_loader)} batches'
    )

    with torch.no_grad():
        for i_batch, batch in enumerate(test_loader):
            img_batch = batch['image'].to(device).float()
            label_batch = batch['label'].to(device).float()
            pred = net(img_batch)

            if i_batch > 0:
                total_predicted = np.append(total_predicted,
                                            pred.cpu().detach().numpy(),
                                            axis=0)
                total_labels = np.append(total_labels,
                                         label_batch.cpu().detach().numpy(),
                                         axis=0)
            else:
                total_predicted = pred.cpu().detach().numpy()
                total_labels = label_batch.cpu().detach().numpy()

            if i_batch % 50 == 0:
                print(
                    f'In [evaluate_model]: prediction done for batch {i_batch}'
                )

    results_path = helper.compute_paths(args, params)['results_path']
    helper.make_dir_if_not_exists(results_path)

    # plot roc
    print(f'In [evaluate_model]: starting plotting ROC...')
    plotting.plot_roc(total_predicted, total_labels, pathology_names,
                      results_path)
Exemplo n.º 3
0
def diverse_samples(args, params):
    """
    command:
    python3 main.py --local --run diverse --image_name strasbourg_000001_061472 --temp 0.9 --model improved_so_large --last_optim_step 276000 \
                    --img_size 512 1024 --dataset cityscapes --direction label2photo \
                    --n_block 4 --n_flow 10 10 10 10 --do_lu --reg_factor 0.0001 --grad_checkpoint
    """
    optim_step = args.last_optim_step
    temperature = args.temp
    n_samples = 10
    img_size = [512, 1024]
    n_blocks = args.n_block
    image_name = args.image_name

    image_cond_path = f'/local_storage/datasets/moein/cityscapes/gtFine_trainvaltest/gtFine/train/{image_name.split("_")[0]}/{image_name}_gtFine_color.png'
    image_gt_path = f'/local_storage/datasets/moein/cityscapes/leftImg8bit_trainvaltest/leftImg8bit/train/{image_name.split("_")[0]}/{image_name}_leftImg8bit.png'

    model_paths = helper.compute_paths(args, params)
    output_folder = os.path.join(model_paths['diverse_path'], image_name, f'temp={temperature}')
    print(f'out folder: {output_folder}')

    helper.make_dir_if_not_exists(output_folder)

    shutil.copyfile(image_cond_path, os.path.join(output_folder, 'segment.png'))
    shutil.copyfile(image_gt_path, os.path.join(output_folder, 'real.png'))
    print('Copy segment and real images: done')

    model = models.init_model(args, params)
    model = helper.load_checkpoint(model_paths['checkpoints_path'], optim_step, model, optimizer=None, resume_train=False)[0]

    step = 2
    for i in range(0, 10, step):
        paths_list = [os.path.join(output_folder, f'sample_{(i + 1) + j}.png') for j in range(step)]
        experiments.take_multiple_samples(model, n_blocks, temperature, step, img_size, image_cond_path, paths_list)
Exemplo n.º 4
0
def evaluate_fcn(args, params):
    assert args.dataset == 'cityscapes' and params['img_size'] == [
        256, 256
    ]  # not supported otherwise for now

    # specify the paths
    if args.direction == 'label2photo' and args.gt:  # evaluating ground-truth images
        all_paths = {
            'resized_path':
            '/Midgard/home/sorkhei/glow2/data/cityscapes/resized/val',
            'eval_results': '/Midgard/home/sorkhei/glow2/gt_eval_results'
        }
    else:
        all_paths = helper.compute_paths(args, params)

    syn_dir = extend_path(all_paths['val_path'],
                          args.sampling_round)  # val_imgs_1, val_imgs_2, ...
    eval_dir = all_paths['eval_path']
    print(f'In [evaluate_city_fcn]: results will be read from: "{syn_dir}"')

    # evaluation
    if args.direction == 'label2photo':
        eval_real_imgs_with_temp(
            base_data_folder=params['data_folder']['base'],
            synthesized_dir=syn_dir,
            save_dir=eval_dir,
            sampling_round=args.sampling_round)
    else:
        eval_segmentations_with_temp(
            synthesized_dir=syn_dir,
            reference_dir=params['data_folder']['segment'],
            base_data_folder=params['data_folder']['base'],
            save_dir=eval_dir,
            sampling_round=args.sampling_round)
    print(f'In [eval_city_with_temp]: evaluation done')
Exemplo n.º 5
0
def retrieve_data(sess, hps, args, params):
    train_tfrecords_file = params['tfrecords_file']['train']
    val_tfrecords_file = params['tfrecords_file']['val']

    train_iter = data_io.read_tfrecords(train_tfrecords_file,
                                        args.dataset,
                                        args.direction,
                                        hps.batch_size,
                                        is_training=True)
    valid_iter = data_io.read_tfrecords(val_tfrecords_file,
                                        args.dataset,
                                        args.direction,
                                        hps.batch_size,
                                        is_training=False)

    first_batch = make_first_batch(sess, train_iter)

    dataset_name, direction = args.dataset, args.direction
    assert dataset_name == 'cityscapes'  # get from helper

    run_mode = 'infer' if args.exp else 'train'
    if run_mode == 'infer':
        if args.direction == 'label2photo':
            conditions_paths = helper.get_all_data_folder_images(
                path=params['data_folder']['segment'],
                partition='val',
                image_type='segment')
        else:
            conditions_paths = helper.get_all_data_folder_images(
                path=params['data_folder']['real'],
                partition='val',
                image_type='real')
        cond_list = create_conditions(conditions_paths, also_visual_grid=False)
        print(
            f'In [retrieve_data]: conditions_paths for inference is of len: {len(conditions_paths)}'
        )

    else:
        conditions_paths = globals.seg_conds_abs_paths if direction == 'label2photo' else globals.real_conds_abs_path
        cond_list, visual_grid = create_conditions(conditions_paths)

        # save condition if not exists
        samples_path = helper.compute_paths(args, params)['samples_path']
        helper.make_dir_if_not_exists(samples_path)
        path = os.path.join(samples_path, 'conditions.png')

        if not os.path.isfile(path):
            Image.fromarray(visual_grid).save(path)
            print(f'In [retrieve_data]: saved conditions to: "{path}"')

    # make a list of dicts containing both image path and image array
    conditions = [{
        'image_path': conditions_paths[i],
        'image_array': cond_list[i]
    } for i in range(len(cond_list))]
    return train_iter, valid_iter, first_batch, conditions
Exemplo n.º 6
0
def init_hps_for_dual_glow(args, params):
    class hps(object):  # a helper class just for carrying attributes among functions
        pass

    # running params
    hps.inference = False
    if args.resume_train or args.exp:
        hps.restore_path = os.path.join(helper.compute_paths(args, params)['checkpoints_path'],
                                        f"step={args.last_optim_step}.ckpt")
    else:
        hps.restore_path = None  # would be checkpoints path + step when resume training or inference

    # batch and image size
    hps.batch_size = 1
    hps.input_size = [256, 256, 3]
    hps.output_size = [256, 256, 3]
    hps.n_bits_x = 8

    # model config
    hps.n_levels, hps.depth = params['n_block'], params['n_flow']

    hps.n_l = 1  # mlp basic layers, default: 1 by the paper
    hps.flow_permutation = 2  # 0: reverse (RealNVP), 1: shuffle, 2: invconv (Glow)"
    hps.flow_coupling = 1  # 0: additive, 1: affine

    hps.width = 512  # Width of hidden layers - default by the paper
    hps.eps_std = .7

    # other model configs
    hps.ycond = False  # Use y conditioning - default by the paper
    hps.learntop = True  # Learn spatial prior
    hps.n_y = 1  # always 1 in the original code
    hps.ycond_loss_type = 'l2'  # loss type of y inferred from z_in - default by the paper - not used by us as we do not have y conditioning

    # training config
    hps.train_its = 2000000 if args.max_step is None else args.max_step  # training iterations
    hps.val_its = 500  # 500 val iterations so we get full validation result with batch size 1 (val set size is 500)
    hps.val_freq = 1000  # get val result every 1000 iterations
    hps.sample_freq = 500 if args.sample_freq is None else args.sample_freq
    hps.direct_iterator = True  # default by the paper
    hps.weight_lambda = 0.001  # Weight of log p(x_o|x_u) in weighted loss, default by the paper
    hps.weight_y = 0.01  # Weight of log p(y|x) in weighted loss, default by the paper - not used by us as we do not have y conditioning

    # adam params
    hps.optimizer = 'adam'
    hps.gradient_checkpointing = 1  # default
    hps.beta1 = .9
    hps.beta2 = .999
    hps.lr = 0.0001
    hps.weight_decay = 1.  # Switched off by default
    hps.polyak_epochs = 1   # default by the code - not used by us
    return hps
Exemplo n.º 7
0
def run_training(args, params):
    # print run info
    helper.print_info(args, params, model=None, which_info='params')

    # setting comet tracker
    tracker = None
    if args.use_comet:
        tracker = init_comet(args, params)
        print("In [run_training]: Comet experiment initialized...")

    if 'dual_glow' in args.model:
        models.train_dual_glow(args, params, tracker)
    else:
        model = models.init_model(args, params)
        optimizer = optim.Adam(model.parameters())
        reverse_cond = data_handler.retrieve_rev_cond(args,
                                                      params,
                                                      run_mode='train')
        train_configs = trainer.init_train_configs(args)

        # resume training
        if args.resume_train:
            optim_step = args.last_optim_step
            checkpoints_path = helper.compute_paths(args,
                                                    params)['checkpoints_path']
            model, optimizer, _, lr = load_checkpoint(checkpoints_path,
                                                      optim_step, model,
                                                      optimizer)

            if lr is None:  # if not saved in checkpoint
                lr = params['lr']
            trainer.train(args,
                          params,
                          train_configs,
                          model,
                          optimizer,
                          lr,
                          tracker,
                          resume=True,
                          last_optim_step=optim_step,
                          reverse_cond=reverse_cond)
        # train from scratch
        else:
            lr = params['lr']
            trainer.train(args,
                          params,
                          train_configs,
                          model,
                          optimizer,
                          lr,
                          tracker,
                          reverse_cond=reverse_cond)
Exemplo n.º 8
0
def train(args, params, model, optimizer, tracker=None):
    loader_params = {
        'batch_size': params['batch_size'],
        'shuffle': params['shuffle'],
        'num_workers': params['num_workers']
    }

    train_loader, val_loader, _ = data_handler.init_data_loaders(
        params, loader_params)

    epoch = 0
    max_epochs = params['max_epochs']

    while epoch < max_epochs:
        print(f'{"=" * 40} In epoch: {epoch} {"=" * 40}')
        print(f'Training on {len(train_loader)} batches...')

        # each for loop for one epoch
        for i_batch, batch in enumerate(train_loader):
            # converting the labels batch  to from Long tensor to Float tensor (otherwise won't work on GPU)
            img_batch = batch['image'].to(device).float()
            label_batch = batch['label'].to(device).float()

            # making gradients zero in each optimization step
            optimizer.zero_grad()

            # getting the network prediction and computing the loss
            pred = model(img_batch)
            train_loss = loss.compute_wcel(fx=pred, labels=label_batch)
            # if i_batch % 50 == 0:
            print(
                f'Batch: {i_batch} - train loss: {round(train_loss.item(), 3)}'
            )

            # tracking the metrics using comet in each iteration
            track('train_loss', round(train_loss.item(), 3), args, tracker)

            # backward and optimization step
            train_loss.backward()
            optimizer.step()

        # save checkpoint after epoch
        save_path = helper.compute_paths(args, params)['save_path']
        helper.make_dir_if_not_exists(save_path)
        helper.save_model(save_path, epoch, model, optimizer,
                          train_loss)  # save the model every epoch

        val_loss = loss.compute_val_loss(model, val_loader)  # track val loss
        print(f'In [train]: epoch={epoch} - val_loss = {val_loss}')
        track('val_loss', val_loss, args, tracker)
        epoch += 1
Exemplo n.º 9
0
def create_rev_cond(args, params, fixed_conds, also_save=True):
    trans = loader.init_transformers(params['img_size'])
    # only supports reading from train data now
    images = read_imgs(params['data_folder'],
                       split='train',
                       fixed_conds=fixed_conds)

    conds_as_list = []
    real_as_list = []

    for image in images:
        h, w = params['img_size']
        # get half of the image corresponding to either the photo or the map
        photo = trans(
            Image.open(image)
        )[:, :, :
          w]  # if args.direction == 'photo2map' else trans(Image.open(image))[:, :, w:]
        the_map = trans(Image.open(image))[:, :, w:]

        if args.direction == 'map2photo':
            cond = the_map
            real = photo
        else:
            cond = photo
            real = the_map

        conds_as_list.append(cond)
        real_as_list.append(real)

    conds_as_tensor = torch.stack(conds_as_list, dim=0).to(device)
    real_as_tensor = torch.stack(real_as_list, dim=0).to(device)

    if also_save:
        save_path = helper.compute_paths(args, params)['samples_path']
        helper.make_dir_if_not_exists(save_path)

        if 'conditions.png' not in os.listdir(save_path):
            utils.save_image(conds_as_tensor,
                             os.path.join(save_path, 'conditions.png'))
            print(f'In [create_rev_cond]: saved reverse conditions')

        if 'real.png' not in os.listdir(save_path):
            utils.save_image(real_as_tensor,
                             os.path.join(save_path, 'real.png'))
            print(f'In [create_rev_cond]: saved real images')

    print(
        f'In [create_rev_cond]: returning cond_as_tensor of shape: {conds_as_tensor.shape}'
    )
    return conds_as_tensor
Exemplo n.º 10
0
def init_and_load(args, params, run_mode):
    checkpoints_path = helper.compute_paths(args, params)['checkpoints_path']
    optim_step = args.last_optim_step
    model = init_model(args, params)

    if run_mode == 'infer':
        model, _, _, _ = helper.load_checkpoint(checkpoints_path, optim_step, model, None, resume_train=False)
        print(f'In [init_and_load]: returned model for inference')
        return model

    else:  # train
        optimizer = optim.Adam(model.parameters(), lr=params['lr'])
        print(f'In [init_and_load]: returned model and optimizer for training')
        model, optimizer, _, lr = helper.load_checkpoint(checkpoints_path, optim_step, model, optimizer, resume_train=True)
        return model, optimizer, lr
Exemplo n.º 11
0
def prepare_city_reverse_cond(args, params, run_mode='train'):
    samples_path = helper.compute_paths(args, params)['samples_path']
    save_path = samples_path if run_mode == 'train' else None  # no need for reverse_cond at inference time
    direction = args.direction
    segmentations, _, real_imgs, boundaries = _create_cond(
        args,
        params,
        fixed_conds=real_conds_abs_path,
        save_path=save_path,
        direction=direction)
    return {
        'segment': segmentations,
        'boundary': boundaries,
        'real': real_imgs
    }
Exemplo n.º 12
0
def create_rev_cond(args, params, also_save=True):
    trans = util.init_transformer(params['img_size'])

    # if args.direction == 'daylight2night':
    names = util.fixed_conds[args.direction]
    conds = [trans(Image.open(f"{params['paths']['data_folder']}/{name}")) for name in names]

    # else:
    #    raise NotImplementedError

    conditions = torch.stack(conds, dim=0).to(device)  # (B, C, H, W)

    if also_save:
        save_path = helper.compute_paths(args, params)['samples_path']
        helper.make_dir_if_not_exists(save_path)

        if 'conditions.png' not in os.listdir(save_path):
            utils.save_image(conditions, f'{save_path}/conditions.png')
            print(f'In [create_rev_cond]: saved reverse conditions')
    return conditions
Exemplo n.º 13
0
def compute_ssim_old(args, params):
    """
    images should be np arrays of shape (H, W, C). This function should be called only after the inference has been
    done on validation set. This function computes SSIM for the temperature specified in params['temperature'].
    :return:
    """
    # =========== init validation loader
    loader_params = {'batch_size': 1, 'shuffle': False, 'num_workers': 0}
    _, val_loader = data_handler.init_city_loader(
        data_folder=params['data_folder'],
        image_size=(params['img_size']),
        remove_alpha=True,  # removing the alpha channel
        loader_params=loader_params)
    print(
        f'In [compute_ssim]: init val data loader of len {len(val_loader)}: done \n'
    )

    # ============= computing the validation path for generated images
    if params['img_size'] != [256, 256]:
        raise NotImplementedError(
            'Now only supports 256x256 images'
        )  # should use paths['resized_path'] probably?

    paths = helper.compute_paths(args, params)
    val_path = paths['val_path']
    print(f'In [compute_ssim]: val_path to read from: \n"{val_path}" \n')

    # ============= SSIM calculation for all the images in validation set
    ssim_vals = []
    for i_batch, batch in enumerate(
            val_loader):  # for every single image (batch size 1)
        # reading reference (ground truth) and generated image
        ref_img = batch['real'].cpu().data.squeeze(dim=0).permute(
            1, 2, 0).numpy()  # (C, H, W) -> (H, W, C)
        pure_name = batch['real_path'][0].split('/')[
            -1]  # the city name of the image

        # compute corresponding path in generated images and read the image
        gen_img_path = val_path + f'/{pure_name}'
        gen_img = city_transforms(Image.open(gen_img_path)).cpu().data.permute(
            1, 2, 0).numpy()

        # ============= computing SSIM
        pixel_range = 1  # with data range 1 since pixels can take a value between 0 to 1
        ssim_val = ssim(ref_img,
                        gen_img,
                        multichannel=True,
                        data_range=pixel_range)  # expects (H, W, C) ordering
        ssim_vals.append(ssim_val)

        if i_batch % 100 == 0:
            print(f'In [compute_ssim]: evaluated {i_batch} images')

    # ssim_score = round(np.mean(ssim_vals), 2)
    ssim_score = np.mean(ssim_vals)
    print(f'In [compute_ssim]: ssim score on validation set: {ssim_score}')

    # ============= append result to ssim.txt
    eval_path_base = paths['eval_path_base']
    with open(f'{eval_path_base}/ssim.txt', 'a') as ssim_file:
        string = f'temp = {params["temperature"]}: {ssim_score} \n'
        ssim_file.write(string)
        print(f'In [compute_ssim]: ssim score appended to ssim.txt')
Exemplo n.º 14
0
def train(args,
          params,
          train_configs,
          model,
          optimizer,
          current_lr,
          comet_tracker=None,
          resume=False,
          last_optim_step=0,
          reverse_cond=None):
    # getting data loaders
    train_loader, val_loader = data_handler.init_data_loaders(args, params)

    # adjusting optim step
    optim_step = last_optim_step + 1 if resume else 1
    max_optim_steps = params['iter']
    paths = helper.compute_paths(args, params)

    if resume:
        print(
            f'In [train]: resuming training from optim_step={optim_step} - max_step: {max_optim_steps}'
        )

    # optimization loop
    while optim_step < max_optim_steps:
        # after each epoch, adjust learning rate accordingly
        current_lr = adjust_lr(
            current_lr,
            initial_lr=params['lr'],
            step=optim_step,
            epoch_steps=len(
                train_loader))  # now only supports with batch size 1
        for param_group in optimizer.param_groups:
            param_group['lr'] = current_lr
        print(
            f'In [train]: optimizer learning rate adjusted to: {current_lr}\n')

        for i_batch, batch in enumerate(train_loader):
            if optim_step > max_optim_steps:
                print(
                    f'In [train]: reaching max_step or lr is zero. Terminating...'
                )
                return  # ============ terminate training if max steps reached

            begin_time = time.time()
            # forward pass
            left_batch, right_batch, extra_cond_batch = data_handler.extract_batches(
                batch, args)
            forward_output = forward_and_loss(args, params, model, left_batch,
                                              right_batch, extra_cond_batch)

            # regularize left loss
            if train_configs['reg_factor'] is not None:
                loss = train_configs['reg_factor'] * forward_output[
                    'loss_left'] + forward_output['loss_right']  # regularized
            else:
                loss = forward_output['loss']

            metrics = {'loss': loss}
            # also add left and right loss if available
            if 'loss_left' in forward_output.keys():
                metrics.update({
                    'loss_right': forward_output['loss_right'],
                    'loss_left': forward_output['loss_left']
                })

            # backward pass and optimizer step
            model.zero_grad()
            loss.backward()
            optimizer.step()
            print(f'In [train]: Step: {optim_step} => loss: {loss.item():.3f}')

            # validation loss
            if (params['monitor_val'] and optim_step % params['val_freq']
                    == 0) or current_lr == 0:
                val_loss_mean, _ = calc_val_loss(args, params, model,
                                                 val_loader)
                metrics['val_loss'] = val_loss_mean
                print(
                    f'====== In [train]: val_loss mean: {round(val_loss_mean, 3)}'
                )

            # tracking metrics
            if args.use_comet:
                for key, value in metrics.items():  # track all metric values
                    comet_tracker.track_metric(key, round(value.item(), 3),
                                               optim_step)

            # saving samples
            if (optim_step % params['sample_freq'] == 0) or current_lr == 0:
                samples_path = paths['samples_path']
                helper.make_dir_if_not_exists(samples_path)
                sampled_images = models.take_samples(args, params, model,
                                                     reverse_cond)
                utils.save_image(
                    sampled_images,
                    f'{samples_path}/{str(optim_step).zfill(6)}.png',
                    nrow=10)
                print(
                    f'\nIn [train]: Sample saved at iteration {optim_step} to: \n"{samples_path}"\n'
                )

            # saving checkpoint
            if (optim_step > 0 and optim_step % params['checkpoint_freq']
                    == 0) or current_lr == 0:
                checkpoints_path = paths['checkpoints_path']
                helper.make_dir_if_not_exists(checkpoints_path)
                helper.save_checkpoint(checkpoints_path, optim_step, model,
                                       optimizer, loss, current_lr)
                print("In [train]: Checkpoint saved at iteration", optim_step,
                      '\n')

            optim_step += 1
            end_time = time.time()
            print(f'Iteration took: {round(end_time - begin_time, 2)}')
            helper.show_memory_usage()
            print('\n')

            if current_lr == 0:
                print(
                    'In [train]: current_lr = 0, terminating the training...')
                sys.exit(0)
Exemplo n.º 15
0
def infer_on_set(args, params, split_set='val'):
    """
    The model name and paths should be equivalent to the name used in the resize_for_fcn function
    in evaluation.third_party.prepare.py module.
    :param split_set:
    :param args:
    :param params:
    :return:
    """
    with torch.no_grad():
        paths = helper.compute_paths(args, params)
        checkpt_path, val_path, train_vis_path = paths[
            'checkpoints_path'], paths['val_path'], paths['train_vis']

        save_path = val_path if split_set == 'val' else train_vis_path
        save_path = helper.extend_path(save_path, args.sampling_round)

        print(
            f'In [infer_on_validation_set]:\n====== checkpt_path: "{checkpt_path}" \n====== save_path: "{save_path}" \n'
        )
        helper.make_dir_if_not_exists(save_path)

        # init and load model
        model = models.init_and_load(args, params, run_mode='infer')
        print(
            f'In [infer_on_validation_set]: init model and load checkpoint: done'
        )

        # validation loader
        batch_size = params['batch_size']
        loader_params = {
            'batch_size': batch_size,
            'shuffle': False,
            'num_workers': 0
        }
        train_loader, val_loader = data_handler.init_city_loader(
            data_folder=params['data_folder'],
            image_size=(params['img_size']),
            loader_params=loader_params)
        loader = val_loader if split_set == 'val' else train_loader
        print('In [infer_on_validation_set]: using loader of len:',
              len(loader))

        # inference on validation set
        print(
            f'In [infer_on_validation_set]: starting inference on {split_set} set'
        )
        for i_batch, batch in enumerate(loader):
            img_batch = batch['real'].to(device)
            segment_batch = batch['segment'].to(device)
            boundary_batch = batch['boundary'].to(
                device) if args.use_bmaps else None
            real_paths = batch[
                'real_path']  # list: used to save samples with the same name as original images
            seg_paths = batch['segment_path']

            if args.direction == 'label2photo':
                reverse_cond = {
                    'segment': segment_batch,
                    'boundary': boundary_batch
                }
            elif args.direction == 'photo2label':  # 'photo2label'
                reverse_cond = {'real': img_batch}

            # sampling from model
            # using batch.shape[0] is safer than batch_size since the last batch may be different in size
            samples = models.take_samples(args,
                                          params,
                                          model,
                                          reverse_cond,
                                          n_samples=segment_batch.shape[0])
            # save inferred images separately
            paths_list = real_paths if args.direction == 'label2photo' else seg_paths
            helper.save_one_by_one(samples, paths_list, save_path)

            print(
                f'In [infer_on_validation_set]: done for the {i_batch}th batch out of {len(loader)} batches (batch size: {segment_batch.shape[0]})'
            )
        print(
            f'In [infer_on_validation_set]: all done. Inferred images could be found at: {save_path} \n'
        )
Exemplo n.º 16
0
def run_model(mode, args, params, hps, sess, model, conditions, tracker):
    sess.graph.finalize()

    # inference on validation set
    if mode == 'infer':
        # val_path = helper.compute_paths(args, params)['val_path']
        paths = helper.compute_paths(args, params)
        val_path = helper.extend_path(paths['val_path'], args.sampling_round)
        helper.make_dir_if_not_exists(val_path)

        take_sample(model,
                    conditions,
                    val_path,
                    mode,
                    direction=args.direction,
                    iteration=None)
        print(f'In [run_model]: validation samples saved to: "{val_path}"')

    # training
    else:
        paths = helper.compute_paths(args, params)
        checkpoint_path = paths['checkpoints_path']
        samples_path = paths['samples_path']
        helper.make_dir_if_not_exists(checkpoint_path)
        helper.make_dir_if_not_exists(samples_path)

        # compute_conditional_bpd(model, args.last_optim_step, hps)

        iteration = 0 if not args.resume_train else args.last_optim_step + 1
        while iteration <= hps.train_its:
            lr = hps.lr
            train_results = model.train(
                lr)  # returns [local_loss, bits_x_u, bits_x_o, bits_y]
            train_loss = round(train_results[0], 3)
            print(f'Step {iteration} - train loss: {train_loss}')

            # take sample
            if iteration % hps.sample_freq == 0:
                take_sample(model, conditions, samples_path, mode, iteration)

            # track train loss
            if tracker is not None:
                tracker.track_metric('train_loss', round(train_loss, 3),
                                     iteration)

            # compute val loss
            if iteration % hps.val_freq == 0:
                test_results = []
                print('Computing validation loss...')

                for _ in range(hps.val_its
                               ):  # one loop over all all validation examples
                    test_results += [model.test()]

                test_results = np.mean(np.asarray(test_results),
                                       axis=0)  # get the mean of val loss
                val_loss = round(test_results[0], 3)
                print(f'Step {iteration} - validation loss: {val_loss}')

                # save checkpoint
                path = os.path.join(checkpoint_path, f"step={iteration}.ckpt")
                model.save(path)
                print(f'Checkpoint saved to: "{path}"')

                # track val loss
                if tracker is not None:
                    tracker.track_metric('val_loss', round(val_loss, 3),
                                         iteration)

            iteration += 1
Exemplo n.º 17
0
def evaluate_iobb(args, params, img_name=None, img_disease=None):
    """
    Goes through every image in the BBox csv file, and calculates IoBB.
    Results are stored in data/iobb.npy together with image name and disease type.

    :param img_name: If given, only load one image and plot bbox and heatmap.
    :param img_disease: Which disease to plot bbox for, used together with img_name.
    """
    pathologies = {
        'Atelectasis': 0,
        'Cardiomegaly': 1,
        'Effusion': 4,
        'Infiltrate': 8,
        'Mass': 9,
        'Nodule': 10,
        'Pneumonia': 12,
        'Pneumothorax': 13
    }

    ############################################
    ### Load model checkpoint and get heatmap.
    ############################################
    path_to_load = helper.compute_paths(args, params)['save_path']
    net = networks.init_unified_net(args.model, params)
    net, _ = helper.load_model(path_to_load,
                               args.epoch,
                               net,
                               optimizer=None,
                               resume_train=False)

    # Get bbox_data from csv file.
    f = open('../data/BBox_List_2017.csv', 'rt')
    reader = csv.reader(f)
    rows = list(reader)[1:]  # ignoring the first row because it is the titles

    # A list of tuples (img_name, disease_index, iobb, num_bboxes)
    results = []

    for i, img_data in enumerate(rows):
        # Make sure image exists.
        file_path = f'../data/extracted/images/{img_data[0]}'
        if not os.path.isfile(file_path):
            continue

        # If only loading one image, check if this row contains the img and correct disease, otherwise continue.
        if img_name is not None:
            if img_data[0] != img_name:
                continue
            if img_disease is not None and pathologies[
                    img_data[1]] != img_disease:
                continue

        ############################################
        ### Load image and turn into tensor.
        ############################################
        xray_img = Image.open(file_path)
        rgb = Image.new('RGB', xray_img.size)
        rgb.paste(xray_img)

        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

        img_tensor = preprocess(rgb)
        img_tensor.unsqueeze_(0)

        # Get heatmap for correct disease and turn into numpy array.
        heatmap = net.forward(img_tensor, return_cam=True)
        disease = pathologies[img_data[1]]
        heatmap = heatmap[0, disease].numpy()

        ground_truth = BBox(float(img_data[2]), float(img_data[3]),
                            float(img_data[4]), float(img_data[5]))

        # Save results if evaluating all images, not just plotting one.
        if img_name is None:
            iobb, num = evaluate_single_bbox(ground_truth, heatmap, iobb=True)
            results.append((img_data[0], disease, iobb, num))
            print(f'#{i}, n:{num}, {iobb}')
        else:
            iobb, _ = evaluate_single_bbox(ground_truth,
                                           heatmap,
                                           iobb=True,
                                           xray_img=xray_img)
            print(f'iobb: {iobb}')
            break

    if img_name is None:
        # Order results by IoBB value.
        results = sorted(results, key=lambda x: x[2], reverse=True)

        results = np.array(results)

        # Save as numpy array.
        np.save(f'../data/iou_blob_{args.model}.npy', results)

        # Save as txt
        with open(f'../data/iou_blob_{args.model}.txt', 'w') as txt:
            for i in range(results.shape[0]):
                txt.write(
                    f'{float(results[i][2])}, {int(results[i][1])}, {int(results[i][3])}, {results[i][0]}\n'
                )

    f.close()  # Close csv file.
Exemplo n.º 18
0
def prep_for_sampling(args, params, img_name, additional_info):
    """
    :param args:
    :param params:
    :param img_name: the (path of the) real image whose segmentation which will be used for conditioning.
    :param additional_info:
    :return:
    """
    # ========== specifying experiment path
    paths = helper.compute_paths(args, params, additional_info)
    if additional_info['exp_type'] == 'random_samples':
        experiment_path = paths['random_samples_path']

    elif additional_info['exp_type'] == 'new_cond':
        experiment_path = paths['new_cond_path']

    else:
        raise NotImplementedError

    # ========== make the condition a single image
    fixed_conds = [img_name]
    # ========== create condition and save it to experiment path
    # no need to save for new_cond type
    path_to_save = None if additional_info[
        'exp_type'] == 'new_cond' else experiment_path
    seg_batch, _, real_batch, boundary_batch = data_handler._create_cond(
        params, fixed_conds=fixed_conds,
        save_path=path_to_save)  # (1, C, H, W)
    # ========== duplicate condition for n_samples times (not used by all exp_modes)
    seg_batch_dup = seg_batch.repeat(
        (params['n_samples'], 1, 1, 1))  # duplicated: (n_samples, C, H, W)
    boundary_dup = boundary_batch.repeat((params['n_samples'], 1, 1, 1))
    real_batch_dup = real_batch.repeat((params['n_samples'], 1, 1, 1))

    if additional_info['exp_type'] == 'random_samples':
        seg_rev_cond = seg_batch_dup  # (n_samples, C, H, W) - duplicate for random samples
        bmap_rev_cond = boundary_dup
        real_rev_cond = real_batch_dup

    elif additional_info['exp_type'] == 'new_cond':
        seg_rev_cond = seg_batch  # (1, C, H, W) - no duplicate needed
        bmap_rev_cond = boundary_batch

    else:
        raise NotImplementedError

    # ========== create reverse cond
    if not args.use_bmaps:
        rev_cond = {'segment': seg_rev_cond, 'boundary': None}

    elif args.use_bmaps:
        rev_cond = {'segment': seg_rev_cond, 'boundary': bmap_rev_cond}

    elif args.direction == 'photo2label':
        rev_cond = {'real': real_rev_cond}

    else:
        raise NotImplementedError

    # ========== specifying paths for saving samples
    if additional_info['exp_type'] == 'random_samples':
        exp_path = paths['random_samples_path']
        save_paths = [
            f'{exp_path}/sample {i + 1}.png'
            for i in range(params['n_samples'])
        ]

    elif additional_info['exp_type'] == 'new_cond':
        save_paths = experiment_path
    else:
        raise NotImplementedError
    return save_paths, rev_cond, real_batch