def main(_): params = dict(data_root=FLAGS.data_root, batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.batch_size, query_seq_len=FLAGS.query_seq_len, block_seq_len=FLAGS.block_seq_len, learning_rate=FLAGS.learning_rate, num_classes=FLAGS.num_classes, num_train_steps=FLAGS.num_train_steps, retriever_module_path=FLAGS.retriever_module_path, reader_module_path=FLAGS.reader_module_path, retriever_beam_size=FLAGS.retriever_beam_size, reader_beam_size=FLAGS.reader_beam_size, reader_seq_len=FLAGS.reader_seq_len, span_hidden_size=FLAGS.span_hidden_size, max_span_width=FLAGS.max_span_width, block_records_path=FLAGS.block_records_path, num_block_records=FLAGS.num_block_records) train_input_fn = functools.partial(text_classifier_model.input_fn, name=FLAGS.dataset_name, is_train=True) eval_input_fn = functools.partial(text_classifier_model.input_fn, name=FLAGS.dataset_name, is_train=False) experiment_utils.run_experiment(model_fn=text_classifier_model.model_fn, params=params, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, exporters=text_classifier_model.exporter(), params_fname="params.json")
def test_run_experiment_tpu(self): params = dict(use_tpu=True) experiment_utils.run_experiment( model_fn=self._simple_model_fn, train_input_fn=self._simple_input_function, eval_input_fn=self._simple_input_function, params=params)
def train(): """Train the model.""" embeddings = load_embeddings() # Need a named parameter `param` since this will be called # with named arguments, so pylint: disable=unused-argument def model_function(features, labels, mode, params): """Builds the `tf.estimator.EstimatorSpec` to train/eval with.""" is_train = mode == tf.estimator.ModeKeys.TRAIN logits = predict(is_train, embeddings, features["premise"], features["hypothesis"]) loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.to_int32(labels), logits=logits) loss = tf.reduce_mean(loss) if mode == tf.estimator.ModeKeys.TRAIN: train_op = get_train_op(loss) else: # Don't build the train_op unnecessarily, since the ADAM variables can # cause problems with loading checkpoints on CPUs. train_op = None metrics = dict( accuracy=tf.metrics.accuracy( tf.argmax(logits, 1, output_type=tf.int32), tf.to_int32(labels))) checkpoint_file = FLAGS.checkpoint_file if checkpoint_file is None: scaffold = None else: saver = tf.train.Saver(tf.trainable_variables()) def _init_fn(_, sess): saver.restore(sess, checkpoint_file) scaffold = tf.train.Scaffold(init_fn=_init_fn) return tf.estimator.EstimatorSpec( mode=mode, scaffold=scaffold, loss=loss, predictions=None, train_op=train_op, eval_metric_ops=metrics) def compare_fn(best_eval_result, current_eval_result): return best_eval_result["accuracy"] < current_eval_result["accuracy"] exporter = best_checkpoint_exporter.BestCheckpointExporter( event_file_pattern="eval_default/*.tfevents.*", compare_fn=compare_fn, ) experiment_utils.run_experiment( model_fn=model_function, train_input_fn=lambda: load_batched_dataset(True, embeddings), eval_input_fn=lambda: load_batched_dataset(False, embeddings), exporters=[exporter])
def main(_): model_function, train_input_fn, eval_input_fn, serving_input_receiver_fn = ( nq_short_pipeline_model.experiment_functions()) best_exporter = tf_estimator.BestExporter( name="best", serving_input_receiver_fn=serving_input_receiver_fn, event_file_pattern="eval_default/*.tfevents.*", compare_fn=nq_short_pipeline_model.compare_metrics) experiment_utils.run_experiment(model_fn=model_function, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, exporters=[best_exporter])
def main(_): params = dict(batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.eval_batch_size, bert_hub_module_path=FLAGS.bert_hub_module_path, query_seq_len=FLAGS.query_seq_len, block_seq_len=FLAGS.block_seq_len, projection_size=FLAGS.projection_size, learning_rate=FLAGS.learning_rate, examples_path=FLAGS.examples_path, mask_rate=FLAGS.mask_rate, num_train_steps=FLAGS.num_train_steps, num_block_records=FLAGS.num_block_records, num_input_threads=FLAGS.num_input_threads) experiment_utils.run_experiment( model_fn=ict_model.model_fn, train_input_fn=functools.partial(ict_model.input_fn, is_train=True), eval_input_fn=functools.partial(ict_model.input_fn, is_train=False), exporters=ict_model.exporter(), params=params)
def main(_): params = dict( batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.eval_batch_size, bert_hub_module_handle=FLAGS.bert_hub_module_handle, embedder_hub_module_handle=FLAGS.embedder_hub_module_handle, vocab_path=FLAGS.vocab_path, do_lower_case=FLAGS.do_lower_case, query_seq_len=FLAGS.query_seq_len, candidate_seq_len=FLAGS.candidate_seq_len, max_masks=FLAGS.max_masks, learning_rate=FLAGS.learning_rate, num_input_threads=FLAGS.num_input_threads, num_candidates=FLAGS.num_candidates, num_train_steps=FLAGS.num_train_steps, train_preprocessing_servers=FLAGS.train_preprocessing_servers, eval_preprocessing_servers=FLAGS.eval_preprocessing_servers, share_embedders=FLAGS.share_embedders, separate_candidate_segments=FLAGS.separate_candidate_segments) experiment_utils.run_experiment( model_fn=model.model_fn, train_input_fn=functools.partial(model.input_fn, is_train=True), eval_input_fn=functools.partial(model.input_fn, is_train=False), params=params, params_fname="estimator_params.json", exporters=model.get_exporters(params)) # Write a "done" file from the trainer. As in experiment_utils, we currently # use 'use_tpu' as a proxy for whether this is a train or eval node. # # We could also use the 'type' field in the 'task' of the TF_CONFIG # environment variable, but we would generally like to get away from TF_CONFIG # in the future. # # This file is checked for existence by refresh_doc_embeds. if experiment_utils.FLAGS.use_tpu: model_dir = experiment_utils.EstimatorSettings.from_flags().model_dir training_done_filename = os.path.join(model_dir, "TRAINING_DONE") with tf.gfile.GFile(training_done_filename, "w") as f: f.write("done")
def test_run_experiment(self): experiment_utils.run_experiment( model_fn=self._simple_model_fn, train_input_fn=self._simple_input_function, eval_input_fn=self._simple_input_function)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) tf.config.set_soft_device_placement(True) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) train_examples = None num_train_steps = None num_warmup_steps = None train_input_fn = None ft_known_train_file = None train_file = None if FLAGS.do_train: current_seed = 0 num_known_classes = FLAGS.num_domains * FLAGS.num_labels_per_domain data_output_dir = FLAGS.data_output_dir if not tf.gfile.Exists(data_output_dir): tf.gfile.MakeDirs(data_output_dir) known_ft_path = os.path.join(data_output_dir, "known_ft_train.tf_record") unknown_ft_path = os.path.join(data_output_dir, "unknown_ft_train.tf_record") if not tf.gfile.Glob(known_ft_path): preprocess_few_shot_training_data(tokenizer, known_ft_path, unknown_ft_path, current_seed) if FLAGS.continual_learning is None: assert False, "Not Implemented" elif FLAGS.continual_learning == "pretrain": train_file = os.path.join(FLAGS.data_output_dir, "known_ft_train.tf_record") num_classes = num_known_classes num_train_examples = num_known_classes * FLAGS.known_num_shots num_shots_per_class = FLAGS.known_num_shots elif FLAGS.continual_learning == "few_shot": train_file = os.path.join(FLAGS.data_output_dir, "unknown_ft_train.tf_record") ft_known_train_file = os.path.join(FLAGS.data_output_dir, "known_ft_train.tf_record") num_unknown_classes = NUM_CLASSES - num_known_classes num_classes = num_unknown_classes num_train_examples = num_unknown_classes * FLAGS.few_shot num_shots_per_class = FLAGS.few_shot tpu_split = FLAGS.tpu_split if FLAGS.use_tpu else 1 if num_shots_per_class < tpu_split: steps_per_epoch = 1 else: steps_per_epoch = num_shots_per_class // tpu_split num_train_steps = int(steps_per_epoch * FLAGS.num_train_epochs) num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) FLAGS.num_train_steps = num_train_steps FLAGS.save_checkpoints_steps = int(steps_per_epoch * FLAGS.save_every_epoch) tf.logging.info("***** Running training *****") tf.logging.info(" train_file: %s" % train_file) tf.logging.info(" use_tpu: %s" % FLAGS.use_tpu) tf.logging.info(" Num examples = %d", num_train_examples) tf.logging.info(" Batch size = %d", FLAGS.batch_size) tf.logging.info(" Save checkpoints steps = %d", FLAGS.save_checkpoints_steps) tf.logging.info(" warmup steps = %d", num_warmup_steps) tf.logging.info(" Num epochs = %d", FLAGS.num_train_epochs) tf.logging.info(" Num steps = %d", num_train_steps) tf.logging.info(" Reduce method = %s", FLAGS.reduce_method) tf.logging.info(" Max Seq Length = %d", FLAGS.max_seq_length) tf.logging.info(" learning_rate = %.7f", FLAGS.learning_rate) tf.logging.info(" dropout rate = %.4f", DROPOUT_PROB) train_input_fn = file_based_input_fn_builder( input_file=train_file, seq_length=FLAGS.max_seq_length, is_training=True, ft_known_train_file=ft_known_train_file, use_tpu=FLAGS.use_tpu) model_fn = model_fn_builder(bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=FLAGS.use_tpu, use_one_hot_embeddings=FLAGS.use_tpu) FLAGS.do_eval = False eval_input_fn = None params = _get_hparams() params.update(num_train_steps=num_train_steps) if not FLAGS.do_train: train_input_fn = eval_input_fn experiment_utils.run_experiment(model_fn=model_fn, train_input_fn=train_input_fn, eval_input_fn=train_input_fn, params=params)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.model == "seq2seq": assert FLAGS.rnn_cell == "lstm" assert FLAGS.att_type != "hyper" if FLAGS.model == "hypernet" and FLAGS.rank != FLAGS.decoder_dim: print("WARNING: recommended rank value: decoder_dim.") if FLAGS.att_neighbor: assert FLAGS.neighbor_dim == FLAGS.encoder_dim or FLAGS.att_type == "my" if FLAGS.use_copy or FLAGS.att_neighbor: assert FLAGS.att_type == "my" # These numbers are the target vocabulary sizes of the datasets. # It allows for using different vocabularies for source and targets, # following the implementation in Open-NMT. # I will later put these into command line arguments. if FLAGS.use_bpe: if FLAGS.dataset == "nyt": output_size = 10013 elif FLAGS.dataset == "giga": output_size = 24654 elif FLAGS.dataset == "cnnd": output_size = 10232 else: if FLAGS.dataset == "nyt": output_size = 68885 elif FLAGS.dataset == "giga": output_size = 107389 elif FLAGS.dataset == "cnnd": output_size = 21000 vocab = data.Vocab(FLAGS.vocab_path, FLAGS.vocab_size, FLAGS.dataset) hps = tf.contrib.training.HParams( sample_neighbor=FLAGS.sample_neighbor, use_cluster=FLAGS.use_cluster, binary_neighbor=FLAGS.binary_neighbor, att_neighbor=FLAGS.att_neighbor, encode_neighbor=FLAGS.encode_neighbor, sum_neighbor=FLAGS.sum_neighbor, dataset=FLAGS.dataset, rnn_cell=FLAGS.rnn_cell, output_size=output_size + vocab.offset, train_path=FLAGS.train_path, dev_path=FLAGS.dev_path, tie_embedding=FLAGS.tie_embedding, use_bpe=FLAGS.use_bpe, use_copy=FLAGS.use_copy, reuse_attention=FLAGS.reuse_attention, use_bridge=FLAGS.use_bridge, use_residual=FLAGS.use_residual, att_type=FLAGS.att_type, random_neighbor=FLAGS.random_neighbor, num_neighbors=FLAGS.num_neighbors, model=FLAGS.model, trainer=FLAGS.trainer, learning_rate=FLAGS.learning_rate, lr_schedule=FLAGS.lr_schedule, total_steps=FLAGS.total_steps, emb_dim=FLAGS.emb_dim, binary_dim=FLAGS.binary_dim, neighbor_dim=FLAGS.neighbor_dim, drop=FLAGS.drop, emb_drop=FLAGS.emb_drop, out_drop=FLAGS.out_drop, encoder_drop=FLAGS.encoder_drop, decoder_drop=FLAGS.decoder_drop, weight_decay=FLAGS.weight_decay, encoder_dim=FLAGS.encoder_dim, num_encoder_layers=FLAGS.num_encoder_layers, decoder_dim=FLAGS.decoder_dim, num_decoder_layers=FLAGS.num_decoder_layers, num_mlp_layers=FLAGS.num_mlp_layers, rank=FLAGS.rank, sigma_norm=FLAGS.sigma_norm, batch_size=FLAGS.batch_size, sampling_probability=FLAGS.sampling_probability, beam_width=FLAGS.beam_width, max_enc_steps=FLAGS.max_enc_steps, max_dec_steps=FLAGS.max_dec_steps, vocab_size=FLAGS.vocab_size, max_grad_norm=FLAGS.max_grad_norm, length_norm=FLAGS.length_norm, cp=FLAGS.coverage_penalty, predict_mode=FLAGS.predict_mode) train_input_fn = partial(data.input_function, is_train=True, vocab=vocab, hps=hps) eval_input_fn = partial(data.input_function, is_train=False, vocab=vocab, hps=hps) model_fn = partial(model_function.model_function, vocab=vocab, hps=hps) experiment_utils.run_experiment(model_fn=model_fn, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn)