コード例 #1
0
def run(blueoil_config_file, experiment_id):
    """Train from blueoil config.

    Args:
        blueoil_config_file: 
        experiment_id: 

    """

    if horovod_util.is_enabled():
        horovod_util.setup()

    if horovod_util.is_rank0():
        # Copy bueoil config yaml.
        output_dir = os.environ.get('OUTPUT_DIR', 'saved')
        experiment_dir = os.path.join(output_dir, experiment_id)
        save_config_file(blueoil_config_file, experiment_dir)

    # Generete lmnet config from blueoil config.
    # this lmnet_config_file cannot be reuse from multiprocesses as the file is a named temporary file.
    lmnet_config_file = generate(blueoil_config_file)

    # Start training
    run_train(network=None,
              dataset=None,
              config_file=lmnet_config_file,
              experiment_id=experiment_id,
              recreate=False)
コード例 #2
0
ファイル: train.py プロジェクト: ammarnajjar/blueoil
def run(network, dataset, config_file, experiment_id, recreate):
    environment.init(experiment_id)
    config = config_util.load(config_file)

    if network:
        network_class = module_loader.load_network_class(network)
        config.NETWORK_CLASS = network_class
    if dataset:
        dataset_class = module_loader.load_dataset_class(dataset)
        config.DATASET_CLASS = dataset_class

    if horovod_util.is_enabled():
        horovod_util.setup()

    if horovod_util.is_rank0():
        config_util.display(config)
        executor.init_logging(config)

        executor.prepare_dirs(recreate)
        config_util.copy_to_experiment_dir(config_file)
        config_util.save_yaml(environment.EXPERIMENT_DIR, config)

    start_training(config)