def main(args=None):
    """Main entry point of the script."""
    # Parse command-line arguments
    args = parse_args(args)
    planet = args.planet
    topdir = mypaths.sadir / f"{planet}_grcs_ensemble"
    for label in ENS_LABELS:
        L.info(f"label = {label}")
        if label == "base":
            config_str = ""
        else:
            config = configparser.ConfigParser()
            config.read(Path(str(topdir) + "_conf") / f"rose-app-{label}.conf")
            config_str = pprint_dict(dict(config["namelist:run_convection"]))
        # Make a list of files matching the file mask and the start day threshold
        fnames = get_filename_list(topdir / label, ts_start=args.startday)
        L.debug(f"fnames = {fnames[0]} ... {fnames[-1]}")

        # Create a subdirectory for processed data
        outdir = topdir / label / "_processed"
        outdir.mkdir(parents=True, exist_ok=True)

        # Initialise a `Run` by loading data from the selected files
        run = Run(
            files=fnames,
            description=config_str,
            name=label,
            planet=planet,
            timestep=GLM_MODEL_TIMESTEP,
        )

        # Regrid & interpolate data
        run.proc_data(process_cubes, timestep=run.timestep)

        # Write the data to a netCDF file
        fname_out = outdir / f"{planet}_{run.name}.nc"
        run.to_netcdf(fname_out)
        L.success(f"Saved to {fname_out}")
示例#2
0
                        logs_all[db][k][size].append(v)

    # avg
    for db in ["mongodb", "couchdb"]:
        for k in logs_all[db].keys():
            for size, rt_list in logs_all[db][k].items():
                avg = sum(rt_list) / float(len(rt_list))
                logs_all[db][k][size] = avg

    return logs_all


if __name__ == "__main__":

    # input data
    import sys
    assert len(sys.argv) == 2, """
Command not properly formatted; correct format:
```
python3 main.py {dataset_directory}
```
"""
    dataset_dir = argv[1]
    tweets = load_dataset(dataset_dir)

    #Get logs
    logs = generate_logs(dataset=tweets,
                         sample_sizes=[10000, 20000, 50000, 70000, 90000],
                         trials=4)
    pprint_dict(logs)
示例#3
0
    def train(self, config, **kwargs):
        """Trains a given model specified in the config file or passed as the --model parameter.
        All options in the config file can be overwritten as needed by passing --PARAM
        Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}'

        :param config: yaml config file
        :param **kwargs: parameters to overwrite yaml config
        """

        config_parameters = utils.parse_config_or_kwargs(config, **kwargs)
        outputdir = os.path.join(
            config_parameters['outputpath'], config_parameters['model'],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex))
        # Early init because of creating dir
        checkpoint_handler = ModelCheckpoint(
            outputdir,
            'run',
            n_saved=3,
            require_empty=False,
            create_dir=True,
            score_function=self._negative_loss,
            score_name='loss')
        logger = utils.getfile_outlogger(os.path.join(outputdir, 'train.log'))
        logger.info("Storing files in {}".format(outputdir))
        # utils.pprint_dict
        utils.pprint_dict(config_parameters, logger.info)
        logger.info("Running on device {}".format(DEVICE))
        label_df = pd.read_csv(config_parameters['label'], sep='\s+')
        data_df = pd.read_csv(config_parameters['data'], sep='\s+')
        # In case that both are not matching
        merged = data_df.merge(label_df, on='filename')
        common_idxs = merged['filename']
        data_df = data_df[data_df['filename'].isin(common_idxs)]
        label_df = label_df[label_df['filename'].isin(common_idxs)]

        train_df, cv_df = utils.split_train_cv(
            label_df, **config_parameters['data_args'])
        train_label = utils.df_to_dict(train_df)
        cv_label = utils.df_to_dict(cv_df)
        data = utils.df_to_dict(data_df)

        transform = utils.parse_transforms(config_parameters['transforms'])
        torch.save(config_parameters, os.path.join(outputdir,
                                                   'run_config.pth'))
        logger.info("Transforms:")
        utils.pprint_dict(transform, logger.info, formatter='pretty')
        assert len(cv_df) > 0, "Fraction a bit too large?"

        trainloader = dataset.gettraindataloader(
            h5files=data,
            h5labels=train_label,
            transform=transform,
            label_type=config_parameters['label_type'],
            batch_size=config_parameters['batch_size'],
            num_workers=config_parameters['num_workers'],
            shuffle=True,
        )

        cvdataloader = dataset.gettraindataloader(
            h5files=data,
            h5labels=cv_label,
            label_type=config_parameters['label_type'],
            transform=None,
            shuffle=False,
            batch_size=config_parameters['batch_size'],
            num_workers=config_parameters['num_workers'],
        )
        model = getattr(models, config_parameters['model'],
                        'CRNN')(inputdim=trainloader.dataset.datadim,
                                outputdim=2,
                                **config_parameters['model_args'])
        if 'pretrained' in config_parameters and config_parameters[
                'pretrained'] is not None:
            model_dump = torch.load(config_parameters['pretrained'],
                                    map_location='cpu')
            model_state = model.state_dict()
            pretrained_state = {
                k: v
                for k, v in model_dump.items()
                if k in model_state and v.size() == model_state[k].size()
            }
            model_state.update(pretrained_state)
            model.load_state_dict(model_state)
            logger.info("Loading pretrained model {}".format(
                config_parameters['pretrained']))

        model = model.to(DEVICE)
        optimizer = getattr(
            torch.optim,
            config_parameters['optimizer'],
        )(model.parameters(), **config_parameters['optimizer_args'])

        utils.pprint_dict(optimizer, logger.info, formatter='pretty')
        utils.pprint_dict(model, logger.info, formatter='pretty')
        if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1:
            logger.info("Using {} GPUs!".format(torch.cuda.device_count()))
            model = torch.nn.DataParallel(model)
        criterion = getattr(losses, config_parameters['loss'])().to(DEVICE)

        def _train_batch(_, batch):
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                output = self._forward(
                    model, batch)  # output is tuple (clip, frame, target)
                loss = criterion(*output)
                loss.backward()
                # Single loss
                optimizer.step()
                return loss.item()

        def _inference(_, batch):
            model.eval()
            with torch.no_grad():
                return self._forward(model, batch)

        def thresholded_output_transform(output):
            # Output is (clip, frame, target, lengths)
            _, y_pred, y, y_clip, length = output
            batchsize, timesteps, ndim = y.shape
            idxs = torch.arange(timesteps,
                                device='cpu').repeat(batchsize).view(
                                    batchsize, timesteps)
            mask = (idxs < length.view(-1, 1)).to(y.device)
            y = y * mask.unsqueeze(-1)
            y_pred = torch.round(y_pred)
            y = torch.round(y)
            return y_pred, y

        metrics = {
            'Loss': losses.Loss(
                criterion),  #reimplementation of Loss, supports 3 way loss 
            'Precision': Precision(thresholded_output_transform),
            'Recall': Recall(thresholded_output_transform),
            'Accuracy': Accuracy(thresholded_output_transform),
        }
        train_engine = Engine(_train_batch)
        inference_engine = Engine(_inference)
        for name, metric in metrics.items():
            metric.attach(inference_engine, name)

        def compute_metrics(engine):
            inference_engine.run(cvdataloader)
            results = inference_engine.state.metrics
            output_str_list = [
                "Validation Results - Epoch : {:<5}".format(engine.state.epoch)
            ]
            for metric in metrics:
                output_str_list.append("{} {:<5.2f}".format(
                    metric, results[metric]))
            logger.info(" ".join(output_str_list))
            pbar.n = pbar.last_print_n = 0

        pbar = ProgressBar(persist=False)
        pbar.attach(train_engine)

        train_engine.add_event_handler(Events.ITERATION_COMPLETED(every=5000),
                                       compute_metrics)
        train_engine.add_event_handler(Events.EPOCH_COMPLETED, compute_metrics)

        early_stop_handler = EarlyStopping(
            patience=config_parameters['early_stop'],
            score_function=self._negative_loss,
            trainer=train_engine)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           early_stop_handler)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           checkpoint_handler, {
                                               'model': model,
                                           })

        train_engine.run(trainloader, max_epochs=config_parameters['epochs'])
        return outputdir
    def train(self, config, **kwargs):
        """Trains a given model specified in the config file or passed as the --model parameter.
        All options in the config file can be overwritten as needed by passing --PARAM
        Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}'

        :param config: yaml config file
        :param **kwargs: parameters to overwrite yaml config
        """

        config_parameters = utils.parse_config_or_kwargs(config, **kwargs)
        outputdir = Path(
            config_parameters['outputpath'], config_parameters['model'],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex[:8]))
        # Early init because of creating dir
        checkpoint_handler = ModelCheckpoint(
            outputdir,
            'run',
            n_saved=1,
            require_empty=False,
            create_dir=True,
            score_function=lambda engine: -engine.state.metrics['Loss'],
            save_as_state_dict=False,
            score_name='loss')
        logger = utils.getfile_outlogger(Path(outputdir, 'train.log'))
        logger.info("Storing files in {}".format(outputdir))
        # utils.pprint_dict
        utils.pprint_dict(config_parameters, logger.info)
        logger.info("Running on device {}".format(DEVICE))
        labels_df = pd.read_csv(config_parameters['trainlabel'], sep=' ')
        labels_df['encoded'], encoder = utils.encode_labels(
            labels=labels_df['bintype'])
        train_df, cv_df = utils.split_train_cv(labels_df)

        transform = utils.parse_transforms(config_parameters['transforms'])
        utils.pprint_dict({'Classes': encoder.classes_},
                          logger.info,
                          formatter='pretty')
        utils.pprint_dict(transform, logger.info, formatter='pretty')
        if 'sampler' in config_parameters and config_parameters[
                'sampler'] == 'MinimumOccupancySampler':
            # Asserts that each "batch" contains at least one instance
            train_sampler = dataset.MinimumOccupancySampler(
                np.stack(train_df['encoded'].values))

            sampling_kwargs = {"sampler": train_sampler, "shuffle": False}
        elif 'shuffle' in config_parameters and config_parameters['shuffle']:
            sampling_kwargs = {"shuffle": True}
        else:
            sampling_kwargs = {"shuffle": False}

        logger.info("Using Sampler {}".format(sampling_kwargs))

        colname = config_parameters.get('colname', ('filename', 'encoded'))  #
        trainloader = dataset.getdataloader(
            train_df,
            config_parameters['traindata'],
            transform=transform,
            batch_size=config_parameters['batch_size'],
            colname=colname,  # For other datasets with different key names
            num_workers=config_parameters['num_workers'],
            **sampling_kwargs)
        cvdataloader = dataset.getdataloader(
            cv_df,
            config_parameters['traindata'],
            transform=None,
            shuffle=False,
            colname=colname,  # For other datasets with different key names
            batch_size=config_parameters['batch_size'],
            num_workers=config_parameters['num_workers'])
        if 'pretrained' in config_parameters and config_parameters[
                'pretrained'] is not None:
            model = models.load_pretrained(config_parameters['pretrained'],
                                           outputdim=len(encoder.classes_))
        else:
            model = getattr(models, config_parameters['model'],
                            'LightCNN')(inputdim=trainloader.dataset.datadim,
                                        outputdim=len(encoder.classes_),
                                        **config_parameters['model_args'])

        if config_parameters['optimizer'] == 'AdaBound':
            try:
                import adabound
                optimizer = adabound.AdaBound(
                    model.parameters(), **config_parameters['optimizer_args'])
            except ImportError:
                logger.info(
                    "Adabound package not found, install via pip install adabound. Using Adam instead"
                )
                config_parameters['optimizer'] = 'Adam'
                config_parameters['optimizer_args'] = {
                }  # Default adam is adabount not found
        else:
            optimizer = getattr(
                torch.optim,
                config_parameters['optimizer'],
            )(model.parameters(), **config_parameters['optimizer_args'])

        utils.pprint_dict(optimizer, logger.info, formatter='pretty')
        utils.pprint_dict(model, logger.info, formatter='pretty')
        if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1:
            logger.info("Using {} GPUs!".format(torch.cuda.device_count()))
            model = torch.nn.DataParallel(model)
        criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
        model = model.to(DEVICE)

        precision = Precision()
        recall = Recall()
        f1_score = (precision * recall * 2 / (precision + recall)).mean()
        metrics = {
            'Loss': Loss(criterion),
            'Precision': precision.mean(),
            'Recall': recall.mean(),
            'Accuracy': Accuracy(),
            'F1': f1_score,
        }

        # batch contains 3 elements, X,Y and filename. Filename is only used
        # during evaluation
        def _prep_batch(batch, device=DEVICE, non_blocking=False):
            x, y, _ = batch
            return (convert_tensor(x, device=device,
                                   non_blocking=non_blocking),
                    convert_tensor(y, device=device,
                                   non_blocking=non_blocking))

        train_engine = create_supervised_trainer(model,
                                                 optimizer=optimizer,
                                                 loss_fn=criterion,
                                                 prepare_batch=_prep_batch,
                                                 device=DEVICE)
        inference_engine = create_supervised_evaluator(
            model, metrics=metrics, prepare_batch=_prep_batch, device=DEVICE)

        RunningAverage(output_transform=lambda x: x).attach(
            train_engine, 'run_loss')  # Showing progressbar during training
        pbar = ProgressBar(persist=False)
        pbar.attach(train_engine, ['run_loss'])
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               patience=3,
                                                               factor=0.1)

        @inference_engine.on(Events.COMPLETED)
        def update_reduce_on_plateau(engine):
            val_loss = engine.state.metrics['Loss']
            if 'ReduceLROnPlateau' == scheduler.__class__.__name__:
                scheduler.step(val_loss)
            else:
                scheduler.step()

        early_stop_handler = EarlyStopping(
            patience=5,
            score_function=lambda engine: -engine.state.metrics['Loss'],
            trainer=train_engine)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           early_stop_handler)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           checkpoint_handler, {
                                               'model': model,
                                               'encoder': encoder,
                                               'config': config_parameters,
                                           })

        @train_engine.on(Events.EPOCH_COMPLETED)
        def compute_validation_metrics(engine):
            inference_engine.run(cvdataloader)
            results = inference_engine.state.metrics
            output_str_list = [
                "Validation Results - Epoch : {:<5}".format(engine.state.epoch)
            ]
            for metric in metrics:
                output_str_list.append("{} {:<5.3f}".format(
                    metric, results[metric]))
            logger.info(" ".join(output_str_list))
            pbar.n = pbar.last_print_n = 0

        train_engine.run(trainloader, max_epochs=config_parameters['epochs'])
        return outputdir
示例#5
0
    def train(self, config, **kwargs):
        """Trains a given model specified in the config file or passed as the --model parameter.
        All options in the config file can be overwritten as needed by passing --PARAM
        Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}'

        :param config: yaml config file
        :param **kwargs: parameters to overwrite yaml config
        """

        config_parameters = utils.parse_config_or_kwargs(config, **kwargs)
        outputdir = os.path.join(
            config_parameters['outputpath'], config_parameters['model'],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex))
        # Create base dir
        Path(outputdir).mkdir(exist_ok=True, parents=True)

        logger = utils.getfile_outlogger(os.path.join(outputdir, 'train.log'))
        logger.info("Storing files in {}".format(outputdir))
        # utils.pprint_dict
        utils.pprint_dict(config_parameters, logger.info)
        logger.info("Running on device {}".format(DEVICE))
        labels_df = pd.read_csv(config_parameters['label'],
                                sep='\s+').convert_dtypes()
        # In case of ave dataset where index is int, we change the
        # absolute name to relname
        if not np.all(labels_df['filename'].str.isnumeric()):
            labels_df.loc[:, 'filename'] = labels_df['filename'].apply(
                os.path.basename)
        encoder = utils.train_labelencoder(labels=labels_df['event_labels'])
        # These labels are useless, only for mode == stratified
        label_array, _ = utils.encode_labels(labels_df['event_labels'],
                                             encoder)
        if 'cv_label' in config_parameters:
            cv_df = pd.read_csv(config_parameters['cv_label'],
                                sep='\s+').convert_dtypes()
            if not np.all(cv_df['filename'].str.isnumeric()):
                cv_df.loc[:, 'filename'] = cv_df['filename'].apply(
                    os.path.basename)
            train_df = labels_df
            logger.info(
                f"Using CV labels from {config_parameters['cv_label']}")
        else:
            train_df, cv_df = utils.split_train_cv(
                labels_df, y=label_array, **config_parameters['data_args'])

        if 'cv_data' in config_parameters:
            cv_data = config_parameters['cv_data']
            logger.info(f"Using CV data {config_parameters['cv_data']}")
        else:
            cv_data = config_parameters['data']

        train_label_array, _ = utils.encode_labels(train_df['event_labels'],
                                                   encoder)
        cv_label_array, _ = utils.encode_labels(cv_df['event_labels'], encoder)

        transform = utils.parse_transforms(config_parameters['transforms'])
        utils.pprint_dict({'Classes': encoder.classes_},
                          logger.info,
                          formatter='pretty')
        torch.save(encoder, os.path.join(outputdir, 'run_encoder.pth'))
        torch.save(config_parameters, os.path.join(outputdir,
                                                   'run_config.pth'))
        logger.info("Transforms:")
        utils.pprint_dict(transform, logger.info, formatter='pretty')
        # For Unbalanced Audioset, this is true
        if 'sampler' in config_parameters and config_parameters[
                'sampler'] == 'MultiBalancedSampler':
            # Training sampler that oversamples the dataset to be roughly equally sized
            # Calcualtes mean over multiple instances, rather useful when number of classes
            # is large
            train_sampler = dataset.MultiBalancedSampler(
                train_label_array,
                num_samples=1 * train_label_array.shape[0],
                replacement=True)
            sampling_kwargs = {"shuffle": False, "sampler": train_sampler}
        elif 'sampler' in config_parameters and config_parameters[
                'sampler'] == 'MinimumOccupancySampler':
            # Asserts that each "batch" contains at least one instance
            train_sampler = dataset.MinimumOccupancySampler(
                train_label_array, sampling_mode='same')
            sampling_kwargs = {"shuffle": False, "sampler": train_sampler}
        else:
            sampling_kwargs = {"shuffle": True}

        logger.info("Using Sampler {}".format(sampling_kwargs))

        trainloader = dataset.getdataloader(
            {
                'filename': train_df['filename'].values,
                'encoded': train_label_array
            },
            config_parameters['data'],
            transform=transform,
            batch_size=config_parameters['batch_size'],
            colname=config_parameters['colname'],
            num_workers=config_parameters['num_workers'],
            **sampling_kwargs)

        cvdataloader = dataset.getdataloader(
            {
                'filename': cv_df['filename'].values,
                'encoded': cv_label_array
            },
            cv_data,
            transform=None,
            shuffle=False,
            colname=config_parameters['colname'],
            batch_size=config_parameters['batch_size'],
            num_workers=config_parameters['num_workers'])
        model = getattr(models, config_parameters['model'],
                        'CRNN')(inputdim=trainloader.dataset.datadim,
                                outputdim=len(encoder.classes_),
                                **config_parameters['model_args'])
        if 'pretrained' in config_parameters and config_parameters[
                'pretrained'] is not None:
            models.load_pretrained(model,
                                   config_parameters['pretrained'],
                                   outputdim=len(encoder.classes_))
            logger.info("Loading pretrained model {}".format(
                config_parameters['pretrained']))

        model = model.to(DEVICE)
        if config_parameters['optimizer'] == 'AdaBound':
            try:
                import adabound
                optimizer = adabound.AdaBound(
                    model.parameters(), **config_parameters['optimizer_args'])
            except ImportError:
                config_parameters['optimizer'] = 'Adam'
                config_parameters['optimizer_args'] = {}
        else:
            optimizer = getattr(
                torch.optim,
                config_parameters['optimizer'],
            )(model.parameters(), **config_parameters['optimizer_args'])

        utils.pprint_dict(optimizer, logger.info, formatter='pretty')
        utils.pprint_dict(model, logger.info, formatter='pretty')
        if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1:
            logger.info("Using {} GPUs!".format(torch.cuda.device_count()))
            model = torch.nn.DataParallel(model)
        criterion = getattr(losses, config_parameters['loss'])().to(DEVICE)

        def _train_batch(_, batch):
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                output = self._forward(
                    model, batch)  # output is tuple (clip, frame, target)
                loss = criterion(*output)
                loss.backward()
                # Single loss
                optimizer.step()
                return loss.item()

        def _inference(_, batch):
            model.eval()
            with torch.no_grad():
                return self._forward(model, batch)

        def thresholded_output_transform(output):
            # Output is (clip, frame, target)
            y_pred, _, y = output
            y_pred = torch.round(y_pred)
            return y_pred, y

        precision = Precision(thresholded_output_transform, average=False)
        recall = Recall(thresholded_output_transform, average=False)
        f1_score = (precision * recall * 2 / (precision + recall)).mean()
        metrics = {
            'Loss': losses.Loss(
                criterion),  #reimplementation of Loss, supports 3 way loss 
            'Precision': Precision(thresholded_output_transform),
            'Recall': Recall(thresholded_output_transform),
            'Accuracy': Accuracy(thresholded_output_transform),
            'F1': f1_score,
        }
        train_engine = Engine(_train_batch)
        inference_engine = Engine(_inference)
        for name, metric in metrics.items():
            metric.attach(inference_engine, name)

        def compute_metrics(engine):
            inference_engine.run(cvdataloader)
            results = inference_engine.state.metrics
            output_str_list = [
                "Validation Results - Epoch : {:<5}".format(engine.state.epoch)
            ]
            for metric in metrics:
                output_str_list.append("{} {:<5.2f}".format(
                    metric, results[metric]))
            logger.info(" ".join(output_str_list))

        pbar = ProgressBar(persist=False)
        pbar.attach(train_engine)

        if 'itercv' in config_parameters and config_parameters[
                'itercv'] is not None:
            train_engine.add_event_handler(
                Events.ITERATION_COMPLETED(every=config_parameters['itercv']),
                compute_metrics)
        train_engine.add_event_handler(Events.EPOCH_COMPLETED, compute_metrics)

        # Default scheduler is using patience=3, factor=0.1
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, **config_parameters['scheduler_args'])

        @inference_engine.on(Events.EPOCH_COMPLETED)
        def update_reduce_on_plateau(engine):
            logger.info(f"Scheduling epoch {engine.state.epoch}")
            val_loss = engine.state.metrics['Loss']
            if 'ReduceLROnPlateau' == scheduler.__class__.__name__:
                scheduler.step(val_loss)
            else:
                scheduler.step()

        early_stop_handler = EarlyStopping(
            patience=config_parameters['early_stop'],
            score_function=self._negative_loss,
            trainer=train_engine)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           early_stop_handler)
        if config_parameters['save'] == 'everyepoch':
            checkpoint_handler = ModelCheckpoint(outputdir,
                                                 'run',
                                                 n_saved=5,
                                                 require_empty=False)
            train_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           checkpoint_handler, {
                                               'model': model,
                                           })
            train_engine.add_event_handler(
                Events.ITERATION_COMPLETED(every=config_parameters['itercv']),
                checkpoint_handler, {
                    'model': model,
                })
        else:
            checkpoint_handler = ModelCheckpoint(
                outputdir,
                'run',
                n_saved=1,
                require_empty=False,
                score_function=self._negative_loss,
                global_step_transform=global_step_from_engine(
                    train_engine),  # Just so that model is saved with epoch...
                score_name='loss')
            inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                               checkpoint_handler, {
                                                   'model': model,
                                               })

        train_engine.run(trainloader, max_epochs=config_parameters['epochs'])
        return outputdir