Пример #1
0
    def test_2_evaluate(self):
        """Test model.evaluate()"""
        set_logger(join(self.cfg.log_dir, 'train.log'))

        tester_cfg = deepcopy(self.cfg)
        tester_cfg.model['load']['version'] = 'default'
        tester_cfg.model['load']['load_best'] = True
        model = BinaryClassificationModel(tester_cfg)
        dataloader, _ = get_dataloader(
            tester_cfg.data, 'val',
            tester_cfg.model['batch_size'],
            num_workers=4,
            shuffle=False,
            drop_last=False)
        model.evaluate(dataloader, 'val', False)
Пример #2
0
    def test_classification_dataloader_2d(self):
        """Test get_dataloader for classification with each input being 2D"""
        cfg = {
            'root': DATA_ROOT,
            'data_type': 'image',
            'dataset': {
                'name':
                'classification_dataset',
                'params': {
                    'test': {
                        'fraction': 0.1
                    }
                },
                'config': [{
                    'name': 'CIFAR10',
                    'version': 'default',
                    'mode': 'test'
                }]
            },
            'target_transform': {
                'name': 'classification',
                'params': {
                    'classes': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
                }
            },
            'signal_transform': {
                'test': [{
                    'name': 'Permute',
                    'params': {
                        'order': [2, 0, 1]
                    }
                }, {
                    'name': 'Resize',
                    'params': {
                        'size': [30, 30]
                    }
                }]
            },
            'sampler': {
                'test': {
                    'name': 'default'
                }
            },
            'collate_fn': {
                'name': 'classification_collate'
            }
        }
        batch_size = 8

        dataloader, _ = get_dataloader(cfg,
                                       'test',
                                       batch_size=batch_size,
                                       shuffle=False,
                                       drop_last=False)

        iterator = iter(dataloader)
        batch = next(iterator)
        signals, labels = batch['signals'], batch['labels']

        self.assertIsInstance(signals, torch.Tensor)
        self.assertIsInstance(labels, torch.Tensor)
        self.assertEqual(signals.dtype, torch.float32)
        self.assertEqual(labels.dtype, torch.float32)
        self.assertEqual(len(signals), len(labels))
        self.assertEqual(len(signals.shape), 4)
        self.assertTrue(signals.shape, (batch_size, 3, 30, 30))
Пример #3
0
def evaluate(config, mode, use_wandb, ignore_cache, n_tta):
    """Run the actual evaluation

    :param config: config for the model to evaluate
    :type config: Config
    :param mode: data mode to evaluate on
    :type mode: str
    :param use_wandb: whether to log values to wandb
    :type use_wandb: bool
    :param ignore_cache: whether to ignore cached predictions
    :type ignore_cache: bool
    """
    model = model_factory.create(config.model['name'], **{'config': config})
    logging.info(color(f'Evaluating on mode: {mode}'))

    # reset sampler to default
    config.data['sampler'].update({mode: {'name': 'default'}})

    dataloader, _ = get_dataloader(config.data,
                                   mode,
                                   config.model['batch_size'],
                                   num_workers=config.num_workers,
                                   shuffle=False,
                                   drop_last=False)

    # set to eval mode
    model.network.eval()

    all_predictions = []
    for run_index in range(n_tta):
        logging.info(f'TTA run #{run_index + 1}')
        results = model.evaluate(dataloader,
                                 mode,
                                 use_wandb,
                                 ignore_cache,
                                 data_only=True,
                                 log_summary=False)

        logging.info(f'AUC = {results["auc-roc"]}')

        # logits
        predictions = results['predictions']

        # convert to softmax
        predictions = torch.sigmoid(predictions)

        # add to list of all predictions across each TTA run
        all_predictions.append(predictions)

    all_predictions = torch.stack(all_predictions, -1)

    # take the mean across several TTA runs
    predictions = all_predictions.mean(-1)

    # calculate the metrics on the TTA predictions
    metrics = model.compute_epoch_metrics(predictions,
                                          results['targets'],
                                          as_logits=False)

    print(f'TTA auc: {metrics["auc-roc"]}')

    # get the file names
    names = [splitext(basename(item.path))[0] for item in results['items']]

    # convert to data frame
    data_frame = pd.DataFrame({
        'image_name': names,
        'target': predictions.tolist()
    })

    # save the results
    save_path = join(config.log_dir, 'evaluation', f'{mode}.csv')
    os.makedirs(dirname(save_path), exist_ok=True)
    logging.info(color(f'Saving results to {save_path}'))
    data_frame.to_csv(save_path, index=False)
Пример #4
0
    def fit(self,
            debug: bool = False,
            overfit_batch: bool = False,
            use_wandb: bool = True):
        """Entry point to training the network

        :param debug: test run with epoch only on the val set without training,
            defaults to False
        :type debug: bool, optional
        :param overfit_batch: whether this run is for overfitting on a batch,
            defaults to False
        :type overfit_batch: bool, optional
        :param use_wandb: flag for whether to log visualizations to wandb,
            defaults to True
        :type use_wandb: bool, optional
        """
        if not debug:
            # if we are overfitting a batch, then turn off shuffling
            # for the train set. Else set it to True
            shuffle = not overfit_batch
            train_dataloader, _ = get_dataloader(
                self.data_config,
                self.config.train_mode,
                self.model_config['batch_size'],
                num_workers=self.config.num_workers,
                shuffle=shuffle,
                drop_last=False)

        # ignore val operations when overfitting on a batch
        if not overfit_batch:
            val_dataloader, _ = get_dataloader(
                self.data_config,
                self.config.val_mode,
                self.model_config['batch_size'],
                num_workers=self.config.num_workers,
                shuffle=False,
                drop_last=False)
        else:
            logging.info(color('Overfitting a single batch', 'blue'))

        # track gradients and weights in wandb
        if use_wandb:
            self.network.watch()

        for epochID in range(self.model_config['epochs']):
            if not debug:
                # train epoch
                train_results = self.process_epoch(train_dataloader,
                                                   self.config.train_mode,
                                                   training=True,
                                                   use_wandb=use_wandb,
                                                   overfit_batch=overfit_batch)

            # ignore val operations when overfitting on a batch
            if not overfit_batch:
                # val epoch
                val_results = self.process_epoch(val_dataloader,
                                                 self.config.val_mode,
                                                 training=False,
                                                 use_wandb=use_wandb)

                # save best model
                self.save(val_results, use_wandb=use_wandb)

                # update optimizer parameters using schedulers that
                # operate per epoch like ReduceLROnPlateau
                if hasattr(self,
                           'update_freq') and 'epoch' in self.update_freq:
                    logging.info('Running scheduler step')
                    self.update_optimizer_params(val_results, 'epoch')

            # increment epoch counter
            self.epoch_counter += 1
Пример #5
0
config = Config(join('/workspace/coreml', config_name + '.yml'))


# In[19]:


set_logger(join(config.log_dir, 'debug.log'))


# In[10]:


val_dataloader, _ = get_dataloader(
        config.data, 'val',
        config.model['batch_size'],
        num_workers=10,
        shuffle=False,
        drop_last=False)


# In[16]:


# set epoch
config.model['load']['version'] = config_name
config.model['load']['load_best'] = True


# In[39]: