Esempio n. 1
0
def eval_during_train(n_token, cutoffs, ps_device, sess):

    tf.logging.info("Reporting on valid during training")

    eval_input_fn, eval_record_info = data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split=FLAGS.eval_split, # train or valid
        per_host_bsz=FLAGS.eval_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_core_per_host,
        num_hosts=1,
        use_tpu=False)

    num_batch = eval_record_info["num_batch"]
    if FLAGS.max_eval_batch > 0:
        num_batch = FLAGS.max_eval_batch
    tf.logging.info("num of batches {}".format(num_batch))

    eval_set = eval_input_fn({
        "batch_size": FLAGS.eval_batch_size,
        "data_dir": FLAGS.data_dir})
    
    input_feed, label_feed = eval_set.make_one_shot_iterator().get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
    labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

    per_core_bsz = FLAGS.eval_batch_size // FLAGS.num_core_per_host
    tower_mems, tower_losses, tower_new_mems = [], [], []
Esempio n. 2
0
def get_input_fn(split):
    """doc."""
    assert split == "train" or split == "valid"

    record_info_dir = os.path.join(FLAGS.record_info_dir, split)
    batch_size = FLAGS.batch_size

    input_fn, record_info_dict = data_utils.get_input_fn(
        info_dir=record_info_dir,
        split=split,
        bsz_per_host=batch_size // FLAGS.num_hosts,
        seq_len=FLAGS.seq_len,
        reuse_len=FLAGS.reuse_len,
        bi_data=FLAGS.bi_data,
        num_hosts=FLAGS.num_hosts,
        num_core_per_host=FLAGS.num_core_per_host,
        perm_size=FLAGS.perm_size,
        mask_alpha=FLAGS.mask_alpha,
        mask_beta=FLAGS.mask_beta,
        use_bfloat16=FLAGS.use_bfloat16,
        num_predict=FLAGS.num_predict,
        use_tpu=FLAGS.use_tpu,
        bucket_uri=FLAGS.bucket_uri)

    return input_fn, record_info_dict
Esempio n. 3
0
def get_AlexNet_experiment(args):
    """
	Function for creating an experiment using the AlexNet model on ImageNet
	"""
    train_input_fn = data_utils.get_input_fn(data_dir=os.path.join(
        args.data_dir, 'train'),
                                             num_epochs=args.num_epochs,
                                             batch_size=args.batch_size,
                                             shuffle=True)

    val_input_fn = data_utils.get_input_fn(data_dir=os.path.join(
        args.data_dir, 'val'),
                                           num_epochs=1,
                                           batch_size=2 * args.batch_size,
                                           shuffle=False)

    net = model.AlexNet(num_classes=1000, scope='ImageNet_AlexNet')

    config = tf.contrib.learn.RunConfig(
        log_device_placement=False,
        gpu_memory_fraction=0.98,
        tf_random_seed=1234,
        save_summary_steps=50,
        save_checkpoints_secs=300,
        keep_checkpoint_max=10000,
        keep_checkpoint_every_n_hours=10000,
        log_step_count_steps=10,
    )

    estimator = tf.estimator.Estimator(model_fn=net.get_model_fn(),
                                       model_dir=args.model_dir,
                                       config=config,
                                       params={'learning_rate': args.lr})

    experiment = tf.contrib.learn.Experiment(estimator=estimator,
                                             train_input_fn=train_input_fn,
                                             eval_input_fn=val_input_fn,
                                             eval_metrics=None,
                                             train_steps=None,
                                             eval_steps=None,
                                             train_monitors=[],
                                             min_eval_frequency=1000,
                                             eval_delay_secs=240)

    return experiment
Esempio n. 4
0
def get_simple_nn_experiment(args):
    """
    Function for creating an experiment using the SimpleNN model on MNIST
    """
    train_input_fn = data_utils.get_input_fn(data_dir=args.data_dir,
                                             is_training=True,
                                             num_epochs=args.num_epochs,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             normalize=args.normalize)

    val_input_fn = data_utils.get_input_fn(data_dir=args.data_dir,
                                           is_training=False,
                                           num_epochs=1,
                                           batch_size=2 * args.batch_size,
                                           shuffle=False,
                                           normalize=args.normalize)

    simplecnn = model.SimpleMnistModel(num_classes=args.num_classes,
                                       scope='SimpleMnist')

    config = tf.estimator.RunConfig(keep_checkpoint_max=10000,
                                    tf_random_seed=1234,
                                    save_summary_steps=50,
                                    save_checkpoints_secs=120)

    estimator = tf.estimator.Estimator(model_fn=simplecnn.get_model_fn(),
                                       model_dir=args.model_dir,
                                       config=config,
                                       params={'learning_rate': args.lr})

    experiment = tf.contrib.learn.Experiment(
        estimator=estimator,
        train_input_fn=train_input_fn,
        eval_input_fn=val_input_fn,
        eval_metrics=None,
        train_steps=None,
        eval_steps=None,
        train_monitors=[],
        min_eval_frequency=1,
    )

    return experiment
def test_data():

    h = copy(hparams)
    h.dataset = "unit_test"
    (sequences, seq_lens), labels = get_input_fn("train", h)()

    sess = tf.Session()
    s, c, l = sess.run([sequences, seq_lens, labels])

    t = Tokenizer.get_tokenizer(hparams)
    decoded = t.decode(s)[0]
    target = "aside from the terrific sea rescue _UNK of which there are very few i just did not care about any of the _UNK"

    assert decoded == target.split()
    assert l[0][0] == 0
Esempio n. 6
0
File: train.py Progetto: w-h-m/coqa
def get_input_fn(split):
    """doc."""
    assert split == "train"
    batch_size = FLAGS.train_batch_size

    input_fn, record_info_dict = data_utils.get_input_fn(
        tfrecord_dir=FLAGS.record_info_dir,
        split=split,
        bsz_per_host=batch_size // FLAGS.num_hosts,
        seq_len=FLAGS.seq_len,
        reuse_len=FLAGS.reuse_len,
        bi_data=FLAGS.bi_data,
        num_hosts=FLAGS.num_hosts,
        num_core_per_host=FLAGS.num_core_per_host,
        perm_size=FLAGS.perm_size,
        mask_alpha=FLAGS.mask_alpha,
        mask_beta=FLAGS.mask_beta,
        uncased=FLAGS.uncased,
        num_passes=FLAGS.num_passes,
        use_bfloat16=FLAGS.use_bfloat16,
        num_predict=FLAGS.num_predict)

    return input_fn, record_info_dict
def evaluate(n_token, cutoffs, ps_device):
    # Get input function and model function
    eval_input_fn, eval_record_info = data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split=FLAGS.eval_split,
        per_host_bsz=FLAGS.eval_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_core_per_host,
        num_hosts=1,
        use_tpu=False)

    num_batch = eval_record_info["num_batch"]
    if FLAGS.max_eval_batch > 0:
        num_batch = FLAGS.max_eval_batch
    tf.logging.info("num of batches {}".format(num_batch))

    # Create computational graph
    eval_set = eval_input_fn({
        "batch_size": FLAGS.eval_batch_size,
        "data_dir": FLAGS.data_dir})

    input_feed, label_feed = eval_set.make_one_shot_iterator().get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
    labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

    per_core_bsz = FLAGS.eval_batch_size // FLAGS.num_core_per_host
    tower_mems, tower_losses, tower_new_mems = [], [], []

    for i in range(FLAGS.num_core_per_host):
        with tf.device(assign_to_gpu(i, ps_device)), \
             tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            mems_i = [tf.placeholder(tf.float32,
                                     [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                      for _ in range(FLAGS.n_layer)]

            loss_i, new_mems_i = single_core_graph(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=False,
                inp=inputs[i],
                tgt=labels[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)

    # sum losses across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
    else:
        loss = tower_losses[0]

    # Evaluation loop
    tower_mems_np = [
        [np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
         for layer in range(FLAGS.n_layer)]
        for core in range(FLAGS.num_core_per_host)
    ]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.eval_ckpt_path is None:
            eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
        else:
            eval_ckpt_path = FLAGS.eval_ckpt_path
        tf.logging.info("Evaluate {}".format(eval_ckpt_path))
        saver.restore(sess, eval_ckpt_path)

        fetches = [loss, tower_new_mems, tf.size(label_feed)]

        format_str = "  >> processing batch {{:{0}d}}/{{:{0}d}} ..".format(
            len(str(num_batch)))

        total_loss, total_cnt = 0, 0
        for step in range(num_batch):
            if step % (num_batch // 10) == 0:
                tf.logging.info(format_str.format(step, num_batch))

            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np

            fetched = sess.run(fetches, feed_dict=feed_dict)

            loss_np, tower_mems_np, cnt_np = fetched[:3]
            total_loss += loss_np * cnt_np
            total_cnt += cnt_np

        avg_loss = total_loss / total_cnt
        tf.logging.info("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
            avg_loss, math.exp(avg_loss), avg_loss / math.log(2)))
def train(n_token, cutoffs, ps_device):
    # os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'

    # Get input function and model function
    train_input_fn, train_record_info = data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split="train",
        per_host_bsz=FLAGS.train_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_core_per_host,
        num_hosts=1,
        use_tpu=False)

    tf.logging.info("num of batches {}".format(train_record_info["num_batch"]))

    # Create computational graph
    train_set = train_input_fn({
        "batch_size": FLAGS.train_batch_size,
        "data_dir": FLAGS.data_dir})

    input_feed, label_feed = train_set.make_one_shot_iterator().get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
    labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

    print_op = tf.print(inputs)

    per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host

    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        #todo  review here
        with tf.device(assign_to_gpu(i, ps_device)), \
             tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
            mems_i = [tf.placeholder(tf.float32,
                                     [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                      for _ in range(FLAGS.n_layer)]

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=True,
                inp=inputs[i],
                tgt=labels[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

    # average losses and gradients across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]
    grads, all_vars = zip(*grads_and_vars)

    # clip gradient
    clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip)
    grads_and_vars = list(zip(clipped, all_vars))

    # configure the optimizer
    global_step = tf.train.get_or_create_global_step()

    # warmup stage: increase the learning rate linearly
    if FLAGS.warmup_steps > 0:
        warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                    * FLAGS.learning_rate
    else:
        warmup_lr = 0.0

    # decay stage: decay the learning rate using the cosine schedule
    decay_lr = tf.train.cosine_decay(
        FLAGS.learning_rate,
        global_step=global_step - FLAGS.warmup_steps,
        decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
        alpha=FLAGS.min_lr_ratio)

    # choose warmup or decay
    learning_rate = tf.where(global_step < FLAGS.warmup_steps,
                             warmup_lr, decay_lr)

    # get the train op
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    # Training loop
    tower_mems_np = [
        [np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
         for layer in range(FLAGS.n_layer)]
        for core in range(FLAGS.num_core_per_host)
    ]

    saver = tf.train.Saver()

    tf.summary.scalar('learning_rate', learning_rate)
    tf.summary.scalar('loss', loss)
    # tf.summary.scalar('pplx', math.exp(curr_loss))
    merged = tf.summary.merge_all()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        # todo 放在 此处是因为不用重复的创建trainer目录能显示变量
        train_writer = tf.summary.FileWriter(os.path.join(FLAGS.model_dir, "log"), sess.graph)

        if FLAGS.warm_start_path is not None:
            tf.logging.info("warm start from {}".format(FLAGS.warm_start_path))
            saver.restore(sess, FLAGS.warm_start_path)

        fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]

        total_loss, prev_step = 0., -1
        while True:
            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np

            #old
            # fetched = sess.run(fetches, feed_dict=feed_dict)

            # with tf.control_dependencies([print_op]):
            summary, fetched = sess.run([merged, fetches], feed_dict=feed_dict)

            loss_np, tower_mems_np, curr_step = fetched[:3]
            total_loss += loss_np

            if curr_step > 0 and curr_step % FLAGS.iterations == 0:
                curr_loss = total_loss / (curr_step - prev_step)
                tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
                                "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(curr_step, fetched[-3], fetched[-2], curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
                total_loss, prev_step = 0., curr_step
                train_writer.add_summary(summary, curr_step)

            if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                save_path = os.path.join(FLAGS.model_dir, "model-{}.ckpt".format(curr_step))
                saver.save(sess, save_path)
                tf.logging.info("Model saved in path: {}".format(save_path))

            if curr_step == FLAGS.train_steps:
                train_writer.close()
                break
Esempio n. 9
0
def train(ps_device):
  ##### Get input function and model function

  train_input_fn, record_info_dict = data_utils.get_input_fn(
      tfrecord_dir=FLAGS.record_info_dir,
      split="train",
      bsz_per_host=FLAGS.train_batch_size,
      seq_len=FLAGS.seq_len,
      reuse_len=FLAGS.reuse_len,
      bi_data=FLAGS.bi_data,
      num_hosts=1,
      num_core_per_host=1, # set to one no matter how many GPUs
      perm_size=FLAGS.perm_size,
      mask_alpha=FLAGS.mask_alpha,
      mask_beta=FLAGS.mask_beta,
      uncased=FLAGS.uncased,
      num_passes=FLAGS.num_passes,
      use_bfloat16=FLAGS.use_bfloat16,
      num_predict=FLAGS.num_predict)

  # for key, info in record_info_dict.items():
  tf.compat.v1.logging.info("num of batches {}".format(record_info_dict["num_batch"]))

  ##### Create input tensors / placeholders
  bsz_per_core = FLAGS.train_batch_size // FLAGS.num_core_per_host

  params = {
      "batch_size": FLAGS.train_batch_size # the whole batch
  }
  train_set = train_input_fn(params)

  example = train_set.make_one_shot_iterator().get_next()

  if FLAGS.num_core_per_host > 1:
    examples = [{} for _ in range(FLAGS.num_core_per_host)]
    for key in example.keys():
      vals = tf.split(example[key], FLAGS.num_core_per_host, 0)
      for device_id in range(FLAGS.num_core_per_host):
        examples[device_id][key] = vals[device_id]
  else:
    examples = [example]

  ##### Create computational graph
  tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

  for i in range(FLAGS.num_core_per_host):
    reuse = True if i > 0 else None
    with tf.device(assign_to_gpu(i, ps_device)), \
        tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

      # The mems for each tower is a dictionary
      mems_i = {}
      if FLAGS.mem_len:
        mems_i["mems"] = create_mems_tf(bsz_per_core)

      loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
          is_training=True,
          features=examples[i],
          mems=mems_i)

      tower_mems.append(mems_i)
      tower_losses.append(loss_i)
      tower_new_mems.append(new_mems_i)
      tower_grads_and_vars.append(grads_and_vars_i)

  ## average losses and gradients across towers
  if len(tower_losses) > 1:
    loss = tf.add_n(tower_losses) / len(tower_losses)
    grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
  else:
    loss = tower_losses[0]
    grads_and_vars = tower_grads_and_vars[0]

  ## get train op
  train_op, learning_rate, gnorm = model_utils.get_train_op(FLAGS, None,
      grads_and_vars=grads_and_vars)
  global_step = tf.train.get_global_step()

  ##### Training loop
  # initialize mems
  tower_mems_np = []
  for i in range(FLAGS.num_core_per_host):
    mems_i_np = {}
    for key in tower_mems[i].keys():
      mems_i_np[key] = initialize_mems_np(bsz_per_core)
    tower_mems_np.append(mems_i_np)

  saver = tf.train.Saver()

  gpu_options = tf.GPUOptions(allow_growth=True)

  model_utils.init_from_checkpoint(FLAGS, global_vars=True)

  with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
      gpu_options=gpu_options)) as sess:
    sess.run(tf.global_variables_initializer())

    fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]

    total_loss, prev_step = 0., -1
    while True:
      feed_dict = {}
      for i in range(FLAGS.num_core_per_host):
        for key in tower_mems_np[i].keys():
          for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]):
            feed_dict[m] = m_np

      fetched = sess.run(fetches, feed_dict=feed_dict)

      loss_np, tower_mems_np, curr_step = fetched[:3]
      total_loss += loss_np

      if curr_step > 0 and curr_step % FLAGS.iterations == 0:
        curr_loss = total_loss / (curr_step - prev_step)
        tf.compat.v1.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
            "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
            curr_step, fetched[-3], fetched[-2],
            curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
        total_loss, prev_step = 0., curr_step

      if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
        save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
        saver.save(sess, save_path)
        tf.compat.v1.logging.info("Model saved in path: {}".format(save_path))

      if curr_step >= FLAGS.train_steps:
        break
Esempio n. 10
0
def main(unused_argv):
    del unused_argv  # Unused

    tf.logging.set_verbosity(tf.logging.INFO)

    # Get corpus info
    corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
    n_token = corpus_info["vocab_size"]
    cutoffs = corpus_info["cutoffs"][1:-1]

    if FLAGS.save_steps == 0:
        FLAGS.save_steps = None

    if not FLAGS.do_eval_only:
        # Get train input function
        train_input_fn, train_record_info = data_utils.get_input_fn(
            record_info_dir=FLAGS.record_info_dir,
            split="train",
            per_host_bsz=FLAGS.train_batch_size // FLAGS.num_hosts,
            tgt_len=FLAGS.tgt_len,
            num_core_per_host=FLAGS.num_core_per_host,
            num_hosts=FLAGS.num_hosts,
            use_tpu=FLAGS.use_tpu)
        train_bin_sizes = train_record_info["bin_sizes"]
        num_train_batch = train_record_info["num_batch"]

        # Get train cache function
        train_cache_fn = get_cache_fn(FLAGS.mem_len)
    else:
        train_bin_sizes = []
        num_train_batch = None
        train_cache_fn = None

    if FLAGS.do_eval or FLAGS.do_eval_only:
        assert FLAGS.num_hosts == 1
        # Get eval input function
        eval_input_fn, eval_record_info = data_utils.get_input_fn(
            record_info_dir=FLAGS.record_info_dir,
            split=FLAGS.eval_split,
            per_host_bsz=FLAGS.eval_batch_size // FLAGS.num_hosts,
            tgt_len=FLAGS.tgt_len,
            num_core_per_host=FLAGS.num_core_per_host,
            num_hosts=FLAGS.num_hosts,
            use_tpu=FLAGS.use_tpu)
        eval_bin_sizes = eval_record_info["bin_sizes"]
        num_eval_batch = eval_record_info["num_batch"]

        if FLAGS.max_eval_batch > 0:
            num_eval_batch = min(FLAGS.max_eval_batch, num_eval_batch)

        # Get eval cache function
        eval_cache_fn = get_cache_fn(FLAGS.mem_len)
        model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes,
                                eval_bin_sizes)
    else:
        eval_cache_fn = None
        model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes, [])

    ##### Create estimator
    # TPU Configuration
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    per_host_input = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        session_config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=True),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations,
            num_shards=FLAGS.num_core_per_host * FLAGS.num_hosts,
            per_host_input_for_training=per_host_input),
        keep_checkpoint_max=100000,  # effectively save all checkpoints
        save_checkpoints_secs=None,
        save_checkpoints_steps=FLAGS.save_steps)

    # warm start
    warm_start_from = None
    if FLAGS.warm_start_path is not None:
        warm_start_from = tf.estimator.WarmStartSettings(
            ckpt_to_initialize_from=FLAGS.warm_start_path)

    # TPU Estimator
    estimator = tpu_estimator.TPUEstimator(
        model_fn=model_fn,
        train_cache_fn=train_cache_fn,
        eval_cache_fn=eval_cache_fn,
        use_tpu=FLAGS.use_tpu,
        config=run_config,
        params={
            "data_dir": FLAGS.data_dir,
            "track_mean": FLAGS.track_mean
        },
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        warm_start_from=warm_start_from)

    if FLAGS.do_eval_only:
        if FLAGS.eval_ckpt_path is not None:
            ret = estimator.evaluate(input_fn=eval_input_fn,
                                     steps=num_eval_batch,
                                     checkpoint_path=FLAGS.eval_ckpt_path)
            tf.logging.info("=" * 200)
            log_str = "Eval results | "
            for key, val in ret.items():
                log_str += "{} {} | ".format(key, val)
            tf.logging.info(log_str)
            tf.logging.info("=" * 200)
        else:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.model_dir)
            eval_results = []
            for eval_checkpoint in ckpt_state.all_model_checkpoint_paths:
                if not exists(eval_checkpoint + ".index"): continue
                global_step = int(eval_checkpoint.split("-")[-1])
                if global_step < FLAGS.start_eval_steps or global_step > FLAGS.train_steps:
                    continue
                ret = estimator.evaluate(input_fn=eval_input_fn,
                                         steps=num_eval_batch,
                                         checkpoint_path=eval_checkpoint)
                eval_results.append(ret)

            eval_results.sort(key=lambda x: x["perplexity"])

            tf.logging.info("=" * 200)
            log_str = "Best results | "
            for key, val in eval_results[0].items():
                log_str += "{} {} | ".format(key, val)
            tf.logging.info(log_str)
            tf.logging.info("=" * 200)
    else:
        if not FLAGS.do_eval:
            estimator.train(input_fn=train_input_fn, steps=FLAGS.train_steps)
        else:
            for step in range(0, FLAGS.train_steps, num_train_batch):
                train_steps = min(FLAGS.train_steps - step, num_train_batch)
                estimator.train(input_fn=train_input_fn, steps=train_steps)
                estimator.evaluate(input_fn=eval_input_fn,
                                   steps=num_eval_batch)
Esempio n. 11
0
def train(ps_device):
    # Get input function and model function

    train_input_fn, record_info_dict = data_utils.get_input_fn(
        tfrecord_dir=FLAGS.record_info_dir,
        split="train",
        bsz_per_host=FLAGS.train_batch_size,
        seq_len=FLAGS.seq_len,
        reuse_len=FLAGS.reuse_len,
        bi_data=FLAGS.bi_data,
        num_hosts=1,
        num_core_per_host=1,  # set to one no matter how many GPUs
        perm_size=FLAGS.perm_size,
        mask_alpha=FLAGS.mask_alpha,
        mask_beta=FLAGS.mask_beta,
        uncased=FLAGS.uncased,
        num_passes=FLAGS.num_passes,
        use_bfloat16=FLAGS.use_bfloat16,
        num_predict=FLAGS.num_predict)

    # for key, info in record_info_dict.items():
    tf.logging.info("num of batches {}".format(record_info_dict["num_batch"]))

    # Create input tensors / placeholders
    bsz_per_core = FLAGS.train_batch_size // FLAGS.num_core_per_host

    params = {
        "batch_size": FLAGS.train_batch_size  # the whole batch
    }
    train_set = train_input_fn(params)

    example = train_set.make_one_shot_iterator().get_next()

    if FLAGS.num_core_per_host > 1:
        examples = [{} for _ in range(FLAGS.num_core_per_host)]
        for key in example.keys():
            vals = tf.split(example[key], FLAGS.num_core_per_host, 0)
            for device_id in range(FLAGS.num_core_per_host):
                examples[device_id][key] = vals[device_id]
    else:
        examples = [example]

    # Create computational graph
    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
                tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

            # The mems for each tower is a dictionary
            mems_i = {}
            if FLAGS.mem_len:
                mems_i["mems"] = create_mems_tf(bsz_per_core)

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                is_training=True,
                features=examples[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

    # average losses and gradients across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]

    # get train op
    train_op, learning_rate, gnorm = model_utils.get_train_op(FLAGS, None,
                                                              grads_and_vars=grads_and_vars)
    global_step = tf.train.get_global_step()

    # Training loop
    # initialize mems
    tower_mems_np = []
    for i in range(FLAGS.num_core_per_host):
        mems_i_np = {}
        for key in tower_mems[i].keys():
            mems_i_np[key] = initialize_mems_np(bsz_per_core)
        tower_mems_np.append(mems_i_np)

    saver = tf.train.Saver()

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.97)#allow_growth=True)

    model_utils.init_from_checkpoint(FLAGS, global_vars=True)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False,
                                          gpu_options=gpu_options)) as sess:
        sess.run(tf.global_variables_initializer())
        sess.graph.finalize()
        run_metadata = tf.RunMetadata()
        options = tf.RunOptions(trace_level=tf.RunOptions.SOFTWARE_TRACE)

        dot_rep = graph_to_dot(tf.get_default_graph())
        # s = Source(dot_rep, filename="test.gv", format="PNG")
        with open('profs/xln.dot', 'w') as fwr:
            fwr.write(str(dot_rep))

        operations_tensors = {}
        operations_attributes = {}
        operations_names = tf.get_default_graph().get_operations()
        count1 = 0
        count2 = 0

        for operation in operations_names:
            operation_name = operation.name
            operations_info = tf.get_default_graph(
            ).get_operation_by_name(operation_name).values()

            try:
                operations_attributes[operation_name] = []
                operations_attributes[operation_name].append(
                    operation.type)
                operations_attributes[operation_name].append(tf.get_default_graph(
                ).get_tensor_by_name(operation_name + ':0').dtype._is_ref_dtype)
            except:
                pass

            if len(operations_info) > 0:
                if not (operations_info[0].shape.ndims is None):
                    operation_shape = operations_info[0].shape.as_list(
                    )
                    operation_dtype_size = operations_info[0].dtype.size
                    if not (operation_dtype_size is None):
                        operation_no_of_elements = 1
                        for dim in operation_shape:
                            if not(dim is None):
                                operation_no_of_elements = operation_no_of_elements * dim
                        total_size = operation_no_of_elements * operation_dtype_size
                        operations_tensors[operation_name] = total_size
                    else:
                        count1 = count1 + 1
                else:
                    count1 = count1 + 1
                    operations_tensors[operation_name] = -1

                #   print('no shape_1: ' + operation_name)
                #  print('no shape_2: ' + str(operations_info))
                #  operation_namee = operation_name + ':0'
                # tensor = tf.get_default_graph().get_tensor_by_name(operation_namee)
                # print('no shape_3:' + str(tf.shape(tensor)))
                # print('no shape:' + str(tensor.get_shape()))

            else:
                # print('no info :' + operation_name)
                # operation_namee = operation.name + ':0'
                count2 = count2 + 1
                operations_tensors[operation_name] = -1

                # try:
                #   tensor = tf.get_default_graph().get_tensor_by_name(operation_namee)
                # print(tensor)
                # print(tf.shape(tensor))
                # except:
                # print('no tensor: ' + operation_namee)
        print(count1)
        print(count2)

        with open('./profs/tensors_sz_32.txt', 'w') as f:
            for tensor, size in operations_tensors.items():
                f.write('"' + tensor + '"::' + str(size) + '\n')

        with open('./profs/operations_attributes.txt', 'w') as f:
            for op, attrs in operations_attributes.items():
                strr = op
                for attr in attrs:
                    strr += '::' + str(attr)
                strr += '\n'
                f.write(strr)

        fetches = [loss, tower_new_mems, global_step,
                   gnorm, learning_rate, train_op]
        iter = 0
        total_loss, prev_step = 0., -1
        while True:
            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for key in tower_mems_np[i].keys():
                    for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]):
                        feed_dict[m] = m_np
            if iter % 10 == 7 or iter == 0:
                fetched = sess.run(fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata)
                #if iter > 0:
                profile(run_metadata, iter)
            else:
                t0 = time.time()
                fetched = sess.run(fetches, feed_dict=feed_dict)
                print(time.time() - t0)
            if iter == 0:
                mem_options = tf.profiler.ProfileOptionBuilder.time_and_memory()
                mem_options["min_bytes"] = 0
                mem_options["min_micros"] = 0
                mem_options["output"] = 'file:outfile=./profs/mem.txt'
                mem_options["select"] = ("bytes", "peak_bytes", "output_bytes",
                          "residual_bytes")
                mem = tf.profiler.profile(
                  tf.Graph(), run_meta=run_metadata, cmd="scope", options=mem_options)
                with open('profs/mem2.txt', 'w') as f:
                  f.write(str(mem))
            iter += 1

            loss_np, tower_mems_np, curr_step = fetched[:3]
            total_loss += loss_np

            if curr_step > 0 and curr_step % FLAGS.iterations == 0:
                curr_loss = total_loss / (curr_step - prev_step)
                tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
                                "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
                                    curr_step, fetched[-3], fetched[-2],
                                    curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
                total_loss, prev_step = 0., curr_step

            if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                saver.save(sess, save_path)
                tf.logging.info("Model saved in path: {}".format(save_path))

            if curr_step >= FLAGS.train_steps:
                break
Esempio n. 12
0
def train(ps_device):
    ##### Get input function and model function

    train_input_fn, record_info_dict = data_utils.get_input_fn(
        info_dir=os.path.join(FLAGS.record_info_dir, "train"),
        split="train",
        bsz_per_host=FLAGS.train_batch_size,
        seq_len=FLAGS.seq_len,
        reuse_len=FLAGS.reuse_len,
        bi_data=FLAGS.bi_data,
        num_hosts=1,
        num_core_per_host=1,  # set to one no matter how many GPUs
        perm_size=FLAGS.perm_size,
        mask_alpha=FLAGS.mask_alpha,
        mask_beta=FLAGS.mask_beta,
        use_bfloat16=FLAGS.use_bfloat16,
        num_predict=FLAGS.num_predict)

    valid_input_fn, record_info_dict_valid = data_utils.get_input_fn(
        info_dir=os.path.join(FLAGS.record_info_dir, "valid"),
        split="valid",
        bsz_per_host=FLAGS.train_batch_size,
        seq_len=FLAGS.seq_len,
        reuse_len=FLAGS.reuse_len,
        bi_data=FLAGS.bi_data,
        num_hosts=1,
        num_core_per_host=1,
        perm_size=FLAGS.perm_size,
        mask_alpha=FLAGS.mask_alpha,
        mask_beta=FLAGS.mask_beta,
        use_bfloat16=FLAGS.use_bfloat16,
        num_predict=FLAGS.num_predict)

    # for key, info in record_info_dict.items():
    num_train_batches = record_info_dict["num_batch"]
    tf.logging.info("num of train batches {}".format(
        record_info_dict["num_batch"]))
    tf.logging.info("num of validation batches {}".format(
        record_info_dict_valid["num_batch"]))

    ##### Create input tensors / placeholders
    bsz_per_core = FLAGS.train_batch_size // FLAGS.num_core_per_host

    params = {
        "batch_size": FLAGS.train_batch_size  # the whole batch
    }
    train_set = train_input_fn(params)
    valid_set = valid_input_fn(params)

    t_iter = train_set.make_initializable_iterator()
    example = t_iter.get_next()
    v_iter = valid_set.make_initializable_iterator()
    v_example = v_iter.get_next()

    if FLAGS.num_core_per_host > 1:
        # train set
        examples = [{} for _ in range(FLAGS.num_core_per_host)]
        for key in example.keys():
            vals = tf.split(example[key], FLAGS.num_core_per_host, 0)
            for device_id in range(FLAGS.num_core_per_host):
                examples[device_id][key] = vals[device_id]

        # validation set
        v_examples = [{} for _ in range(FLAGS.num_core_per_host)]
        for key in v_example.keys():
            vals = tf.split(v_example[key], FLAGS.num_core_per_host, 0)
            for device_id in range(FLAGS.num_core_per_host):
                v_examples[device_id][key] = vals[device_id]
    else:
        examples = [example]
        v_examples = [v_example]

    ##### Create computational graph
    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []
    v_tower_mems, v_tower_losses, v_tower_new_mems = [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
            tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

            # The mems for each tower is a dictionary
            mems_i = {}
            v_mems_i = {}
            if FLAGS.mem_len:
                mems_i["mems"] = create_mems_tf(bsz_per_core)
                v_mems_i["mems"] = create_mems_tf(bsz_per_core)

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                is_training=True, features=examples[i], mems=mems_i)

            v_loss_i, v_new_mems_i = single_core_graph(is_training=False,
                                                       features=v_examples[i],
                                                       mems=v_mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

            v_tower_mems.append(v_mems_i)
            v_tower_losses.append(v_loss_i)
            v_tower_new_mems.append(v_new_mems_i)

    ## average losses and gradients across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]

    if len(v_tower_losses) > 1:
        v_loss = tf.add_n(v_tower_losses) / len(v_tower_losses)
    else:
        v_loss = v_tower_losses[0]

    ## get train op
    train_op, learning_rate, gnorm = model_utils.get_train_op(
        FLAGS, None, num_train_batches, grads_and_vars=grads_and_vars)
    global_step = tf.train.get_global_step()

    ##### Training loop
    # initialize mems
    tower_mems_np = []
    v_tower_mems_np = []
    for i in range(FLAGS.num_core_per_host):
        mems_i_np = {}
        v_mems_i_np = {}
        for key in tower_mems[i].keys():
            mems_i_np[key] = initialize_mems_np(bsz_per_core)
            v_mems_i_np[key] = initialize_mems_np(bsz_per_core)
        tower_mems_np.append(mems_i_np)
        v_tower_mems_np.append(v_mems_i_np)

    saver = tf.train.Saver()

    gpu_options = tf.GPUOptions(allow_growth=True)

    model_utils.init_from_checkpoint(FLAGS, global_vars=True)

    # Create performance summaries for Tensorboard logging
    training_performance_summaries, valid_performance_summaries = tb.tensorboard_setup(
    )

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          gpu_options=gpu_options)) as sess:
        sess.run(tf.global_variables_initializer())

        # variables that are run in the session
        fetches = [
            loss, tower_new_mems, global_step, gnorm, learning_rate, train_op
        ]
        v_fetches = [v_loss, v_tower_new_mems]

        # Create writers for Tensorboard logging
        info_dict = {
            "id": FLAGS.run_id,
            "n_layers": FLAGS.n_layers,
            "d_model": FLAGS.d_model,
            "n_heads": FLAGS.n_head
        }
        train_summary_writer, valid_summary_writer = tb.create_writers(
            sess, info_dict, logging_dir=FLAGS.tb_logging_dir)

        total_loss, prev_step = 0., -1
        for i in range(FLAGS.epochs):

            # Train loop
            try:
                sess.run(t_iter.initializer)
                while True:
                    feed_dict = {}
                    for i in range(FLAGS.num_core_per_host):
                        for key in tower_mems_np[i].keys():
                            for m, m_np in zip(tower_mems[i][key],
                                               tower_mems_np[i][key]):
                                feed_dict[m] = m_np

                    fetched = sess.run(fetches, feed_dict=feed_dict)
                    loss_np, tower_mems_np, curr_step = fetched[:3]
                    total_loss += loss_np
                    print(curr_step)

                    # Log training progress
                    if curr_step > 0 and curr_step % FLAGS.log_steps == 0:
                        curr_loss = total_loss / (curr_step - prev_step)
                        summ = tb.run_train(sess,
                                            training_performance_summaries,
                                            curr_loss)
                        train_summary_writer.add_summary(summ, curr_step)
                        tf.logging.info(
                            "[{}] | gnorm {:.2f} lr {:8.6f} "
                            "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".
                            format(curr_step, fetched[-3], fetched[-2],
                                   curr_loss, math.exp(curr_loss),
                                   curr_loss / math.log(2)))
                        total_loss, prev_step = 0., curr_step

                    # Save checkpoint
                    if curr_step > 0 and FLAGS.save_steps is not None and curr_step % FLAGS.save_steps == 0:
                        save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                        saver.save(sess, save_path)
                        tf.logging.info(
                            "Model saved in path: {}".format(save_path))

            except tf.errors.OutOfRangeError:
                pass

            # Validation loop
            try:
                sess.run(v_iter.initializer)
                v_total_loss, v_steps = 0., 0
                while True:
                    v_feed_dict = {}
                    for i in range(FLAGS.num_core_per_host):
                        for key in v_tower_mems_np[i].keys():
                            for m, m_np in zip(v_tower_mems[i][key],
                                               v_tower_mems_np[i][key]):
                                v_feed_dict[m] = m_np

                    v_fetched = sess.run(v_fetches, feed_dict=v_feed_dict)
                    v_loss_np, v_tower_mems_np = v_fetched[:]
                    v_total_loss += v_loss_np
                    v_steps += 1

            except tf.errors.OutOfRangeError:
                val_loss = v_total_loss / v_steps
                v_pplx = math.exp(val_loss)
                tf.logging.info(
                    "Validation: [{}] | loss {:.2f} | pplx {:>7.2f}".format(
                        curr_step, val_loss, v_pplx))

                summ_valid = tb.run_valid(sess, valid_performance_summaries,
                                          val_loss, v_pplx)
                valid_summary_writer.add_summary(summ_valid, curr_step)

            tf.logging.info("------------ Epoch {} ------------".format(i))
Esempio n. 13
0
def evaluate(n_token, cutoffs):
  ##### Get input function and model function
  eval_input_fn, eval_record_info = data_utils.get_input_fn(
      record_info_dir=FLAGS.record_info_dir,
      split=FLAGS.eval_split,
      per_host_bsz=FLAGS.eval_batch_size,
      tgt_len=FLAGS.tgt_len,
      num_core_per_host=FLAGS.num_core_per_host,
      num_hosts=1)

  meters = {}
  warmup = 2
  meters['eval_throughput'] = AverageMeter(warmup=warmup)
  meters['eval_latency'] = AverageMeter(warmup=warmup, keep=True)

  num_batch = eval_record_info["num_batch"]
  if FLAGS.max_eval_batch > 0:
      num_batch = FLAGS.max_eval_batch
  tf.logging.info("num of batches {}".format(num_batch))

  ##### Create computational graph
  eval_set = eval_input_fn({
      "batch_size": FLAGS.eval_batch_size,
      "data_dir": FLAGS.data_dir})

  inputs, labels = eval_set.make_one_shot_iterator().get_next()

  bsz = FLAGS.eval_batch_size

  with tf.variable_scope(tf.get_variable_scope()):
    mems = [tf.placeholder(tf.float32,
                             [FLAGS.mem_len, bsz, FLAGS.d_model])
              for _ in range(FLAGS.n_layer)]

    loss, new_mems = single_core_graph(
        n_token=n_token,
        cutoffs=cutoffs,
        is_training=False,
        inp=inputs,
        tgt=labels,
        mems=mems)

  target_tokens = tf.size(labels)
  ##### Evaluation loop
  mems_np = [np.zeros([FLAGS.mem_len, bsz, FLAGS.d_model], dtype=np.float32)
          for layer in range(FLAGS.n_layer)]

  saver = tf.train.Saver()

  with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
    sess.run(tf.global_variables_initializer())

    if FLAGS.eval_ckpt_path is None:
      eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
    else:
      eval_ckpt_path = FLAGS.eval_ckpt_path
    tf.logging.info("Evaluate {}".format(eval_ckpt_path))
    saver.restore(sess, eval_ckpt_path)

    fetches = [loss, new_mems, target_tokens]

    format_str = "  >> processing batch {{:{0}d}}/{{:{0}d}}".format(
        len(str(num_batch)))

    total_loss, total_cnt, target_tokens = 0, 0, 0
    start_time = time.time()
    for step in range(num_batch):
      feed_dict = {}
      for m, m_np in zip(mems, mems_np):
        feed_dict[m] = m_np

      fetched = sess.run(fetches, feed_dict=feed_dict)

      loss_np, mems_np, tt = fetched
      target_tokens += tt
      cnt_np = 1
      total_loss += loss_np * cnt_np
      total_cnt += cnt_np

      elapsed = time.time()-start_time
      throughput = target_tokens / elapsed
      latency = elapsed*1000
      meters['eval_throughput'].update(throughput)
      meters['eval_latency'].update(latency)
      target_tokens = 0
      if (step+1) % (num_batch // 10) == 0:
        tf.logging.info(format_str.format(step+1, num_batch))
        dllogger_data = {
            'eval_latency': latency,
            'eval_throughput': throughput,
        }
        dllogger.log(step=step+1, data=dllogger_data)


      start_time = time.time()
    avg_loss = total_loss / total_cnt
    latency_data = np.array(meters['eval_latency'].vals)
    tf.logging.info("Evaluating with: bs {}, math {} ".format(FLAGS.eval_batch_size, "fp16" if FLAGS.fp16 else "fp32"))
    tf.logging.info("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}, tok/s {:>6.1f}, ms/batch {:>4.2f}".format(
        avg_loss, math.exp(avg_loss), avg_loss / math.log(2), meters['eval_throughput'].avg, meters['eval_latency'].avg))
    summary = {
        'eval_loss': avg_loss,
        'eval_ppl': math.exp(avg_loss),
        'eval_avg_throughput': meters['eval_throughput'].avg,
        'eval_avg_latency': meters['eval_latency'].avg,
    }
    for p in FLAGS.percentiles:
      p = int(p)
      tf.logging.info("Latency {}%: {:>4.2f} ms".format(
        p, np.percentile(latency_data, p)))
      summary[f'eval_{p}%_latency'] = np.percentile(latency_data, p)
    dllogger.log(step=tuple(), data=summary)
Esempio n. 14
0
def train(n_token, cutoffs, rank, local_rank, size):

  meters = {}
  warmup = 2 + 12/size
  meters['train_throughput'] = AverageMeter(warmup=warmup)
  train_batch_size = FLAGS.train_batch_size // FLAGS.batch_chunk
  ##### Get input function and model function
  train_input_fn, train_record_info = data_utils.get_input_fn(
      record_info_dir=FLAGS.record_info_dir,
      split="train",
      per_host_bsz=train_batch_size,
      tgt_len=FLAGS.tgt_len,
      num_core_per_host=FLAGS.num_core_per_host,
      num_hosts=1)

  tf.logging.info("num of batches {}".format(train_record_info["num_batch"]))

  ##### Create computational graph
  train_set = train_input_fn({
      "batch_size": train_batch_size,
      "data_dir": FLAGS.data_dir})

  inputs, labels = train_set.make_one_shot_iterator().get_next()

  per_core_bsz = train_batch_size // FLAGS.num_core_per_host

  with tf.variable_scope(tf.get_variable_scope()):
    mems = [tf.Variable(tf.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], tf.float32), trainable=False)
              for _ in range(FLAGS.n_layer)]

    loss, new_mems, all_vars = single_core_graph(
        n_token=n_token,
        cutoffs=cutoffs,
        is_training=True,
        inp=inputs,
        tgt=labels,
        mems=mems)

    assign_mems = [mems[i].assign(new_mems[i]) for i in range(FLAGS.n_layer)]

  target_tokens = tf.size(labels)

  ## configure the optimizer
  global_step = tf.train.get_or_create_global_step()

  # warmup stage: increase the learning rate linearly
  if FLAGS.warmup_steps > 0:
    warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                * FLAGS.learning_rate
  else:
    warmup_lr = 0.0

  # decay stage: decay the learning rate using the cosine schedule
  decay_lr = tf.train.cosine_decay(
      FLAGS.learning_rate,
      global_step=global_step-FLAGS.warmup_steps,
      decay_steps=FLAGS.train_steps-FLAGS.warmup_steps,
      alpha=FLAGS.min_lr_ratio)

  # choose warmup or decay
  learning_rate = tf.where(global_step < FLAGS.warmup_steps,
                           warmup_lr, decay_lr)

  # get the train op
  optimizer = lamb.LAMBOptimizer(learning_rate=learning_rate)
  if FLAGS.horovod:
    optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense=True)
  grads_and_vars = optimizer.compute_gradients(loss/FLAGS.batch_chunk, all_vars)
  grads, all_vars = zip(*grads_and_vars)

  accum_vars = [tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) for tv in all_vars]
  in_progress = tf.get_variable(name="in_progress", shape=[], dtype=tf.bool, trainable=False,
                               initializer=tf.zeros_initializer)
  accum_ops = tf.cond(in_progress,
                      lambda: [accum_vars[i].assign_add(grad) for i, grad in enumerate(grads)],
                      lambda: [accum_vars[i].assign(grad) for i, grad in enumerate(grads)])
  with tf.control_dependencies(accum_ops + assign_mems):
    acc_op = in_progress.assign(tf.ones_like(in_progress))
  final_accum_vars = [accum_vars[i] + gv for i,gv in enumerate(grads)]
  acc_clipped, acc_gnorm = tf.clip_by_global_norm(final_accum_vars, FLAGS.clip)
  clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip)
  acc_train_op = optimizer.apply_gradients(list(zip(acc_clipped, all_vars)), global_step)
  grads_and_vars = list(zip(clipped, all_vars))
  if FLAGS.jit_optimizer:
    jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
    with jit_scope():
      train_op = optimizer.apply_gradients(grads_and_vars, global_step)
  else:
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)
  final_op = tf.group(train_op, assign_mems)
  acc_final_op = tf.group(acc_train_op, assign_mems, in_progress.assign(tf.zeros_like(in_progress)))
  ##### Training loop
  saver = tf.train.Saver()

  gpu_options = tf.GPUOptions(allow_growth = True, visible_device_list = str(local_rank))
  with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options = gpu_options)) as sess:
    sess.run(tf.global_variables_initializer())
    if FLAGS.horovod:
      sess.run(hvd.broadcast_global_variables(0))

    accum = [acc_op, target_tokens]
    fetches = [loss, global_step, target_tokens, learning_rate, final_op if FLAGS.batch_chunk == 1 else acc_final_op]
    total_loss, prev_step, target_tokens = 0., -1, 0
    start_time = time.time()
    while True:
      for i in range(FLAGS.batch_chunk-1):
        _,tt = sess.run(accum)
        target_tokens += tt
      fetched = sess.run(fetches)

      loss_np, curr_step, tt = fetched[:3]
      total_loss += loss_np
      target_tokens += tt

      if curr_step > 0 and curr_step % FLAGS.log_interval == 0:
        curr_loss = total_loss / (curr_step - prev_step)
        throughput = target_tokens * size / (time.time()-start_time)
        meters['train_throughput'].update(throughput)
        if rank == 0:
          tf.logging.info("step {} | lr {:8.9f} "
                        "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}, tok/s {:>6.0f}".format(
                            curr_step, fetched[-2],
                            curr_loss, math.exp(curr_loss), curr_loss / math.log(2), throughput))
          dllogger_data = {
              'lr': fetched[-1],
              'train_loss': curr_loss,
              'train_perplexity': math.exp(curr_loss),
              'train_throughput': throughput,
          }
          dllogger.log(step=int(curr_step), data=dllogger_data)
        total_loss, prev_step, target_tokens = 0., curr_step, 0
        start_time = time.time()

      if curr_step > 0 and curr_step % FLAGS.save_steps == 0 and rank == 0:
        save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
        saver.save(sess, save_path)
        tf.logging.info("Model saved in path: {}".format(save_path))

      if curr_step == FLAGS.train_steps:
        break
    if rank == 0:
      tf.logging.info("Training throughput: {:>6.0f} tok/s".format(meters['train_throughput'].avg))
      summary = {
          'train_throughput': meters['train_throughput'].avg,
      }
      dllogger.log(step=tuple(), data=summary)
Esempio n. 15
0
def test(ps_device):

    test_input_fn, record_info_dict_test = data_utils.get_input_fn(
          info_dir=os.path.join(FLAGS.record_info_dir, "test"),
          split="test",
          bsz_per_host=FLAGS.test_batch_size,
          seq_len=FLAGS.seq_len,
          reuse_len=FLAGS.reuse_len,
          bi_data=FLAGS.bi_data,
          num_hosts=1,
          num_core_per_host=1,
          perm_size=FLAGS.perm_size,
          mask_alpha=FLAGS.mask_alpha,
          mask_beta=FLAGS.mask_beta,
          use_bfloat16=FLAGS.use_bfloat16,
          num_predict=FLAGS.num_predict)

    tf.logging.info("num of test batches {}".format(record_info_dict_test["num_batch"]))

    ##### Create input tensors / placeholders
    bsz_per_core = FLAGS.test_batch_size // FLAGS.num_core_per_host

    params = {
        "batch_size": FLAGS.test_batch_size # the whole batch
    }
    test_set = test_input_fn(params)

    t_iter = test_set.make_initializable_iterator()
    t_example = t_iter.get_next()

    if FLAGS.num_core_per_host > 1:
        # test set
        t_examples = [{} for _ in range(FLAGS.num_core_per_host)]
        for key in t_example.keys():
            vals = tf.split(t_examples[key], FLAGS.num_core_per_host, 0)
            for device_id in range(FLAGS.num_core_per_host):
                t_examples[device_id][key] = vals[device_id]
    else:
        t_examples = [t_example]

    ##### Create computational graph
    v_tower_mems, v_tower_losses, v_tower_new_mems = [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
            tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

            # The mems for each tower is a dictionary
            v_mems_i = {}
            if FLAGS.mem_len:
                v_mems_i["mems"] = create_mems_tf(bsz_per_core)
            
            v_loss_i, v_new_mems_i = single_core_graph(
                features=t_examples[i],
                mems=v_mems_i)
            
            v_tower_mems.append(v_mems_i)
            v_tower_losses.append(v_loss_i)
            v_tower_new_mems.append(v_new_mems_i)

    ## average losses and gradients across towers
    if len(v_tower_losses) > 1:
      v_loss = tf.add_n(v_tower_losses) / len(v_tower_losses)
    else:
      v_loss = v_tower_losses[0]

    gpu_options = tf.GPUOptions(allow_growth=True)

    model_utils.init_from_checkpoint(FLAGS, global_vars=True)

    # Create performance summaries for Tensorboard logging
    test_performance_summaries = tb.tensorboard_setup_test()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
        gpu_options=gpu_options)) as sess:
        sess.run(tf.global_variables_initializer())

        # Create writers for Tensorboard logging
        test_summary_writer = tb.create_test_writer(sess, logging_dir=FLAGS.tb_logging_dir)

        # initialize mems
        v_tower_mems_np = []
        for i in range(FLAGS.num_core_per_host):
            v_mems_i_np = {}
        for key in v_tower_mems[i].keys():
            v_mems_i_np[key] = initialize_mems_np(bsz_per_core)
            v_tower_mems_np.append(v_mems_i_np)
        
        v_fetches = [v_loss, v_tower_new_mems]
        
        sess.run(t_iter.initializer)
        v_total_loss = 0.
        v_steps = 0

        try:
            while True:
                v_feed_dict = {}
                for i in range(FLAGS.num_core_per_host):
                    for key in v_tower_mems_np[i].keys():
                        for m, m_np in zip(v_tower_mems[i][key], v_tower_mems_np[i][key]):
                            v_feed_dict[m] = m_np
                    
                v_fetched = sess.run(v_fetches, feed_dict=v_feed_dict)
                v_loss_np, v_tower_mems_np = v_fetched[:]
                v_total_loss += v_loss_np
                v_steps += 1
                print(v_steps)
            
        except tf.errors.OutOfRangeError:
            test_loss = v_total_loss/v_steps
            t_pplx = math.exp(test_loss)
            tf.logging.info("Test: loss {:.2f} | pplx {:>7.2f}".format(
                            test_loss,  t_pplx))
            
            summ_test = tb.run_test(sess, test_performance_summaries, test_loss, t_pplx)
            test_summary_writer.add_summary(summ_test, 1)
Esempio n. 16
0
def train(n_token, cutoffs, ps_device):
    # get TF logger
    log = logging.getLogger('tensorflow')
    log.setLevel(logging.INFO)

    # create formatter and add it to the handlers
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # create file handler which logs even debug messages
    fh = logging.FileHandler('run_train.log')
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    log.addHandler(fh)

    ##### Get input function and model function
    train_input_fn, train_record_info = data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split="train",
        per_host_bsz=FLAGS.train_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_core_per_host,
        num_hosts=1,
        use_tpu=False)

    tf.logging.info("num of batches {}".format(train_record_info["num_batch"]))

    ##### Create computational graph
    train_set = train_input_fn({
        "batch_size": FLAGS.train_batch_size,
        "data_dir": FLAGS.data_dir
    })

    input_feed, label_feed = train_set.make_one_shot_iterator().get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
    labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

    per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host

    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
            tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

            mems_i = [
                tf.placeholder(tf.float32,
                               [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                for _ in range(FLAGS.n_layer)
            ]

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=True,
                inp=inputs[i],
                tgt=labels[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

    ## average losses and gradients across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]
    grads, all_vars = zip(*grads_and_vars)

    ## clip gradient
    clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip)
    grads_and_vars = list(zip(clipped, all_vars))

    ## configure the optimizer
    global_step = tf.train.get_or_create_global_step()

    # warmup stage: increase the learning rate linearly
    if FLAGS.warmup_steps > 0:
        warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                    * FLAGS.learning_rate
    else:
        warmup_lr = 0.0

    # decay stage: decay the learning rate using the cosine schedule
    decay_lr = tf.train.cosine_decay(
        FLAGS.learning_rate,
        global_step=global_step - FLAGS.warmup_steps,
        decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
        alpha=FLAGS.min_lr_ratio)

    # choose warmup or decay
    learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr,
                             decay_lr)

    # get the train op
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    ##### Training loop
    tower_mems_np = [[
        np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model],
                 dtype=np.float32) for layer in range(FLAGS.n_layer)
    ] for core in range(FLAGS.num_core_per_host)]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.warm_start_path is not None:
            tf.logging.info("warm start from {}".format(FLAGS.warm_start_path))
            saver.restore(sess, FLAGS.warm_start_path)

        fetches = [
            loss, tower_new_mems, global_step, gnorm, learning_rate, train_op
        ]

        total_loss, prev_step = 0., -1
        while True:
            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np

            fetched = sess.run(fetches, feed_dict=feed_dict)

            loss_np, tower_mems_np, curr_step = fetched[:3]
            total_loss += loss_np

            if curr_step % 100 == 0:
                print("Current step:", curr_step)

            if curr_step > 0 and curr_step % FLAGS.iterations == 0:
                curr_loss = total_loss / (curr_step - prev_step)
                tf.logging.info(
                    "[{}] | gnorm {:.2f} lr {:8.6f} "
                    "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
                        curr_step, fetched[-3], fetched[-2], curr_loss,
                        math.exp(curr_loss), curr_loss / math.log(2)))
                total_loss, prev_step = 0., curr_step

            if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                saver.save(sess, save_path)
                tf.logging.info("Model saved in path: {}".format(save_path))

            if curr_step == FLAGS.train_steps:
                break
def dynamic_eval(n_token, cutoffs, ps_device):
    ##### Get input function and model function
    if FLAGS.rms:
        ##using training data to collect gradient statistics
        train_input_fn, train_record_info = data_utils.get_input_fn(
            record_info_dir=FLAGS.record_info_dir,
            split="train",
            per_host_bsz=FLAGS.train_batch_size,
            tgt_len=FLAGS.tgt_len,
            num_core_per_host=FLAGS.num_core_per_host,
            num_hosts=1,
            use_tpu=False)

        num_batch = train_record_info["num_batch"]

        tf.logging.info("num of batches {}".format(num_batch))

        ##### Create computational graph
        train_set = train_input_fn({
            "batch_size": FLAGS.train_batch_size,
            "data_dir": FLAGS.data_dir
        })

        input_feed, label_feed = train_set.make_one_shot_iterator().get_next()

        inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
        labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

        per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host


        tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

        for i in range(FLAGS.num_core_per_host):
            reuse = True if i > 0 else None
            with tf.device(assign_to_gpu(i, ps_device)), \
                tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):

                mems_i = [
                    tf.placeholder(
                        tf.float32,
                        [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                    for _ in range(FLAGS.n_layer)
                ]

                loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                    n_token=n_token,
                    cutoffs=cutoffs,
                    is_training=True,
                    inp=inputs[i],
                    tgt=labels[i],
                    mems=mems_i)

                tower_mems.append(mems_i)
                tower_losses.append(loss_i)
                tower_new_mems.append(new_mems_i)
                tower_grads_and_vars.append(grads_and_vars_i)

        ## sum losses across towers
        if len(tower_losses) > 1:
            loss = tf.add_n(tower_losses) / len(tower_losses)
            grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
        else:
            loss = tower_losses[0]
            grads_and_vars = tower_grads_and_vars[0]

        global_step = tf.train.get_or_create_global_step()

        optimizer = DynamicEvalOpt(learning_rate=FLAGS.learning_rate,
                                   decay_rate=FLAGS.decay_rate,
                                   eps=FLAGS.epsilon)
        optimizer.gradstat = True
        train_op = optimizer.apply_gradients(grads_and_vars, global_step)

        tower_mems_np = [[
            np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model],
                     dtype=np.float32) for layer in range(FLAGS.n_layer)
        ] for core in range(FLAGS.num_core_per_host)]

        saver = tf.train.Saver()

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            sess.run(tf.global_variables_initializer())

            if FLAGS.eval_ckpt_path is None:
                eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
            else:
                eval_ckpt_path = FLAGS.eval_ckpt_path

            tf.logging.info("Evaluate {}".format(eval_ckpt_path))
            saver.restore(sess, eval_ckpt_path)

            fetches = [loss, tower_new_mems, tf.size(label_feed), train_op]

            total_loss, prev_step = 0., -1

            total_loss, total_cnt = 0, 0

            format_str = "  >> processing batch for gradient statistics {{:{0}d}}/{{:{0}d}} ..".format(
                len(str(num_batch // 5000)))

            ## only small subset of training set used for gradient stats to save time
            for step in range(num_batch // 5000):
                if step % (num_batch // 50000) == 0:
                    tf.logging.info(format_str.format(step, num_batch // 5000))

                feed_dict = {}
                for i in range(FLAGS.num_core_per_host):
                    for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                        feed_dict[m] = m_np

                fetched = sess.run(fetches, feed_dict=feed_dict)

                loss_np, tower_mems_np, cnt_np = fetched[:3]
                total_loss += loss_np * cnt_np
                total_cnt += cnt_np

            avg_loss = total_loss / total_cnt
    ##    tf.logging.info("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
    ##        avg_loss, math.exp(avg_loss), avg_loss / math.log(2)))


#####Done gradstat

###starting dynamic eval

    eval_input_fn, eval_record_info = data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split=FLAGS.eval_split,
        per_host_bsz=FLAGS.eval_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_core_per_host,
        num_hosts=1,
        use_tpu=False)

    num_batch = eval_record_info["num_batch"]

    tf.logging.info("num of batches {}".format(num_batch))

    ##### Create computational graph
    eval_set = eval_input_fn({
        "batch_size": FLAGS.eval_batch_size,
        "data_dir": FLAGS.data_dir
    })

    input_feed, label_feed = eval_set.make_one_shot_iterator().get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
    labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

    per_core_bsz = FLAGS.eval_batch_size // FLAGS.num_core_per_host


    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
            tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):

            mems_i = [
                tf.placeholder(tf.float32,
                               [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                for _ in range(FLAGS.n_layer)
            ]

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=True,
                inp=inputs[i],
                tgt=labels[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

    ## sum losses across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]

    ## configure the optimizer
    global_step = tf.train.get_or_create_global_step()
    if not FLAGS.rms:

        optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=FLAGS.learning_rate
        )  # DynamicEvalPS(learning_rate=FLAGS.learning_rate )
    else:
        optimizer.gradstat = False
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    ##### Evaluation loop
    tower_mems_np = [[
        np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model],
                 dtype=np.float32) for layer in range(FLAGS.n_layer)
    ] for core in range(FLAGS.num_core_per_host)]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.eval_ckpt_path is None:
            eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
        else:
            eval_ckpt_path = FLAGS.eval_ckpt_path

        tf.logging.info("Evaluate {}".format(eval_ckpt_path))
        saver.restore(sess, eval_ckpt_path)

        fetches = [loss, tower_new_mems, tf.size(label_feed), train_op]

        total_loss, prev_step = 0., -1

        total_loss, total_cnt = 0, 0
        format_str = "  >> processing batch {{:{0}d}}/{{:{0}d}} ..".format(
            len(str(num_batch)))
        for step in range(num_batch // FLAGS.ratio):
            if step % (num_batch // (10 * FLAGS.ratio)) == 0:
                tf.logging.info(format_str.format(step, num_batch))

            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np

            fetched = sess.run(fetches, feed_dict=feed_dict)

            loss_np, tower_mems_np, cnt_np = fetched[:3]
            total_loss += loss_np * cnt_np
            total_cnt += cnt_np

        avg_loss = total_loss / total_cnt
        tf.logging.info("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
            avg_loss, math.exp(avg_loss), avg_loss / math.log(2)))
Esempio n. 18
0
def train_epoch(epoch, csv_logger, n_token, cutoffs):
    ps_device = "/gpu:0"

    train_input_fn, train_record_info = data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split="train",
        per_host_bsz=FLAGS.train_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_gpu,
        num_hosts=1,
        use_tpu=False)

    tf.logging.info("-" * 30)
    tf.logging.info("Starting epoch {}!".format(epoch))
    tf.logging.info("num of batches {}".format(train_record_info["num_batch"]))
    num_batch = train_record_info["num_batch"]

    train_set = train_input_fn({
        "batch_size": FLAGS.train_batch_size,
        "data_dir": FLAGS.data_dir})

    input_feed, label_feed = train_set.make_one_shot_iterator().get_next()

    inputs = tf.split(input_feed, FLAGS.num_gpu, 0)
    labels = tf.split(label_feed, FLAGS.num_gpu, 0)

    per_core_bsz = FLAGS.train_batch_size // FLAGS.num_gpu
    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    for i in range(FLAGS.num_gpu):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
                tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

            mems_i = [tf.placeholder(tf.float32,
                                     [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                      for _ in range(FLAGS.n_layer)]

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=True,
                inp=inputs[i],
                tgt=labels[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]

    grads, all_vars = zip(*grads_and_vars)

    clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip)
    grads_and_vars = list(zip(clipped, all_vars))

    global_step = tf.train.get_or_create_global_step()
    total_steps = FLAGS.epochs * num_batch

    if FLAGS.warmup_steps > 0:
        warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                    * FLAGS.learning_rate
    else:
        warmup_lr = 0.0

    decay_lr = tf.train.cosine_decay(
        FLAGS.learning_rate,
        global_step=global_step-FLAGS.warmup_steps,
        decay_steps=total_steps-FLAGS.warmup_steps,
        alpha=FLAGS.min_lr_ratio)

    learning_rate = tf.where(global_step < FLAGS.warmup_steps,
                             warmup_lr, decay_lr)

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    tower_mems_np = [
        [np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
            for layer in range(FLAGS.n_layer)]
        for core in range(FLAGS.num_gpu)
    ]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        latest_ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
        if latest_ckpt is not None:
            tf.logging.info("loading saved model from {}".format(latest_ckpt))
            saver.restore(sess, latest_ckpt)
        else:
            tf.logging.info("No previously saved model. Starting from scratch!")

        fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]
        total_loss, prev_step = 0., -1

        for ba in range(num_batch):
            feed_dict = {}
            for i in range(FLAGS.num_gpu):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np
            fetched = sess.run(fetches, feed_dict=feed_dict)
            loss_np, tower_mems_np, curr_step = fetched[:3]
            total_loss += loss_np
            if curr_step > 0 and curr_step % FLAGS.iterations == 0:
                curr_loss = total_loss / (curr_step - prev_step)
                tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
                    "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
                    curr_step, fetched[-3], fetched[-2],
                    curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
                log_dict = {
                    'train_loss': curr_loss,
                    'train_ppl': math.exp(curr_loss),
                    'train_bpc': curr_loss / math.log(2),
                    'lr': fetched[-2],
                    'global_step': curr_step,
                    'epoch': epoch
                }
                csv_logger.writerow(log_dict)
                total_loss, prev_step = 0., curr_step

            if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                saver.save(sess, save_path)
                tf.logging.info("Finished Step : {}".format(curr_step))
                tf.logging.info("Model saved in path: {}".format(save_path))


        curr_loss = total_loss / (curr_step - prev_step)
        tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
            "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
            curr_step, fetched[-3], fetched[-2],
            curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))

        save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
        saver.save(sess, save_path)
        tf.logging.info("Finished Epoch {}".format(curr_step))
        tf.logging.info("Model saved in path: {}".format(save_path))
        tf.logging.info("-" * 30)