def __init__(self,
                 dataset_paths,
                 model_paths,
                 eval_refine=False,
                 eval_pretrained=False,
                 perform_adjustment=True,
                 evaluation_path=None,
                 imagenet_path=None):
        self.dataset = Dataset(dataset_paths, imagenet_path=imagenet_path)
        self.dataset.mode = 'eval'
        self.dataset.padding = True
        self.data_loader = self.dataset.get_dataloader(batch_size=1)
        self.perform_adjustment = perform_adjustment
        self.dataset_paths = dataset_paths

        torch.manual_seed(42)
        np.random.seed(42)

        self.moduleSemantics = Semantics().to(device).eval()
        self.moduleDisparity = Disparity().to(device).eval()
        self.moduleMaskrcnn = torchvision.models.detection.maskrcnn_resnet50_fpn(
            pretrained=True).to(device).eval()
        self.moduleRefine = Refine().to(device).eval()

        self.eval_pretrained = eval_pretrained

        if model_paths is not None:
            models_list = [{
                'model': self.moduleDisparity,
                'type': 'disparity'
            }, {
                'model': self.moduleRefine,
                'type': 'refine'
            }]
            load_models(models_list, model_paths)
Exemple #2
0
 def __init__(self, config=None):
     super().__init__()
     self.config = config
     trn_data, val_data, tst_data, self.vocab = load_dataset()
     self.ds_trn, self.ds_val, self.ds_tst = Dataset(
         trn_data,
         self.vocab), Dataset(val_data,
                              self.vocab), Dataset(tst_data, self.vocab)
     val_data_c, test_data_c = load_dataset_c()
     self.ds_val_c, self.ds_tst_c = Dataset(val_data_c,
                                            self.vocab), Dataset(
                                                test_data_c, self.vocab)
     exclude = [
         'np', 'torch', 'random', 'args', 'os', 'argparse', 'parser',
         'Namespace', 'sys'
     ]
     self.hparams = Namespace(
         **{
             k: v
             for k, v in config.__dict__.items()
             if k[:2] != '__' and k not in exclude
         })
     self.model = Graph_DialogRe(config, self.vocab)
     self.loss_fn = nn.BCEWithLogitsLoss()  # nn.BCELoss()
     self.f1_metric = f1_score
     self.f1c_metric = f1c_score
     self.acc_metric = acc_score
def main():
    """
    Starting point of the application
    """
    hvd.init()
    params = parse_args(PARSER.parse_args())
    set_flags(params)
    model_dir = prepare_model_dir(params)
    params.model_dir = model_dir
    logger = get_logger(params)

    model = Unet()

    dataset = Dataset(data_dir=params.data_dir,
                      batch_size=params.batch_size,
                      fold=params.crossvalidation_idx,
                      augment=params.augment,
                      gpu_id=hvd.rank(),
                      num_gpus=hvd.size(),
                      seed=params.seed)

    if 'train' in params.exec_mode:
        train(params, model, dataset, logger)

    if 'evaluate' in params.exec_mode:
        if hvd.rank() == 0:
            evaluate(params, model, dataset, logger)

    if 'predict' in params.exec_mode:
        if hvd.rank() == 0:
            predict(params, model, dataset, logger)
def main():
    """
    Starting point of the application
    """
    flags = PARSER.parse_args()

    if flags.to == 'savedmodel':
        to_savedmodel(input_shape=flags.input_shape,
                      model_fn=unet_fn,
                      src_dir=flags.checkpoint_dir,
                      dst_dir='./saved_model',
                      input_names=['IteratorGetNext'],
                      output_names=['total_loss_ref'],
                      use_amp=flags.use_amp,
                      use_xla=flags.use_xla,
                      compress=flags.compress)
    if flags.to == 'tensorrt':
        ds = Dataset(data_dir=flags.data_dir,
                     batch_size=1,
                     augment=False,
                     gpu_id=0,
                     num_gpus=1,
                     seed=42)
        iterator = ds.test_fn(count=1).make_one_shot_iterator()
        features = iterator.get_next()

        sess = tf.Session()

        def input_data():
            return {'input_tensor:0': sess.run(features)}

        to_tensorrt(src_dir=flags.savedmodel_dir,
                    dst_dir='./tf_trt_model',
                    precision=flags.precision,
                    feed_dict_fn=input_data,
                    num_runs=1,
                    output_tensor_names=['Softmax:0'],
                    compress=flags.compress)
    if flags.to == 'onnx':
        to_onnx(src_dir=flags.savedmodel_dir,
                dst_dir='./onnx_model',
                compress=flags.compress)
Exemple #5
0
    def __init__(self, dataset_paths, models_paths=None, partial_conv=False):
        self.dataset_paths = dataset_paths
        self.dataset = Dataset(dataset_paths, mode='inpaint-eval')
        self.dataset_length = len(self.dataset)

        if not partial_conv:
            self.moduleInpaint = Inpaint().to(device).train()
        else:
            self.moduleInpaint = PartialInpaint().to(device).train()

        if models_paths is not None:
            print('Loading model state from ' + str(models_paths))
            models_list = [{'model': self.moduleInpaint, 'type': 'inpaint'}]

            load_models(models_list, models_paths)

        self.data_loader = torch.utils.data.DataLoader(self.dataset,
                                                       batch_size=1,
                                                       shuffle=True,
                                                       pin_memory=True,
                                                       num_workers=1)
    def __init__(self, params):
        hvd.init()

        LOGGER.log(str(params))

        data_dir = params['data_dir']
        batch_size = params['batch_size']
        augment = params['augment']
        benchmark = params['benchmark']
        seed = params['seed']

        self._model_dir = params['model_dir']
        self._max_steps = params['max_steps']

        self._classifier = tf.estimator.Estimator(
            model_fn=_model_fn,
            model_dir=self._model_dir,
            params=params,
            config=tf.estimator.RunConfig(
                tf_random_seed=None,
                session_config=self._get_session_config(),
                save_checkpoints_steps=self._max_steps if hvd.rank() == 0 else None,
                keep_checkpoint_max=1))

        self._dataset = Dataset(data_dir=data_dir,
                                batch_size=batch_size,
                                augment=augment,
                                gpu_id=hvd.rank(),
                                num_gpus=hvd.size(),
                                seed=seed)

        self._training_hooks = [hvd.BroadcastGlobalVariablesHook(0)]

        if benchmark and hvd.rank() == 0:
            self._training_hooks.append(ProfilerHook(self._model_dir, batch_size, log_every=params['log_every'],
                                                     warmup_steps=params['warmup_steps']))
Exemple #7
0
def main(argv):
    
    try:

        task_id = int(os.environ['SLURM_ARRAY_TASK_ID'])

    except KeyError:

        task_id = 0
    
    
    model_save_dir = FLAGS.model_dir
    data_dir = FLAGS.data_dir
    print("Saving model to : " + str(model_save_dir))
    print("Loading data from : " + str(data_dir))
    test_data_dir = data_dir
    train_data_dir = data_dir
    epochs = FLAGS.epochs
    batch_size = FLAGS.batch_size
    dropout_rate = FLAGS.dropout_rate
    weight_decay = FLAGS.weight_decay
    lr = FLAGS.learning_rate
    load_model = FLAGS.load_model
    training_percentage = FLAGS.training_percentage
    preload_samples = FLAGS.preload_samples

    ds = Dataset(data_dir,is_training_set = True)
    n_total = ds.n_samples
    
    def augment_fn(sample,training):
        return augment_input(sample,ds.n_classes,training)
    
    dg = DataGenerator(ds,augment_fn,
                             training_percentage = training_percentage,
                             preload_samples = preload_samples,
                             save_created_features = False,
                             max_samples_per_audio = 99,
                             is_training=True)
    
    n_train = int(n_total*training_percentage/100)
    n_val = n_total-n_train

    #ResNet 18
    classifier_model = Classifier(ResBlockBasicLayer,
                 n_blocks = 4,
                 n_layers = [2,2,2,2],
                 strides = [2,2,2,2],
                 channel_base = [64,128,256,512],
                 n_classes = ds.n_classes+1,
                 init_ch = 64,
                 init_ksize = 7,
                 init_stride = 2,
                 use_max_pool = True,
                 kernel_regularizer = tf.keras.regularizers.l2(2e-4),
                 kernel_initializer = tf.keras.initializers.he_normal(),
                 name = "classifier",
                 dropout=dropout_rate)
    #Generator model used to augment to false samples
    generator_model = Generator(8,
                                [8,8,16,16,32,32,64,64],
                                kernel_regularizer = tf.keras.regularizers.l2(2e-4),
                                kernel_initializer = tf.keras.initializers.he_normal(),
                                name = "generator")
    
    #Discriminator for estimating the Wasserstein distance
    discriminator_model = Discriminator(3,
                                        [32,64,128],
                                        [4,4,4],
                                        name = "discriminator")
    
    data_gen = data_generator(dg.generate_all_samples_from_scratch,batch_size,
                                    is_training=True,
                                    n_classes = ds.n_classes)
    
    trainer = ModelTrainer(data_gen,
                    None,
                    None,
                    epochs,
                    EvalFunctions,
                    model_settings = [{'model':classifier_model,
                               'optimizer_type':tf.keras.optimizers.SGD,
                               'base_learning_rate':lr,
                               'learning_rate_fn':learning_rate_fn,
                               'init_data':tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS])},
                              {'model':generator_model,
                               'optimizer_type':tf.keras.optimizers.Adam,
                               'base_learning_rate':lr*0.0001,
                               'learning_rate_fn':learning_rate_fn,
                               'init_data':tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS])},
                              {'model':discriminator_model,
                               'optimizer_type':tf.keras.optimizers.Adam,
                               'base_learning_rate':lr*0.002,
                               'learning_rate_fn':learning_rate_fn,
                               'init_data':tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS])}],
                    summaries = None,
                    num_train_batches = int(n_train/batch_size),
                    load_model = load_model,
                    save_dir = model_save_dir,
                    input_keys = ["input_features","false_sample"],
                    label_keys = ["labels"],
                    start_epoch = 0)
    
    all_predictions = trainer.predict_dataset(data_gen)
    np.save(os.path.join(data_dir,"train_set_predictions.npy"),all_predictions)
Exemple #8
0
    def __init__(self,
                 dataset_paths,
                 training_params,
                 models_paths=None,
                 logs_path='runs/train_0',
                 continue_training=False):
        self.iter_nb = 0
        self.dataset_paths = dataset_paths
        self.training_params = training_params

        self.dataset = Dataset(
            dataset_paths,
            imagenet_path=self.training_params['mask_loss_path'])

        torch.manual_seed(111)
        np.random.seed(42)

        # Create training and validation set randomly
        dataset_length = len(self.dataset)
        train_set_length = int(0.99 * dataset_length)
        validation_set_length = dataset_length - train_set_length

        self.training_set, self.validation_set = torch.utils.data.random_split(
            self.dataset, [train_set_length, validation_set_length])

        self.data_loader = torch.utils.data.DataLoader(
            self.training_set,
            batch_size=self.training_params['batch_size'],
            shuffle=True,
            pin_memory=True,
            num_workers=2)

        self.data_loader_validation = torch.utils.data.DataLoader(
            self.validation_set,
            batch_size=self.training_params['batch_size'],
            shuffle=True,
            pin_memory=True,
            num_workers=2)

        self.moduleSemantics = Semantics().to(device).eval()
        self.moduleDisparity = Disparity().to(device).eval()

        weights_init(self.moduleDisparity)

        self.moduleMaskrcnn = torchvision.models.detection.maskrcnn_resnet50_fpn(
            pretrained=True).to(device).eval()

        self.optimizer_disparity = torch.optim.Adam(
            self.moduleDisparity.parameters(),
            lr=self.training_params['lr_estimation'])

        lambda_lr = lambda epoch: self.training_params['gamma_lr']**epoch
        self.scheduler_disparity = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_disparity, lr_lambda=lambda_lr)

        if self.training_params[
                'model_to_train'] == 'refine' or self.training_params[
                    'model_to_train'] == 'both':
            self.moduleRefine = Refine().to(device).eval()
            weights_init(self.moduleRefine)
            self.optimizer_refine = torch.optim.Adam(
                self.moduleRefine.parameters(),
                lr=self.training_params['lr_refine'])
            self.scheduler_refine = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer_refine, lr_lambda=lambda_lr)

        if models_paths is not None:
            print('Loading model state from ' + str(models_paths))
            if self.training_params[
                    'model_to_train'] == 'refine' or self.training_params[
                        'model_to_train'] == 'both':
                models_list = [{
                    'model': self.moduleDisparity,
                    'type': 'disparity',
                    'opt': self.optimizer_disparity,
                    'schedule': self.scheduler_disparity
                }, {
                    'model': self.moduleRefine,
                    'type': 'refine',
                    'opt': self.optimizer_refine,
                    'schedule': self.scheduler_refine
                }]
            else:
                models_list = [{
                    'model': self.moduleDisparity,
                    'type': 'disparity',
                    'opt': self.optimizer_disparity,
                    'schedule': self.scheduler_disparity
                }]
            self.iter_nb = load_models(models_list,
                                       models_paths,
                                       continue_training=continue_training)

        # use tensorboard to keep track of the runs
        self.writer = CustomWriter(logs_path)
Exemple #9
0
class TrainerDepth():
    def __init__(self,
                 dataset_paths,
                 training_params,
                 models_paths=None,
                 logs_path='runs/train_0',
                 continue_training=False):
        self.iter_nb = 0
        self.dataset_paths = dataset_paths
        self.training_params = training_params

        self.dataset = Dataset(
            dataset_paths,
            imagenet_path=self.training_params['mask_loss_path'])

        torch.manual_seed(111)
        np.random.seed(42)

        # Create training and validation set randomly
        dataset_length = len(self.dataset)
        train_set_length = int(0.99 * dataset_length)
        validation_set_length = dataset_length - train_set_length

        self.training_set, self.validation_set = torch.utils.data.random_split(
            self.dataset, [train_set_length, validation_set_length])

        self.data_loader = torch.utils.data.DataLoader(
            self.training_set,
            batch_size=self.training_params['batch_size'],
            shuffle=True,
            pin_memory=True,
            num_workers=2)

        self.data_loader_validation = torch.utils.data.DataLoader(
            self.validation_set,
            batch_size=self.training_params['batch_size'],
            shuffle=True,
            pin_memory=True,
            num_workers=2)

        self.moduleSemantics = Semantics().to(device).eval()
        self.moduleDisparity = Disparity().to(device).eval()

        weights_init(self.moduleDisparity)

        self.moduleMaskrcnn = torchvision.models.detection.maskrcnn_resnet50_fpn(
            pretrained=True).to(device).eval()

        self.optimizer_disparity = torch.optim.Adam(
            self.moduleDisparity.parameters(),
            lr=self.training_params['lr_estimation'])

        lambda_lr = lambda epoch: self.training_params['gamma_lr']**epoch
        self.scheduler_disparity = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_disparity, lr_lambda=lambda_lr)

        if self.training_params[
                'model_to_train'] == 'refine' or self.training_params[
                    'model_to_train'] == 'both':
            self.moduleRefine = Refine().to(device).eval()
            weights_init(self.moduleRefine)
            self.optimizer_refine = torch.optim.Adam(
                self.moduleRefine.parameters(),
                lr=self.training_params['lr_refine'])
            self.scheduler_refine = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer_refine, lr_lambda=lambda_lr)

        if models_paths is not None:
            print('Loading model state from ' + str(models_paths))
            if self.training_params[
                    'model_to_train'] == 'refine' or self.training_params[
                        'model_to_train'] == 'both':
                models_list = [{
                    'model': self.moduleDisparity,
                    'type': 'disparity',
                    'opt': self.optimizer_disparity,
                    'schedule': self.scheduler_disparity
                }, {
                    'model': self.moduleRefine,
                    'type': 'refine',
                    'opt': self.optimizer_refine,
                    'schedule': self.scheduler_refine
                }]
            else:
                models_list = [{
                    'model': self.moduleDisparity,
                    'type': 'disparity',
                    'opt': self.optimizer_disparity,
                    'schedule': self.scheduler_disparity
                }]
            self.iter_nb = load_models(models_list,
                                       models_paths,
                                       continue_training=continue_training)

        # use tensorboard to keep track of the runs
        self.writer = CustomWriter(logs_path)

    def train(self):

        print(
            'Starting training of estimation net on datasets: ',
            functools.reduce(lambda s1, s2: s1 + '\n ' + s2['path'],
                             self.dataset_paths, ""))

        if self.training_params['model_to_train'] == 'disparity':
            print('Training disparity estimation network.')
            self.moduleDisparity.train()
        elif self.training_params['model_to_train'] == 'refine':
            print('Training disparity refinement network.')
            self.moduleRefine.train()
        elif self.training_params['model_to_train'] == 'both':
            print('Training disparity networks.')
            self.moduleDisparity.train()
            self.moduleRefine.train()

        if self.training_params[
                'model_to_train'] == 'refine' or self.training_params[
                    'model_to_train'] == 'both':
            self.train_refine()

        elif self.training_params['model_to_train'] == 'disparity':
            self.train_estimation()

        self.writer.add_hparams(self.training_params, {})

    def train_estimation(self):
        for epoch in range(self.training_params['n_epochs']):
            for idx, (tensorImage, GTdisparities, sparseMask, imageNetTensor,
                      dataset_ids) in enumerate(
                          tqdm(self.data_loader,
                               desc='Epoch %d/%d' %
                               (epoch + 1, self.training_params['n_epochs']))):
                if ((idx + 1) % 500) == 0:
                    save_model(
                        {
                            'disparity': {
                                'model': self.moduleDisparity,
                                'opt': self.optimizer_disparity,
                                'schedule': self.scheduler_disparity,
                                'save_name': self.training_params['save_name']
                            }
                        }, self.iter_nb)
                    self.validation()

                tensorImage = tensorImage.to(device, non_blocking=True)
                GTdisparities = GTdisparities.to(device, non_blocking=True)
                sparseMask = sparseMask.to(device, non_blocking=True)
                imageNetTensor = imageNetTensor.to(device, non_blocking=True)

                with torch.no_grad():
                    semantic_tensor = self.moduleSemantics(tensorImage)

                # forward pass
                tensorDisparity = self.moduleDisparity(
                    tensorImage, semantic_tensor)  # depth estimation
                tensorDisparity = F.threshold(tensorDisparity,
                                              threshold=0.0,
                                              value=0.0)

                # reconstruction loss computation
                estimation_loss_ord = compute_loss_ord(tensorDisparity,
                                                       GTdisparities,
                                                       sparseMask,
                                                       mode='logrmse')
                estimation_loss_grad = compute_loss_grad(
                    tensorDisparity, GTdisparities, sparseMask)

                # loss weights computation
                beta = 0.015
                gamma_ord = 0.03 * (1 + 2 * np.exp(-beta * self.iter_nb)
                                    )  # for scale-invariant Loss
                # gamma_ord = 0.001 * (1+ 200 * np.exp( - beta * self.iter_nb)) # for L1 loss
                gamma_grad = 1 - np.exp(-beta * self.iter_nb)
                gamma_mask = 0.0001 * (1 - np.exp(-beta * self.iter_nb))

                if self.training_params['mask_loss'] == 'same':
                    # when mask_loss is 'same' masks are computed on the same images
                    with torch.no_grad():
                        objectPredictions = self.moduleMaskrcnn(tensorImage)

                    masks_tensor_list = list(
                        map(
                            lambda object_pred: resize_image(
                                object_pred['masks'], max_size=256),
                            objectPredictions))
                    estimation_masked_loss = 0
                    for i, masks_tensor in enumerate(masks_tensor_list):
                        if masks_tensor is not None:
                            estimation_masked_loss += compute_masked_grad_loss(
                                tensorDisparity[i].view(
                                    1, *tensorDisparity[i].shape),
                                masks_tensor, [1], 0.5)

                    loss_depth = gamma_ord * estimation_loss_ord + gamma_grad * estimation_loss_grad + gamma_mask * estimation_masked_loss

                else:  # No mask loss in this case
                    loss_depth = gamma_ord * estimation_loss_ord + gamma_grad * estimation_loss_grad

                # compute gradients and update net
                self.optimizer_disparity.zero_grad()
                loss_depth.backward()
                torch.nn.utils.clip_grad_norm_(
                    self.moduleDisparity.parameters(), 1)
                self.optimizer_disparity.step()
                self.scheduler_disparity.step()

                # keep track of loss values
                self.writer.add_scalar('Estimation/Loss ord',
                                       estimation_loss_ord, self.iter_nb)
                self.writer.add_scalar('Estimation/Loss grad',
                                       estimation_loss_grad, self.iter_nb)
                self.writer.add_scalar('Estimation/Loss depth', loss_depth,
                                       self.iter_nb)

                if self.training_params['mask_loss'] == 'same':
                    self.writer.add_scalar('Estimation/Loss mask',
                                           estimation_masked_loss,
                                           self.iter_nb)
                elif self.training_params['mask_loss'] == 'other':
                    self.step_imagenet(
                        imageNetTensor
                    )  # when mask loss is computed on another dataset
                else:
                    self.writer.add_scalar('Estimation/Loss mask', 0,
                                           self.iter_nb)

                # keep track of gradient magnitudes
                # for i, m in enumerate(self.moduleDisparity.modules()):
                #     if m.__class__.__name__ == 'Conv2d':
                #         g = m.weight.grad
                #         # print(g)
                #         if g is not None:
                #             self.writer.add_scalar('Estimation gradients/Conv {}'.format(i), torch.norm(g/g.size(0), p=1).item(), self.iter_nb)

                self.iter_nb += 1

            self.validation()

    def train_refine(self):

        self.dataset.mode = 'refine'
        self.data_loader = self.dataset.get_dataloader(batch_size=2)

        for epoch in range(self.training_params['n_epochs']):
            for idx, (tensorImage, GTdisparities, masks, imageNetTensor,
                      dataset_ids) in enumerate(
                          tqdm(self.data_loader,
                               desc='Epoch %d/%d' %
                               (epoch + 1, self.training_params['n_epochs']))):
                if ((idx + 1) % 500) == 0:
                    save_model(
                        {
                            'refine': {
                                'model': self.moduleRefine,
                                'opt': self.optimizer_refine,
                                'schedule': self.scheduler_refine,
                                'save_name': self.training_params['save_name']
                            }
                        }, self.iter_nb)
                    self.validation(refine_training=True)

                tensorImage = tensorImage.to(device, non_blocking=True)
                GTdisparities = GTdisparities.to(device, non_blocking=True)
                masks = masks.to(device, non_blocking=True)

                # first estimate depth with estimation net
                with torch.no_grad():
                    tensorResized = resize_image(tensorImage, max_size=512)
                    tensorDisparity = self.moduleDisparity(
                        tensorResized, self.moduleSemantics(
                            tensorResized))  # depth estimation
                    tensorResized = None

                # forward pass with refinement net
                tensorDisparity = self.moduleRefine(tensorImage,
                                                    tensorDisparity)

                # compute losses
                refine_loss_ord = compute_loss_ord(tensorDisparity,
                                                   GTdisparities, masks)
                refine_loss_grad = compute_loss_grad(tensorDisparity,
                                                     GTdisparities, masks)

                loss_depth = 0.0001 * refine_loss_ord + refine_loss_grad

                if self.training_params['model_to_train'] == 'both':
                    self.optimizer_disparity.zero_grad()

                # backward pass
                self.optimizer_refine.zero_grad()
                loss_depth.backward()
                torch.nn.utils.clip_grad_norm_(self.moduleRefine.parameters(),
                                               1)
                self.optimizer_refine.step()
                self.scheduler_refine.step()

                if self.training_params['model_to_train'] == 'both':
                    self.optimizer_disparity.step()

                ## keep track of loss on tensorboard
                self.writer.add_scalar('Refine/Loss ord', refine_loss_ord,
                                       self.iter_nb)
                self.writer.add_scalar('Refine/Loss grad', refine_loss_grad,
                                       self.iter_nb)
                self.writer.add_scalar('Refine/Loss depth', loss_depth,
                                       self.iter_nb)

                ## keep track of gradient magnitudes
                # for i, m in enumerate(self.moduleRefine.modules()):
                #     if m.__class__.__name__ == 'Conv2d':
                #         g = m.weight.grad.view(-1)
                #         if g is not None:
                #             self.writer.add_scalar('Refine gradients/Conv {}'.format(i), torch.norm(g/g.size(0), p=1).item(), self.iter_nb)

                self.iter_nb += 1

    def step_imagenet(self, tensorImage):
        with torch.no_grad():
            semantic_tensor = self.moduleSemantics(tensorImage)

        tensorDisparity = self.moduleDisparity(
            tensorImage, semantic_tensor)  # depth estimation

        # compute segmentation masks on batch
        with torch.no_grad():
            objectPredictions = self.moduleMaskrcnn(tensorImage)

        masks_tensor_list = list(
            map(
                lambda object_pred: resize_image(object_pred['masks'],
                                                 max_size=256),
                objectPredictions))

        # compute mask loss
        estimation_masked_loss = 0
        for i, masks_tensor in enumerate(masks_tensor_list):
            if masks_tensor is not None:
                estimation_masked_loss += 0.0001 * compute_masked_grad_loss(
                    tensorDisparity[i].view(1, *tensorDisparity[i].shape),
                    resize_image(masks_tensor.view(-1, 1, 256, 256),
                                 max_size=128), [1], 1)

        if estimation_masked_loss != 0:
            # backward pass for mask loss only
            self.optimizer_disparity.zero_grad()
            estimation_masked_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.moduleDisparity.parameters(),
                                           0.1)
            self.optimizer_disparity.step()
            self.scheduler_disparity.step()

            self.writer.add_scalar('Estimation/Loss mask',
                                   estimation_masked_loss, self.iter_nb)

    def validation(self, refine_training=False):
        # compute metrics on the validation set
        self.moduleDisparity.eval()
        if refine_training:
            self.moduleRefine.eval()

        measures = []

        metrics = {}
        metrics_list = [
            'Abs rel', 'Sq rel', 'RMSE', 'log RMSE', 's1', 's2', 's3'
        ]
        MSELoss = nn.MSELoss()

        with torch.no_grad():
            for idx, (tensorImage, disparities, masks, imageNetTensor,
                      dataset_ids) in enumerate(
                          tqdm(self.data_loader_validation,
                               desc='Validation')):
                tensorImage = tensorImage.to(device, non_blocking=True)
                disparities = disparities.to(device, non_blocking=True)
                masks = masks.to(device, non_blocking=True)

                tensorResized = resize_image(tensorImage)
                tensorDisparity = self.moduleDisparity(
                    tensorResized,
                    self.moduleSemantics(tensorResized))  # depth estimation

                if refine_training:
                    tensorDisparity = self.moduleRefine(
                        tensorImage, tensorDisparity)  # increase resolution
                else:
                    disparities = resize_image(disparities, max_size=256)
                    masks = resize_image(masks, max_size=256)

                tensorDisparity = F.threshold(tensorDisparity,
                                              threshold=0.0,
                                              value=0.0)

                masks = masks.clamp(0, 1)
                measures.append(
                    np.array(
                        compute_metrics(tensorDisparity, disparities, masks)))

        measures = np.array(measures).mean(axis=0)

        for i, name in enumerate(metrics_list):
            metrics[name] = measures[i]
            self.writer.add_scalar('Validation/' + name, measures[i],
                                   self.iter_nb)

        if refine_training:
            self.moduleRefine.train()
        else:
            self.moduleDisparity.train()
def main(_):
    """
    Starting point of the application
    """

    flags = PARSER.parse_args()
    params = _cmd_params(flags)
    np.random.seed(params.seed)
    tf.compat.v1.random.set_random_seed(params.seed)
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

    backends = [StdOutBackend(Verbosity.VERBOSE)]
    if params.log_dir is not None:
        backends.append(JSONStreamBackend(Verbosity.VERBOSE, params.log_dir))
    logger = Logger(backends)

    # Optimization flags
    os.environ['CUDA_CACHE_DISABLE'] = '0'

    os.environ['HOROVOD_GPU_ALLREDUCE'] = 'NCCL'

    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

    os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'

    os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = 'data'

    os.environ['TF_ADJUST_HUE_FUSED'] = 'data'
    os.environ['TF_ADJUST_SATURATION_FUSED'] = 'data'
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = 'data'

    os.environ['TF_SYNC_ON_FINISH'] = '0'
    os.environ['TF_AUTOTUNE_THRESHOLD'] = '2'

    if params.use_amp:
        os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
    else:
        os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '0'
    hvd.init()

    # Build run config
    gpu_options = tf.compat.v1.GPUOptions()
    config = tf.compat.v1.ConfigProto(gpu_options=gpu_options,
                                      allow_soft_placement=True)

    if params.use_xla:
        config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1

    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())

    run_config = tf.estimator.RunConfig(
        save_summary_steps=1,
        tf_random_seed=None,
        session_config=config,
        save_checkpoints_steps=params.max_steps // hvd.size(),
        keep_checkpoint_max=1)

    # Build the estimator model
    estimator = tf.estimator.Estimator(model_fn=unet_fn,
                                       model_dir=params.model_dir,
                                       config=run_config,
                                       params=params)

    dataset = Dataset(data_dir=params.data_dir,
                      batch_size=params.batch_size,
                      fold=params.crossvalidation_idx,
                      augment=params.augment,
                      gpu_id=hvd.rank(),
                      num_gpus=hvd.size(),
                      seed=params.seed)

    if 'train' in params.exec_mode:
        max_steps = params.max_steps // (1 if params.benchmark else hvd.size())
        hooks = [
            hvd.BroadcastGlobalVariablesHook(0),
            TrainingHook(logger,
                         max_steps=max_steps,
                         log_every=params.log_every)
        ]

        if params.benchmark and hvd.rank() == 0:
            hooks.append(
                ProfilingHook(logger,
                              batch_size=params.batch_size,
                              log_every=params.log_every,
                              warmup_steps=params.warmup_steps,
                              mode='train'))

        estimator.train(input_fn=dataset.train_fn,
                        steps=max_steps,
                        hooks=hooks)

    if 'evaluate' in params.exec_mode:
        if hvd.rank() == 0:
            results = estimator.evaluate(input_fn=dataset.eval_fn,
                                         steps=dataset.eval_size)
            logger.log(step=(),
                       data={
                           "eval_ce_loss": float(results["eval_ce_loss"]),
                           "eval_dice_loss": float(results["eval_dice_loss"]),
                           "eval_total_loss":
                           float(results["eval_total_loss"]),
                           "eval_dice_score": float(results["eval_dice_score"])
                       })

    if 'predict' in params.exec_mode:
        if hvd.rank() == 0:
            predict_steps = dataset.test_size
            hooks = None
            if params.benchmark:
                hooks = [
                    ProfilingHook(logger,
                                  batch_size=params.batch_size,
                                  log_every=params.log_every,
                                  warmup_steps=params.warmup_steps,
                                  mode="test")
                ]
                predict_steps = params.warmup_steps * 2 * params.batch_size

            predictions = estimator.predict(input_fn=lambda: dataset.test_fn(
                count=math.ceil(predict_steps / dataset.test_size)),
                                            hooks=hooks)
            binary_masks = [
                np.argmax(p['logits'], axis=-1).astype(np.uint8) * 255
                for p in predictions
            ]

            if not params.benchmark:
                multipage_tif = [
                    Image.fromarray(mask).resize(size=(512, 512),
                                                 resample=Image.BILINEAR)
                    for mask in binary_masks
                ]

                output_dir = os.path.join(params.model_dir, 'pred')

                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)

                multipage_tif[0].save(os.path.join(output_dir,
                                                   'test-masks.tif'),
                                      compression="tiff_deflate",
                                      save_all=True,
                                      append_images=multipage_tif[1:])
def main():
    """
    Starting point of the application
    """

    flags = PARSER.parse_args()
    params = _cmd_params(flags)

    backends = [StdOutBackend(Verbosity.VERBOSE)]
    if params.log_dir is not None:
        backends.append(JSONStreamBackend(Verbosity.VERBOSE, params.log_dir))
    logger = Logger(backends)

    # Optimization flags
    os.environ['CUDA_CACHE_DISABLE'] = '0'

    os.environ['HOROVOD_GPU_ALLREDUCE'] = 'NCCL'

    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

    os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'

    os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = 'data'

    os.environ['TF_ADJUST_HUE_FUSED'] = 'data'
    os.environ['TF_ADJUST_SATURATION_FUSED'] = 'data'
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = 'data'

    os.environ['TF_SYNC_ON_FINISH'] = '0'
    os.environ['TF_AUTOTUNE_THRESHOLD'] = '2'

    hvd.init()

    if params.use_xla:
        tf.config.optimizer.set_jit(True)

    gpus = tf.config.experimental.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()],
                                                   'GPU')

    if params.use_amp:
        tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
    else:
        os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '0'

    # Build the  model
    model = Unet()

    dataset = Dataset(data_dir=params.data_dir,
                      batch_size=params.batch_size,
                      fold=params.crossvalidation_idx,
                      augment=params.augment,
                      gpu_id=hvd.rank(),
                      num_gpus=hvd.size(),
                      seed=params.seed)

    if 'train' in params.exec_mode:
        train(params, model, dataset, logger)

    if 'evaluate' in params.exec_mode:
        if hvd.rank() == 0:
            model = restore_checkpoint(model, params.model_dir)
            evaluate(params, model, dataset, logger)

    if 'predict' in params.exec_mode:
        if hvd.rank() == 0:
            model = restore_checkpoint(model, params.model_dir)
            predict(params, model, dataset, logger)
def main(_):
    """
    Starting point of the application
    """

    flags = PARSER.parse_args()

    params = _cmd_params(flags)

    tf.logging.set_verbosity(tf.logging.ERROR)

    # Optimization flags
    os.environ['CUDA_CACHE_DISABLE'] = '0'

    os.environ['HOROVOD_GPU_ALLREDUCE'] = 'NCCL'

    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

    os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'

    os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = 'data'

    os.environ['TF_ADJUST_HUE_FUSED'] = 'data'
    os.environ['TF_ADJUST_SATURATION_FUSED'] = 'data'
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = 'data'

    os.environ['TF_SYNC_ON_FINISH'] = '0'
    os.environ['TF_AUTOTUNE_THRESHOLD'] = '2'

    if params['use_amp']:
        os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'

    hvd.init()

    # Build run config
    gpu_options = tf.GPUOptions()
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.force_gpu_compatible = True
    config.intra_op_parallelism_threads = 1
    config.inter_op_parallelism_threads = max(2, 40 // hvd.size() - 2)

    run_config = tf.estimator.RunConfig(
        save_summary_steps=1,
        tf_random_seed=None,
        session_config=config,
        save_checkpoints_steps=params['max_steps'],
        keep_checkpoint_max=1)

    # Build the estimator model
    estimator = tf.estimator.Estimator(model_fn=unet_fn,
                                       model_dir=params['model_dir'],
                                       config=run_config,
                                       params=params)

    dataset = Dataset(data_dir=params['data_dir'],
                      batch_size=params['batch_size'],
                      augment=params['augment'],
                      gpu_id=hvd.rank(),
                      num_gpus=hvd.size(),
                      seed=params['seed'])

    if 'train' in params['exec_mode']:
        hooks = [
            hvd.BroadcastGlobalVariablesHook(0),
            TrainingHook(params['log_every'])
        ]

        if params['benchmark']:
            hooks.append(
                ProfilingHook(params['batch_size'], params['log_every'],
                              params['warmup_steps']))

        LOGGER.log('Begin Training...')

        LOGGER.log(tags.RUN_START)
        estimator.train(input_fn=dataset.train_fn,
                        steps=params['max_steps'],
                        hooks=hooks)
        LOGGER.log(tags.RUN_STOP)

    if 'predict' in params['exec_mode']:
        if hvd.rank() == 0:
            predict_steps = dataset.test_size
            hooks = None
            if params['benchmark']:
                hooks = [
                    ProfilingHook(params['batch_size'], params['log_every'],
                                  params['warmup_steps'])
                ]
                predict_steps = params['warmup_steps'] * 2 * params[
                    'batch_size']

            LOGGER.log('Begin Predict...')
            LOGGER.log(tags.RUN_START)

            predictions = estimator.predict(input_fn=lambda: dataset.test_fn(
                count=math.ceil(predict_steps / dataset.test_size)),
                                            hooks=hooks)

            binary_masks = [
                np.argmax(p['logits'], axis=-1).astype(np.uint8) * 255
                for p in predictions
            ]
            LOGGER.log(tags.RUN_STOP)

            multipage_tif = [
                Image.fromarray(mask).resize(size=(512, 512),
                                             resample=Image.BILINEAR)
                for mask in binary_masks
            ]

            output_dir = os.path.join(params['model_dir'], 'pred')

            if not os.path.exists(output_dir):
                os.makedirs(output_dir)

            multipage_tif[0].save(os.path.join(output_dir, 'test-masks.tif'),
                                  compression="tiff_deflate",
                                  save_all=True,
                                  append_images=multipage_tif[1:])

            LOGGER.log("Predict finished")
            LOGGER.log("Results available in: {}".format(output_dir))
class DepthEval():
    def __init__(self,
                 dataset_paths,
                 model_paths,
                 eval_refine=False,
                 eval_pretrained=False,
                 perform_adjustment=True,
                 evaluation_path=None,
                 imagenet_path=None):
        self.dataset = Dataset(dataset_paths, imagenet_path=imagenet_path)
        self.dataset.mode = 'eval'
        self.dataset.padding = True
        self.data_loader = self.dataset.get_dataloader(batch_size=1)
        self.perform_adjustment = perform_adjustment
        self.dataset_paths = dataset_paths

        torch.manual_seed(42)
        np.random.seed(42)

        self.moduleSemantics = Semantics().to(device).eval()
        self.moduleDisparity = Disparity().to(device).eval()
        self.moduleMaskrcnn = torchvision.models.detection.maskrcnn_resnet50_fpn(
            pretrained=True).to(device).eval()
        self.moduleRefine = Refine().to(device).eval()

        self.eval_pretrained = eval_pretrained

        if model_paths is not None:
            models_list = [{
                'model': self.moduleDisparity,
                'type': 'disparity'
            }, {
                'model': self.moduleRefine,
                'type': 'refine'
            }]
            load_models(models_list, model_paths)

    def eval(self):
        # compute the metrics on the provided dataset with the provided networks
        measures = []

        metrics = {}
        metrics_list = [
            'Abs rel', 'Sq rel', 'RMSE', 'log RMSE', 's1', 's2', 's3'
        ]
        MSELoss = nn.MSELoss()

        print(
            'Starting evaluation on datasets: ',
            functools.reduce(lambda s1, s2: s1['path'] + ', ' + s2['path'],
                             self.dataset_paths))

        for idx, (tensorImage, disparities, masks, imageNetTensor,
                  dataset_ids) in enumerate(tqdm(self.data_loader)):
            tensorImage = tensorImage.to(device, non_blocking=True)
            disparities = disparities.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            N = tensorImage.size()[2] * tensorImage.size()[3]

            # pretrained networks from 3D KBE were trained with image normalized between 0 and 1
            if self.eval_pretrained:
                tensorImage = (tensorImage + 1) / 2

            tensorResized = resize_image(tensorImage)

            tensorDisparity = self.moduleDisparity(
                tensorResized,
                self.moduleSemantics(tensorResized))  # depth estimation
            tensorDisparity = self.moduleRefine(
                tensorImage, tensorDisparity)  # increase resolution
            tensorDisparity = F.threshold(tensorDisparity,
                                          threshold=0.0,
                                          value=0.0)

            masks = masks.clamp(0, 1)
            measures.append(
                np.array(compute_metrics(tensorDisparity, disparities, masks)))

        measures = np.array(measures).mean(axis=0)

        for i, name in enumerate(metrics_list):
            metrics[name] = measures[i]

        return metrics

    def get_depths(self):
        # return input images and predictions
        def detach_tensor(tensor):
            return tensor.cpu().detach().numpy()

        tensorImage, disparities, masks, imageNetTensor, dataset_ids = next(
            iter(self.data_loader))
        tensorImage = tensorImage.to(device, non_blocking=True)
        disparities = disparities.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)

        # pretrained networks from 3D KBE were trained with image normalized between 0 and 1
        if self.eval_pretrained:
            tensorImage = (tensorImage + 1) / 2

        tensorResized = resize_image(tensorImage)

        # retrieve parameters for different sets of images
        tensorFocal = torch.Tensor([
            self.dataset_paths[int(id.item())]['params']['focal']
            for id in dataset_ids
        ])
        tensorBaseline = torch.Tensor([
            self.dataset_paths[int(id.item())]['params']['baseline']
            for id in dataset_ids
        ])
        tensorFocal = tensorFocal.view(-1, 1).repeat(
            1, 1,
            tensorImage.size(2) *
            tensorImage.size(3)).view(*disparities.size())
        tensorBaseline = tensorBaseline.view(-1, 1).repeat(
            1, 1,
            tensorImage.size(2) *
            tensorImage.size(3)).view(*disparities.size())

        tensorBaseline = tensorBaseline.to(device)
        tensorFocal = tensorFocal.to(device)

        tensorDisparity = self.moduleDisparity(
            tensorResized,
            self.moduleSemantics(tensorResized))  # depth estimation

        objectPredictions = self.moduleMaskrcnn(
            tensorImage)  # segment image in mask using Mask-RCNN

        tensorDisparityAdjusted = tensorDisparity
        tensorDisparityRefined = self.moduleRefine(
            tensorImage[:2, :, :, :],
            tensorDisparityAdjusted[:2, :, :, :])  # increase resolution

        return (detach_tensor(tensorDisparity),
                detach_tensor(tensorDisparityAdjusted),
                detach_tensor(tensorDisparityRefined),
                detach_tensor(disparities),
                detach_tensor(resize_image(disparities, max_size=256)),
                detach_tensor((tensorImage.permute(0, 2, 3, 1) + 1) / 2),
                objectPredictions, detach_tensor(masks),
                detach_tensor(resize_image(masks, max_size=256)))
Exemple #14
0
    def __init__(self,
                 dataset_paths,
                 training_params,
                 models_paths=None,
                 logs_path='runs/train_0',
                 continue_training=False):
        self.iter_nb = 0
        self.dataset_paths = dataset_paths
        self.training_params = training_params

        self.dataset = Dataset(dataset_paths, mode='inpainting')

        # Create training and validation set randomly
        dataset_length = len(self.dataset)
        train_set_length = int(0.99 * dataset_length)
        validation_set_length = dataset_length - train_set_length

        self.training_set, self.validation_set = torch.utils.data.random_split(
            self.dataset, [train_set_length, validation_set_length])

        self.data_loader = torch.utils.data.DataLoader(
            self.training_set,
            batch_size=self.training_params['batch_size'],
            shuffle=True,
            pin_memory=True,
            num_workers=2)

        self.data_loader_validation = torch.utils.data.DataLoader(
            self.validation_set,
            batch_size=self.training_params['batch_size'] * 2,
            shuffle=True,
            pin_memory=True,
            num_workers=2)

        if self.training_params['model_to_train'] == 'inpainting':
            self.moduleInpaint = Inpaint().to(device).train()
        elif self.training_params['model_to_train'] == 'partial inpainting':
            self.moduleInpaint = PartialInpaint().to(device).train()

        weights_init(self.moduleInpaint, init_gain=0.01)

        self.optimizer_inpaint = torch.optim.Adam(
            self.moduleInpaint.parameters(),
            lr=self.training_params['lr_inpaint'])

        self.loss_inpaint = InpaintingLoss(kbe_only=False, perceptual=True)

        self.loss_weights = {
            'hole': 6,
            'valid': 1,
            'prc': 0.05,
            'tv': 0.1,
            'style': 120,
            'grad': 10,
            'ord': 0.0001,
            'color': 0,
            'mask': 0.0001,
            'valid_depth': 1,
            'joint_edge': 1
        }

        lambda_lr = lambda epoch: self.training_params['gamma_lr']**epoch
        self.scheduler_inpaint = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_inpaint, lr_lambda=lambda_lr)

        if models_paths is not None:
            models_list = [{'model': self.moduleInpaint, 'type': 'inpaint'}]
            load_models(models_list, models_paths)

        if self.training_params['adversarial']:
            ## Train with view B
            self.discriminator = MPDDiscriminator().to(
                device)  # other type of discriminator can be used here

            ## Train with view C
            # self.discriminator = MultiScalePerceptualDiscriminator().to(device)

            spectral_norm_switch(self.discriminator, on=True)

            self.optimizerD = torch.optim.Adam(self.discriminator.parameters(),
                                               lr=self.training_params['lr_D'])
            self.schedulerD = torch.optim.lr_scheduler.LambdaLR(
                self.optimizerD, lr_lambda=lambda_lr)

            # discriminator balancing parameters
            self.balanceSteps = 5  # number of D steps per G step
            self.pretrainSteps = 1000  # number of pretraining steps for D
            self.stopG = 10000  # restart pretraining of D every stopG steps

        self.writer = CustomWriter(logs_path)
def main(_):
    """
    Starting point of the application
    """
    hvd.init()
    set_flags()
    params = parse_args(PARSER.parse_args())
    model_dir = prepare_model_dir(params)
    logger = get_logger(params)

    estimator = build_estimator(params, model_dir)

    dataset = Dataset(data_dir=params.data_dir,
                      batch_size=params.batch_size,
                      fold=params.crossvalidation_idx,
                      augment=params.augment,
                      gpu_id=hvd.rank(),
                      num_gpus=hvd.size(),
                      seed=params.seed)

    if 'train' in params.exec_mode:
        max_steps = params.max_steps // (1 if params.benchmark else hvd.size())
        hooks = [hvd.BroadcastGlobalVariablesHook(0),
                 TrainingHook(logger,
                              max_steps=max_steps,
                              log_every=params.log_every)]

        if params.benchmark and hvd.rank() == 0:
            hooks.append(ProfilingHook(logger,
                                       batch_size=params.batch_size,
                                       log_every=params.log_every,
                                       warmup_steps=params.warmup_steps,
                                       mode='train'))

        estimator.train(
            input_fn=dataset.train_fn,
            steps=max_steps,
            hooks=hooks)

    if 'evaluate' in params.exec_mode:
        if hvd.rank() == 0:
            results = estimator.evaluate(input_fn=dataset.eval_fn, steps=dataset.eval_size)
            logger.log(step=(),
                       data={"eval_ce_loss": float(results["eval_ce_loss"]),
                             "eval_dice_loss": float(results["eval_dice_loss"]),
                             "eval_total_loss": float(results["eval_total_loss"]),
                             "eval_dice_score": float(results["eval_dice_score"])})

    if 'predict' in params.exec_mode:
        if hvd.rank() == 0:
            predict_steps = dataset.test_size
            hooks = None
            if params.benchmark:
                hooks = [ProfilingHook(logger,
                                       batch_size=params.batch_size,
                                       log_every=params.log_every,
                                       warmup_steps=params.warmup_steps,
                                       mode="test")]
                predict_steps = params.warmup_steps * 2 * params.batch_size

            predictions = estimator.predict(
                input_fn=lambda: dataset.test_fn(count=math.ceil(predict_steps / dataset.test_size)),
                hooks=hooks)
            binary_masks = [np.argmax(p['logits'], axis=-1).astype(np.uint8) * 255 for p in predictions]

            if not params.benchmark:
                multipage_tif = [Image.fromarray(mask).resize(size=(512, 512), resample=Image.BILINEAR)
                                 for mask in binary_masks]

                output_dir = os.path.join(params.model_dir, 'pred')

                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)

                multipage_tif[0].save(os.path.join(output_dir, 'test-masks.tif'),
                                      compression="tiff_deflate",
                                      save_all=True,
                                      append_images=multipage_tif[1:])