def run(self, config, **kwargs):
        """run

        Trains and evaluates a given config

        :param config: Config for training and evaluation
            :param data: pass --data for trainingdata (HDF5)
            :param label: pass --label for training labels
        :param test_data: Data to use for testing (HDF5)
        :param test_label: According labels for testing
        :param **kwargs:
        """
        config_parameters = utils.parse_config_or_kwargs(config, **kwargs)
        experiment_path = self.train(config, **kwargs)
        evaluation_logger = utils.getfile_outlogger(
            Path(experiment_path, 'evaluation.log'))
        for testdata, testlabel in zip(config_parameters['testdata'],
                                       config_parameters['testlabel']):
            evaluation_logger.info(
                f'Evaluting {testdata} with {testlabel} in {experiment_path}')
            # Scores for later evaluation
            scores_file = Path(experiment_path,
                               'scores_' + Path(testdata).stem + '.tsv')
            evaluation_result_file = Path(
                experiment_path) / 'evaluation_{}.txt'.format(
                    Path(testdata).stem)
            self.score(experiment_path,
                       result_file=scores_file,
                       label=testlabel,
                       data=testdata)
            self.evaluate_eer(scores_file,
                              ground_truth_file=testlabel,
                              evaluation_res_file=evaluation_result_file)
Beispiel #2
0
 def train_evaluate(self, config, test_data, test_label, **kwargs):
     experiment_path = self.train(config, **kwargs)
     from h5py import File
     # Get the output time-ratio factor from the model
     print(glob.glob("{}/run_model*".format(experiment_path))[0])
     model_parameters = torch.load(
         glob.glob("{}/run_model*".format(experiment_path))[0],
         map_location=lambda storage, loc: storage)
     config_param = torch.load(glob.glob(
         "{}/run_config*".format(experiment_path))[0],
                               map_location=lambda storage, loc: storage)
     encoder = torch.load(glob.glob(
         "{}/run_encoder*".format(experiment_path))[0],
                          map_location=lambda storage, loc: storage)
     # Dummy to calculate the pooling factor a bit dynamic
     with File(test_data, 'r') as store:
         timedim, datadim = next(iter(store.values())).shape
     model = getattr(models,
                     config_param['model'])(inputdim=datadim,
                                            outputdim=len(encoder.classes_),
                                            **config_param['model_args'])
     model.load_state_dict(model_parameters)
     dummy = torch.randn(1, timedim, datadim)
     _, time_out = model(dummy)
     time_ratio = max(0.02, 0.02 * np.round(timedim / time_out.shape[1]))
     # Parse for evaluation and update original values such as
     # --data
     # --label
     config_parameters = utils.parse_config_or_kwargs(config, **kwargs)
     threshold = config_parameters.get('threshold', None)
     postprocessing = config_parameters.get('postprocessing', 'double')
     window_size = config_parameters.get('window_size', None)
     self.evaluate(experiment_path,
                   label=test_label,
                   data=test_data,
                   time_ratio=time_ratio,
                   postprocessing=postprocessing,
                   threshold=threshold,
                   window_size=window_size)
Beispiel #3
0
def main(config, rank, world_size, gpu_id, port, kwargs):
    torch.backends.cudnn.benchmark = True

    conf = parse_config_or_kwargs(config, **kwargs)

    # --------- multi machine train set up --------------
    if conf['train_local'] == 1:
        host_addr = 'localhost'
        conf['rank'] = rank
        conf['local_rank'] = gpu_id  # specify the local gpu id
        conf['world_size'] = world_size
        dist_init(host_addr, conf['rank'], conf['local_rank'],
                  conf['world_size'], port)
    else:
        host_addr = getoneNode()
        conf['rank'] = int(os.environ['SLURM_PROCID'])
        conf['local_rank'] = int(os.environ['SLURM_LOCALID'])
        conf['world_size'] = int(os.environ['SLURM_NTASKS'])
        dist_init(host_addr, conf['rank'], conf['local_rank'],
                  conf['world_size'], '2' + os.environ['SLURM_JOBID'][-4:])
        gpu_id = conf['local_rank']
    # --------- multi machine train set up --------------

    # setup logger
    if conf['rank'] == 0:
        check_dir(conf['exp_dir'])
        logger = get_logger_2(os.path.join(conf['exp_dir'], 'train.log'),
                              "[ %(asctime)s ] %(message)s")
    dist.barrier()  # let the rank 0 mkdir first
    if conf['rank'] != 0:
        logger = get_logger_2(os.path.join(conf['exp_dir'], 'train.log'),
                              "[ %(asctime)s ] %(message)s")

    logger.info("Rank: {}/{}, local rank:{} is running".format(
        conf['rank'], conf['world_size'], conf['rank']))

    # write the config file to the exp_dir
    if conf['rank'] == 0:
        store_path = os.path.join(conf['exp_dir'], 'config.yaml')
        store_yaml(config, store_path, **kwargs)

    cuda_id = 'cuda:' + str(gpu_id)
    conf['device'] = torch.device(
        cuda_id if torch.cuda.is_available() else 'cpu')

    model_dir = os.path.join(conf['exp_dir'], 'models')
    if conf['rank'] == 0:
        check_dir(model_dir)
    conf['checkpoint_format'] = os.path.join(model_dir, '{}.th')

    set_seed(666 + conf['rank'])

    if 'R' in conf['model_type']:
        model = eval(conf['model_type'])(base_ch_num=conf['base_ch_num'],
                                         t=conf['t'])
    else:
        model = eval(conf['model_type'])(base_ch_num=conf['base_ch_num'])
    model = model.to(conf['device'])
    model = DDP(model,
                device_ids=[conf['local_rank']],
                output_device=conf['local_rank'])
    optimizer = optim.Adam(model.parameters(),
                           lr=conf['lr'],
                           betas=(0.5, 0.99))

    if conf['rank'] == 0:
        num_params = sum(param.numel() for param in model.parameters())
        logger.info("Model type: {} Base channel num:{}".format(
            conf['model_type'], conf['base_ch_num']))
        logger.info("Number of parameters: {:.4f}M".format(1.0 * num_params /
                                                           1e6))
        logger.info(optimizer)

    train_set = ImageFolder(root=conf['root'],
                            mode='train',
                            augmentation_prob=conf['aug_prob'],
                            crop_size_min=conf['crop_size_min'],
                            crop_size_max=conf['crop_size_max'],
                            data_num=conf['data_num'],
                            gauss_size=conf['gauss_size'],
                            data_aug_list=conf['aug_list'])
    train_loader = DataLoader(dataset=train_set,
                              batch_size=conf['batch_size'],
                              shuffle=conf['shuffle'],
                              num_workers=conf['num_workers'])

    dev_set = ImageFolder(root=conf['root'],
                          mode='train',
                          augmentation_prob=0.0)
    dev_loader = DataLoader(dataset=dev_set,
                            batch_size=5,
                            shuffle=False,
                            num_workers=1)

    valid_set = ImageFolder(root=conf['root'], mode='valid')
    valid_loader = DataLoader(dataset=valid_set,
                              batch_size=5,
                              shuffle=False,
                              num_workers=1)

    test_set = ImageFolder(root=conf['root'], mode='test')
    test_loader = DataLoader(dataset=test_set,
                             batch_size=5,
                             shuffle=False,
                             num_workers=1)

    dist.barrier()  # synchronize here
    train(model, train_loader, test_loader, dev_loader, optimizer, conf,
          logger)
Beispiel #4
0
    def train(self, config, **kwargs):
        """Trains a given model specified in the config file or passed as the --model parameter.
        All options in the config file can be overwritten as needed by passing --PARAM
        Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}'

        :param config: yaml config file
        :param **kwargs: parameters to overwrite yaml config
        """

        config_parameters = utils.parse_config_or_kwargs(config, **kwargs)
        outputdir = os.path.join(
            config_parameters['outputpath'], config_parameters['model'],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex))
        # Early init because of creating dir
        checkpoint_handler = ModelCheckpoint(
            outputdir,
            'run',
            n_saved=3,
            require_empty=False,
            create_dir=True,
            score_function=self._negative_loss,
            score_name='loss')
        logger = utils.getfile_outlogger(os.path.join(outputdir, 'train.log'))
        logger.info("Storing files in {}".format(outputdir))
        # utils.pprint_dict
        utils.pprint_dict(config_parameters, logger.info)
        logger.info("Running on device {}".format(DEVICE))
        label_df = pd.read_csv(config_parameters['label'], sep='\s+')
        data_df = pd.read_csv(config_parameters['data'], sep='\s+')
        # In case that both are not matching
        merged = data_df.merge(label_df, on='filename')
        common_idxs = merged['filename']
        data_df = data_df[data_df['filename'].isin(common_idxs)]
        label_df = label_df[label_df['filename'].isin(common_idxs)]

        train_df, cv_df = utils.split_train_cv(
            label_df, **config_parameters['data_args'])
        train_label = utils.df_to_dict(train_df)
        cv_label = utils.df_to_dict(cv_df)
        data = utils.df_to_dict(data_df)

        transform = utils.parse_transforms(config_parameters['transforms'])
        torch.save(config_parameters, os.path.join(outputdir,
                                                   'run_config.pth'))
        logger.info("Transforms:")
        utils.pprint_dict(transform, logger.info, formatter='pretty')
        assert len(cv_df) > 0, "Fraction a bit too large?"

        trainloader = dataset.gettraindataloader(
            h5files=data,
            h5labels=train_label,
            transform=transform,
            label_type=config_parameters['label_type'],
            batch_size=config_parameters['batch_size'],
            num_workers=config_parameters['num_workers'],
            shuffle=True,
        )

        cvdataloader = dataset.gettraindataloader(
            h5files=data,
            h5labels=cv_label,
            label_type=config_parameters['label_type'],
            transform=None,
            shuffle=False,
            batch_size=config_parameters['batch_size'],
            num_workers=config_parameters['num_workers'],
        )
        model = getattr(models, config_parameters['model'],
                        'CRNN')(inputdim=trainloader.dataset.datadim,
                                outputdim=2,
                                **config_parameters['model_args'])
        if 'pretrained' in config_parameters and config_parameters[
                'pretrained'] is not None:
            model_dump = torch.load(config_parameters['pretrained'],
                                    map_location='cpu')
            model_state = model.state_dict()
            pretrained_state = {
                k: v
                for k, v in model_dump.items()
                if k in model_state and v.size() == model_state[k].size()
            }
            model_state.update(pretrained_state)
            model.load_state_dict(model_state)
            logger.info("Loading pretrained model {}".format(
                config_parameters['pretrained']))

        model = model.to(DEVICE)
        optimizer = getattr(
            torch.optim,
            config_parameters['optimizer'],
        )(model.parameters(), **config_parameters['optimizer_args'])

        utils.pprint_dict(optimizer, logger.info, formatter='pretty')
        utils.pprint_dict(model, logger.info, formatter='pretty')
        if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1:
            logger.info("Using {} GPUs!".format(torch.cuda.device_count()))
            model = torch.nn.DataParallel(model)
        criterion = getattr(losses, config_parameters['loss'])().to(DEVICE)

        def _train_batch(_, batch):
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                output = self._forward(
                    model, batch)  # output is tuple (clip, frame, target)
                loss = criterion(*output)
                loss.backward()
                # Single loss
                optimizer.step()
                return loss.item()

        def _inference(_, batch):
            model.eval()
            with torch.no_grad():
                return self._forward(model, batch)

        def thresholded_output_transform(output):
            # Output is (clip, frame, target, lengths)
            _, y_pred, y, y_clip, length = output
            batchsize, timesteps, ndim = y.shape
            idxs = torch.arange(timesteps,
                                device='cpu').repeat(batchsize).view(
                                    batchsize, timesteps)
            mask = (idxs < length.view(-1, 1)).to(y.device)
            y = y * mask.unsqueeze(-1)
            y_pred = torch.round(y_pred)
            y = torch.round(y)
            return y_pred, y

        metrics = {
            'Loss': losses.Loss(
                criterion),  #reimplementation of Loss, supports 3 way loss 
            'Precision': Precision(thresholded_output_transform),
            'Recall': Recall(thresholded_output_transform),
            'Accuracy': Accuracy(thresholded_output_transform),
        }
        train_engine = Engine(_train_batch)
        inference_engine = Engine(_inference)
        for name, metric in metrics.items():
            metric.attach(inference_engine, name)

        def compute_metrics(engine):
            inference_engine.run(cvdataloader)
            results = inference_engine.state.metrics
            output_str_list = [
                "Validation Results - Epoch : {:<5}".format(engine.state.epoch)
            ]
            for metric in metrics:
                output_str_list.append("{} {:<5.2f}".format(
                    metric, results[metric]))
            logger.info(" ".join(output_str_list))
            pbar.n = pbar.last_print_n = 0

        pbar = ProgressBar(persist=False)
        pbar.attach(train_engine)

        train_engine.add_event_handler(Events.ITERATION_COMPLETED(every=5000),
                                       compute_metrics)
        train_engine.add_event_handler(Events.EPOCH_COMPLETED, compute_metrics)

        early_stop_handler = EarlyStopping(
            patience=config_parameters['early_stop'],
            score_function=self._negative_loss,
            trainer=train_engine)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           early_stop_handler)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           checkpoint_handler, {
                                               'model': model,
                                           })

        train_engine.run(trainloader, max_epochs=config_parameters['epochs'])
        return outputdir
    def train(self, config, **kwargs):
        """Trains a given model specified in the config file or passed as the --model parameter.
        All options in the config file can be overwritten as needed by passing --PARAM
        Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}'

        :param config: yaml config file
        :param **kwargs: parameters to overwrite yaml config
        """

        config_parameters = utils.parse_config_or_kwargs(config, **kwargs)
        outputdir = Path(
            config_parameters['outputpath'], config_parameters['model'],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex[:8]))
        # Early init because of creating dir
        checkpoint_handler = ModelCheckpoint(
            outputdir,
            'run',
            n_saved=1,
            require_empty=False,
            create_dir=True,
            score_function=lambda engine: -engine.state.metrics['Loss'],
            save_as_state_dict=False,
            score_name='loss')
        logger = utils.getfile_outlogger(Path(outputdir, 'train.log'))
        logger.info("Storing files in {}".format(outputdir))
        # utils.pprint_dict
        utils.pprint_dict(config_parameters, logger.info)
        logger.info("Running on device {}".format(DEVICE))
        labels_df = pd.read_csv(config_parameters['trainlabel'], sep=' ')
        labels_df['encoded'], encoder = utils.encode_labels(
            labels=labels_df['bintype'])
        train_df, cv_df = utils.split_train_cv(labels_df)

        transform = utils.parse_transforms(config_parameters['transforms'])
        utils.pprint_dict({'Classes': encoder.classes_},
                          logger.info,
                          formatter='pretty')
        utils.pprint_dict(transform, logger.info, formatter='pretty')
        if 'sampler' in config_parameters and config_parameters[
                'sampler'] == 'MinimumOccupancySampler':
            # Asserts that each "batch" contains at least one instance
            train_sampler = dataset.MinimumOccupancySampler(
                np.stack(train_df['encoded'].values))

            sampling_kwargs = {"sampler": train_sampler, "shuffle": False}
        elif 'shuffle' in config_parameters and config_parameters['shuffle']:
            sampling_kwargs = {"shuffle": True}
        else:
            sampling_kwargs = {"shuffle": False}

        logger.info("Using Sampler {}".format(sampling_kwargs))

        colname = config_parameters.get('colname', ('filename', 'encoded'))  #
        trainloader = dataset.getdataloader(
            train_df,
            config_parameters['traindata'],
            transform=transform,
            batch_size=config_parameters['batch_size'],
            colname=colname,  # For other datasets with different key names
            num_workers=config_parameters['num_workers'],
            **sampling_kwargs)
        cvdataloader = dataset.getdataloader(
            cv_df,
            config_parameters['traindata'],
            transform=None,
            shuffle=False,
            colname=colname,  # For other datasets with different key names
            batch_size=config_parameters['batch_size'],
            num_workers=config_parameters['num_workers'])
        if 'pretrained' in config_parameters and config_parameters[
                'pretrained'] is not None:
            model = models.load_pretrained(config_parameters['pretrained'],
                                           outputdim=len(encoder.classes_))
        else:
            model = getattr(models, config_parameters['model'],
                            'LightCNN')(inputdim=trainloader.dataset.datadim,
                                        outputdim=len(encoder.classes_),
                                        **config_parameters['model_args'])

        if config_parameters['optimizer'] == 'AdaBound':
            try:
                import adabound
                optimizer = adabound.AdaBound(
                    model.parameters(), **config_parameters['optimizer_args'])
            except ImportError:
                logger.info(
                    "Adabound package not found, install via pip install adabound. Using Adam instead"
                )
                config_parameters['optimizer'] = 'Adam'
                config_parameters['optimizer_args'] = {
                }  # Default adam is adabount not found
        else:
            optimizer = getattr(
                torch.optim,
                config_parameters['optimizer'],
            )(model.parameters(), **config_parameters['optimizer_args'])

        utils.pprint_dict(optimizer, logger.info, formatter='pretty')
        utils.pprint_dict(model, logger.info, formatter='pretty')
        if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1:
            logger.info("Using {} GPUs!".format(torch.cuda.device_count()))
            model = torch.nn.DataParallel(model)
        criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
        model = model.to(DEVICE)

        precision = Precision()
        recall = Recall()
        f1_score = (precision * recall * 2 / (precision + recall)).mean()
        metrics = {
            'Loss': Loss(criterion),
            'Precision': precision.mean(),
            'Recall': recall.mean(),
            'Accuracy': Accuracy(),
            'F1': f1_score,
        }

        # batch contains 3 elements, X,Y and filename. Filename is only used
        # during evaluation
        def _prep_batch(batch, device=DEVICE, non_blocking=False):
            x, y, _ = batch
            return (convert_tensor(x, device=device,
                                   non_blocking=non_blocking),
                    convert_tensor(y, device=device,
                                   non_blocking=non_blocking))

        train_engine = create_supervised_trainer(model,
                                                 optimizer=optimizer,
                                                 loss_fn=criterion,
                                                 prepare_batch=_prep_batch,
                                                 device=DEVICE)
        inference_engine = create_supervised_evaluator(
            model, metrics=metrics, prepare_batch=_prep_batch, device=DEVICE)

        RunningAverage(output_transform=lambda x: x).attach(
            train_engine, 'run_loss')  # Showing progressbar during training
        pbar = ProgressBar(persist=False)
        pbar.attach(train_engine, ['run_loss'])
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               patience=3,
                                                               factor=0.1)

        @inference_engine.on(Events.COMPLETED)
        def update_reduce_on_plateau(engine):
            val_loss = engine.state.metrics['Loss']
            if 'ReduceLROnPlateau' == scheduler.__class__.__name__:
                scheduler.step(val_loss)
            else:
                scheduler.step()

        early_stop_handler = EarlyStopping(
            patience=5,
            score_function=lambda engine: -engine.state.metrics['Loss'],
            trainer=train_engine)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           early_stop_handler)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           checkpoint_handler, {
                                               'model': model,
                                               'encoder': encoder,
                                               'config': config_parameters,
                                           })

        @train_engine.on(Events.EPOCH_COMPLETED)
        def compute_validation_metrics(engine):
            inference_engine.run(cvdataloader)
            results = inference_engine.state.metrics
            output_str_list = [
                "Validation Results - Epoch : {:<5}".format(engine.state.epoch)
            ]
            for metric in metrics:
                output_str_list.append("{} {:<5.3f}".format(
                    metric, results[metric]))
            logger.info(" ".join(output_str_list))
            pbar.n = pbar.last_print_n = 0

        train_engine.run(trainloader, max_epochs=config_parameters['epochs'])
        return outputdir
Beispiel #6
0
    def train(self, config, **kwargs):
        """Trains a given model specified in the config file or passed as the --model parameter.
        All options in the config file can be overwritten as needed by passing --PARAM
        Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}'

        :param config: yaml config file
        :param **kwargs: parameters to overwrite yaml config
        """

        config_parameters = utils.parse_config_or_kwargs(config, **kwargs)
        outputdir = os.path.join(
            config_parameters['outputpath'], config_parameters['model'],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex))
        # Create base dir
        Path(outputdir).mkdir(exist_ok=True, parents=True)

        logger = utils.getfile_outlogger(os.path.join(outputdir, 'train.log'))
        logger.info("Storing files in {}".format(outputdir))
        # utils.pprint_dict
        utils.pprint_dict(config_parameters, logger.info)
        logger.info("Running on device {}".format(DEVICE))
        labels_df = pd.read_csv(config_parameters['label'],
                                sep='\s+').convert_dtypes()
        # In case of ave dataset where index is int, we change the
        # absolute name to relname
        if not np.all(labels_df['filename'].str.isnumeric()):
            labels_df.loc[:, 'filename'] = labels_df['filename'].apply(
                os.path.basename)
        encoder = utils.train_labelencoder(labels=labels_df['event_labels'])
        # These labels are useless, only for mode == stratified
        label_array, _ = utils.encode_labels(labels_df['event_labels'],
                                             encoder)
        if 'cv_label' in config_parameters:
            cv_df = pd.read_csv(config_parameters['cv_label'],
                                sep='\s+').convert_dtypes()
            if not np.all(cv_df['filename'].str.isnumeric()):
                cv_df.loc[:, 'filename'] = cv_df['filename'].apply(
                    os.path.basename)
            train_df = labels_df
            logger.info(
                f"Using CV labels from {config_parameters['cv_label']}")
        else:
            train_df, cv_df = utils.split_train_cv(
                labels_df, y=label_array, **config_parameters['data_args'])

        if 'cv_data' in config_parameters:
            cv_data = config_parameters['cv_data']
            logger.info(f"Using CV data {config_parameters['cv_data']}")
        else:
            cv_data = config_parameters['data']

        train_label_array, _ = utils.encode_labels(train_df['event_labels'],
                                                   encoder)
        cv_label_array, _ = utils.encode_labels(cv_df['event_labels'], encoder)

        transform = utils.parse_transforms(config_parameters['transforms'])
        utils.pprint_dict({'Classes': encoder.classes_},
                          logger.info,
                          formatter='pretty')
        torch.save(encoder, os.path.join(outputdir, 'run_encoder.pth'))
        torch.save(config_parameters, os.path.join(outputdir,
                                                   'run_config.pth'))
        logger.info("Transforms:")
        utils.pprint_dict(transform, logger.info, formatter='pretty')
        # For Unbalanced Audioset, this is true
        if 'sampler' in config_parameters and config_parameters[
                'sampler'] == 'MultiBalancedSampler':
            # Training sampler that oversamples the dataset to be roughly equally sized
            # Calcualtes mean over multiple instances, rather useful when number of classes
            # is large
            train_sampler = dataset.MultiBalancedSampler(
                train_label_array,
                num_samples=1 * train_label_array.shape[0],
                replacement=True)
            sampling_kwargs = {"shuffle": False, "sampler": train_sampler}
        elif 'sampler' in config_parameters and config_parameters[
                'sampler'] == 'MinimumOccupancySampler':
            # Asserts that each "batch" contains at least one instance
            train_sampler = dataset.MinimumOccupancySampler(
                train_label_array, sampling_mode='same')
            sampling_kwargs = {"shuffle": False, "sampler": train_sampler}
        else:
            sampling_kwargs = {"shuffle": True}

        logger.info("Using Sampler {}".format(sampling_kwargs))

        trainloader = dataset.getdataloader(
            {
                'filename': train_df['filename'].values,
                'encoded': train_label_array
            },
            config_parameters['data'],
            transform=transform,
            batch_size=config_parameters['batch_size'],
            colname=config_parameters['colname'],
            num_workers=config_parameters['num_workers'],
            **sampling_kwargs)

        cvdataloader = dataset.getdataloader(
            {
                'filename': cv_df['filename'].values,
                'encoded': cv_label_array
            },
            cv_data,
            transform=None,
            shuffle=False,
            colname=config_parameters['colname'],
            batch_size=config_parameters['batch_size'],
            num_workers=config_parameters['num_workers'])
        model = getattr(models, config_parameters['model'],
                        'CRNN')(inputdim=trainloader.dataset.datadim,
                                outputdim=len(encoder.classes_),
                                **config_parameters['model_args'])
        if 'pretrained' in config_parameters and config_parameters[
                'pretrained'] is not None:
            models.load_pretrained(model,
                                   config_parameters['pretrained'],
                                   outputdim=len(encoder.classes_))
            logger.info("Loading pretrained model {}".format(
                config_parameters['pretrained']))

        model = model.to(DEVICE)
        if config_parameters['optimizer'] == 'AdaBound':
            try:
                import adabound
                optimizer = adabound.AdaBound(
                    model.parameters(), **config_parameters['optimizer_args'])
            except ImportError:
                config_parameters['optimizer'] = 'Adam'
                config_parameters['optimizer_args'] = {}
        else:
            optimizer = getattr(
                torch.optim,
                config_parameters['optimizer'],
            )(model.parameters(), **config_parameters['optimizer_args'])

        utils.pprint_dict(optimizer, logger.info, formatter='pretty')
        utils.pprint_dict(model, logger.info, formatter='pretty')
        if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1:
            logger.info("Using {} GPUs!".format(torch.cuda.device_count()))
            model = torch.nn.DataParallel(model)
        criterion = getattr(losses, config_parameters['loss'])().to(DEVICE)

        def _train_batch(_, batch):
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                output = self._forward(
                    model, batch)  # output is tuple (clip, frame, target)
                loss = criterion(*output)
                loss.backward()
                # Single loss
                optimizer.step()
                return loss.item()

        def _inference(_, batch):
            model.eval()
            with torch.no_grad():
                return self._forward(model, batch)

        def thresholded_output_transform(output):
            # Output is (clip, frame, target)
            y_pred, _, y = output
            y_pred = torch.round(y_pred)
            return y_pred, y

        precision = Precision(thresholded_output_transform, average=False)
        recall = Recall(thresholded_output_transform, average=False)
        f1_score = (precision * recall * 2 / (precision + recall)).mean()
        metrics = {
            'Loss': losses.Loss(
                criterion),  #reimplementation of Loss, supports 3 way loss 
            'Precision': Precision(thresholded_output_transform),
            'Recall': Recall(thresholded_output_transform),
            'Accuracy': Accuracy(thresholded_output_transform),
            'F1': f1_score,
        }
        train_engine = Engine(_train_batch)
        inference_engine = Engine(_inference)
        for name, metric in metrics.items():
            metric.attach(inference_engine, name)

        def compute_metrics(engine):
            inference_engine.run(cvdataloader)
            results = inference_engine.state.metrics
            output_str_list = [
                "Validation Results - Epoch : {:<5}".format(engine.state.epoch)
            ]
            for metric in metrics:
                output_str_list.append("{} {:<5.2f}".format(
                    metric, results[metric]))
            logger.info(" ".join(output_str_list))

        pbar = ProgressBar(persist=False)
        pbar.attach(train_engine)

        if 'itercv' in config_parameters and config_parameters[
                'itercv'] is not None:
            train_engine.add_event_handler(
                Events.ITERATION_COMPLETED(every=config_parameters['itercv']),
                compute_metrics)
        train_engine.add_event_handler(Events.EPOCH_COMPLETED, compute_metrics)

        # Default scheduler is using patience=3, factor=0.1
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, **config_parameters['scheduler_args'])

        @inference_engine.on(Events.EPOCH_COMPLETED)
        def update_reduce_on_plateau(engine):
            logger.info(f"Scheduling epoch {engine.state.epoch}")
            val_loss = engine.state.metrics['Loss']
            if 'ReduceLROnPlateau' == scheduler.__class__.__name__:
                scheduler.step(val_loss)
            else:
                scheduler.step()

        early_stop_handler = EarlyStopping(
            patience=config_parameters['early_stop'],
            score_function=self._negative_loss,
            trainer=train_engine)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           early_stop_handler)
        if config_parameters['save'] == 'everyepoch':
            checkpoint_handler = ModelCheckpoint(outputdir,
                                                 'run',
                                                 n_saved=5,
                                                 require_empty=False)
            train_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           checkpoint_handler, {
                                               'model': model,
                                           })
            train_engine.add_event_handler(
                Events.ITERATION_COMPLETED(every=config_parameters['itercv']),
                checkpoint_handler, {
                    'model': model,
                })
        else:
            checkpoint_handler = ModelCheckpoint(
                outputdir,
                'run',
                n_saved=1,
                require_empty=False,
                score_function=self._negative_loss,
                global_step_transform=global_step_from_engine(
                    train_engine),  # Just so that model is saved with epoch...
                score_name='loss')
            inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                               checkpoint_handler, {
                                                   'model': model,
                                               })

        train_engine.run(trainloader, max_epochs=config_parameters['epochs'])
        return outputdir