def experiment_fn(run_config, params): conversation = Conversation() estimator = tf.estimator.Estimator(model_fn=conversation.model_fn, model_dir=Config.train.model_dir, params=params, config=run_config) vocab = data_loader.load_vocab("vocab") Config.data.vocab_size = len(vocab) train_X, test_X, train_y, test_y = data_loader.make_train_and_test_set() train_input_fn, train_input_hook = data_loader.make_batch( (train_X, train_y), batch_size=Config.model.batch_size) test_input_fn, test_input_hook = data_loader.make_batch( (test_X, test_y), batch_size=Config.model.batch_size, scope="test") experiment = tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=train_input_fn, eval_input_fn=test_input_fn, train_steps=Config.train.train_steps, min_eval_frequency=Config.train.min_eval_frequency, train_monitors=[ train_input_hook, hook.print_variables( variables=['train/enc_0', 'train/dec_0', 'train/pred_0'], vocab=vocab, every_n_iter=Config.train.check_hook_n_iter) ], eval_hooks=[test_input_hook], eval_delay_secs=0) return experiment
def main(): params = tf.contrib.training.HParams(**Config.model.to_dict()) run_config = tf.estimator.RunConfig( model_dir=Config.train.model_dir, save_checkpoints_steps=Config.train.save_checkpoints_steps, ) tf_config = os.environ.get('TF_CONFIG', '{}') tf_config_json = json.loads(tf_config) cluster = tf_config_json.get('cluster') job_name = tf_config_json.get('task', {}).get('type') task_index = tf_config_json.get('task', {}).get('index') cluster_spec = tf.train.ClusterSpec(cluster) server = tf.train.Server(cluster_spec, job_name=job_name, task_index=task_index) if job_name == "ps": tf.logging.info("Started server!") server.join() if job_name == "worker": with tf.Session(server.target): with tf.device( tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % task_index, cluster=cluster)): tf.logging.info("Initializing Estimator") conversation = Conversation() estimator = tf.estimator.Estimator( model_fn=conversation.model_fn, model_dir=Config.train.model_dir, params=params, config=run_config) tf.logging.info("Initializing vocabulary") vocab = data_loader.load_vocab("vocab") Config.data.vocab_size = len(vocab) train_X, test_X, train_y, test_y = data_loader.make_train_and_test_set( ) train_input_fn, train_input_hook = data_loader.make_batch( (train_X, train_y), batch_size=Config.model.batch_size) test_input_fn, test_input_hook = data_loader.make_batch( (test_X, test_y), batch_size=Config.model.batch_size, scope="test") tf.logging.info("Initializing Specifications") train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000) eval_spec = tf.estimator.EvalSpec(input_fn=test_input_fn) tf.logging.info("Run training") tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
def experiment_fn(run_config, params): # 先定义estimator conversation = Conversation() estimator = tf.estimator.Estimator(model_fn=conversation.model_fn, model_dir=Config.train.model_dir, params=params, config=run_config) # 返回字典 vocab = data_loader.load_vocab("vocab") Config.data.vocab_size = len(vocab) # 定义训练数据 train_X, test_X, train_y, test_y = data_loader.make_train_and_test_set() train_input_fn, train_input_hook = data_loader.make_batch( (train_X, train_y), batch_size=Config.model.batch_size) test_input_fn, test_input_hook = data_loader.make_batch( (test_X, test_y), batch_size=Config.model.batch_size, scope="test") train_hooks = [train_input_hook] if Config.train.print_verbose: train_hooks.append( hook.print_variables( variables=['train/enc_0', 'train/dec_0', 'train/pred_0'], rev_vocab=utils.get_rev_vocab(vocab), every_n_iter=Config.train.check_hook_n_iter)) if Config.train.debug: train_hooks.append(tf_debug.LocalCLIDebugHook()) eval_hooks = [test_input_hook] if Config.train.debug: eval_hooks.append(tf_debug.LocalCLIDebugHook()) # 定义实验 experiment = tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=train_input_fn, eval_input_fn=test_input_fn, train_steps=Config.train.train_steps, min_eval_frequency=Config.train.min_eval_frequency, train_monitors=train_hooks, eval_hooks=eval_hooks, eval_delay_secs=0) return experiment
def experiment_fn(run_config, params): model = Model() estimator = tf.estimator.Estimator(model_fn=model.model_fn, model_dir=Config.train.model_dir, params=params, config=run_config) vocab = data_loader.load_vocab("vocab") Config.data.vocab_size = len(vocab) train_data, test_data = data_loader.make_train_and_test_set() train_input_fn, train_input_hook = data_loader.make_batch(train_data, batch_size=Config.model.batch_size, scope="train") test_input_fn, test_input_hook = data_loader.make_batch(test_data, batch_size=Config.model.batch_size, scope="test") train_hooks = [train_input_hook] if Config.train.print_verbose: train_hooks.append( hook.print_variables(variables=['train/input_0'], rev_vocab=get_rev_vocab(vocab), every_n_iter=Config.train.check_hook_n_iter)) train_hooks.append( hook.print_target(variables=['train/target_0', 'train/pred_0'], every_n_iter=Config.train.check_hook_n_iter)) if Config.train.debug: train_hooks.append(tf_debug.LocalCLIDebugHook()) eval_hooks = [test_input_hook] if Config.train.debug: eval_hooks.append(tf_debug.LocalCLIDebugHook()) experiment = tf.contrib.learn.Experiment(estimator=estimator, train_input_fn=train_input_fn, eval_input_fn=test_input_fn, train_steps=Config.train.train_steps, min_eval_frequency=Config.train.min_eval_frequency, train_monitors=train_hooks, eval_hooks=eval_hooks) return experiment
def experiment_fn(run_config, params): model = Model() estimator = tf.estimator.Estimator(model_fn=model.model_fn, model_dir=Config.train.model_dir, params=params, config=run_config) train_data, test_data = data_loader.make_train_and_test_set() train_input_fn, train_input_hook = data_loader.make_batch( train_data, batch_size=Config.model.batch_size, scope="train") test_input_fn, test_input_hook = data_loader.make_batch( test_data, batch_size=Config.model.batch_size, scope="test") train_hooks = [train_input_hook] if Config.train.debug: train_hooks.append(tf_debug.LocalCLIDebugHook()) if Config.train.print_verbose: train_hooks.append( tf.train.LoggingTensorHook( ["loss/reconstruction_error", "loss/kl_divergence"], every_n_iter=Config.train.check_hook_n_iter)) eval_hooks = [test_input_hook] if Config.train.debug: eval_hooks.append(tf_debug.LocalCLIDebugHook()) experiment = tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=train_input_fn, eval_input_fn=test_input_fn, train_steps=Config.train.train_steps, min_eval_frequency=Config.train.min_eval_frequency, train_monitors=train_hooks, eval_hooks=eval_hooks) return experiment