Example #1
0
def main():
    args = parser.parse_args()
    torch.cuda.set_device(args.gpu_id)

    dataset = DBreader_Vimeo90k(args.train,
                                random_crop=(args.patch_size, args.patch_size))
    TestDB = Middlebury_other(args.test_input, args.gt)
    train_loader = DataLoader(dataset=dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=0)
    model = models.Model(args)
    loss = losses.Loss(args)

    start_epoch = 0
    if args.load is not None:
        checkpoint = torch.load(args.load)
        model.load(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch']

    my_trainer = Trainer(args, train_loader, TestDB, model, loss, start_epoch)

    now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
    with open(args.out_dir + '/config.txt', 'a') as f:
        f.write(now + '\n\n')
        for arg in vars(args):
            f.write('{}: {}\n'.format(arg, getattr(args, arg)))
        f.write('\n')

    while not my_trainer.terminate():
        my_trainer.train()
        my_trainer.test()

    my_trainer.close()
Example #2
0
def train():
    '''
    loads batches of training data (triplets) and trains the CNN over various epochs
    saves the model in checkpoints to allow pauses to training. Similar to predict.py

    TODO: add validation
    TODO: add argparse functionality
    '''

    # load training data
    train_data_dir = pathlib.Path('../out')
    train_data = data_processor.load(train_data_dir, training=True)
    train_data_batches = tf.data.experimental.cardinality(train_data).numpy()

    print('{} batches of {} triplets each'.format(train_data_batches,
                                                  parameters.BATCH_SIZE))

    # prepare directories for the model
    model_dir = pathlib.Path('../model')
    model_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_dir = model_dir / 'checkpoints'

    # load the model and any checkpoints, if they exist
    model = Interpolator()
    optimizer = tf.keras.optimizers.Adam(lr=parameters.ADAM_LR)
    loss_func = losses.Loss()
    iterator = iter(train_data)
    checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
                                     optimizer=optimizer,
                                     net=model)
    progress_bar = tf.keras.utils.Progbar(train_data_batches)

    if tf.train.latest_checkpoint(checkpoint_dir):
        checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
        print('Restored latest checkpoint!')

    for epoch in range(parameters.EPOCHS):
        print('Epoch: {}'.format(epoch))

        for batch_num, batch in enumerate(train_data):
            loss, ssim, psnr = train_batch(model, batch, optimizer, loss_func)
            # loss = total_loss, r_loss, p_loss, w_loss, s_loss

            progress_vals = [('Total loss', loss[0]), ('SSIM', ssim),
                             ('Rec. loss', loss[1]), ('Perc. loss', loss[2]),
                             ('Warp. loss', loss[3]),
                             ('Smooth. loss', loss[4])]

            progress_bar.update(batch_num + 1, progress_vals)

        checkpoint.save(pathlib.PurePath(checkpoint_dir, 'checkpoint'))
        print('Created checkpoint for epoch #{}!'.format(epoch))

    print('Training done!')
def main():

    if a.seed is None:
        a.seed = random.randint(0, 2**31 - 1)

    tf.set_random_seed(a.seed)
    np.random.seed(a.seed)
    random.seed(a.seed)
    #Load some options from the checkpoint if we provided one.
    loadCheckpointOption()
    #If we feed the network with renderings done in the network for a test run, we save the images before, to be able to compare later with other networks on the same testset.
    if a.mode == "test" and a.feedMethod == "render":
        testHelpers.renderTests(a.input_dir, a.testFolder, a.maxImages,
                                tmpFolder, a.imageFormat, CROP_SIZE,
                                a.nbTargets, a.input_size, a.batch_size,
                                a.renderingScene, a.jitterLightPos,
                                a.jitterViewPos, a.inputMode, a.mode,
                                a.output_dir)
        generateTmpData = True
        a.nbInputs = a.maxImages
        a.feedMethod = "files"
        a.testFolder = tmpFolder
        a.input_size = CROP_SIZE

    backupOutputDir = a.output_dir
    #We run the network once if we a training
    nbRun = 1
    #And as many time as the maximum number of images we want to treat with if testing (to have results with one image, two images, three images etc... to see the improvement)
    if a.mode == "test":
        nbRun = a.maxImages  #1
        a.fixImageNb = True

    #Now run the network nbRun times.
    for runID in range(nbRun):
        maxInputNb = a.maxImages
        if a.mode == "test":
            maxInputNb = runID + 1  #a.maxImages
            a.output_dir = os.path.join(backupOutputDir, str(runID))
            tf.reset_default_graph()

        #Create the output dir if it doesn't exist
        if not os.path.exists(a.output_dir):
            os.makedirs(a.output_dir)

        #Write to the "options" file the different parameters of this run.
        with open(os.path.join(a.output_dir, "options.json"), "w") as f:
            f.write(json.dumps(vars(a), sort_keys=True, indent=4))

        #Create a dataset object
        data = dataReader.dataset(
            a.input_dir,
            imageFormat=a.imageFormat,
            trainFolder=a.trainFolder,
            testFolder=a.testFolder,
            inputNumbers=a.nbInputs,
            maxInputToRead=maxInputNb,
            nbTargetsToRead=a.nbTargets,
            cropSize=CROP_SIZE,
            inputImageSize=a.input_size,
            batchSize=a.batch_size,
            fixCrop=(a.mode == "test"),
            mixMaterials=(a.mode == "train"),
            fixImageNb=a.fixImageNb,
            logInput=a.useLog,
            useAmbientLight=a.useAmbientLight,
            jitterRenderings=a.jitterRenderings,
            firstAsGuide=False,
            useAugmentationInRenderings=not a.NoAugmentationInRenderings,
            mode=a.mode)

        # Populate the list of files the dataset will contain
        data.loadPathList(a.inputMode, a.mode, a.mode == "train")

        # Depending on wheter we want to render our input data or directly use files, we create the tensorflow data loading system.
        if a.feedMethod == "render":
            data.populateInNetworkFeedGraph(a.renderingScene,
                                            a.jitterLightPos,
                                            a.jitterViewPos,
                                            a.mode == "test",
                                            shuffle=a.mode == "train")
        elif a.feedMethod == "files":
            data.populateFeedGraph(shuffle=a.mode == "train")

        # Here we reshape the input to have all the images in the first dimension (to treat in parallel)
        inputReshaped, dyn_batch_size = helpers.input_reshape(
            data.inputBatch, a.NoMaxPooling, a.maxImages)

        if a.mode == "train":
            with tf.name_scope("recurrentTest"):
                #Initialize different data for tests.
                dataTest = dataReader.dataset(
                    a.input_dir,
                    imageFormat=a.imageFormat,
                    testFolder=a.testFolder,
                    inputNumbers=a.nbInputs,
                    maxInputToRead=a.maxImages,
                    nbTargetsToRead=a.nbTargets,
                    cropSize=CROP_SIZE,
                    inputImageSize=a.input_size,
                    batchSize=a.batch_size,
                    fixCrop=True,
                    mixMaterials=False,
                    fixImageNb=a.fixImageNb,
                    logInput=a.useLog,
                    useAmbientLight=a.useAmbientLight,
                    jitterRenderings=a.jitterRenderings,
                    firstAsGuide=a.firstAsGuide,
                    useAugmentationInRenderings=not a.
                    NoAugmentationInRenderings,
                    mode=a.mode)
                dataTest.loadPathList(a.inputMode, "test", False)
                if a.feedMethod == "render":
                    dataTest.populateInNetworkFeedGraph(a.renderingScene,
                                                        a.jitterLightPos,
                                                        a.jitterViewPos,
                                                        True,
                                                        shuffle=False)
                elif a.feedMethod == "files":
                    dataTest.populateFeedGraph(False)
                TestinputReshaped, test_dyn_batch_size = helpers.input_reshape(
                    dataTest.inputBatch, a.NoMaxPooling, a.maxImages)

        #Reshape the targets to [?(Batchsize), 256,256,12]
        targetsReshaped = helpers.target_reshape(data.targetBatch)

        #Create the object to contain the network model.
        model = mod.Model(inputReshaped,
                          dyn_batch_size,
                          last_convolutions_channels=last_convs_chans,
                          generatorOutputChannels=64,
                          useCoordConv=a.useCoordConv,
                          firstAsGuide=a.firstAsGuide,
                          NoMaxPooling=a.NoMaxPooling,
                          pooling_type=a.poolingtype)

        #Initialize the model.
        model.create_model()

        if a.mode == "train":
            #Initialize the regular test network with different data so that it can run regular test sets.
            testTargetsReshaped = helpers.target_reshape(dataTest.targetBatch)
            testmodel = mod.Model(TestinputReshaped,
                                  test_dyn_batch_size,
                                  last_convolutions_channels=last_convs_chans,
                                  generatorOutputChannels=64,
                                  reuse_bool=True,
                                  useCoordConv=a.useCoordConv,
                                  firstAsGuide=a.firstAsGuide,
                                  NoMaxPooling=a.NoMaxPooling,
                                  pooling_type=a.poolingtype)
            testmodel.create_model()

            #Organize the images we want to retrieve from the test network run
            display_fetches_test, _ = helpers.display_images_fetches(
                dataTest.pathBatch, dataTest.inputBatch, dataTest.targetBatch,
                dataTest.gammaCorrectedInputsBatch, testmodel.output,
                a.nbTargets, a.logOutputAlbedos)

            # Compute the training network loss.
            loss = losses.Loss(a.loss, model.output, targetsReshaped,
                               CROP_SIZE, a.batch_size,
                               tf.placeholder(tf.float64, shape=(),
                                              name="lr"), a.includeDiffuse)
            loss.createLossGraph()

            #Create the training graph part
            loss.createTrainVariablesGraph()

        #Organize the images we want to retrieve from the train network run
        display_fetches, converted_images = helpers.display_images_fetches(
            data.pathBatch, data.inputBatch, data.targetBatch,
            data.gammaCorrectedInputsBatch, model.output, a.nbTargets,
            a.logOutputAlbedos)
        if a.mode == "train":
            #Register inputs, targets, renderings and loss in Tensorboard
            helpers.registerTensorboard(data.pathBatch, converted_images,
                                        a.maxImages, a.nbTargets,
                                        loss.lossValue, a.batch_size,
                                        loss.targetsRenderings,
                                        loss.outputsRenderings)

        #Compute how many paramters the network has
        with tf.name_scope("parameter_count"):
            parameter_count = tf.reduce_sum([
                tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()
            ])

        #Initialize a saver
        saver = tf.train.Saver(max_to_keep=1)
        if a.checkpoint is not None:
            print("reading model from checkpoint : " + a.checkpoint)
            checkpoint = tf.train.latest_checkpoint(a.checkpoint)
            partialSaver = helpers.optimistic_saver(checkpoint)
        logdir = a.output_dir if a.summary_freq > 0 else None
        sv = tf.train.Supervisor(logdir=logdir,
                                 save_summaries_secs=0,
                                 saver=None)
        #helpers.print_trainable()
        with sv.managed_session() as sess:
            print("parameter_count =", sess.run(parameter_count))

            #Loads the checkpoint
            if a.checkpoint is not None:
                print("restoring model from checkpoint : " + a.checkpoint)
                partialSaver.restore(sess, checkpoint)

            #Evaluate how many steps to run
            max_steps = 2**32
            if a.max_epochs is not None:
                max_steps = data.stepsPerEpoch * a.max_epochs
            if a.max_steps is not None:
                max_steps = a.max_steps

            #If we want to run a test
            if a.mode == "test" or a.mode == "eval":
                filesets = test(sess,
                                data,
                                max_steps,
                                display_fetches,
                                output_dir=a.output_dir)
                if runID == nbRun - 1 and runID >= 1:  #If we are at the last iteration of the test, generate the full html
                    helpers.writeGlobalHTML(backupOutputDir, filesets,
                                            a.nbTargets, a.mode, a.maxImages)
            #If we want to train
            if a.mode == "train":
                train(sv, sess, data, max_steps, display_fetches,
                      display_fetches_test, dataTest, saver, loss)
def main():
    if a.seed is None:
        a.seed = random.randint(0, 2**31 - 1)

    tf.set_random_seed(a.seed)
    np.random.seed(a.seed)
    random.seed(a.seed)
    loadCheckpointOption(a.mode, a.checkpoint) #loads so that I don't mix up options and it generates data corresponding to this training

    config = tf.ConfigProto()

    if not os.path.exists(a.output_dir):
        os.makedirs(a.output_dir)

    with open(os.path.join(a.output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(a), sort_keys=True, indent=4))

    data = dataReader.dataset(a.input_dir, imageFormat = a.imageFormat, trainFolder = a.trainFolder, testFolder = a.testFolder, nbTargetsToRead = a.nbTargets, tileSize=TILE_SIZE, inputImageSize=a.input_size, batchSize=a.batch_size, fixCrop = (a.mode == "test"), mixMaterials = (a.mode == "train" or a.mode == "finetune"), logInput = a.useLog, useAmbientLight = a.useAmbientLight, useAugmentationInRenderings = not a.NoAugmentationInRenderings)
    # Populate data
    data.loadPathList(a.inputMode, a.mode, a.mode == "train" or a.mode == "finetune", inputpythonList)

    if a.feedMethod == "render":
        if a.mode == "train":
            data.populateInNetworkFeedGraph(a.renderingScene, a.jitterLightPos, a.jitterViewPos,  shuffle = (a.mode == "train"  or a.mode == "finetune"))
        elif a.mode == "finetune":
            data.populateInNetworkFeedGraphSpatialMix(a.renderingScene, shuffle = False, imageSize = a.input_size)

    elif a.feedMethod == "files":
        data.populateFeedGraph(shuffle = (a.mode == "train"  or a.mode == "finetune"))


    if a.mode == "train" or a.mode == "finetune":
        with tf.name_scope("recurrentTest"):
            dataTest = dataReader.dataset(a.input_dir, imageFormat = a.imageFormat, testFolder = a.testFolder, nbTargetsToRead = a.nbTargets, tileSize=TILE_SIZE, inputImageSize=a.test_input_size, batchSize=a.batch_size, fixCrop = True, mixMaterials = False, logInput = a.useLog, useAmbientLight = a.useAmbientLight, useAugmentationInRenderings = not a.NoAugmentationInRenderings)
            dataTest.loadPathList(a.inputMode, "test", False, inputpythonList)
            if a.testApproach == "render":
                #dataTest.populateInNetworkFeedGraphSpatialMix(a.renderingScene, shuffle = False, imageSize = TILE_SIZE, useSpatialMix=False)
                dataTest.populateInNetworkFeedGraph(a.renderingScene, a.jitterLightPos, a.jitterViewPos, shuffle = False)
            elif a.testApproach == "files":
                dataTest.populateFeedGraph(False) 

    targetsReshaped = helpers.target_reshape(data.targetBatch)

    #CreateModel
    model = mod.Model(data.inputBatch, generatorOutputChannels=9)
    model.create_model()
    if a.mode == "train" or a.mode == "finetune":
        testTargetsReshaped = helpers.target_reshape(dataTest.targetBatch)

        testmodel = mod.Model(dataTest.inputBatch, generatorOutputChannels=9, reuse_bool=True)

        testmodel.create_model()
        display_fetches_test, _ = helpers.display_images_fetches(dataTest.pathBatch, dataTest.inputBatch, dataTest.targetBatch, dataTest.gammaCorrectedInputsBatch, testmodel.output, a.nbTargets, a.logOutputAlbedos)

        loss = losses.Loss(a.loss, model.output, targetsReshaped, TILE_SIZE, a.batch_size, tf.placeholder(tf.float64, shape=(), name="lr"), a.includeDiffuse, a.nbSpecularRendering, a.nbDiffuseRendering)

        loss.createLossGraph()
        loss.createTrainVariablesGraph()

    #Register Renderings And Loss In Tensorflow
    display_fetches, converted_images = helpers.display_images_fetches(data.pathBatch, data.inputBatch, data.targetBatch, data.gammaCorrectedInputsBatch, model.output, a.nbTargets, a.logOutputAlbedos)
    if a.mode == "train":
        helpers.registerTensorboard(data.pathBatch, converted_images, a.nbTargets, loss.lossValue, a.batch_size, loss.targetsRenderings, loss.outputsRenderings)

    #Run either training or test
    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])
    saver = tf.train.Saver(max_to_keep=1)
    
    if a.checkpoint is not None:
        print("reading model from checkpoint : " + a.checkpoint)
        checkpoint = tf.train.latest_checkpoint(a.checkpoint)
        partialSaver = helpers.optimistic_saver(checkpoint) #Be careful this will silently not load variables if they are missing from the graph or checkpoint
        
    logdir = a.output_dir if a.summary_freq > 0 else None
    sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)

    with sv.managed_session("", config= config) as sess:
        sess.run(data.iterator.initializer)
        print("parameter_count =", sess.run(parameter_count))

        if a.checkpoint is not None:
            print("restoring model from checkpoint : " + a.checkpoint)
            partialSaver.restore(sess, checkpoint)

        max_steps = 2**32
        if a.max_epochs is not None:
            max_steps = data.stepsPerEpoch * a.max_epochs
        if a.max_steps is not None:
            max_steps = a.max_steps

        sess.run(data.iterator.initializer)
        if a.mode == "test":
            filesets = test(sess, data, max_steps, display_fetches, output_dir = a.output_dir)

        if a.mode == "train"  or a.mode == "finetune":
           train(sv, sess, data, max_steps, display_fetches, display_fetches_test, dataTest, saver, loss, a.output_dir)
Example #5
0
 def Loss(self):
     instance = losses.Loss(self.opts)
     return getattr(instance, self.opts.model)
Example #6
0
    def train(self, config, **kwargs):
        """Trains a given model specified in the config file or passed as the --model parameter.
        All options in the config file can be overwritten as needed by passing --PARAM
        Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}'

        :param config: yaml config file
        :param **kwargs: parameters to overwrite yaml config
        """

        config_parameters = utils.parse_config_or_kwargs(config, **kwargs)
        outputdir = os.path.join(
            config_parameters['outputpath'], config_parameters['model'],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex))
        # Early init because of creating dir
        checkpoint_handler = ModelCheckpoint(
            outputdir,
            'run',
            n_saved=3,
            require_empty=False,
            create_dir=True,
            score_function=self._negative_loss,
            score_name='loss')
        logger = utils.getfile_outlogger(os.path.join(outputdir, 'train.log'))
        logger.info("Storing files in {}".format(outputdir))
        # utils.pprint_dict
        utils.pprint_dict(config_parameters, logger.info)
        logger.info("Running on device {}".format(DEVICE))
        label_df = pd.read_csv(config_parameters['label'], sep='\s+')
        data_df = pd.read_csv(config_parameters['data'], sep='\s+')
        # In case that both are not matching
        merged = data_df.merge(label_df, on='filename')
        common_idxs = merged['filename']
        data_df = data_df[data_df['filename'].isin(common_idxs)]
        label_df = label_df[label_df['filename'].isin(common_idxs)]

        train_df, cv_df = utils.split_train_cv(
            label_df, **config_parameters['data_args'])
        train_label = utils.df_to_dict(train_df)
        cv_label = utils.df_to_dict(cv_df)
        data = utils.df_to_dict(data_df)

        transform = utils.parse_transforms(config_parameters['transforms'])
        torch.save(config_parameters, os.path.join(outputdir,
                                                   'run_config.pth'))
        logger.info("Transforms:")
        utils.pprint_dict(transform, logger.info, formatter='pretty')
        assert len(cv_df) > 0, "Fraction a bit too large?"

        trainloader = dataset.gettraindataloader(
            h5files=data,
            h5labels=train_label,
            transform=transform,
            label_type=config_parameters['label_type'],
            batch_size=config_parameters['batch_size'],
            num_workers=config_parameters['num_workers'],
            shuffle=True,
        )

        cvdataloader = dataset.gettraindataloader(
            h5files=data,
            h5labels=cv_label,
            label_type=config_parameters['label_type'],
            transform=None,
            shuffle=False,
            batch_size=config_parameters['batch_size'],
            num_workers=config_parameters['num_workers'],
        )
        model = getattr(models, config_parameters['model'],
                        'CRNN')(inputdim=trainloader.dataset.datadim,
                                outputdim=2,
                                **config_parameters['model_args'])
        if 'pretrained' in config_parameters and config_parameters[
                'pretrained'] is not None:
            model_dump = torch.load(config_parameters['pretrained'],
                                    map_location='cpu')
            model_state = model.state_dict()
            pretrained_state = {
                k: v
                for k, v in model_dump.items()
                if k in model_state and v.size() == model_state[k].size()
            }
            model_state.update(pretrained_state)
            model.load_state_dict(model_state)
            logger.info("Loading pretrained model {}".format(
                config_parameters['pretrained']))

        model = model.to(DEVICE)
        optimizer = getattr(
            torch.optim,
            config_parameters['optimizer'],
        )(model.parameters(), **config_parameters['optimizer_args'])

        utils.pprint_dict(optimizer, logger.info, formatter='pretty')
        utils.pprint_dict(model, logger.info, formatter='pretty')
        if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1:
            logger.info("Using {} GPUs!".format(torch.cuda.device_count()))
            model = torch.nn.DataParallel(model)
        criterion = getattr(losses, config_parameters['loss'])().to(DEVICE)

        def _train_batch(_, batch):
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                output = self._forward(
                    model, batch)  # output is tuple (clip, frame, target)
                loss = criterion(*output)
                loss.backward()
                # Single loss
                optimizer.step()
                return loss.item()

        def _inference(_, batch):
            model.eval()
            with torch.no_grad():
                return self._forward(model, batch)

        def thresholded_output_transform(output):
            # Output is (clip, frame, target, lengths)
            _, y_pred, y, y_clip, length = output
            batchsize, timesteps, ndim = y.shape
            idxs = torch.arange(timesteps,
                                device='cpu').repeat(batchsize).view(
                                    batchsize, timesteps)
            mask = (idxs < length.view(-1, 1)).to(y.device)
            y = y * mask.unsqueeze(-1)
            y_pred = torch.round(y_pred)
            y = torch.round(y)
            return y_pred, y

        metrics = {
            'Loss': losses.Loss(
                criterion),  #reimplementation of Loss, supports 3 way loss 
            'Precision': Precision(thresholded_output_transform),
            'Recall': Recall(thresholded_output_transform),
            'Accuracy': Accuracy(thresholded_output_transform),
        }
        train_engine = Engine(_train_batch)
        inference_engine = Engine(_inference)
        for name, metric in metrics.items():
            metric.attach(inference_engine, name)

        def compute_metrics(engine):
            inference_engine.run(cvdataloader)
            results = inference_engine.state.metrics
            output_str_list = [
                "Validation Results - Epoch : {:<5}".format(engine.state.epoch)
            ]
            for metric in metrics:
                output_str_list.append("{} {:<5.2f}".format(
                    metric, results[metric]))
            logger.info(" ".join(output_str_list))
            pbar.n = pbar.last_print_n = 0

        pbar = ProgressBar(persist=False)
        pbar.attach(train_engine)

        train_engine.add_event_handler(Events.ITERATION_COMPLETED(every=5000),
                                       compute_metrics)
        train_engine.add_event_handler(Events.EPOCH_COMPLETED, compute_metrics)

        early_stop_handler = EarlyStopping(
            patience=config_parameters['early_stop'],
            score_function=self._negative_loss,
            trainer=train_engine)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           early_stop_handler)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           checkpoint_handler, {
                                               'model': model,
                                           })

        train_engine.run(trainloader, max_epochs=config_parameters['epochs'])
        return outputdir
Example #7
0
    def train(self, config, **kwargs):
        """Trains a given model specified in the config file or passed as the --model parameter.
        All options in the config file can be overwritten as needed by passing --PARAM
        Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}'

        :param config: yaml config file
        :param **kwargs: parameters to overwrite yaml config
        """

        config_parameters = utils.parse_config_or_kwargs(config, **kwargs)
        outputdir = os.path.join(
            config_parameters['outputpath'], config_parameters['model'],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex))
        # Create base dir
        Path(outputdir).mkdir(exist_ok=True, parents=True)

        logger = utils.getfile_outlogger(os.path.join(outputdir, 'train.log'))
        logger.info("Storing files in {}".format(outputdir))
        # utils.pprint_dict
        utils.pprint_dict(config_parameters, logger.info)
        logger.info("Running on device {}".format(DEVICE))
        labels_df = pd.read_csv(config_parameters['label'],
                                sep='\s+').convert_dtypes()
        # In case of ave dataset where index is int, we change the
        # absolute name to relname
        if not np.all(labels_df['filename'].str.isnumeric()):
            labels_df.loc[:, 'filename'] = labels_df['filename'].apply(
                os.path.basename)
        encoder = utils.train_labelencoder(labels=labels_df['event_labels'])
        # These labels are useless, only for mode == stratified
        label_array, _ = utils.encode_labels(labels_df['event_labels'],
                                             encoder)
        if 'cv_label' in config_parameters:
            cv_df = pd.read_csv(config_parameters['cv_label'],
                                sep='\s+').convert_dtypes()
            if not np.all(cv_df['filename'].str.isnumeric()):
                cv_df.loc[:, 'filename'] = cv_df['filename'].apply(
                    os.path.basename)
            train_df = labels_df
            logger.info(
                f"Using CV labels from {config_parameters['cv_label']}")
        else:
            train_df, cv_df = utils.split_train_cv(
                labels_df, y=label_array, **config_parameters['data_args'])

        if 'cv_data' in config_parameters:
            cv_data = config_parameters['cv_data']
            logger.info(f"Using CV data {config_parameters['cv_data']}")
        else:
            cv_data = config_parameters['data']

        train_label_array, _ = utils.encode_labels(train_df['event_labels'],
                                                   encoder)
        cv_label_array, _ = utils.encode_labels(cv_df['event_labels'], encoder)

        transform = utils.parse_transforms(config_parameters['transforms'])
        utils.pprint_dict({'Classes': encoder.classes_},
                          logger.info,
                          formatter='pretty')
        torch.save(encoder, os.path.join(outputdir, 'run_encoder.pth'))
        torch.save(config_parameters, os.path.join(outputdir,
                                                   'run_config.pth'))
        logger.info("Transforms:")
        utils.pprint_dict(transform, logger.info, formatter='pretty')
        # For Unbalanced Audioset, this is true
        if 'sampler' in config_parameters and config_parameters[
                'sampler'] == 'MultiBalancedSampler':
            # Training sampler that oversamples the dataset to be roughly equally sized
            # Calcualtes mean over multiple instances, rather useful when number of classes
            # is large
            train_sampler = dataset.MultiBalancedSampler(
                train_label_array,
                num_samples=1 * train_label_array.shape[0],
                replacement=True)
            sampling_kwargs = {"shuffle": False, "sampler": train_sampler}
        elif 'sampler' in config_parameters and config_parameters[
                'sampler'] == 'MinimumOccupancySampler':
            # Asserts that each "batch" contains at least one instance
            train_sampler = dataset.MinimumOccupancySampler(
                train_label_array, sampling_mode='same')
            sampling_kwargs = {"shuffle": False, "sampler": train_sampler}
        else:
            sampling_kwargs = {"shuffle": True}

        logger.info("Using Sampler {}".format(sampling_kwargs))

        trainloader = dataset.getdataloader(
            {
                'filename': train_df['filename'].values,
                'encoded': train_label_array
            },
            config_parameters['data'],
            transform=transform,
            batch_size=config_parameters['batch_size'],
            colname=config_parameters['colname'],
            num_workers=config_parameters['num_workers'],
            **sampling_kwargs)

        cvdataloader = dataset.getdataloader(
            {
                'filename': cv_df['filename'].values,
                'encoded': cv_label_array
            },
            cv_data,
            transform=None,
            shuffle=False,
            colname=config_parameters['colname'],
            batch_size=config_parameters['batch_size'],
            num_workers=config_parameters['num_workers'])
        model = getattr(models, config_parameters['model'],
                        'CRNN')(inputdim=trainloader.dataset.datadim,
                                outputdim=len(encoder.classes_),
                                **config_parameters['model_args'])
        if 'pretrained' in config_parameters and config_parameters[
                'pretrained'] is not None:
            models.load_pretrained(model,
                                   config_parameters['pretrained'],
                                   outputdim=len(encoder.classes_))
            logger.info("Loading pretrained model {}".format(
                config_parameters['pretrained']))

        model = model.to(DEVICE)
        if config_parameters['optimizer'] == 'AdaBound':
            try:
                import adabound
                optimizer = adabound.AdaBound(
                    model.parameters(), **config_parameters['optimizer_args'])
            except ImportError:
                config_parameters['optimizer'] = 'Adam'
                config_parameters['optimizer_args'] = {}
        else:
            optimizer = getattr(
                torch.optim,
                config_parameters['optimizer'],
            )(model.parameters(), **config_parameters['optimizer_args'])

        utils.pprint_dict(optimizer, logger.info, formatter='pretty')
        utils.pprint_dict(model, logger.info, formatter='pretty')
        if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1:
            logger.info("Using {} GPUs!".format(torch.cuda.device_count()))
            model = torch.nn.DataParallel(model)
        criterion = getattr(losses, config_parameters['loss'])().to(DEVICE)

        def _train_batch(_, batch):
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                output = self._forward(
                    model, batch)  # output is tuple (clip, frame, target)
                loss = criterion(*output)
                loss.backward()
                # Single loss
                optimizer.step()
                return loss.item()

        def _inference(_, batch):
            model.eval()
            with torch.no_grad():
                return self._forward(model, batch)

        def thresholded_output_transform(output):
            # Output is (clip, frame, target)
            y_pred, _, y = output
            y_pred = torch.round(y_pred)
            return y_pred, y

        precision = Precision(thresholded_output_transform, average=False)
        recall = Recall(thresholded_output_transform, average=False)
        f1_score = (precision * recall * 2 / (precision + recall)).mean()
        metrics = {
            'Loss': losses.Loss(
                criterion),  #reimplementation of Loss, supports 3 way loss 
            'Precision': Precision(thresholded_output_transform),
            'Recall': Recall(thresholded_output_transform),
            'Accuracy': Accuracy(thresholded_output_transform),
            'F1': f1_score,
        }
        train_engine = Engine(_train_batch)
        inference_engine = Engine(_inference)
        for name, metric in metrics.items():
            metric.attach(inference_engine, name)

        def compute_metrics(engine):
            inference_engine.run(cvdataloader)
            results = inference_engine.state.metrics
            output_str_list = [
                "Validation Results - Epoch : {:<5}".format(engine.state.epoch)
            ]
            for metric in metrics:
                output_str_list.append("{} {:<5.2f}".format(
                    metric, results[metric]))
            logger.info(" ".join(output_str_list))

        pbar = ProgressBar(persist=False)
        pbar.attach(train_engine)

        if 'itercv' in config_parameters and config_parameters[
                'itercv'] is not None:
            train_engine.add_event_handler(
                Events.ITERATION_COMPLETED(every=config_parameters['itercv']),
                compute_metrics)
        train_engine.add_event_handler(Events.EPOCH_COMPLETED, compute_metrics)

        # Default scheduler is using patience=3, factor=0.1
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, **config_parameters['scheduler_args'])

        @inference_engine.on(Events.EPOCH_COMPLETED)
        def update_reduce_on_plateau(engine):
            logger.info(f"Scheduling epoch {engine.state.epoch}")
            val_loss = engine.state.metrics['Loss']
            if 'ReduceLROnPlateau' == scheduler.__class__.__name__:
                scheduler.step(val_loss)
            else:
                scheduler.step()

        early_stop_handler = EarlyStopping(
            patience=config_parameters['early_stop'],
            score_function=self._negative_loss,
            trainer=train_engine)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           early_stop_handler)
        if config_parameters['save'] == 'everyepoch':
            checkpoint_handler = ModelCheckpoint(outputdir,
                                                 'run',
                                                 n_saved=5,
                                                 require_empty=False)
            train_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           checkpoint_handler, {
                                               'model': model,
                                           })
            train_engine.add_event_handler(
                Events.ITERATION_COMPLETED(every=config_parameters['itercv']),
                checkpoint_handler, {
                    'model': model,
                })
        else:
            checkpoint_handler = ModelCheckpoint(
                outputdir,
                'run',
                n_saved=1,
                require_empty=False,
                score_function=self._negative_loss,
                global_step_transform=global_step_from_engine(
                    train_engine),  # Just so that model is saved with epoch...
                score_name='loss')
            inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                               checkpoint_handler, {
                                                   'model': model,
                                               })

        train_engine.run(trainloader, max_epochs=config_parameters['epochs'])
        return outputdir