コード例 #1
0
    def build_graph_from_config(self,
                                model_config,
                                mode="encode"):  #, checkpoint_path):
        """Builds the inference graph from a configuration object.

    Args:
      model_config: Object containing configuration for building the model.
      checkpoint_path: Checkpoint file or a directory containing a checkpoint
        file.

    Returns:
      restore_fn: A function such that restore_fn(sess) loads model variables
        from the checkpoint file.
    """
        tf.logging.info("Building model.")
        model = s2v_model.s2v(model_config, mode=mode)
        if mode == "train":
            model.build()
        elif mode == "encode":
            model.build_enc()
        self._embeddings = model.word_embeddings
        saver = tf.train.Saver()
        checkpoint_path = model_config.checkpoint_path

        return self._create_restore_fn(checkpoint_path, saver)
コード例 #2
0
def main(unused_argv):
    if ((not FLAGS.input_file_pattern_word) |
        (not FLAGS.input_file_pattern_POS)):
        print("Hello")
        raise ValueError("--input_file_pattern is required.")
    if not FLAGS.train_dir:
        raise ValueError("--train_dir is required.")

    with open(FLAGS.model_config) as json_config_file:
        model_config = json.load(json_config_file)

    model_config = configuration.model_config(model_config, mode="train")
    tf.logging.info("Building training graph.")
    g = tf.Graph()
    with g.as_default():
        model = s2v_model.s2v(model_config, mode="train")
        model.build()

        optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)

        train_tensor = tf.contrib.slim.learning.create_train_op(
            total_loss=model.total_loss,
            optimizer=optimizer,
            clip_gradient_norm=FLAGS.clip_gradient_norm)
        #global_step=model.global_step,

        if FLAGS.max_ckpts != 5:
            saver = tf.train.Saver(max_to_keep=FLAGS.max_ckpts)
        else:
            saver = tf.train.Saver()

    load_words = model.init
    if load_words:

        def InitAssignFn(sess):
            sess.run(load_words[0], {load_words[1]: load_words[2]})

    nsteps = int(FLAGS.nepochs * (FLAGS.num_train_inst / FLAGS.batch_size))
    tf.contrib.slim.learning.train(
        train_op=train_tensor,
        logdir=FLAGS.train_dir,
        graph=g,
        number_of_steps=nsteps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        saver=saver,
        save_interval_secs=FLAGS.save_model_secs,
        init_fn=InitAssignFn if load_words else None)
コード例 #3
0
def main(unused_argv):
  if not FLAGS.input_file_pattern:
    raise ValueError("--input_file_pattern is required.")
  if not FLAGS.train_dir:
    raise ValueError("--train_dir is required.")

  with open(FLAGS.model_config) as json_config_file:
    model_config = json.load(json_config_file)

  model_config = configuration.model_config(model_config, mode="train")
  tf.logging.info("Building training graph.")
  g = tf.Graph()
  with g.as_default():
    model = s2v_model.s2v(model_config, mode="train")
    model.build()

    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)

    train_tensor = tf.contrib.slim.learning.create_train_op(
        total_loss=model.total_loss,
        optimizer=optimizer,
        clip_gradient_norm=FLAGS.clip_gradient_norm)

    saver = tf.train.Saver(max_to_keep=FLAGS.max_ckpts)

    checkpoint_path = model_config.checkpoint_path
    variables_to_restore = slim.get_model_variables()
    checkpoint_path = tf.train.latest_checkpoint(model_config.checkpoint_path)
    init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
        checkpoint_path, variables_to_restore)

    def InitAssignFn(sess):
      sess.run(init_assign_op, init_feed_dict)

  nsteps = int(FLAGS.nepochs * (FLAGS.num_train_inst / FLAGS.batch_size))
  slim.learning.train(
      train_op=train_tensor,
      logdir=FLAGS.train_dir,
      graph=g,
      number_of_steps=nsteps,
      save_summaries_secs=FLAGS.save_summaries_secs,
      saver=saver,
      save_interval_secs=FLAGS.save_model_secs,
      init_fn=InitAssignFn
  )
コード例 #4
0
def main(unused_argv):
    if not FLAGS.input_file_pattern:
        raise ValueError("--input_file_pattern is required.")
    if not FLAGS.checkpoint_dir:
        raise ValueError("--checkpoint_dir is required.")
    if not FLAGS.eval_dir:
        raise ValueError("--eval_dir is required.")

    eval_dir = FLAGS.eval_dir
    if not tf.gfile.IsDirectory(eval_dir):
        tf.logging.info("Creating eval directory: %s", eval_dir)
        tf.gfile.MakeDirs(eval_dir)

    with open(FLAGS.model_config) as json_config_file:
        model_config = json.load(json_config_file)

    model_config = configuration.model_config(model_config, mode="eval")
    model = s2v_model.s2v(model_config, mode="eval")
    model.build()

    tf.summary.scalar("Loss", model.total_loss)
    summary_op = tf.summary.merge_all()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    nsteps = int(FLAGS.num_eval_examples / FLAGS.batch_size)
    tf.contrib.slim.evaluation.evaluation_loop(
        master=FLAGS.master,
        checkpoint_dir=FLAGS.checkpoint_dir,
        logdir=FLAGS.eval_dir,
        num_evals=nsteps,
        eval_op=model.eval_op,
        summary_op=summary_op,
        eval_interval_secs=FLAGS.eval_interval_secs,
        session_config=config)
コード例 #5
0
    def build_graph_from_config(self, model_config):  #, checkpoint_path):
        """Builds the inference graph from a configuration object.

    Args:
      model_config: Object containing configuration for building the model.
      checkpoint_path: Checkpoint file or a directory containing a checkpoint
        file.

    Returns:
      restore_fn: A function such that restore_fn(sess) loads model variables
        from the checkpoint file.
    """
        tf.logging.info("Building model.")
        model = s2v_model.s2v(model_config, mode="encode")
        model.build_enc()
        self._embeddings = model.word_embeddings
        saver = tf.train.Saver()
        checkpoint_path = model_config.checkpoint_path
        if FLAGS.model_version == 0:  # Running original model
            checkpoint_path = model_config.checkpoint_path + "/original model"
        elif FLAGS.model_version == 1:  # Running improved model
            checkpoint_path = model_config.checkpoint_path + "/improved model"

        return self._create_restore_fn(checkpoint_path, saver)
コード例 #6
0
def main(input_file_pattern,
         train_dir,
         model_config,
         word2vec_path,
         learning_rate=0.005,
         clip_gradient_norm=5.0,
         uniform_init_scale=0.1,
         shuffle_input_data=False,
         input_queue_capacity=640000,
         num_input_reader_threads=1,
         dropout=False,
         dropout_rate=0.3,
         context_size=1,
         num_train_inst=800000,
         batch_size=128,
         nepochs=1,
         max_ckpts=5,
         save_summaries_secs=600,
         save_model_secs=600):

    start = time.time()
    if not input_file_pattern:
        raise ValueError("--input_file_pattern is required.")
    if not train_dir:
        raise ValueError("--train_dir is required.")

    with open(model_config) as json_config_file:
        model_config = json.load(json_config_file)

    model_config = configuration.model_config(model_config,
                                              mode="train",
                                              word2vec_path=word2vec_path)
    tf.logging.info("Building training graph.")
    g = tf.Graph()
    with g.as_default():
        model = s2v_model.s2v(model_config,
                              uniform_init_scale,
                              input_file_pattern,
                              shuffle_input_data,
                              input_queue_capacity,
                              num_input_reader_threads,
                              batch_size,
                              dropout,
                              dropout_rate,
                              context_size,
                              mode="train")
        model.build()
        optimizer = tf.train.AdamOptimizer(learning_rate)
        train_tensor = tf.contrib.slim.learning.create_train_op(
            total_loss=model.total_loss,
            optimizer=optimizer,
            clip_gradient_norm=clip_gradient_norm)

        if max_ckpts != 5:
            saver = tf.train.Saver(max_to_keep=max_ckpts)

        else:
            saver = tf.train.Saver()

    load_words = model.init  # ????????初始化的【encode,encode】,如果fixed
    # print("load_words",load_words)
    if load_words:

        def InitAssignFn(sess):
            sess.run(load_words[0], {load_words[1]: load_words[2]})

    nsteps = int(nepochs * (num_train_inst / batch_size))

    tf.contrib.slim.learning.train(
        train_op=train_tensor,
        logdir=train_dir,
        graph=g,
        number_of_steps=nsteps,
        save_summaries_secs=save_summaries_secs,
        saver=saver,
        save_interval_secs=save_model_secs,
        init_fn=InitAssignFn if load_words else None)
    end = time.time()
    cost_time = end - start
    tf.logging.info("the cost time of training is %f ! ", cost_time)