示例#1
0
def main():
    params = parse_args()
    tf.random.set_seed(params.seed)
    tf.get_logger().setLevel(logging.ERROR)

    params = setup_horovod(params)
    set_flags(params)
    model_dir = prepare_model_dir(params)
    logger = get_logger(params)

    dataset = Dataset(data_dir=params.data_dir,
                      batch_size=params.batch_size,
                      fold_idx=params.fold,
                      n_folds=params.num_folds,
                      params=params,
                      seed=params.seed)

    estimator = build_estimator(params, model_dir)

    if params.tensorboard_logging and (params.worker_id == 0
                                       or params.log_all_workers):
        from TensorFlow.common.tb_utils import write_hparams_v1
        write_hparams_v1(params.log_dir, vars(params))

    if not params.benchmark:
        params.max_steps = params.max_steps // params.num_workers
    if 'train' in params.exec_mode:
        with dump_callback(params.dump_config):
            training_hooks = get_hooks(params, logger)
            dataset_fn = dataset.synth_train_fn if params.synth_data else dataset.train_fn

            estimator.train(input_fn=dataset_fn,
                            steps=params.max_steps,
                            hooks=training_hooks)

    if 'evaluate' in params.exec_mode:
        result = estimator.evaluate(input_fn=dataset.eval_fn,
                                    steps=dataset.eval_size)
        data = parse_evaluation_results(result)
        if params.worker_id == 0:
            logger.log(step=(), data=data)

    if 'predict' == params.exec_mode:
        inference_hooks = get_hooks(params, logger)
        if params.worker_id == 0:
            count = 1 if not params.benchmark else 2 * params.warmup_steps * params.batch_size // dataset.test_size
            predictions = estimator.predict(input_fn=lambda: dataset.test_fn(
                count=count, drop_remainder=params.benchmark),
                                            hooks=inference_hooks)

            for idx, p in enumerate(predictions):
                volume = p['predictions']
                if not params.benchmark:
                    np.save(
                        os.path.join(params.model_dir,
                                     "vol_{}.npy".format(idx)), volume)
示例#2
0
def main():
    tf.get_logger().setLevel(logging.ERROR)
    hvd.init()
    params = PARSER.parse_args()
    model_dir = prepare_model_dir(params)
    logger = get_logger(params)

    dataset = Dataset(data_dir=params.data_dir,
                      batch_size=params.batch_size,
                      fold_idx=params.fold,
                      n_folds=params.num_folds,
                      params=params)

    estimator = build_estimator(params=params, model_dir=model_dir)

    max_steps = params.max_steps // (1 if params.benchmark else hvd.size())

    if 'train' in params.exec_mode:
        training_hooks = get_hooks(params, logger)
        estimator.train(input_fn=dataset.train_fn,
                        steps=max_steps,
                        hooks=training_hooks)

    if 'evaluate' in params.exec_mode:
        result = estimator.evaluate(input_fn=dataset.eval_fn,
                                    steps=dataset.eval_size)
        data = parse_evaluation_results(result)
        if hvd.rank() == 0:
            logger.log(step=(), data=data)

    if 'predict' == params.exec_mode:
        inference_hooks = get_hooks(params, logger)
        if hvd.rank() == 0:
            count = 1 if not params.benchmark else 2 * params.warmup_steps * params.batch_size // dataset.test_size
            predictions = estimator.predict(input_fn=lambda: dataset.test_fn(
                count=count, drop_remainder=params.benchmark),
                                            hooks=inference_hooks)

            for idx, p in enumerate(predictions):
                volume = p['predictions']
                if not params.benchmark:
                    np.save(
                        os.path.join(params.model_dir,
                                     "vol_{}.npy".format(idx)), volume)

    if 'debug_train' == params.exec_mode:
        hooks = [hvd.BroadcastGlobalVariablesHook(0)]
        if hvd.rank() == 0:
            hooks += [
                TrainingHook(log_every=params.log_every,
                             logger=logger,
                             tensor_names=['total_loss_ref:0']),
                ProfilingHook(warmup_steps=params.warmup_steps,
                              global_batch_size=hvd.size() * params.batch_size,
                              logger=logger,
                              mode='train')
            ]

        estimator.train(input_fn=dataset.synth_train_fn,
                        steps=max_steps,
                        hooks=hooks)

    if 'debug_predict' == params.exec_mode:
        if hvd.rank() == 0:
            hooks = [
                ProfilingHook(warmup_steps=params.warmup_steps,
                              global_batch_size=params.batch_size,
                              logger=logger,
                              mode='inference')
            ]
            count = 2 * params.warmup_steps
            predictions = estimator.predict(
                input_fn=lambda: dataset.synth_predict_fn(count=count),
                hooks=hooks)
            for p in predictions:
                _ = p['predictions']