Ejemplo n.º 1
0
def main(argv):
  argparser = argparse.ArgumentParser(description='Dump raw strings from dataset. Same format as in search.')
  argparser.add_argument('--config', help="filename to config-file. will use dataset 'eval' from it")
  argparser.add_argument("--dataset", help="dataset, overwriting config")
  argparser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)')
  argparser.add_argument('--endseq', type=int, default=-1, help='end seq idx (inclusive) or -1 (default: -1)')
  argparser.add_argument("--key", default="raw", help="data-key, e.g. 'data' or 'classes'. (default: 'raw')")
  argparser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)")
  argparser.add_argument("--out", required=True, help="out-file. py-format as in task=search")
  args = argparser.parse_args(argv[1:])
  assert args.config or args.dataset

  init(config_filename=args.config, log_verbosity=args.verbosity)
  if args.dataset:
    dataset = init_dataset(args.dataset)
  elif config.value("dump_data", "eval") in ["train", "dev", "eval"]:
    dataset = init_dataset(config.opt_typed_value(config.value("search_data", "eval")))
  else:
    dataset = init_dataset(config.opt_typed_value("wer_data"))
  dataset.init_seq_order(epoch=1)

  try:
    with open(args.out, "w") as output_file:
      refs = get_raw_strings(dataset=dataset, options=args)
      output_file.write("{\n")
      for seq_tag, ref in refs:
        output_file.write("%r: %r,\n" % (seq_tag, ref))
      output_file.write("}\n")
    print("Done. Wrote to %r." % args.out)
  except KeyboardInterrupt:
    print("KeyboardInterrupt")
    sys.exit(1)
  finally:
    rnn.finalize()
Ejemplo n.º 2
0
def load_data(config, cache_byte_size, files_config_key, **kwargs):
  """
  :param Config config:
  :param int cache_byte_size:
  :param str files_config_key: such as "train" or "dev"
  :param kwargs: passed on to init_dataset() or init_dataset_via_str()
  :rtype: (Dataset,int)
  :returns the dataset, and the cache byte size left over if we cache the whole dataset.
  """
  if not config.bool_or_other(files_config_key, None):
    return None, 0
  kwargs = kwargs.copy()
  kwargs.setdefault("name", files_config_key)
  if config.is_typed(files_config_key) and isinstance(config.typed_value(files_config_key), dict):
    config_opts = config.typed_value(files_config_key)
    assert isinstance(config_opts, dict)
    kwargs.update(config_opts)
    if 'cache_byte_size' not in config_opts:
      if kwargs.get('class', None) == 'HDFDataset':
        kwargs["cache_byte_size"] = cache_byte_size
    Dataset.kwargs_update_from_config(config, kwargs)
    data = init_dataset(kwargs)
  else:
    config_str = config.value(files_config_key, "")
    data = init_dataset_via_str(config_str, config=config, cache_byte_size=cache_byte_size, **kwargs)
  cache_leftover = 0
  if isinstance(data, HDFDataset):
    cache_leftover = data.definite_cache_leftover
  return data, cache_leftover
Ejemplo n.º 3
0
def init(config_filename, cmd_line_opts, dataset_config_str):
  """
  :param str config_filename: global config for CRNN
  :param list[str] cmd_line_opts: options for init_config method
  :param str dataset_config_str: dataset via init_dataset_via_str()
  """
  rnn.init_better_exchook()
  rnn.init_thread_join_hack()
  if config_filename:
    rnn.init_config(config_filename, cmd_line_opts)
    rnn.init_log()
  else:
    log.initialize(verbosity=[5])
  print("Returnn hdf_dump starting up.", file=log.v3)
  rnn.init_faulthandler()
  if config_filename:
    rnn.init_data()
    rnn.print_task_properties()
    assert isinstance(rnn.train_data, Dataset)
    dataset = rnn.train_data
  else:
    assert dataset_config_str
    dataset = init_dataset(dataset_config_str)
  print("Source dataset:", dataset.len_info(), file=log.v3)
  return dataset
Ejemplo n.º 4
0
def load_data(config, cache_byte_size, files_config_key, **kwargs):
  """
  :type config: Config
  :type cache_byte_size: int
  :type chunking: str
  :type seq_ordering: str
  :rtype: (Dataset,int)
  :returns the dataset, and the cache byte size left over if we cache the whole dataset.
  """
  if not config.has(files_config_key):
    return None, 0
  if config.is_typed(files_config_key) and isinstance(config.typed_value(files_config_key), dict):
    new_kwargs = config.typed_value(files_config_key)
    assert isinstance(new_kwargs, dict)
    kwargs.update(new_kwargs)
    if 'cache_byte_size' not in new_kwargs:
      if kwargs.get('class', None) == 'HDFDataset':
        kwargs["cache_byte_size"] = cache_byte_size
    Dataset.kwargs_update_from_config(config, kwargs)
    data = init_dataset(kwargs)
  else:
    config_str = config.value(files_config_key, "")
    data = init_dataset_via_str(config_str, config=config, cache_byte_size=cache_byte_size, **kwargs)
  cache_leftover = 0
  if isinstance(data, HDFDataset):
    cache_leftover = data.definite_cache_leftover
  return data, cache_leftover
Ejemplo n.º 5
0
  def __init__(self,
               seq_list_file, seq_lens_file,
               datasets,
               data_map, data_dims,
               data_dtypes=None,
               window=1, **kwargs):
    """
    :param str seq_list_file: filename. line-separated
    :param str seq_lens_file: filename. json. dict[str,dict[str,int]], seq-tag -> data-key -> len
    :param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files'
    :param dict[str,(str,str)] data_map: self-data-key -> (dataset-key, dataset-data-key).
      Should contain 'data' as key. Also defines the target-list, which is all except 'data'.
    :param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr).
    :param dict[str,str] data_dtypes: self-data-key -> dtype. automatic if not specified
    """
    assert window == 1  # not implemented
    super(MetaDataset, self).__init__(**kwargs)
    assert self.shuffle_frames_of_nseqs == 0  # not implemented. anyway only for non-recurrent nets

    self.seq_list_original = open(seq_list_file).read().splitlines()
    self.tag_idx = {tag: idx for (idx, tag) in enumerate(self.seq_list_original)}
    self._num_seqs = len(self.seq_list_original)

    self.data_map = data_map
    self.dataset_keys = set([m[0] for m in self.data_map.values()]); ":type: set[str]"
    self.data_keys = set(self.data_map.keys()); ":type: set[str]"
    assert "data" in self.data_keys
    self.target_list = sorted(self.data_keys - ["data"])

    data_dims = convert_data_dims(data_dims)
    self.data_dims = data_dims
    assert "data" in data_dims
    for key in self.target_list:
      assert key in data_dims
    self.num_inputs = data_dims["data"][0]
    self.num_outputs = data_dims

    self.data_dtypes = {data_key: _select_dtype(data_key, data_dims, data_dtypes) for data_key in self.data_keys}

    if seq_lens_file:
      seq_lens = load_json(filename=seq_lens_file)
      assert isinstance(seq_lens, dict)
      # dict[str,NumbersDict], seq-tag -> data-key -> len
      self._seq_lens = {tag: NumbersDict(l) for (tag, l) in seq_lens.items()}
    else:
      self._seq_lens = None

    if self._seq_lens:
      self._num_timesteps = sum([self._seq_lens[s] for s in self.seq_list_original])
    else:
      self._num_timesteps = None

    # Will only init the needed datasets.
    self.datasets = {key: init_dataset(datasets[key]) for key in self.dataset_keys}
Ejemplo n.º 6
0
 def __init__(self, datasets, **kwargs):
   """
   :param list[dict[str]] datasets: list of kwargs for init_dataset
   """
   super(ConcatDataset, self).__init__(**kwargs)
   self.datasets = [init_dataset(d_kwargs) for d_kwargs in datasets]
   assert self.datasets
   self.num_inputs = self.datasets[0].num_inputs
   self.num_outputs = self.datasets[0].num_outputs
   self.labels = self.datasets[0].labels
   for ds in self.datasets[1:]:
     assert ds.num_inputs == self.num_inputs
     assert ds.num_outputs == self.num_outputs
Ejemplo n.º 7
0
def generate_hdf_from_other(opts):
  """
  :param dict[str] opts:
  :return: hdf filename
  :rtype: str
  """
  # See test_hdf_dump.py and tools/hdf_dump.py.
  from Util import make_hashable
  cache_key = make_hashable(opts)
  if cache_key in _hdf_cache:
    return _hdf_cache[cache_key]
  fn = _get_tmp_file(suffix=".hdf")
  from Dataset import init_dataset
  dataset = init_dataset(opts)
  hdf_dataset = HDFDatasetWriter(fn)
  hdf_dataset.dump_from_dataset(dataset)
  hdf_dataset.close()
  _hdf_cache[cache_key] = fn
  return fn
Ejemplo n.º 8
0
 def __init__(self, dataset,
              chunk_shuffle_cache=1000,
              batch_gen_batch_size=5000, batch_gen_max_seqs=1,
              batch_gen_recurrent_net=True,
              **kwargs):
   """
   :param dict[str] dataset: kwargs for init_dataset
   """
   super(ChunkShuffleDataset, self).__init__(**kwargs)
   self.dataset = init_dataset(dataset)
   assert self.dataset
   self.dataset_last_load_seq_end = None
   self.chunk_shuffle_cache = chunk_shuffle_cache
   self.batch_gen = None
   self.batch_gen_batch_size = batch_gen_batch_size
   self.batch_gen_max_seqs = batch_gen_max_seqs
   self.batch_gen_recurrent_net = batch_gen_recurrent_net
   self.num_inputs = self.dataset.num_inputs
   self.num_outputs = self.dataset.num_outputs
   self.labels = self.dataset.labels
   self.rng = Random(0)
   self.load_seqs_end = None
Ejemplo n.º 9
0
  def __init__(self,
               datasets,
               data_map, data_dims,
               data_dtypes=None,
               window=1, **kwargs):
    """
    :param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files'
    :param dict[str,(str,str)] data_map: self-data-key -> (dataset-key, dataset-data-key).
      Should contain 'data' as key. Also defines the target-list, which is all except 'data'.
    :param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr).
    :param dict[str,str] data_dtypes: self-data-key -> dtype. automatic if not specified
    """
    assert window == 1  # not implemented
    super(CombinedDataset, self).__init__(**kwargs)
    assert self.shuffle_frames_of_nseqs == 0  # not implemented. anyway only for non-recurrent nets

    self.data_map = data_map
    self.dataset_keys = set([m[0] for m in self.data_map.values()]); ":type: set[str]"
    self.dataset_idxs = dict(enumerate(sorted(self.dataset_keys)))  # idx -> dataset-key
    self.data_keys = set(self.data_map.keys()); ":type: set[str]"
    assert "data" in self.data_keys
    self.target_list = sorted(self.data_keys - ["data"])

    data_dims = convert_data_dims(data_dims)
    self.data_dims = data_dims
    assert "data" in data_dims
    for key in self.target_list:
      assert key in data_dims
    self.num_inputs = data_dims["data"][0]
    self.num_outputs = data_dims

    self.data_dtypes = {data_key: _select_dtype(data_key, data_dims, data_dtypes) for data_key in self.data_keys}

    # Will only init the needed datasets.
    self.datasets = {key: init_dataset(datasets[key]) for key in self.dataset_keys}

    self._num_seqs = sum([ds.num_seqs for ds in self.datasets.values()])
Ejemplo n.º 10
0
def benchmark(lstm_unit, use_gpu):
  """
  :param str lstm_unit: e.g. "LSTMBlock", one of LstmCellTypes
  :param bool use_gpu:
  :return: runtime in seconds of the training itself, excluding initialization
  :rtype: float
  """
  device = {True: "GPU", False: "CPU"}[use_gpu]
  key = "%s:%s" % (device, lstm_unit)
  print(">>> Start benchmark for %s." % key)
  config = Config()
  config.update(make_config_dict(lstm_unit=lstm_unit, use_gpu=use_gpu))
  dataset_kwargs = config.typed_value("train")
  Dataset.kwargs_update_from_config(config, dataset_kwargs)
  dataset = init_dataset(dataset_kwargs)
  engine = Engine(config=config)
  engine.init_train_from_config(config=config, train_data=dataset)
  print(">>> Start training now for %s." % key)
  start_time = time.time()
  engine.train()
  runtime = time.time() - start_time
  print(">>> Runtime of %s: %s" % (key, hms_fraction(runtime)))
  engine.finalize()
  return runtime
Ejemplo n.º 11
0
def benchmark(lstm_unit, use_gpu):
    """
  :param str lstm_unit: e.g. "LSTMBlock", one of LstmCellTypes
  :param bool use_gpu:
  :return: runtime in seconds of the training itself, excluding initialization
  :rtype: float
  """
    device = {True: "GPU", False: "CPU"}[use_gpu]
    key = "%s:%s" % (device, lstm_unit)
    print(">>> Start benchmark for %s." % key)
    config = Config()
    config.update(make_config_dict(lstm_unit=lstm_unit, use_gpu=use_gpu))
    dataset_kwargs = config.typed_value("train")
    Dataset.kwargs_update_from_config(config, dataset_kwargs)
    dataset = init_dataset(dataset_kwargs)
    engine = Engine(config=config)
    engine.init_train_from_config(config=config, train_data=dataset)
    print(">>> Start training now for %s." % key)
    start_time = time.time()
    engine.train()
    runtime = time.time() - start_time
    print(">>> Runtime of %s: %s" % (key, hms_fraction(runtime)))
    engine.finalize()
    return runtime
Ejemplo n.º 12
0
 def __init__(self,
              dataset,
              chunk_shuffle_cache=1000,
              batch_gen_batch_size=5000,
              batch_gen_max_seqs=1,
              batch_gen_recurrent_net=True,
              **kwargs):
     """
 :param dict[str] dataset: kwargs for init_dataset
 """
     super(ChunkShuffleDataset, self).__init__(**kwargs)
     self.dataset = init_dataset(dataset)
     assert self.dataset
     self.dataset_last_load_seq_end = None
     self.chunk_shuffle_cache = chunk_shuffle_cache
     self.batch_gen = None
     self.batch_gen_batch_size = batch_gen_batch_size
     self.batch_gen_max_seqs = batch_gen_max_seqs
     self.batch_gen_recurrent_net = batch_gen_recurrent_net
     self.num_inputs = self.dataset.num_inputs
     self.num_outputs = self.dataset.num_outputs
     self.labels = self.dataset.labels
     self.rng = Random(0)
     self.load_seqs_end = None
Ejemplo n.º 13
0
def main(argv):
    argparser = argparse.ArgumentParser(
        description='Dump something from dataset.')
    argparser.add_argument(
        '--config',
        help="filename to config-file. will use dataset 'eval' from it")
    argparser.add_argument("--dataset", help="dataset, overwriting config")
    argparser.add_argument(
        "--refs",
        help="same format as hyps. alternative to providing dataset/config")
    argparser.add_argument("--hyps",
                           help="hypotheses, dumped via search in py format")
    argparser.add_argument('--startseq',
                           type=int,
                           default=0,
                           help='start seq idx (inclusive) (default: 0)')
    argparser.add_argument('--endseq',
                           type=int,
                           default=-1,
                           help='end seq idx (inclusive) or -1 (default: -1)')
    argparser.add_argument(
        "--key",
        default="raw",
        help="data-key, e.g. 'data' or 'classes'. (default: 'raw')")
    argparser.add_argument("--verbosity",
                           default=4,
                           type=int,
                           help="5 for all seqs (default: 4)")
    argparser.add_argument(
        "--out", help="if provided, will write WER% (as string) to this file")
    argparser.add_argument("--expect_full",
                           action="store_true",
                           help="full dataset should be scored")
    args = argparser.parse_args(argv[1:])
    assert args.config or args.dataset or args.refs

    init(config_filename=args.config, log_verbosity=args.verbosity)
    dataset = None
    refs = None
    if args.refs:
        refs = load_hyps_refs(args.refs)
    elif args.dataset:
        dataset = init_dataset(args.dataset)
    elif config.value("wer_data", "eval") in ["train", "dev", "eval"]:
        dataset = init_dataset(
            config.opt_typed_value(config.value("search_data", "eval")))
    else:
        dataset = init_dataset(config.opt_typed_value("wer_data"))
    hyps = load_hyps_refs(args.hyps)

    global wer_compute
    wer_compute = WerComputeGraph()
    with tf.Session(config=tf.ConfigProto(
            device_count={"GPU": 0})) as _session:
        global session
        session = _session
        session.run(tf.global_variables_initializer())
        try:
            wer = calc_wer_on_dataset(dataset=dataset,
                                      refs=refs,
                                      options=args,
                                      hyps=hyps)
            print("Final WER: %.02f%%" % (wer * 100), file=log.v1)
            if args.out:
                with open(args.out, "w") as output_file:
                    output_file.write("%.02f\n" % (wer * 100))
                print("Wrote WER%% to %r." % args.out)
        except KeyboardInterrupt:
            print("KeyboardInterrupt")
            sys.exit(1)
        finally:
            rnn.finalize()
Ejemplo n.º 14
0
def main(argv):
  argparser = argparse.ArgumentParser(description=__doc__)
  argparser.add_argument("config_file", type=str, help="RETURNN config, or model-dir")
  argparser.add_argument("--epoch", type=int)
  argparser.add_argument('--data', default="train",
                         help="e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'")
  argparser.add_argument('--do_search', default=False, action='store_true')
  argparser.add_argument('--beam_size', default=12, type=int)
  argparser.add_argument('--dump_dir', help="for npy or png")
  argparser.add_argument("--output_file", help="hdf")
  argparser.add_argument("--device", help="gpu or cpu (default: automatic)")
  argparser.add_argument("--layers", default=["att_weights"], action="append",
                         help="Layer of subnet to grab")
  argparser.add_argument("--rec_layer", default="output", help="Subnet layer to grab from; decoder")
  argparser.add_argument("--enc_layer", default="encoder")
  argparser.add_argument("--batch_size", type=int, default=5000)
  argparser.add_argument("--seq_list", default=[], action="append", help="predefined list of seqs")
  argparser.add_argument("--min_seq_len", default="0", help="can also be dict")
  argparser.add_argument("--num_seqs", default=-1, type=int, help="stop after this many seqs")
  argparser.add_argument("--output_format", default="npy", help="npy, png or hdf")
  argparser.add_argument("--dropout", default=None, type=float, help="if set, overwrites all dropout values")
  argparser.add_argument("--train_flag", action="store_true")
  argparser.add_argument("--reset_partition_epoch", type=int, default=1)
  argparser.add_argument("--reset_seq_ordering", default="sorted_reverse")
  argparser.add_argument("--reset_epoch_wise_filter", default=None)
  args = argparser.parse_args(argv[1:])

  layers = args.layers
  assert isinstance(layers, list)
  config_fn = args.config_file
  explicit_model_dir = None
  if os.path.isdir(config_fn):
    # Assume we gave a model dir.
    explicit_model_dir = config_fn
    train_log_dir_config_pattern = "%s/train-*/*.config" % config_fn
    train_log_dir_configs = sorted(glob(train_log_dir_config_pattern))
    assert train_log_dir_configs
    config_fn = train_log_dir_configs[-1]
    print("Using this config via model dir:", config_fn)
  else:
    assert os.path.isfile(config_fn)
  model_name = ".".join(config_fn.split("/")[-1].split(".")[:-1])

  init_returnn(config_fn=config_fn, args=args)
  if explicit_model_dir:
    config.set("model", "%s/%s" % (explicit_model_dir, os.path.basename(config.value('model', ''))))
  print("Model file prefix:", config.value('model', ''))

  if args.do_search:
    raise NotImplementedError
  min_seq_length = NumbersDict(eval(args.min_seq_len))

  assert args.output_format in ["npy", "png", "hdf"]
  if args.output_format in ["npy", "png"]:
    assert args.dump_dir
    if not os.path.exists(args.dump_dir):
      os.makedirs(args.dump_dir)
  plt = ticker = None
  if args.output_format == "png":
    import matplotlib.pyplot as plt  # need to import early? https://stackoverflow.com/a/45582103/133374
    import matplotlib.ticker as ticker

  dataset_str = args.data
  if dataset_str in ["train", "dev", "eval"]:
    dataset_str = "config:%s" % dataset_str
  extra_dataset_kwargs = {}
  if args.reset_partition_epoch:
    print("NOTE: We are resetting partition epoch to %i." % (args.reset_partition_epoch,))
    extra_dataset_kwargs["partition_epoch"] = args.reset_partition_epoch
  if args.reset_seq_ordering:
    print("NOTE: We will use %r seq ordering." % (args.reset_seq_ordering,))
    extra_dataset_kwargs["seq_ordering"] = args.reset_seq_ordering
  if args.reset_epoch_wise_filter:
    extra_dataset_kwargs["epoch_wise_filter"] = eval(args.reset_epoch_wise_filter)
  dataset = init_dataset(dataset_str, extra_kwargs=extra_dataset_kwargs)
  if hasattr(dataset, "epoch_wise_filter") and args.reset_epoch_wise_filter is None:
    if dataset.epoch_wise_filter:
      print("NOTE: Resetting epoch_wise_filter to None.")
      dataset.epoch_wise_filter = None
  if args.reset_partition_epoch:
    assert dataset.partition_epoch == args.reset_partition_epoch
  if args.reset_seq_ordering:
    assert dataset.seq_ordering == args.reset_seq_ordering

  init_net(args, layers)
  network = rnn.engine.network

  hdf_writer = None
  if args.output_format == "hdf":
    assert args.output_file
    assert len(layers) == 1
    sub_layer = network.get_layer("%s/%s" % (args.rec_layer, layers[0]))
    from HDFDataset import SimpleHDFWriter
    hdf_writer = SimpleHDFWriter(filename=args.output_file, dim=sub_layer.output.dim, ndim=sub_layer.output.ndim)

  extra_fetches = {
    "output": network.layers[args.rec_layer].output.get_placeholder_as_batch_major(),
    "output_len": network.layers[args.rec_layer].output.get_sequence_lengths(),  # decoder length
    "encoder_len": network.layers[args.enc_layer].output.get_sequence_lengths(),  # encoder length
    "seq_idx": network.get_extern_data("seq_idx"),
    "seq_tag": network.get_extern_data("seq_tag"),
    "target_data": network.get_extern_data(network.extern_data.default_input),
    "target_classes": network.get_extern_data(network.extern_data.default_target),
  }
  for l in layers:
    sub_layer = rnn.engine.network.get_layer("%s/%s" % (args.rec_layer, l))
    extra_fetches["rec_%s" % l] = sub_layer.output.get_placeholder_as_batch_major()
  dataset.init_seq_order(epoch=1, seq_list=args.seq_list or None)  # use always epoch 1, such that we have same seqs
  dataset_batch = dataset.generate_batches(
    recurrent_net=network.recurrent,
    batch_size=args.batch_size,
    max_seqs=rnn.engine.max_seqs,
    max_seq_length=sys.maxsize,
    min_seq_length=min_seq_length,
    max_total_num_seqs=args.num_seqs,
    used_data_keys=network.used_data_keys)

  stats = {l: Stats() for l in layers}

  # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None
  def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output, output_len, encoder_len, **kwargs):
    """
    :param list[int] seq_idx: len is n_batch
    :param list[str] seq_tag: len is n_batch
    :param numpy.ndarray target_data: extern data default input (e.g. "data"), shape e.g. (B,enc-T,...)
    :param numpy.ndarray target_classes: extern data default target (e.g. "classes"), shape e.g. (B,dec-T,...)
    :param numpy.ndarray output: rec layer output, shape e.g. (B,dec-T,...)
    :param numpy.ndarray output_len: rec layer seq len, i.e. decoder length, shape (B,)
    :param numpy.ndarray encoder_len: encoder seq len, shape (B,)
    :param kwargs: contains "rec_%s" % l for l in layers, the sub layers (e.g att weights) we are interested in
    """
    n_batch = len(seq_idx)
    for i in range(n_batch):
      for l in layers:
        att_weights = kwargs["rec_%s" % l][i]
        stats[l].collect(att_weights.flatten())
    if args.output_format == "npy":
      data = {}
      for i in range(n_batch):
        data[i] = {
          'tag': seq_tag[i],
          'data': target_data[i],
          'classes': target_classes[i],
          'output': output[i],
          'output_len': output_len[i],
          'encoder_len': encoder_len[i],
        }
        for l in [("rec_%s" % l) for l in layers]:
          assert l in kwargs
          out = kwargs[l][i]
          assert out.ndim >= 2
          assert out.shape[0] >= output_len[i] and out.shape[1] >= encoder_len[i]
          data[i][l] = out[:output_len[i], :encoder_len[i]]
        fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % (model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1])
        np.save(fname, data)
    elif args.output_format == "png":
      for i in range(n_batch):
        for l in layers:
          extra_postfix = ""
          if args.dropout is not None:
            extra_postfix += "_dropout%.2f" % args.dropout
          elif args.train_flag:
            extra_postfix += "_train"
          fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % (
            model_name, rnn.engine.epoch, seq_idx[i], l, extra_postfix)
          att_weights = kwargs["rec_%s" % l][i]
          att_weights = att_weights.squeeze(axis=2)  # (out,enc)
          assert att_weights.shape[0] >= output_len[i] and att_weights.shape[1] >= encoder_len[i]
          att_weights = att_weights[:output_len[i], :encoder_len[i]]
          print("Seq %i, %s: Dump att weights with shape %r to: %s" % (
            seq_idx[i], seq_tag[i], att_weights.shape, fname))
          plt.matshow(att_weights)
          title = seq_tag[i]
          if dataset.can_serialize_data(network.extern_data.default_target):
            title += "\n" + dataset.serialize_data(
              network.extern_data.default_target, target_classes[i][:output_len[i]])
            ax = plt.gca()
            tick_labels = [
              dataset.serialize_data(network.extern_data.default_target, np.array([x], dtype=target_classes[i].dtype))
              for x in target_classes[i][:output_len[i]]]
            ax.set_yticklabels([''] + tick_labels, fontsize=8)
            ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
          plt.title(title)
          plt.savefig(fname)
          plt.close()
    elif args.output_format == "hdf":
      assert len(layers) == 1
      att_weights = kwargs["rec_%s" % layers[0]]
      hdf_writer.insert_batch(inputs=att_weights, seq_len={0: output_len, 1: encoder_len}, seq_tag=seq_tag)
    else:
      raise Exception("output format %r" % args.output_format)

  runner = Runner(engine=rnn.engine, dataset=dataset, batches=dataset_batch,
                  train=False, train_flag=bool(args.dropout) or args.train_flag,
                  extra_fetches=extra_fetches,
                  extra_fetches_callback=fetch_callback)
  runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch)
  for l in layers:
    stats[l].dump(stream_prefix="Layer %r " % l)
  if not runner.finalized:
    print("Some error occured, not finalized.")
    sys.exit(1)

  if hdf_writer:
    hdf_writer.close()
  rnn.finalize()
Ejemplo n.º 15
0
def demo():
  """
  Demo.
  """
  print("SprintDataset demo.")
  from argparse import ArgumentParser
  from Util import progress_bar_with_time
  from Log import log
  from Config import Config
  from Dataset import init_dataset
  arg_parser = ArgumentParser()
  arg_parser.add_argument("--config", help="config with ExternSprintDataset", required=True)
  arg_parser.add_argument("--sprint_cache_dataset", help="kwargs dict for SprintCacheDataset", required=True)
  arg_parser.add_argument("--max_num_seqs", default=sys.maxsize, type=int)
  arg_parser.add_argument("--action", default="compare", help="compare or benchmark")
  args = arg_parser.parse_args()
  log.initialize(verbosity=[4])
  sprint_cache_dataset_kwargs = eval(args.sprint_cache_dataset)
  assert isinstance(sprint_cache_dataset_kwargs, dict)
  sprint_cache_dataset = SprintCacheDataset(**sprint_cache_dataset_kwargs)
  print("SprintCacheDataset: %r" % sprint_cache_dataset)
  config = Config()
  config.load_file(args.config)
  dataset = init_dataset(config.typed_value("train"))
  print("Dataset via config: %r" % dataset)
  assert sprint_cache_dataset.num_inputs == dataset.num_inputs
  assert tuple(sprint_cache_dataset.num_outputs["classes"]) == tuple(dataset.num_outputs["classes"])
  sprint_cache_dataset.init_seq_order(epoch=1)

  if args.action == "compare":
    print("Iterating through dataset...")
    seq_idx = 0
    dataset.init_seq_order(epoch=1)
    while seq_idx < args.max_num_seqs:
      if not dataset.is_less_than_num_seqs(seq_idx):
        break
      dataset.load_seqs(seq_idx, seq_idx + 1)
      tag = dataset.get_tag(seq_idx)
      assert not tag.startswith("seq-"), "dataset does not provide tag-names for seqs"
      dataset_seq = sprint_cache_dataset.get_dataset_seq_for_name(tag)
      data = dataset.get_data(seq_idx, "data")
      targets = dataset.get_data(seq_idx, "classes")
      assert data.shape == dataset_seq.features["data"].shape
      assert targets.shape == dataset_seq.features["classes"].shape
      assert numpy.allclose(data, dataset_seq.features["data"])
      assert numpy.allclose(targets, dataset_seq.features["classes"])
      seq_idx += 1
      progress_bar_with_time(dataset.get_complete_frac(seq_idx))

    print("Finished through dataset. Num seqs: %i" % seq_idx)
    print("SprintCacheDataset has num seqs: %i." % sprint_cache_dataset.num_seqs)

  elif args.action == "benchmark":
    print("Iterating through dataset...")
    start_time = time.time()
    seq_tags = []
    seq_idx = 0
    dataset.init_seq_order(epoch=1)
    while seq_idx < args.max_num_seqs:
      if not dataset.is_less_than_num_seqs(seq_idx):
        break
      dataset.load_seqs(seq_idx, seq_idx + 1)
      tag = dataset.get_tag(seq_idx)
      assert not tag.startswith("seq-"), "dataset does not provide tag-names for seqs"
      seq_tags.append(tag)
      dataset.get_data(seq_idx, "data")
      dataset.get_data(seq_idx, "classes")
      seq_idx += 1
      progress_bar_with_time(dataset.get_complete_frac(seq_idx))
    print("Finished through dataset. Num seqs: %i, time: %f" % (seq_idx, time.time() - start_time))
    print("SprintCacheDataset has num seqs: %i." % sprint_cache_dataset.num_seqs)
    if hasattr(dataset, "exit_handler"):
      dataset.exit_handler()
    else:
      print("No way to stop any background tasks.")
    del dataset

    start_time = time.time()
    print("Iterating through SprintCacheDataset...")
    for i, tag in enumerate(seq_tags):
      sprint_cache_dataset.get_dataset_seq_for_name(tag)
      progress_bar_with_time(float(i) / len(seq_tags))
    print("Finished through SprintCacheDataset. time: %f" % (time.time() - start_time,))

  else:
    raise Exception("invalid action: %r" % args.action)
Ejemplo n.º 16
0
  network={
    "fw0": {"class": "rec", "unit": "NativeLstm2", "dropout": 0.1, "n_out": 10},
    "output": {"class": "softmax", "loss": "ce", "from": ["fw0"]}
  },

  # training
  nadam=True,
  learning_rate=0.01,
  num_epochs=100,
  debug_add_check_numerics_ops=True,

  model="/tmp/%s/returnn-demo-as-framework/model" % get_login_username(),
  cleanup_old_models=True,

  learning_rate_control="newbob_multi_epoch",
  learning_rate_control_relative_error_relative_lr=True,
  newbob_multi_num_epochs=3, newbob_multi_update_interval=1, newbob_learning_rate_decay=0.9,
  learning_rate_file="/tmp/%s/returnn-demo-as-framework/newbob.data" % get_login_username(),

  # log
  log_verbosity=3
))

engine = Engine(config)

train_data = init_dataset({"class": "Task12AXDataset", "num_seqs": 1000, "name": "train"})
dev_data = init_dataset({"class": "Task12AXDataset", "num_seqs": 100, "name": "dev", "fixed_random_seed": 1})

engine.init_train_from_config(train_data=train_data, dev_data=dev_data)
engine.train()
Ejemplo n.º 17
0
    def __init__(self,
                 datasets,
                 data_map,
                 data_dims,
                 data_dtypes=None,
                 window=1,
                 **kwargs):
        """
    :param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files'
    :param dict[(str,str),str] data_map: (dataset-key, dataset-data-key) -> self-data-key.
      Should contain 'data' as key. Also defines the target-list, which is all except 'data'.
    :param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr).
    :param dict[str,str] data_dtypes: self-data-key -> dtype. automatic if not specified
    """
        assert window == 1  # not implemented
        super(CombinedDataset, self).__init__(**kwargs)
        assert self.shuffle_frames_of_nseqs == 0  # not implemented. anyway only for non-recurrent nets

        self.rnd = Random(self.epoch)
        self.dataset_keys = set(datasets.keys())
        ":type: set[str]"
        self.dataset_idxs = dict(enumerate(sorted(
            self.dataset_keys)))  # idx -> dataset-key
        self.data_keys = set(data_map.values())
        ":type: set[str]"
        assert "data" in self.data_keys
        self.target_list = sorted(self.data_keys - {"data"})

        # Build target lookup table
        target_lookup_table = {}
        for dataset_key in self.dataset_keys:
            target_lookup_table[dataset_key] = {
                datamap_maps: datamap_keys[1]
                for datamap_keys, datamap_maps in data_map.iteritems()
                if datamap_keys[0] == dataset_key
            }
            for key in self.data_keys:
                target_lookup_table[dataset_key].setdefault(key, None)

        self.target_lookup_table = target_lookup_table

        data_dims = convert_data_dims(data_dims)
        self.data_dims = data_dims
        assert "data" in data_dims
        for key in self.target_list:
            assert key in data_dims
        self.num_inputs = data_dims["data"][0]
        self.num_outputs = data_dims

        self.data_dtypes = {
            data_key: _select_dtype(data_key, data_dims, data_dtypes)
            for data_key in self.data_keys
        }

        # Will only init the needed datasets.
        self.datasets = {
            key: init_dataset(datasets[key])
            for key in self.dataset_keys
        }

        try:
            self._num_seqs = sum([
                self.datasets[k].num_seqs for k in sorted(self.datasets.keys())
            ])
            self.know_num_seqs_beforehand = True
        except Exception:
            self._estimated_num_seqs = sum([
                self.datasets[k].estimated_num_seqs
                for k in sorted(self.datasets.keys())
            ])
            self.estimated_num_seq_per_subset = [
                self.datasets[k].estimated_num_seqs
                for k in sorted(self.datasets.keys())
            ]
            self.know_num_seqs_beforehand = False
Ejemplo n.º 18
0
    def __init__(self,
                 seq_list_file,
                 seq_lens_file,
                 datasets,
                 data_map,
                 data_dims,
                 data_dtypes=None,
                 window=1,
                 **kwargs):
        """
    :param str seq_list_file: filename. line-separated
    :param str seq_lens_file: filename. json. dict[str,dict[str,int]], seq-tag -> data-key -> len
    :param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files'
    :param dict[str,(str,str)] data_map: self-data-key -> (dataset-key, dataset-data-key).
      Should contain 'data' as key. Also defines the target-list, which is all except 'data'.
    :param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr).
    :param dict[str,str] data_dtypes: self-data-key -> dtype. automatic if not specified
    """
        assert window == 1  # not implemented
        super(MetaDataset, self).__init__(**kwargs)
        assert self.shuffle_frames_of_nseqs == 0  # not implemented. anyway only for non-recurrent nets

        self.data_map = data_map
        self.dataset_keys = set([m[0] for m in self.data_map.values()])
        ":type: set[str]"
        self.data_keys = set(self.data_map.keys())
        ":type: set[str]"
        assert "data" in self.data_keys
        self.target_list = sorted(self.data_keys - {"data"})
        self.default_dataset_key = self.data_map["data"][0]

        if seq_list_file.endswith(".pkl"):
            import pickle
            seq_list = pickle.load(open(seq_list_file, 'rb'))
        else:
            seq_list = open(seq_list_file).read().splitlines()
        assert isinstance(seq_list, (list, dict))
        if isinstance(seq_list, list):
            seq_list = {key: seq_list for key in self.dataset_keys}
        self.seq_list_original = seq_list  # type: dict[str,list[str]]  # dataset key -> seq list
        self._num_seqs = len(self.seq_list_original[self.default_dataset_key])
        for key in self.dataset_keys:
            assert len(self.seq_list_original[key]) == self._num_seqs
        self.tag_idx = {
            tag: idx
            for (idx, tag) in enumerate(self.seq_list_original[
                self.default_dataset_key])
        }

        data_dims = convert_data_dims(data_dims)
        self.data_dims = data_dims
        assert "data" in data_dims
        for key in self.target_list:
            assert key in data_dims
        self.num_inputs = data_dims["data"][0]
        self.num_outputs = data_dims

        self.data_dtypes = {
            data_key: _select_dtype(data_key, data_dims, data_dtypes)
            for data_key in self.data_keys
        }

        if seq_lens_file:
            seq_lens = load_json(filename=seq_lens_file)
            assert isinstance(seq_lens, dict)
            # dict[str,NumbersDict], seq-tag -> data-key -> len
            self._seq_lens = {
                tag: NumbersDict(l)
                for (tag, l) in seq_lens.items()
            }
        else:
            self._seq_lens = None

        if self._seq_lens:
            self._num_timesteps = sum([
                self._seq_lens[s]
                for s in self.seq_list_original[self.default_dataset_key]
            ])
        else:
            self._num_timesteps = None

        # Will only init the needed datasets.
        self.datasets = {
            key:
            init_dataset(datasets[key],
                         extra_kwargs={"name": "%s_%s" % (self.name, key)})
            for key in self.dataset_keys
        }
        for data_key in self.data_keys:
            dataset_key, dataset_data_key = self.data_map[data_key]
            dataset = self.datasets[dataset_key]
            if dataset_data_key in dataset.labels:
                self.labels[data_key] = dataset.labels[dataset_data_key]
Ejemplo n.º 19
0
def main():
  arg_parser = ArgumentParser()
  arg_parser.add_argument("--action")
  arg_parser.add_argument("--print_seq", action='store_true')
  arg_parser.add_argument("--print_allos", action='store_true')
  arg_parser.add_argument("--print_targets", action='store_true')
  arg_parser.add_argument("--dataset")
  arg_parser.add_argument("--corpus")
  arg_parser.add_argument("--lexicon", help="filename")
  arg_parser.add_argument("--silence", type=int, help="index")
  arg_parser.add_argument("--context", default=1, type=int)
  arg_parser.add_argument("--hmm_states", default=3, type=int)
  arg_parser.add_argument("--state_tying_type", help="'monophone' or 'full'")
  arg_parser.add_argument("--state_tying_output", help="filename")
  arg_parser.add_argument("--allo_add_all", action="store_true")
  args = arg_parser.parse_args()

  dataset = init_dataset(config_str=args.dataset) if args.dataset else None
  corpus = dict(iter_bliss_orth(filename=args.corpus)) if args.corpus else None
  lexicon = Lexicon(filename=args.lexicon) if args.lexicon else None
  silence_label = args.silence

  if args.action == "show_corpus":
    pprint(corpus)
    return

  print("Num phones: %i" % len(lexicon.phonemes), file=log.v1)
  print("Phones: %r" % sorted(lexicon.phonemes.keys()), file=log.v1)

  orth_handler = OrthHandler(
    lexicon=lexicon,
    allo_context_len=args.context,
    allo_num_states=args.hmm_states)
  map_idx_to_allo = defaultdict(set)  # type: dict[int, set[AllophoneState]]
  map_allo_to_idx = {}  # type: dict[AllophoneState, int]
  if args.allo_add_all:
    orth_handler.allo_add_all = True

  print("Num HMM states: %i" % orth_handler.allo_num_states, file=log.v1)
  if args.state_tying_type == "monophone":
    print("Monophone state tying.", file=log.v1)
    num_labels = orth_handler.expected_num_labels_for_monophone_state_tying()
    all_label_idx_are_used = True
  elif args.state_tying_type == "full":
    print("Full state tying.", file=log.v1)
    phone_idxs = {k: i + 1 for (i, k) in enumerate(lexicon.phoneme_list)}  # +1 to keep 0 reserved as the term-symbol
    for phon in lexicon.phoneme_list:
      for allo in orth_handler.all_allophone_variations(phon, all_boundary_variations=True):
        allo_idx = allo.index(
          phone_idxs=phone_idxs,
          num_states=orth_handler.allo_num_states,
          context_length=orth_handler.allo_context_len)
        map_idx_to_allo[allo_idx].add(allo)
    num_labels = max(map_idx_to_allo.keys()) + 1
    all_label_idx_are_used = False
  else:
    raise Exception("invalid state tying type %r" % args.state_tying_type)
  print("Num labels: %i" % num_labels, file=log.v1)

  if dataset:
    count = 0
    for segment_name, targets in iter_dataset_targets(dataset):
      count += 1
      if silence_label is None or count == 1:
        likely_silence_label = collections.Counter(targets).most_common(1)[0][0]
        if silence_label is None:
          silence_label = likely_silence_label
        if silence_label != likely_silence_label:
          print("warning: silence %i but likely %i" % (silence_label, likely_silence_label), file=log.v2)
        print("Silence label: %i" % silence_label, file=log.v1)
        orth_handler.si_label = silence_label
        # Monophone state tying:
        for allo in orth_handler.all_allophone_variations(orth_handler.si_phone):
          map_idx_to_allo[silence_label].add(allo)
          map_allo_to_idx[allo] = silence_label
      assert segment_name in corpus
      orth = corpus[segment_name]
      allo_states = orth_handler.orth_to_allophone_states(orth=orth)
      if args.print_seq:
        print("%r %r" % (segment_name, orth))
      if args.print_allos:
        print("  allophone state seq: %r" % allo_states)
      tgt_seq = [t for t in uniq(targets) if t != silence_label]
      if args.print_targets:
        print("  target seq: %r" % (tgt_seq,))
      assert len(allo_states) == len(tgt_seq), "check --hmm_states or so"
      for allo, t in zip(allo_states, tgt_seq):
        allo.boundary = 0  # do not differ between boundaries
        allos = map_idx_to_allo[t]
        if allo in map_allo_to_idx:
          assert allo in allos, "bad mapping"
        else:
          assert allo not in allos
          allos.add(allo)
          map_allo_to_idx[allo] = t
      if len(map_idx_to_allo) >= num_labels:
        assert len(map_idx_to_allo) == num_labels
        assert 0 in map_idx_to_allo
        assert num_labels - 1 in map_idx_to_allo
        print("Finished with uniq mapping after %i sequences." % count, file=log.v1)
        break
      if count % 100 == 0:
        print("Have indices: %i (num labels: %i)" % (len(map_idx_to_allo), num_labels), file=log.v1)

    print("Finished. Have indices: %i (num labels: %i)" % (len(map_idx_to_allo), num_labels), file=log.v1)
    if len(map_idx_to_allo) < num_labels:
      found = []
      not_found = []
      for p in sorted(lexicon.phonemes.keys()):
        allo = AllophoneState(p, state=0)
        if allo in map_allo_to_idx:
          found.append(p)
        else:
          not_found.append(p)
      print("Phonemes found: %r" % found)
      print("Phonemes not found: %r" % not_found)

  if args.state_tying_output:
    assert not os.path.exists(args.state_tying_output)
    if all_label_idx_are_used:
      assert len(map_idx_to_allo) == num_labels
      assert 0 in map_idx_to_allo
      assert num_labels - 1 in map_idx_to_allo
    f = open(args.state_tying_output, "w")
    for i, allos in sorted(map_idx_to_allo.items()):
      for allo in allos:
        f.write("%s %i\n" % (allo.format(), i))
    f.close()
    print("Wrote state tying to %r." % args.state_tying_output, file=log.v1)

  print("The end.")
Ejemplo n.º 20
0
def main(argv):
    argparser = argparse.ArgumentParser(description=__doc__)
    argparser.add_argument("config_file", type=str)
    argparser.add_argument("--epoch", required=False, type=int)
    argparser.add_argument(
        '--data',
        default="train",
        help=
        "e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'"
    )
    argparser.add_argument('--do_search', default=False, action='store_true')
    argparser.add_argument('--beam_size', default=12, type=int)
    argparser.add_argument('--dump_dir', required=True)
    argparser.add_argument("--device", default="gpu")
    argparser.add_argument("--layers",
                           default=["att_weights"],
                           action="append",
                           help="Layer of subnet to grab")
    argparser.add_argument("--rec_layer",
                           default="output",
                           help="Subnet layer to grab from; decoder")
    argparser.add_argument("--enc_layer", default="encoder")
    argparser.add_argument("--batch_size", type=int, default=5000)
    argparser.add_argument("--seq_list",
                           default=[],
                           action="append",
                           help="predefined list of seqs")
    argparser.add_argument("--min_seq_len",
                           default="0",
                           help="can also be dict")
    argparser.add_argument("--output_format", default="npy", help="npy or png")
    argparser.add_argument("--dropout",
                           default=None,
                           type=float,
                           help="if set, overwrites all dropout values")
    argparser.add_argument("--train_flag", action="store_true")
    args = argparser.parse_args(argv[1:])

    layers = args.layers
    assert isinstance(layers, list)
    model_name = ".".join(args.config_file.split("/")[-1].split(".")[:-1])

    init_returnn(config_fn=args.config_file,
                 cmd_line_opts=["--device", args.device],
                 args=args)

    if args.do_search:
        raise NotImplementedError
    min_seq_length = NumbersDict(eval(args.min_seq_len))

    if not os.path.exists(args.dump_dir):
        os.makedirs(args.dump_dir)
    assert args.output_format in ["npy", "png"]
    if args.output_format == "png":
        import matplotlib.pyplot as plt  # need to import early? https://stackoverflow.com/a/45582103/133374
        import matplotlib.ticker as ticker
    dataset_str = args.data
    if dataset_str in ["train", "dev", "eval"]:
        dataset_str = "config:%s" % dataset_str
    dataset = init_dataset(dataset_str)
    init_net(args, layers)

    network = rnn.engine.network

    extra_fetches = {}
    for rec_ret_layer in ["rec_%s" % l for l in layers]:
        extra_fetches[rec_ret_layer] = rnn.engine.network.layers[
            rec_ret_layer].output.get_placeholder_as_batch_major()
    extra_fetches.update({
        "output":
        network.layers[args.rec_layer].output.get_placeholder_as_batch_major(),
        "output_len":
        network.layers[
            args.rec_layer].output.get_sequence_lengths(),  # decoder length
        "encoder_len":
        network.layers[
            args.enc_layer].output.get_sequence_lengths(),  # encoder length
        "seq_idx":
        network.get_extern_data("seq_idx"),
        "seq_tag":
        network.get_extern_data("seq_tag"),
        "target_data":
        network.get_extern_data(network.extern_data.default_input),
        "target_classes":
        network.get_extern_data(network.extern_data.default_target),
    })
    dataset.init_seq_order(epoch=rnn.engine.epoch,
                           seq_list=args.seq_list or None)
    dataset_batch = dataset.generate_batches(
        recurrent_net=network.recurrent,
        batch_size=args.batch_size,
        max_seqs=rnn.engine.max_seqs,
        max_seq_length=sys.maxsize,
        min_seq_length=min_seq_length,
        used_data_keys=network.used_data_keys)

    stats = {l: Stats() for l in layers}

    # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None
    def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output,
                       output_len, encoder_len, **kwargs):
        for i in range(len(seq_idx)):
            for l in layers:
                att_weights = kwargs["rec_%s" % l][i]
                stats[l].collect(att_weights.flatten())
        if args.output_format == "npy":
            data = {}
            for i in range(len(seq_idx)):
                data[i] = {
                    'tag': seq_tag[i],
                    'data': target_data[i],
                    'classes': target_classes[i],
                    'output': output[i],
                    'output_len': output_len[i],
                    'encoder_len': encoder_len[i],
                }
                for l in [("rec_%s" % l) for l in layers]:
                    assert l in kwargs
                    out = kwargs[l][i]
                    assert out.ndim >= 2
                    assert out.shape[0] >= output_len[i] and out.shape[
                        1] >= encoder_len[i]
                    data[i][l] = out[:output_len[i], :encoder_len[i]]
                fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % (
                    model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1])
                np.save(fname, data)
        elif args.output_format == "png":
            for i in range(len(seq_idx)):
                for l in layers:
                    extra_postfix = ""
                    if args.dropout is not None:
                        extra_postfix += "_dropout%.2f" % args.dropout
                    elif args.train_flag:
                        extra_postfix += "_train"
                    fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % (
                        model_name, rnn.engine.epoch, seq_idx[i], l,
                        extra_postfix)
                    att_weights = kwargs["rec_%s" % l][i]
                    att_weights = att_weights.squeeze(axis=2)  # (out,enc)
                    assert att_weights.shape[0] >= output_len[
                        i] and att_weights.shape[1] >= encoder_len[i]
                    att_weights = att_weights[:output_len[i], :encoder_len[i]]
                    print("Seq %i, %s: Dump att weights with shape %r to: %s" %
                          (seq_idx[i], seq_tag[i], att_weights.shape, fname))
                    plt.matshow(att_weights)
                    title = seq_tag[i]
                    if dataset.can_serialize_data(
                            network.extern_data.default_target):
                        title += "\n" + dataset.serialize_data(
                            network.extern_data.default_target,
                            target_classes[i][:output_len[i]])
                        ax = plt.gca()
                        tick_labels = [
                            dataset.serialize_data(
                                network.extern_data.default_target,
                                np.array([x], dtype=target_classes[i].dtype))
                            for x in target_classes[i][:output_len[i]]
                        ]
                        ax.set_yticklabels([''] + tick_labels, fontsize=8)
                        ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
                    plt.title(title)
                    plt.savefig(fname)
                    plt.close()
        else:
            raise NotImplementedError("output format %r" % args.output_format)

    runner = Runner(engine=rnn.engine,
                    dataset=dataset,
                    batches=dataset_batch,
                    train=False,
                    train_flag=args.dropout is not None or args.train_flag,
                    extra_fetches=extra_fetches,
                    extra_fetches_callback=fetch_callback)
    runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch)
    for l in layers:
        stats[l].dump(stream_prefix="Layer %r " % l)
    if not runner.finalized:
        print("Some error occured, not finalized.")
        sys.exit(1)

    rnn.finalize()
Ejemplo n.º 21
0
def execute_main_task():
    """
  Executes the main task (via config ``task`` option).
  """
    from Util import hms_fraction
    start_time = time.time()
    task = config.value('task', 'train')
    if config.is_true("dry_run"):
        print("Dry run, will not save anything.", file=log.v1)
    if task == 'train':
        assert train_data.have_seqs(
        ), "no train files specified, check train option: %s" % config.value(
            'train', None)
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
        engine.train()
    elif task == "eval":
        epoch = config.int("epoch", -1)
        load_epoch = config.int("load_epoch", -1)
        if epoch >= 0:
            assert (load_epoch < 0) or (
                load_epoch == epoch), "epoch and load_epoch have to match"
            engine.epoch = epoch
            config.set('load_epoch', engine.epoch)
        else:
            assert load_epoch >= 0, "specify epoch or load_epoch"
            engine.epoch = load_epoch
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
        print("Evaluate epoch", engine.epoch, file=log.v4)
        engine.eval_model(
            output_file=config.value("eval_output_file", None),
            output_per_seq_file=config.value("eval_output_file_per_seq", None),
            loss_name=config.value("loss_name", None),
            output_per_seq_format=config.list("output_per_seq_format",
                                              ["score"]),
            output_per_seq_file_format=config.value(
                "output_per_seq_file_format", "txt"))
    elif task in ['forward', 'hpx']:
        assert eval_data is not None, 'no eval data provided'
        combine_labels = config.value('combine_labels', '')
        engine.use_search_flag = config.bool("forward_use_search", False)
        if config.has("epoch"):
            config.set('load_epoch', config.int('epoch', 0))
        engine.init_network_from_config(config)
        output_file = config.value('output_file',
                                   'dump-fwd-epoch-%i.hdf' % engine.epoch)
        engine.forward_to_hdf(data=eval_data,
                              output_file=output_file,
                              combine_labels=combine_labels,
                              batch_size=config.int('forward_batch_size', 0))
    elif task == "search":
        engine.use_search_flag = True
        engine.init_network_from_config(config)
        if config.value("search_data", "eval") in ["train", "dev", "eval"]:
            data = {
                "train": train_data,
                "dev": dev_data,
                "eval": eval_data
            }[config.value("search_data", "eval")]
            assert data, "set search_data"
        else:
            data = init_dataset(config.opt_typed_value("search_data"))
        engine.search(
            data,
            do_eval=config.bool("search_do_eval", True),
            output_layer_names=config.typed_value("search_output_layer",
                                                  "output"),
            output_file=config.value("search_output_file", ""),
            output_file_format=config.value("search_output_file_format",
                                            "txt"))
    elif task == 'compute_priors':
        assert train_data is not None, 'train data for priors should be provided'
        engine.init_network_from_config(config)
        engine.compute_priors(dataset=train_data, config=config)
    elif task == 'theano_graph':
        # noinspection PyPackageRequirements,PyUnresolvedReferences
        import theano.printing
        # noinspection PyPackageRequirements,PyUnresolvedReferences
        import theano.compile.io
        # noinspection PyPackageRequirements,PyUnresolvedReferences
        import theano.compile.function_module
        engine.start_epoch = 1
        engine.init_network_from_config(config)
        for task in config.list('theano_graph.task', ['train']):
            func = engine.devices[-1].get_compute_func(task)
            prefix = config.value("theano_graph.prefix", "current") + ".task"
            print("dumping to %s.* ..." % prefix, file=log.v1)
            theano.printing.debugprint(func,
                                       file=open(
                                           "%s.optimized_func.txt" % prefix,
                                           "w"))
            assert isinstance(func.maker,
                              theano.compile.function_module.FunctionMaker)
            for inp in func.maker.inputs:
                assert isinstance(inp, theano.compile.io.In)
                if inp.update:
                    theano.printing.debugprint(
                        inp.update,
                        file=open(
                            "%s.unoptimized.var_%s_update.txt" %
                            (prefix, inp.name), "w"))
            theano.printing.pydotprint(func,
                                       format='png',
                                       var_with_name_simple=True,
                                       outfile="%s.png" % prefix)
    elif task == 'analyze':  # anything based on the network + Device
        statistics = config.list('statistics', None)
        engine.init_network_from_config(config)
        engine.analyze(data=eval_data or dev_data, statistics=statistics)
    elif task == "analyze_data":  # anything just based on the data
        analyze_data(config)
    elif task == "classify":
        assert eval_data is not None, 'no eval data provided'
        assert config.has('label_file'), 'no output file provided'
        label_file = config.value('label_file', '')
        engine.init_network_from_config(config)
        engine.classify(engine.devices[0], eval_data, label_file)
    elif task == "hyper_param_tuning":
        import HyperParamTuning
        tuner = HyperParamTuning.Optimization(config=config,
                                              train_data=train_data)
        tuner.work()
    elif task == "cleanup_old_models":
        engine.cleanup_old_models(ask_for_confirmation=True)
    elif task == "daemon":
        engine.init_network_from_config(config)
        engine.daemon(config)
    elif task == "server":
        print("Server Initiating", file=log.v1)
        server.run()
    elif task == "search_server":
        engine.use_search_flag = True
        engine.init_network_from_config(config)
        engine.web_server(port=config.int("web_server_port", 12380))
    elif task.startswith("config:"):
        action = config.typed_dict[task[len("config:"):]]
        print("Task: %r" % action, file=log.v1)
        assert callable(action)
        action()
    elif task.startswith("optional-config:"):
        action = config.typed_dict.get(task[len("optional-config:"):], None)
        if action is None:
            print("No task found for %r, so just quitting." % task,
                  file=log.v1)
        else:
            print("Task: %r" % action, file=log.v1)
            assert callable(action)
            action()
    elif task == "nop":
        print("Task: No-operation", file=log.v1)
    elif task == "nop_init_net_train":
        print(
            "Task: No-operation, despite initializing the network (for training)",
            file=log.v1)
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
    elif task == "initialize_model":
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
        engine.save_model(config.value('model', 'dummy'))
    else:
        assert False, "unknown task: %s" % task

    print(("elapsed: %s" % hms_fraction(time.time() - start_time)),
          file=log.v3)
Ejemplo n.º 22
0
def executeMainTask():
    st = time.time()
    task = config.value('task', 'train')
    if task == 'train':
        assert train_data.have_seqs(
        ), "no train files specified, check train option: %s" % config.value(
            'train', None)
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
        engine.train()
    elif task == "eval":
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
        engine.epoch = config.int("epoch", None)
        assert engine.epoch
        print("Evaluate epoch", engine.epoch, file=log.v4)
        engine.eval_model()
    elif task == 'forward':
        assert eval_data is not None, 'no eval data provided'
        assert config.has('output_file'), 'no output file provided'
        combine_labels = config.value('combine_labels', '')
        output_file = config.value('output_file', '')
        engine.init_network_from_config(config)
        engine.forward_to_hdf(data=eval_data,
                              output_file=output_file,
                              combine_labels=combine_labels,
                              batch_size=config.int('forward_batch_size', 0))
    elif task == "search":
        engine.use_search_flag = True
        engine.init_network_from_config(config)
        if config.value("search_data", "eval") in ["train", "dev", "eval"]:
            data = {
                "train": train_data,
                "dev": dev_data,
                "eval": eval_data
            }[config.value("search_data", "eval")]
            assert data, "set search_data"
        else:
            data = init_dataset(config.typed_value("search_data"))
        engine.search(data,
                      output_layer_name=config.value("search_output_layer",
                                                     "output"))
    elif task == 'compute_priors':
        assert train_data is not None, 'train data for priors should be provided'
        engine.init_network_from_config(config)
        engine.compute_priors(dataset=train_data, config=config)
    elif task == 'theano_graph':
        import theano.printing
        import theano.compile.io
        import theano.compile.function_module
        engine.start_epoch = 1
        engine.init_network_from_config(config)
        for task in config.list('theano_graph.task', ['train']):
            func = engine.devices[-1].get_compute_func(task)
            prefix = config.value("theano_graph.prefix", "current") + ".task"
            print("dumping to %s.* ..." % prefix, file=log.v1)
            theano.printing.debugprint(func,
                                       file=open(
                                           "%s.optimized_func.txt" % prefix,
                                           "w"))
            assert isinstance(func.maker,
                              theano.compile.function_module.FunctionMaker)
            for inp in func.maker.inputs:
                assert isinstance(inp, theano.compile.io.In)
                if inp.update:
                    theano.printing.debugprint(
                        inp.update,
                        file=open(
                            "%s.unoptimized.var_%s_update.txt" %
                            (prefix, inp.name), "w"))
            theano.printing.pydotprint(func,
                                       format='png',
                                       var_with_name_simple=True,
                                       outfile="%s.png" % prefix)
    elif task == 'analyze':  # anything based on the network + Device
        statistics = config.list('statistics', None)
        engine.init_network_from_config(config)
        engine.analyze(data=eval_data or dev_data, statistics=statistics)
    elif task == "analyze_data":  # anything just based on the data
        analyze_data(config)
    elif task == "classify":
        assert eval_data is not None, 'no eval data provided'
        assert config.has('label_file'), 'no output file provided'
        label_file = config.value('label_file', '')
        engine.init_network_from_config(config)
        engine.classify(engine.devices[0], eval_data, label_file)
    elif task == "daemon":
        engine.init_network_from_config(config)
        engine.daemon(config)
    elif task == "server":
        print("Server Initiating", file=log.v1)
    elif task.startswith("config:"):
        action = config.typed_dict[task[len("config:"):]]
        print("Task: %r" % action, file=log.v1)
        assert callable(action)
        action()
    elif task.startswith("optional-config:"):
        action = config.typed_dict.get(task[len("optional-config:"):], None)
        if action is None:
            print("No task found for %r, so just quitting." % task,
                  file=log.v1)
        else:
            print("Task: %r" % action, file=log.v1)
            assert callable(action)
            action()
    elif task == "nop":
        print("Task: No-operation", file=log.v1)
    else:
        assert False, "unknown task: %s" % task

    print(("elapsed: %f" % (time.time() - st)), file=log.v3)
Ejemplo n.º 23
0
def demo():
    print("SprintDataset demo.")
    from argparse import ArgumentParser
    from Util import hms, progress_bar_with_time
    from Log import log
    from Config import Config
    from Dataset import init_dataset
    arg_parser = ArgumentParser()
    arg_parser.add_argument("--config",
                            help="config with ExternSprintDataset",
                            required=True)
    arg_parser.add_argument("--sprint_cache_dataset",
                            help="kwargs dict for SprintCacheDataset",
                            required=True)
    arg_parser.add_argument("--max_num_seqs", default=sys.maxint, type=int)
    arg_parser.add_argument("--action",
                            default="compare",
                            help="compare or benchmark")
    args = arg_parser.parse_args()
    log.initialize(verbosity=[4])
    sprint_cache_dataset_kwargs = eval(args.sprint_cache_dataset)
    assert isinstance(sprint_cache_dataset_kwargs, dict)
    sprint_cache_dataset = SprintCacheDataset(**sprint_cache_dataset_kwargs)
    print("SprintCacheDataset: %r" % sprint_cache_dataset)
    config = Config()
    config.load_file(args.config)
    dataset = init_dataset(config.typed_value("train"))
    print("Dataset via config: %r" % dataset)
    assert sprint_cache_dataset.num_inputs == dataset.num_inputs
    assert tuple(sprint_cache_dataset.num_outputs["classes"]) == tuple(
        dataset.num_outputs["classes"])
    sprint_cache_dataset.init_seq_order(epoch=1)

    if args.action == "compare":
        print("Iterating through dataset...")
        seq_idx = 0
        dataset.init_seq_order(epoch=1)
        while seq_idx < args.max_num_seqs:
            if not dataset.is_less_than_num_seqs(seq_idx):
                break
            dataset.load_seqs(seq_idx, seq_idx + 1)
            tag = dataset.get_tag(seq_idx)
            assert not tag.startswith(
                "seq-"), "dataset does not provide tag-names for seqs"
            dataset_seq = sprint_cache_dataset.get_dataset_seq_for_name(tag)
            data = dataset.get_data(seq_idx, "data")
            targets = dataset.get_data(seq_idx, "classes")
            assert data.shape == dataset_seq.features.shape
            assert targets.shape == dataset_seq.targets["classes"].shape
            assert numpy.allclose(data, dataset_seq.features)
            assert numpy.allclose(targets, dataset_seq.targets["classes"])
            seq_idx += 1
            progress_bar_with_time(dataset.get_complete_frac(seq_idx))

        print("Finished through dataset. Num seqs: %i" % seq_idx)
        print("SprintCacheDataset has num seqs: %i." %
              sprint_cache_dataset.num_seqs)

    elif args.action == "benchmark":
        print("Iterating through dataset...")
        start_time = time.time()
        seq_tags = []
        seq_idx = 0
        dataset.init_seq_order(epoch=1)
        while seq_idx < args.max_num_seqs:
            if not dataset.is_less_than_num_seqs(seq_idx):
                break
            dataset.load_seqs(seq_idx, seq_idx + 1)
            tag = dataset.get_tag(seq_idx)
            assert not tag.startswith(
                "seq-"), "dataset does not provide tag-names for seqs"
            seq_tags.append(tag)
            dataset.get_data(seq_idx, "data")
            dataset.get_data(seq_idx, "classes")
            seq_idx += 1
            progress_bar_with_time(dataset.get_complete_frac(seq_idx))
        print("Finished through dataset. Num seqs: %i, time: %f" %
              (seq_idx, time.time() - start_time))
        print("SprintCacheDataset has num seqs: %i." %
              sprint_cache_dataset.num_seqs)
        if hasattr(dataset, "exit_handler"):
            dataset.exit_handler()
        else:
            print("No way to stop any background tasks.")
        del dataset

        start_time = time.time()
        print("Iterating through SprintCacheDataset...")
        for i, tag in enumerate(seq_tags):
            sprint_cache_dataset.get_dataset_seq_for_name(tag)
            progress_bar_with_time(float(i) / len(seq_tags))
        print("Finished through SprintCacheDataset. time: %f" %
              (time.time() - start_time, ))

    else:
        raise Exception("invalid action: %r" % args.action)
Ejemplo n.º 24
0
def main(argv):
    argparser = argparse.ArgumentParser(description=__doc__)
    argparser.add_argument("config_file",
                           type=str,
                           help="RETURNN config, or model-dir")
    argparser.add_argument("--epoch", type=int)
    argparser.add_argument(
        '--data',
        default="train",
        help=
        "e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'"
    )
    argparser.add_argument('--do_search', default=False, action='store_true')
    argparser.add_argument('--beam_size', default=12, type=int)
    argparser.add_argument('--dump_dir', help="for npy or png")
    argparser.add_argument("--output_file", help="hdf")
    argparser.add_argument("--device", help="gpu or cpu (default: automatic)")
    argparser.add_argument("--layers",
                           default=[],
                           action="append",
                           help="Layer of subnet to grab")
    argparser.add_argument("--rec_layer",
                           default="output",
                           help="Subnet layer to grab from; decoder")
    argparser.add_argument("--enc_layer", default="encoder")
    argparser.add_argument("--batch_size", type=int, default=5000)
    argparser.add_argument("--seq_list",
                           default=[],
                           action="append",
                           help="predefined list of seqs")
    argparser.add_argument("--min_seq_len",
                           default="0",
                           help="can also be dict")
    argparser.add_argument("--num_seqs",
                           default=-1,
                           type=int,
                           help="stop after this many seqs")
    argparser.add_argument("--output_format",
                           default="npy",
                           help="npy, png or hdf")
    argparser.add_argument("--dropout",
                           default=None,
                           type=float,
                           help="if set, overwrites all dropout values")
    argparser.add_argument("--train_flag", action="store_true")
    argparser.add_argument("--reset_partition_epoch", type=int, default=None)
    argparser.add_argument("--reset_seq_ordering", default="default")
    argparser.add_argument("--reset_epoch_wise_filter", default=None)
    argparser.add_argument('--hmm_fac_fo', default=False, action='store_true')
    argparser.add_argument('--encoder_sa', default=False, action='store_true')
    argparser.add_argument('--tf_log_dir', help="for npy or png", default=None)
    argparser.add_argument("--instead_save_encoder_decoder",
                           action="store_true")
    args = argparser.parse_args(argv[1:])

    layers = args.layers
    assert isinstance(layers, list)
    config_fn = args.config_file
    explicit_model_dir = None
    if os.path.isdir(config_fn):
        # Assume we gave a model dir.
        explicit_model_dir = config_fn
        train_log_dir_config_pattern = "%s/train-*/*.config" % config_fn
        train_log_dir_configs = sorted(glob(train_log_dir_config_pattern))
        assert train_log_dir_configs
        config_fn = train_log_dir_configs[-1]
        print("Using this config via model dir:", config_fn)
    else:
        assert os.path.isfile(config_fn)
    model_name = ".".join(config_fn.split("/")[-1].split(".")[:-1])

    init_returnn(config_fn=config_fn, args=args)
    if explicit_model_dir:
        config.set(
            "model", "%s/%s" %
            (explicit_model_dir, os.path.basename(config.value('model', ''))))
    print("Model file prefix:", config.value('model', ''))

    min_seq_length = NumbersDict(eval(args.min_seq_len))

    assert args.output_format in ["npy", "png", "hdf"]
    if args.output_format in ["npy", "png"]:
        assert args.dump_dir
        if not os.path.exists(args.dump_dir):
            os.makedirs(args.dump_dir)
    plt = ticker = None
    if args.output_format == "png":
        import matplotlib.pyplot as plt  # need to import early? https://stackoverflow.com/a/45582103/133374
        import matplotlib.ticker as ticker

    dataset_str = args.data
    if dataset_str in ["train", "dev", "eval"]:
        dataset_str = "config:%s" % dataset_str
    extra_dataset_kwargs = {}
    if args.reset_partition_epoch:
        print("NOTE: We are resetting partition epoch to %i." %
              (args.reset_partition_epoch, ))
        extra_dataset_kwargs["partition_epoch"] = args.reset_partition_epoch
    if args.reset_seq_ordering:
        print("NOTE: We will use %r seq ordering." %
              (args.reset_seq_ordering, ))
        extra_dataset_kwargs["seq_ordering"] = args.reset_seq_ordering
    if args.reset_epoch_wise_filter:
        extra_dataset_kwargs["epoch_wise_filter"] = eval(
            args.reset_epoch_wise_filter)
    dataset = init_dataset(dataset_str, extra_kwargs=extra_dataset_kwargs)
    if hasattr(dataset,
               "epoch_wise_filter") and args.reset_epoch_wise_filter is None:
        if dataset.epoch_wise_filter:
            print("NOTE: Resetting epoch_wise_filter to None.")
            dataset.epoch_wise_filter = None
    if args.reset_partition_epoch:
        assert dataset.partition_epoch == args.reset_partition_epoch
    if args.reset_seq_ordering:
        assert dataset.seq_ordering == args.reset_seq_ordering

    init_net(args, layers)
    network = rnn.engine.network

    hdf_writer = None
    if args.output_format == "hdf":
        assert args.output_file
        assert len(layers) == 1
        sub_layer = network.get_layer("%s/%s" % (args.rec_layer, layers[0]))
        from HDFDataset import SimpleHDFWriter
        hdf_writer = SimpleHDFWriter(filename=args.output_file,
                                     dim=sub_layer.output.dim,
                                     ndim=sub_layer.output.ndim)

    extra_fetches = {
        "output":
        network.layers[args.rec_layer].output.get_placeholder_as_batch_major(),
        "output_len":
        network.layers[
            args.rec_layer].output.get_sequence_lengths(),  # decoder length
        "encoder_len":
        network.layers[
            args.enc_layer].output.get_sequence_lengths(),  # encoder length
        "seq_idx":
        network.get_extern_data("seq_idx"),
        "seq_tag":
        network.get_extern_data("seq_tag"),
        "target_data":
        network.get_extern_data(network.extern_data.default_input),
        "target_classes":
        network.get_extern_data(network.extern_data.default_target),
    }

    if args.instead_save_encoder_decoder:
        extra_fetches["encoder"] = network.layers["encoder"]
        extra_fetches["decoder"] = network.layers["output"].get_sub_layer(
            "decoder")
    else:
        for l in layers:
            sub_layer = rnn.engine.network.get_layer("%s/%s" %
                                                     (args.rec_layer, l))
            extra_fetches[
                "rec_%s" %
                l] = sub_layer.output.get_placeholder_as_batch_major()
            if args.do_search:
                o_layer = rnn.engine.network.get_layer("output")
                extra_fetches["beam_scores_" +
                              l] = o_layer.get_search_choices().beam_scores

    dataset.init_seq_order(
        epoch=1, seq_list=args.seq_list
        or None)  # use always epoch 1, such that we have same seqs
    dataset_batch = dataset.generate_batches(
        recurrent_net=network.recurrent,
        batch_size=args.batch_size,
        max_seqs=rnn.engine.max_seqs,
        max_seq_length=sys.maxsize,
        min_seq_length=min_seq_length,
        max_total_num_seqs=args.num_seqs,
        used_data_keys=network.used_data_keys)

    stats = {l: Stats() for l in layers}

    # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None
    def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output,
                       output_len, encoder_len, **kwargs):
        """
    :param list[int] seq_idx: len is n_batch
    :param list[str] seq_tag: len is n_batch
    :param numpy.ndarray target_data: extern data default input (e.g. "data"), shape e.g. (B,enc-T,...)
    :param numpy.ndarray target_classes: extern data default target (e.g. "classes"), shape e.g. (B,dec-T,...)
    :param numpy.ndarray output: rec layer output, shape e.g. (B,dec-T,...)
    :param numpy.ndarray output_len: rec layer seq len, i.e. decoder length, shape (B,)
    :param numpy.ndarray encoder_len: encoder seq len, shape (B,)
    :param kwargs: contains "rec_%s" % l for l in layers, the sub layers (e.g att weights) we are interested in
    """
        n_batch = len(seq_idx)

        for i in range(n_batch):
            if not args.instead_save_encoder_decoder:
                for l in layers:
                    att_weights = kwargs["rec_%s" % l][i]
                    stats[l].collect(att_weights.flatten())
        if args.output_format == "npy":
            data = {}
            if args.do_search:
                assert not args.instead_save_encoder_decoder, "Not implemented"

                # find axis with correct beam size
                axis_beam_size = n_batch * args.beam_size
                corr_axis = None
                num_axes = len(kwargs["rec_%s" % layers[0]].shape)
                for a in range(len(kwargs["rec_%s" % layers[0]].shape)):
                    if kwargs["rec_%s" % layers[0]].shape[a] == axis_beam_size:
                        corr_axis = a
                        break
                assert corr_axis is not None, "Att Weights Decoding: correct axis not found! maybe check beam size."

                # set dimensions correctly
                for l, l_raw in zip([("rec_%s" % l) for l in layers], layers):
                    swap = list(range(num_axes))
                    del swap[corr_axis]
                    swap.insert(0, corr_axis)
                    kwargs[l] = np.transpose(kwargs[l], swap)

                for i in range(n_batch):
                    # The first beam contains the score with the highest beam
                    i_beam = args.beam_size * i
                    data[i] = {
                        'tag': seq_tag[i],
                        'data': target_data[i],
                        'classes': target_classes[i],
                        'output': output[i_beam],
                        'output_len': output_len[i_beam],
                        'encoder_len': encoder_len[i],
                    }

                    #if args.hmm_fac_fo is False:
                    for l, l_raw in zip([("rec_%s" % l) for l in layers],
                                        layers):
                        assert l in kwargs
                        out = kwargs[l][i_beam]
                        # Do search for multihead
                        # out is [I, H, 1, J] is new version
                        # out is [I, J, H, 1] for old version
                        if len(out.shape) == 3 and min(out.shape) > 1:
                            out = np.transpose(
                                out,
                                axes=(0, 3, 1, 2))  # [I, J, H, 1] new version
                            out = np.squeeze(out, axis=-1)
                        else:
                            out = np.squeeze(out, axis=1)
                        assert out.shape[0] >= output_len[
                            i_beam] and out.shape[1] >= encoder_len[i]
                        data[i][l] = out[:output_len[i_beam], :encoder_len[i]]
                    fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % (
                        model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1])
                    np.save(fname, data)
            else:
                for i in range(n_batch):
                    data[i] = {
                        'tag': seq_tag[i],
                        'data': target_data[i],
                        'classes': target_classes[i],
                        'output': output[i],
                        'output_len': output_len[i],
                        'encoder_len': encoder_len[i],
                    }
                    #if args.hmm_fac_fo is False:

                    if args.instead_save_encoder_decoder:
                        out = kwargs["encoder"][i]
                        data[i]["encoder"] = out[:encoder_len[i]]
                        out_2 = kwargs["decoder"][i]
                        data[i]["decoder"] = out_2[:output_len[i]]
                    else:
                        for l in [("rec_%s" % l) for l in layers]:
                            assert l in kwargs
                            out = kwargs[l][i]  # []
                            # multi-head attention
                            if len(out.shape) == 3 and min(out.shape) > 1:
                                # Multihead attention
                                out = np.transpose(
                                    out,
                                    axes=(1, 2, 0))  # (I, J, H) new version
                                #out = np.transpose(out, axes=(2, 0, 1))  # [I, J, H] old version
                            else:
                                # RNN
                                out = np.squeeze(out, axis=1)
                            assert out.ndim >= 2
                            assert out.shape[0] >= output_len[i] and out.shape[
                                1] >= encoder_len[i]
                            data[i][l] = out[:output_len[i], :encoder_len[i]]
                    fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % (
                        model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1])
                    np.save(fname, data)
        elif args.output_format == "png":
            for i in range(n_batch):
                for l in layers:
                    extra_postfix = ""
                    if args.dropout is not None:
                        extra_postfix += "_dropout%.2f" % args.dropout
                    elif args.train_flag:
                        extra_postfix += "_train"
                    fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % (
                        model_name, rnn.engine.epoch, seq_idx[i], l,
                        extra_postfix)
                    att_weights = kwargs["rec_%s" % l][i]
                    att_weights = att_weights.squeeze(axis=2)  # (out,enc)
                    assert att_weights.shape[0] >= output_len[
                        i] and att_weights.shape[1] >= encoder_len[i]
                    att_weights = att_weights[:output_len[i], :encoder_len[i]]
                    print("Seq %i, %s: Dump att weights with shape %r to: %s" %
                          (seq_idx[i], seq_tag[i], att_weights.shape, fname))
                    plt.matshow(att_weights)
                    title = seq_tag[i]
                    if dataset.can_serialize_data(
                            network.extern_data.default_target):
                        title += "\n" + dataset.serialize_data(
                            network.extern_data.default_target,
                            target_classes[i][:output_len[i]])
                        ax = plt.gca()
                        tick_labels = [
                            dataset.serialize_data(
                                network.extern_data.default_target,
                                np.array([x], dtype=target_classes[i].dtype))
                            for x in target_classes[i][:output_len[i]]
                        ]
                        ax.set_yticklabels([''] + tick_labels, fontsize=8)
                        ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
                    plt.title(title)
                    plt.savefig(fname)
                    plt.close()
        elif args.output_format == "hdf":
            assert len(layers) == 1
            att_weights = kwargs["rec_%s" % layers[0]]
            hdf_writer.insert_batch(inputs=att_weights,
                                    seq_len={
                                        0: output_len,
                                        1: encoder_len
                                    },
                                    seq_tag=seq_tag)
        else:
            raise Exception("output format %r" % args.output_format)

    runner = Runner(engine=rnn.engine,
                    dataset=dataset,
                    batches=dataset_batch,
                    train=False,
                    train_flag=bool(args.dropout) or args.train_flag,
                    extra_fetches=extra_fetches,
                    extra_fetches_callback=fetch_callback,
                    eval=False)
    runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch)

    if not args.instead_save_encoder_decoder:
        for l in layers:
            stats[l].dump(stream_prefix="Layer %r " % l)
    if not runner.finalized:
        print("Some error occured, not finalized.")
        sys.exit(1)

    if hdf_writer:
        hdf_writer.close()
    rnn.finalize()
Ejemplo n.º 25
0
def execute_main_task():
  """
  Executes the main task (via config ``task`` option).
  """
  from Util import hms_fraction
  start_time = time.time()
  task = config.value('task', 'train')
  if task == 'train':
    assert train_data.have_seqs(), "no train files specified, check train option: %s" % config.value('train', None)
    engine.init_train_from_config(config, train_data, dev_data, eval_data)
    engine.train()
  elif task == "eval":
    epoch = config.int("epoch", -1)
    load_epoch = config.int("load_epoch", -1)
    if epoch >= 0:
      assert (load_epoch < 0) or (load_epoch == epoch), "epoch and load_epoch have to match"
      engine.epoch = epoch
      config.set('load_epoch', engine.epoch)
    else:
      assert load_epoch, "specify epoch or load_epoch"
      engine.epoch = load_epoch
    engine.init_train_from_config(config, train_data, dev_data, eval_data)
    print("Evaluate epoch", engine.epoch, file=log.v4)
    engine.eval_model(
      output_file=config.value("eval_output_file", None),
      output_per_seq_file=config.value("eval_output_file_per_seq", None),
      loss_name=config.value("loss_name", None),
      output_per_seq_format=config.list("output_per_seq_format", ["score"]),
      output_per_seq_file_format=config.value("output_per_seq_file_format", "txt"))
  elif task in ['forward', 'hpx']:
    assert eval_data is not None, 'no eval data provided'
    combine_labels = config.value('combine_labels', '')
    engine.use_search_flag = config.bool("forward_use_search", False)
    if config.has("epoch"):
      config.set('load_epoch', config.int('epoch', 0))
    engine.init_network_from_config(config)
    output_file = config.value('output_file', 'dump-fwd-epoch-%i.hdf' % engine.epoch)
    engine.forward_to_hdf(
      data=eval_data, output_file=output_file, combine_labels=combine_labels,
      batch_size=config.int('forward_batch_size', 0))
  elif task == "search":
    engine.use_search_flag = True
    engine.init_network_from_config(config)
    if config.value("search_data", "eval") in ["train", "dev", "eval"]:
      data = {"train": train_data, "dev": dev_data, "eval": eval_data}[config.value("search_data", "eval")]
      assert data, "set search_data"
    else:
      data = init_dataset(config.opt_typed_value("search_data"))
    engine.search(
      data,
      do_eval=config.bool("search_do_eval", True),
      output_layer_names=config.typed_value("search_output_layer", "output"),
      output_file=config.value("search_output_file", ""),
      output_file_format=config.value("search_output_file_format", "txt"))
  elif task == 'compute_priors':
    assert train_data is not None, 'train data for priors should be provided'
    engine.init_network_from_config(config)
    engine.compute_priors(dataset=train_data, config=config)
  elif task == 'theano_graph':
    import theano.printing
    import theano.compile.io
    import theano.compile.function_module
    engine.start_epoch = 1
    engine.init_network_from_config(config)
    for task in config.list('theano_graph.task', ['train']):
      func = engine.devices[-1].get_compute_func(task)
      prefix = config.value("theano_graph.prefix", "current") + ".task"
      print("dumping to %s.* ..." % prefix, file=log.v1)
      theano.printing.debugprint(func, file=open("%s.optimized_func.txt" % prefix, "w"))
      assert isinstance(func.maker, theano.compile.function_module.FunctionMaker)
      for inp in func.maker.inputs:
        assert isinstance(inp, theano.compile.io.In)
        if inp.update:
          theano.printing.debugprint(
            inp.update, file=open("%s.unoptimized.var_%s_update.txt" % (prefix, inp.name), "w"))
      theano.printing.pydotprint(func, format='png', var_with_name_simple=True,
                                 outfile="%s.png" % prefix)
  elif task == 'analyze':  # anything based on the network + Device
    statistics = config.list('statistics', None)
    engine.init_network_from_config(config)
    engine.analyze(data=eval_data or dev_data, statistics=statistics)
  elif task == "analyze_data":  # anything just based on the data
    analyze_data(config)
  elif task == "classify":
    assert eval_data is not None, 'no eval data provided'
    assert config.has('label_file'), 'no output file provided'
    label_file = config.value('label_file', '')
    engine.init_network_from_config(config)
    engine.classify(engine.devices[0], eval_data, label_file)
  elif task == "hyper_param_tuning":
    import HyperParamTuning
    tuner = HyperParamTuning.Optimization(config=config, train_data=train_data)
    tuner.work()
  elif task == "cleanup_old_models":
    engine.cleanup_old_models(ask_for_confirmation=True)
  elif task == "daemon":
    engine.init_network_from_config(config)
    engine.daemon(config)
  elif task == "server":
    print("Server Initiating", file=log.v1)
    server.run()
  elif task == "search_server":
    engine.use_search_flag = True
    engine.init_network_from_config(config)
    engine.web_server(port=config.int("web_server_port", 12380))
  elif task.startswith("config:"):
    action = config.typed_dict[task[len("config:"):]]
    print("Task: %r" % action, file=log.v1)
    assert callable(action)
    action()
  elif task.startswith("optional-config:"):
    action = config.typed_dict.get(task[len("optional-config:"):], None)
    if action is None:
      print("No task found for %r, so just quitting." % task, file=log.v1)
    else:
      print("Task: %r" % action, file=log.v1)
      assert callable(action)
      action()
  elif task == "nop":
    print("Task: No-operation", file=log.v1)
  elif task == "nop_init_net_train":
    print("Task: No-operation, despite initializing the network (for training)", file=log.v1)
    engine.init_train_from_config(config, train_data, dev_data, eval_data)
  elif task == "initialize_model":
    engine.init_train_from_config(config, train_data, dev_data, eval_data)
    engine.save_model(config.value('model', 'dummy'))
  else:
    assert False, "unknown task: %s" % task

  print(("elapsed: %s" % hms_fraction(time.time() - start_time)), file=log.v3)
Ejemplo n.º 26
0
def init(config_str, config_dataset, use_pretrain, epoch, verbosity):
  """
  :param str config_str: either filename to config-file, or dict for dataset
  :param str|None config_dataset:
  :param bool use_pretrain: might overwrite config options, or even the dataset
  :param int epoch:
  :param int verbosity:
  """
  rnn.init_better_exchook()
  rnn.init_thread_join_hack()
  dataset_opts = None
  config_filename = None
  if config_str.strip().startswith("{"):
    print("Using dataset %s." % config_str)
    dataset_opts = eval(config_str.strip())
  elif config_str.endswith(".hdf"):
    dataset_opts = {"class": "HDFDataset", "files": [config_str]}
    print("Using dataset %r." % dataset_opts)
    assert os.path.exists(config_str)
  else:
    config_filename = config_str
    print("Using config file %r." % config_filename)
    assert os.path.exists(config_filename)
  rnn.init_config(config_filename=config_filename, default_config={"cache_size": "0"})
  global config
  config = rnn.config
  config.set("log", None)
  config.set("log_verbosity", verbosity)
  rnn.init_log()
  print("Returnn %s starting up." % __file__, file=log.v2)
  rnn.returnn_greeting()
  rnn.init_faulthandler()
  rnn.init_config_json_network()
  Util.BackendEngine.select_engine(config=config)
  if not dataset_opts:
    if config_dataset:
      dataset_opts = "config:%s" % config_dataset
    else:
      dataset_opts = "config:train"
  if use_pretrain:
    from Pretrain import pretrain_from_config
    pretrain = pretrain_from_config(config)
    if pretrain:
      print("Using pretrain %s, epoch %i" % (pretrain, epoch), file=log.v2)
      net_dict = pretrain.get_network_json_for_epoch(epoch=epoch)
      if "#config" in net_dict:
        config_overwrites = net_dict["#config"]
        print("Pretrain overwrites these config options:", file=log.v2)
        assert isinstance(config_overwrites, dict)
        for key, value in sorted(config_overwrites.items()):
          assert isinstance(key, str)
          orig_value = config.typed_dict.get(key, None)
          if isinstance(orig_value, dict) and isinstance(value, dict):
            diff_str = "\n" + Util.dict_diff_str(orig_value, value)
          elif isinstance(value, dict):
            diff_str = "\n%r ->\n%s" % (orig_value, pformat(value))
          else:
            diff_str = " %r -> %r" % (orig_value, value)
          print("Config key %r for epoch %i:%s" % (key, epoch, diff_str), file=log.v2)
          config.set(key, value)
      else:
        print("No config overwrites for this epoch.", file=log.v2)
    else:
      print("No pretraining used.", file=log.v2)
  elif config.typed_dict.get("pretrain", None):
    print("Not using pretrain.", file=log.v2)
  dataset_default_opts = {}
  Dataset.kwargs_update_from_config(config, dataset_default_opts)
  print("Using dataset:", dataset_opts, file=log.v2)
  global dataset
  dataset = init_dataset(dataset_opts, default_kwargs=dataset_default_opts)
  assert isinstance(dataset, Dataset)
  dataset.init_seq_order(epoch=epoch)