예제 #1
0
def experiment_fn(run_config, params):

    conversation = Conversation()
    estimator = tf.estimator.Estimator(model_fn=conversation.model_fn,
                                       model_dir=Config.train.model_dir,
                                       params=params,
                                       config=run_config)

    vocab = data_loader.load_vocab("vocab")
    Config.data.vocab_size = len(vocab)

    train_X, test_X, train_y, test_y = data_loader.make_train_and_test_set()

    train_input_fn, train_input_hook = data_loader.make_batch(
        (train_X, train_y), batch_size=Config.model.batch_size)
    test_input_fn, test_input_hook = data_loader.make_batch(
        (test_X, test_y), batch_size=Config.model.batch_size, scope="test")

    experiment = tf.contrib.learn.Experiment(
        estimator=estimator,
        train_input_fn=train_input_fn,
        eval_input_fn=test_input_fn,
        train_steps=Config.train.train_steps,
        min_eval_frequency=Config.train.min_eval_frequency,
        train_monitors=[
            train_input_hook,
            hook.print_variables(
                variables=['train/enc_0', 'train/dec_0', 'train/pred_0'],
                vocab=vocab,
                every_n_iter=Config.train.check_hook_n_iter)
        ],
        eval_hooks=[test_input_hook],
        eval_delay_secs=0)
    return experiment
예제 #2
0
def main():
    params = tf.contrib.training.HParams(**Config.model.to_dict())

    run_config = tf.estimator.RunConfig(
        model_dir=Config.train.model_dir,
        save_checkpoints_steps=Config.train.save_checkpoints_steps,
    )

    tf_config = os.environ.get('TF_CONFIG', '{}')
    tf_config_json = json.loads(tf_config)

    cluster = tf_config_json.get('cluster')
    job_name = tf_config_json.get('task', {}).get('type')
    task_index = tf_config_json.get('task', {}).get('index')

    cluster_spec = tf.train.ClusterSpec(cluster)
    server = tf.train.Server(cluster_spec,
                             job_name=job_name,
                             task_index=task_index)

    if job_name == "ps":
        tf.logging.info("Started server!")
        server.join()

    if job_name == "worker":
        with tf.Session(server.target):
            with tf.device(
                    tf.train.replica_device_setter(
                        worker_device="/job:worker/task:%d" % task_index,
                        cluster=cluster)):
                tf.logging.info("Initializing Estimator")
                conversation = Conversation()
                estimator = tf.estimator.Estimator(
                    model_fn=conversation.model_fn,
                    model_dir=Config.train.model_dir,
                    params=params,
                    config=run_config)

                tf.logging.info("Initializing vocabulary")
                vocab = data_loader.load_vocab("vocab")
                Config.data.vocab_size = len(vocab)

                train_X, test_X, train_y, test_y = data_loader.make_train_and_test_set(
                )
                train_input_fn, train_input_hook = data_loader.make_batch(
                    (train_X, train_y), batch_size=Config.model.batch_size)
                test_input_fn, test_input_hook = data_loader.make_batch(
                    (test_X, test_y),
                    batch_size=Config.model.batch_size,
                    scope="test")

                tf.logging.info("Initializing Specifications")
                train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn,
                                                    max_steps=1000)
                eval_spec = tf.estimator.EvalSpec(input_fn=test_input_fn)
                tf.logging.info("Run training")
                tf.estimator.train_and_evaluate(estimator, train_spec,
                                                eval_spec)
예제 #3
0
파일: main.py 프로젝트: PhantomGrapes/ds
def experiment_fn(run_config, params):
    # 先定义estimator
    conversation = Conversation()
    estimator = tf.estimator.Estimator(model_fn=conversation.model_fn,
                                       model_dir=Config.train.model_dir,
                                       params=params,
                                       config=run_config)

    # 返回字典
    vocab = data_loader.load_vocab("vocab")
    Config.data.vocab_size = len(vocab)

    # 定义训练数据
    train_X, test_X, train_y, test_y = data_loader.make_train_and_test_set()

    train_input_fn, train_input_hook = data_loader.make_batch(
        (train_X, train_y), batch_size=Config.model.batch_size)
    test_input_fn, test_input_hook = data_loader.make_batch(
        (test_X, test_y), batch_size=Config.model.batch_size, scope="test")

    train_hooks = [train_input_hook]
    if Config.train.print_verbose:
        train_hooks.append(
            hook.print_variables(
                variables=['train/enc_0', 'train/dec_0', 'train/pred_0'],
                rev_vocab=utils.get_rev_vocab(vocab),
                every_n_iter=Config.train.check_hook_n_iter))
    if Config.train.debug:
        train_hooks.append(tf_debug.LocalCLIDebugHook())

    eval_hooks = [test_input_hook]
    if Config.train.debug:
        eval_hooks.append(tf_debug.LocalCLIDebugHook())

    # 定义实验
    experiment = tf.contrib.learn.Experiment(
        estimator=estimator,
        train_input_fn=train_input_fn,
        eval_input_fn=test_input_fn,
        train_steps=Config.train.train_steps,
        min_eval_frequency=Config.train.min_eval_frequency,
        train_monitors=train_hooks,
        eval_hooks=eval_hooks,
        eval_delay_secs=0)
    return experiment
예제 #4
0
def experiment_fn(run_config, params):

    model = Model()
    estimator = tf.estimator.Estimator(model_fn=model.model_fn,
                                       model_dir=Config.train.model_dir,
                                       params=params,
                                       config=run_config)

    vocab = data_loader.load_vocab("vocab")
    Config.data.vocab_size = len(vocab)

    train_data, test_data = data_loader.make_train_and_test_set()
    train_input_fn, train_input_hook = data_loader.make_batch(train_data,
                                                              batch_size=Config.model.batch_size,
                                                              scope="train")
    test_input_fn, test_input_hook = data_loader.make_batch(test_data,
                                                            batch_size=Config.model.batch_size,
                                                            scope="test")

    train_hooks = [train_input_hook]
    if Config.train.print_verbose:
        train_hooks.append(
            hook.print_variables(variables=['train/input_0'],
                                 rev_vocab=get_rev_vocab(vocab),
                                 every_n_iter=Config.train.check_hook_n_iter))
        train_hooks.append(
            hook.print_target(variables=['train/target_0', 'train/pred_0'],
                              every_n_iter=Config.train.check_hook_n_iter))
    if Config.train.debug:
        train_hooks.append(tf_debug.LocalCLIDebugHook())

    eval_hooks = [test_input_hook]
    if Config.train.debug:
        eval_hooks.append(tf_debug.LocalCLIDebugHook())

    experiment = tf.contrib.learn.Experiment(estimator=estimator,
                                             train_input_fn=train_input_fn,
                                             eval_input_fn=test_input_fn,
                                             train_steps=Config.train.train_steps,
                                             min_eval_frequency=Config.train.min_eval_frequency,
                                             train_monitors=train_hooks,
                                             eval_hooks=eval_hooks)
    return experiment
예제 #5
0
def experiment_fn(run_config, params):

    model = Model()
    estimator = tf.estimator.Estimator(model_fn=model.model_fn,
                                       model_dir=Config.train.model_dir,
                                       params=params,
                                       config=run_config)

    train_data, test_data = data_loader.make_train_and_test_set()

    train_input_fn, train_input_hook = data_loader.make_batch(
        train_data, batch_size=Config.model.batch_size, scope="train")
    test_input_fn, test_input_hook = data_loader.make_batch(
        test_data, batch_size=Config.model.batch_size, scope="test")

    train_hooks = [train_input_hook]
    if Config.train.debug:
        train_hooks.append(tf_debug.LocalCLIDebugHook())
    if Config.train.print_verbose:
        train_hooks.append(
            tf.train.LoggingTensorHook(
                ["loss/reconstruction_error", "loss/kl_divergence"],
                every_n_iter=Config.train.check_hook_n_iter))

    eval_hooks = [test_input_hook]
    if Config.train.debug:
        eval_hooks.append(tf_debug.LocalCLIDebugHook())

    experiment = tf.contrib.learn.Experiment(
        estimator=estimator,
        train_input_fn=train_input_fn,
        eval_input_fn=test_input_fn,
        train_steps=Config.train.train_steps,
        min_eval_frequency=Config.train.min_eval_frequency,
        train_monitors=train_hooks,
        eval_hooks=eval_hooks)
    return experiment