if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", dest="mode", default="cluster")
    args = parser.parse_args()

    config = {
        "hyperparameters": {
            "hidden_size": 2,
            "learning_rate": 0.1,
            "global_batch_size": 4,
            "optimizer": "sgd",
            "shuffle": False,
        }
    }

    context = estimator.init(
        config=config, mode=experimental.Mode(args.mode), context_dir=str(pathlib.Path.cwd())
    )

    batch_size = context.get_per_slot_batch_size()
    shuffle = context.get_hparam("shuffle")
    context.serving_input_receiver_fns = build_serving_input_receiver_fns()
    context.train_and_evaluate(
        build_estimator(context),
        tf.estimator.TrainSpec(
            xor_input_fn(context=context, batch_size=batch_size, shuffle=shuffle), max_steps=1
        ),
        tf.estimator.EvalSpec(xor_input_fn(context=context, batch_size=batch_size, shuffle=False)),
    )
示例#2
0
    parser.add_argument("--local", action="store_true")
    parser.add_argument("--test", action="store_true")
    args = parser.parse_args()

    config = {
        "hyperparameters": {
            "hidden_size": 2,
            "learning_rate": 0.1,
            "global_batch_size": 4,
            "optimizer": "sgd",
            "shuffle": False,
        }
    }

    context = estimator.init(config=config,
                             local=args.local,
                             test=args.test,
                             context_dir=str(pathlib.Path.cwd()))

    batch_size = context.get_per_slot_batch_size()
    shuffle = context.get_hparam("shuffle")
    context.serving_input_receiver_fns = build_serving_input_receiver_fns()
    context.train_and_evaluate(
        build_estimator(context),
        tf.estimator.TrainSpec(xor_input_fn(context=context,
                                            batch_size=batch_size,
                                            shuffle=shuffle),
                               max_steps=1),
        tf.estimator.EvalSpec(
            xor_input_fn(context=context, batch_size=batch_size,
                         shuffle=False)),
    )