def create_dataset(args):
    ''' Create train and val dataset
    '''
    # Data loading code
    if args.task == 'syn_t1':
        dataset = MRISynDataset
    elif args.task == 'seg_oct':
        dataset = OCTSegDataset
    else:
        raise NotImplementedError
    train_dataset = dataset(args, train=True, augment=True)
    val_dataset_list = []
    if args.vimg_path:           
        val_dataset = dataset(args, train=False, augment=False)
        # return a list of val_dataset, each contains one subject             
        if args.__dict__.get('sub_name',False):
            if os.path.isfile(args.sub_name):
                with open(args.sub_name) as f:
                    dataname = f.read().splitlines() 
            datalist = val_dataset.datalist
            labellist = val_dataset.labellist
            for name in dataname:
                val_dataset.datalist = sorted([_ for _ in datalist if name in str(_)])
                val_dataset.labellist = sorted([_ for _ in labellist if name in str(_)])
                val_dataset_list.append(deepcopy(val_dataset))
        else:
            val_dataset_list = [val_dataset]
    elif args.split and not args.__dict__.get('sub_name',False):
        train_dataset, val_dataset = split_data(dataset=train_dataset,
                                                split=args.split,
                                                switch=args.test or args.evaluate)
        train_dataset = train_dataset
        val_dataset = val_dataset
        val_dataset_list = [val_dataset]
    else:
        # no validation used
        val_dataset = dataset(args, train=False)
        val_dataset_list = [val_dataset]
    logger.info('Found {} training samples'.format(len(train_dataset)))
    logger.info('Found {} validation subjects with total {} samples'.format(len(val_dataset_list),len(val_dataset)))

    train_loader = DataLoader(dataset=train_dataset, 
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers, 
                              pin_memory=True)
    val_loader = [DataLoader(dataset=_,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.workers,
                            pin_memory=True) for _ in val_dataset_list]
    return train_loader, val_loader
def run():

    # declare global variables
    global config
    global config_path

    # add config file to experiment
    experiment.add_artifact(config_path)

    if config['dataset'].get('link', True):
        dataset_config_path = f'../configs/datasets/{ config["dataset"]["link"] }'
        experiment.add_artifact(dataset_config_path)
        config['dataset'].update(json.load(open(dataset_config_path)))

    # dataset specific variables
    folds = config['dataset']['split']
    data_directory = config['dataset']['path']

    # split dataset into k folds
    split_dirs = split_data(folds, data_directory, data_name)

    # total results dictionary
    results = {
        'f1': [],
        'rec': [],
        'acc': [],
        'mcc': [],
        'prec': [],
        'spec': []
    }

    # iterate over each dataset split
    for split_index in range(len(split_dirs)):

        # print current validation split index
        print(f'start validating on split {split_index}')

        # restart keras session
        K.clear_session()

        # prepare dataset by distributing the k splits
        # into training and validation sets
        training_directory, validation_directory = prepare_dataset(
            split_dirs, split_index, len(split_dirs))

        if config['dataset'].get('validation_extension', False):

            extension_path = config['dataset']['validation_extension_path']

            for class_extension in os.listdir(extension_path):

                class_path = os.path.join(extension_path, class_extension)
                target_path = os.path.join(validation_directory,
                                           class_extension)

                for filename in os.listdir(class_path):
                    shutil.copy(os.path.join(class_path, filename),
                                os.path.join(target_path, filename))

        # print training directories for sanity
        print(f'training on {training_directory}')
        print(f'validation on {validation_directory}')

        # load model from model file or build it using a build file.
        if config['model'].get('load_model', False):
            model = load_model(config['model']['model_splits'][split_index])
        else:
            model_builder_path = config['model']['build_file']
            model_builder = importlib.import_module(
                f'models.{model_builder_path}')
            model = model_builder.build(config)

        # train model and get last weigths
        if config['model'].get('train', True):
            print("Start training...")
            model = train(model, config, experiment, training_directory,
                          validation_directory, f'split_{split_index}')
            evaluate(model, config, experiment, validation_directory,
                     f'split_{split_index}')

        # if fine tune, train model again on config link found in config
        if config.get('fine_tuning', {}).get(
                'link', False) and config['model'].get('train', True):

            print("Start fine tuning...")

            # load config link from config
            fine_tuning_config_name = config['fine_tuning']['link']
            fine_tuning_config_path = f'../configs/links/{fine_tuning_config_name}'
            fine_tuning_config = json.load(open(fine_tuning_config_path))

            if fine_tuning_config['dataset'].get('link', True):
                dataset_config_path = f'../configs/datasets/{fine_tuning_config["dataset"]["link"]}'
                experiment.add_artifact(dataset_config_path)
                fine_tuning_config['dataset'].update(
                    json.load(open(dataset_config_path)))

            # add link config to experiment
            experiment.add_artifact(fine_tuning_config_path)

            # train using new config
            model = train(model, fine_tuning_config, experiment,
                          training_directory, validation_directory,
                          f'fine_split_{split_index}')

        # evaluate train model and get metrics
        print("Start evaluation...")
        split_results = evaluate(model, config, experiment,
                                 validation_directory, f'split_{split_index}')

        # merge split results with total results
        for key in split_results:
            results[key].append(split_results[key])
            print(key, results[key])

    # log results
    log_cross_validation_results(full_kfold_summary_file_path, results,
                                 experiment_name, folds)
    log_to_results_comparison(results, experiment_name, folds)

    experiment.add_artifact(full_kfold_summary_file_path)
    experiment.add_artifact(all_results_file_path)