Example #1
0
def infer(config):
  """
  Implements inference mode
  :param config: python dictionary describing model and data layer
  :return: nothing
  """
  deco_print("Executing training mode")
  deco_print("Creating data layer")
  dl = data_layer.ParallelDataInRamInputLayer(params=config)
  if 'pad_vocabs_to_eight' in config and config['pad_vocabs_to_eight']:
    config['src_vocab_size'] = math.ceil(len(dl.source_seq2idx) / 8) * 8
    config['tgt_vocab_size'] = math.ceil(len(dl.target_seq2idx) / 8) * 8
  else:
    config['src_vocab_size'] = len(dl.source_seq2idx)
    config['tgt_vocab_size'] = len(dl.target_seq2idx)
  use_beam_search = False if "decoder_type" not in config else config["decoder_type"] == "beam_search"
  deco_print("Data layer created")

  with tf.Graph().as_default():
    global_step = tf.contrib.framework.get_or_create_global_step()
    model = seq2seq_model.BasicSeq2SeqWithAttention(model_params=config,
                                                    global_step=global_step,
                                                    tgt_max_size=max(config["bucket_tgt"]),
                                                    mode="infer")
    fetches = [model._final_outputs]
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    saver = tf.train.Saver()
    with tf.train.MonitoredTrainingSession(config=sess_config) as sess:
      deco_print("Trying to restore from: {}".format(tf.train.latest_checkpoint(FLAGS.logdir)))
      saver.restore(sess, tf.train.latest_checkpoint(FLAGS.logdir))
      deco_print("Saving inference results to: " + FLAGS.inference_out)
      if FLAGS.inference_out == "stdout":
        fout = sys.stdout
      else:
        fout = open(FLAGS.inference_out, 'w')

      for i, (x, y, bucket_id, len_x, len_y) in enumerate(dl.iterate_one_epoch()):
        # need to check outputs for beam search, and if required, make a common approach
        # to handle both greedy and beam search decoding methods
        samples = sess.run(fetches=fetches,
                           feed_dict={
                               model.x: x,
                               model.x_length: len_x,
                           })
        if i % 200 == 0 and FLAGS.inference_out != "stdout":
          print(utils.pretty_print_array(samples[0].predicted_ids[:, :, 0][0] if use_beam_search else samples[0].sample_id[0],
                                         vocab=dl.target_idx2seq,
                                         ignore_special=False,
                                         delim=config["delimiter"]))
        fout.write(utils.pretty_print_array(samples[0].predicted_ids[:, :, 0][0] if use_beam_search else samples[0].sample_id[0],
                                            vocab=dl.target_idx2seq,
                                            ignore_special=True,
                                            delim=config["delimiter"]) + "\n")
      if FLAGS.inference_out != "stdout":
          fout.close()
  deco_print("Inference finished")
Example #2
0
def train(config):
  """
  Implements training mode
  :param config: python dictionary describing model and data layer
  :return: nothing
  """
  deco_print("Executing training mode")
  deco_print("Creating data layer")
  dl = data_layer.ParallelDataInRamInputLayer(params=config)
  config['src_vocab_size'] = len(dl.source_seq2idx)
  config['tgt_vocab_size'] = len(dl.target_seq2idx)

  eval_using_bleu = True if "eval_bleu" not in config else config["eval_bleu"]
  bpe_used = False if "bpe_used" not in config else config["bpe_used"]

  #create eval config
  do_eval = False
  if 'source_file_eval' in config and 'target_file_eval' in config:
    do_eval = True
    eval_config = copy.deepcopy(config)
    eval_config['source_file'] = eval_config['source_file_eval']
    eval_config['target_file'] = eval_config['target_file_eval']
    deco_print('Creating eval data layer')
    eval_dl = data_layer.ParallelDataInRamInputLayer(params=eval_config)

  deco_print("Data layer created")
  with tf.Graph().as_default():
    global_step = tf.contrib.framework.get_or_create_global_step()
    #create model
    model = seq2seq_model.BasicSeq2SeqWithAttention(model_params=config,
                                                    global_step=global_step,
                                                    mode="train")
    #create eval model
    if do_eval:
      e_model = seq2seq_model.BasicSeq2SeqWithAttention(model_params=eval_config,
                                                        global_step=global_step,
                                                        force_var_reuse=True,
                                                        mode="eval")

    tf.summary.scalar(name="loss", tensor=model.loss)
    if do_eval:
      eval_fetches = [e_model._eval_y, e_model._eval_ops]
    summary_op = tf.summary.merge_all()

    fetches = [model.loss, model.train_op, model._lr]
    fetches_s = [model.loss, model.train_op, model._final_outputs, summary_op, model._lr]

    sess_config = tf.ConfigProto(allow_soft_placement=True)

    # regular checkpoint saver
    saver = tf.train.Saver()
    # eval checkpoint saver
    epoch_saver = tf.train.Saver(max_to_keep=FLAGS.max_eval_checkpoints)

    with tf.Session(config=sess_config) as sess:
      sw = tf.summary.FileWriter(logdir=FLAGS.logdir, graph=sess.graph, flush_secs=60)

      if tf.train.latest_checkpoint(FLAGS.logdir) is not None:
          saver.restore(sess, tf.train.latest_checkpoint(FLAGS.logdir))
          deco_print("Restored checkpoint. Resuming training")
      else:
          sess.run(tf.global_variables_initializer())

      #begin training
      for epoch in range(0, config['num_epochs']):
        deco_print("\n\n")
        deco_print("Doing epoch {}".format(epoch))
        epoch_start = time.time()
        total_train_loss = 0.0
        t_cnt = 0
        for i, (x, y, bucket_id, len_x, len_y) in enumerate(dl.iterate_one_epoch()):
          # run evaluation
          if do_eval and i % FLAGS.eval_frequency == 0:
            deco_print("Evaluation on validation set")
            preds = []
            targets = []
            #iterate through evaluation data
            for j, (x, y, bucket_id, len_x, len_y) in enumerate(eval_dl.iterate_one_epoch()):
              tgt, samples = sess.run(fetches=eval_fetches,
                                  feed_dict={
                                    e_model.x: x,
                                    e_model.y: y,
                                    e_model.x_length: len_x,
                                    e_model.y_length: len_y
                                  })

              if eval_using_bleu:
                preds.extend([utils.transform_for_bleu(si,
                                             vocab=eval_dl.target_idx2seq,
                                             ignore_special=True,
                                             delim=config["delimiter"], bpe_used=bpe_used) for sample in samples for si in sample.sample_id])
                targets.extend([[utils.transform_for_bleu(yi,
                                             vocab=eval_dl.target_idx2seq,
                                             ignore_special=True,
                                             delim=config["delimiter"], bpe_used=bpe_used)] for yii in tgt for yi in yii])

            eval_dl.bucketize()

            if eval_using_bleu:
              eval_bleu = calculate_bleu(preds, targets)
              bleu_value = summary_pb2.Summary.Value(tag="Eval_BLEU_Score", simple_value=eval_bleu)
              bleu_summary = summary_pb2.Summary(value=[bleu_value])
              sw.add_summary(summary=bleu_summary, global_step=sess.run(global_step))
              sw.flush()

            if i > 0:
              deco_print("Saving EVAL checkpoint")
              epoch_saver.save(sess, save_path=os.path.join(FLAGS.logdir, "model-eval"), global_step=global_step)

          # save model
          if i % FLAGS.checkpoint_frequency == 0 and i > 0: # save freq arg
              deco_print("Saving checkpoint")
              saver.save(sess, save_path=os.path.join(FLAGS.logdir, "model"), global_step=global_step)

          # print sample
          if i % FLAGS.summary_frequency == 0: # print arg
            loss, _, samples, sm, lr = sess.run(fetches=fetches_s,
                                        feed_dict={
                                          model.x: x,
                                          model.y: y,
                                          model.x_length: len_x,
                                          model.y_length: len_y
                                        })
            sw.add_summary(sm, global_step=sess.run(global_step))
            deco_print("In epoch {}, step {} the loss is {}".format(epoch, i, loss))
            deco_print("Train Source[0]:     " + utils.pretty_print_array(x[0, :],
                                                                      vocab=dl.source_idx2seq,
                                                                      delim=config["delimiter"]))
            deco_print("Train Target[0]:     " + utils.pretty_print_array(y[0,:],
                                                                      vocab=dl.target_idx2seq,
                                                                      delim = config["delimiter"]))
            deco_print("Train Prediction[0]: " + utils.pretty_print_array(samples.sample_id[0,:],
                                                                          vocab=dl.target_idx2seq,
                                                                          delim=config["delimiter"]))
          else:
            loss, _, lr = sess.run(fetches=fetches,
                            feed_dict={
                                model.x: x,
                                model.y: y,
                                model.x_length: len_x,
                                model.y_length: len_y
                             })
          total_train_loss += loss
          t_cnt += 1

        # epoch finished
        epoch_end = time.time()
        deco_print('Epoch {} training loss: {}'.format(epoch, total_train_loss / t_cnt))
        value = summary_pb2.Summary.Value(tag="TrainEpochLoss", simple_value= total_train_loss / t_cnt)
        summary = summary_pb2.Summary(value=[value])
        sw.add_summary(summary=summary, global_step=epoch)
        sw.flush()
        deco_print("Did epoch {} in {} seconds".format(epoch, epoch_end - epoch_start))
        dl.bucketize()

      # end of epoch loop
      deco_print("Saving last checkpoint")
      saver.save(sess, save_path=os.path.join(FLAGS.logdir, "model"), global_step=global_step)
Example #3
0
def train(config):
    """
  Implements training mode
  :param config: python dictionary describing model and data layer
  :param eval_config: (default) None python dictionary describing model and data layer used for evaluation
  :return: nothing
  """
    hvd.init()
    utils.deco_print("Executing training mode")
    utils.deco_print("Creating data layer")
    if FLAGS.split_data_per_rank:
        dl = data_layer.ParallelDataInRamInputLayer(params=config,
                                                    num_workers=hvd.size(),
                                                    worker_id=hvd.rank())
    else:
        dl = data_layer.ParallelDataInRamInputLayer(params=config)
    if 'pad_vocabs_to_eight' in config and config['pad_vocabs_to_eight']:
        config['src_vocab_size'] = int(
            math.ceil(len(dl.source_seq2idx) / 8) * 8)
        config['tgt_vocab_size'] = int(
            math.ceil(len(dl.target_seq2idx) / 8) * 8)
    else:
        config['src_vocab_size'] = len(dl.source_seq2idx)
        config['tgt_vocab_size'] = len(dl.target_seq2idx)
    utils.deco_print("Data layer created")

    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()
        # Create train model
        model = seq2seq_model.BasicSeq2SeqWithAttention(
            model_params=config,
            global_step=global_step,
            mode="train",
            gpu_ids="horovod")
        fetches = [model.loss, model.train_op, model.lr]

        if hvd.rank() == 0:
            tf.summary.scalar(name="loss", tensor=model.loss)
            summary_op = tf.summary.merge_all()
            fetches_s = [
                model.loss, model.train_op, model.final_outputs, summary_op,
                model.lr
            ]
            #sw = tf.summary.FileWriter(logdir=FLAGS.logdir, flush_secs=60)
        # done constructing graph at this point

        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.visible_device_list = str(hvd.local_rank())
        sess_config.gpu_options.allow_growth = True

        hooks = [
            hvd.BroadcastGlobalVariablesHook(hvd.size() - 1),
            tf.train.StepCounterHook(every_n_steps=FLAGS.summary_frequency),
            tf.train.StopAtStepHook(last_step=FLAGS.max_steps)
        ]
        checkpoint_dir = FLAGS.logdir if hvd.rank() == 0 else None

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=checkpoint_dir,
                save_summaries_steps=FLAGS.summary_frequency,
                config=sess_config,
                save_checkpoint_secs=FLAGS.checkpoint_frequency,
                log_step_count_steps=FLAGS.summary_frequency,
                stop_grace_period_secs=300,
                hooks=hooks) as sess:
            #begin training
            for i, (x, y, bucket_id, len_x,
                    len_y) in enumerate(dl.iterate_forever()):
                if not sess.should_stop():
                    # do training
                    if i % FLAGS.summary_frequency == 0 and hvd.rank(
                    ) == 0:  # print arg
                        loss, _, samples, sm, lr = sess.run(fetches=fetches_s,
                                                            feed_dict={
                                                                model.x:
                                                                x,
                                                                model.y:
                                                                y,
                                                                model.x_length:
                                                                len_x,
                                                                model.y_length:
                                                                len_y
                                                            })
                        #sw.add_summary(sm, global_step=sess.run(global_step))
                        utils.deco_print("Step: " + str(i))
                        utils.deco_print("Train Source[0]:     " +
                                         utils.pretty_print_array(
                                             x[0, :],
                                             vocab=dl.source_idx2seq,
                                             delim=config["delimiter"]))
                        utils.deco_print("Train Target[0]:     " +
                                         utils.pretty_print_array(
                                             y[0, :],
                                             vocab=dl.target_idx2seq,
                                             delim=config["delimiter"]))
                        utils.deco_print("Train Prediction[0]: " +
                                         utils.pretty_print_array(
                                             samples.sample_id[0, :],
                                             vocab=dl.target_idx2seq,
                                             delim=config["delimiter"]))
                    else:
                        loss, _, lr = sess.run(fetches=fetches,
                                               feed_dict={
                                                   model.x: x,
                                                   model.y: y,
                                                   model.x_length: len_x,
                                                   model.y_length: len_y
                                               })
                    # training step done
                else:
                    utils.deco_print("Finished training on rank {}".format(
                        hvd.rank()))
                    break