def __load_metrics(self):
     precision = Precision(average=False)
     recall = Recall(average=False)
     F1 = precision * recall * 2 / (precision + recall + 1e-20)
     F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1)
     confusion_matrix = ConfusionMatrix(self.n_class, average="recall")
     # TODO: Add metric by patient
     self.metrics = {
         'accuracy': Accuracy(),
         "f1": F1,
         "confusion_matrix": confusion_matrix,
         "precision": precision.mean(),
         "recall": recall.mean(),
         'loss': Loss(self.loss)
     }
    def train(self, config, **kwargs):
        """Trains a given model specified in the config file or passed as the --model parameter.
        All options in the config file can be overwritten as needed by passing --PARAM
        Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}'

        :param config: yaml config file
        :param **kwargs: parameters to overwrite yaml config
        """

        config_parameters = utils.parse_config_or_kwargs(config, **kwargs)
        outputdir = Path(
            config_parameters['outputpath'], config_parameters['model'],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex[:8]))
        # Early init because of creating dir
        checkpoint_handler = ModelCheckpoint(
            outputdir,
            'run',
            n_saved=1,
            require_empty=False,
            create_dir=True,
            score_function=lambda engine: -engine.state.metrics['Loss'],
            save_as_state_dict=False,
            score_name='loss')
        logger = utils.getfile_outlogger(Path(outputdir, 'train.log'))
        logger.info("Storing files in {}".format(outputdir))
        # utils.pprint_dict
        utils.pprint_dict(config_parameters, logger.info)
        logger.info("Running on device {}".format(DEVICE))
        labels_df = pd.read_csv(config_parameters['trainlabel'], sep=' ')
        labels_df['encoded'], encoder = utils.encode_labels(
            labels=labels_df['bintype'])
        train_df, cv_df = utils.split_train_cv(labels_df)

        transform = utils.parse_transforms(config_parameters['transforms'])
        utils.pprint_dict({'Classes': encoder.classes_},
                          logger.info,
                          formatter='pretty')
        utils.pprint_dict(transform, logger.info, formatter='pretty')
        if 'sampler' in config_parameters and config_parameters[
                'sampler'] == 'MinimumOccupancySampler':
            # Asserts that each "batch" contains at least one instance
            train_sampler = dataset.MinimumOccupancySampler(
                np.stack(train_df['encoded'].values))

            sampling_kwargs = {"sampler": train_sampler, "shuffle": False}
        elif 'shuffle' in config_parameters and config_parameters['shuffle']:
            sampling_kwargs = {"shuffle": True}
        else:
            sampling_kwargs = {"shuffle": False}

        logger.info("Using Sampler {}".format(sampling_kwargs))

        colname = config_parameters.get('colname', ('filename', 'encoded'))  #
        trainloader = dataset.getdataloader(
            train_df,
            config_parameters['traindata'],
            transform=transform,
            batch_size=config_parameters['batch_size'],
            colname=colname,  # For other datasets with different key names
            num_workers=config_parameters['num_workers'],
            **sampling_kwargs)
        cvdataloader = dataset.getdataloader(
            cv_df,
            config_parameters['traindata'],
            transform=None,
            shuffle=False,
            colname=colname,  # For other datasets with different key names
            batch_size=config_parameters['batch_size'],
            num_workers=config_parameters['num_workers'])
        if 'pretrained' in config_parameters and config_parameters[
                'pretrained'] is not None:
            model = models.load_pretrained(config_parameters['pretrained'],
                                           outputdim=len(encoder.classes_))
        else:
            model = getattr(models, config_parameters['model'],
                            'LightCNN')(inputdim=trainloader.dataset.datadim,
                                        outputdim=len(encoder.classes_),
                                        **config_parameters['model_args'])

        if config_parameters['optimizer'] == 'AdaBound':
            try:
                import adabound
                optimizer = adabound.AdaBound(
                    model.parameters(), **config_parameters['optimizer_args'])
            except ImportError:
                logger.info(
                    "Adabound package not found, install via pip install adabound. Using Adam instead"
                )
                config_parameters['optimizer'] = 'Adam'
                config_parameters['optimizer_args'] = {
                }  # Default adam is adabount not found
        else:
            optimizer = getattr(
                torch.optim,
                config_parameters['optimizer'],
            )(model.parameters(), **config_parameters['optimizer_args'])

        utils.pprint_dict(optimizer, logger.info, formatter='pretty')
        utils.pprint_dict(model, logger.info, formatter='pretty')
        if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1:
            logger.info("Using {} GPUs!".format(torch.cuda.device_count()))
            model = torch.nn.DataParallel(model)
        criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
        model = model.to(DEVICE)

        precision = Precision()
        recall = Recall()
        f1_score = (precision * recall * 2 / (precision + recall)).mean()
        metrics = {
            'Loss': Loss(criterion),
            'Precision': precision.mean(),
            'Recall': recall.mean(),
            'Accuracy': Accuracy(),
            'F1': f1_score,
        }

        # batch contains 3 elements, X,Y and filename. Filename is only used
        # during evaluation
        def _prep_batch(batch, device=DEVICE, non_blocking=False):
            x, y, _ = batch
            return (convert_tensor(x, device=device,
                                   non_blocking=non_blocking),
                    convert_tensor(y, device=device,
                                   non_blocking=non_blocking))

        train_engine = create_supervised_trainer(model,
                                                 optimizer=optimizer,
                                                 loss_fn=criterion,
                                                 prepare_batch=_prep_batch,
                                                 device=DEVICE)
        inference_engine = create_supervised_evaluator(
            model, metrics=metrics, prepare_batch=_prep_batch, device=DEVICE)

        RunningAverage(output_transform=lambda x: x).attach(
            train_engine, 'run_loss')  # Showing progressbar during training
        pbar = ProgressBar(persist=False)
        pbar.attach(train_engine, ['run_loss'])
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               patience=3,
                                                               factor=0.1)

        @inference_engine.on(Events.COMPLETED)
        def update_reduce_on_plateau(engine):
            val_loss = engine.state.metrics['Loss']
            if 'ReduceLROnPlateau' == scheduler.__class__.__name__:
                scheduler.step(val_loss)
            else:
                scheduler.step()

        early_stop_handler = EarlyStopping(
            patience=5,
            score_function=lambda engine: -engine.state.metrics['Loss'],
            trainer=train_engine)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           early_stop_handler)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           checkpoint_handler, {
                                               'model': model,
                                               'encoder': encoder,
                                               'config': config_parameters,
                                           })

        @train_engine.on(Events.EPOCH_COMPLETED)
        def compute_validation_metrics(engine):
            inference_engine.run(cvdataloader)
            results = inference_engine.state.metrics
            output_str_list = [
                "Validation Results - Epoch : {:<5}".format(engine.state.epoch)
            ]
            for metric in metrics:
                output_str_list.append("{} {:<5.3f}".format(
                    metric, results[metric]))
            logger.info(" ".join(output_str_list))
            pbar.n = pbar.last_print_n = 0

        train_engine.run(trainloader, max_epochs=config_parameters['epochs'])
        return outputdir