Ejemplo n.º 1
0
def main(_):
    params = dict(data_root=FLAGS.data_root,
                  batch_size=FLAGS.batch_size,
                  eval_batch_size=FLAGS.batch_size,
                  query_seq_len=FLAGS.query_seq_len,
                  block_seq_len=FLAGS.block_seq_len,
                  learning_rate=FLAGS.learning_rate,
                  num_classes=FLAGS.num_classes,
                  num_train_steps=FLAGS.num_train_steps,
                  retriever_module_path=FLAGS.retriever_module_path,
                  reader_module_path=FLAGS.reader_module_path,
                  retriever_beam_size=FLAGS.retriever_beam_size,
                  reader_beam_size=FLAGS.reader_beam_size,
                  reader_seq_len=FLAGS.reader_seq_len,
                  span_hidden_size=FLAGS.span_hidden_size,
                  max_span_width=FLAGS.max_span_width,
                  block_records_path=FLAGS.block_records_path,
                  num_block_records=FLAGS.num_block_records)

    train_input_fn = functools.partial(text_classifier_model.input_fn,
                                       name=FLAGS.dataset_name,
                                       is_train=True)
    eval_input_fn = functools.partial(text_classifier_model.input_fn,
                                      name=FLAGS.dataset_name,
                                      is_train=False)

    experiment_utils.run_experiment(model_fn=text_classifier_model.model_fn,
                                    params=params,
                                    train_input_fn=train_input_fn,
                                    eval_input_fn=eval_input_fn,
                                    exporters=text_classifier_model.exporter(),
                                    params_fname="params.json")
Ejemplo n.º 2
0
 def test_run_experiment_tpu(self):
     params = dict(use_tpu=True)
     experiment_utils.run_experiment(
         model_fn=self._simple_model_fn,
         train_input_fn=self._simple_input_function,
         eval_input_fn=self._simple_input_function,
         params=params)
Ejemplo n.º 3
0
def train():
  """Train the model."""
  embeddings = load_embeddings()

  # Need a named parameter `param` since this will be called
  # with named arguments, so pylint: disable=unused-argument
  def model_function(features, labels, mode, params):
    """Builds the `tf.estimator.EstimatorSpec` to train/eval with."""
    is_train = mode == tf.estimator.ModeKeys.TRAIN
    logits = predict(is_train, embeddings, features["premise"],
                     features["hypothesis"])

    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=tf.to_int32(labels), logits=logits)
    loss = tf.reduce_mean(loss)
    if mode == tf.estimator.ModeKeys.TRAIN:
      train_op = get_train_op(loss)
    else:
      # Don't build the train_op unnecessarily, since the ADAM variables can
      # cause problems with loading checkpoints on CPUs.
      train_op = None
    metrics = dict(
        accuracy=tf.metrics.accuracy(
            tf.argmax(logits, 1, output_type=tf.int32), tf.to_int32(labels)))

    checkpoint_file = FLAGS.checkpoint_file
    if checkpoint_file is None:
      scaffold = None
    else:
      saver = tf.train.Saver(tf.trainable_variables())

      def _init_fn(_, sess):
        saver.restore(sess, checkpoint_file)

      scaffold = tf.train.Scaffold(init_fn=_init_fn)

    return tf.estimator.EstimatorSpec(
        mode=mode,
        scaffold=scaffold,
        loss=loss,
        predictions=None,
        train_op=train_op,
        eval_metric_ops=metrics)

  def compare_fn(best_eval_result, current_eval_result):
    return best_eval_result["accuracy"] < current_eval_result["accuracy"]

  exporter = best_checkpoint_exporter.BestCheckpointExporter(
      event_file_pattern="eval_default/*.tfevents.*",
      compare_fn=compare_fn,
  )

  experiment_utils.run_experiment(
      model_fn=model_function,
      train_input_fn=lambda: load_batched_dataset(True, embeddings),
      eval_input_fn=lambda: load_batched_dataset(False, embeddings),
      exporters=[exporter])
def main(_):
    model_function, train_input_fn, eval_input_fn, serving_input_receiver_fn = (
        nq_short_pipeline_model.experiment_functions())
    best_exporter = tf_estimator.BestExporter(
        name="best",
        serving_input_receiver_fn=serving_input_receiver_fn,
        event_file_pattern="eval_default/*.tfevents.*",
        compare_fn=nq_short_pipeline_model.compare_metrics)
    experiment_utils.run_experiment(model_fn=model_function,
                                    train_input_fn=train_input_fn,
                                    eval_input_fn=eval_input_fn,
                                    exporters=[best_exporter])
Ejemplo n.º 5
0
def main(_):
    params = dict(batch_size=FLAGS.batch_size,
                  eval_batch_size=FLAGS.eval_batch_size,
                  bert_hub_module_path=FLAGS.bert_hub_module_path,
                  query_seq_len=FLAGS.query_seq_len,
                  block_seq_len=FLAGS.block_seq_len,
                  projection_size=FLAGS.projection_size,
                  learning_rate=FLAGS.learning_rate,
                  examples_path=FLAGS.examples_path,
                  mask_rate=FLAGS.mask_rate,
                  num_train_steps=FLAGS.num_train_steps,
                  num_block_records=FLAGS.num_block_records,
                  num_input_threads=FLAGS.num_input_threads)
    experiment_utils.run_experiment(
        model_fn=ict_model.model_fn,
        train_input_fn=functools.partial(ict_model.input_fn, is_train=True),
        eval_input_fn=functools.partial(ict_model.input_fn, is_train=False),
        exporters=ict_model.exporter(),
        params=params)
Ejemplo n.º 6
0
def main(_):
  params = dict(
      batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      bert_hub_module_handle=FLAGS.bert_hub_module_handle,
      embedder_hub_module_handle=FLAGS.embedder_hub_module_handle,
      vocab_path=FLAGS.vocab_path,
      do_lower_case=FLAGS.do_lower_case,
      query_seq_len=FLAGS.query_seq_len,
      candidate_seq_len=FLAGS.candidate_seq_len,
      max_masks=FLAGS.max_masks,
      learning_rate=FLAGS.learning_rate,
      num_input_threads=FLAGS.num_input_threads,
      num_candidates=FLAGS.num_candidates,
      num_train_steps=FLAGS.num_train_steps,
      train_preprocessing_servers=FLAGS.train_preprocessing_servers,
      eval_preprocessing_servers=FLAGS.eval_preprocessing_servers,
      share_embedders=FLAGS.share_embedders,
      separate_candidate_segments=FLAGS.separate_candidate_segments)

  experiment_utils.run_experiment(
      model_fn=model.model_fn,
      train_input_fn=functools.partial(model.input_fn, is_train=True),
      eval_input_fn=functools.partial(model.input_fn, is_train=False),
      params=params,
      params_fname="estimator_params.json",
      exporters=model.get_exporters(params))

  # Write a "done" file from the trainer. As in experiment_utils, we currently
  # use 'use_tpu' as a proxy for whether this is a train or eval node.
  #
  # We could also use the 'type' field in the 'task' of the TF_CONFIG
  # environment variable, but we would generally like to get away from TF_CONFIG
  # in the future.
  #
  # This file is checked for existence by refresh_doc_embeds.
  if experiment_utils.FLAGS.use_tpu:
    model_dir = experiment_utils.EstimatorSettings.from_flags().model_dir
    training_done_filename = os.path.join(model_dir, "TRAINING_DONE")
    with tf.gfile.GFile(training_done_filename, "w") as f:
      f.write("done")
Ejemplo n.º 7
0
 def test_run_experiment(self):
     experiment_utils.run_experiment(
         model_fn=self._simple_model_fn,
         train_input_fn=self._simple_input_function,
         eval_input_fn=self._simple_input_function)
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.config.set_soft_device_placement(True)
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    train_input_fn = None
    ft_known_train_file = None
    train_file = None
    if FLAGS.do_train:
        current_seed = 0
        num_known_classes = FLAGS.num_domains * FLAGS.num_labels_per_domain
        data_output_dir = FLAGS.data_output_dir
        if not tf.gfile.Exists(data_output_dir):
            tf.gfile.MakeDirs(data_output_dir)
        known_ft_path = os.path.join(data_output_dir,
                                     "known_ft_train.tf_record")
        unknown_ft_path = os.path.join(data_output_dir,
                                       "unknown_ft_train.tf_record")
        if not tf.gfile.Glob(known_ft_path):
            preprocess_few_shot_training_data(tokenizer, known_ft_path,
                                              unknown_ft_path, current_seed)

        if FLAGS.continual_learning is None:
            assert False, "Not Implemented"
        elif FLAGS.continual_learning == "pretrain":
            train_file = os.path.join(FLAGS.data_output_dir,
                                      "known_ft_train.tf_record")
            num_classes = num_known_classes
            num_train_examples = num_known_classes * FLAGS.known_num_shots
            num_shots_per_class = FLAGS.known_num_shots
        elif FLAGS.continual_learning == "few_shot":
            train_file = os.path.join(FLAGS.data_output_dir,
                                      "unknown_ft_train.tf_record")
            ft_known_train_file = os.path.join(FLAGS.data_output_dir,
                                               "known_ft_train.tf_record")
            num_unknown_classes = NUM_CLASSES - num_known_classes
            num_classes = num_unknown_classes
            num_train_examples = num_unknown_classes * FLAGS.few_shot
            num_shots_per_class = FLAGS.few_shot

        tpu_split = FLAGS.tpu_split if FLAGS.use_tpu else 1
        if num_shots_per_class < tpu_split:
            steps_per_epoch = 1
        else:
            steps_per_epoch = num_shots_per_class // tpu_split
        num_train_steps = int(steps_per_epoch * FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
        FLAGS.num_train_steps = num_train_steps
        FLAGS.save_checkpoints_steps = int(steps_per_epoch *
                                           FLAGS.save_every_epoch)

        tf.logging.info("***** Running training *****")
        tf.logging.info("  train_file: %s" % train_file)
        tf.logging.info("  use_tpu: %s" % FLAGS.use_tpu)
        tf.logging.info("  Num examples = %d", num_train_examples)
        tf.logging.info("  Batch size = %d", FLAGS.batch_size)
        tf.logging.info("  Save checkpoints steps = %d",
                        FLAGS.save_checkpoints_steps)
        tf.logging.info("  warmup steps = %d", num_warmup_steps)
        tf.logging.info("  Num epochs = %d", FLAGS.num_train_epochs)
        tf.logging.info("  Num steps = %d", num_train_steps)
        tf.logging.info("  Reduce method = %s", FLAGS.reduce_method)
        tf.logging.info("  Max Seq Length = %d", FLAGS.max_seq_length)
        tf.logging.info(" learning_rate = %.7f", FLAGS.learning_rate)
        tf.logging.info(" dropout rate = %.4f", DROPOUT_PROB)

        train_input_fn = file_based_input_fn_builder(
            input_file=train_file,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            ft_known_train_file=ft_known_train_file,
            use_tpu=FLAGS.use_tpu)

    model_fn = model_fn_builder(bert_config=bert_config,
                                init_checkpoint=FLAGS.init_checkpoint,
                                learning_rate=FLAGS.learning_rate,
                                num_train_steps=num_train_steps,
                                num_warmup_steps=num_warmup_steps,
                                use_tpu=FLAGS.use_tpu,
                                use_one_hot_embeddings=FLAGS.use_tpu)

    FLAGS.do_eval = False
    eval_input_fn = None
    params = _get_hparams()
    params.update(num_train_steps=num_train_steps)
    if not FLAGS.do_train:
        train_input_fn = eval_input_fn

    experiment_utils.run_experiment(model_fn=model_fn,
                                    train_input_fn=train_input_fn,
                                    eval_input_fn=train_input_fn,
                                    params=params)
Ejemplo n.º 9
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    if FLAGS.model == "seq2seq":
        assert FLAGS.rnn_cell == "lstm"
        assert FLAGS.att_type != "hyper"
    if FLAGS.model == "hypernet" and FLAGS.rank != FLAGS.decoder_dim:
        print("WARNING: recommended rank value: decoder_dim.")
    if FLAGS.att_neighbor:
        assert FLAGS.neighbor_dim == FLAGS.encoder_dim or FLAGS.att_type == "my"

    if FLAGS.use_copy or FLAGS.att_neighbor:
        assert FLAGS.att_type == "my"
    # These numbers are the target vocabulary sizes of the datasets.
    # It allows for using different vocabularies for source and targets,
    # following the implementation in Open-NMT.
    # I will later put these into command line arguments.
    if FLAGS.use_bpe:
        if FLAGS.dataset == "nyt":
            output_size = 10013
        elif FLAGS.dataset == "giga":
            output_size = 24654
        elif FLAGS.dataset == "cnnd":
            output_size = 10232
    else:
        if FLAGS.dataset == "nyt":
            output_size = 68885
        elif FLAGS.dataset == "giga":
            output_size = 107389
        elif FLAGS.dataset == "cnnd":
            output_size = 21000

    vocab = data.Vocab(FLAGS.vocab_path, FLAGS.vocab_size, FLAGS.dataset)
    hps = tf.contrib.training.HParams(
        sample_neighbor=FLAGS.sample_neighbor,
        use_cluster=FLAGS.use_cluster,
        binary_neighbor=FLAGS.binary_neighbor,
        att_neighbor=FLAGS.att_neighbor,
        encode_neighbor=FLAGS.encode_neighbor,
        sum_neighbor=FLAGS.sum_neighbor,
        dataset=FLAGS.dataset,
        rnn_cell=FLAGS.rnn_cell,
        output_size=output_size + vocab.offset,
        train_path=FLAGS.train_path,
        dev_path=FLAGS.dev_path,
        tie_embedding=FLAGS.tie_embedding,
        use_bpe=FLAGS.use_bpe,
        use_copy=FLAGS.use_copy,
        reuse_attention=FLAGS.reuse_attention,
        use_bridge=FLAGS.use_bridge,
        use_residual=FLAGS.use_residual,
        att_type=FLAGS.att_type,
        random_neighbor=FLAGS.random_neighbor,
        num_neighbors=FLAGS.num_neighbors,
        model=FLAGS.model,
        trainer=FLAGS.trainer,
        learning_rate=FLAGS.learning_rate,
        lr_schedule=FLAGS.lr_schedule,
        total_steps=FLAGS.total_steps,
        emb_dim=FLAGS.emb_dim,
        binary_dim=FLAGS.binary_dim,
        neighbor_dim=FLAGS.neighbor_dim,
        drop=FLAGS.drop,
        emb_drop=FLAGS.emb_drop,
        out_drop=FLAGS.out_drop,
        encoder_drop=FLAGS.encoder_drop,
        decoder_drop=FLAGS.decoder_drop,
        weight_decay=FLAGS.weight_decay,
        encoder_dim=FLAGS.encoder_dim,
        num_encoder_layers=FLAGS.num_encoder_layers,
        decoder_dim=FLAGS.decoder_dim,
        num_decoder_layers=FLAGS.num_decoder_layers,
        num_mlp_layers=FLAGS.num_mlp_layers,
        rank=FLAGS.rank,
        sigma_norm=FLAGS.sigma_norm,
        batch_size=FLAGS.batch_size,
        sampling_probability=FLAGS.sampling_probability,
        beam_width=FLAGS.beam_width,
        max_enc_steps=FLAGS.max_enc_steps,
        max_dec_steps=FLAGS.max_dec_steps,
        vocab_size=FLAGS.vocab_size,
        max_grad_norm=FLAGS.max_grad_norm,
        length_norm=FLAGS.length_norm,
        cp=FLAGS.coverage_penalty,
        predict_mode=FLAGS.predict_mode)

    train_input_fn = partial(data.input_function,
                             is_train=True,
                             vocab=vocab,
                             hps=hps)
    eval_input_fn = partial(data.input_function,
                            is_train=False,
                            vocab=vocab,
                            hps=hps)

    model_fn = partial(model_function.model_function, vocab=vocab, hps=hps)
    experiment_utils.run_experiment(model_fn=model_fn,
                                    train_input_fn=train_input_fn,
                                    eval_input_fn=eval_input_fn)