def main(_): model_class = models.get_model_class(FLAGS.model) # Look up the model configuration. assert (FLAGS.config_name is None) != (FLAGS.config_json is None), ( "Exactly one of --config_name or --config_json is required.") config = ( models.get_model_config(FLAGS.model, FLAGS.config_name) if FLAGS.config_name else config_util.parse_json(FLAGS.config_json)) config = configdict.ConfigDict(config) # Create the estimator. estimator = estimator_util.create_estimator( model_class, config.hparams, model_dir=FLAGS.model_dir) # Create an input function that reads the evaluation dataset. input_fn = estimator_util.create_input_fn( file_pattern=FLAGS.eval_files, input_config=config.inputs, mode=tf.estimator.ModeKeys.EVAL) # Run evaluation. This will log the result to stderr and also write a summary # file in the model_dir. estimator_util.evaluate(estimator, input_fn, eval_name=FLAGS.eval_name)
def main(_): model_class = models.get_model_class(FLAGS.model) # Look up the model configuration. assert (FLAGS.config_name is None) != (FLAGS.config_json is None), ( "Exactly one of --config_name or --config_json is required.") config = ( models.get_model_config(FLAGS.model, FLAGS.config_name) if FLAGS.config_name else config_util.parse_json(FLAGS.config_json)) config = configdict.ConfigDict(config) # Create the estimator. estimator = estimator_util.create_estimator( model_class, config.hparams, model_dir=FLAGS.model_dir) # Print no. of trainable parameters to console. var_names = [v for v in estimator.get_variable_names()] n_params = np.sum([len(estimator.get_variable_value(v).flatten()) for v in var_names]) print("Trainable parameters in model:", int(n_params)) # Create an input function that reads the evaluation dataset. input_fn = estimator_util.create_input_fn( file_pattern=FLAGS.eval_files, input_config=config.inputs, mode=tf.estimator.ModeKeys.EVAL) # Run evaluation. This will log the result to stderr and also write a summary # file in the model_dir. estimator_util.evaluate(estimator, input_fn, eval_name=FLAGS.eval_name)
def main(_): model_class = models.get_model_class(FLAGS.model) # Look up the model configuration. assert (FLAGS.config_name is None) != (FLAGS.config_json is None), ( "Exactly one of --config_name or --config_json is required.") config = ( models.get_model_config(FLAGS.model, FLAGS.config_name) if FLAGS.config_name else config_util.parse_json(FLAGS.config_json)) config = configdict.ConfigDict(config) config_util.log_and_save_config(config, FLAGS.model_dir) # Create the estimator. run_config = tf.estimator.RunConfig(keep_checkpoint_max=1) estimator = estimator_util.create_estimator(model_class, config.hparams, run_config, FLAGS.model_dir) # Create an input function that reads the training dataset. We iterate through # the dataset once at a time if we are alternating with evaluation, otherwise # we iterate infinitely. train_input_fn = estimator_util.create_input_fn( file_pattern=FLAGS.train_files, input_config=config.inputs, mode=tf.estimator.ModeKeys.TRAIN, shuffle_values_buffer=FLAGS.shuffle_buffer_size, repeat=1 if FLAGS.eval_files else None) if not FLAGS.eval_files: estimator.train(train_input_fn, max_steps=FLAGS.train_steps) else: eval_input_fn = estimator_util.create_input_fn( file_pattern=FLAGS.eval_files, input_config=config.inputs, mode=tf.estimator.ModeKeys.EVAL) eval_args = { "val": (eval_input_fn, None) # eval_name: (input_fn, eval_steps) } for _ in estimator_runner.continuous_train_and_eval( estimator=estimator, train_input_fn=train_input_fn, eval_args=eval_args, train_steps=FLAGS.train_steps): # continuous_train_and_eval() yields evaluation metrics after each # training epoch. We don't do anything here. pass