コード例 #1
0
def train(n_token, cutoffs, ps_device):
    # os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'

    # Get input function and model function
    train_input_fn, train_record_info = data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split="train",
        per_host_bsz=FLAGS.train_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_core_per_host,
        num_hosts=1,
        use_tpu=False)

    tf.logging.info("num of batches {}".format(train_record_info["num_batch"]))

    # Create computational graph
    train_set = train_input_fn({
        "batch_size": FLAGS.train_batch_size,
        "data_dir": FLAGS.data_dir})

    input_feed, label_feed = train_set.make_one_shot_iterator().get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
    labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

    print_op = tf.print(inputs)

    per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host

    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        #todo  review here
        with tf.device(assign_to_gpu(i, ps_device)), \
             tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
            mems_i = [tf.placeholder(tf.float32,
                                     [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                      for _ in range(FLAGS.n_layer)]

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=True,
                inp=inputs[i],
                tgt=labels[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

    # average losses and gradients across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]
    grads, all_vars = zip(*grads_and_vars)

    # clip gradient
    clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip)
    grads_and_vars = list(zip(clipped, all_vars))

    # configure the optimizer
    global_step = tf.train.get_or_create_global_step()

    # warmup stage: increase the learning rate linearly
    if FLAGS.warmup_steps > 0:
        warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                    * FLAGS.learning_rate
    else:
        warmup_lr = 0.0

    # decay stage: decay the learning rate using the cosine schedule
    decay_lr = tf.train.cosine_decay(
        FLAGS.learning_rate,
        global_step=global_step - FLAGS.warmup_steps,
        decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
        alpha=FLAGS.min_lr_ratio)

    # choose warmup or decay
    learning_rate = tf.where(global_step < FLAGS.warmup_steps,
                             warmup_lr, decay_lr)

    # get the train op
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    # Training loop
    tower_mems_np = [
        [np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
         for layer in range(FLAGS.n_layer)]
        for core in range(FLAGS.num_core_per_host)
    ]

    saver = tf.train.Saver()

    tf.summary.scalar('learning_rate', learning_rate)
    tf.summary.scalar('loss', loss)
    # tf.summary.scalar('pplx', math.exp(curr_loss))
    merged = tf.summary.merge_all()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        # todo 放在 此处是因为不用重复的创建trainer目录能显示变量
        train_writer = tf.summary.FileWriter(os.path.join(FLAGS.model_dir, "log"), sess.graph)

        if FLAGS.warm_start_path is not None:
            tf.logging.info("warm start from {}".format(FLAGS.warm_start_path))
            saver.restore(sess, FLAGS.warm_start_path)

        fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]

        total_loss, prev_step = 0., -1
        while True:
            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np

            #old
            # fetched = sess.run(fetches, feed_dict=feed_dict)

            # with tf.control_dependencies([print_op]):
            summary, fetched = sess.run([merged, fetches], feed_dict=feed_dict)

            loss_np, tower_mems_np, curr_step = fetched[:3]
            total_loss += loss_np

            if curr_step > 0 and curr_step % FLAGS.iterations == 0:
                curr_loss = total_loss / (curr_step - prev_step)
                tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
                                "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(curr_step, fetched[-3], fetched[-2], curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
                total_loss, prev_step = 0., curr_step
                train_writer.add_summary(summary, curr_step)

            if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                save_path = os.path.join(FLAGS.model_dir, "model-{}.ckpt".format(curr_step))
                saver.save(sess, save_path)
                tf.logging.info("Model saved in path: {}".format(save_path))

            if curr_step == FLAGS.train_steps:
                train_writer.close()
                break
コード例 #2
0
def train(train_data, valid_data, n_token_N, n_token_T, cutoffs, vocab_size,
          fout):

    input_dataN = tf.placeholder(dtype=tf.int32,
                                 shape=[FLAGS.train_batch_size, FLAGS.tgt_len],
                                 name="input_dataN")
    input_dataT = tf.placeholder(dtype=tf.int32,
                                 shape=[FLAGS.train_batch_size, FLAGS.tgt_len],
                                 name="input_dataT")
    input_dataPath = tf.placeholder(
        dtype=tf.int32,
        shape=[FLAGS.train_batch_size, FLAGS.tgt_len, 5],
        name="input_dataPath")

    # targetsN = tf.placeholder(dtype=tf.int32, shape=[FLAGS.train_batch_size, FLAGS.tgt_len], name="targetsN")
    targetsT = tf.placeholder(dtype=tf.int32,
                              shape=[FLAGS.train_batch_size, FLAGS.tgt_len],
                              name="targetsT")

    inputsN = tf.split(input_dataN, len(gpu_list), 0)
    inputsT = tf.split(input_dataT, len(gpu_list), 0)
    inputsPath = tf.split(input_dataPath, len(gpu_list), 0)
    # labelsN = tf.split(targetsN, len(gpu_list), 0)
    labelsT = tf.split(targetsT, len(gpu_list), 0)

    per_core_bsz = FLAGS.train_batch_size // len(gpu_list)

    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars, predictionN, predictionT = [], [], [], [], [], []

    for i, gpu_id in enumerate(gpu_list):
        with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)):
            with tf.variable_scope(tf.get_variable_scope(), reuse=(i > 0)):
                mems_i = [
                    tf.placeholder(tf.float32, [
                        FLAGS.mem_len, per_core_bsz,
                        FLAGS.d_model_T + FLAGS.d_model_N
                    ]) for _ in range(FLAGS.n_layer)
                ]

                loss_i, new_mems_i, grads_and_vars_i, predictionT_i = single_core_graph(
                    n_token_N=n_token_N,
                    n_token_T=n_token_T,
                    cutoffs=cutoffs,
                    is_training=True,
                    inpN=inputsN[i],
                    inpT=inputsT[i],
                    # tgtN=labelsN[i],
                    tgtT=labelsT[i],
                    inputPath=inputsPath[i],
                    h_par=FLAGS.h_par,
                    mems=mems_i,
                    alpha=FLAGS.alpha)

                tower_mems.append(mems_i)
                tower_losses.append(loss_i)
                tower_new_mems.append(new_mems_i)
                tower_grads_and_vars.append(grads_and_vars_i)
                # predictionN.append(predictionN_i)
                predictionT.append(predictionT_i)

    ## average losses and gradients across towers
    if len(tower_losses) >= 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        # predictionN = tf.concat(predictionN, axis=0)
        predictionT = tf.concat(predictionT, axis=0)
        # acc_N = tf.reduce_mean(tf.cast(predictionN, tf.float32))
        acc_T = tf.reduce_mean(tf.cast(predictionT, tf.float32))
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]
    grads, all_vars = zip(*grads_and_vars)

    ## clip gradient
    clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip)
    grads_and_vars = list(zip(clipped, all_vars))

    ## configure the optimizer
    global_step = tf.train.get_or_create_global_step()

    # warmup stage: increase the learning rate linearly
    if FLAGS.warmup_steps > 0:
        warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                    * FLAGS.learning_rate
    else:
        warmup_lr = 0.0

    # decay stage: decay the learning rate using the cosine schedule
    decay_lr = tf.train.cosine_decay(
        FLAGS.learning_rate,
        global_step=global_step - FLAGS.warmup_steps,
        decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
        alpha=FLAGS.min_lr_ratio)

    # choose warmup or decay
    learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr,
                             decay_lr)

    # get the train op
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    ##### Training loop
    tower_mems_np = [[
        np.zeros(
            [FLAGS.mem_len, per_core_bsz, FLAGS.d_model_N + FLAGS.d_model_T],
            dtype=np.float32) for layer in range(FLAGS.n_layer)
    ] for core in range(len(gpu_list))]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())
        # saver.restore(sess, os.path.join(FLAGS.model_dir, "model.ckpt"))
        for epoch in range(FLAGS.epochs):
            min_loss = float("INF")
            if FLAGS.warm_start_path is not None:
                tf.logging.info("warm start from {}".format(
                    FLAGS.warm_start_path))
                saver.restore(sess, FLAGS.warm_start_path)

            fetches = [
                loss, tower_new_mems, global_step, gnorm, learning_rate,
                train_op, acc_T
            ]

            total_loss, prev_step = 0., -1
            data_loader = reader.real_data_producer(train_data,
                                                    FLAGS.train_batch_size,
                                                    FLAGS.tgt_len, vocab_size)
            step = 0
            while True:
                feed_dict = {}
                dataN, tN, dataT, tT, epoch_size, eof_indicator, input_dataP, dataPath = next(
                    data_loader)

                feed_dict[input_dataN] = dataN
                feed_dict[input_dataT] = dataT
                feed_dict[input_dataPath] = dataPath
                # feed_dict[targetsN] = tN
                feed_dict[targetsT] = tT

                for i in range(len(gpu_list)):
                    for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                        feed_dict[m] = m_np

                fetched = sess.run(fetches, feed_dict=feed_dict)

                loss_np, tower_mems_np, curr_step = fetched[:3]
                # accuracy_N, accuracy_T = (fetched[-2], fetched[-1])
                accuracy_T = fetched[-1]
                total_loss += loss_np

                if curr_step > 0 and curr_step % FLAGS.iterations == 0:
                    curr_loss = total_loss / (curr_step - prev_step)

                    tf.logging.info(
                        "Epoch {}  |  [{}] | gnorm {:.2f} lr {:8.6f} "
                        "| loss {:.2f} | pplx {:>7.2f}, accT {:.4f}, bpc {:>7.4f}"
                        .format(epoch + 1, curr_step, fetched[-4], fetched[-3],
                                curr_loss, math.exp(curr_loss), accuracy_T,
                                curr_loss / math.log(2)))
                    total_loss, prev_step = 0., curr_step

                    print(
                        "Epoch {}  |  [{}] | gnorm {:.2f} lr {:8.6f} "
                        "| loss {:.2f} | pplx {:>7.2f}, accT {:.4f}, bpc {:>7.4f}"
                        .format(epoch + 1, curr_step, fetched[-4], fetched[-3],
                                curr_loss, math.exp(curr_loss), accuracy_T,
                                curr_loss / math.log(2)),
                        file=fout)
                    fout.flush()

                    if FLAGS.model_dir:
                        if curr_loss < min_loss:
                            min_loss = curr_loss
                            save_path = os.path.join(FLAGS.model_dir,
                                                     "model.ckpt")
                            saver.save(sess, save_path)
                            tf.logging.info(
                                "Model saved in path: {}".format(save_path))

                    # if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                    #     save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                    #     saver.save(sess, save_path)
                    #     tf.logging.info("Model saved in path: {}".format(save_path))
                step += 1
                if step >= epoch_size:
                    break
                    # if curr_step == FLAGS.train_steps:
                    #     break

            # evaluate
            eval_fetches = [loss, tower_new_mems, acc_T]
            total_loss, prev_step = 0., -1
            total_accN = 0.
            total_accT = 0.
            data_loader = reader.real_data_producer(valid_data,
                                                    FLAGS.train_batch_size,
                                                    FLAGS.tgt_len, vocab_size)
            step = 0
            while True:
                feed_dict = {}
                dataN, tN, dataT, tT, epoch_size, eof_indicator, input_dataP, dataPath = next(
                    data_loader)

                feed_dict[input_dataN] = dataN
                feed_dict[input_dataT] = dataT
                feed_dict[input_dataPath] = dataPath
                # feed_dict[targetsN] = tN
                feed_dict[targetsT] = tT

                for i in range(len(gpu_list)):
                    for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                        feed_dict[m] = m_np

                eval_fetched = sess.run(eval_fetches, feed_dict=feed_dict)

                loss_np, tower_mems_np = eval_fetched[:2]
                # accuracy_N, accuracy_T = (fetched[-2], eval_fetched[-1])
                accuracy_T = eval_fetched[-1]
                # total_accN += accuracy_N
                total_accT += accuracy_T
                total_loss += loss_np

                step += 1
                if step >= epoch_size:
                    break

            tf.logging.info(
                "Validation:  Epoch {}  | loss {:.2f} | pplx {:>7.2f}, accT {:.4f}, bpc {:>7.4f}"
                .format(epoch + 1, total_loss / epoch_size,
                        math.exp(total_loss / epoch_size),
                        total_accT / epoch_size,
                        (total_loss / epoch_size) / math.log(2)))
            print(
                "Validation:  Epoch {}  | loss {:.2f} | pplx {:>7.2f}, accT {:.4f}, bpc {:>7.4f}"
                .format(epoch + 1, total_loss / epoch_size,
                        math.exp(total_loss / epoch_size),
                        total_accT / epoch_size,
                        (total_loss / epoch_size) / math.log(2)),
                file=fout)
            fout.flush()
コード例 #3
0
ファイル: train_gpu.py プロジェクト: luhuaei/xlnet
def train(ps_device):
  ##### Get input function and model function

  train_input_fn, record_info_dict = data_utils.get_input_fn(
      tfrecord_dir=FLAGS.record_info_dir,
      split="train",
      bsz_per_host=FLAGS.train_batch_size,
      seq_len=FLAGS.seq_len,
      reuse_len=FLAGS.reuse_len,
      bi_data=FLAGS.bi_data,
      num_hosts=1,
      num_core_per_host=1, # set to one no matter how many GPUs
      perm_size=FLAGS.perm_size,
      mask_alpha=FLAGS.mask_alpha,
      mask_beta=FLAGS.mask_beta,
      uncased=FLAGS.uncased,
      num_passes=FLAGS.num_passes,
      use_bfloat16=FLAGS.use_bfloat16,
      num_predict=FLAGS.num_predict)

  # for key, info in record_info_dict.items():
  tf.compat.v1.logging.info("num of batches {}".format(record_info_dict["num_batch"]))

  ##### Create input tensors / placeholders
  bsz_per_core = FLAGS.train_batch_size // FLAGS.num_core_per_host

  params = {
      "batch_size": FLAGS.train_batch_size # the whole batch
  }
  train_set = train_input_fn(params)

  example = train_set.make_one_shot_iterator().get_next()

  if FLAGS.num_core_per_host > 1:
    examples = [{} for _ in range(FLAGS.num_core_per_host)]
    for key in example.keys():
      vals = tf.split(example[key], FLAGS.num_core_per_host, 0)
      for device_id in range(FLAGS.num_core_per_host):
        examples[device_id][key] = vals[device_id]
  else:
    examples = [example]

  ##### Create computational graph
  tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

  for i in range(FLAGS.num_core_per_host):
    reuse = True if i > 0 else None
    with tf.device(assign_to_gpu(i, ps_device)), \
        tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

      # The mems for each tower is a dictionary
      mems_i = {}
      if FLAGS.mem_len:
        mems_i["mems"] = create_mems_tf(bsz_per_core)

      loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
          is_training=True,
          features=examples[i],
          mems=mems_i)

      tower_mems.append(mems_i)
      tower_losses.append(loss_i)
      tower_new_mems.append(new_mems_i)
      tower_grads_and_vars.append(grads_and_vars_i)

  ## average losses and gradients across towers
  if len(tower_losses) > 1:
    loss = tf.add_n(tower_losses) / len(tower_losses)
    grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
  else:
    loss = tower_losses[0]
    grads_and_vars = tower_grads_and_vars[0]

  ## get train op
  train_op, learning_rate, gnorm = model_utils.get_train_op(FLAGS, None,
      grads_and_vars=grads_and_vars)
  global_step = tf.train.get_global_step()

  ##### Training loop
  # initialize mems
  tower_mems_np = []
  for i in range(FLAGS.num_core_per_host):
    mems_i_np = {}
    for key in tower_mems[i].keys():
      mems_i_np[key] = initialize_mems_np(bsz_per_core)
    tower_mems_np.append(mems_i_np)

  saver = tf.train.Saver()

  gpu_options = tf.GPUOptions(allow_growth=True)

  model_utils.init_from_checkpoint(FLAGS, global_vars=True)

  with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
      gpu_options=gpu_options)) as sess:
    sess.run(tf.global_variables_initializer())

    fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]

    total_loss, prev_step = 0., -1
    while True:
      feed_dict = {}
      for i in range(FLAGS.num_core_per_host):
        for key in tower_mems_np[i].keys():
          for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]):
            feed_dict[m] = m_np

      fetched = sess.run(fetches, feed_dict=feed_dict)

      loss_np, tower_mems_np, curr_step = fetched[:3]
      total_loss += loss_np

      if curr_step > 0 and curr_step % FLAGS.iterations == 0:
        curr_loss = total_loss / (curr_step - prev_step)
        tf.compat.v1.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
            "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
            curr_step, fetched[-3], fetched[-2],
            curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
        total_loss, prev_step = 0., curr_step

      if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
        save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
        saver.save(sess, save_path)
        tf.compat.v1.logging.info("Model saved in path: {}".format(save_path))

      if curr_step >= FLAGS.train_steps:
        break
コード例 #4
0
def train(ps_device):
    ##### Get input function and model function

    train_input_fn, record_info_dict = data_utils.get_input_fn(
        info_dir=os.path.join(FLAGS.record_info_dir, "train"),
        split="train",
        bsz_per_host=FLAGS.train_batch_size,
        seq_len=FLAGS.seq_len,
        reuse_len=FLAGS.reuse_len,
        bi_data=FLAGS.bi_data,
        num_hosts=1,
        num_core_per_host=1,  # set to one no matter how many GPUs
        perm_size=FLAGS.perm_size,
        mask_alpha=FLAGS.mask_alpha,
        mask_beta=FLAGS.mask_beta,
        use_bfloat16=FLAGS.use_bfloat16,
        num_predict=FLAGS.num_predict)

    valid_input_fn, record_info_dict_valid = data_utils.get_input_fn(
        info_dir=os.path.join(FLAGS.record_info_dir, "valid"),
        split="valid",
        bsz_per_host=FLAGS.train_batch_size,
        seq_len=FLAGS.seq_len,
        reuse_len=FLAGS.reuse_len,
        bi_data=FLAGS.bi_data,
        num_hosts=1,
        num_core_per_host=1,
        perm_size=FLAGS.perm_size,
        mask_alpha=FLAGS.mask_alpha,
        mask_beta=FLAGS.mask_beta,
        use_bfloat16=FLAGS.use_bfloat16,
        num_predict=FLAGS.num_predict)

    # for key, info in record_info_dict.items():
    num_train_batches = record_info_dict["num_batch"]
    tf.logging.info("num of train batches {}".format(
        record_info_dict["num_batch"]))
    tf.logging.info("num of validation batches {}".format(
        record_info_dict_valid["num_batch"]))

    ##### Create input tensors / placeholders
    bsz_per_core = FLAGS.train_batch_size // FLAGS.num_core_per_host

    params = {
        "batch_size": FLAGS.train_batch_size  # the whole batch
    }
    train_set = train_input_fn(params)
    valid_set = valid_input_fn(params)

    t_iter = train_set.make_initializable_iterator()
    example = t_iter.get_next()
    v_iter = valid_set.make_initializable_iterator()
    v_example = v_iter.get_next()

    if FLAGS.num_core_per_host > 1:
        # train set
        examples = [{} for _ in range(FLAGS.num_core_per_host)]
        for key in example.keys():
            vals = tf.split(example[key], FLAGS.num_core_per_host, 0)
            for device_id in range(FLAGS.num_core_per_host):
                examples[device_id][key] = vals[device_id]

        # validation set
        v_examples = [{} for _ in range(FLAGS.num_core_per_host)]
        for key in v_example.keys():
            vals = tf.split(v_example[key], FLAGS.num_core_per_host, 0)
            for device_id in range(FLAGS.num_core_per_host):
                v_examples[device_id][key] = vals[device_id]
    else:
        examples = [example]
        v_examples = [v_example]

    ##### Create computational graph
    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []
    v_tower_mems, v_tower_losses, v_tower_new_mems = [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
            tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

            # The mems for each tower is a dictionary
            mems_i = {}
            v_mems_i = {}
            if FLAGS.mem_len:
                mems_i["mems"] = create_mems_tf(bsz_per_core)
                v_mems_i["mems"] = create_mems_tf(bsz_per_core)

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                is_training=True, features=examples[i], mems=mems_i)

            v_loss_i, v_new_mems_i = single_core_graph(is_training=False,
                                                       features=v_examples[i],
                                                       mems=v_mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

            v_tower_mems.append(v_mems_i)
            v_tower_losses.append(v_loss_i)
            v_tower_new_mems.append(v_new_mems_i)

    ## average losses and gradients across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]

    if len(v_tower_losses) > 1:
        v_loss = tf.add_n(v_tower_losses) / len(v_tower_losses)
    else:
        v_loss = v_tower_losses[0]

    ## get train op
    train_op, learning_rate, gnorm = model_utils.get_train_op(
        FLAGS, None, num_train_batches, grads_and_vars=grads_and_vars)
    global_step = tf.train.get_global_step()

    ##### Training loop
    # initialize mems
    tower_mems_np = []
    v_tower_mems_np = []
    for i in range(FLAGS.num_core_per_host):
        mems_i_np = {}
        v_mems_i_np = {}
        for key in tower_mems[i].keys():
            mems_i_np[key] = initialize_mems_np(bsz_per_core)
            v_mems_i_np[key] = initialize_mems_np(bsz_per_core)
        tower_mems_np.append(mems_i_np)
        v_tower_mems_np.append(v_mems_i_np)

    saver = tf.train.Saver()

    gpu_options = tf.GPUOptions(allow_growth=True)

    model_utils.init_from_checkpoint(FLAGS, global_vars=True)

    # Create performance summaries for Tensorboard logging
    training_performance_summaries, valid_performance_summaries = tb.tensorboard_setup(
    )

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          gpu_options=gpu_options)) as sess:
        sess.run(tf.global_variables_initializer())

        # variables that are run in the session
        fetches = [
            loss, tower_new_mems, global_step, gnorm, learning_rate, train_op
        ]
        v_fetches = [v_loss, v_tower_new_mems]

        # Create writers for Tensorboard logging
        info_dict = {
            "id": FLAGS.run_id,
            "n_layers": FLAGS.n_layers,
            "d_model": FLAGS.d_model,
            "n_heads": FLAGS.n_head
        }
        train_summary_writer, valid_summary_writer = tb.create_writers(
            sess, info_dict, logging_dir=FLAGS.tb_logging_dir)

        total_loss, prev_step = 0., -1
        for i in range(FLAGS.epochs):

            # Train loop
            try:
                sess.run(t_iter.initializer)
                while True:
                    feed_dict = {}
                    for i in range(FLAGS.num_core_per_host):
                        for key in tower_mems_np[i].keys():
                            for m, m_np in zip(tower_mems[i][key],
                                               tower_mems_np[i][key]):
                                feed_dict[m] = m_np

                    fetched = sess.run(fetches, feed_dict=feed_dict)
                    loss_np, tower_mems_np, curr_step = fetched[:3]
                    total_loss += loss_np
                    print(curr_step)

                    # Log training progress
                    if curr_step > 0 and curr_step % FLAGS.log_steps == 0:
                        curr_loss = total_loss / (curr_step - prev_step)
                        summ = tb.run_train(sess,
                                            training_performance_summaries,
                                            curr_loss)
                        train_summary_writer.add_summary(summ, curr_step)
                        tf.logging.info(
                            "[{}] | gnorm {:.2f} lr {:8.6f} "
                            "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".
                            format(curr_step, fetched[-3], fetched[-2],
                                   curr_loss, math.exp(curr_loss),
                                   curr_loss / math.log(2)))
                        total_loss, prev_step = 0., curr_step

                    # Save checkpoint
                    if curr_step > 0 and FLAGS.save_steps is not None and curr_step % FLAGS.save_steps == 0:
                        save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                        saver.save(sess, save_path)
                        tf.logging.info(
                            "Model saved in path: {}".format(save_path))

            except tf.errors.OutOfRangeError:
                pass

            # Validation loop
            try:
                sess.run(v_iter.initializer)
                v_total_loss, v_steps = 0., 0
                while True:
                    v_feed_dict = {}
                    for i in range(FLAGS.num_core_per_host):
                        for key in v_tower_mems_np[i].keys():
                            for m, m_np in zip(v_tower_mems[i][key],
                                               v_tower_mems_np[i][key]):
                                v_feed_dict[m] = m_np

                    v_fetched = sess.run(v_fetches, feed_dict=v_feed_dict)
                    v_loss_np, v_tower_mems_np = v_fetched[:]
                    v_total_loss += v_loss_np
                    v_steps += 1

            except tf.errors.OutOfRangeError:
                val_loss = v_total_loss / v_steps
                v_pplx = math.exp(val_loss)
                tf.logging.info(
                    "Validation: [{}] | loss {:.2f} | pplx {:>7.2f}".format(
                        curr_step, val_loss, v_pplx))

                summ_valid = tb.run_valid(sess, valid_performance_summaries,
                                          val_loss, v_pplx)
                valid_summary_writer.add_summary(summ_valid, curr_step)

            tf.logging.info("------------ Epoch {} ------------".format(i))
コード例 #5
0
ファイル: train_gpu.py プロジェクト: fqararyah/xlnet
def train(ps_device):
    # Get input function and model function

    train_input_fn, record_info_dict = data_utils.get_input_fn(
        tfrecord_dir=FLAGS.record_info_dir,
        split="train",
        bsz_per_host=FLAGS.train_batch_size,
        seq_len=FLAGS.seq_len,
        reuse_len=FLAGS.reuse_len,
        bi_data=FLAGS.bi_data,
        num_hosts=1,
        num_core_per_host=1,  # set to one no matter how many GPUs
        perm_size=FLAGS.perm_size,
        mask_alpha=FLAGS.mask_alpha,
        mask_beta=FLAGS.mask_beta,
        uncased=FLAGS.uncased,
        num_passes=FLAGS.num_passes,
        use_bfloat16=FLAGS.use_bfloat16,
        num_predict=FLAGS.num_predict)

    # for key, info in record_info_dict.items():
    tf.logging.info("num of batches {}".format(record_info_dict["num_batch"]))

    # Create input tensors / placeholders
    bsz_per_core = FLAGS.train_batch_size // FLAGS.num_core_per_host

    params = {
        "batch_size": FLAGS.train_batch_size  # the whole batch
    }
    train_set = train_input_fn(params)

    example = train_set.make_one_shot_iterator().get_next()

    if FLAGS.num_core_per_host > 1:
        examples = [{} for _ in range(FLAGS.num_core_per_host)]
        for key in example.keys():
            vals = tf.split(example[key], FLAGS.num_core_per_host, 0)
            for device_id in range(FLAGS.num_core_per_host):
                examples[device_id][key] = vals[device_id]
    else:
        examples = [example]

    # Create computational graph
    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
                tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

            # The mems for each tower is a dictionary
            mems_i = {}
            if FLAGS.mem_len:
                mems_i["mems"] = create_mems_tf(bsz_per_core)

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                is_training=True,
                features=examples[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

    # average losses and gradients across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]

    # get train op
    train_op, learning_rate, gnorm = model_utils.get_train_op(FLAGS, None,
                                                              grads_and_vars=grads_and_vars)
    global_step = tf.train.get_global_step()

    # Training loop
    # initialize mems
    tower_mems_np = []
    for i in range(FLAGS.num_core_per_host):
        mems_i_np = {}
        for key in tower_mems[i].keys():
            mems_i_np[key] = initialize_mems_np(bsz_per_core)
        tower_mems_np.append(mems_i_np)

    saver = tf.train.Saver()

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.97)#allow_growth=True)

    model_utils.init_from_checkpoint(FLAGS, global_vars=True)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False,
                                          gpu_options=gpu_options)) as sess:
        sess.run(tf.global_variables_initializer())
        sess.graph.finalize()
        run_metadata = tf.RunMetadata()
        options = tf.RunOptions(trace_level=tf.RunOptions.SOFTWARE_TRACE)

        dot_rep = graph_to_dot(tf.get_default_graph())
        # s = Source(dot_rep, filename="test.gv", format="PNG")
        with open('profs/xln.dot', 'w') as fwr:
            fwr.write(str(dot_rep))

        operations_tensors = {}
        operations_attributes = {}
        operations_names = tf.get_default_graph().get_operations()
        count1 = 0
        count2 = 0

        for operation in operations_names:
            operation_name = operation.name
            operations_info = tf.get_default_graph(
            ).get_operation_by_name(operation_name).values()

            try:
                operations_attributes[operation_name] = []
                operations_attributes[operation_name].append(
                    operation.type)
                operations_attributes[operation_name].append(tf.get_default_graph(
                ).get_tensor_by_name(operation_name + ':0').dtype._is_ref_dtype)
            except:
                pass

            if len(operations_info) > 0:
                if not (operations_info[0].shape.ndims is None):
                    operation_shape = operations_info[0].shape.as_list(
                    )
                    operation_dtype_size = operations_info[0].dtype.size
                    if not (operation_dtype_size is None):
                        operation_no_of_elements = 1
                        for dim in operation_shape:
                            if not(dim is None):
                                operation_no_of_elements = operation_no_of_elements * dim
                        total_size = operation_no_of_elements * operation_dtype_size
                        operations_tensors[operation_name] = total_size
                    else:
                        count1 = count1 + 1
                else:
                    count1 = count1 + 1
                    operations_tensors[operation_name] = -1

                #   print('no shape_1: ' + operation_name)
                #  print('no shape_2: ' + str(operations_info))
                #  operation_namee = operation_name + ':0'
                # tensor = tf.get_default_graph().get_tensor_by_name(operation_namee)
                # print('no shape_3:' + str(tf.shape(tensor)))
                # print('no shape:' + str(tensor.get_shape()))

            else:
                # print('no info :' + operation_name)
                # operation_namee = operation.name + ':0'
                count2 = count2 + 1
                operations_tensors[operation_name] = -1

                # try:
                #   tensor = tf.get_default_graph().get_tensor_by_name(operation_namee)
                # print(tensor)
                # print(tf.shape(tensor))
                # except:
                # print('no tensor: ' + operation_namee)
        print(count1)
        print(count2)

        with open('./profs/tensors_sz_32.txt', 'w') as f:
            for tensor, size in operations_tensors.items():
                f.write('"' + tensor + '"::' + str(size) + '\n')

        with open('./profs/operations_attributes.txt', 'w') as f:
            for op, attrs in operations_attributes.items():
                strr = op
                for attr in attrs:
                    strr += '::' + str(attr)
                strr += '\n'
                f.write(strr)

        fetches = [loss, tower_new_mems, global_step,
                   gnorm, learning_rate, train_op]
        iter = 0
        total_loss, prev_step = 0., -1
        while True:
            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for key in tower_mems_np[i].keys():
                    for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]):
                        feed_dict[m] = m_np
            if iter % 10 == 7 or iter == 0:
                fetched = sess.run(fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata)
                #if iter > 0:
                profile(run_metadata, iter)
            else:
                t0 = time.time()
                fetched = sess.run(fetches, feed_dict=feed_dict)
                print(time.time() - t0)
            if iter == 0:
                mem_options = tf.profiler.ProfileOptionBuilder.time_and_memory()
                mem_options["min_bytes"] = 0
                mem_options["min_micros"] = 0
                mem_options["output"] = 'file:outfile=./profs/mem.txt'
                mem_options["select"] = ("bytes", "peak_bytes", "output_bytes",
                          "residual_bytes")
                mem = tf.profiler.profile(
                  tf.Graph(), run_meta=run_metadata, cmd="scope", options=mem_options)
                with open('profs/mem2.txt', 'w') as f:
                  f.write(str(mem))
            iter += 1

            loss_np, tower_mems_np, curr_step = fetched[:3]
            total_loss += loss_np

            if curr_step > 0 and curr_step % FLAGS.iterations == 0:
                curr_loss = total_loss / (curr_step - prev_step)
                tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
                                "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
                                    curr_step, fetched[-3], fetched[-2],
                                    curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
                total_loss, prev_step = 0., curr_step

            if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                saver.save(sess, save_path)
                tf.logging.info("Model saved in path: {}".format(save_path))

            if curr_step >= FLAGS.train_steps:
                break
コード例 #6
0
def main(_):
    if FLAGS.server_ip and FLAGS.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(FLAGS.server_ip, FLAGS.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    tf.set_random_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)

    tf.logging.set_verbosity(tf.logging.INFO)

    #### Validate flags
    if FLAGS.save_steps is not None:
        FLAGS.log_step_count_steps = min(FLAGS.log_step_count_steps,
                                         FLAGS.save_steps)

    if FLAGS.do_predict:
        predict_dir = FLAGS.predict_dir
        if not tf.gfile.Exists(predict_dir):
            tf.gfile.MakeDirs(predict_dir)

    processors = {
        "mnli_matched": MnliMatchedProcessor,
        "mnli_mismatched": MnliMismatchedProcessor,
        'sts-b': StsbProcessor,
        'imdb': ImdbProcessor,
        "yelp5": Yelp5Processor
    }

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `do_train`, `do_eval, `do_predict` or "
            "`do_submit` must be True.")

    if not tf.gfile.Exists(FLAGS.output_dir):
        tf.gfile.MakeDirs(FLAGS.output_dir)

    if not tf.gfile.Exists(FLAGS.model_dir):
        tf.gfile.MakeDirs(FLAGS.model_dir)

#   ########################### LOAD PT model
#   ########################### LOAD PT model
#   import torch
#   from pytorch_transformers import CONFIG_NAME, TF_WEIGHTS_NAME, XLNetTokenizer, XLNetConfig, XLNetForSequenceClassification

#   save_path = os.path.join(FLAGS.model_dir, TF_WEIGHTS_NAME)
#   tf.logging.info("Model loaded from path: {}".format(save_path))

#   device = torch.device("cuda", 4)
#   config = XLNetConfig.from_pretrained('xlnet-large-cased', finetuning_task=u'sts-b')
#   config_path = os.path.join(FLAGS.model_dir, CONFIG_NAME)
#   config.to_json_file(config_path)
#   pt_model = XLNetForSequenceClassification.from_pretrained(FLAGS.model_dir, from_tf=True, num_labels=1)
#   pt_model.to(device)
#   pt_model = torch.nn.DataParallel(pt_model, device_ids=[4, 5, 6, 7])

#   from torch.optim import Adam
#   optimizer = Adam(pt_model.parameters(), lr=0.001, betas=(0.9, 0.999),
#                     eps=FLAGS.adam_epsilon, weight_decay=FLAGS.weight_decay,
#                     amsgrad=False)
#   ########################### LOAD PT model
#   ########################### LOAD PT model

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    label_list = processor.get_labels() if not FLAGS.is_regression else None

    sp = spm.SentencePieceProcessor()
    sp.Load(FLAGS.spiece_model_file)

    def tokenize_fn(text):
        text = preprocess_text(text, lower=FLAGS.uncased)
        return encode_ids(sp, text)

    # run_config = model_utils.configure_tpu(FLAGS)


#   model_fn = get_model_fn(len(label_list) if label_list is not None else None)

    spm_basename = os.path.basename(FLAGS.spiece_model_file)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    # estimator = tf.estimator.Estimator(
    #     model_fn=model_fn,
    #     config=run_config)

    if FLAGS.do_train:
        train_file_base = "{}.len-{}.train.tf_record".format(
            spm_basename, FLAGS.max_seq_length)
        train_file = os.path.join(FLAGS.output_dir, train_file_base)
        tf.logging.info("Use tfrecord file {}".format(train_file))

        train_examples = processor.get_train_examples(FLAGS.data_dir)
        tf.logging.info("Num of train samples: {}".format(len(train_examples)))

        file_based_convert_examples_to_features(train_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenize_fn, train_file,
                                                FLAGS.num_passes)

        train_input_fn = file_based_input_fn_builder(
            input_file=train_file,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)

        # estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)

        ##### Create input tensors / placeholders
        bsz_per_core = FLAGS.train_batch_size // FLAGS.num_core_per_host

        params = {
            "batch_size": FLAGS.train_batch_size  # the whole batch
        }
        train_set = train_input_fn(params)

        example = train_set.make_one_shot_iterator().get_next()
        if FLAGS.num_core_per_host > 1:
            examples = [{} for _ in range(FLAGS.num_core_per_host)]
            for key in example.keys():
                vals = tf.split(example[key], FLAGS.num_core_per_host, 0)
                for device_id in range(FLAGS.num_core_per_host):
                    examples[device_id][key] = vals[device_id]
        else:
            examples = [example]

        ##### Create computational graph
        tower_losses, tower_grads_and_vars, tower_inputs, tower_hidden_states, tower_logits = [], [], [], [], []

        for i in range(FLAGS.num_core_per_host):
            reuse = True if i > 0 else None
            with tf.device(assign_to_gpu(i, "/gpu:0")), \
                tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

                loss_i, grads_and_vars_i, inputs_i, hidden_states_i, logits_i = single_core_graph(
                    is_training=True,
                    features=examples[i],
                    label_list=label_list)

                tower_losses.append(loss_i)
                tower_grads_and_vars.append(grads_and_vars_i)
                tower_inputs.append(inputs_i)
                tower_hidden_states.append(hidden_states_i)
                tower_logits.append(logits_i)

        ## average losses and gradients across towers
        if len(tower_losses) > 1:
            loss = tf.add_n(tower_losses) / len(tower_losses)
            grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
            inputs = dict((n, tf.concat([t[n] for t in tower_inputs], 0))
                          for n in tower_inputs[0])
            hidden_states = list(
                tf.concat(t, 0) for t in zip(*tower_hidden_states))
            logits = tf.concat(tower_logits, 0)
        else:
            loss = tower_losses[0]
            grads_and_vars = tower_grads_and_vars[0]
            inputs = tower_inputs[0]
            hidden_states = tower_hidden_states[0]
            logits = tower_logits[0]

        # Summaries
        merged = tf.summary.merge_all()

        ## get train op
        train_op, learning_rate, gnorm = model_utils.get_train_op(
            FLAGS, None, grads_and_vars=grads_and_vars)
        global_step = tf.train.get_global_step()

        ##### Training loop
        saver = tf.train.Saver(max_to_keep=FLAGS.max_save)

        gpu_options = tf.GPUOptions(allow_growth=True)

        #### load pretrained models
        model_utils.init_from_checkpoint(FLAGS, global_vars=True)

        writer = tf.summary.FileWriter(logdir=FLAGS.model_dir,
                                       graph=tf.get_default_graph())
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True, gpu_options=gpu_options)) as sess:
            sess.run(tf.global_variables_initializer())

            #########
            ##### PYTORCH
            import torch
            from torch.optim import Adam
            from pytorch_transformers import CONFIG_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, XLNetTokenizer, XLNetConfig, XLNetForSequenceClassification, BertAdam

            save_path = os.path.join(FLAGS.model_dir, TF_WEIGHTS_NAME + '-00')
            saver.save(sess, save_path)
            tf.logging.info("Model saved in path: {}".format(save_path))

            device = torch.device("cuda", 4)
            config = XLNetConfig.from_pretrained('xlnet-large-cased',
                                                 finetuning_task=u'sts-b',
                                                 num_labels=1)
            tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')

            # pt_model = XLNetForSequenceClassification.from_pretrained('xlnet-large-cased', num_labels=1)
            pt_model = XLNetForSequenceClassification.from_pretrained(
                save_path, from_tf=True, config=config)
            pt_model.to(device)
            pt_model = torch.nn.DataParallel(pt_model, device_ids=[4, 5, 6, 7])

            optimizer = Adam(pt_model.parameters(),
                             lr=0.001,
                             betas=(0.9, 0.999),
                             eps=FLAGS.adam_epsilon,
                             weight_decay=FLAGS.weight_decay,
                             amsgrad=False)
            # optimizer = BertAdam(pt_model.parameters(), lr=FLAGS.learning_rate, t_total=FLAGS.train_steps, warmup=FLAGS.warmup_steps / FLAGS.train_steps,
            #                      eps=FLAGS.adam_epsilon, weight_decay=FLAGS.weight_decay)
            ##### PYTORCH
            #########

            fetches = [
                loss, global_step, gnorm, learning_rate, train_op, merged,
                inputs, hidden_states, logits
            ]

            total_loss, total_loss_pt, prev_step, gnorm_pt = 0., 0., -1, 0.0
            total_logits = None
            total_labels = None
            while True:
                feed_dict = {}
                # for i in range(FLAGS.num_core_per_host):
                #   for key in tower_mems_np[i].keys():
                #     for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]):
                #       feed_dict[m] = m_np

                fetched = sess.run(fetches)

                loss_np, curr_step, gnorm_np, learning_rate_np, _, summary_np, inputs_np, hidden_states_np, logits_np = fetched
                total_loss += loss_np

                if total_logits is None:
                    total_logits = logits_np
                    total_labels = inputs_np['label_ids']
                else:
                    total_logits = np.append(total_logits, logits_np, axis=0)
                    total_labels = np.append(total_labels,
                                             inputs_np['label_ids'],
                                             axis=0)

                #########
                ##### PYTORCH
                f_inp = torch.tensor(inputs_np["input_ids"],
                                     dtype=torch.long,
                                     device=device)
                f_seg_id = torch.tensor(inputs_np["segment_ids"],
                                        dtype=torch.long,
                                        device=device)
                f_inp_mask = torch.tensor(inputs_np["input_mask"],
                                          dtype=torch.float,
                                          device=device)
                f_label = torch.tensor(inputs_np["label_ids"],
                                       dtype=torch.float,
                                       device=device)

                # with torch.no_grad():
                #   _, hidden_states_pt, _ = pt_model.transformer(f_inp, f_seg_id, f_inp_mask)
                # logits_pt, _ = pt_model(f_inp, token_type_ids=f_seg_id, input_mask=f_inp_mask)

                pt_model.train()
                outputs = pt_model(f_inp,
                                   token_type_ids=f_seg_id,
                                   input_mask=f_inp_mask,
                                   labels=f_label)
                loss_pt = outputs[0]
                loss_pt = loss_pt.mean()
                total_loss_pt += loss_pt.item()

                # # hidden_states_pt = list(t.detach().cpu().numpy() for t in hidden_states_pt)
                # # special_pt = special_pt.detach().cpu().numpy()

                # # Optimizer pt
                pt_model.zero_grad()
                loss_pt.backward()
                gnorm_pt = torch.nn.utils.clip_grad_norm_(
                    pt_model.parameters(), FLAGS.clip)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = learning_rate_np
                optimizer.step()
                ##### PYTORCH
                #########

                if curr_step > 0 and curr_step % FLAGS.log_step_count_steps == 0:
                    curr_loss = total_loss / (curr_step - prev_step)
                    curr_loss_pt = total_loss_pt / (curr_step - prev_step)
                    tf.logging.info(
                        "[{}] | gnorm {:.2f} lr {:8.6f} "
                        "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
                            curr_step, gnorm_np, learning_rate_np, curr_loss,
                            math.exp(curr_loss), curr_loss / math.log(2)))

                    #########
                    ##### PYTORCH
                    tf.logging.info(
                        "  PT [{}] | gnorm PT {:.2f} lr PT {:8.6f} "
                        "| loss PT {:.2f} | pplx PT {:>7.2f}, bpc PT {:>7.4f}".
                        format(curr_step, gnorm_pt, learning_rate_np,
                               curr_loss_pt, math.exp(curr_loss_pt),
                               curr_loss_pt / math.log(2)))
                    ##### PYTORCH
                    #########

                    total_loss, total_loss_pt, prev_step = 0., 0., curr_step
                    writer.add_summary(summary_np, global_step=curr_step)

                if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                    save_path = os.path.join(FLAGS.model_dir,
                                             "model.ckpt-{}".format(curr_step))
                    saver.save(sess, save_path)
                    tf.logging.info(
                        "Model saved in path: {}".format(save_path))

                    #########
                    ##### PYTORCH
                    # Save a trained model, configuration and tokenizer
                    model_to_save = pt_model.module if hasattr(
                        pt_model,
                        'module') else pt_model  # Only save the model it-self
                    # If we save using the predefined names, we can load using `from_pretrained`
                    output_dir = os.path.join(
                        FLAGS.output_dir, "pytorch-ckpt-{}".format(curr_step))
                    if not tf.gfile.Exists(output_dir):
                        tf.gfile.MakeDirs(output_dir)
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
                    tf.logging.info(
                        "PyTorch Model saved in path: {}".format(output_dir))
                    ##### PYTORCH
                    #########

                if curr_step >= FLAGS.train_steps:
                    break

    if FLAGS.do_eval:
        # TPU requires a fixed batch size for all batches, therefore the number
        # of examples must be a multiple of the batch size, or else examples
        # will get dropped. So we pad with fake examples which are ignored
        # later on. These do NOT count towards the metric (all tf.metrics
        # support a per-instance weight, and these get a weight of 0.0).
        #
        # Modified in XL: We also adopt the same mechanism for GPUs.
        while len(eval_examples) % FLAGS.eval_batch_size != 0:
            eval_examples.append(PaddingInputExample())

        eval_file_base = "{}.len-{}.{}.eval.tf_record".format(
            spm_basename, FLAGS.max_seq_length, FLAGS.eval_split)
        eval_file = os.path.join(FLAGS.output_dir, eval_file_base)

        file_based_convert_examples_to_features(eval_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenize_fn, eval_file)

        assert len(eval_examples) % FLAGS.eval_batch_size == 0
        eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)

        eval_input_fn = file_based_input_fn_builder(
            input_file=eval_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=True)

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True, gpu_options=gpu_options)) as sess:
            sess.run(tf.global_variables_initializer())

            ########################### LOAD PT model
            #   import torch
            #   from pytorch_transformers import CONFIG_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, XLNetTokenizer, XLNetConfig, XLNetForSequenceClassification, BertAdam

            #   save_path = os.path.join(FLAGS.model_dir, TF_WEIGHTS_NAME)
            #   saver.save(sess, save_path)
            #   tf.logging.info("Model saved in path: {}".format(save_path))

            #   device = torch.device("cuda", 4)
            #   config = XLNetConfig.from_pretrained('xlnet-large-cased', finetuning_task=u'sts-b', num_labels=1)
            #   tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
            #   config_path = os.path.join(FLAGS.model_dir, CONFIG_NAME)
            #   config.to_json_file(config_path)
            #   # pt_model = XLNetForSequenceClassification.from_pretrained('xlnet-large-cased', num_labels=1)
            #   pt_model = XLNetForSequenceClassification.from_pretrained(FLAGS.model_dir, from_tf=True)
            #   pt_model.to(device)
            #   pt_model = torch.nn.DataParallel(pt_model, device_ids=[4, 5, 6, 7])
            #   from torch.optim import Adam
            #   optimizer = Adam(pt_model.parameters(), lr=0.001, betas=(0.9, 0.999),
            #                    eps=FLAGS.adam_epsilon, weight_decay=FLAGS.weight_decay,
            #                    amsgrad=False)
            #   optimizer = BertAdam(pt_model.parameters(), lr=FLAGS.learning_rate, t_total=FLAGS.train_steps, warmup=FLAGS.warmup_steps / FLAGS.train_steps,
            #                        eps=FLAGS.adam_epsilon, weight_decay=FLAGS.weight_decay)

            ##### PYTORCH
            #########

            fetches = [
                loss, global_step, gnorm, learning_rate, train_op, merged,
                inputs, hidden_states, logits
            ]

            total_loss, total_loss_pt, prev_step, gnorm_pt = 0., 0., -1, 0.0
            total_logits = None
            total_labels = None
            while True:
                feed_dict = {}
                # for i in range(FLAGS.num_core_per_host):
                #   for key in tower_mems_np[i].keys():
                #     for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]):
                #       feed_dict[m] = m_np

                fetched = sess.run(fetches)

                loss_np, curr_step, gnorm_np, learning_rate_np, _, summary_np, inputs_np, hidden_states_np, logits_np = fetched
                total_loss += loss_np

                if total_logits is None:
                    total_logits = logits_np
                    total_labels = inputs_np['label_ids']
                else:
                    total_logits = np.append(total_logits, logits_np, axis=0)
                    total_labels = np.append(total_labels,
                                             inputs_np['label_ids'],
                                             axis=0)
コード例 #7
0
def train(n_token, cutoffs, ps_device):
    # get TF logger
    log = logging.getLogger('tensorflow')
    log.setLevel(logging.INFO)

    # create formatter and add it to the handlers
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # create file handler which logs even debug messages
    fh = logging.FileHandler('run_train.log')
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    log.addHandler(fh)

    ##### Get input function and model function
    train_input_fn, train_record_info = data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split="train",
        per_host_bsz=FLAGS.train_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_core_per_host,
        num_hosts=1,
        use_tpu=False)

    tf.logging.info("num of batches {}".format(train_record_info["num_batch"]))

    ##### Create computational graph
    train_set = train_input_fn({
        "batch_size": FLAGS.train_batch_size,
        "data_dir": FLAGS.data_dir
    })

    input_feed, label_feed = train_set.make_one_shot_iterator().get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
    labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

    per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host

    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
            tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

            mems_i = [
                tf.placeholder(tf.float32,
                               [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                for _ in range(FLAGS.n_layer)
            ]

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=True,
                inp=inputs[i],
                tgt=labels[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

    ## average losses and gradients across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]
    grads, all_vars = zip(*grads_and_vars)

    ## clip gradient
    clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip)
    grads_and_vars = list(zip(clipped, all_vars))

    ## configure the optimizer
    global_step = tf.train.get_or_create_global_step()

    # warmup stage: increase the learning rate linearly
    if FLAGS.warmup_steps > 0:
        warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                    * FLAGS.learning_rate
    else:
        warmup_lr = 0.0

    # decay stage: decay the learning rate using the cosine schedule
    decay_lr = tf.train.cosine_decay(
        FLAGS.learning_rate,
        global_step=global_step - FLAGS.warmup_steps,
        decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
        alpha=FLAGS.min_lr_ratio)

    # choose warmup or decay
    learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr,
                             decay_lr)

    # get the train op
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    ##### Training loop
    tower_mems_np = [[
        np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model],
                 dtype=np.float32) for layer in range(FLAGS.n_layer)
    ] for core in range(FLAGS.num_core_per_host)]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.warm_start_path is not None:
            tf.logging.info("warm start from {}".format(FLAGS.warm_start_path))
            saver.restore(sess, FLAGS.warm_start_path)

        fetches = [
            loss, tower_new_mems, global_step, gnorm, learning_rate, train_op
        ]

        total_loss, prev_step = 0., -1
        while True:
            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np

            fetched = sess.run(fetches, feed_dict=feed_dict)

            loss_np, tower_mems_np, curr_step = fetched[:3]
            total_loss += loss_np

            if curr_step % 100 == 0:
                print("Current step:", curr_step)

            if curr_step > 0 and curr_step % FLAGS.iterations == 0:
                curr_loss = total_loss / (curr_step - prev_step)
                tf.logging.info(
                    "[{}] | gnorm {:.2f} lr {:8.6f} "
                    "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
                        curr_step, fetched[-3], fetched[-2], curr_loss,
                        math.exp(curr_loss), curr_loss / math.log(2)))
                total_loss, prev_step = 0., curr_step

            if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                saver.save(sess, save_path)
                tf.logging.info("Model saved in path: {}".format(save_path))

            if curr_step == FLAGS.train_steps:
                break
コード例 #8
0
def dynamic_eval(n_token, cutoffs, ps_device):
    ##### Get input function and model function
    if FLAGS.rms:
        ##using training data to collect gradient statistics
        train_input_fn, train_record_info = data_utils.get_input_fn(
            record_info_dir=FLAGS.record_info_dir,
            split="train",
            per_host_bsz=FLAGS.train_batch_size,
            tgt_len=FLAGS.tgt_len,
            num_core_per_host=FLAGS.num_core_per_host,
            num_hosts=1,
            use_tpu=False)

        num_batch = train_record_info["num_batch"]

        tf.logging.info("num of batches {}".format(num_batch))

        ##### Create computational graph
        train_set = train_input_fn({
            "batch_size": FLAGS.train_batch_size,
            "data_dir": FLAGS.data_dir
        })

        input_feed, label_feed = train_set.make_one_shot_iterator().get_next()

        inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
        labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

        per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host


        tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

        for i in range(FLAGS.num_core_per_host):
            reuse = True if i > 0 else None
            with tf.device(assign_to_gpu(i, ps_device)), \
                tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):

                mems_i = [
                    tf.placeholder(
                        tf.float32,
                        [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                    for _ in range(FLAGS.n_layer)
                ]

                loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                    n_token=n_token,
                    cutoffs=cutoffs,
                    is_training=True,
                    inp=inputs[i],
                    tgt=labels[i],
                    mems=mems_i)

                tower_mems.append(mems_i)
                tower_losses.append(loss_i)
                tower_new_mems.append(new_mems_i)
                tower_grads_and_vars.append(grads_and_vars_i)

        ## sum losses across towers
        if len(tower_losses) > 1:
            loss = tf.add_n(tower_losses) / len(tower_losses)
            grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
        else:
            loss = tower_losses[0]
            grads_and_vars = tower_grads_and_vars[0]

        global_step = tf.train.get_or_create_global_step()

        optimizer = DynamicEvalOpt(learning_rate=FLAGS.learning_rate,
                                   decay_rate=FLAGS.decay_rate,
                                   eps=FLAGS.epsilon)
        optimizer.gradstat = True
        train_op = optimizer.apply_gradients(grads_and_vars, global_step)

        tower_mems_np = [[
            np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model],
                     dtype=np.float32) for layer in range(FLAGS.n_layer)
        ] for core in range(FLAGS.num_core_per_host)]

        saver = tf.train.Saver()

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            sess.run(tf.global_variables_initializer())

            if FLAGS.eval_ckpt_path is None:
                eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
            else:
                eval_ckpt_path = FLAGS.eval_ckpt_path

            tf.logging.info("Evaluate {}".format(eval_ckpt_path))
            saver.restore(sess, eval_ckpt_path)

            fetches = [loss, tower_new_mems, tf.size(label_feed), train_op]

            total_loss, prev_step = 0., -1

            total_loss, total_cnt = 0, 0

            format_str = "  >> processing batch for gradient statistics {{:{0}d}}/{{:{0}d}} ..".format(
                len(str(num_batch // 5000)))

            ## only small subset of training set used for gradient stats to save time
            for step in range(num_batch // 5000):
                if step % (num_batch // 50000) == 0:
                    tf.logging.info(format_str.format(step, num_batch // 5000))

                feed_dict = {}
                for i in range(FLAGS.num_core_per_host):
                    for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                        feed_dict[m] = m_np

                fetched = sess.run(fetches, feed_dict=feed_dict)

                loss_np, tower_mems_np, cnt_np = fetched[:3]
                total_loss += loss_np * cnt_np
                total_cnt += cnt_np

            avg_loss = total_loss / total_cnt
    ##    tf.logging.info("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
    ##        avg_loss, math.exp(avg_loss), avg_loss / math.log(2)))


#####Done gradstat

###starting dynamic eval

    eval_input_fn, eval_record_info = data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split=FLAGS.eval_split,
        per_host_bsz=FLAGS.eval_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_core_per_host,
        num_hosts=1,
        use_tpu=False)

    num_batch = eval_record_info["num_batch"]

    tf.logging.info("num of batches {}".format(num_batch))

    ##### Create computational graph
    eval_set = eval_input_fn({
        "batch_size": FLAGS.eval_batch_size,
        "data_dir": FLAGS.data_dir
    })

    input_feed, label_feed = eval_set.make_one_shot_iterator().get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
    labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

    per_core_bsz = FLAGS.eval_batch_size // FLAGS.num_core_per_host


    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
            tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):

            mems_i = [
                tf.placeholder(tf.float32,
                               [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                for _ in range(FLAGS.n_layer)
            ]

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=True,
                inp=inputs[i],
                tgt=labels[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

    ## sum losses across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]

    ## configure the optimizer
    global_step = tf.train.get_or_create_global_step()
    if not FLAGS.rms:

        optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=FLAGS.learning_rate
        )  # DynamicEvalPS(learning_rate=FLAGS.learning_rate )
    else:
        optimizer.gradstat = False
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    ##### Evaluation loop
    tower_mems_np = [[
        np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model],
                 dtype=np.float32) for layer in range(FLAGS.n_layer)
    ] for core in range(FLAGS.num_core_per_host)]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.eval_ckpt_path is None:
            eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
        else:
            eval_ckpt_path = FLAGS.eval_ckpt_path

        tf.logging.info("Evaluate {}".format(eval_ckpt_path))
        saver.restore(sess, eval_ckpt_path)

        fetches = [loss, tower_new_mems, tf.size(label_feed), train_op]

        total_loss, prev_step = 0., -1

        total_loss, total_cnt = 0, 0
        format_str = "  >> processing batch {{:{0}d}}/{{:{0}d}} ..".format(
            len(str(num_batch)))
        for step in range(num_batch // FLAGS.ratio):
            if step % (num_batch // (10 * FLAGS.ratio)) == 0:
                tf.logging.info(format_str.format(step, num_batch))

            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np

            fetched = sess.run(fetches, feed_dict=feed_dict)

            loss_np, tower_mems_np, cnt_np = fetched[:3]
            total_loss += loss_np * cnt_np
            total_cnt += cnt_np

        avg_loss = total_loss / total_cnt
        tf.logging.info("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
            avg_loss, math.exp(avg_loss), avg_loss / math.log(2)))
コード例 #9
0
def train_epoch(epoch, csv_logger, n_token, cutoffs):
    ps_device = "/gpu:0"

    train_input_fn, train_record_info = data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split="train",
        per_host_bsz=FLAGS.train_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_gpu,
        num_hosts=1,
        use_tpu=False)

    tf.logging.info("-" * 30)
    tf.logging.info("Starting epoch {}!".format(epoch))
    tf.logging.info("num of batches {}".format(train_record_info["num_batch"]))
    num_batch = train_record_info["num_batch"]

    train_set = train_input_fn({
        "batch_size": FLAGS.train_batch_size,
        "data_dir": FLAGS.data_dir})

    input_feed, label_feed = train_set.make_one_shot_iterator().get_next()

    inputs = tf.split(input_feed, FLAGS.num_gpu, 0)
    labels = tf.split(label_feed, FLAGS.num_gpu, 0)

    per_core_bsz = FLAGS.train_batch_size // FLAGS.num_gpu
    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    for i in range(FLAGS.num_gpu):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
                tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

            mems_i = [tf.placeholder(tf.float32,
                                     [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                      for _ in range(FLAGS.n_layer)]

            loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=True,
                inp=inputs[i],
                tgt=labels[i],
                mems=mems_i)

            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)

    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]

    grads, all_vars = zip(*grads_and_vars)

    clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip)
    grads_and_vars = list(zip(clipped, all_vars))

    global_step = tf.train.get_or_create_global_step()
    total_steps = FLAGS.epochs * num_batch

    if FLAGS.warmup_steps > 0:
        warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                    * FLAGS.learning_rate
    else:
        warmup_lr = 0.0

    decay_lr = tf.train.cosine_decay(
        FLAGS.learning_rate,
        global_step=global_step-FLAGS.warmup_steps,
        decay_steps=total_steps-FLAGS.warmup_steps,
        alpha=FLAGS.min_lr_ratio)

    learning_rate = tf.where(global_step < FLAGS.warmup_steps,
                             warmup_lr, decay_lr)

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    tower_mems_np = [
        [np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
            for layer in range(FLAGS.n_layer)]
        for core in range(FLAGS.num_gpu)
    ]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        latest_ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
        if latest_ckpt is not None:
            tf.logging.info("loading saved model from {}".format(latest_ckpt))
            saver.restore(sess, latest_ckpt)
        else:
            tf.logging.info("No previously saved model. Starting from scratch!")

        fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]
        total_loss, prev_step = 0., -1

        for ba in range(num_batch):
            feed_dict = {}
            for i in range(FLAGS.num_gpu):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np
            fetched = sess.run(fetches, feed_dict=feed_dict)
            loss_np, tower_mems_np, curr_step = fetched[:3]
            total_loss += loss_np
            if curr_step > 0 and curr_step % FLAGS.iterations == 0:
                curr_loss = total_loss / (curr_step - prev_step)
                tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
                    "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
                    curr_step, fetched[-3], fetched[-2],
                    curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
                log_dict = {
                    'train_loss': curr_loss,
                    'train_ppl': math.exp(curr_loss),
                    'train_bpc': curr_loss / math.log(2),
                    'lr': fetched[-2],
                    'global_step': curr_step,
                    'epoch': epoch
                }
                csv_logger.writerow(log_dict)
                total_loss, prev_step = 0., curr_step

            if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                saver.save(sess, save_path)
                tf.logging.info("Finished Step : {}".format(curr_step))
                tf.logging.info("Model saved in path: {}".format(save_path))


        curr_loss = total_loss / (curr_step - prev_step)
        tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
            "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
            curr_step, fetched[-3], fetched[-2],
            curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))

        save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
        saver.save(sess, save_path)
        tf.logging.info("Finished Epoch {}".format(curr_step))
        tf.logging.info("Model saved in path: {}".format(save_path))
        tf.logging.info("-" * 30)
コード例 #10
0
def train(n_token, ps_device):
    #  Get input function and model function
    train_input_fn, train_record_info = simple_data_utils.get_input_fn(
        record_info_dir=FLAGS.record_info_dir,
        split="train",
        per_host_bsz=FLAGS.train_batch_size,
        tgt_len=FLAGS.tgt_len,
        num_core_per_host=FLAGS.num_core_per_host,
        num_hosts=1)

    num_batches = train_record_info["num_batch"]
    tf.logging.info("num of batches {}".format(num_batches))
    tf.logging.info("run {} epochs:".format(TRAIN_STEPS / num_batches))

    # Create computational graph
    train_set = train_input_fn({
        "batch_size": FLAGS.train_batch_size,
        "data_dir": FLAGS.data_dir
    })

    input_feed, label_feed = train_set.make_one_shot_iterator().get_next()

    # 因为需要把一批数据分配到不同的机器上,需要将这批数据分割
    # tf.split(input, num_split, dimension) # num_split:份数,dimension:在哪个维度上切分,函数返回的列表(list)
    inputs = tf.split(input_feed, FLAGS.num_core_per_host,
                      0)  # 第0个维度表示batch size
    labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)

    per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host

    tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

    # assign_to_gpu(i, ps_device)
    for i in range(FLAGS.num_core_per_host):
        reuse = True if i > 0 else None
        with tf.device(assign_to_gpu(i, ps_device)), \
             tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
            all_mems = [
                tf.placeholder(tf.float32,
                               [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                for _ in range(FLAGS.n_layer)
            ]  # all_mems,list类型,保存了每一层的mems

            loss_i, new_all_mems, grads_and_vars_i = single_core_graph(  # 为什么变量也要呢,有啥用???
                n_token=n_token,  # 字表字母的个数,27,26+1,a..z+_
                is_training=True,
                inp=inputs[i],
                labels=labels[i],
                mems=all_mems)

            tower_mems.append(all_mems)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_all_mems)
            tower_grads_and_vars.append(grads_and_vars_i)

    #  average losses and gradients across towers
    if len(tower_losses) > 1:
        loss = tf.add_n(tower_losses) / len(tower_losses)
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]
    grads, all_vars = zip(*grads_and_vars)

    # clip gradient
    clipped, gnorm = tf.clip_by_global_norm(grads,
                                            FLAGS.clip)  # 梯度张量和一个所有张量的全局范数
    grads_and_vars = list(zip(clipped, all_vars))

    # configure the optimizer
    global_step = tf.train.get_or_create_global_step()

    # warmup stage: increase the learning rate linearly
    if FLAGS.warmup_steps > 0:
        warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                    * FLAGS.learning_rate
    else:
        warmup_lr = 0.0

    # decay stage: decay the learning rate using the cosine schedule
    decay_lr = tf.train.cosine_decay(
        FLAGS.learning_rate,
        global_step=global_step - FLAGS.warmup_steps,
        decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
        alpha=FLAGS.min_lr_ratio)

    # choose warmup or decay
    learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr,
                             decay_lr)

    # get the train op
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    # Training loop
    tower_mems_np = [[
        np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model],
                 dtype=np.float32) for layer in range(FLAGS.n_layer)
    ] for core in range(FLAGS.num_core_per_host)]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.warm_start_path is not None:
            tf.logging.info("warm start from {}".format(FLAGS.warm_start_path))
            saver.restore(sess, FLAGS.warm_start_path)

        fetches = [
            loss, tower_new_mems, global_step, gnorm, learning_rate, train_op
        ]

        total_loss, prev_step = 0., -1
        epoch = 0
        while True:
            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np

            fetched = sess.run(fetches, feed_dict=feed_dict)

            loss_np, tower_mems_np, curr_step = fetched[:3]
            total_loss += loss_np

            if curr_step > 0 and curr_step % FLAGS.iterations == 0:
                curr_loss = total_loss / (curr_step - prev_step)
                tf.logging.info(
                    "[{}] | gnorm {:.2f} lr {:8.6f} "
                    "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
                        curr_step, fetched[-3], fetched[-2], curr_loss,
                        math.exp(curr_loss), curr_loss / math.log(2)))
                total_loss, prev_step = 0., curr_step

            if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                saver.save(sess, save_path)
                tf.logging.info("Model saved in path: {}".format(save_path))

            if curr_step > 0 and curr_step % num_batches == 0:  # 整除一次,相当于一个epoch
                epoch += 1
                tf.logging.info("epoch: {} Done".format(epoch))
            if curr_step == FLAGS.train_steps:
                break

        tf.logging.info("run {} epochs:".format(TRAIN_STEPS / num_batches))