def train_and_eval(): """Train and evaluate a model.""" save_summary_steps = FLAGS.save_summaries_steps save_checkpoints_steps = FLAGS.save_checkpoints_steps log_step_count = FLAGS.log_step_count config = tf_estimator.RunConfig( save_summary_steps=save_summary_steps, save_checkpoints_steps=save_checkpoints_steps, log_step_count_steps=log_step_count, keep_checkpoint_max=None) params = {'dummy': 0} estimator = tf_estimator.Estimator(model_fn=model_fn, model_dir=FLAGS.checkpoint_dir, config=config, params=params) train_spec = tf_estimator.TrainSpec(input_fn=train_input_fn, max_steps=FLAGS.train_steps) eval_spec = tf_estimator.EvalSpec(input_fn=eval_input_fn, start_delay_secs=60, steps=FLAGS.eval_examples, throttle_secs=60) tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)
def run(): """Runs train_and_evaluate.""" hparams_filename = os.path.join(FLAGS.model_dir, 'hparams.json') if FLAGS.is_chief: gfile.MakeDirs(FLAGS.model_dir) hparams = core.read_hparams(FLAGS.hparams, get_hparams()) core.write_hparams(hparams, hparams_filename) # Always load HParams from model_dir. hparams = core.wait_for_hparams(hparams_filename, get_hparams()) grammar = grammar_utils.load_grammar(grammar_path=hparams.grammar_path) estimator = tf_estimator.Estimator( model_fn=functools.partial(model_fn, grammar=grammar), params=hparams, config=tf_estimator.RunConfig( save_checkpoints_secs=hparams.save_checkpoints_secs, keep_checkpoint_max=hparams.keep_checkpoint_max)) train_spec = tf_estimator.TrainSpec(input_fn=functools.partial( input_ops.input_fn, input_pattern=hparams.train_pattern, grammar=grammar), max_steps=hparams.train_steps) # NOTE(leeley): The SavedModel will be stored under the # tf.saved_model.tag_constants.SERVING tag. latest_exporter = tf_estimator.LatestExporter( name='latest_exported_model', serving_input_receiver_fn=functools.partial( input_ops.serving_input_receiver_fn, params=hparams, num_production_rules=grammar.num_production_rules), exports_to_keep=hparams.exports_to_keep) eval_hooks = [] if hparams.num_expressions_per_condition > 0: eval_hooks.append( metrics.GenerationWithLeadingPowersHook( generation_leading_powers_abs_sums=core.hparams_list_value( hparams.generation_leading_powers_abs_sums), num_expressions_per_condition=hparams. num_expressions_per_condition, max_length=hparams.max_length, grammar=grammar)) eval_spec = tf_estimator.EvalSpec( input_fn=functools.partial(input_ops.input_fn, input_pattern=hparams.tune_pattern, grammar=grammar), steps=hparams.eval_steps, exporters=latest_exporter, start_delay_secs=hparams.start_delay_secs, throttle_secs=hparams.throttle_secs, hooks=eval_hooks) tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)
def main(argv): del argv # Unused. if FLAGS.output_dir is None: raise ValueError("`output_dir` must be defined") if FLAGS.delete_existing and tf.gfile.Exists(FLAGS.output_dir): tf.logging.warn("Deleting old log directory at {}".format( FLAGS.output_dir)) tf.gfile.DeleteRecursively(FLAGS.output_dir) tf.gfile.MakeDirs(FLAGS.output_dir) print("Logging to {}".format(FLAGS.output_dir)) # Load the training or test split of the Celeb-A filenames. if FLAGS.celeba_dir is None: raise ValueError("`celeba_dir` must be defined") celeba_dataset_path = \ os.path.join(FLAGS.celeba_dir, "Img/img_align_celeba/") celeba_partition_path = \ os.path.join(FLAGS.celeba_dir, "Eval/list_eval_partition.txt") with open(celeba_partition_path, "r") as fid: partition = fid.readlines() filenames, splits = zip(*[x.split() for x in partition]) filenames = np.array( [os.path.join(celeba_dataset_path, f) for f in filenames]) splits = np.array([int(x) for x in splits]) with tf.Graph().as_default(): train_input_fn = prep_dataset_fn(filenames, splits, is_training=True) eval_input_fn = prep_dataset_fn(filenames, splits, is_training=False) estimator = tf_estimator.Estimator( model_fn, config=tf_estimator.RunConfig( model_dir=FLAGS.output_dir, save_checkpoints_steps=FLAGS.viz_steps, ), ) train_spec = tf_estimator.TrainSpec(input_fn=train_input_fn, max_steps=FLAGS.max_steps) # Sad ugly hack here. Setting steps=None should go through all of the # validation set, but doesn't seem to, so I'm doing it manually. eval_spec = tf_estimator.EvalSpec(input_fn=eval_input_fn, steps=len(filenames[splits == 1]) // FLAGS.batch_size, start_delay_secs=0, throttle_secs=0) for _ in range(FLAGS.max_steps // FLAGS.viz_steps): tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)
def main(_): cpu = os.cpu_count() tf_config = _tf_config(flags) #1 # 分布式需要 TF_CONFIG 环境变量 os.environ['TF_CONFIG'] = json.dumps(tf_config) #2 session_config = ConfigProto(device_count={'CPU': cpu}, inter_op_parallelism_threads=cpu // 2, intra_op_parallelism_threads=cpu // 2, device_filters=flags.device_filters, allow_soft_placement=True) strategy = experimental.ParameterServerStrategy() run_config = estimator.RunConfig( **{ 'save_summary_steps': 100, 'save_checkpoints_steps': 1000, 'keep_checkpoint_max': 10, 'log_step_count_steps': 100, 'train_distribute': strategy, 'eval_distribute': strategy, }).replace(session_config=session_config) model = estimator.Estimator( model_fn=model_fn, model_dir='/home/axing/din/checkpoints/din', #实际应用中是分布式文件系统 config=run_config, params={ 'tf_config': tf_config, 'decay_rate': 0.9, 'decay_steps': 10000, 'learning_rate': 0.1 }) train_spec = estimator.TrainSpec( input_fn=lambda: input_fn(mode='train', num_workers=flags.num_workers, worker_index=flags.worker_index, pattern='/home/axing/din/dataset/*'), #3 max_steps=1000 #4 ) # 这里就假设验证集和训练集地址一样了,实际应用中是肯定不一样的。 eval_spec = estimator.EvalSpec( input_fn=lambda: input_fn(mode='eval', pattern='/home/axing/din/dataset/*'), steps=100, # 每次验证 100 个 batch size 的数据 throttle_secs=60 # 每隔至少 60 秒验证一次 ) estimator.train_and_evaluate(model, train_spec, eval_spec)
def train(data_base_path, output_dir, label_vocab_path, hparams_set_name, train_fold, eval_fold): """Constructs trains, and evaluates a model on the given input data. Args: data_base_path: str. Directory path containing tfrecords named like "train", "dev" and "test" output_dir: str. Path to save checkpoints. label_vocab_path: str. Path to tsv file containing columns _VOCAB_ITEM_COLUMN_NAME and _VOCAB_INDEX_COLUMN_NAME. See testdata/label_vocab.tsv for an example. hparams_set_name: name of a function in the hparams module which returns a tf.contrib.training.HParams object. train_fold: fold to use for training data (one of protein_dataset.DATA_FOLD_VALUES) eval_fold: fold to use for training data (one of protein_dataset.DATA_FOLD_VALUES) Returns: A tuple of the evaluation metrics, and the exported objects from Estimator. """ hparams = get_hparams(hparams_set_name) label_vocab = parse_label_vocab(label_vocab_path) (estimator, train_spec, eval_spec) = _make_estimator_and_inputs(hparams=hparams, label_vocab=label_vocab, data_base_path=data_base_path, output_dir=output_dir, train_fold=train_fold, eval_fold=eval_fold) return tf_estimator.train_and_evaluate(estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
def main(unused_argv): flags.mark_flag_as_required('model_dir') flags.mark_flag_as_required('pipeline_config_path') config = tf_estimator.RunConfig(model_dir=FLAGS.model_dir) train_and_eval_dict = model_lib.create_estimator_and_inputs( run_config=config, pipeline_config_path=FLAGS.pipeline_config_path, train_steps=FLAGS.num_train_steps, sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples, sample_1_of_n_eval_on_train_examples=( FLAGS.sample_1_of_n_eval_on_train_examples)) estimator = train_and_eval_dict['estimator'] train_input_fn = train_and_eval_dict['train_input_fn'] eval_input_fns = train_and_eval_dict['eval_input_fns'] eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn'] predict_input_fn = train_and_eval_dict['predict_input_fn'] train_steps = train_and_eval_dict['train_steps'] if FLAGS.checkpoint_dir: if FLAGS.eval_training_data: name = 'training_data' input_fn = eval_on_train_input_fn else: name = 'validation_data' # The first eval input will be evaluated. input_fn = eval_input_fns[0] if FLAGS.run_once: estimator.evaluate(input_fn, steps=None, checkpoint_path=tf.train.latest_checkpoint( FLAGS.checkpoint_dir)) else: model_lib.continuous_eval(estimator, FLAGS.checkpoint_dir, input_fn, train_steps, name, FLAGS.max_eval_retries) else: train_spec, eval_specs = model_lib.create_train_and_eval_specs( train_input_fn, eval_input_fns, eval_on_train_input_fn, predict_input_fn, train_steps, eval_on_train_data=False) # Currently only a single Eval Spec is allowed. tf_estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])
def main(_): inference_fn = network.inference hparams = contrib_training.HParams(learning_rate=FLAGS.learning_rate) model_fn = estimator.create_model_fn(inference_fn, hparams) config = tf_estimator.RunConfig(FLAGS.model_dir) tf_estimator = tf_estimator.Estimator(model_fn=model_fn, config=config) train_dataset_fn = dataset.create_dataset_fn(FLAGS.train_pattern, height=FLAGS.image_size, width=FLAGS.image_size, batch_size=FLAGS.batch_size) eval_dataset_fn = dataset.create_dataset_fn(FLAGS.test_pattern, height=FLAGS.image_size, width=FLAGS.image_size, batch_size=FLAGS.batch_size) train_spec, eval_spec = estimator.create_train_and_eval_specs( train_dataset_fn, eval_dataset_fn) tf.logging.set_verbosity(tf.logging.INFO) tf_estimator.train_and_evaluate(tf_estimator, train_spec, eval_spec)
def continuous_train_and_eval(self, continuous_eval_predicate_fn=None): del continuous_eval_predicate_fn tf_estimator.train_and_evaluate(self._estimator, self._train_spec, self._eval_spec) return self.evaluate()
def run_model(): """Run experiment with tf.estimator. """ params = { 'kb_index': FLAGS.kb_index, 'cm_width': FLAGS.cm_width, 'cm_depth': FLAGS.cm_depth, 'entity_emb_size': FLAGS.entity_emb_size, 'relation_emb_size': FLAGS.relation_emb_size, 'vocab_emb_size': FLAGS.vocab_emb_size, 'max_set': FLAGS.max_set, 'learning_rate': FLAGS.learning_rate, 'gradient_clip': FLAGS.gradient_clip, 'intermediate_top_k': FLAGS.intermediate_top_k, 'use_cm_sketch': FLAGS.use_cm_sketch, 'train_entity_emb': FLAGS.train_entity_emb, 'train_relation_emb': FLAGS.train_relation_emb, 'bert_handle': FLAGS.bert_handle, 'train_bert': FLAGS.train_bert, } data_loader = DataLoader(params, FLAGS.name, get_root_dir(FLAGS.name), FLAGS.kb_file, FLAGS.vocab_file) estimator_config = tf_estimator.RunConfig( save_checkpoints_steps=FLAGS.checkpoint_step) warm_start_settings = tf_estimator.WarmStartSettings( # pylint: disable=g-long-ternary ckpt_to_initialize_from=FLAGS.load_model_dir, vars_to_warm_start=[ 'embeddings_mat/entity_embeddings_mat', 'embeddings_mat/relation_embeddings_mat' ], ) if FLAGS.load_model_dir is not None else None estimator = tf_estimator.Estimator( model_fn=build_model_fn(FLAGS.name, data_loader, FLAGS.eval_name, FLAGS.eval_metric_at_k), model_dir=FLAGS.checkpoint_dir + FLAGS.model_name, config=estimator_config, params=params, warm_start_from=warm_start_settings) if FLAGS.mode == 'train': train_input_fn = data_loader.build_input_fn( name=FLAGS.name, batch_size=FLAGS.batch_size, mode='train', epochs=FLAGS.epochs, n_take=-1, shuffle=True) eval_input_fn = data_loader.build_input_fn(name=FLAGS.name, batch_size=FLAGS.batch_size, mode='eval', epochs=1, n_take=FLAGS.num_eval, shuffle=False) # Define mode-specific operations if FLAGS.mode == 'train': train_spec = tf_estimator.TrainSpec(input_fn=train_input_fn) # Busy waiting for evaluation until new checkpoint comes out test_spec = tf_estimator.EvalSpec(input_fn=eval_input_fn, steps=FLAGS.num_online_eval, start_delay_secs=0, throttle_secs=FLAGS.eval_time) tf_estimator.train_and_evaluate(estimator, train_spec, test_spec) elif FLAGS.mode == 'eval': tf_evaluation = estimator.evaluate(eval_input_fn) print(tf_evaluation) elif FLAGS.mode == 'pred': tf_predictions = estimator.predict(eval_input_fn) if FLAGS.name.startswith('query2box'): task = FLAGS.name.split('_')[-1] metrics = Query2BoxMetrics(task, FLAGS.root_dir, data_loader) else: raise NotImplementedError() for tf_prediction in tqdm(tf_predictions): metrics.eval(tf_prediction) metrics.print_metrics() else: raise ValueError('mode not recognized: %s' % FLAGS.mode)
def main(_): # Modify the paths to save results when tuning hyperparameters. if FLAGS.node_encoder == "lstm": FLAGS.result_path = os.path.join(FLAGS.result_path, str(FLAGS.node_lstm_size)) if FLAGS.node_encoder == "transformer": FLAGS.result_path = os.path.join( FLAGS.result_path, "max_steps_" + str(FLAGS.max_steps_no_increase)) FLAGS.result_path = os.path.join( FLAGS.result_path, "hidden_unit_" + str(FLAGS.transformer_hidden_unit)) if FLAGS.cross_vertical: FLAGS.result_path = os.path.join( FLAGS.result_path, "CKP-{0}/{1}/".format(FLAGS.checkpoint_vertical, FLAGS.checkpoint_websites)) FLAGS.checkpoint_path = os.path.join( FLAGS.checkpoint_path, "{0}/{1}-results/".format(FLAGS.checkpoint_vertical, FLAGS.checkpoint_websites)) tf.gfile.MakeDirs( os.path.join( FLAGS.result_path, "{0}/{1}-results/".format(FLAGS.vertical, FLAGS.source_website))) tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.use_uniform_embedding: vocab_vertical = "all" else: vocab_vertical = FLAGS.vertical # Hyper-parameters. params = { "add_goldmine": FLAGS.add_goldmine, "add_leaf_types": FLAGS.add_leaf_types, "batch_size": FLAGS.batch_size, "buffer": 1000, # Buffer for shuffling. No need to care about. "chars": os.path.join(FLAGS.domtree_data_path, "%s.vocab.chars.txt" % vocab_vertical), "circle_features": FLAGS.circle_features, "dim_word_embedding": FLAGS.dim_word_embedding, "dim_chars": FLAGS.dim_chars, "dim_label_embedding": FLAGS.dim_label_embedding, "dim_goldmine": 30, "dim_leaf_type": 20, "dim_positions": 30, "dim_xpath_units": FLAGS.dim_xpath_units, "dropout": 0.3, "epochs": FLAGS.epochs, "extract_node_emb": FLAGS.extract_node_emb, "filters": 50, # The dimension of char-level word representations. "friend_encoder": FLAGS.friend_encoder, "use_friend_semantic": FLAGS.use_friend_semantic, "goldmine_features": os.path.join(FLAGS.domtree_data_path, "vocab.goldmine_features.txt"), "glove": os.path.join( FLAGS.domtree_data_path, "%s.%d.emb.npz" % (vocab_vertical, FLAGS.dim_word_embedding)), "friend_hidden_size": FLAGS.friend_hidden_size, "kernel_size": 3, # CNN window size to embed char sequences. "last_hidden_layer_size": FLAGS.last_hidden_layer_size, "leaf_types": os.path.join(FLAGS.domtree_data_path, "%s.vocab.leaf_types.txt" % vocab_vertical), "lstm_size": 100, "max_steps_no_increase": FLAGS.max_steps_no_increase, "node_encoder": FLAGS.node_encoder, "node_filters": 100, "node_kernel_size": 5, "node_lstm_size": FLAGS.node_lstm_size, "num_oov_buckets": 1, "objective": FLAGS.objective, "positions": os.path.join(FLAGS.domtree_data_path, "vocab.positions.txt"), "running_mode": FLAGS.run, "semantic_encoder": FLAGS.semantic_encoder, "source_website": FLAGS.source_website, "tags": os.path.join(FLAGS.domtree_data_path, "%s.vocab.tags.txt" % (FLAGS.vertical)), "tags-all": os.path.join(FLAGS.domtree_data_path, "all.vocab.tags.txt"), "target_website": FLAGS.target_website, "transformer_hidden_unit": FLAGS.transformer_hidden_unit, "transformer_head": FLAGS.transformer_head, "transformer_hidden_layer": FLAGS.transformer_hidden_layer, "use_crf": FLAGS.use_crf, "use_friends_cnn": FLAGS.use_friends_cnn, "use_friends_discrete_feature": FLAGS.use_friends_discrete_feature, "use_prev_text_lstm": FLAGS.use_prev_text_lstm, "use_xpath_lstm": FLAGS.use_xpath_lstm, "use_uniform_label": FLAGS.use_uniform_label, "use_position_embedding": FLAGS.use_position_embedding, "words": os.path.join(FLAGS.domtree_data_path, "%s.vocab.words.txt" % vocab_vertical), "xpath_lstm_size": 100, "xpath_units": os.path.join(FLAGS.domtree_data_path, "%s.vocab.xpath_units.txt" % vocab_vertical), } with tf.gfile.Open( os.path.join( FLAGS.result_path, "{0}/{1}-results/params.json".format(FLAGS.vertical, FLAGS.source_website)), "w") as f: json.dump(params, f, indent=4, sort_keys=True) # Build estimator, train and evaluate. train_input_function = functools.partial( model_util.joint_input_fn, get_data_path(vertical=FLAGS.vertical, website=FLAGS.source_website, dev=False, goldmine=False), get_data_path(vertical=FLAGS.vertical, website=FLAGS.source_website, dev=False, goldmine=True), FLAGS.vertical, params, shuffle_and_repeat=True, mode="train") cfg = tf_estimator.RunConfig(save_checkpoints_steps=300, save_summary_steps=300, tf_random_seed=42) # Set up the checkpoint to load. if FLAGS.checkpoint_path: # The best model was always saved in "cpkt-601". checkpoint_file = FLAGS.checkpoint_path + "/model/model.ckpt-601" # Do not load parameters whose names contain the "label_dense". # These parameters are ought to be learned from scratch. ws = tf_estimator.WarmStartSettings( ckpt_to_initialize_from=checkpoint_file, vars_to_warm_start="^((?!label_dense).)*$") estimator = tf_estimator.Estimator(models.joint_extraction_model_fn, os.path.join( FLAGS.result_path, "{0}/{1}-results/model".format( FLAGS.vertical, FLAGS.source_website)), cfg, params, warm_start_from=ws) else: estimator = tf_estimator.Estimator( models.joint_extraction_model_fn, os.path.join( FLAGS.result_path, "{0}/{1}-results/model".format(FLAGS.vertical, FLAGS.source_website)), cfg, params) tf.gfile.MakeDirs(estimator.eval_dir()) hook = early_stopping.stop_if_no_increase_hook( estimator, metric_name="f1", max_steps_without_increase=FLAGS.max_steps_no_increase, min_steps=300, run_every_steps=100, run_every_secs=None) train_spec = tf_estimator.TrainSpec(input_fn=train_input_function, hooks=[hook]) if FLAGS.run == "train": eval_input_function = functools.partial( model_util.joint_input_fn, get_data_path(vertical=FLAGS.vertical, website=FLAGS.source_website, dev=True, goldmine=False), get_data_path(vertical=FLAGS.vertical, website=FLAGS.source_website, dev=True, goldmine=True), FLAGS.vertical, mode="all") eval_spec = tf_estimator.EvalSpec(input_fn=eval_input_function, steps=300, throttle_secs=1) tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec) target_websites = FLAGS.target_website.split("_") if FLAGS.source_website not in target_websites: target_websites = [FLAGS.source_website] + target_websites for target_website in target_websites: write_predictions(estimator=estimator, vertical=FLAGS.vertical, source_website=FLAGS.source_website, target_website=target_website) model_util.page_hits_level_metric(result_path=FLAGS.result_path, vertical=FLAGS.vertical, source_website=FLAGS.source_website, target_website=target_website) model_util.site_level_voting(result_path=FLAGS.result_path, vertical=FLAGS.vertical, source_website=FLAGS.source_website, target_website=target_website) model_util.page_level_constraint( domtree_data_path=FLAGS.domtree_data_path, result_path=FLAGS.result_path, vertical=FLAGS.vertical, source_website=FLAGS.source_website, target_website=target_website)
def run_experiment(model_fn, train_input_fn, eval_input_fn, exporters=None, params=None, params_fname=None): """Run an experiment using estimators. This is a light wrapper around typical estimator usage to avoid boilerplate code. Please use the following components separately for more complex usages. Args: model_fn: A model function to be passed to the estimator. See https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#args_1 train_input_fn: An input function to be passed to the estimator that corresponds to the training data. See https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#train eval_input_fn: An input function to be passed to the estimator that corresponds to the held-out eval data. See https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#evaluate exporters: (Optional) An tf.estimator.Exporter or a list of them. params: (Optional) A dictionary of parameters that will be accessible by the model_fn and input_fns. The 'batch_size' and 'use_tpu' values will be set automatically. params_fname: (Optional) If specified, `params` will be written to here under `FLAGS.model_dir` in JSON format. """ params = params if params is not None else {} params.setdefault("use_tpu", FLAGS.use_tpu) if FLAGS.model_dir and params_fname: tf.io.gfile.makedirs(FLAGS.model_dir) params_path = os.path.join(FLAGS.model_dir, params_fname) with tf.io.gfile.GFile(params_path, "w") as params_file: json.dump(params, params_file, indent=2, sort_keys=True) if params["use_tpu"]: if FLAGS.tpu_name: tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) else: tpu_cluster_resolver = None run_config = tf_estimator.tpu.RunConfig( cluster=tpu_cluster_resolver, master=FLAGS.master, model_dir=FLAGS.model_dir, tf_random_seed=FLAGS.tf_random_seed, save_checkpoints_steps=FLAGS.save_checkpoints_steps, tpu_config=tf_estimator.tpu.TPUConfig( iterations_per_loop=FLAGS.save_checkpoints_steps)) if "batch_size" in params: # Let the TPUEstimator fill in the batch size. params.pop("batch_size") estimator = tf_estimator.tpu.TPUEstimator( use_tpu=True, model_fn=model_fn, params=params, config=run_config, train_batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.eval_batch_size, predict_batch_size=FLAGS.eval_batch_size) else: run_config = tf_estimator.RunConfig( model_dir=FLAGS.model_dir, tf_random_seed=FLAGS.tf_random_seed, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_max=FLAGS.keep_checkpoint_max) params["batch_size"] = FLAGS.batch_size estimator = tf_estimator.Estimator( config=run_config, model_fn=model_fn, params=params, model_dir=FLAGS.model_dir) train_spec = tf_estimator.TrainSpec( input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) eval_spec = tf_estimator.EvalSpec( name="default", input_fn=eval_input_fn, exporters=exporters, start_delay_secs=FLAGS.eval_start_delay_secs, throttle_secs=FLAGS.eval_throttle_secs, steps=FLAGS.num_eval_steps) tf.logging.set_verbosity(tf.logging.INFO) tf_estimator.train_and_evaluate( estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
def run_model(): """Instantiate and run model. Raises: ValueError: if model_name is not implemented. ValueError: if dataset is not implemented. """ if FLAGS.model_name not in MODEL_KEYS: raise ValueError("Model {} is not implemented.".format( FLAGS.model_name)) else: model_dir, model_name, print_dir = _initialize_model_dir() tf.logging.info( "Creating experiment, storing model files in {}".format(model_dir)) # Instantiates dataset and gets input_fn if FLAGS.dataset == "law_school": load_dataset = LawSchoolInput(dataset_base_dir=FLAGS.dataset_base_dir, train_file=FLAGS.train_file, test_file=FLAGS.test_file) elif FLAGS.dataset == "compas": load_dataset = CompasInput(dataset_base_dir=FLAGS.dataset_base_dir, train_file=FLAGS.train_file, test_file=FLAGS.test_file) elif FLAGS.dataset == "uci_adult": load_dataset = UCIAdultInput(dataset_base_dir=FLAGS.dataset_base_dir, train_file=FLAGS.train_file, test_file=FLAGS.test_file) else: raise ValueError("Input_fn for {} dataset is not implemented.".format( FLAGS.dataset)) train_input_fn = load_dataset.get_input_fn( mode=tf_estimator.ModeKeys.TRAIN, batch_size=FLAGS.batch_size) test_input_fn = load_dataset.get_input_fn(mode=tf_estimator.ModeKeys.EVAL, batch_size=FLAGS.batch_size) feature_columns, _, protected_groups, label_column_name = ( load_dataset.get_feature_columns( embedding_dimension=FLAGS.embedding_dimension, include_sensitive_columns=FLAGS.include_sensitive_columns)) # Constructs a int list enumerating the number of subgroups in the dataset. # # For example, if the dataset has two (binary) protected_groups. The dataset has 2^2 = 4 subgroups, which we enumerate as [0, 1, 2, 3]. # # If the dataset has two protected features ["race","sex"] that are cast as binary features race=["White"(0), "Black"(1)], and sex=["Male"(0), "Female"(1)]. # # We call their catesian product ["White Male" (00), "White Female" (01), "Black Male"(10), "Black Female"(11)] as subgroups which are enumerated as [0, 1, 2, 3]. subgroups = np.arange( len(protected_groups) * 2) # Assumes each protected_group has two possible values. # Instantiates tf.estimator.Estimator object estimator = get_estimator(model_dir, model_name, feature_columns=feature_columns, label_column_name=label_column_name) # Adds additional fairness metrics fairness_metrics = RobustFairnessMetrics( label_column_name=label_column_name, protected_groups=protected_groups, subgroups=subgroups, print_dir=print_dir) eval_metrics_fn = fairness_metrics.create_fairness_metrics_fn() estimator = tf_estimator.add_metrics(estimator, eval_metrics_fn) # Creates training and evaluation specifications train_steps = int(FLAGS.total_train_steps / FLAGS.batch_size) train_spec = tf_estimator.TrainSpec(input_fn=train_input_fn, max_steps=train_steps) eval_spec = tf_estimator.EvalSpec(input_fn=test_input_fn, steps=FLAGS.test_steps) tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec) tf.logging.info("Training completed.") eval_results = estimator.evaluate(input_fn=test_input_fn, steps=FLAGS.test_steps) eval_results_path = os.path.join(model_dir, FLAGS.output_file_name) write_to_output_file(eval_results, eval_results_path)