コード例 #1
0
def evaluate(n_token, cutoffs, ps_device):
    # Get input function and model function
    eval_input_fn, eval_record_info = data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split=FLAGS.eval_split,
        per_host_bsz=FLAGS.eval_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_core_per_host,
        num_hosts=1,
        use_tpu=False)

    num_batch = eval_record_info["num_batch"]
    if FLAGS.max_eval_batch > 0:
        num_batch = FLAGS.max_eval_batch
    tf.logging.info("num of batches {}".format(num_batch))

    # Create computational graph
    eval_set = eval_input_fn({
        "batch_size": FLAGS.eval_batch_size,
        "data_dir": FLAGS.data_dir})

    input_feed, label_feed = eval_set.make_one_shot_iterator().get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
    labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

    per_core_bsz = FLAGS.eval_batch_size // FLAGS.num_core_per_host
    tower_mems, tower_losses, tower_new_mems = [], [], []

    for i in range(FLAGS.num_core_per_host):
        with tf.device(assign_to_gpu(i, ps_device)), \
             tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            mems_i = [tf.placeholder(tf.float32,
                                     [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                      for _ in range(FLAGS.n_layer)]

            loss_i, new_mems_i = single_core_graph(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=False,
                inp=inputs[i],
                tgt=labels[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)

    # sum losses across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
    else:
        loss = tower_losses[0]

    # Evaluation loop
    tower_mems_np = [
        [np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
         for layer in range(FLAGS.n_layer)]
        for core in range(FLAGS.num_core_per_host)
    ]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.eval_ckpt_path is None:
            eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
        else:
            eval_ckpt_path = FLAGS.eval_ckpt_path
        tf.logging.info("Evaluate {}".format(eval_ckpt_path))
        saver.restore(sess, eval_ckpt_path)

        fetches = [loss, tower_new_mems, tf.size(label_feed)]

        format_str = "  >> processing batch {{:{0}d}}/{{:{0}d}} ..".format(
            len(str(num_batch)))

        total_loss, total_cnt = 0, 0
        for step in range(num_batch):
            if step % (num_batch // 10) == 0:
                tf.logging.info(format_str.format(step, num_batch))

            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np

            fetched = sess.run(fetches, feed_dict=feed_dict)

            loss_np, tower_mems_np, cnt_np = fetched[:3]
            total_loss += loss_np * cnt_np
            total_cnt += cnt_np

        avg_loss = total_loss / total_cnt
        tf.logging.info("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
            avg_loss, math.exp(avg_loss), avg_loss / math.log(2)))
コード例 #2
0
def inference(n_token, cutoffs, ps_device):
    dataset_name = "doupo"
    tmp_Vocab = Vocab()
    tmp_Vocab.count_file("../data/{}/train.txt".format(dataset_name), add_eos=False)
    tmp_Vocab.build_vocab()

    n_token = len(tmp_Vocab)
    # print(tmp_Vocab.idx2sym)

    test_list = tf.placeholder(tf.int64, shape=[1, None])
    dataset = tf.data.Dataset.from_tensors(test_list)
    # dataset = dataset.batch(1, drop_remainder=True)

    iterator = dataset.make_initializable_iterator()
    input_feed = iterator.get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
    # inputs = input_feed

    per_core_bsz = 1
    tower_mems, tower_losses, tower_new_mems = [], [], []
    tower_output = []
    tower_mems_id = []
    tower_new_mems_id = []
    tower_attn_prob = []

    for i in range(FLAGS.num_core_per_host):
        with tf.device(assign_to_gpu(i, ps_device)), \
             tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            mems_i = [tf.placeholder(tf.float32,
                                     [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                      for _ in range(FLAGS.n_layer)]

            mems_i_id = [tf.placeholder(tf.int64,
                                     [FLAGS.mem_len, per_core_bsz])
                      for _ in range(FLAGS.n_layer)]

            new_mems_i, output_i, new_mems_i_id, attn_prob_i = single_core_graph_for_inference(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=False,
                inp=inputs[i],
                mems=mems_i,
                mems_id=mems_i_id)

            tower_mems.append(mems_i)
            tower_new_mems.append(new_mems_i)
            tower_output.append(output_i)
            tower_mems_id.append(mems_i_id)
            tower_new_mems_id.append(new_mems_i_id)
            tower_attn_prob.append(attn_prob_i)

    # Evaluation loop
    tower_mems_np = [
        [np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
         for layer in range(FLAGS.n_layer)]
        for core in range(FLAGS.num_core_per_host)
    ]

    tower_mems_id_np = [
        [np.zeros([FLAGS.mem_len, per_core_bsz], dtype=np.float32)
         for layer in range(FLAGS.n_layer)]
        for core in range(FLAGS.num_core_per_host)
    ]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.eval_ckpt_path is None:
            eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
        else:
            eval_ckpt_path = FLAGS.eval_ckpt_path
        print('eval_ckpt_path:', eval_ckpt_path)
        saver.restore(sess, eval_ckpt_path)

        # attention_score = tf.get_variable('transformer/layer_2/rel_attn/transpose_1:0')

        fetches = [tower_new_mems,
                   tower_output,
                   tower_new_mems_id,
                   tower_attn_prob,
                   'transformer/adaptive_embed/lookup_table:0']

        while True:
            input_text = input("seed text >>> ")
            while not input_text:
                print('Prompt should not be empty!')
                input_text = input("Model prompt >>> ")
            encoded_input = tmp_Vocab.encode_sents(input_text, ordered=True)

            with open('{}.txt'.format(dataset_name), 'a') as f:
                f.write('-' * 100+'\n')
                f.write('input:\n')
                f.write(input_text+'\n')

            output_len = 200
            progress = ProgressBar()
            for step in progress(range(output_len)):
                time.sleep(0.01)
                feed_dict = {}
                for i in range(FLAGS.num_core_per_host):
                    for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                        feed_dict[m] = m_np

                    for id, id_np in zip(tower_mems_id[i], tower_mems_id_np[i]):
                        feed_dict[id] = id_np

                sess.run(iterator.initializer, feed_dict={test_list: [encoded_input]})
                fetched = sess.run(fetches, feed_dict=feed_dict)

                tower_mems_np, output = fetched[:2]

                tower_mems_id_np = fetched[2]

                attn_prob = fetched[3]
                lookup_table = fetched[4]
                # print(attention_score)
                # print(np.array(lookup_table).shape)
                # print(np.array(tower_mems_id_np).shape)

                tmp_list = output[0][-1][0]
                tmp_list = tmp_list.tolist()

                # 下面是对结果的6种处理方式,若需要就保留,然后注释掉其他几种
                # todo 取top1
                index = top_one_result(tmp_list)
                # todo diversity
                # index = gen_diversity(tmp_list)
                # todo base on keyword
                # index = gen_on_keyword(tmp_Vocab, '喜', tmp_list, lookup_table)

                # # todo 可视化候选词
                # visualize_prob(tmp_Vocab, tmp_list,
                # '../exp_result/{}/candidates'.format(dataset_name+'mem_len500'), len(input_text))

                # # # todo 可视化attention per layer
                # visualize_attention_per_layer(tmp_Vocab, tower_mems_id_np, attn_prob, index,
                #                               '../exp_result/{}/attention_per_layer'.format(dataset_name+'mem_len500'),
                #                               len(input_text))

                # # # todo 可视化attention per head
                # visualize_attention_per_head(tmp_Vocab, tower_mems_id_np, attn_prob, index,
                #                              '../exp_result/{}/attention_per_head'.format(dataset_name+'_repeat'),
                #                              len(input_text))

                input_text += tmp_Vocab.get_sym(index) if tmp_Vocab.get_sym(index) != '<eos>' else '\n'
                encoded_input = [index]

            print(input_text)

            with open('{}.txt'.format(dataset_name), 'a') as f:
                f.write('output:\n')
                f.write(input_text+'\n')
                f.write('-'*100+'\n')
コード例 #3
0
ファイル: train_gpu.py プロジェクト: luhuaei/xlnet
def train(ps_device):
  ##### Get input function and model function

  train_input_fn, record_info_dict = data_utils.get_input_fn(
      tfrecord_dir=FLAGS.record_info_dir,
      split="train",
      bsz_per_host=FLAGS.train_batch_size,
      seq_len=FLAGS.seq_len,
      reuse_len=FLAGS.reuse_len,
      bi_data=FLAGS.bi_data,
      num_hosts=1,
      num_core_per_host=1, # set to one no matter how many GPUs
      perm_size=FLAGS.perm_size,
      mask_alpha=FLAGS.mask_alpha,
      mask_beta=FLAGS.mask_beta,
      uncased=FLAGS.uncased,
      num_passes=FLAGS.num_passes,
      use_bfloat16=FLAGS.use_bfloat16,
      num_predict=FLAGS.num_predict)

  # for key, info in record_info_dict.items():
  tf.compat.v1.logging.info("num of batches {}".format(record_info_dict["num_batch"]))

  ##### Create input tensors / placeholders
  bsz_per_core = FLAGS.train_batch_size // FLAGS.num_core_per_host

  params = {
      "batch_size": FLAGS.train_batch_size # the whole batch
  }
  train_set = train_input_fn(params)

  example = train_set.make_one_shot_iterator().get_next()

  if FLAGS.num_core_per_host > 1:
    examples = [{} for _ in range(FLAGS.num_core_per_host)]
    for key in example.keys():
      vals = tf.split(example[key], FLAGS.num_core_per_host, 0)
      for device_id in range(FLAGS.num_core_per_host):
        examples[device_id][key] = vals[device_id]
  else:
    examples = [example]

  ##### Create computational graph
  tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

  for i in range(FLAGS.num_core_per_host):
    reuse = True if i > 0 else None
    with tf.device(assign_to_gpu(i, ps_device)), \
        tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

      # The mems for each tower is a dictionary
      mems_i = {}
      if FLAGS.mem_len:
        mems_i["mems"] = create_mems_tf(bsz_per_core)

      loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
          is_training=True,
          features=examples[i],
          mems=mems_i)

      tower_mems.append(mems_i)
      tower_losses.append(loss_i)
      tower_new_mems.append(new_mems_i)
      tower_grads_and_vars.append(grads_and_vars_i)

  ## average losses and gradients across towers
  if len(tower_losses) > 1:
    loss = tf.add_n(tower_losses) / len(tower_losses)
    grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
  else:
    loss = tower_losses[0]
    grads_and_vars = tower_grads_and_vars[0]

  ## get train op
  train_op, learning_rate, gnorm = model_utils.get_train_op(FLAGS, None,
      grads_and_vars=grads_and_vars)
  global_step = tf.train.get_global_step()

  ##### Training loop
  # initialize mems
  tower_mems_np = []
  for i in range(FLAGS.num_core_per_host):
    mems_i_np = {}
    for key in tower_mems[i].keys():
      mems_i_np[key] = initialize_mems_np(bsz_per_core)
    tower_mems_np.append(mems_i_np)

  saver = tf.train.Saver()

  gpu_options = tf.GPUOptions(allow_growth=True)

  model_utils.init_from_checkpoint(FLAGS, global_vars=True)

  with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
      gpu_options=gpu_options)) as sess:
    sess.run(tf.global_variables_initializer())

    fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]

    total_loss, prev_step = 0., -1
    while True:
      feed_dict = {}
      for i in range(FLAGS.num_core_per_host):
        for key in tower_mems_np[i].keys():
          for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]):
            feed_dict[m] = m_np

      fetched = sess.run(fetches, feed_dict=feed_dict)

      loss_np, tower_mems_np, curr_step = fetched[:3]
      total_loss += loss_np

      if curr_step > 0 and curr_step % FLAGS.iterations == 0:
        curr_loss = total_loss / (curr_step - prev_step)
        tf.compat.v1.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
            "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
            curr_step, fetched[-3], fetched[-2],
            curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
        total_loss, prev_step = 0., curr_step

      if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
        save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
        saver.save(sess, save_path)
        tf.compat.v1.logging.info("Model saved in path: {}".format(save_path))

      if curr_step >= FLAGS.train_steps:
        break
コード例 #4
0
def train(n_token, cutoffs, ps_device):
    # os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'

    # Get input function and model function
    train_input_fn, train_record_info = data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split="train",
        per_host_bsz=FLAGS.train_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_core_per_host,
        num_hosts=1,
        use_tpu=False)

    tf.logging.info("num of batches {}".format(train_record_info["num_batch"]))

    # Create computational graph
    train_set = train_input_fn({
        "batch_size": FLAGS.train_batch_size,
        "data_dir": FLAGS.data_dir})

    input_feed, label_feed = train_set.make_one_shot_iterator().get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
    labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

    print_op = tf.print(inputs)

    per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host

    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        #todo  review here
        with tf.device(assign_to_gpu(i, ps_device)), \
             tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
            mems_i = [tf.placeholder(tf.float32,
                                     [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                      for _ in range(FLAGS.n_layer)]

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=True,
                inp=inputs[i],
                tgt=labels[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

    # average losses and gradients across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]
    grads, all_vars = zip(*grads_and_vars)

    # clip gradient
    clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip)
    grads_and_vars = list(zip(clipped, all_vars))

    # configure the optimizer
    global_step = tf.train.get_or_create_global_step()

    # warmup stage: increase the learning rate linearly
    if FLAGS.warmup_steps > 0:
        warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                    * FLAGS.learning_rate
    else:
        warmup_lr = 0.0

    # decay stage: decay the learning rate using the cosine schedule
    decay_lr = tf.train.cosine_decay(
        FLAGS.learning_rate,
        global_step=global_step - FLAGS.warmup_steps,
        decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
        alpha=FLAGS.min_lr_ratio)

    # choose warmup or decay
    learning_rate = tf.where(global_step < FLAGS.warmup_steps,
                             warmup_lr, decay_lr)

    # get the train op
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    # Training loop
    tower_mems_np = [
        [np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
         for layer in range(FLAGS.n_layer)]
        for core in range(FLAGS.num_core_per_host)
    ]

    saver = tf.train.Saver()

    tf.summary.scalar('learning_rate', learning_rate)
    tf.summary.scalar('loss', loss)
    # tf.summary.scalar('pplx', math.exp(curr_loss))
    merged = tf.summary.merge_all()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        # todo 放在 此处是因为不用重复的创建trainer目录能显示变量
        train_writer = tf.summary.FileWriter(os.path.join(FLAGS.model_dir, "log"), sess.graph)

        if FLAGS.warm_start_path is not None:
            tf.logging.info("warm start from {}".format(FLAGS.warm_start_path))
            saver.restore(sess, FLAGS.warm_start_path)

        fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]

        total_loss, prev_step = 0., -1
        while True:
            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np

            #old
            # fetched = sess.run(fetches, feed_dict=feed_dict)

            # with tf.control_dependencies([print_op]):
            summary, fetched = sess.run([merged, fetches], feed_dict=feed_dict)

            loss_np, tower_mems_np, curr_step = fetched[:3]
            total_loss += loss_np

            if curr_step > 0 and curr_step % FLAGS.iterations == 0:
                curr_loss = total_loss / (curr_step - prev_step)
                tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
                                "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(curr_step, fetched[-3], fetched[-2], curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
                total_loss, prev_step = 0., curr_step
                train_writer.add_summary(summary, curr_step)

            if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                save_path = os.path.join(FLAGS.model_dir, "model-{}.ckpt".format(curr_step))
                saver.save(sess, save_path)
                tf.logging.info("Model saved in path: {}".format(save_path))

            if curr_step == FLAGS.train_steps:
                train_writer.close()
                break
コード例 #5
0
def train(ps_device):
    ##### Get input function and model function

    train_input_fn, record_info_dict = data_utils.get_input_fn(
        info_dir=os.path.join(FLAGS.record_info_dir, "train"),
        split="train",
        bsz_per_host=FLAGS.train_batch_size,
        seq_len=FLAGS.seq_len,
        reuse_len=FLAGS.reuse_len,
        bi_data=FLAGS.bi_data,
        num_hosts=1,
        num_core_per_host=1,  # set to one no matter how many GPUs
        perm_size=FLAGS.perm_size,
        mask_alpha=FLAGS.mask_alpha,
        mask_beta=FLAGS.mask_beta,
        use_bfloat16=FLAGS.use_bfloat16,
        num_predict=FLAGS.num_predict)

    valid_input_fn, record_info_dict_valid = data_utils.get_input_fn(
        info_dir=os.path.join(FLAGS.record_info_dir, "valid"),
        split="valid",
        bsz_per_host=FLAGS.train_batch_size,
        seq_len=FLAGS.seq_len,
        reuse_len=FLAGS.reuse_len,
        bi_data=FLAGS.bi_data,
        num_hosts=1,
        num_core_per_host=1,
        perm_size=FLAGS.perm_size,
        mask_alpha=FLAGS.mask_alpha,
        mask_beta=FLAGS.mask_beta,
        use_bfloat16=FLAGS.use_bfloat16,
        num_predict=FLAGS.num_predict)

    # for key, info in record_info_dict.items():
    num_train_batches = record_info_dict["num_batch"]
    tf.logging.info("num of train batches {}".format(
        record_info_dict["num_batch"]))
    tf.logging.info("num of validation batches {}".format(
        record_info_dict_valid["num_batch"]))

    ##### Create input tensors / placeholders
    bsz_per_core = FLAGS.train_batch_size // FLAGS.num_core_per_host

    params = {
        "batch_size": FLAGS.train_batch_size  # the whole batch
    }
    train_set = train_input_fn(params)
    valid_set = valid_input_fn(params)

    t_iter = train_set.make_initializable_iterator()
    example = t_iter.get_next()
    v_iter = valid_set.make_initializable_iterator()
    v_example = v_iter.get_next()

    if FLAGS.num_core_per_host > 1:
        # train set
        examples = [{} for _ in range(FLAGS.num_core_per_host)]
        for key in example.keys():
            vals = tf.split(example[key], FLAGS.num_core_per_host, 0)
            for device_id in range(FLAGS.num_core_per_host):
                examples[device_id][key] = vals[device_id]

        # validation set
        v_examples = [{} for _ in range(FLAGS.num_core_per_host)]
        for key in v_example.keys():
            vals = tf.split(v_example[key], FLAGS.num_core_per_host, 0)
            for device_id in range(FLAGS.num_core_per_host):
                v_examples[device_id][key] = vals[device_id]
    else:
        examples = [example]
        v_examples = [v_example]

    ##### Create computational graph
    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []
    v_tower_mems, v_tower_losses, v_tower_new_mems = [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
            tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

            # The mems for each tower is a dictionary
            mems_i = {}
            v_mems_i = {}
            if FLAGS.mem_len:
                mems_i["mems"] = create_mems_tf(bsz_per_core)
                v_mems_i["mems"] = create_mems_tf(bsz_per_core)

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                is_training=True, features=examples[i], mems=mems_i)

            v_loss_i, v_new_mems_i = single_core_graph(is_training=False,
                                                       features=v_examples[i],
                                                       mems=v_mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

            v_tower_mems.append(v_mems_i)
            v_tower_losses.append(v_loss_i)
            v_tower_new_mems.append(v_new_mems_i)

    ## average losses and gradients across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]

    if len(v_tower_losses) > 1:
        v_loss = tf.add_n(v_tower_losses) / len(v_tower_losses)
    else:
        v_loss = v_tower_losses[0]

    ## get train op
    train_op, learning_rate, gnorm = model_utils.get_train_op(
        FLAGS, None, num_train_batches, grads_and_vars=grads_and_vars)
    global_step = tf.train.get_global_step()

    ##### Training loop
    # initialize mems
    tower_mems_np = []
    v_tower_mems_np = []
    for i in range(FLAGS.num_core_per_host):
        mems_i_np = {}
        v_mems_i_np = {}
        for key in tower_mems[i].keys():
            mems_i_np[key] = initialize_mems_np(bsz_per_core)
            v_mems_i_np[key] = initialize_mems_np(bsz_per_core)
        tower_mems_np.append(mems_i_np)
        v_tower_mems_np.append(v_mems_i_np)

    saver = tf.train.Saver()

    gpu_options = tf.GPUOptions(allow_growth=True)

    model_utils.init_from_checkpoint(FLAGS, global_vars=True)

    # Create performance summaries for Tensorboard logging
    training_performance_summaries, valid_performance_summaries = tb.tensorboard_setup(
    )

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          gpu_options=gpu_options)) as sess:
        sess.run(tf.global_variables_initializer())

        # variables that are run in the session
        fetches = [
            loss, tower_new_mems, global_step, gnorm, learning_rate, train_op
        ]
        v_fetches = [v_loss, v_tower_new_mems]

        # Create writers for Tensorboard logging
        info_dict = {
            "id": FLAGS.run_id,
            "n_layers": FLAGS.n_layers,
            "d_model": FLAGS.d_model,
            "n_heads": FLAGS.n_head
        }
        train_summary_writer, valid_summary_writer = tb.create_writers(
            sess, info_dict, logging_dir=FLAGS.tb_logging_dir)

        total_loss, prev_step = 0., -1
        for i in range(FLAGS.epochs):

            # Train loop
            try:
                sess.run(t_iter.initializer)
                while True:
                    feed_dict = {}
                    for i in range(FLAGS.num_core_per_host):
                        for key in tower_mems_np[i].keys():
                            for m, m_np in zip(tower_mems[i][key],
                                               tower_mems_np[i][key]):
                                feed_dict[m] = m_np

                    fetched = sess.run(fetches, feed_dict=feed_dict)
                    loss_np, tower_mems_np, curr_step = fetched[:3]
                    total_loss += loss_np
                    print(curr_step)

                    # Log training progress
                    if curr_step > 0 and curr_step % FLAGS.log_steps == 0:
                        curr_loss = total_loss / (curr_step - prev_step)
                        summ = tb.run_train(sess,
                                            training_performance_summaries,
                                            curr_loss)
                        train_summary_writer.add_summary(summ, curr_step)
                        tf.logging.info(
                            "[{}] | gnorm {:.2f} lr {:8.6f} "
                            "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".
                            format(curr_step, fetched[-3], fetched[-2],
                                   curr_loss, math.exp(curr_loss),
                                   curr_loss / math.log(2)))
                        total_loss, prev_step = 0., curr_step

                    # Save checkpoint
                    if curr_step > 0 and FLAGS.save_steps is not None and curr_step % FLAGS.save_steps == 0:
                        save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                        saver.save(sess, save_path)
                        tf.logging.info(
                            "Model saved in path: {}".format(save_path))

            except tf.errors.OutOfRangeError:
                pass

            # Validation loop
            try:
                sess.run(v_iter.initializer)
                v_total_loss, v_steps = 0., 0
                while True:
                    v_feed_dict = {}
                    for i in range(FLAGS.num_core_per_host):
                        for key in v_tower_mems_np[i].keys():
                            for m, m_np in zip(v_tower_mems[i][key],
                                               v_tower_mems_np[i][key]):
                                v_feed_dict[m] = m_np

                    v_fetched = sess.run(v_fetches, feed_dict=v_feed_dict)
                    v_loss_np, v_tower_mems_np = v_fetched[:]
                    v_total_loss += v_loss_np
                    v_steps += 1

            except tf.errors.OutOfRangeError:
                val_loss = v_total_loss / v_steps
                v_pplx = math.exp(val_loss)
                tf.logging.info(
                    "Validation: [{}] | loss {:.2f} | pplx {:>7.2f}".format(
                        curr_step, val_loss, v_pplx))

                summ_valid = tb.run_valid(sess, valid_performance_summaries,
                                          val_loss, v_pplx)
                valid_summary_writer.add_summary(summ_valid, curr_step)

            tf.logging.info("------------ Epoch {} ------------".format(i))
コード例 #6
0
ファイル: train_gpu.py プロジェクト: fqararyah/xlnet
def train(ps_device):
    # Get input function and model function

    train_input_fn, record_info_dict = data_utils.get_input_fn(
        tfrecord_dir=FLAGS.record_info_dir,
        split="train",
        bsz_per_host=FLAGS.train_batch_size,
        seq_len=FLAGS.seq_len,
        reuse_len=FLAGS.reuse_len,
        bi_data=FLAGS.bi_data,
        num_hosts=1,
        num_core_per_host=1,  # set to one no matter how many GPUs
        perm_size=FLAGS.perm_size,
        mask_alpha=FLAGS.mask_alpha,
        mask_beta=FLAGS.mask_beta,
        uncased=FLAGS.uncased,
        num_passes=FLAGS.num_passes,
        use_bfloat16=FLAGS.use_bfloat16,
        num_predict=FLAGS.num_predict)

    # for key, info in record_info_dict.items():
    tf.logging.info("num of batches {}".format(record_info_dict["num_batch"]))

    # Create input tensors / placeholders
    bsz_per_core = FLAGS.train_batch_size // FLAGS.num_core_per_host

    params = {
        "batch_size": FLAGS.train_batch_size  # the whole batch
    }
    train_set = train_input_fn(params)

    example = train_set.make_one_shot_iterator().get_next()

    if FLAGS.num_core_per_host > 1:
        examples = [{} for _ in range(FLAGS.num_core_per_host)]
        for key in example.keys():
            vals = tf.split(example[key], FLAGS.num_core_per_host, 0)
            for device_id in range(FLAGS.num_core_per_host):
                examples[device_id][key] = vals[device_id]
    else:
        examples = [example]

    # Create computational graph
    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
                tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

            # The mems for each tower is a dictionary
            mems_i = {}
            if FLAGS.mem_len:
                mems_i["mems"] = create_mems_tf(bsz_per_core)

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                is_training=True,
                features=examples[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

    # average losses and gradients across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]

    # get train op
    train_op, learning_rate, gnorm = model_utils.get_train_op(FLAGS, None,
                                                              grads_and_vars=grads_and_vars)
    global_step = tf.train.get_global_step()

    # Training loop
    # initialize mems
    tower_mems_np = []
    for i in range(FLAGS.num_core_per_host):
        mems_i_np = {}
        for key in tower_mems[i].keys():
            mems_i_np[key] = initialize_mems_np(bsz_per_core)
        tower_mems_np.append(mems_i_np)

    saver = tf.train.Saver()

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.97)#allow_growth=True)

    model_utils.init_from_checkpoint(FLAGS, global_vars=True)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False,
                                          gpu_options=gpu_options)) as sess:
        sess.run(tf.global_variables_initializer())
        sess.graph.finalize()
        run_metadata = tf.RunMetadata()
        options = tf.RunOptions(trace_level=tf.RunOptions.SOFTWARE_TRACE)

        dot_rep = graph_to_dot(tf.get_default_graph())
        # s = Source(dot_rep, filename="test.gv", format="PNG")
        with open('profs/xln.dot', 'w') as fwr:
            fwr.write(str(dot_rep))

        operations_tensors = {}
        operations_attributes = {}
        operations_names = tf.get_default_graph().get_operations()
        count1 = 0
        count2 = 0

        for operation in operations_names:
            operation_name = operation.name
            operations_info = tf.get_default_graph(
            ).get_operation_by_name(operation_name).values()

            try:
                operations_attributes[operation_name] = []
                operations_attributes[operation_name].append(
                    operation.type)
                operations_attributes[operation_name].append(tf.get_default_graph(
                ).get_tensor_by_name(operation_name + ':0').dtype._is_ref_dtype)
            except:
                pass

            if len(operations_info) > 0:
                if not (operations_info[0].shape.ndims is None):
                    operation_shape = operations_info[0].shape.as_list(
                    )
                    operation_dtype_size = operations_info[0].dtype.size
                    if not (operation_dtype_size is None):
                        operation_no_of_elements = 1
                        for dim in operation_shape:
                            if not(dim is None):
                                operation_no_of_elements = operation_no_of_elements * dim
                        total_size = operation_no_of_elements * operation_dtype_size
                        operations_tensors[operation_name] = total_size
                    else:
                        count1 = count1 + 1
                else:
                    count1 = count1 + 1
                    operations_tensors[operation_name] = -1

                #   print('no shape_1: ' + operation_name)
                #  print('no shape_2: ' + str(operations_info))
                #  operation_namee = operation_name + ':0'
                # tensor = tf.get_default_graph().get_tensor_by_name(operation_namee)
                # print('no shape_3:' + str(tf.shape(tensor)))
                # print('no shape:' + str(tensor.get_shape()))

            else:
                # print('no info :' + operation_name)
                # operation_namee = operation.name + ':0'
                count2 = count2 + 1
                operations_tensors[operation_name] = -1

                # try:
                #   tensor = tf.get_default_graph().get_tensor_by_name(operation_namee)
                # print(tensor)
                # print(tf.shape(tensor))
                # except:
                # print('no tensor: ' + operation_namee)
        print(count1)
        print(count2)

        with open('./profs/tensors_sz_32.txt', 'w') as f:
            for tensor, size in operations_tensors.items():
                f.write('"' + tensor + '"::' + str(size) + '\n')

        with open('./profs/operations_attributes.txt', 'w') as f:
            for op, attrs in operations_attributes.items():
                strr = op
                for attr in attrs:
                    strr += '::' + str(attr)
                strr += '\n'
                f.write(strr)

        fetches = [loss, tower_new_mems, global_step,
                   gnorm, learning_rate, train_op]
        iter = 0
        total_loss, prev_step = 0., -1
        while True:
            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for key in tower_mems_np[i].keys():
                    for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]):
                        feed_dict[m] = m_np
            if iter % 10 == 7 or iter == 0:
                fetched = sess.run(fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata)
                #if iter > 0:
                profile(run_metadata, iter)
            else:
                t0 = time.time()
                fetched = sess.run(fetches, feed_dict=feed_dict)
                print(time.time() - t0)
            if iter == 0:
                mem_options = tf.profiler.ProfileOptionBuilder.time_and_memory()
                mem_options["min_bytes"] = 0
                mem_options["min_micros"] = 0
                mem_options["output"] = 'file:outfile=./profs/mem.txt'
                mem_options["select"] = ("bytes", "peak_bytes", "output_bytes",
                          "residual_bytes")
                mem = tf.profiler.profile(
                  tf.Graph(), run_meta=run_metadata, cmd="scope", options=mem_options)
                with open('profs/mem2.txt', 'w') as f:
                  f.write(str(mem))
            iter += 1

            loss_np, tower_mems_np, curr_step = fetched[:3]
            total_loss += loss_np

            if curr_step > 0 and curr_step % FLAGS.iterations == 0:
                curr_loss = total_loss / (curr_step - prev_step)
                tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
                                "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
                                    curr_step, fetched[-3], fetched[-2],
                                    curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
                total_loss, prev_step = 0., curr_step

            if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                saver.save(sess, save_path)
                tf.logging.info("Model saved in path: {}".format(save_path))

            if curr_step >= FLAGS.train_steps:
                break
コード例 #7
0
def sent_gen(tmp_Vocab, input_txt, n_token, cutoffs, ps_device):

    test_list = tf.placeholder(tf.int64, shape=[1, None])
    dataset = tf.data.Dataset.from_tensors(test_list)
    # dataset = dataset.batch(1, drop_remainder=True)

    iterator = dataset.make_initializable_iterator()
    input_feed = iterator.get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)

    per_core_bsz = 1
    tower_mems, tower_losses, tower_new_mems = [], [], []
    tower_output = []
    tower_mems_id = []
    tower_new_mems_id = []
    tower_attn_prob = []

    for i in range(FLAGS.num_core_per_host):
        with tf.device(assign_to_gpu(i, ps_device)), \
             tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            mems_i = [
                tf.placeholder(tf.float32,
                               [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                for _ in range(FLAGS.n_layer)
            ]

            mems_i_id = [
                tf.placeholder(tf.int64, [FLAGS.mem_len, per_core_bsz])
                for _ in range(FLAGS.n_layer)
            ]

            new_mems_i, output_i, new_mems_i_id, attn_prob_i = single_core_graph_for_inference(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=False,
                inp=inputs[i],
                mems=mems_i,
                mems_id=mems_i_id)

            tower_mems.append(mems_i)
            tower_new_mems.append(new_mems_i)
            tower_output.append(output_i)
            tower_mems_id.append(mems_i_id)
            tower_new_mems_id.append(new_mems_i_id)
            tower_attn_prob.append(attn_prob_i)

    # Evaluation loop
    tower_mems_np = [[
        np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model],
                 dtype=np.float32) for layer in range(FLAGS.n_layer)
    ] for core in range(FLAGS.num_core_per_host)]

    tower_mems_id_np = [[
        np.zeros([FLAGS.mem_len, per_core_bsz], dtype=np.float32)
        for layer in range(FLAGS.n_layer)
    ] for core in range(FLAGS.num_core_per_host)]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)

        saver.restore(sess, eval_ckpt_path)

        if input_txt == "":
            txt_gen = tmp_Vocab.get_sym(
                random.randint(3,
                               len(tmp_Vocab.idx2sym) - 1))
        else:
            txt_gen = input_txt

        fetches = [
            tower_new_mems, tower_output, tower_new_mems_id, tower_attn_prob,
            'transformer/adaptive_embed/lookup_table:0'
        ]

        encoded_input = tmp_Vocab.encode_sents(txt_gen, ordered=True)

        progress = ProgressBar()
        for _ in progress(range(FLAGS.gen_len)):
            time.sleep(0.01)
            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np

                for id, id_np in zip(tower_mems_id[i], tower_mems_id_np[i]):
                    feed_dict[id] = id_np

            sess.run(iterator.initializer,
                     feed_dict={test_list: [encoded_input]})
            fetched = sess.run(fetches, feed_dict=feed_dict)

            tower_mems_np, output = fetched[:2]

            tower_mems_id_np = fetched[2]

            tmp_list = output[0][-1][0]
            tmp_list = tmp_list.tolist()

            index = top_one_result(tmp_list)

            txt_gen += tmp_Vocab.get_sym(index)
            if tmp_Vocab.get_sym(index) == "<eos>":
                break
            else:
                encoded_input = [index]

        return txt_gen
コード例 #8
0
def sent_ppl(input_txt_list, n_token, cutoffs, ps_device):

    test_list = tf.placeholder(tf.int64, shape=[1, None])
    dataset = tf.data.Dataset.from_tensors(test_list)
    # dataset = dataset.batch(1, drop_remainder=True)

    iterator = dataset.make_initializable_iterator()
    input_feed = iterator.get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)

    per_core_bsz = 1
    tower_mems, tower_losses, tower_new_mems = [], [], []
    tower_output = []
    tower_mems_id = []
    tower_new_mems_id = []
    tower_attn_prob = []

    for i in range(FLAGS.num_core_per_host):
        with tf.device(assign_to_gpu(i, ps_device)), \
             tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            mems_i = [
                tf.placeholder(tf.float32,
                               [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                for _ in range(FLAGS.n_layer)
            ]

            mems_i_id = [
                tf.placeholder(tf.int64, [FLAGS.mem_len, per_core_bsz])
                for _ in range(FLAGS.n_layer)
            ]

            new_mems_i, output_i, new_mems_i_id, attn_prob_i = single_core_graph_for_inference(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=False,
                inp=inputs[i],
                mems=mems_i,
                mems_id=mems_i_id)

            tower_mems.append(mems_i)
            tower_new_mems.append(new_mems_i)
            tower_output.append(output_i)
            tower_mems_id.append(mems_i_id)
            tower_new_mems_id.append(new_mems_i_id)
            tower_attn_prob.append(attn_prob_i)

    # Evaluation loop
    tower_mems_np = [[
        np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model],
                 dtype=np.float32) for layer in range(FLAGS.n_layer)
    ] for core in range(FLAGS.num_core_per_host)]

    tower_mems_id_np = [[
        np.zeros([FLAGS.mem_len, per_core_bsz], dtype=np.float32)
        for layer in range(FLAGS.n_layer)
    ] for core in range(FLAGS.num_core_per_host)]

    saver = tf.train.Saver()

    #with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
    gpu_config = tf.ConfigProto(allow_soft_placement=True)
    gpu_config.gpu_options.allow_growth = True  # 按需分配内存
    gpu_config.gpu_options.per_process_gpu_memory_fraction = 0.2  # 限制单进程只能占用GPU显存一定比例
    with tf.Session(config=gpu_config) as sess:
        sess.run(tf.global_variables_initializer())

        eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)

        saver.restore(sess, eval_ckpt_path)

        fetches = [
            tower_new_mems, tower_output, tower_new_mems_id, tower_attn_prob,
            'transformer/adaptive_embed/lookup_table:0'
        ]

        sent_ppl_list = []

        def _cal_ppl(log_prob, sent_len):
            ppl = pow(math.exp((-1) * log_prob), 1 / (sent_len - 1))

            return ppl

        for i in range(len(input_txt_list)):
            #tf.logging.info('#time: {}'.format(time.time()))
            input_txt = input_txt_list[i]

            tower_mems_np = [[
                np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model],
                         dtype=np.float32) for layer in range(FLAGS.n_layer)
            ] for core in range(FLAGS.num_core_per_host)]

            tower_mems_id_np = [[
                np.zeros([FLAGS.mem_len, per_core_bsz], dtype=np.float32)
                for layer in range(FLAGS.n_layer)
            ] for core in range(FLAGS.num_core_per_host)]

            #print("Encoded Input:", encoded_input)

            log_prob = 0

            for token in range(1, len(input_txt)):
                tf.logging.info('#time: {}'.format(time.time()))
                feed_dict = {}
                for i in range(FLAGS.num_core_per_host):
                    for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                        feed_dict[m] = m_np

                    for id, id_np in zip(tower_mems_id[i],
                                         tower_mems_id_np[i]):
                        feed_dict[id] = id_np

                sess.run(iterator.initializer,
                         feed_dict={test_list: [[input_txt[token - 1]]]})
                fetched = sess.run(fetches, feed_dict=feed_dict)

                tower_mems_np, output = fetched[:2]

                tower_mems_id_np = fetched[2]

                tmp_list = output[0][-1][0]
                tmp_list = tmp_list.tolist()

                e_sum = sum([math.exp(i) for i in tmp_list])
                log_prob_list = [
                    math.log(math.exp(i)) - math.log(e_sum) for i in tmp_list
                ]

                log_prob = log_prob + log_prob_list[input_txt[token]]

            sent_ppl_list.append(_cal_ppl(log_prob, len(input_txt)))

        return sent_ppl_list
コード例 #9
0
def test(ps_device):

    test_input_fn, record_info_dict_test = data_utils.get_input_fn(
          info_dir=os.path.join(FLAGS.record_info_dir, "test"),
          split="test",
          bsz_per_host=FLAGS.test_batch_size,
          seq_len=FLAGS.seq_len,
          reuse_len=FLAGS.reuse_len,
          bi_data=FLAGS.bi_data,
          num_hosts=1,
          num_core_per_host=1,
          perm_size=FLAGS.perm_size,
          mask_alpha=FLAGS.mask_alpha,
          mask_beta=FLAGS.mask_beta,
          use_bfloat16=FLAGS.use_bfloat16,
          num_predict=FLAGS.num_predict)

    tf.logging.info("num of test batches {}".format(record_info_dict_test["num_batch"]))

    ##### Create input tensors / placeholders
    bsz_per_core = FLAGS.test_batch_size // FLAGS.num_core_per_host

    params = {
        "batch_size": FLAGS.test_batch_size # the whole batch
    }
    test_set = test_input_fn(params)

    t_iter = test_set.make_initializable_iterator()
    t_example = t_iter.get_next()

    if FLAGS.num_core_per_host > 1:
        # test set
        t_examples = [{} for _ in range(FLAGS.num_core_per_host)]
        for key in t_example.keys():
            vals = tf.split(t_examples[key], FLAGS.num_core_per_host, 0)
            for device_id in range(FLAGS.num_core_per_host):
                t_examples[device_id][key] = vals[device_id]
    else:
        t_examples = [t_example]

    ##### Create computational graph
    v_tower_mems, v_tower_losses, v_tower_new_mems = [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
            tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

            # The mems for each tower is a dictionary
            v_mems_i = {}
            if FLAGS.mem_len:
                v_mems_i["mems"] = create_mems_tf(bsz_per_core)
            
            v_loss_i, v_new_mems_i = single_core_graph(
                features=t_examples[i],
                mems=v_mems_i)
            
            v_tower_mems.append(v_mems_i)
            v_tower_losses.append(v_loss_i)
            v_tower_new_mems.append(v_new_mems_i)

    ## average losses and gradients across towers
    if len(v_tower_losses) > 1:
      v_loss = tf.add_n(v_tower_losses) / len(v_tower_losses)
    else:
      v_loss = v_tower_losses[0]

    gpu_options = tf.GPUOptions(allow_growth=True)

    model_utils.init_from_checkpoint(FLAGS, global_vars=True)

    # Create performance summaries for Tensorboard logging
    test_performance_summaries = tb.tensorboard_setup_test()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
        gpu_options=gpu_options)) as sess:
        sess.run(tf.global_variables_initializer())

        # Create writers for Tensorboard logging
        test_summary_writer = tb.create_test_writer(sess, logging_dir=FLAGS.tb_logging_dir)

        # initialize mems
        v_tower_mems_np = []
        for i in range(FLAGS.num_core_per_host):
            v_mems_i_np = {}
        for key in v_tower_mems[i].keys():
            v_mems_i_np[key] = initialize_mems_np(bsz_per_core)
            v_tower_mems_np.append(v_mems_i_np)
        
        v_fetches = [v_loss, v_tower_new_mems]
        
        sess.run(t_iter.initializer)
        v_total_loss = 0.
        v_steps = 0

        try:
            while True:
                v_feed_dict = {}
                for i in range(FLAGS.num_core_per_host):
                    for key in v_tower_mems_np[i].keys():
                        for m, m_np in zip(v_tower_mems[i][key], v_tower_mems_np[i][key]):
                            v_feed_dict[m] = m_np
                    
                v_fetched = sess.run(v_fetches, feed_dict=v_feed_dict)
                v_loss_np, v_tower_mems_np = v_fetched[:]
                v_total_loss += v_loss_np
                v_steps += 1
                print(v_steps)
            
        except tf.errors.OutOfRangeError:
            test_loss = v_total_loss/v_steps
            t_pplx = math.exp(test_loss)
            tf.logging.info("Test: loss {:.2f} | pplx {:>7.2f}".format(
                            test_loss,  t_pplx))
            
            summ_test = tb.run_test(sess, test_performance_summaries, test_loss, t_pplx)
            test_summary_writer.add_summary(summ_test, 1)
コード例 #10
0
def main(_):
    if FLAGS.server_ip and FLAGS.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(FLAGS.server_ip, FLAGS.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    tf.set_random_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)

    tf.logging.set_verbosity(tf.logging.INFO)

    #### Validate flags
    if FLAGS.save_steps is not None:
        FLAGS.log_step_count_steps = min(FLAGS.log_step_count_steps,
                                         FLAGS.save_steps)

    if FLAGS.do_predict:
        predict_dir = FLAGS.predict_dir
        if not tf.gfile.Exists(predict_dir):
            tf.gfile.MakeDirs(predict_dir)

    processors = {
        "mnli_matched": MnliMatchedProcessor,
        "mnli_mismatched": MnliMismatchedProcessor,
        'sts-b': StsbProcessor,
        'imdb': ImdbProcessor,
        "yelp5": Yelp5Processor
    }

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `do_train`, `do_eval, `do_predict` or "
            "`do_submit` must be True.")

    if not tf.gfile.Exists(FLAGS.output_dir):
        tf.gfile.MakeDirs(FLAGS.output_dir)

    if not tf.gfile.Exists(FLAGS.model_dir):
        tf.gfile.MakeDirs(FLAGS.model_dir)

#   ########################### LOAD PT model
#   ########################### LOAD PT model
#   import torch
#   from pytorch_transformers import CONFIG_NAME, TF_WEIGHTS_NAME, XLNetTokenizer, XLNetConfig, XLNetForSequenceClassification

#   save_path = os.path.join(FLAGS.model_dir, TF_WEIGHTS_NAME)
#   tf.logging.info("Model loaded from path: {}".format(save_path))

#   device = torch.device("cuda", 4)
#   config = XLNetConfig.from_pretrained('xlnet-large-cased', finetuning_task=u'sts-b')
#   config_path = os.path.join(FLAGS.model_dir, CONFIG_NAME)
#   config.to_json_file(config_path)
#   pt_model = XLNetForSequenceClassification.from_pretrained(FLAGS.model_dir, from_tf=True, num_labels=1)
#   pt_model.to(device)
#   pt_model = torch.nn.DataParallel(pt_model, device_ids=[4, 5, 6, 7])

#   from torch.optim import Adam
#   optimizer = Adam(pt_model.parameters(), lr=0.001, betas=(0.9, 0.999),
#                     eps=FLAGS.adam_epsilon, weight_decay=FLAGS.weight_decay,
#                     amsgrad=False)
#   ########################### LOAD PT model
#   ########################### LOAD PT model

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    label_list = processor.get_labels() if not FLAGS.is_regression else None

    sp = spm.SentencePieceProcessor()
    sp.Load(FLAGS.spiece_model_file)

    def tokenize_fn(text):
        text = preprocess_text(text, lower=FLAGS.uncased)
        return encode_ids(sp, text)

    # run_config = model_utils.configure_tpu(FLAGS)


#   model_fn = get_model_fn(len(label_list) if label_list is not None else None)

    spm_basename = os.path.basename(FLAGS.spiece_model_file)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    # estimator = tf.estimator.Estimator(
    #     model_fn=model_fn,
    #     config=run_config)

    if FLAGS.do_train:
        train_file_base = "{}.len-{}.train.tf_record".format(
            spm_basename, FLAGS.max_seq_length)
        train_file = os.path.join(FLAGS.output_dir, train_file_base)
        tf.logging.info("Use tfrecord file {}".format(train_file))

        train_examples = processor.get_train_examples(FLAGS.data_dir)
        tf.logging.info("Num of train samples: {}".format(len(train_examples)))

        file_based_convert_examples_to_features(train_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenize_fn, train_file,
                                                FLAGS.num_passes)

        train_input_fn = file_based_input_fn_builder(
            input_file=train_file,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)

        # estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)

        ##### Create input tensors / placeholders
        bsz_per_core = FLAGS.train_batch_size // FLAGS.num_core_per_host

        params = {
            "batch_size": FLAGS.train_batch_size  # the whole batch
        }
        train_set = train_input_fn(params)

        example = train_set.make_one_shot_iterator().get_next()
        if FLAGS.num_core_per_host > 1:
            examples = [{} for _ in range(FLAGS.num_core_per_host)]
            for key in example.keys():
                vals = tf.split(example[key], FLAGS.num_core_per_host, 0)
                for device_id in range(FLAGS.num_core_per_host):
                    examples[device_id][key] = vals[device_id]
        else:
            examples = [example]

        ##### Create computational graph
        tower_losses, tower_grads_and_vars, tower_inputs, tower_hidden_states, tower_logits = [], [], [], [], []

        for i in range(FLAGS.num_core_per_host):
            reuse = True if i > 0 else None
            with tf.device(assign_to_gpu(i, "/gpu:0")), \
                tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

                loss_i, grads_and_vars_i, inputs_i, hidden_states_i, logits_i = single_core_graph(
                    is_training=True,
                    features=examples[i],
                    label_list=label_list)

                tower_losses.append(loss_i)
                tower_grads_and_vars.append(grads_and_vars_i)
                tower_inputs.append(inputs_i)
                tower_hidden_states.append(hidden_states_i)
                tower_logits.append(logits_i)

        ## average losses and gradients across towers
        if len(tower_losses) > 1:
            loss = tf.add_n(tower_losses) / len(tower_losses)
            grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
            inputs = dict((n, tf.concat([t[n] for t in tower_inputs], 0))
                          for n in tower_inputs[0])
            hidden_states = list(
                tf.concat(t, 0) for t in zip(*tower_hidden_states))
            logits = tf.concat(tower_logits, 0)
        else:
            loss = tower_losses[0]
            grads_and_vars = tower_grads_and_vars[0]
            inputs = tower_inputs[0]
            hidden_states = tower_hidden_states[0]
            logits = tower_logits[0]

        # Summaries
        merged = tf.summary.merge_all()

        ## get train op
        train_op, learning_rate, gnorm = model_utils.get_train_op(
            FLAGS, None, grads_and_vars=grads_and_vars)
        global_step = tf.train.get_global_step()

        ##### Training loop
        saver = tf.train.Saver(max_to_keep=FLAGS.max_save)

        gpu_options = tf.GPUOptions(allow_growth=True)

        #### load pretrained models
        model_utils.init_from_checkpoint(FLAGS, global_vars=True)

        writer = tf.summary.FileWriter(logdir=FLAGS.model_dir,
                                       graph=tf.get_default_graph())
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True, gpu_options=gpu_options)) as sess:
            sess.run(tf.global_variables_initializer())

            #########
            ##### PYTORCH
            import torch
            from torch.optim import Adam
            from pytorch_transformers import CONFIG_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, XLNetTokenizer, XLNetConfig, XLNetForSequenceClassification, BertAdam

            save_path = os.path.join(FLAGS.model_dir, TF_WEIGHTS_NAME + '-00')
            saver.save(sess, save_path)
            tf.logging.info("Model saved in path: {}".format(save_path))

            device = torch.device("cuda", 4)
            config = XLNetConfig.from_pretrained('xlnet-large-cased',
                                                 finetuning_task=u'sts-b',
                                                 num_labels=1)
            tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')

            # pt_model = XLNetForSequenceClassification.from_pretrained('xlnet-large-cased', num_labels=1)
            pt_model = XLNetForSequenceClassification.from_pretrained(
                save_path, from_tf=True, config=config)
            pt_model.to(device)
            pt_model = torch.nn.DataParallel(pt_model, device_ids=[4, 5, 6, 7])

            optimizer = Adam(pt_model.parameters(),
                             lr=0.001,
                             betas=(0.9, 0.999),
                             eps=FLAGS.adam_epsilon,
                             weight_decay=FLAGS.weight_decay,
                             amsgrad=False)
            # optimizer = BertAdam(pt_model.parameters(), lr=FLAGS.learning_rate, t_total=FLAGS.train_steps, warmup=FLAGS.warmup_steps / FLAGS.train_steps,
            #                      eps=FLAGS.adam_epsilon, weight_decay=FLAGS.weight_decay)
            ##### PYTORCH
            #########

            fetches = [
                loss, global_step, gnorm, learning_rate, train_op, merged,
                inputs, hidden_states, logits
            ]

            total_loss, total_loss_pt, prev_step, gnorm_pt = 0., 0., -1, 0.0
            total_logits = None
            total_labels = None
            while True:
                feed_dict = {}
                # for i in range(FLAGS.num_core_per_host):
                #   for key in tower_mems_np[i].keys():
                #     for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]):
                #       feed_dict[m] = m_np

                fetched = sess.run(fetches)

                loss_np, curr_step, gnorm_np, learning_rate_np, _, summary_np, inputs_np, hidden_states_np, logits_np = fetched
                total_loss += loss_np

                if total_logits is None:
                    total_logits = logits_np
                    total_labels = inputs_np['label_ids']
                else:
                    total_logits = np.append(total_logits, logits_np, axis=0)
                    total_labels = np.append(total_labels,
                                             inputs_np['label_ids'],
                                             axis=0)

                #########
                ##### PYTORCH
                f_inp = torch.tensor(inputs_np["input_ids"],
                                     dtype=torch.long,
                                     device=device)
                f_seg_id = torch.tensor(inputs_np["segment_ids"],
                                        dtype=torch.long,
                                        device=device)
                f_inp_mask = torch.tensor(inputs_np["input_mask"],
                                          dtype=torch.float,
                                          device=device)
                f_label = torch.tensor(inputs_np["label_ids"],
                                       dtype=torch.float,
                                       device=device)

                # with torch.no_grad():
                #   _, hidden_states_pt, _ = pt_model.transformer(f_inp, f_seg_id, f_inp_mask)
                # logits_pt, _ = pt_model(f_inp, token_type_ids=f_seg_id, input_mask=f_inp_mask)

                pt_model.train()
                outputs = pt_model(f_inp,
                                   token_type_ids=f_seg_id,
                                   input_mask=f_inp_mask,
                                   labels=f_label)
                loss_pt = outputs[0]
                loss_pt = loss_pt.mean()
                total_loss_pt += loss_pt.item()

                # # hidden_states_pt = list(t.detach().cpu().numpy() for t in hidden_states_pt)
                # # special_pt = special_pt.detach().cpu().numpy()

                # # Optimizer pt
                pt_model.zero_grad()
                loss_pt.backward()
                gnorm_pt = torch.nn.utils.clip_grad_norm_(
                    pt_model.parameters(), FLAGS.clip)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = learning_rate_np
                optimizer.step()
                ##### PYTORCH
                #########

                if curr_step > 0 and curr_step % FLAGS.log_step_count_steps == 0:
                    curr_loss = total_loss / (curr_step - prev_step)
                    curr_loss_pt = total_loss_pt / (curr_step - prev_step)
                    tf.logging.info(
                        "[{}] | gnorm {:.2f} lr {:8.6f} "
                        "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
                            curr_step, gnorm_np, learning_rate_np, curr_loss,
                            math.exp(curr_loss), curr_loss / math.log(2)))

                    #########
                    ##### PYTORCH
                    tf.logging.info(
                        "  PT [{}] | gnorm PT {:.2f} lr PT {:8.6f} "
                        "| loss PT {:.2f} | pplx PT {:>7.2f}, bpc PT {:>7.4f}".
                        format(curr_step, gnorm_pt, learning_rate_np,
                               curr_loss_pt, math.exp(curr_loss_pt),
                               curr_loss_pt / math.log(2)))
                    ##### PYTORCH
                    #########

                    total_loss, total_loss_pt, prev_step = 0., 0., curr_step
                    writer.add_summary(summary_np, global_step=curr_step)

                if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                    save_path = os.path.join(FLAGS.model_dir,
                                             "model.ckpt-{}".format(curr_step))
                    saver.save(sess, save_path)
                    tf.logging.info(
                        "Model saved in path: {}".format(save_path))

                    #########
                    ##### PYTORCH
                    # Save a trained model, configuration and tokenizer
                    model_to_save = pt_model.module if hasattr(
                        pt_model,
                        'module') else pt_model  # Only save the model it-self
                    # If we save using the predefined names, we can load using `from_pretrained`
                    output_dir = os.path.join(
                        FLAGS.output_dir, "pytorch-ckpt-{}".format(curr_step))
                    if not tf.gfile.Exists(output_dir):
                        tf.gfile.MakeDirs(output_dir)
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
                    tf.logging.info(
                        "PyTorch Model saved in path: {}".format(output_dir))
                    ##### PYTORCH
                    #########

                if curr_step >= FLAGS.train_steps:
                    break

    if FLAGS.do_eval:
        # TPU requires a fixed batch size for all batches, therefore the number
        # of examples must be a multiple of the batch size, or else examples
        # will get dropped. So we pad with fake examples which are ignored
        # later on. These do NOT count towards the metric (all tf.metrics
        # support a per-instance weight, and these get a weight of 0.0).
        #
        # Modified in XL: We also adopt the same mechanism for GPUs.
        while len(eval_examples) % FLAGS.eval_batch_size != 0:
            eval_examples.append(PaddingInputExample())

        eval_file_base = "{}.len-{}.{}.eval.tf_record".format(
            spm_basename, FLAGS.max_seq_length, FLAGS.eval_split)
        eval_file = os.path.join(FLAGS.output_dir, eval_file_base)

        file_based_convert_examples_to_features(eval_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenize_fn, eval_file)

        assert len(eval_examples) % FLAGS.eval_batch_size == 0
        eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)

        eval_input_fn = file_based_input_fn_builder(
            input_file=eval_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=True)

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True, gpu_options=gpu_options)) as sess:
            sess.run(tf.global_variables_initializer())

            ########################### LOAD PT model
            #   import torch
            #   from pytorch_transformers import CONFIG_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, XLNetTokenizer, XLNetConfig, XLNetForSequenceClassification, BertAdam

            #   save_path = os.path.join(FLAGS.model_dir, TF_WEIGHTS_NAME)
            #   saver.save(sess, save_path)
            #   tf.logging.info("Model saved in path: {}".format(save_path))

            #   device = torch.device("cuda", 4)
            #   config = XLNetConfig.from_pretrained('xlnet-large-cased', finetuning_task=u'sts-b', num_labels=1)
            #   tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
            #   config_path = os.path.join(FLAGS.model_dir, CONFIG_NAME)
            #   config.to_json_file(config_path)
            #   # pt_model = XLNetForSequenceClassification.from_pretrained('xlnet-large-cased', num_labels=1)
            #   pt_model = XLNetForSequenceClassification.from_pretrained(FLAGS.model_dir, from_tf=True)
            #   pt_model.to(device)
            #   pt_model = torch.nn.DataParallel(pt_model, device_ids=[4, 5, 6, 7])
            #   from torch.optim import Adam
            #   optimizer = Adam(pt_model.parameters(), lr=0.001, betas=(0.9, 0.999),
            #                    eps=FLAGS.adam_epsilon, weight_decay=FLAGS.weight_decay,
            #                    amsgrad=False)
            #   optimizer = BertAdam(pt_model.parameters(), lr=FLAGS.learning_rate, t_total=FLAGS.train_steps, warmup=FLAGS.warmup_steps / FLAGS.train_steps,
            #                        eps=FLAGS.adam_epsilon, weight_decay=FLAGS.weight_decay)

            ##### PYTORCH
            #########

            fetches = [
                loss, global_step, gnorm, learning_rate, train_op, merged,
                inputs, hidden_states, logits
            ]

            total_loss, total_loss_pt, prev_step, gnorm_pt = 0., 0., -1, 0.0
            total_logits = None
            total_labels = None
            while True:
                feed_dict = {}
                # for i in range(FLAGS.num_core_per_host):
                #   for key in tower_mems_np[i].keys():
                #     for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]):
                #       feed_dict[m] = m_np

                fetched = sess.run(fetches)

                loss_np, curr_step, gnorm_np, learning_rate_np, _, summary_np, inputs_np, hidden_states_np, logits_np = fetched
                total_loss += loss_np

                if total_logits is None:
                    total_logits = logits_np
                    total_labels = inputs_np['label_ids']
                else:
                    total_logits = np.append(total_logits, logits_np, axis=0)
                    total_labels = np.append(total_labels,
                                             inputs_np['label_ids'],
                                             axis=0)
コード例 #11
0
def train(n_token, cutoffs, ps_device):
    # get TF logger
    log = logging.getLogger('tensorflow')
    log.setLevel(logging.INFO)

    # create formatter and add it to the handlers
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # create file handler which logs even debug messages
    fh = logging.FileHandler('run_train.log')
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    log.addHandler(fh)

    ##### Get input function and model function
    train_input_fn, train_record_info = data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split="train",
        per_host_bsz=FLAGS.train_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_core_per_host,
        num_hosts=1,
        use_tpu=False)

    tf.logging.info("num of batches {}".format(train_record_info["num_batch"]))

    ##### Create computational graph
    train_set = train_input_fn({
        "batch_size": FLAGS.train_batch_size,
        "data_dir": FLAGS.data_dir
    })

    input_feed, label_feed = train_set.make_one_shot_iterator().get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
    labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

    per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host

    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
            tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

            mems_i = [
                tf.placeholder(tf.float32,
                               [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                for _ in range(FLAGS.n_layer)
            ]

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=True,
                inp=inputs[i],
                tgt=labels[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

    ## average losses and gradients across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]
    grads, all_vars = zip(*grads_and_vars)

    ## clip gradient
    clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip)
    grads_and_vars = list(zip(clipped, all_vars))

    ## configure the optimizer
    global_step = tf.train.get_or_create_global_step()

    # warmup stage: increase the learning rate linearly
    if FLAGS.warmup_steps > 0:
        warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                    * FLAGS.learning_rate
    else:
        warmup_lr = 0.0

    # decay stage: decay the learning rate using the cosine schedule
    decay_lr = tf.train.cosine_decay(
        FLAGS.learning_rate,
        global_step=global_step - FLAGS.warmup_steps,
        decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
        alpha=FLAGS.min_lr_ratio)

    # choose warmup or decay
    learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr,
                             decay_lr)

    # get the train op
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    ##### Training loop
    tower_mems_np = [[
        np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model],
                 dtype=np.float32) for layer in range(FLAGS.n_layer)
    ] for core in range(FLAGS.num_core_per_host)]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.warm_start_path is not None:
            tf.logging.info("warm start from {}".format(FLAGS.warm_start_path))
            saver.restore(sess, FLAGS.warm_start_path)

        fetches = [
            loss, tower_new_mems, global_step, gnorm, learning_rate, train_op
        ]

        total_loss, prev_step = 0., -1
        while True:
            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np

            fetched = sess.run(fetches, feed_dict=feed_dict)

            loss_np, tower_mems_np, curr_step = fetched[:3]
            total_loss += loss_np

            if curr_step % 100 == 0:
                print("Current step:", curr_step)

            if curr_step > 0 and curr_step % FLAGS.iterations == 0:
                curr_loss = total_loss / (curr_step - prev_step)
                tf.logging.info(
                    "[{}] | gnorm {:.2f} lr {:8.6f} "
                    "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
                        curr_step, fetched[-3], fetched[-2], curr_loss,
                        math.exp(curr_loss), curr_loss / math.log(2)))
                total_loss, prev_step = 0., curr_step

            if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                saver.save(sess, save_path)
                tf.logging.info("Model saved in path: {}".format(save_path))

            if curr_step == FLAGS.train_steps:
                break
コード例 #12
0
def dynamic_eval(n_token, cutoffs, ps_device):
    ##### Get input function and model function
    if FLAGS.rms:
        ##using training data to collect gradient statistics
        train_input_fn, train_record_info = data_utils.get_input_fn(
            record_info_dir=FLAGS.record_info_dir,
            split="train",
            per_host_bsz=FLAGS.train_batch_size,
            tgt_len=FLAGS.tgt_len,
            num_core_per_host=FLAGS.num_core_per_host,
            num_hosts=1,
            use_tpu=False)

        num_batch = train_record_info["num_batch"]

        tf.logging.info("num of batches {}".format(num_batch))

        ##### Create computational graph
        train_set = train_input_fn({
            "batch_size": FLAGS.train_batch_size,
            "data_dir": FLAGS.data_dir
        })

        input_feed, label_feed = train_set.make_one_shot_iterator().get_next()

        inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
        labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

        per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host


        tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

        for i in range(FLAGS.num_core_per_host):
            reuse = True if i > 0 else None
            with tf.device(assign_to_gpu(i, ps_device)), \
                tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):

                mems_i = [
                    tf.placeholder(
                        tf.float32,
                        [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                    for _ in range(FLAGS.n_layer)
                ]

                loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                    n_token=n_token,
                    cutoffs=cutoffs,
                    is_training=True,
                    inp=inputs[i],
                    tgt=labels[i],
                    mems=mems_i)

                tower_mems.append(mems_i)
                tower_losses.append(loss_i)
                tower_new_mems.append(new_mems_i)
                tower_grads_and_vars.append(grads_and_vars_i)

        ## sum losses across towers
        if len(tower_losses) > 1:
            loss = tf.add_n(tower_losses) / len(tower_losses)
            grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
        else:
            loss = tower_losses[0]
            grads_and_vars = tower_grads_and_vars[0]

        global_step = tf.train.get_or_create_global_step()

        optimizer = DynamicEvalOpt(learning_rate=FLAGS.learning_rate,
                                   decay_rate=FLAGS.decay_rate,
                                   eps=FLAGS.epsilon)
        optimizer.gradstat = True
        train_op = optimizer.apply_gradients(grads_and_vars, global_step)

        tower_mems_np = [[
            np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model],
                     dtype=np.float32) for layer in range(FLAGS.n_layer)
        ] for core in range(FLAGS.num_core_per_host)]

        saver = tf.train.Saver()

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            sess.run(tf.global_variables_initializer())

            if FLAGS.eval_ckpt_path is None:
                eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
            else:
                eval_ckpt_path = FLAGS.eval_ckpt_path

            tf.logging.info("Evaluate {}".format(eval_ckpt_path))
            saver.restore(sess, eval_ckpt_path)

            fetches = [loss, tower_new_mems, tf.size(label_feed), train_op]

            total_loss, prev_step = 0., -1

            total_loss, total_cnt = 0, 0

            format_str = "  >> processing batch for gradient statistics {{:{0}d}}/{{:{0}d}} ..".format(
                len(str(num_batch // 5000)))

            ## only small subset of training set used for gradient stats to save time
            for step in range(num_batch // 5000):
                if step % (num_batch // 50000) == 0:
                    tf.logging.info(format_str.format(step, num_batch // 5000))

                feed_dict = {}
                for i in range(FLAGS.num_core_per_host):
                    for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                        feed_dict[m] = m_np

                fetched = sess.run(fetches, feed_dict=feed_dict)

                loss_np, tower_mems_np, cnt_np = fetched[:3]
                total_loss += loss_np * cnt_np
                total_cnt += cnt_np

            avg_loss = total_loss / total_cnt
    ##    tf.logging.info("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
    ##        avg_loss, math.exp(avg_loss), avg_loss / math.log(2)))


#####Done gradstat

###starting dynamic eval

    eval_input_fn, eval_record_info = data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split=FLAGS.eval_split,
        per_host_bsz=FLAGS.eval_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_core_per_host,
        num_hosts=1,
        use_tpu=False)

    num_batch = eval_record_info["num_batch"]

    tf.logging.info("num of batches {}".format(num_batch))

    ##### Create computational graph
    eval_set = eval_input_fn({
        "batch_size": FLAGS.eval_batch_size,
        "data_dir": FLAGS.data_dir
    })

    input_feed, label_feed = eval_set.make_one_shot_iterator().get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
    labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

    per_core_bsz = FLAGS.eval_batch_size // FLAGS.num_core_per_host


    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
            tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):

            mems_i = [
                tf.placeholder(tf.float32,
                               [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                for _ in range(FLAGS.n_layer)
            ]

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=True,
                inp=inputs[i],
                tgt=labels[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

    ## sum losses across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]

    ## configure the optimizer
    global_step = tf.train.get_or_create_global_step()
    if not FLAGS.rms:

        optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=FLAGS.learning_rate
        )  # DynamicEvalPS(learning_rate=FLAGS.learning_rate )
    else:
        optimizer.gradstat = False
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    ##### Evaluation loop
    tower_mems_np = [[
        np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model],
                 dtype=np.float32) for layer in range(FLAGS.n_layer)
    ] for core in range(FLAGS.num_core_per_host)]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.eval_ckpt_path is None:
            eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
        else:
            eval_ckpt_path = FLAGS.eval_ckpt_path

        tf.logging.info("Evaluate {}".format(eval_ckpt_path))
        saver.restore(sess, eval_ckpt_path)

        fetches = [loss, tower_new_mems, tf.size(label_feed), train_op]

        total_loss, prev_step = 0., -1

        total_loss, total_cnt = 0, 0
        format_str = "  >> processing batch {{:{0}d}}/{{:{0}d}} ..".format(
            len(str(num_batch)))
        for step in range(num_batch // FLAGS.ratio):
            if step % (num_batch // (10 * FLAGS.ratio)) == 0:
                tf.logging.info(format_str.format(step, num_batch))

            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np

            fetched = sess.run(fetches, feed_dict=feed_dict)

            loss_np, tower_mems_np, cnt_np = fetched[:3]
            total_loss += loss_np * cnt_np
            total_cnt += cnt_np

        avg_loss = total_loss / total_cnt
        tf.logging.info("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
            avg_loss, math.exp(avg_loss), avg_loss / math.log(2)))
コード例 #13
0
def train_epoch(epoch, csv_logger, n_token, cutoffs):
    ps_device = "/gpu:0"

    train_input_fn, train_record_info = data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split="train",
        per_host_bsz=FLAGS.train_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_gpu,
        num_hosts=1,
        use_tpu=False)

    tf.logging.info("-" * 30)
    tf.logging.info("Starting epoch {}!".format(epoch))
    tf.logging.info("num of batches {}".format(train_record_info["num_batch"]))
    num_batch = train_record_info["num_batch"]

    train_set = train_input_fn({
        "batch_size": FLAGS.train_batch_size,
        "data_dir": FLAGS.data_dir})

    input_feed, label_feed = train_set.make_one_shot_iterator().get_next()

    inputs = tf.split(input_feed, FLAGS.num_gpu, 0)
    labels = tf.split(label_feed, FLAGS.num_gpu, 0)

    per_core_bsz = FLAGS.train_batch_size // FLAGS.num_gpu
    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    for i in range(FLAGS.num_gpu):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
                tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

            mems_i = [tf.placeholder(tf.float32,
                                     [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                      for _ in range(FLAGS.n_layer)]

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=True,
                inp=inputs[i],
                tgt=labels[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]

    grads, all_vars = zip(*grads_and_vars)

    clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip)
    grads_and_vars = list(zip(clipped, all_vars))

    global_step = tf.train.get_or_create_global_step()
    total_steps = FLAGS.epochs * num_batch

    if FLAGS.warmup_steps > 0:
        warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                    * FLAGS.learning_rate
    else:
        warmup_lr = 0.0

    decay_lr = tf.train.cosine_decay(
        FLAGS.learning_rate,
        global_step=global_step-FLAGS.warmup_steps,
        decay_steps=total_steps-FLAGS.warmup_steps,
        alpha=FLAGS.min_lr_ratio)

    learning_rate = tf.where(global_step < FLAGS.warmup_steps,
                             warmup_lr, decay_lr)

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    tower_mems_np = [
        [np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
            for layer in range(FLAGS.n_layer)]
        for core in range(FLAGS.num_gpu)
    ]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        latest_ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
        if latest_ckpt is not None:
            tf.logging.info("loading saved model from {}".format(latest_ckpt))
            saver.restore(sess, latest_ckpt)
        else:
            tf.logging.info("No previously saved model. Starting from scratch!")

        fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]
        total_loss, prev_step = 0., -1

        for ba in range(num_batch):
            feed_dict = {}
            for i in range(FLAGS.num_gpu):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np
            fetched = sess.run(fetches, feed_dict=feed_dict)
            loss_np, tower_mems_np, curr_step = fetched[:3]
            total_loss += loss_np
            if curr_step > 0 and curr_step % FLAGS.iterations == 0:
                curr_loss = total_loss / (curr_step - prev_step)
                tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
                    "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
                    curr_step, fetched[-3], fetched[-2],
                    curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
                log_dict = {
                    'train_loss': curr_loss,
                    'train_ppl': math.exp(curr_loss),
                    'train_bpc': curr_loss / math.log(2),
                    'lr': fetched[-2],
                    'global_step': curr_step,
                    'epoch': epoch
                }
                csv_logger.writerow(log_dict)
                total_loss, prev_step = 0., curr_step

            if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                saver.save(sess, save_path)
                tf.logging.info("Finished Step : {}".format(curr_step))
                tf.logging.info("Model saved in path: {}".format(save_path))


        curr_loss = total_loss / (curr_step - prev_step)
        tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
            "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
            curr_step, fetched[-3], fetched[-2],
            curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))

        save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
        saver.save(sess, save_path)
        tf.logging.info("Finished Epoch {}".format(curr_step))
        tf.logging.info("Model saved in path: {}".format(save_path))
        tf.logging.info("-" * 30)
コード例 #14
0
def condition_inference(n_token, ps_device):
    tf.logging.info("Now,starting conditional mode")
    inference_length = 100
    inference_bsz = FLAGS.inference_bsz

    # 读取保存的corpus,用于将把词转化为词索引。在必须inference阶段保证corpus已经存在!!!
    print("reading the saved corpus object")
    corpus = simple_data_utils.get_saved_corpus(
        "data/taobao/")  # data path,dataset name
    print("read corpus Done")

    # raw_text = input(" Model prompt >>> ")
    raw_text = '我 喜欢 这件'
    raw_text = raw_text.split()
    while not raw_text:
        print('Prompt should not be empty!')
        raw_text = input(
            "Model prompt >>> ")  # 输入的raw_text需要在输入时分词,如"杭州 艾耕 科技"

    text_indices = np.array([corpus.vocab.get_indices(raw_text)])

    # 把索引对应生成语句
    generated_sentence = corpus.vocab.get_symbols(text_indices[0])
    print(generated_sentence)
    print(text_indices)

    # 搭建模型图,在单机上inference
    with tf.device(assign_to_gpu(0, ps_device)), tf.variable_scope(
            tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        all_mems = [
            tf.placeholder(tf.float32,
                           [FLAGS.mem_len, inference_bsz, FLAGS.d_model])
            for _ in range(FLAGS.n_layer)
        ]
        input_data = tf.placeholder(tf.int32, [inference_bsz, None],
                                    name='input_data')
        new_all_mems, logits_out = inference_graph(n_token=n_token,
                                                   inp=input_data,
                                                   mems=all_mems)
        all_mems_past = [
            np.zeros([FLAGS.mem_len, inference_bsz, FLAGS.d_model],
                     dtype=np.float32) for layer in range(FLAGS.n_layer)
        ]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        if FLAGS.eval_ckpt_path is None:
            eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
        else:
            eval_ckpt_path = FLAGS.eval_ckpt_path
        tf.logging.info("Evaluate {}".format(eval_ckpt_path))

        saver.restore(sess, eval_ckpt_path)
        print("restore ckpt Done")

        feed_dict = {input_data: text_indices}
        fetches = [new_all_mems, logits_out]

        for i in range(inference_length):
            for m, m_np in zip(all_mems, all_mems_past):
                feed_dict[m] = m_np
            all_mems_past, _logtits_out = sess.run(
                fetches,
                feed_dict=feed_dict)  # _logtits_out:[第几个词,第几个句子,词表中各个词的概率]
            next_word_index = np.argmax(_logtits_out[-1],
                                        axis=1)  # [word_index]
            text_indices = np.concatenate([text_indices, [next_word_index]],
                                          axis=1)
            feed_dict[input_data] = text_indices

        for sentence in text_indices:
            print(corpus.vocab.get_symbols(sentence))
コード例 #15
0
def train(n_token, ps_device):
    #  Get input function and model function
    train_input_fn, train_record_info = simple_data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split="train",
        per_host_bsz=FLAGS.train_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_core_per_host,
        num_hosts=1)

    num_batches = train_record_info["num_batch"]
    tf.logging.info("num of batches {}".format(num_batches))
    tf.logging.info("run {} epochs:".format(TRAIN_STEPS / num_batches))

    # Create computational graph
    train_set = train_input_fn({
        "batch_size": FLAGS.train_batch_size,
        "data_dir": FLAGS.data_dir
    })

    input_feed, label_feed = train_set.make_one_shot_iterator().get_next()

    # 因为需要把一批数据分配到不同的机器上,需要将这批数据分割
    # tf.split(input, num_split, dimension) # num_split:份数,dimension:在哪个维度上切分,函数返回的列表(list)
    inputs = tf.split(input_feed, FLAGS.num_core_per_host,
                      0)  # 第0个维度表示batch size
    labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

    per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host

    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    # assign_to_gpu(i, ps_device)
    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
             tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
            all_mems = [
                tf.placeholder(tf.float32,
                               [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                for _ in range(FLAGS.n_layer)
            ]  # all_mems,list类型,保存了每一层的mems

            loss_i, new_all_mems, grads_and_vars_i = single_core_graph(  # 为什么变量也要呢,有啥用???
                n_token=n_token,  # 字表字母的个数,27,26+1,a..z+_
                is_training=True,
                inp=inputs[i],
                labels=labels[i],
                mems=all_mems)

            tower_mems.append(all_mems)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_all_mems)
            tower_grads_and_vars.append(grads_and_vars_i)

    #  average losses and gradients across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]
    grads, all_vars = zip(*grads_and_vars)

    # clip gradient
    clipped, gnorm = tf.clip_by_global_norm(grads,
                                            FLAGS.clip)  # 梯度张量和一个所有张量的全局范数
    grads_and_vars = list(zip(clipped, all_vars))

    # configure the optimizer
    global_step = tf.train.get_or_create_global_step()

    # warmup stage: increase the learning rate linearly
    if FLAGS.warmup_steps > 0:
        warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                    * FLAGS.learning_rate
    else:
        warmup_lr = 0.0

    # decay stage: decay the learning rate using the cosine schedule
    decay_lr = tf.train.cosine_decay(
        FLAGS.learning_rate,
        global_step=global_step - FLAGS.warmup_steps,
        decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
        alpha=FLAGS.min_lr_ratio)

    # choose warmup or decay
    learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr,
                             decay_lr)

    # get the train op
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    # Training loop
    tower_mems_np = [[
        np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model],
                 dtype=np.float32) for layer in range(FLAGS.n_layer)
    ] for core in range(FLAGS.num_core_per_host)]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.warm_start_path is not None:
            tf.logging.info("warm start from {}".format(FLAGS.warm_start_path))
            saver.restore(sess, FLAGS.warm_start_path)

        fetches = [
            loss, tower_new_mems, global_step, gnorm, learning_rate, train_op
        ]

        total_loss, prev_step = 0., -1
        epoch = 0
        while True:
            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np

            fetched = sess.run(fetches, feed_dict=feed_dict)

            loss_np, tower_mems_np, curr_step = fetched[:3]
            total_loss += loss_np

            if curr_step > 0 and curr_step % FLAGS.iterations == 0:
                curr_loss = total_loss / (curr_step - prev_step)
                tf.logging.info(
                    "[{}] | gnorm {:.2f} lr {:8.6f} "
                    "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
                        curr_step, fetched[-3], fetched[-2], curr_loss,
                        math.exp(curr_loss), curr_loss / math.log(2)))
                total_loss, prev_step = 0., curr_step

            if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                saver.save(sess, save_path)
                tf.logging.info("Model saved in path: {}".format(save_path))

            if curr_step > 0 and curr_step % num_batches == 0:  # 整除一次,相当于一个epoch
                epoch += 1
                tf.logging.info("epoch: {} Done".format(epoch))
            if curr_step == FLAGS.train_steps:
                break

        tf.logging.info("run {} epochs:".format(TRAIN_STEPS / num_batches))