def main(config_file: str):
    config = cfg.load(config_file, cfg.Configuration)

    # set up directories and logging
    model_dir, result_dir = fs.prepare_directories(
        config_file, cfg.Configuration, lambda: fs.get_directory_name(config))
    config.model_dir = model_dir
    config.result_dir = result_dir
    print(config)

    # set seed before model instantiation
    print('Set seed to {}'.format(config.seed))
    seed.set_seed(config.seed, config.cudnn_determinism)

    # load train and valid subjects from split file (also test but it is unused)
    subjects_train, subjects_valid, subjects_test = split.load_split(
        config.split_file)
    print('Train subjects:', subjects_train)
    print('Valid subjects:', subjects_valid)

    # set up data handling
    data_handler = hdlr.MRFDataHandler(
        config,
        subjects_train,
        subjects_valid,
        subjects_test,
        False,
        padding_size=mdl.get_padding_size(config))

    # extract a sample for model initialization
    data_handler.dataset.set_extractor(data_handler.extractor_train)
    sample = data_handler.dataset[0]

    with tf.Session() as sess:
        model = mdl.get_model(config)(sess, sample, config)
        print('Number of parameters:', model.get_number_parameters())

        logger = log.TensorFlowLogger(config.model_dir, sess,
                                      model.epoch_summaries(),
                                      model.batch_summaries(),
                                      model.visualization_summaries())

        trainer = train.AssemblingTesterTensorFlow(
            data_handler, logger, config, model,
            sess)  # use this class to test the pipeline
        #trainer = train.MRFTrainer(data_handler, logger, config, model, sess)

        tf.get_default_graph().finalize(
        )  # to ensure that no ops are added during training, which would lead to
        # a growing graph
        trainer.train()
        logger.close()
Esempio n. 2
0
def main(model_dir: str, result_dir: str, do_occlusion: bool):
    if not os.path.isdir(model_dir):
        raise RuntimeError('Model dir "{}" does not exist'.format(model_dir))

    config = cfg.load(
        glob.glob(os.path.join(model_dir, '*config*.json'))[0],
        cfg.Configuration)
    split_file = glob.glob(os.path.join(model_dir, '*split*.json'))[0]

    os.makedirs(result_dir, exist_ok=True)

    # load train, valid, and test subjects from split file
    subjects_train, subjects_valid, subjects_test = split.load_split(
        split_file)
    print('Test subjects:', subjects_test)

    # set up data handling
    data_handler = hdlr.MRFDataHandler(
        config,
        subjects_train,
        subjects_valid,
        subjects_test,
        False,
        padding_size=mdl.get_padding_size(config))

    # extract a sample for model initialization
    data_handler.dataset.set_extractor(data_handler.extractor_train)
    sample = data_handler.dataset[0]

    with tf.Session() as sess:
        model = mdl.get_model(config)(sess, sample, config)
        tester = test.MRFTensorFlowTester(data_handler, model, model_dir,
                                          result_dir, config.maps, sess)
        tester.load(os.path.join(model_dir, config.best_model_file_name))

        print('Predict...')
        tester.predict()
        if do_occlusion:
            occlude(tester, result_dir, sample[pymia_def.KEY_IMAGES].shape[-2])
def prepare_directories(config_file, config_cls,
                        directory_name_fn) -> (str, str):
    """Prepares the directories for an experiment.

    Args:
        config_file: The config file path.
        config_cls: The config file class.
        directory_name_fn: Lambda to a function that returns a the name for the directories to create, i.e. a str.

    Returns:
        A tuple with the paths to the created model and result directories.
    """
    config = cfg.load(config_file, config_cls)

    # create required directories
    suffix = directory_name_fn()
    model_dir = os.path.join(config.model_dir, suffix)
    result_dir = os.path.join(config.result_dir, suffix)

    # create directories
    os.makedirs(model_dir, exist_ok=True)
    os.makedirs(result_dir, exist_ok=True)

    # copy config file
    shutil.copyfile(config_file, os.path.join(result_dir, 'config.json'))
    shutil.copyfile(config_file, os.path.join(model_dir, 'config.json'))

    # copy split file
    if os.path.exists(config.split_file):
        shutil.copyfile(
            config.split_file,
            os.path.join(result_dir, os.path.basename(config.split_file)))
        shutil.copyfile(
            config.split_file,
            os.path.join(model_dir, os.path.basename(config.split_file)))

    return model_dir, result_dir