def get_raw_strings(dataset, options):
  """
  :param Dataset dataset:
  :param options: argparse.Namespace
  :return: list of (seq tag, string)
  :rtype: list[(str,str)]
  """
  refs = []
  start_time = time.time()
  seq_len_stats = Stats()
  seq_idx = options.startseq
  if options.endseq < 0:
    options.endseq = float("inf")
  interactive = Util.is_tty() and not log.verbose[5]
  print("Iterating over %r." % dataset, file=log.v2)
  while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq:
    dataset.load_seqs(seq_idx, seq_idx + 1)
    complete_frac = dataset.get_complete_frac(seq_idx)
    start_elapsed = time.time() - start_time
    try:
      num_seqs_s = str(dataset.num_seqs)
    except NotImplementedError:
      try:
        num_seqs_s = "~%i" % dataset.estimated_num_seqs
      except TypeError:  # a number is required, not NoneType
        num_seqs_s = "?"
    progress_prefix = "%i/%s" % (seq_idx, num_seqs_s,)
    progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100)
    if complete_frac > 0:
      total_time_estimated = start_elapsed / complete_frac
      remaining_estimated = total_time_estimated - start_elapsed
      progress += " (%s)" % hms(remaining_estimated)
    seq_tag = dataset.get_tag(seq_idx)
    assert isinstance(seq_tag, str)
    ref = dataset.get_data(seq_idx, options.key)
    if isinstance(ref, numpy.ndarray):
      assert ref.shape == () or (ref.ndim == 1 and ref.dtype == numpy.uint8)
      if ref.shape == ():
        ref = ref.flatten()[0]  # get the entry itself (str or bytes)
      else:
        ref = ref.tobytes()
    if isinstance(ref, bytes):
      ref = ref.decode("utf8")
    assert isinstance(ref, str)
    seq_len_stats.collect([len(ref)])
    refs.append((seq_tag, ref))
    if interactive:
      Util.progress_bar_with_time(complete_frac, prefix=progress_prefix)
    elif log.verbose[5]:
      print(progress_prefix, "seq tag %r, ref len %i chars" % (seq_tag, len(ref)))
    seq_idx += 1
  print("Done. Num seqs %i. Total time %s." % (
    seq_idx, hms(time.time() - start_time)), file=log.v1)
  print("More seqs which we did not dumped: %s." % (
    dataset.is_less_than_num_seqs(seq_idx),), file=log.v1)
  seq_len_stats.dump(stream_prefix="Seq-length %r " % (options.key,), stream=log.v2)
  return refs
def get_raw_strings(dataset, options):
  """
  :param Dataset dataset:
  :param options: argparse.Namespace
  :return: list of (seq tag, string)
  :rtype: list[(str,str)]
  """
  refs = []
  start_time = time.time()
  seq_len_stats = Stats()
  seq_idx = options.startseq
  if options.endseq < 0:
    options.endseq = float("inf")
  interactive = Util.is_tty() and not log.verbose[5]
  print("Iterating over %r." % dataset, file=log.v2)
  while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq:
    dataset.load_seqs(seq_idx, seq_idx + 1)
    complete_frac = dataset.get_complete_frac(seq_idx)
    start_elapsed = time.time() - start_time
    try:
      num_seqs_s = str(dataset.num_seqs)
    except NotImplementedError:
      try:
        num_seqs_s = "~%i" % dataset.estimated_num_seqs
      except TypeError:  # a number is required, not NoneType
        num_seqs_s = "?"
    progress_prefix = "%i/%s" % (seq_idx, num_seqs_s,)
    progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100)
    if complete_frac > 0:
      total_time_estimated = start_elapsed / complete_frac
      remaining_estimated = total_time_estimated - start_elapsed
      progress += " (%s)" % hms(remaining_estimated)
    seq_tag = dataset.get_tag(seq_idx)
    assert isinstance(seq_tag, str)
    ref = dataset.get_data(seq_idx, options.key)
    if isinstance(ref, numpy.ndarray):
      assert ref.shape == () or (ref.ndim == 1 and ref.dtype == numpy.uint8)
      if ref.shape == ():
        ref = ref.flatten()[0]  # get the entry itself (str or bytes)
      else:
        ref = ref.tobytes()
    if isinstance(ref, bytes):
      ref = ref.decode("utf8")
    assert isinstance(ref, str)
    seq_len_stats.collect([len(ref)])
    refs.append((seq_tag, ref))
    if interactive:
      Util.progress_bar_with_time(complete_frac, prefix=progress_prefix)
    elif log.verbose[5]:
      print(progress_prefix, "seq tag %r, ref len %i chars" % (seq_tag, len(ref)))
    seq_idx += 1
  print("Done. Num seqs %i. Total time %s." % (
    seq_idx, hms(time.time() - start_time)), file=log.v1)
  print("More seqs which we did not dumped: %s." % (
    dataset.is_less_than_num_seqs(seq_idx),), file=log.v1)
  seq_len_stats.dump(stream_prefix="Seq-length %r " % (options.key,), stream=log.v2)
  return refs
Esempio n. 3
0
def main(argv):
    argparser = argparse.ArgumentParser(description=__doc__)
    argparser.add_argument("config_file", type=str)
    argparser.add_argument("--epoch", required=False, type=int)
    argparser.add_argument(
        '--data',
        default="train",
        help=
        "e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'"
    )
    argparser.add_argument('--do_search', default=False, action='store_true')
    argparser.add_argument('--beam_size', default=12, type=int)
    argparser.add_argument('--dump_dir', required=True)
    argparser.add_argument("--device", default="gpu")
    argparser.add_argument("--layers",
                           default=["att_weights"],
                           action="append",
                           help="Layer of subnet to grab")
    argparser.add_argument("--rec_layer",
                           default="output",
                           help="Subnet layer to grab from; decoder")
    argparser.add_argument("--enc_layer", default="encoder")
    argparser.add_argument("--batch_size", type=int, default=5000)
    argparser.add_argument("--seq_list",
                           default=[],
                           action="append",
                           help="predefined list of seqs")
    argparser.add_argument("--min_seq_len",
                           default="0",
                           help="can also be dict")
    argparser.add_argument("--output_format", default="npy", help="npy or png")
    argparser.add_argument("--dropout",
                           default=None,
                           type=float,
                           help="if set, overwrites all dropout values")
    argparser.add_argument("--train_flag", action="store_true")
    args = argparser.parse_args(argv[1:])

    layers = args.layers
    assert isinstance(layers, list)
    model_name = ".".join(args.config_file.split("/")[-1].split(".")[:-1])

    init_returnn(config_fn=args.config_file,
                 cmd_line_opts=["--device", args.device],
                 args=args)

    if args.do_search:
        raise NotImplementedError
    min_seq_length = NumbersDict(eval(args.min_seq_len))

    if not os.path.exists(args.dump_dir):
        os.makedirs(args.dump_dir)
    assert args.output_format in ["npy", "png"]
    if args.output_format == "png":
        import matplotlib.pyplot as plt  # need to import early? https://stackoverflow.com/a/45582103/133374
        import matplotlib.ticker as ticker
    dataset_str = args.data
    if dataset_str in ["train", "dev", "eval"]:
        dataset_str = "config:%s" % dataset_str
    dataset = init_dataset(dataset_str)
    init_net(args, layers)

    network = rnn.engine.network

    extra_fetches = {}
    for rec_ret_layer in ["rec_%s" % l for l in layers]:
        extra_fetches[rec_ret_layer] = rnn.engine.network.layers[
            rec_ret_layer].output.get_placeholder_as_batch_major()
    extra_fetches.update({
        "output":
        network.layers[args.rec_layer].output.get_placeholder_as_batch_major(),
        "output_len":
        network.layers[
            args.rec_layer].output.get_sequence_lengths(),  # decoder length
        "encoder_len":
        network.layers[
            args.enc_layer].output.get_sequence_lengths(),  # encoder length
        "seq_idx":
        network.get_extern_data("seq_idx"),
        "seq_tag":
        network.get_extern_data("seq_tag"),
        "target_data":
        network.get_extern_data(network.extern_data.default_input),
        "target_classes":
        network.get_extern_data(network.extern_data.default_target),
    })
    dataset.init_seq_order(epoch=rnn.engine.epoch,
                           seq_list=args.seq_list or None)
    dataset_batch = dataset.generate_batches(
        recurrent_net=network.recurrent,
        batch_size=args.batch_size,
        max_seqs=rnn.engine.max_seqs,
        max_seq_length=sys.maxsize,
        min_seq_length=min_seq_length,
        used_data_keys=network.used_data_keys)

    stats = {l: Stats() for l in layers}

    # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None
    def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output,
                       output_len, encoder_len, **kwargs):
        for i in range(len(seq_idx)):
            for l in layers:
                att_weights = kwargs["rec_%s" % l][i]
                stats[l].collect(att_weights.flatten())
        if args.output_format == "npy":
            data = {}
            for i in range(len(seq_idx)):
                data[i] = {
                    'tag': seq_tag[i],
                    'data': target_data[i],
                    'classes': target_classes[i],
                    'output': output[i],
                    'output_len': output_len[i],
                    'encoder_len': encoder_len[i],
                }
                for l in [("rec_%s" % l) for l in layers]:
                    assert l in kwargs
                    out = kwargs[l][i]
                    assert out.ndim >= 2
                    assert out.shape[0] >= output_len[i] and out.shape[
                        1] >= encoder_len[i]
                    data[i][l] = out[:output_len[i], :encoder_len[i]]
                fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % (
                    model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1])
                np.save(fname, data)
        elif args.output_format == "png":
            for i in range(len(seq_idx)):
                for l in layers:
                    extra_postfix = ""
                    if args.dropout is not None:
                        extra_postfix += "_dropout%.2f" % args.dropout
                    elif args.train_flag:
                        extra_postfix += "_train"
                    fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % (
                        model_name, rnn.engine.epoch, seq_idx[i], l,
                        extra_postfix)
                    att_weights = kwargs["rec_%s" % l][i]
                    att_weights = att_weights.squeeze(axis=2)  # (out,enc)
                    assert att_weights.shape[0] >= output_len[
                        i] and att_weights.shape[1] >= encoder_len[i]
                    att_weights = att_weights[:output_len[i], :encoder_len[i]]
                    print("Seq %i, %s: Dump att weights with shape %r to: %s" %
                          (seq_idx[i], seq_tag[i], att_weights.shape, fname))
                    plt.matshow(att_weights)
                    title = seq_tag[i]
                    if dataset.can_serialize_data(
                            network.extern_data.default_target):
                        title += "\n" + dataset.serialize_data(
                            network.extern_data.default_target,
                            target_classes[i][:output_len[i]])
                        ax = plt.gca()
                        tick_labels = [
                            dataset.serialize_data(
                                network.extern_data.default_target,
                                np.array([x], dtype=target_classes[i].dtype))
                            for x in target_classes[i][:output_len[i]]
                        ]
                        ax.set_yticklabels([''] + tick_labels, fontsize=8)
                        ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
                    plt.title(title)
                    plt.savefig(fname)
                    plt.close()
        else:
            raise NotImplementedError("output format %r" % args.output_format)

    runner = Runner(engine=rnn.engine,
                    dataset=dataset,
                    batches=dataset_batch,
                    train=False,
                    train_flag=args.dropout is not None or args.train_flag,
                    extra_fetches=extra_fetches,
                    extra_fetches_callback=fetch_callback)
    runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch)
    for l in layers:
        stats[l].dump(stream_prefix="Layer %r " % l)
    if not runner.finalized:
        print("Some error occured, not finalized.")
        sys.exit(1)

    rnn.finalize()
def analyze_dataset(options):
  """
  :param options: argparse.Namespace
  """
  print("Epoch: %i" % options.epoch, file=log.v3)
  print("Dataset keys:", dataset.get_data_keys(), file=log.v3)
  print("Dataset target keys:", dataset.get_target_list(), file=log.v3)
  assert options.key in dataset.get_data_keys()

  terminal_width, _ = Util.terminal_size()
  show_interactive_process_bar = (log.verbose[3] and (not log.verbose[5]) and terminal_width >= 0)

  start_time = time.time()
  num_seqs_stats = Stats()
  if options.endseq < 0:
    options.endseq = float("inf")

  recurrent = True
  used_data_keys = dataset.get_data_keys()
  batch_size = config.typed_value('batch_size', 1)
  max_seqs = config.int('max_seqs', -1)
  seq_drop = config.float('seq_drop', 0.0)
  max_seq_length = config.typed_value('max_seq_length', None) or config.float('max_seq_length', 0)
  max_pad_size = config.typed_value("max_pad_size", None)

  batches = dataset.generate_batches(
    recurrent_net=recurrent,
    batch_size=batch_size,
    max_seqs=max_seqs,
    max_seq_length=max_seq_length,
    max_pad_size=max_pad_size,
    seq_drop=seq_drop,
    used_data_keys=used_data_keys)

  step = 0
  total_num_seqs = 0
  total_num_frames = NumbersDict()
  total_num_used_frames = NumbersDict()

  try:
    while batches.has_more():
      # See FeedDictDataProvider.
      batch, = batches.peek_next_n(1)
      assert isinstance(batch, Batch)
      if batch.start_seq > options.endseq:
        break
      dataset.load_seqs(batch.start_seq, batch.end_seq)
      complete_frac = batches.completed_frac()
      start_elapsed = time.time() - start_time
      try:
        num_seqs_s = str(dataset.num_seqs)
      except NotImplementedError:
        try:
          num_seqs_s = "~%i" % dataset.estimated_num_seqs
        except TypeError:  # a number is required, not NoneType
          num_seqs_s = "?"
      progress_prefix = "%i/%s" % (batch.start_seq, num_seqs_s)
      progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100)
      if complete_frac > 0:
        total_time_estimated = start_elapsed / complete_frac
        remaining_estimated = total_time_estimated - start_elapsed
        progress += " (%s)" % hms(remaining_estimated)

      batch_max_time = NumbersDict.max([seq.frame_length for seq in batch.seqs]) * len(batch.seqs)
      batch_num_used_frames = sum([seq.frame_length for seq in batch.seqs], NumbersDict())
      total_num_seqs += len(batch.seqs)
      num_seqs_stats.collect(numpy.array([len(batch.seqs)]))
      total_num_frames += batch_max_time
      total_num_used_frames += batch_num_used_frames

      print(
        "%s, batch %i, num seqs %i, frames %s, used %s (%s)" % (
          progress, step, len(batch.seqs),
          batch_max_time, batch_num_used_frames, batch_num_used_frames / batch_max_time),
        file=log.v5)
      if show_interactive_process_bar:
        Util.progress_bar_with_time(complete_frac, prefix=progress_prefix)

      step += 1
      batches.advance(1)

  finally:
    print("Done. Total time %s. More seqs which we did not dumped: %s" % (
      hms(time.time() - start_time), batches.has_more()), file=log.v2)
    print("Dataset epoch %i, order %r." % (dataset.epoch, dataset.seq_ordering))
    print("Num batches (steps): %i" % step, file=log.v1)
    print("Num seqs: %i" % total_num_seqs, file=log.v1)
    num_seqs_stats.dump(stream=log.v1, stream_prefix="Batch num seqs ")
    for key in used_data_keys:
      print("Data key %r:" % key, file=log.v1)
      print("  Num frames: %s" % total_num_frames[key], file=log.v1)
      print("  Num used frames: %s" % total_num_used_frames[key], file=log.v1)
      print("  Fraction used frames: %s" % (total_num_used_frames / total_num_frames)[key], file=log.v1)
    dataset.finish_epoch()
Esempio n. 5
0
def main(argv):
    argparser = argparse.ArgumentParser(description=__doc__)
    argparser.add_argument("config_file",
                           type=str,
                           help="RETURNN config, or model-dir")
    argparser.add_argument("--epoch", type=int)
    argparser.add_argument(
        '--data',
        default="train",
        help=
        "e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'"
    )
    argparser.add_argument('--do_search', default=False, action='store_true')
    argparser.add_argument('--beam_size', default=12, type=int)
    argparser.add_argument('--dump_dir', help="for npy or png")
    argparser.add_argument("--output_file", help="hdf")
    argparser.add_argument("--device", help="gpu or cpu (default: automatic)")
    argparser.add_argument("--layers",
                           default=[],
                           action="append",
                           help="Layer of subnet to grab")
    argparser.add_argument("--rec_layer",
                           default="output",
                           help="Subnet layer to grab from; decoder")
    argparser.add_argument("--enc_layer", default="encoder")
    argparser.add_argument("--batch_size", type=int, default=5000)
    argparser.add_argument("--seq_list",
                           default=[],
                           action="append",
                           help="predefined list of seqs")
    argparser.add_argument("--min_seq_len",
                           default="0",
                           help="can also be dict")
    argparser.add_argument("--num_seqs",
                           default=-1,
                           type=int,
                           help="stop after this many seqs")
    argparser.add_argument("--output_format",
                           default="npy",
                           help="npy, png or hdf")
    argparser.add_argument("--dropout",
                           default=None,
                           type=float,
                           help="if set, overwrites all dropout values")
    argparser.add_argument("--train_flag", action="store_true")
    argparser.add_argument("--reset_partition_epoch", type=int, default=None)
    argparser.add_argument("--reset_seq_ordering", default="default")
    argparser.add_argument("--reset_epoch_wise_filter", default=None)
    argparser.add_argument('--hmm_fac_fo', default=False, action='store_true')
    argparser.add_argument('--encoder_sa', default=False, action='store_true')
    argparser.add_argument('--tf_log_dir', help="for npy or png", default=None)
    argparser.add_argument("--instead_save_encoder_decoder",
                           action="store_true")
    args = argparser.parse_args(argv[1:])

    layers = args.layers
    assert isinstance(layers, list)
    config_fn = args.config_file
    explicit_model_dir = None
    if os.path.isdir(config_fn):
        # Assume we gave a model dir.
        explicit_model_dir = config_fn
        train_log_dir_config_pattern = "%s/train-*/*.config" % config_fn
        train_log_dir_configs = sorted(glob(train_log_dir_config_pattern))
        assert train_log_dir_configs
        config_fn = train_log_dir_configs[-1]
        print("Using this config via model dir:", config_fn)
    else:
        assert os.path.isfile(config_fn)
    model_name = ".".join(config_fn.split("/")[-1].split(".")[:-1])

    init_returnn(config_fn=config_fn, args=args)
    if explicit_model_dir:
        config.set(
            "model", "%s/%s" %
            (explicit_model_dir, os.path.basename(config.value('model', ''))))
    print("Model file prefix:", config.value('model', ''))

    min_seq_length = NumbersDict(eval(args.min_seq_len))

    assert args.output_format in ["npy", "png", "hdf"]
    if args.output_format in ["npy", "png"]:
        assert args.dump_dir
        if not os.path.exists(args.dump_dir):
            os.makedirs(args.dump_dir)
    plt = ticker = None
    if args.output_format == "png":
        import matplotlib.pyplot as plt  # need to import early? https://stackoverflow.com/a/45582103/133374
        import matplotlib.ticker as ticker

    dataset_str = args.data
    if dataset_str in ["train", "dev", "eval"]:
        dataset_str = "config:%s" % dataset_str
    extra_dataset_kwargs = {}
    if args.reset_partition_epoch:
        print("NOTE: We are resetting partition epoch to %i." %
              (args.reset_partition_epoch, ))
        extra_dataset_kwargs["partition_epoch"] = args.reset_partition_epoch
    if args.reset_seq_ordering:
        print("NOTE: We will use %r seq ordering." %
              (args.reset_seq_ordering, ))
        extra_dataset_kwargs["seq_ordering"] = args.reset_seq_ordering
    if args.reset_epoch_wise_filter:
        extra_dataset_kwargs["epoch_wise_filter"] = eval(
            args.reset_epoch_wise_filter)
    dataset = init_dataset(dataset_str, extra_kwargs=extra_dataset_kwargs)
    if hasattr(dataset,
               "epoch_wise_filter") and args.reset_epoch_wise_filter is None:
        if dataset.epoch_wise_filter:
            print("NOTE: Resetting epoch_wise_filter to None.")
            dataset.epoch_wise_filter = None
    if args.reset_partition_epoch:
        assert dataset.partition_epoch == args.reset_partition_epoch
    if args.reset_seq_ordering:
        assert dataset.seq_ordering == args.reset_seq_ordering

    init_net(args, layers)
    network = rnn.engine.network

    hdf_writer = None
    if args.output_format == "hdf":
        assert args.output_file
        assert len(layers) == 1
        sub_layer = network.get_layer("%s/%s" % (args.rec_layer, layers[0]))
        from HDFDataset import SimpleHDFWriter
        hdf_writer = SimpleHDFWriter(filename=args.output_file,
                                     dim=sub_layer.output.dim,
                                     ndim=sub_layer.output.ndim)

    extra_fetches = {
        "output":
        network.layers[args.rec_layer].output.get_placeholder_as_batch_major(),
        "output_len":
        network.layers[
            args.rec_layer].output.get_sequence_lengths(),  # decoder length
        "encoder_len":
        network.layers[
            args.enc_layer].output.get_sequence_lengths(),  # encoder length
        "seq_idx":
        network.get_extern_data("seq_idx"),
        "seq_tag":
        network.get_extern_data("seq_tag"),
        "target_data":
        network.get_extern_data(network.extern_data.default_input),
        "target_classes":
        network.get_extern_data(network.extern_data.default_target),
    }

    if args.instead_save_encoder_decoder:
        extra_fetches["encoder"] = network.layers["encoder"]
        extra_fetches["decoder"] = network.layers["output"].get_sub_layer(
            "decoder")
    else:
        for l in layers:
            sub_layer = rnn.engine.network.get_layer("%s/%s" %
                                                     (args.rec_layer, l))
            extra_fetches[
                "rec_%s" %
                l] = sub_layer.output.get_placeholder_as_batch_major()
            if args.do_search:
                o_layer = rnn.engine.network.get_layer("output")
                extra_fetches["beam_scores_" +
                              l] = o_layer.get_search_choices().beam_scores

    dataset.init_seq_order(
        epoch=1, seq_list=args.seq_list
        or None)  # use always epoch 1, such that we have same seqs
    dataset_batch = dataset.generate_batches(
        recurrent_net=network.recurrent,
        batch_size=args.batch_size,
        max_seqs=rnn.engine.max_seqs,
        max_seq_length=sys.maxsize,
        min_seq_length=min_seq_length,
        max_total_num_seqs=args.num_seqs,
        used_data_keys=network.used_data_keys)

    stats = {l: Stats() for l in layers}

    # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None
    def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output,
                       output_len, encoder_len, **kwargs):
        """
    :param list[int] seq_idx: len is n_batch
    :param list[str] seq_tag: len is n_batch
    :param numpy.ndarray target_data: extern data default input (e.g. "data"), shape e.g. (B,enc-T,...)
    :param numpy.ndarray target_classes: extern data default target (e.g. "classes"), shape e.g. (B,dec-T,...)
    :param numpy.ndarray output: rec layer output, shape e.g. (B,dec-T,...)
    :param numpy.ndarray output_len: rec layer seq len, i.e. decoder length, shape (B,)
    :param numpy.ndarray encoder_len: encoder seq len, shape (B,)
    :param kwargs: contains "rec_%s" % l for l in layers, the sub layers (e.g att weights) we are interested in
    """
        n_batch = len(seq_idx)

        for i in range(n_batch):
            if not args.instead_save_encoder_decoder:
                for l in layers:
                    att_weights = kwargs["rec_%s" % l][i]
                    stats[l].collect(att_weights.flatten())
        if args.output_format == "npy":
            data = {}
            if args.do_search:
                assert not args.instead_save_encoder_decoder, "Not implemented"

                # find axis with correct beam size
                axis_beam_size = n_batch * args.beam_size
                corr_axis = None
                num_axes = len(kwargs["rec_%s" % layers[0]].shape)
                for a in range(len(kwargs["rec_%s" % layers[0]].shape)):
                    if kwargs["rec_%s" % layers[0]].shape[a] == axis_beam_size:
                        corr_axis = a
                        break
                assert corr_axis is not None, "Att Weights Decoding: correct axis not found! maybe check beam size."

                # set dimensions correctly
                for l, l_raw in zip([("rec_%s" % l) for l in layers], layers):
                    swap = list(range(num_axes))
                    del swap[corr_axis]
                    swap.insert(0, corr_axis)
                    kwargs[l] = np.transpose(kwargs[l], swap)

                for i in range(n_batch):
                    # The first beam contains the score with the highest beam
                    i_beam = args.beam_size * i
                    data[i] = {
                        'tag': seq_tag[i],
                        'data': target_data[i],
                        'classes': target_classes[i],
                        'output': output[i_beam],
                        'output_len': output_len[i_beam],
                        'encoder_len': encoder_len[i],
                    }

                    #if args.hmm_fac_fo is False:
                    for l, l_raw in zip([("rec_%s" % l) for l in layers],
                                        layers):
                        assert l in kwargs
                        out = kwargs[l][i_beam]
                        # Do search for multihead
                        # out is [I, H, 1, J] is new version
                        # out is [I, J, H, 1] for old version
                        if len(out.shape) == 3 and min(out.shape) > 1:
                            out = np.transpose(
                                out,
                                axes=(0, 3, 1, 2))  # [I, J, H, 1] new version
                            out = np.squeeze(out, axis=-1)
                        else:
                            out = np.squeeze(out, axis=1)
                        assert out.shape[0] >= output_len[
                            i_beam] and out.shape[1] >= encoder_len[i]
                        data[i][l] = out[:output_len[i_beam], :encoder_len[i]]
                    fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % (
                        model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1])
                    np.save(fname, data)
            else:
                for i in range(n_batch):
                    data[i] = {
                        'tag': seq_tag[i],
                        'data': target_data[i],
                        'classes': target_classes[i],
                        'output': output[i],
                        'output_len': output_len[i],
                        'encoder_len': encoder_len[i],
                    }
                    #if args.hmm_fac_fo is False:

                    if args.instead_save_encoder_decoder:
                        out = kwargs["encoder"][i]
                        data[i]["encoder"] = out[:encoder_len[i]]
                        out_2 = kwargs["decoder"][i]
                        data[i]["decoder"] = out_2[:output_len[i]]
                    else:
                        for l in [("rec_%s" % l) for l in layers]:
                            assert l in kwargs
                            out = kwargs[l][i]  # []
                            # multi-head attention
                            if len(out.shape) == 3 and min(out.shape) > 1:
                                # Multihead attention
                                out = np.transpose(
                                    out,
                                    axes=(1, 2, 0))  # (I, J, H) new version
                                #out = np.transpose(out, axes=(2, 0, 1))  # [I, J, H] old version
                            else:
                                # RNN
                                out = np.squeeze(out, axis=1)
                            assert out.ndim >= 2
                            assert out.shape[0] >= output_len[i] and out.shape[
                                1] >= encoder_len[i]
                            data[i][l] = out[:output_len[i], :encoder_len[i]]
                    fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % (
                        model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1])
                    np.save(fname, data)
        elif args.output_format == "png":
            for i in range(n_batch):
                for l in layers:
                    extra_postfix = ""
                    if args.dropout is not None:
                        extra_postfix += "_dropout%.2f" % args.dropout
                    elif args.train_flag:
                        extra_postfix += "_train"
                    fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % (
                        model_name, rnn.engine.epoch, seq_idx[i], l,
                        extra_postfix)
                    att_weights = kwargs["rec_%s" % l][i]
                    att_weights = att_weights.squeeze(axis=2)  # (out,enc)
                    assert att_weights.shape[0] >= output_len[
                        i] and att_weights.shape[1] >= encoder_len[i]
                    att_weights = att_weights[:output_len[i], :encoder_len[i]]
                    print("Seq %i, %s: Dump att weights with shape %r to: %s" %
                          (seq_idx[i], seq_tag[i], att_weights.shape, fname))
                    plt.matshow(att_weights)
                    title = seq_tag[i]
                    if dataset.can_serialize_data(
                            network.extern_data.default_target):
                        title += "\n" + dataset.serialize_data(
                            network.extern_data.default_target,
                            target_classes[i][:output_len[i]])
                        ax = plt.gca()
                        tick_labels = [
                            dataset.serialize_data(
                                network.extern_data.default_target,
                                np.array([x], dtype=target_classes[i].dtype))
                            for x in target_classes[i][:output_len[i]]
                        ]
                        ax.set_yticklabels([''] + tick_labels, fontsize=8)
                        ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
                    plt.title(title)
                    plt.savefig(fname)
                    plt.close()
        elif args.output_format == "hdf":
            assert len(layers) == 1
            att_weights = kwargs["rec_%s" % layers[0]]
            hdf_writer.insert_batch(inputs=att_weights,
                                    seq_len={
                                        0: output_len,
                                        1: encoder_len
                                    },
                                    seq_tag=seq_tag)
        else:
            raise Exception("output format %r" % args.output_format)

    runner = Runner(engine=rnn.engine,
                    dataset=dataset,
                    batches=dataset_batch,
                    train=False,
                    train_flag=bool(args.dropout) or args.train_flag,
                    extra_fetches=extra_fetches,
                    extra_fetches_callback=fetch_callback,
                    eval=False)
    runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch)

    if not args.instead_save_encoder_decoder:
        for l in layers:
            stats[l].dump(stream_prefix="Layer %r " % l)
    if not runner.finalized:
        print("Some error occured, not finalized.")
        sys.exit(1)

    if hdf_writer:
        hdf_writer.close()
    rnn.finalize()
Esempio n. 6
0
def dump_dataset(dataset, options):
  """
  :type dataset: Dataset.Dataset
  :param options: argparse.Namespace
  """
  print("Epoch: %i" % options.epoch, file=log.v3)
  dataset.init_seq_order(epoch=options.epoch)
  print("Dataset keys:", dataset.get_data_keys(), file=log.v3)
  print("Dataset target keys:", dataset.get_target_list(), file=log.v3)
  assert options.key in dataset.get_data_keys()

  if options.get_num_seqs:
    print("Get num seqs.")
    print("estimated_num_seqs: %r" % dataset.estimated_num_seqs)
    try:
      print("num_seqs: %r" % dataset.num_seqs)
    except Exception as exc:
      print("num_seqs exception %r, which is valid, so we count." % exc)
      seq_idx = 0
      if dataset.get_target_list():
        default_target = dataset.get_target_list()[0]
      else:
        default_target = None
      while dataset.is_less_than_num_seqs(seq_idx):
        dataset.load_seqs(seq_idx, seq_idx + 1)
        if seq_idx % 10000 == 0:
          if default_target:
            targets = dataset.get_targets(default_target, seq_idx)
            postfix = " (targets = %r...)" % (targets[:10],)
          else:
            postfix = ""
          print("%i ...%s" % (seq_idx, postfix))
        seq_idx += 1
      print("accumulated num seqs: %i" % seq_idx)
    print("Done.")
    return

  dump_file = None
  if options.type == "numpy":
    print("Dump files: %r*%r" % (options.dump_prefix, options.dump_postfix), file=log.v3)
  elif options.type == "stdout":
    print("Dump to stdout", file=log.v3)
    if options.stdout_limit is not None:
      Util.set_pretty_print_default_limit(options.stdout_limit)
      numpy.set_printoptions(
        threshold=sys.maxsize if options.stdout_limit == float("inf") else int(options.stdout_limit))
    if options.stdout_as_bytes:
      Util.set_pretty_print_as_bytes(options.stdout_as_bytes)
  elif options.type == "print_tag":
    print("Dump seq tag to stdout", file=log.v3)
  elif options.type == "dump_tag":
    dump_file = open("%sseq-tags.txt" % options.dump_prefix, "w")
    print("Dump seq tag to file: %s" % (dump_file.name,), file=log.v3)
  elif options.type == "dump_seq_len":
    dump_file = open("%sseq-lens.txt" % options.dump_prefix, "w")
    print("Dump seq lens to file: %s" % (dump_file.name,), file=log.v3)
    dump_file.write("{\n")
  elif options.type == "print_shape":
    print("Dump shape to stdout", file=log.v3)
  elif options.type == "plot":
    print("Plot.", file=log.v3)
  elif options.type == "interactive":
    print("Interactive debug shell.", file=log.v3)
  elif options.type == "null":
    print("No dump.")
  else:
    raise Exception("unknown dump option type %r" % options.type)

  start_time = time.time()
  stats = Stats() if (options.stats or options.dump_stats) else None
  seq_len_stats = {key: Stats() for key in dataset.get_data_keys()}
  seq_idx = options.startseq
  if options.endseq < 0:
    options.endseq = float("inf")
  while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq:
    dataset.load_seqs(seq_idx, seq_idx + 1)
    complete_frac = dataset.get_complete_frac(seq_idx)
    start_elapsed = time.time() - start_time
    try:
      num_seqs_s = str(dataset.num_seqs)
    except NotImplementedError:
      try:
        num_seqs_s = "~%i" % dataset.estimated_num_seqs
      except TypeError:  # a number is required, not NoneType
        num_seqs_s = "?"
    progress_prefix = "%i/%s" % (seq_idx, num_seqs_s)
    progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100)
    if complete_frac > 0:
      total_time_estimated = start_elapsed / complete_frac
      remaining_estimated = total_time_estimated - start_elapsed
      progress += " (%s)" % hms(remaining_estimated)
    if options.type == "print_tag":
      print("seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx))
    elif options.type == "dump_tag":
      print("seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx))
      dump_file.write("%s\n" % dataset.get_tag(seq_idx))
    elif options.type == "dump_seq_len":
      seq_len = dataset.get_seq_length(seq_idx)[options.key]
      print(
        "seq %s tag:" % (progress if log.verbose[2] else progress_prefix),
        dataset.get_tag(seq_idx), "%r len:" % options.key, seq_len)
      dump_file.write("%r: %r,\n" % (dataset.get_tag(seq_idx), seq_len))
    else:
      data = dataset.get_data(seq_idx, options.key)
      if options.type == "numpy":
        numpy.savetxt("%s%i.data%s" % (options.dump_prefix, seq_idx, options.dump_postfix), data)
      elif options.type == "stdout":
        print("seq %s tag:" % progress, dataset.get_tag(seq_idx))
        print("seq %s data:" % progress, pretty_print(data))
      elif options.type == "print_shape":
        print("seq %s data shape:" % progress, data.shape)
      elif options.type == "plot":
        plot(data)
      for target in dataset.get_target_list():
        targets = dataset.get_targets(target, seq_idx)
        if options.type == "numpy":
          numpy.savetxt("%s%i.targets.%s%s" % (options.dump_prefix, seq_idx, target, options.dump_postfix), targets, fmt='%i')
        elif options.type == "stdout":
          extra = ""
          if target in dataset.labels and len(dataset.labels[target]) > 1:
            assert dataset.can_serialize_data(target)
            extra += " (%r)" % dataset.serialize_data(key=target, data=targets)
          print("seq %i target %r: %s%s" % (seq_idx, target, pretty_print(targets), extra))
        elif options.type == "print_shape":
          print("seq %i target %r shape:" % (seq_idx, target), targets.shape)
      if options.type == "interactive":
        from Debug import debug_shell
        debug_shell(locals())
    seq_len = dataset.get_seq_length(seq_idx)
    for key in dataset.get_data_keys():
      seq_len_stats[key].collect([seq_len[key]])
    if stats:
      stats.collect(data)
    if options.type == "null":
      Util.progress_bar_with_time(complete_frac, prefix=progress_prefix)

    seq_idx += 1

  print("Done. Total time %s. More seqs which we did not dumped: %s" % (
    hms_fraction(time.time() - start_time), dataset.is_less_than_num_seqs(seq_idx)), file=log.v2)
  for key in dataset.get_data_keys():
    seq_len_stats[key].dump(stream_prefix="Seq-length %r " % key, stream=log.v2)
  if stats:
    stats.dump(output_file_prefix=options.dump_stats, stream_prefix="Data %r " % options.key, stream=log.v1)
  if options.type == "dump_seq_len":
    dump_file.write("}\n")
  if dump_file:
    print("Dumped to file:", dump_file.name, file=log.v2)
    dump_file.close()
Esempio n. 7
0
def calc_wer_on_dataset(dataset, refs, options, hyps):
    """
  :param Dataset|None dataset:
  :param dict[str,str]|None refs: seq tag -> ref string (words delimited by space)
  :param options: argparse.Namespace
  :param dict[str,str] hyps: seq tag -> hyp string (words delimited by space)
  :return: WER
  :rtype: float
  """
    assert dataset or refs
    start_time = time.time()
    seq_len_stats = {"refs": Stats(), "hyps": Stats()}
    seq_idx = options.startseq
    if options.endseq < 0:
        options.endseq = float("inf")
    wer = 1.0
    remaining_hyp_seq_tags = set(hyps.keys())
    interactive = Util.is_tty() and not log.verbose[5]
    collected = {"hyps": [], "refs": []}
    max_num_collected = 1
    if dataset:
        dataset.init_seq_order(epoch=1)
    else:
        refs = sorted(refs.items(), key=lambda item: len(item[1]))
    while True:
        if seq_idx > options.endseq:
            break
        if dataset:
            if not dataset.is_less_than_num_seqs(seq_idx):
                break
            dataset.load_seqs(seq_idx, seq_idx + 1)
            complete_frac = dataset.get_complete_frac(seq_idx)
            seq_tag = dataset.get_tag(seq_idx)
            assert isinstance(seq_tag, str)
            ref = dataset.get_data(seq_idx, options.key)
            if isinstance(ref, numpy.ndarray):
                assert ref.shape == ()
                ref = ref.flatten()[0]  # get the entry itself (str or bytes)
            if isinstance(ref, bytes):
                ref = ref.decode("utf8")
            assert isinstance(ref, str)
            try:
                num_seqs_s = str(dataset.num_seqs)
            except NotImplementedError:
                try:
                    num_seqs_s = "~%i" % dataset.estimated_num_seqs
                except TypeError:  # a number is required, not NoneType
                    num_seqs_s = "?"
        else:
            if seq_idx >= len(refs):
                break
            complete_frac = (seq_idx + 1) / float(len(refs))
            seq_tag, ref = refs[seq_idx]
            assert isinstance(seq_tag, str)
            assert isinstance(ref, str)
            num_seqs_s = str(len(refs))

        start_elapsed = time.time() - start_time
        progress_prefix = "%i/%s (WER %.02f%%)" % (seq_idx, num_seqs_s,
                                                   wer * 100)
        progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100)
        if complete_frac > 0:
            total_time_estimated = start_elapsed / complete_frac
            remaining_estimated = total_time_estimated - start_elapsed
            progress += " (%s)" % hms(remaining_estimated)

        remaining_hyp_seq_tags.remove(seq_tag)
        hyp = hyps[seq_tag]
        seq_len_stats["hyps"].collect([len(hyp)])
        seq_len_stats["refs"].collect([len(ref)])
        collected["hyps"].append(hyp)
        collected["refs"].append(ref)

        if len(collected["hyps"]) >= max_num_collected:
            wer = wer_compute.step(session, **collected)
            del collected["hyps"][:]
            del collected["refs"][:]

        if interactive:
            Util.progress_bar_with_time(complete_frac, prefix=progress_prefix)
        elif log.verbose[5]:
            print(
                progress_prefix, "seq tag %r, ref/hyp len %i/%i chars" %
                (seq_tag, len(ref), len(hyp)))
        seq_idx += 1
    if len(collected["hyps"]) > 0:
        wer = wer_compute.step(session, **collected)
    print("Done. Num seqs %i. Total time %s." %
          (seq_idx, hms(time.time() - start_time)),
          file=log.v1)
    print("Remaining num hyp seqs %i." % (len(remaining_hyp_seq_tags), ),
          file=log.v1)
    if dataset:
        print("More seqs which we did not dumped: %s." %
              dataset.is_less_than_num_seqs(seq_idx),
              file=log.v1)
    for key in ["hyps", "refs"]:
        seq_len_stats[key].dump(stream_prefix="Seq-length %r %r " %
                                (key, options.key),
                                stream=log.v2)
    if options.expect_full:
        assert not remaining_hyp_seq_tags, "There are still remaining hypotheses."
    return wer
Esempio n. 8
0
def dump_dataset(dataset, options):
    """
  :type dataset: Dataset.Dataset
  :param options: argparse.Namespace
  """
    print("Epoch: %i" % options.epoch, file=log.v3)
    dataset.init_seq_order(epoch=options.epoch)
    print("Dataset keys:", dataset.get_data_keys(), file=log.v3)
    print("Dataset target keys:", dataset.get_target_list(), file=log.v3)
    assert options.key in dataset.get_data_keys()

    if options.get_num_seqs:
        print("Get num seqs.")
        print("estimated_num_seqs: %r" % dataset.estimated_num_seqs)
        try:
            print("num_seqs: %r" % dataset.num_seqs)
        except Exception as exc:
            print("num_seqs exception %r, which is valid, so we count." % exc)
            seq_idx = 0
            if dataset.get_target_list():
                default_target = dataset.get_target_list()[0]
            else:
                default_target = None
            while dataset.is_less_than_num_seqs(seq_idx):
                dataset.load_seqs(seq_idx, seq_idx + 1)
                if seq_idx % 10000 == 0:
                    if default_target:
                        targets = dataset.get_targets(default_target, seq_idx)
                        postfix = " (targets = %r...)" % (targets[:10], )
                    else:
                        postfix = ""
                    print("%i ...%s" % (seq_idx, postfix))
                seq_idx += 1
            print("accumulated num seqs: %i" % seq_idx)
        print("Done.")
        return

    if options.type == "numpy":
        print("Dump files: %r*%r" %
              (options.dump_prefix, options.dump_postfix),
              file=log.v3)
    elif options.type == "stdout":
        print("Dump to stdout", file=log.v3)
    elif options.type == "print_shape":
        print("Dump shape to stdout", file=log.v3)
    elif options.type == "plot":
        print("Plot.", file=log.v3)
    elif options.type == "null":
        print("No dump.")
    else:
        raise Exception("unknown dump option type %r" % options.type)

    start_time = time.time()
    stats = Stats() if (options.stats or options.dump_stats) else None
    seq_len_stats = {key: Stats() for key in dataset.get_data_keys()}
    seq_idx = options.startseq
    if options.endseq < 0:
        options.endseq = float("inf")
    while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq:
        dataset.load_seqs(seq_idx, seq_idx + 1)
        complete_frac = dataset.get_complete_frac(seq_idx)
        start_elapsed = time.time() - start_time
        try:
            num_seqs_s = str(dataset.num_seqs)
        except NotImplementedError:
            try:
                num_seqs_s = "~%i" % dataset.estimated_num_seqs
            except TypeError:  # a number is required, not NoneType
                num_seqs_s = "?"
        progress_prefix = "%i/%s" % (seq_idx, num_seqs_s)
        progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100)
        if complete_frac > 0:
            total_time_estimated = start_elapsed / complete_frac
            remaining_estimated = total_time_estimated - start_elapsed
            progress += " (%s)" % hms(remaining_estimated)
        data = dataset.get_data(seq_idx, options.key)
        if options.type == "numpy":
            numpy.savetxt(
                "%s%i.data%s" %
                (options.dump_prefix, seq_idx, options.dump_postfix), data)
        elif options.type == "stdout":
            print("seq %s data:" % progress, pretty_print(data))
        elif options.type == "print_shape":
            print("seq %s data shape:" % progress, data.shape)
        elif options.type == "plot":
            plot(data)
        for target in dataset.get_target_list():
            targets = dataset.get_targets(target, seq_idx)
            if options.type == "numpy":
                numpy.savetxt("%s%i.targets.%s%s" %
                              (options.dump_prefix, seq_idx, target,
                               options.dump_postfix),
                              targets,
                              fmt='%i')
            elif options.type == "stdout":
                print("seq %i target %r:" % (seq_idx, target),
                      pretty_print(targets))
            elif options.type == "print_shape":
                print("seq %i target %r shape:" % (seq_idx, target),
                      targets.shape)
        seq_len = dataset.get_seq_length(seq_idx)
        for key in dataset.get_data_keys():
            seq_len_stats[key].collect([seq_len[key]])
        if stats:
            stats.collect(data)
        if options.type == "null":
            Util.progress_bar_with_time(complete_frac, prefix=progress_prefix)

        seq_idx += 1

    print("Done. Total time %s. More seqs which we did not dumped: %s" %
          (hms(time.time() - start_time),
           dataset.is_less_than_num_seqs(seq_idx)),
          file=log.v1)
    for key in dataset.get_data_keys():
        seq_len_stats[key].dump(stream_prefix="Seq-length %r " % key,
                                stream=log.v2)
    if stats:
        stats.dump(output_file_prefix=options.dump_stats,
                   stream_prefix="Data %r " % options.key,
                   stream=log.v2)
Esempio n. 9
0
def dump_dataset(dataset, options):
  """
  :type dataset: Dataset.Dataset
  :param options: argparse.Namespace
  """
  print("Epoch: %i" % options.epoch, file=log.v3)
  dataset.init_seq_order(epoch=options.epoch)
  print("Dataset keys:", dataset.get_data_keys(), file=log.v3)
  print("Dataset target keys:", dataset.get_target_list(), file=log.v3)
  assert options.key in dataset.get_data_keys()

  if options.get_num_seqs:
    print("Get num seqs.")
    print("estimated_num_seqs: %r" % dataset.estimated_num_seqs)
    try:
      print("num_seqs: %r" % dataset.num_seqs)
    except Exception as exc:
      print("num_seqs exception %r, which is valid, so we count." % exc)
      seq_idx = 0
      if dataset.get_target_list():
        default_target = dataset.get_target_list()[0]
      else:
        default_target = None
      while dataset.is_less_than_num_seqs(seq_idx):
        dataset.load_seqs(seq_idx, seq_idx + 1)
        if seq_idx % 10000 == 0:
          if default_target:
            targets = dataset.get_targets(default_target, seq_idx)
            postfix = " (targets = %r...)" % (targets[:10],)
          else:
            postfix = ""
          print("%i ...%s" % (seq_idx, postfix))
        seq_idx += 1
      print("accumulated num seqs: %i" % seq_idx)
    print("Done.")
    return

  if options.type == "numpy":
    print("Dump files: %r*%r" % (options.dump_prefix, options.dump_postfix), file=log.v3)
  elif options.type == "stdout":
    print("Dump to stdout", file=log.v3)
    if options.stdout_limit is not None:
      Util.set_pretty_print_default_limit(options.stdout_limit)
      numpy.set_printoptions(
        threshold=sys.maxsize if options.stdout_limit == float("inf") else int(options.stdout_limit))
    if options.stdout_as_bytes:
      Util.set_pretty_print_as_bytes(options.stdout_as_bytes)
  elif options.type == "print_tag":
    print("Dump seq tag to stdout", file=log.v3)
  elif options.type == "print_shape":
    print("Dump shape to stdout", file=log.v3)
  elif options.type == "plot":
    print("Plot.", file=log.v3)
  elif options.type == "null":
    print("No dump.")
  else:
    raise Exception("unknown dump option type %r" % options.type)

  start_time = time.time()
  stats = Stats() if (options.stats or options.dump_stats) else None
  seq_len_stats = {key: Stats() for key in dataset.get_data_keys()}
  seq_idx = options.startseq
  if options.endseq < 0:
    options.endseq = float("inf")
  while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq:
    dataset.load_seqs(seq_idx, seq_idx + 1)
    complete_frac = dataset.get_complete_frac(seq_idx)
    start_elapsed = time.time() - start_time
    try:
      num_seqs_s = str(dataset.num_seqs)
    except NotImplementedError:
      try:
        num_seqs_s = "~%i" % dataset.estimated_num_seqs
      except TypeError:  # a number is required, not NoneType
        num_seqs_s = "?"
    progress_prefix = "%i/%s" % (seq_idx, num_seqs_s)
    progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100)
    if complete_frac > 0:
      total_time_estimated = start_elapsed / complete_frac
      remaining_estimated = total_time_estimated - start_elapsed
      progress += " (%s)" % hms(remaining_estimated)
    if options.type == "print_tag":
      print("seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx))
    else:
      data = dataset.get_data(seq_idx, options.key)
      if options.type == "numpy":
        numpy.savetxt("%s%i.data%s" % (options.dump_prefix, seq_idx, options.dump_postfix), data)
      elif options.type == "stdout":
        print("seq %s tag:" % progress, dataset.get_tag(seq_idx))
        print("seq %s data:" % progress, pretty_print(data))
      elif options.type == "print_shape":
        print("seq %s data shape:" % progress, data.shape)
      elif options.type == "plot":
        plot(data)
      for target in dataset.get_target_list():
        targets = dataset.get_targets(target, seq_idx)
        if options.type == "numpy":
          numpy.savetxt("%s%i.targets.%s%s" % (options.dump_prefix, seq_idx, target, options.dump_postfix), targets, fmt='%i')
        elif options.type == "stdout":
          extra = ""
          if target in dataset.labels and len(dataset.labels[target]) > 1:
            labels = dataset.labels[target]
            if len(labels) < 1000 and all([len(l) == 1 for l in labels]):
              join_str = ""
            else:
              join_str = " "
            extra += " (%r)" % join_str.join(map(dataset.labels[target].__getitem__, targets))
          print("seq %i target %r: %s%s" % (seq_idx, target, pretty_print(targets), extra))
        elif options.type == "print_shape":
          print("seq %i target %r shape:" % (seq_idx, target), targets.shape)
    seq_len = dataset.get_seq_length(seq_idx)
    for key in dataset.get_data_keys():
      seq_len_stats[key].collect([seq_len[key]])
    if stats:
      stats.collect(data)
    if options.type == "null":
      Util.progress_bar_with_time(complete_frac, prefix=progress_prefix)

    seq_idx += 1

  print("Done. Total time %s. More seqs which we did not dumped: %s" % (
    hms(time.time() - start_time), dataset.is_less_than_num_seqs(seq_idx)), file=log.v2)
  for key in dataset.get_data_keys():
    seq_len_stats[key].dump(stream_prefix="Seq-length %r " % key, stream=log.v2)
  if stats:
    stats.dump(output_file_prefix=options.dump_stats, stream_prefix="Data %r " % options.key, stream=log.v1)