Ejemplo n.º 1
0
def infer(model, checkpoint, output_file):
    results_per_batch = restore_and_get_results(model,
                                                checkpoint,
                                                mode="infer")
    if not model.on_horovod or model.hvd.rank() == 0:
        model.finalize_inference(results_per_batch, output_file)
        deco_print("Finished inference")
Ejemplo n.º 2
0
    def maybe_print_logs(self, input_values, output_values, training_step):
        x, len_x = input_values['source_tensors']
        y, len_y = input_values['target_tensors']
        # samples = output_values[0][0]

        x_sample = x[0]
        len_x_sample = len_x[0]
        y_sample = y[0]
        len_y_sample = len_y[0]

        deco_print(
            "Train Source[0]:     " + array_to_string(
                x_sample[:len_x_sample],
                vocab=self.get_data_layer().corp.dictionary.idx2word,
                delim=self.get_data_layer().params["delimiter"],
            ),
            offset=4,
        )
        deco_print(
            "Train Target[0]:     " + array_to_string(
                y_sample[:len_y_sample],
                vocab=self.get_data_layer().corp.dictionary.idx2word,
                delim=self.get_data_layer().params["delimiter"],
            ),
            offset=4,
        )

        return {}
Ejemplo n.º 3
0
    def finalize_inference(self, results_per_batch, output_file):
        out = open(output_file, 'w')
        out.write('\t'.join(['Source', 'Pred', 'Label']) + '\n')
        preds, labels = [], []

        for results in results_per_batch:
            for x, pred, y in results:
                out.write('\t'.join([x, str(pred), str(y)]) + '\n')
                preds.append(pred)
                labels.append(y)

        if len(labels) > 0 and labels[0] is not None:
            preds = np.asarray(preds)
            labels = np.asarray(labels)
            deco_print(
                "TEST Accuracy: {:.4f}".format(metrics.accuracy(labels,
                                                                preds)),
                offset=4,
            )
            deco_print(
                "TEST Precision: {:.4f} | Recall: {:.4f} | F1: {:.4f}".format(
                    metrics.precision(labels, preds),
                    metrics.recall(labels, preds), metrics.f1(labels, preds)),
                offset=4,
            )
        return {}
Ejemplo n.º 4
0
    def finalize_evaluation(self, results_per_batch, training_step=None):
        accuracies = []
        true_pos, pred_pos, actual_pos = 0.0, 0.0, 0.0

        for results in results_per_batch:
            if not 'accuracy' in results:
                return {}
            accuracies.append(results['accuracy'])
            if 'true_pos' in results:
                true_pos += results['true_pos']
                pred_pos += results['pred_pos']
                actual_pos += results['actual_pos']

        deco_print(
            "EVAL Accuracy: {:.4f}".format(np.mean(accuracies)),
            offset=4,
        )

        if true_pos > 0:
            prec = true_pos / pred_pos
            rec = true_pos / actual_pos
            f1 = 2.0 * prec * rec / (rec + prec)
            deco_print(
                "EVAL Precision: {:.4f} | Recall: {:.4f} | F1: {:.4f} | True pos: {}"
                .format(prec, rec, f1, true_pos),
                offset=4,
            )
        return {}
Ejemplo n.º 5
0
  def maybe_evaluate(self, inputs_per_batch, outputs_per_batch):
    total_word_lev = 0.0
    total_word_count = 0.0
    samples_count = 0
    dataset_size = self.data_layer.get_size_in_samples()

    for input_values, output_values in zip(inputs_per_batch, outputs_per_batch):
      for gpu_id in range(self.num_gpus):
        decoded_sequence = output_values[gpu_id]
        decoded_texts = sparse_tensor_to_chars(
          decoded_sequence,
          self.data_layer.params['idx2char'],
        )
        for sample_id in range(self.params['batch_size_per_gpu']):
          # this is necessary for correct processing of the last batch
          if samples_count >= dataset_size:
            break
          samples_count += 1

          # y is the third returned input value, thus input_values[2]
          # len_y is the fourth returned input value
          y = input_values[2][gpu_id][sample_id]
          len_y = input_values[3][gpu_id][sample_id]
          true_text = "".join(map(self.data_layer.params['idx2char'].get,
                                  y[:len_y]))
          pred_text = "".join(decoded_texts[sample_id])

          total_word_lev += levenshtein(true_text.split(), pred_text.split())
          total_word_count += len(true_text.split())

    total_wer = 1.0 * total_word_lev / total_word_count
    deco_print("Validation WER:  {:.4f}".format(total_wer), offset=4)
    return {
      "Eval WER": total_wer,
    }
Ejemplo n.º 6
0
def evaluate(model, checkpoint):
  results_per_batch = restore_and_get_results(model, checkpoint, mode="eval")
  if not model.on_horovod or model.hvd.rank() == 0:
    eval_dict = model.finalize_evaluation(results_per_batch)
    deco_print("Finished evaluation")
    return eval_dict
  return None
Ejemplo n.º 7
0
def evaluate(model, checkpoint):
  results_per_batch = restore_and_get_results(model, checkpoint, mode="eval")
  if not model.on_horovod or model.hvd.rank() == 0:
    eval_dict = model.finalize_evaluation(results_per_batch)
    deco_print("Finished evaluation")
    return eval_dict
  else:
    return None
Ejemplo n.º 8
0
def build_layer(inputs,
                layer,
                layer_params,
                data_format,
                regularizer,
                training,
                verbose=True):
    """This function builds a layer from the layer function and it's parameters.

  It will automatically add regularizer parameter to the layer_params if the
  layer supports regularization. To check this, it will look for the
  "regularizer", "kernel_regularizer" and "gamma_regularizer" names in this
  order in the ``layer`` call signature. If one of this parameters is supported
  it will pass regularizer object as a value for that parameter. Based on the
  same "checking signature" technique "data_format" and "training" parameters
  will try to be added.

  Args:
    inputs: input Tensor that will be passed to the layer. Note that layer has
        to accept input as the first parameter.
    layer: layer function or class with ``__call__`` method defined.
    layer_params (dict): parameters passed to the ``layer``.
    data_format (string): data format ("channels_first" or "channels_last")
        that will be tried to be passed as an additional argument.
    regularizer: regularizer instance that will be tried to be passed as an
        additional argument.
    training (bool): whether layer is built in training mode. Will be tried to
        be passed as an additional argument.
    verbose (bool): whether to print information about built layers.

  Returns:
    Tensor with layer output.
  """
    layer_params_cp = copy.deepcopy(layer_params)
    for reg_name in ['regularizer', 'kernel_regularizer', 'gamma_regularizer']:
        if reg_name not in layer_params_cp and \
           reg_name in signature(layer).parameters:
            layer_params_cp.update({reg_name: regularizer})

    if 'data_format' not in layer_params_cp and \
       'data_format' in signature(layer).parameters:
        layer_params_cp.update({'data_format': data_format})

    if 'training' not in layer_params_cp and \
       'training' in signature(layer).parameters:
        layer_params_cp.update({'training': training})

    outputs = layer(inputs, **layer_params_cp)

    if verbose:
        if hasattr(layer, '_tf_api_names'):
            layer_name = layer._tf_api_names[0]
        else:
            layer_name = layer
        deco_print("Building layer: {}(inputs, {})".format(
            layer_name, ", ".join("{}={}".format(key, value)
                                  for key, value in layer_params_cp.items())))
    return outputs
Ejemplo n.º 9
0
def infer(model, checkpoint, output_file):
    results_per_batch = restore_and_get_results(
        model,
        checkpoint,
        mode="infer",
        detailed_inference_outputs=model.params.get("verbose_inference",
                                                    False),
        save_mels=model.params.get("save_mels", False))
    if not model.on_horovod or model.hvd.rank() == 0:
        deco_print("Finished inference")
Ejemplo n.º 10
0
  def finalize_evaluation(self, results_per_batch):
    total_word_lev = 0.0
    total_word_count = 0.0
    for word_lev, word_count in results_per_batch:
      total_word_lev += word_lev
      total_word_count += word_count

    total_wer = 1.0 * total_word_lev / total_word_count
    deco_print("Validation WER:  {:.4f}".format(total_wer), offset=4)
    return {
      "Eval WER": total_wer,
    }
Ejemplo n.º 11
0
  def finalize_evaluation(self, results_per_batch, training_step=None):
    total_word_lev = 0.0
    total_word_count = 0.0
    for word_lev, word_count in results_per_batch:
      total_word_lev += word_lev
      total_word_count += word_count

    total_wer = 1.0 * total_word_lev / total_word_count
    deco_print("Validation WER:  {:.4f}".format(total_wer), offset=4)
    return {
        "Eval WER": total_wer,
    }
Ejemplo n.º 12
0
def run():
    """This function executes a saved checkpoint for
  50 LibriSpeech dev clean files whose alignments are stored in
  calibration/target.json
  This function saves a pickle file with logits after running
  through the model as calibration/sample.pkl

  :return: None
  """
    args, base_config, base_model, config_module = get_calibration_config(
        sys.argv[1:])
    config_module["infer_params"]["data_layer_params"]["dataset_files"] = \
      ["calibration/sample.csv"]
    config_module["base_params"]["decoder_params"][
        "infer_logits_to_pickle"] = True
    load_model = base_config.get('load_model', None)
    restore_best_checkpoint = base_config.get('restore_best_checkpoint', False)
    base_ckpt_dir = check_base_model_logdir(load_model, args,
                                            restore_best_checkpoint)
    base_config['load_model'] = base_ckpt_dir

    # Check logdir and create it if necessary
    checkpoint = check_logdir(args, base_config, restore_best_checkpoint)

    # Initilize Horovod
    if base_config['use_horovod']:
        import horovod.tensorflow as hvd
        hvd.init()
        if hvd.rank() == 0:
            deco_print("Using horovod")
        from mpi4py import MPI
        MPI.COMM_WORLD.Barrier()
    else:
        hvd = None

    if args.enable_logs:
        if hvd is None or hvd.rank() == 0:
            old_stdout, old_stderr, stdout_log, stderr_log = create_logdir(
                args, base_config)
            base_config['logdir'] = os.path.join(base_config['logdir'], 'logs')

    if args.mode == 'infer':
        if hvd is None or hvd.rank() == 0:
            deco_print("Loading model from {}".format(checkpoint))
    else:
        print("Run in infer mode only")
        sys.exit()
    with tf.Graph().as_default():
        model = create_model(args, base_config, config_module, base_model, hvd,
                             checkpoint)
        infer(model, checkpoint, args.infer_output_file)

    return args.calibration_out
Ejemplo n.º 13
0
    def _build_forward_pass_graph(self, input_tensors, gpu_id=0):
        """TensorFlow graph for sequence-to-sequence model is created here.
    This function connects encoder, decoder and loss together. As an input for
    encoder it will specify source sequence and source length (as returned from
    the data layer). As an input for decoder it will specify target sequence
    and target length as well as all output returned from encoder. For loss it
    will also specify target sequence and length and all output returned from
    decoder. Note that loss will only be built for mode == "train" or "eval".

    See :meth:`models.model.Model._build_forward_pass_graph` for description of
    arguments and return values.
    """
        if self.mode == "infer":
            src_sequence, src_length = input_tensors
            tgt_sequence, tgt_length = None, None
        else:
            src_sequence, src_length, tgt_sequence, tgt_length = input_tensors

        with tf.variable_scope("ForwardPass"):
            encoder_input = {
                "src_sequence": src_sequence,
                "src_length": src_length,
            }
            encoder_output = self.encoder.encode(input_dict=encoder_input)

            tgt_length_eval = tf.cast(1.2 * tf.cast(src_length, tf.float32),
                                      tf.int32)
            decoder_input = {
                "encoder_output":
                encoder_output,
                "tgt_sequence":
                tgt_sequence,
                # when the mode is not "train", replacing correct tgt_length with
                # somewhat increased src_length
                "tgt_length":
                tgt_length if self.mode == "train" else tgt_length_eval
            }
            decoder_output = self.decoder.decode(input_dict=decoder_input)
            decoder_samples = decoder_output.get("samples", None)

            if self.mode == "train" or self.mode == "eval":
                with tf.variable_scope("Loss"):
                    loss_input_dict = {
                        "decoder_output": decoder_output,
                        "tgt_sequence": tgt_sequence,
                        "tgt_length": tgt_length,
                    }
                    loss = self.loss_computator.compute_loss(loss_input_dict)
            else:
                deco_print("Inference Mode. Loss part of graph isn't built.")
                loss = None
            return loss, decoder_samples
Ejemplo n.º 14
0
    def finalize_evaluation(self, results_per_batch, training_step=None):
        preds, targets = [], []
        for preds_cur, targets_cur in results_per_batch:
            if self.params.get('eval_using_bleu', True):
                preds.extend(preds_cur)
                targets.extend(targets_cur)

        if self.params.get('eval_using_bleu', True):
            eval_bleu = calculate_bleu(preds, targets)
            deco_print("Eval BLUE score: {}".format(eval_bleu), offset=4)
            return {'Eval_BLEU_Score': eval_bleu}

        return {}
Ejemplo n.º 15
0
  def finalize_evaluation(self, results_per_batch):
    preds, targets = [], []
    for preds_cur, targets_cur in results_per_batch:
      if self.params.get('eval_using_bleu', True):
        preds.extend(preds_cur)
        targets.extend(targets_cur)

    if self.params.get('eval_using_bleu', True):
      eval_bleu = calculate_bleu(preds, targets)
      deco_print("Eval BLUE score: {}".format(eval_bleu), offset=4)
      return {'Eval_BLEU_Score': eval_bleu}

    return {}
Ejemplo n.º 16
0
  def after_run(self, run_context, run_values):
    results, step = run_values.results
    self._iter_count = step

    if not results:
      return
    self._timer.update_last_triggered_step(self._iter_count - 1)

    if self._model.steps_in_epoch is None:
      deco_print("Global step {}:".format(step), end=" ")
    else:
      deco_print(
        "Epoch {}, global step {}:".format(
          step // self._model.steps_in_epoch, step),
        end=" ",
      )

    loss = results[0]
    deco_print("loss = {:.4f}".format(loss), start="", end=", ")

    tm = (time.time() - self._last_time) / self._every_steps
    m, s = divmod(tm, 60)
    h, m = divmod(m, 60)

    deco_print(
      "time per step = {}:{:02}:{:.3f}".format(int(h), int(m), s),
      start="",
    )
    self._last_time = time.time()
Ejemplo n.º 17
0
 def infer(self, input_values, output_values):
     vocab = self.get_data_layer().corp.dictionary.idx2word
     seed_tokens = self.params['encoder_params']['seed_tokens']
     for i in range(len(seed_tokens)):
         print(output_values[0][i].shape)
         print('Seed:', vocab[seed_tokens[i]] + '\n')
         deco_print(
             "Output: " + array_to_string(
                 output_values[0][i],
                 vocab=self.get_data_layer().corp.dictionary.idx2word,
                 delim=self.get_data_layer().params["delimiter"],
             ),
             offset=4,
         )
Ejemplo n.º 18
0
    def after_run(self, run_context, run_values):
        results, step = run_values.results
        self._iter_count = step

        if not self._triggered and step != self._last_step - 1:
            return
        self._timer.update_last_triggered_step(self._iter_count - 1)

        deco_print("Running evaluation on a validation set:")

        inputs_per_batch, outputs_per_batch = [], []
        total_loss = 0.0

        for cnt, feed_dict in enumerate(
                self._model.data_layer.iterate_one_epoch(cross_over=True)):
            loss, inputs, outputs = run_context.session.run(
                self._fetches,
                feed_dict,
            )
            inputs_per_batch.append(inputs)
            outputs_per_batch.append(outputs)
            total_loss += loss

        total_loss /= (cnt + 1)
        deco_print("Validation loss: {:.4f}".format(total_loss), offset=4)
        dict_to_log = self._model.maybe_evaluate(
            inputs_per_batch,
            outputs_per_batch,
        )
        dict_to_log['eval_loss'] = total_loss

        # saving the best validation model
        if total_loss < self._best_eval_loss:
            self._best_eval_loss = total_loss
            self._eval_saver.save(
                run_context.session,
                os.path.join(self._model.params['logdir'], 'best_models',
                             'val_loss={:.4f}-step'.format(total_loss)),
                global_step=step + 1,
            )

        # optionally logging to tensorboard any values
        # returned from maybe_print_logs
        if dict_to_log:
            log_summaries_from_dict(
                dict_to_log,
                self._model.params['logdir'],
                step,
            )
Ejemplo n.º 19
0
def main():
    # Parse args and create config
    args, base_config, base_model, config_module = get_base_config(
        sys.argv[1:])

    if args.mode == "interactive_infer":
        raise ValueError(
            "Interactive infer is meant to be run from an IPython",
            "notebook not from run.py.")

    # Initilize Horovod
    if base_config['use_horovod']:
        import horovod.tensorflow as hvd
        hvd.init()
        if hvd.rank() == 0:
            deco_print("Using horovod")
    else:
        hvd = None

    restore_best_checkpoint = base_config.get('restore_best_checkpoint', False)

    # Check logdir and create it if necessary
    checkpoint = check_logdir(args, base_config, restore_best_checkpoint)
    if args.enable_logs:
        if hvd is None or hvd.rank() == 0:
            old_stdout, old_stderr, stdout_log, stderr_log = create_logdir(
                args, base_config)
        base_config['logdir'] = os.path.join(base_config['logdir'], 'logs')

    if args.mode == 'train' or args.mode == 'train_eval' or args.benchmark:
        if hvd is None or hvd.rank() == 0:
            if checkpoint is None or args.benchmark:
                deco_print("Starting training from scratch")
            else:
                deco_print(
                    "Restored checkpoint from {}. Resuming training".format(
                        checkpoint), )
    elif args.mode == 'eval' or args.mode == 'infer':
        if hvd is None or hvd.rank() == 0:
            deco_print("Loading model from {}".format(checkpoint))

    # Create model and train/eval/infer
    with tf.Graph().as_default():
        model = create_model(args, base_config, config_module, base_model, hvd)
        if args.mode == "train_eval":
            train(model[0], model[1], debug_port=args.debug_port)
        elif args.mode == "train":
            train(model, None, debug_port=args.debug_port)
        elif args.mode == "eval":
            evaluate(model, checkpoint)
        elif args.mode == "infer":
            infer(model, checkpoint, args.infer_output_file, args.use_trt)

    if args.enable_logs and (hvd is None or hvd.rank() == 0):
        sys.stdout = old_stdout
        sys.stderr = old_stderr
        stdout_log.close()
        stderr_log.close()
Ejemplo n.º 20
0
    def __init__(self, params, model, name="ctc_loss"):
        """CTC loss constructor.

    See parent class for arguments description.

    Config parameters:

    * **mask_nan** (bool) --- whether to mask nans in the loss output. Defaults
      to True.
    """
        super(CTCLoss, self).__init__(params, model, name)
        self._mask_nan = self.params.get("mask_nan", True)
        # this loss can only operate in full precision
        if self.params['dtype'] != tf.float32:
            deco_print("Warning: defaulting CTC loss to work in float32")
        self.params['dtype'] = tf.float32
Ejemplo n.º 21
0
  def __init__(self, params, model, name="ctc_loss"):
    """CTC loss constructor.

    See parent class for arguments description.

    Config parameters:

    * **mask_nan** (bool) --- whether to mask nans in the loss output. Defaults
      to True.
    """
    super(CTCLoss, self).__init__(params, model, name)
    self._mask_nan = self.params.get("mask_nan", True)
    # this loss can only operate in full precision
    if self.params['dtype'] != tf.float32:
      deco_print("Warning: defaulting CTC loss to work in float32")
    self.params['dtype'] = tf.float32
Ejemplo n.º 22
0
  def finalize_evaluation(self, results_per_batch, training_step=None):
    top1 = 0.0
    top5 = 0.0
    total = 0.0

    for cur_total, cur_top1, cur_top5 in results_per_batch:
      top1 += cur_top1
      top5 += cur_top5
      total += cur_total

    top1 = 1.0 * top1 / total
    top5 = 1.0 * top5 / total
    deco_print("Validation top-1: {:.4f}".format(top1), offset=4)
    deco_print("Validation top-5: {:.4f}".format(top5), offset=4)
    return {
      "Eval top-1": top1,
      "Eval top-5": top5,
    }
Ejemplo n.º 23
0
  def maybe_print_logs(self, input_values, output_values, training_step):
    labels = input_values['target_tensors'][0]
    logits = output_values[0]

    labels = np.where(labels == 1)[1]

    total = logits.shape[0]
    top1 = np.sum(np.argmax(logits, axis=1) == labels)
    top5 = np.sum(labels[:, np.newaxis] == np.argpartition(logits, -5)[:, -5:])

    top1 = 1.0 * top1 / total
    top5 = 1.0 * top5 / total
    deco_print("Train batch top-1: {:.4f}".format(top1), offset=4)
    deco_print("Train batch top-5: {:.4f}".format(top5), offset=4)
    return {
      "Train batch top-1": top1,
      "Train batch top-5": top5,
    }
Ejemplo n.º 24
0
    def after_run(self, run_context, run_values):
        results, step = run_values.results
        self._iter_count = step

        if not self._triggered and step != self._last_step - 1:
            return
        self._timer.update_last_triggered_step(self._iter_count - 1)

        if not self._model.on_horovod or self._model.hvd.rank() == 0:
            deco_print("Running evaluation on a validation set:")

        results_per_batch, total_loss = get_results_for_epoch(
            self._model,
            run_context.session,
            mode="eval",
            compute_loss=True,
        )

        if not self._model.on_horovod or self._model.hvd.rank() == 0:
            deco_print("Validation loss: {:.4f}".format(total_loss), offset=4)

            dict_to_log = self._model.finalize_evaluation(
                results_per_batch, step)
            dict_to_log['eval_loss'] = total_loss

            # saving the best validation model
            if self._model.params['save_checkpoint_steps'] and \
               total_loss < self._best_eval_loss:
                self._best_eval_loss = total_loss
                self._eval_saver.save(
                    run_context.session,
                    os.path.join(self._model.params['logdir'], 'best_models',
                                 'val_loss={:.4f}-step'.format(total_loss)),
                    global_step=step + 1,
                )

            # optionally logging to tensorboard any values
            # returned from maybe_print_logs
            if self._model.params['save_summaries_steps']:
                log_summaries_from_dict(
                    dict_to_log,
                    self._model.params['logdir'],
                    step,
                )
Ejemplo n.º 25
0
    def infer(self, inputs_per_batch, outputs_per_batch, output_file):
        # this function assumes it is run on 1 gpu with batch size of 1
        with codecs.open(output_file, 'w', 'utf-8') as fout:
            for step in range(len(inputs_per_batch)):
                input_values = inputs_per_batch[step][0][0]
                output_values = outputs_per_batch[step][0]
                output_string = text_ids_to_string(
                    output_values[0],
                    self.data_layer.params['target_idx2seq'],
                    S_ID=self.decoder.params['GO_SYMBOL'],
                    EOS_ID=self.decoder.params['END_SYMBOL'],
                    PAD_ID=self.decoder.params['PAD_SYMBOL'],
                    ignore_special=True,
                    delim=' ',
                )
                input_string = text_ids_to_string(
                    input_values[0],
                    self.data_layer.params['source_idx2seq'],
                    S_ID=self.decoder.params['GO_SYMBOL'],
                    EOS_ID=self.decoder.params['END_SYMBOL'],
                    PAD_ID=self.decoder.params['PAD_SYMBOL'],
                    ignore_special=True,
                    delim=' ',
                )
                fout.write(output_string + "\n")
                if step % 200 == 0:
                    if six.PY2:
                        input_string = input_string.encode('utf-8')
                        output_string = output_string.encode('utf-8')

                    deco_print("Input sequence:  {}".format(input_string))
                    deco_print("Output sequence: {}".format(output_string))
                    deco_print("")
Ejemplo n.º 26
0
def get_batches_for_epoch(model, checkpoint):
    total_time = 0.0
    bench_start = model.params.get('bench_start', 10)

    saver = tf.train.Saver()
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.allow_growth = True
    with tf.Session(config=sess_config) as sess:
        saver.restore(sess, checkpoint)
        inputs_per_batch, outputs_per_batch = [], []
        fetches = [
            model.data_layer.get_input_tensors(),
            model.get_output_tensors()
        ]
        total_batches = model.data_layer.get_size_in_batches()
        for step, feed_dict in enumerate(
                model.data_layer.iterate_one_epoch(cross_over=True)):
            tm = time.time()
            inputs, outputs = sess.run(fetches, feed_dict)
            if step >= bench_start:
                total_time += time.time() - tm
            inputs_per_batch.append(inputs)
            outputs_per_batch.append(outputs)

            ending = '\r' if step < total_batches - 1 else '\n'
            deco_print("Processed {}/{} batches".format(
                step + 1, total_batches),
                       end=ending)
    if step > bench_start:
        deco_print("Avg time per step: {:.3}s".format(1.0 * total_time /
                                                      (step - bench_start)))
    else:
        deco_print("Not enough steps for benchmarking")
    return inputs_per_batch, outputs_per_batch
Ejemplo n.º 27
0
    def maybe_print_logs(self, input_values, output_values):
        x, len_x = input_values['source_tensors']
        y, len_y = input_values['target_tensors']
        samples = output_values[0]

        x_sample = x[0]
        len_x_sample = len_x[0]
        y_sample = y[0]
        len_y_sample = len_y[0]

        deco_print(
            "Train Source[0]:     " + array_to_string(
                x_sample[:len_x_sample],
                vocab=self.get_data_layer().params['source_idx2seq'],
                delim=self.get_data_layer().params["delimiter"],
            ),
            offset=4,
        )
        deco_print(
            "Train Target[0]:     " + array_to_string(
                y_sample[:len_y_sample],
                vocab=self.get_data_layer().params['target_idx2seq'],
                delim=self.get_data_layer().params["delimiter"],
            ),
            offset=4,
        )
        deco_print(
            "Train Prediction[0]: " + array_to_string(
                samples[0, :],
                vocab=self.get_data_layer().params['target_idx2seq'],
                delim=self.get_data_layer().params["delimiter"],
            ),
            offset=4,
        )
        return {}
Ejemplo n.º 28
0
  def maybe_print_logs(self, input_values, output_values):
    x, len_x, y, len_y = input_values
    decoded_sequence = output_values
    # using only the first sample from the batch on the first gpu, thus y[0][0]
    if self.on_horovod:
      y_one_sample = y[0]
      len_y_one_sample = len_y[0]
      decoded_sequence_one_batch = decoded_sequence[0]
    else:
      y_one_sample = y[0][0]
      len_y_one_sample = len_y[0][0]
      decoded_sequence_one_batch = decoded_sequence[0]
    # we also clip the sample by the correct length
    true_text = "".join(map(
      self.data_layer.params['idx2char'].get,
      y_one_sample[:len_y_one_sample],
    ))
    pred_text = "".join(sparse_tensor_to_chars(
      decoded_sequence_one_batch, self.data_layer.params['idx2char'])[0]
    )
    sample_wer = levenshtein(true_text.split(), pred_text.split()) / \
                 len(true_text.split())

    deco_print("Sample WER: {:.4f}".format(sample_wer), offset=4)
    deco_print("Sample target:     " + true_text, offset=4)
    deco_print("Sample prediction: " + pred_text, offset=4)
    return {
      'Sample WER': sample_wer,
    }
Ejemplo n.º 29
0
  def maybe_print_logs(self, input_values, output_values):
    x, len_x = input_values['source_tensors']
    y, len_y = input_values['target_tensors']
    samples = output_values[0]

    x_sample = x[0]
    len_x_sample = len_x[0]
    y_sample = y[0]
    len_y_sample = len_y[0]

    deco_print(
      "Train Source[0]:     " + array_to_string(
        x_sample[:len_x_sample],
        vocab=self.get_data_layer().params['source_idx2seq'],
        delim=self.get_data_layer().params["delimiter"],
      ),
      offset=4,
    )
    deco_print(
      "Train Target[0]:     " + array_to_string(
        y_sample[:len_y_sample],
        vocab=self.get_data_layer().params['target_idx2seq'],
        delim=self.get_data_layer().params["delimiter"],
      ),
      offset=4,
    )
    deco_print(
      "Train Prediction[0]: " + array_to_string(
        samples[0, :],
        vocab=self.get_data_layer().params['target_idx2seq'],
        delim=self.get_data_layer().params["delimiter"],
      ),
      offset=4,
    )
    return {}
Ejemplo n.º 30
0
    def maybe_print_logs(self, input_values, output_values):
        y, len_y = input_values['target_tensors']
        decoded_sequence = output_values
        y_one_sample = y[0]
        len_y_one_sample = len_y[0]
        decoded_sequence_one_batch = decoded_sequence[0]

        # we also clip the sample by the correct length
        true_text = "".join(
            map(
                self.get_data_layer().params['idx2char'].get,
                y_one_sample[:len_y_one_sample],
            ))
        pred_text = "".join(
            sparse_tensor_to_chars(
                decoded_sequence_one_batch,
                self.get_data_layer().params['idx2char'])[0])
        sample_wer = levenshtein(true_text.split(), pred_text.split()) / \
                     len(true_text.split())

        deco_print("Sample WER: {:.4f}".format(sample_wer), offset=4)
        deco_print("Sample target:     " + true_text, offset=4)
        deco_print("Sample prediction: " + pred_text, offset=4)
        return {
            'Sample WER': sample_wer,
        }
Ejemplo n.º 31
0
    def evaluate(self, input_values, output_values):
        ex, elen_x = input_values['source_tensors']
        ey, elen_y = input_values['target_tensors']

        x_sample = ex[0]
        len_x_sample = elen_x[0]
        y_sample = ey[0]
        len_y_sample = elen_y[0]

        deco_print(
            "*****EVAL Source[0]:     " + array_to_string(
                x_sample[:len_x_sample],
                vocab=self.get_data_layer().corp.dictionary.idx2word,
                delim=self.get_data_layer().params["delimiter"],
            ),
            offset=4,
        )
        deco_print(
            "*****EVAL Target[0]:     " + array_to_string(
                y_sample[:len_y_sample],
                vocab=self.get_data_layer().corp.dictionary.idx2word,
                delim=self.get_data_layer().params["delimiter"],
            ),
            offset=4,
        )
        samples = output_values[0][0]
        deco_print(
            "*****EVAL Prediction[0]: " + array_to_string(
                samples,
                vocab=self.get_data_layer().corp.dictionary.idx2word,
                delim=self.get_data_layer().params["delimiter"],
            ),
            offset=4,
        )
Ejemplo n.º 32
0
  def after_run(self, run_context, run_values):
    results, step = run_values.results
    self._iter_count = step

    if not self._triggered and step != self._last_step - 1:
      return
    self._timer.update_last_triggered_step(self._iter_count - 1)

    if not self._model.on_horovod or self._model.hvd.rank() == 0:
      deco_print("Running evaluation on a validation set:")

    results_per_batch, total_loss = get_results_for_epoch(
      self._model, run_context.session, mode="eval", compute_loss=True,
    )

    if not self._model.on_horovod or self._model.hvd.rank() == 0:
      deco_print("Validation loss: {:.4f}".format(total_loss), offset=4)

      dict_to_log = self._model.finalize_evaluation(results_per_batch)
      dict_to_log['eval_loss'] = total_loss

      # saving the best validation model
      if total_loss < self._best_eval_loss:
        self._best_eval_loss = total_loss
        self._eval_saver.save(
          run_context.session,
          os.path.join(self._model.params['logdir'], 'best_models',
                       'val_loss={:.4f}-step'.format(total_loss)),
          global_step=step + 1,
        )

      # optionally logging to tensorboard any values
      # returned from maybe_print_logs
      if dict_to_log:
        log_summaries_from_dict(
          dict_to_log,
          self._model.params['logdir'],
          step,
        )
Ejemplo n.º 33
0
  def after_run(self, run_context, run_values):
    results, step = run_values.results
    self._iter_count = step

    if not results:
      return
    self._timer.update_last_triggered_step(self._iter_count - 1)

    if self._model.steps_in_epoch is None:
      deco_print("Global step {}:".format(step), end=" ")
    else:
      deco_print(
          "Epoch {}, global step {}:".format(
              step // self._model.steps_in_epoch, step),
          end=" ",
      )

    loss = results[0]
    if not self._model.on_horovod or self._model.hvd.rank() == 0:
      if self._print_ppl:
        deco_print("Train loss: {:.4f} | ppl = {:.4f} | bpc = {:.4f}"
                   .format(loss, math.exp(loss),
                           loss/math.log(2)),
                   start="", end=", ")
      else:
        deco_print(
          "Train loss: {:.4f} ".format(loss),
          offset=4)

    tm = (time.time() - self._last_time) / self._every_steps
    m, s = divmod(tm, 60)
    h, m = divmod(m, 60)

    deco_print(
        "time per step = {}:{:02}:{:.3f}".format(int(h), int(m), s),
        start="",
    )
    self._last_time = time.time()
Ejemplo n.º 34
0
    def evaluate(self, input_values, output_values):
        ex, elen_x = input_values['source_tensors']
        ey, elen_y = input_values['target_tensors']

        x_sample = ex[0]
        len_x_sample = elen_x[0]
        y_sample = ey[0]
        len_y_sample = elen_y[0]

        deco_print(
            "*****EVAL Source[0]:     " + array_to_string(
                x_sample[:len_x_sample],
                vocab=self.get_data_layer().params['source_idx2seq'],
                delim=self.get_data_layer().params["delimiter"],
            ),
            offset=4,
        )
        deco_print(
            "*****EVAL Target[0]:     " + array_to_string(
                y_sample[:len_y_sample],
                vocab=self.get_data_layer().params['target_idx2seq'],
                delim=self.get_data_layer().params["delimiter"],
            ),
            offset=4,
        )
        samples = output_values[0]
        deco_print(
            "*****EVAL Prediction[0]: " + array_to_string(
                samples[0, :],
                vocab=self.get_data_layer().params['target_idx2seq'],
                delim=self.get_data_layer().params["delimiter"],
            ),
            offset=4,
        )
        preds, targets = [], []

        if self.params.get('eval_using_bleu', True):
            preds.extend([
                transform_for_bleu(
                    sample,
                    vocab=self.get_data_layer().params['target_idx2seq'],
                    ignore_special=True,
                    delim=self.get_data_layer().params["delimiter"],
                    bpe_used=self.params.get('bpe_used', False),
                ) for sample in samples
            ])
            targets.extend([[
                transform_for_bleu(
                    yi,
                    vocab=self.get_data_layer().params['target_idx2seq'],
                    ignore_special=True,
                    delim=self.get_data_layer().params["delimiter"],
                    bpe_used=self.params.get('bpe_used', False),
                )
            ] for yi in ey])

        return preds, targets
Ejemplo n.º 35
0
    def infer(self, input_values, output_values):
        if self._lm_phase:
            vocab = self.get_data_layer().corp.dictionary.idx2word
            seed_tokens = self.params['encoder_params']['seed_tokens']
            for i in range(len(seed_tokens)):
                print('Seed:', vocab[seed_tokens[i]] + '\n')
                deco_print(
                    "Output: " + array_to_string(
                        output_values[0][i],
                        vocab=self.get_data_layer().corp.dictionary.idx2word,
                        delim=self.delimiter,
                    ),
                    offset=4,
                )
            return []
        else:
            ex, elen_x = input_values['source_tensors']
            ey, elen_y = None, None
            if 'target_tensors' in input_values:
                ey, elen_y = input_values['target_tensors']

            n_samples = len(ex)
            results = []
            for i in range(n_samples):
                current_x = array_to_string(
                    ex[i][:elen_x[i]],
                    vocab=self.get_data_layer().corp.dictionary.idx2word,
                    delim=self.delimiter,
                ),
                current_pred = np.argmax(output_values[0][i])
                curret_y = None
                if ey is not None:
                    current_y = np.argmax(ey[i])

                results.append((current_x[0], current_pred, current_y))
            return results
Ejemplo n.º 36
0
 def finalize_inference(self, results_per_batch, output_file):
   with codecs.open(output_file, 'w', 'utf-8') as fout:
     step = 0
     for input_strings, output_strings in results_per_batch:
       for input_string, output_string in zip(input_strings, output_strings):
         fout.write(output_string + "\n")
         if step % 200 == 0:
           deco_print("Input sequence:  {}".format(input_string))
           deco_print("Output sequence: {}".format(output_string))
           deco_print("")
         step += 1
Ejemplo n.º 37
0
 def finalize_inference(self, results_per_batch, output_file):
   with codecs.open(output_file, 'w', 'utf-8') as fout:
     step = 0
     for input_strings, output_strings in results_per_batch:
       for input_string, output_string in zip(input_strings, output_strings):
         fout.write(output_string + "\n")
         if step % 200 == 0:
           deco_print("Input sequence:  {}".format(input_string))
           deco_print("Output sequence: {}".format(output_string))
           deco_print("")
         step += 1
Ejemplo n.º 38
0
  def evaluate(self, input_values, output_values):
    ex, elen_x = input_values['source_tensors']
    ey, elen_y = input_values['target_tensors']

    x_sample = ex[0]
    len_x_sample = elen_x[0]
    y_sample = ey[0]
    len_y_sample = elen_y[0]

    deco_print(
      "*****EVAL Source[0]:     " + array_to_string(
        x_sample[:len_x_sample],
        vocab=self.get_data_layer().params['source_idx2seq'],
        delim=self.get_data_layer().params["delimiter"],
      ),
      offset=4,
    )
    deco_print(
      "*****EVAL Target[0]:     " + array_to_string(
        y_sample[:len_y_sample],
        vocab=self.get_data_layer().params['target_idx2seq'],
        delim=self.get_data_layer().params["delimiter"],
      ),
      offset=4,
    )
    samples = output_values[0]
    deco_print(
      "*****EVAL Prediction[0]: " + array_to_string(
        samples[0, :],
        vocab=self.get_data_layer().params['target_idx2seq'],
        delim=self.get_data_layer().params["delimiter"],
      ),
      offset=4,
    )
    preds, targets = [], []

    if self.params.get('eval_using_bleu', True):
      preds.extend([transform_for_bleu(
        sample,
        vocab=self.get_data_layer().params['target_idx2seq'],
        ignore_special=True,
        delim=self.get_data_layer().params["delimiter"],
        bpe_used=self.params.get('bpe_used', False),
      ) for sample in samples])
      targets.extend([[transform_for_bleu(
        yi,
        vocab=self.get_data_layer().params['target_idx2seq'],
        ignore_special=True,
        delim=self.get_data_layer().params["delimiter"],
        bpe_used=self.params.get('bpe_used', False),
      )] for yi in ey])

    return preds, targets
Ejemplo n.º 39
0
  def maybe_print_logs(self, input_values, output_values):
    y, len_y = input_values['target_tensors']
    decoded_sequence = output_values
    y_one_sample = y[0]
    len_y_one_sample = len_y[0]
    decoded_sequence_one_batch = decoded_sequence[0]

    # we also clip the sample by the correct length
    true_text = "".join(map(
      self.get_data_layer().params['idx2char'].get,
      y_one_sample[:len_y_one_sample],
    ))
    pred_text = "".join(sparse_tensor_to_chars(
      decoded_sequence_one_batch, self.get_data_layer().params['idx2char'])[0]
    )
    sample_wer = levenshtein(true_text.split(), pred_text.split()) / \
                 len(true_text.split())

    deco_print("Sample WER: {:.4f}".format(sample_wer), offset=4)
    deco_print("Sample target:     " + true_text, offset=4)
    deco_print("Sample prediction: " + pred_text, offset=4)
    return {
      'Sample WER': sample_wer,
    }
Ejemplo n.º 40
0
def infer(model, checkpoint, output_file):
  results_per_batch = restore_and_get_results(model, checkpoint, mode="infer")
  if not model.on_horovod or model.hvd.rank() == 0:
    model.finalize_inference(results_per_batch, output_file)
    deco_print("Finished inference")
Ejemplo n.º 41
0
def train(train_model, eval_model=None, debug_port=None):
  if eval_model is not None and 'eval_steps' not in eval_model.params:
    raise ValueError("eval_steps parameter has to be specified "
                     "if eval_model is provided")
  hvd = train_model.hvd
  if hvd:
    master_worker = hvd.rank() == 0
  else:
    master_worker = True

  # initializing session parameters
  sess_config = tf.ConfigProto(allow_soft_placement=True)
  sess_config.gpu_options.allow_growth = True
  if hvd is not None:
    sess_config.gpu_options.visible_device_list = str(hvd.local_rank())

  # defining necessary hooks
  hooks = [tf.train.StopAtStepHook(last_step=train_model.last_step)]
  if hvd is not None:
    hooks.append(BroadcastGlobalVariablesHook(0))

  if master_worker:
    checkpoint_dir = train_model.params['logdir']
  else:
    checkpoint_dir = None

  if eval_model is not None:
    # noinspection PyTypeChecker
    hooks.append(
      RunEvaluationHook(
        every_steps=eval_model.params['eval_steps'],
        model=eval_model,
        last_step=train_model.last_step,
      ),
    )

  if master_worker:
    if train_model.params['save_checkpoint_steps'] is not None:
      # noinspection PyTypeChecker
      saver = tf.train.Saver(save_relative_paths=True)
      hooks.append(tf.train.CheckpointSaverHook(
        checkpoint_dir,
        saver=saver,
        save_steps=train_model.params['save_checkpoint_steps']),
      )
    if train_model.params['print_loss_steps'] is not None:
      # noinspection PyTypeChecker
      hooks.append(PrintLossAndTimeHook(
        every_steps=train_model.params['print_loss_steps'],
        model=train_model,
      ))
    if train_model.params['print_samples_steps'] is not None:
      # noinspection PyTypeChecker
      hooks.append(PrintSamplesHook(
        every_steps=train_model.params['print_samples_steps'],
        model=train_model,
      ))

  total_time = 0.0
  bench_start = train_model.params.get('bench_start', 10)

  if debug_port:
    hooks.append(
      tf_debug.TensorBoardDebugHook("localhost:{}".format(debug_port))
    )

  if train_model.on_horovod:
    init_data_layer = train_model.get_data_layer().iterator.initializer
  else:
    init_data_layer = tf.group(
      [train_model.get_data_layer(i).iterator.initializer
       for i in range(train_model.num_gpus)]
    )

  scaffold = tf.train.Scaffold(
    local_init_op=tf.group(tf.local_variables_initializer(), init_data_layer)
  )
  fetches = [train_model.train_op]
  try:
    total_objects = 0.0
    # on horovod num_gpus is 1
    for worker_id in range(train_model.num_gpus):
      fetches.append(train_model.get_num_objects_per_step(worker_id))
  except NotImplementedError:
    deco_print("WARNING: Can't compute number of objects per step, since "
               "train model does not define get_num_objects_per_step method.")

  # starting training
  with tf.train.MonitoredTrainingSession(
    scaffold=scaffold,
    checkpoint_dir=checkpoint_dir,
    save_summaries_steps=train_model.params['save_summaries_steps'],
    config=sess_config,
    save_checkpoint_secs=None,
    log_step_count_steps=train_model.params['save_summaries_steps'],
    stop_grace_period_secs=300,
    hooks=hooks,
  ) as sess:
    step = 0
    while True:
      if sess.should_stop():
        break
      tm = time.time()
      try:
        fetches_vals = sess.run(fetches)
      except tf.errors.OutOfRangeError:
        break
      if step >= bench_start:
        total_time += time.time() - tm
        if len(fetches) > 1:
          for i in range(train_model.num_gpus):
            total_objects += np.sum(fetches_vals[i + 1])
      step += 1

  if hvd is not None:
    deco_print("Finished training on rank {}".format(hvd.rank()))
  else:
    deco_print("Finished training")

  if train_model.on_horovod:
    ending = " on worker {}".format(hvd.rank())
  else:
    ending = ""
  if step > bench_start:
    deco_print(
      "Avg time per step{}: {:.3f}s".format(
        ending, 1.0 * total_time / (step - bench_start))
    )
    if len(fetches) > 1:
      deco_print(
        "Avg objects per second{}: {:.3f}".format(
          ending, 1.0 * total_objects / total_time)
      )
  else:
    deco_print("Not enough steps for benchmarking{}".format(ending))
Ejemplo n.º 42
0
  def _build_forward_pass_graph(self, input_tensors, gpu_id=0):
    """TensorFlow graph for encoder-decoder-loss model is created here.
    This function connects encoder, decoder and loss together. As an input for
    encoder it will specify source tensors (as returned from
    the data layer). As an input for decoder it will specify target tensors
    as well as all output returned from encoder. For loss it
    will also specify target tensors and all output returned from
    decoder. Note that loss will only be built for mode == "train" or "eval".

    Args:
      input_tensors (dict): ``input_tensors`` dictionary that has to contain
          ``source_tensors`` key with the list of all source tensors, and
          ``target_tensors`` with the list of all target tensors. Note that
          ``target_tensors`` only need to be provided if mode is
          "train" or "eval".
      gpu_id (int, optional): id of the GPU where the current copy of the model
          is constructed. For Horovod this is always zero.

    Returns:
      tuple: tuple containing loss tensor as returned from
      ``loss.compute_loss()`` and samples tensor, which is taken from
      ``decoder.decode()['samples']``. When ``mode == 'infer'``, loss will
      be None.
    """
    if not isinstance(input_tensors, dict) or \
       'source_tensors' not in input_tensors:
      raise ValueError('Input tensors should be a dict containing '
                       '"source_tensors" key')

    if not isinstance(input_tensors['source_tensors'], list):
      raise ValueError('source_tensors should be a list')

    source_tensors = input_tensors['source_tensors']
    if self.mode == "train" or self.mode == "eval":
      if 'target_tensors' not in input_tensors:
        raise ValueError('Input tensors should contain "target_tensors" key'
                         'when mode != "infer"')
      if not isinstance(input_tensors['target_tensors'], list):
        raise ValueError('target_tensors should be a list')
      target_tensors = input_tensors['target_tensors']

    with tf.variable_scope("ForwardPass"):
      encoder_input = {"source_tensors": source_tensors}
      encoder_output = self.encoder.encode(input_dict=encoder_input)

      decoder_input = {"encoder_output": encoder_output}
      if self.mode == "train":
        decoder_input['target_tensors'] = target_tensors
      decoder_output = self.decoder.decode(input_dict=decoder_input)
      decoder_samples = decoder_output.get("samples", None)

      if self.mode == "train" or self.mode == "eval":
        with tf.variable_scope("Loss"):
          loss_input_dict = {
            "decoder_output": decoder_output,
            "target_tensors": target_tensors,
          }
          loss = self.loss_computator.compute_loss(loss_input_dict)
      else:
        deco_print("Inference Mode. Loss part of graph isn't built.")
        loss = None
      return loss, decoder_samples
Ejemplo n.º 43
0
  def compile(self, force_var_reuse=False):
    """TensorFlow graph is built here."""
    if 'initializer' not in self.params:
      initializer = None
    else:
      init_dict = self.params.get('initializer_params', {})
      initializer = self.params['initializer'](**init_dict)

    if not self.on_horovod:  # not using Horovod
      # below we follow data parallelism for multi-GPU training
      losses = []
      for gpu_cnt, gpu_id in enumerate(self._gpu_ids):
        with tf.device("/gpu:{}".format(gpu_id)), tf.variable_scope(
          name_or_scope=tf.get_variable_scope(),
          # re-using variables across GPUs.
          reuse=force_var_reuse or (gpu_cnt > 0),
          initializer=initializer,
          dtype=self.get_tf_dtype(),
        ):
          deco_print("Building graph on GPU:{}".format(gpu_id))

          self.get_data_layer(gpu_cnt).build_graph()
          input_tensors = self.get_data_layer(gpu_cnt).input_tensors

          loss, self._outputs[gpu_cnt] = self._build_forward_pass_graph(
            input_tensors,
            gpu_id=gpu_cnt,
          )
          if self._outputs[gpu_cnt] is not None and \
             not isinstance(self._outputs[gpu_cnt], list):
            raise ValueError('Decoder samples have to be either None or list')
          if self._mode == "train" or self._mode == "eval":
            losses.append(loss)
      # end of for gpu_ind loop
      if self._mode == "train":
        self.loss = tf.reduce_mean(losses)
      if self._mode == "eval":
        self.eval_losses = losses
    else:  # is using Horovod
      # gpu_id should always be zero, since Horovod takes care of isolating
      # different processes to 1 GPU only
      with tf.device("/gpu:0"), tf.variable_scope(
          name_or_scope=tf.get_variable_scope(),
          reuse=force_var_reuse,
          initializer=initializer,
          dtype=self.get_tf_dtype(),
      ):
        deco_print(
          "Building graph in Horovod rank: {}".format(self._hvd.rank())
        )
        self.get_data_layer().build_graph()
        input_tensors = self.get_data_layer().input_tensors

        loss, self._output = self._build_forward_pass_graph(input_tensors,
                                                            gpu_id=0)
        if self._output is not None and not isinstance(self._output, list):
          raise ValueError('Decoder samples have to be either None or list')

        if self._mode == "train":
          self.loss = loss
        if self._mode == "eval":
          self.eval_losses = [loss]

    if self._mode == "train":
      if 'lr_policy' not in self.params:
        lr_policy = None
      else:
        lr_params = self.params.get('lr_policy_params', {})
        # adding default decay_steps = max_steps if lr_policy supports it and
        # different value is not provided
        if 'decay_steps' in self.params['lr_policy'].__code__.co_varnames and \
           'decay_steps' not in lr_params:
          lr_params['decay_steps'] = self._last_step
        if 'steps_per_epoch' in self.params['lr_policy'].__code__.co_varnames and \
           'steps_per_epoch' not in lr_params and 'num_epochs' in self.params:
          lr_params['steps_per_epoch'] = self.steps_in_epoch
        lr_policy = lambda gs: self.params['lr_policy'](global_step=gs,
                                                        **lr_params)

      self.train_op = optimize_loss(
        loss=tf.cast(self.loss, tf.float32) + get_regularization_loss(),
        dtype=self.params['dtype'],
        optimizer=self.params['optimizer'],
        optimizer_params=self.params.get('optimizer_params', {}),
        gradient_noise_scale=None,
        gradient_multipliers=None,
        clip_gradients=self.params.get('max_grad_norm', None),
        learning_rate_decay_fn=lr_policy,
        update_ops=None,
        variables=None,
        name="Loss_Optimization",
        summaries=self.params.get('summaries', None),
        colocate_gradients_with_ops=True,
        increment_global_step=True,
        larc_params=self.params.get('larc_params', None),
        loss_scale=self.params.get('loss_scale', 1.0),
        automatic_loss_scaling=self.params.get('automatic_loss_scaling', None),
        on_horovod=self.on_horovod,
      )
      tf.summary.scalar(name="train_loss", tensor=self.loss)
      if self.steps_in_epoch:
        tf.summary.scalar(
          name="epoch",
          tensor=tf.floor(tf.train.get_global_step() /
                          tf.constant(self.steps_in_epoch, dtype=tf.int64)),
        )

      if not self.on_horovod or self._hvd.rank() == 0:
        deco_print("Trainable variables:")
        total_params = 0
        unknown_shape = False
        for var in tf.trainable_variables():
          var_params = 1
          deco_print('{}'.format(var.name), offset=2)
          deco_print('shape: {}, {}'.format(var.get_shape(), var.dtype),
                     offset=4)
          if var.get_shape():
            for dim in var.get_shape():
              var_params *= dim.value
            total_params += var_params
          else:
            unknown_shape = True
        if unknown_shape:
          deco_print("Encountered unknown variable shape, can't compute total "
                     "number of parameters.")
        else:
          deco_print('Total trainable parameters: {}'.format(total_params))