Example #1
0
  def __init__(self,
               datasets,
               data_map, data_dims,
               data_dtypes=None,
               window=1, **kwargs):
    """
    :param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files'
    :param dict[(str,str),str] data_map: (dataset-key, dataset-data-key) -> self-data-key.
      Should contain 'data' as key. Also defines the target-list, which is all except 'data'.
    :param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr).
    :param dict[str,str] data_dtypes: self-data-key -> dtype. automatic if not specified
    """
    assert window == 1  # not implemented
    super(CombinedDataset, self).__init__(**kwargs)
    assert self.shuffle_frames_of_nseqs == 0  # not implemented. anyway only for non-recurrent nets

    self.rnd = Random(self.epoch)
#    self.data_map = data_map
    self.dataset_keys = set(datasets.keys()); ":type: set[str]"
    self.dataset_idxs = dict(enumerate(sorted(self.dataset_keys)))  # idx -> dataset-key
    self.data_keys = set(data_map.values()); ":type: set[str]"
    assert "data" in self.data_keys
    self.target_list = sorted(self.data_keys - {"data"})

    # Build target lookup table
    target_lookup_table = {}
    for dataset_key in self.dataset_keys:
      target_lookup_table[dataset_key] = {datamap_maps: datamap_keys[1] for datamap_keys,datamap_maps in data_map.iteritems() if datamap_keys[0]==dataset_key}
      for key in self.data_keys:
        target_lookup_table[dataset_key].setdefault(key,None)

    self.target_lookup_table = target_lookup_table


    data_dims = convert_data_dims(data_dims)
    self.data_dims = data_dims
    assert "data" in data_dims
    for key in self.target_list:
      assert key in data_dims
    self.num_inputs = data_dims["data"][0]
    self.num_outputs = data_dims

    self.data_dtypes = {data_key: _select_dtype(data_key, data_dims, data_dtypes) for data_key in self.data_keys}

    # Will only init the needed datasets.
    self.datasets = {key: init_dataset(datasets[key]) for key in self.dataset_keys}

    try:
      self._num_seqs = sum([self.datasets[k].num_seqs for k in sorted(self.datasets.keys())])
      self.know_num_seqs_beforehand = True
#      print "Dont need to set estimations for num_seqs. Currently is {s}".format(s=[ds.num_seqs for ds in self.datasets.values()])
    except Exception:
      self._estimated_num_seqs = sum([self.datasets[k].estimated_num_seqs for k in sorted(self.datasets.keys())])
      self.estimated_num_seq_per_subset = [self.datasets[k].estimated_num_seqs for k in sorted(self.datasets.keys())]
#      TODO this estimate seems broken on a small test corpus; needs further testing
#      print "Need to set estimations for num_seqs. Currently is {s}".format(s=[ds.estimated_num_seqs for ds in self.datasets.values()])
      self.know_num_seqs_beforehand = False
Example #2
0
  def __init__(self,
               seq_list_file, seq_lens_file,
               datasets,
               data_map, data_dims,
               data_dtypes=None,
               window=1, **kwargs):
    """
    :param str seq_list_file: filename. line-separated
    :param str seq_lens_file: filename. json. dict[str,dict[str,int]], seq-tag -> data-key -> len
    :param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files'
    :param dict[str,(str,str)] data_map: self-data-key -> (dataset-key, dataset-data-key).
      Should contain 'data' as key. Also defines the target-list, which is all except 'data'.
    :param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr).
    :param dict[str,str] data_dtypes: self-data-key -> dtype. automatic if not specified
    """
    assert window == 1  # not implemented
    super(MetaDataset, self).__init__(**kwargs)
    assert self.shuffle_frames_of_nseqs == 0  # not implemented. anyway only for non-recurrent nets

    self.seq_list_original = open(seq_list_file).read().splitlines()
    self.tag_idx = {tag: idx for (idx, tag) in enumerate(self.seq_list_original)}
    self._num_seqs = len(self.seq_list_original)

    self.data_map = data_map
    self.dataset_keys = set([m[0] for m in self.data_map.values()]); ":type: set[str]"
    self.data_keys = set(self.data_map.keys()); ":type: set[str]"
    assert "data" in self.data_keys
    self.target_list = sorted(self.data_keys - ["data"])

    data_dims = convert_data_dims(data_dims)
    self.data_dims = data_dims
    assert "data" in data_dims
    for key in self.target_list:
      assert key in data_dims
    self.num_inputs = data_dims["data"][0]
    self.num_outputs = data_dims

    self.data_dtypes = {data_key: _select_dtype(data_key, data_dims, data_dtypes) for data_key in self.data_keys}

    if seq_lens_file:
      seq_lens = load_json(filename=seq_lens_file)
      assert isinstance(seq_lens, dict)
      # dict[str,NumbersDict], seq-tag -> data-key -> len
      self._seq_lens = {tag: NumbersDict(l) for (tag, l) in seq_lens.items()}
    else:
      self._seq_lens = None

    if self._seq_lens:
      self._num_timesteps = sum([self._seq_lens[s] for s in self.seq_list_original])
    else:
      self._num_timesteps = None

    # Will only init the needed datasets.
    self.datasets = {key: init_dataset(datasets[key]) for key in self.dataset_keys}
Example #3
0
  def __init__(self,
               seq_list_file, seq_lens_file,
               datasets,
               data_map, data_dims,
               data_dtypes=None,
               window=1, **kwargs):
    """
    :param str seq_list_file: filename. line-separated
    :param str seq_lens_file: filename. json. dict[str,dict[str,int]], seq-tag -> data-key -> len
    :param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files'
    :param dict[str,(str,str)] data_map: self-data-key -> (dataset-key, dataset-data-key).
      Should contain 'data' as key. Also defines the target-list, which is all except 'data'.
    :param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr).
    :param dict[str,str] data_dtypes: self-data-key -> dtype. automatic if not specified
    """
    assert window == 1  # not implemented
    super(MetaDataset, self).__init__(**kwargs)
    assert self.shuffle_frames_of_nseqs == 0  # not implemented. anyway only for non-recurrent nets

    self.seq_list_original = open(seq_list_file).read().splitlines()
    self.tag_idx = {tag: idx for (idx, tag) in enumerate(self.seq_list_original)}
    self._num_seqs = len(self.seq_list_original)

    self.data_map = data_map
    self.dataset_keys = set([m[0] for m in self.data_map.values()]); ":type: set[str]"
    self.data_keys = set(self.data_map.keys()); ":type: set[str]"
    assert "data" in self.data_keys
    self.target_list = sorted(self.data_keys - ["data"])

    data_dims = convert_data_dims(data_dims)
    self.data_dims = data_dims
    assert "data" in data_dims
    for key in self.target_list:
      assert key in data_dims
    self.num_inputs = data_dims["data"][0]
    self.num_outputs = data_dims

    self.data_dtypes = {data_key: _select_dtype(data_key, data_dims, data_dtypes) for data_key in self.data_keys}

    if seq_lens_file:
      seq_lens = load_json(filename=seq_lens_file)
      assert isinstance(seq_lens, dict)
      # dict[str,NumbersDict], seq-tag -> data-key -> len
      self._seq_lens = {tag: NumbersDict(l) for (tag, l) in seq_lens.items()}
    else:
      self._seq_lens = None

    if self._seq_lens:
      self._num_timesteps = sum([self._seq_lens[s] for s in self.seq_list_original])
    else:
      self._num_timesteps = None

    # Will only init the needed datasets.
    self.datasets = {key: init_dataset(datasets[key]) for key in self.dataset_keys}
Example #4
0
 def num_inputs_outputs_from_config(cls, config):
   """
   :type config: Config.Config
   :returns (num_inputs, num_outputs),
      where num_inputs is like num_outputs["data"][0],
      and num_outputs is a dict of data_key -> (dim, ndim),
        where data_key is e.g. "classes" or "data",
        dim is the feature dimension or the number of classes,
        and ndim is the ndim counted without batch-dim,
        i.e. ndim=1 means usually sparse data and ndim=2 means dense data.
   :rtype: (int,dict[str,(int,int)])
   """
   num_inputs = config.int('num_inputs', 0)
   target = config.value('target', 'classes')
   if config.is_typed('num_outputs'):
     num_outputs = config.typed_value('num_outputs')
     if not isinstance(num_outputs, dict):
       num_outputs = {target: num_outputs}
     num_outputs = num_outputs.copy()
     from Dataset import convert_data_dims
     from Util import BackendEngine
     num_outputs = convert_data_dims(num_outputs, leave_dict_as_is=BackendEngine.is_tensorflow_selected())
     if "data" in num_outputs:
       num_inputs = num_outputs["data"][0]
   elif config.has('num_outputs'):
     num_outputs = {target: [config.int('num_outputs', 0), 1]}
   else:
     num_outputs = None
   dataset = None
   if config.list('train') and ":" not in config.value('train', ''):
     dataset = config.list('train')[0]
   if not config.is_typed('num_outputs') and dataset:
     try:
       _num_inputs = hdf5_dimension(dataset, 'inputCodeSize') * config.int('window', 1)
     except Exception:
       _num_inputs = hdf5_dimension(dataset, 'inputPattSize') * config.int('window', 1)
     try:
       _num_outputs = {target: [hdf5_dimension(dataset, 'numLabels'), 1]}
     except Exception:
       _num_outputs = hdf5_group(dataset, 'targets/size')
       for k in _num_outputs:
         _num_outputs[k] = [_num_outputs[k], len(hdf5_shape(dataset, 'targets/data/' + k))]
     if num_inputs: assert num_inputs == _num_inputs
     if num_outputs: assert num_outputs == _num_outputs
     num_inputs = _num_inputs
     num_outputs = _num_outputs
   if not num_inputs and not num_outputs and config.has("load"):
     from Network import LayerNetwork
     import h5py
     model = h5py.File(config.value("load", ""), "r")
     num_inputs, num_outputs = LayerNetwork._n_in_out_from_hdf_model(model)
   assert num_inputs and num_outputs, "provide num_inputs/num_outputs directly or via train"
   return num_inputs, num_outputs
    def __init__(self,
                 data,
                 target_list=None,
                 output_dim=None,
                 input_dim=None,
                 **kwargs):
        """
    :type data: list[dict[str,numpy.ndarray]]
    """
        assert len(data) > 0
        self.data = data
        num_seqs = len(data)
        first_data = data[0]
        assert "data" in first_data  # input
        if target_list is None:
            target_list = []
            for target in first_data.keys():
                if target == "data": continue
                target_list.append(target)
        else:
            for target in target_list:
                assert target in first_data
        self.target_list = target_list

        if output_dim is None:
            output_dim = {}
        output_dim = convert_data_dims(output_dim)

        first_data_input = first_data["data"]
        assert len(first_data_input.shape) <= 2  # (time[,dim])
        if input_dim is None:
            if "data" in output_dim:
                input_dim = output_dim["data"][0]
            else:
                input_dim = first_data_input.shape[1]

        for target in target_list:
            first_data_output = first_data[target]
            assert len(first_data_output.shape) <= 2  # (time[,dim])
            if target in output_dim:
                assert output_dim[target][1] == len(first_data_output.shape)
                if len(first_data_output.shape) >= 2:
                    assert output_dim[target][0] == first_data_output.shape[1]
            else:
                assert len(
                    first_data_output.shape
                ) == 2, "We expect not sparse. Or specify it explicitly in output_dim."
                output_dim[target] = [first_data_output.shape[1], 2]

        super(StaticDataset, self).__init__(input_dim=input_dim,
                                            output_dim=output_dim,
                                            num_seqs=num_seqs,
                                            **kwargs)
Example #6
0
  def __init__(self, input_dim, output_dim, window=1, num_seqs=float("inf"), fixed_random_seed=None, **kwargs):
    assert window == 1
    super(GeneratingDataset, self).__init__(window, **kwargs)
    assert self.shuffle_frames_of_nseqs == 0

    self.num_inputs = input_dim
    output_dim = convert_data_dims(output_dim)
    if "data" not in output_dim:
      output_dim["data"] = [input_dim, 2]  # not sparse
    self.num_outputs = output_dim
    self.expected_load_seq_start = 0
    self._num_seqs = num_seqs
    self.random = numpy.random.RandomState(1)
    self.fixed_random_seed = fixed_random_seed  # useful when used as eval dataset
Example #7
0
  def __init__(self, input_dim, output_dim, window=1, num_seqs=float("inf"), fixed_random_seed=None, **kwargs):
    assert window == 1
    super(GeneratingDataset, self).__init__(window=window, **kwargs)
    assert self.shuffle_frames_of_nseqs == 0

    self.num_inputs = input_dim
    output_dim = convert_data_dims(output_dim)
    if "data" not in output_dim:
      output_dim["data"] = [input_dim, 2]  # not sparse
    self.num_outputs = output_dim
    self.expected_load_seq_start = 0
    self._num_seqs = num_seqs
    self.random = numpy.random.RandomState(1)
    self.fixed_random_seed = fixed_random_seed  # useful when used as eval dataset
Example #8
0
  def __init__(self, data, target_list=None, output_dim=None, input_dim=None, **kwargs):
    """
    :type data: list[dict[str,numpy.ndarray]]
    """
    assert len(data) > 0
    self.data = data
    num_seqs = len(data)
    first_data = data[0]
    assert "data" in first_data  # input
    if target_list is None:
      target_list = []
      for target in first_data.keys():
        if target == "data": continue
        target_list.append(target)
    else:
      for target in target_list:
        assert target in first_data
    self.target_list = target_list

    if output_dim is None:
      output_dim = {}
    output_dim = convert_data_dims(output_dim)

    first_data_input = first_data["data"]
    assert len(first_data_input.shape) <= 2  # (time[,dim])
    if input_dim is None:
      if "data" in output_dim:
        input_dim = output_dim["data"][0]
      else:
        input_dim = first_data_input.shape[1]

    for target in target_list:
      first_data_output = first_data[target]
      assert len(first_data_output.shape) <= 2  # (time[,dim])
      if target in output_dim:
        assert output_dim[target][1] == len(first_data_output.shape)
        if len(first_data_output.shape) >= 2:
          assert output_dim[target][0] == first_data_output.shape[1]
      else:
        assert len(first_data_output.shape) == 2, "We expect not sparse. Or specify it explicitly in output_dim."
        output_dim[target] = [first_data_output.shape[1], 2]

    super(StaticDataset, self).__init__(input_dim=input_dim, output_dim=output_dim, num_seqs=num_seqs, **kwargs)
Example #9
0
 def num_inputs_outputs_from_config(cls, config):
   """
   :type config: Config.Config
   :rtype: (int,dict[str,(int,int)])
   """
   num_inputs = config.int('num_inputs', 0)
   target = config.value('target', 'classes')
   if config.is_typed('num_outputs'):
     num_outputs = config.typed_value('num_outputs')
     if not isinstance(num_outputs, dict):
       num_outputs = {target: num_outputs}
     num_outputs = num_outputs.copy()
     from Dataset import convert_data_dims
     num_outputs = convert_data_dims(num_outputs)
     if "data" in num_outputs:
       num_inputs = num_outputs["data"][0]
   elif config.has('num_outputs'):
     num_outputs = {target: [config.int('num_outputs', 0), 1]}
   else:
     num_outputs = None
   if not config.is_typed('num_outputs') and config.list('train') and ":" not in config.value('train', ''):
     try:
       _num_inputs = hdf5_dimension(config.list('train')[0], 'inputCodeSize') * config.int('window', 1)
     except Exception:
       _num_inputs = hdf5_dimension(config.list('train')[0], 'inputPattSize') * config.int('window', 1)
     try:
       _num_outputs = {target: [hdf5_dimension(config.list('train')[0], 'numLabels'), 1]}
     except Exception:
       _num_outputs = hdf5_group(config.list('train')[0], 'targets/size')
       for k in _num_outputs:
         _num_outputs[k] = [_num_outputs[k], len(hdf5_shape(config.list('train')[0], 'targets/data/' + k))]
     if num_inputs: assert num_inputs == _num_inputs
     if num_outputs: assert num_outputs == _num_outputs
     num_inputs = _num_inputs
     num_outputs = _num_outputs
   assert num_inputs and num_outputs, "provide num_inputs/num_outputs directly or via train"
   loss = cls.loss_from_config(config)
   #if loss in ('ctc', 'ce_ctc') or config.bool('add_blank', False):
   #  for k in num_outputs:
   #    num_outputs[k][0] += 1  # add blank
   return num_inputs, num_outputs
Example #10
0
  def __init__(self,
               datasets,
               data_map, data_dims,
               data_dtypes=None,
               window=1, **kwargs):
    """
    :param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files'
    :param dict[str,(str,str)] data_map: self-data-key -> (dataset-key, dataset-data-key).
      Should contain 'data' as key. Also defines the target-list, which is all except 'data'.
    :param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr).
    :param dict[str,str] data_dtypes: self-data-key -> dtype. automatic if not specified
    """
    assert window == 1  # not implemented
    super(CombinedDataset, self).__init__(**kwargs)
    assert self.shuffle_frames_of_nseqs == 0  # not implemented. anyway only for non-recurrent nets

    self.data_map = data_map
    self.dataset_keys = set([m[0] for m in self.data_map.values()]); ":type: set[str]"
    self.dataset_idxs = dict(enumerate(sorted(self.dataset_keys)))  # idx -> dataset-key
    self.data_keys = set(self.data_map.keys()); ":type: set[str]"
    assert "data" in self.data_keys
    self.target_list = sorted(self.data_keys - ["data"])

    data_dims = convert_data_dims(data_dims)
    self.data_dims = data_dims
    assert "data" in data_dims
    for key in self.target_list:
      assert key in data_dims
    self.num_inputs = data_dims["data"][0]
    self.num_outputs = data_dims

    self.data_dtypes = {data_key: _select_dtype(data_key, data_dims, data_dtypes) for data_key in self.data_keys}

    # Will only init the needed datasets.
    self.datasets = {key: init_dataset(datasets[key]) for key in self.dataset_keys}

    self._num_seqs = sum([ds.num_seqs for ds in self.datasets.values()])
Example #11
0
    def __init__(self,
                 input_dim,
                 output_dim,
                 num_seqs=float("inf"),
                 fixed_random_seed=None,
                 **kwargs):
        """
    :param int input_dim:
    :param int|dict[str,int|(int,int)|dict] output_dim:
    :param int|float num_seqs:
    :param int fixed_random_seed:
    """
        super(GeneratingDataset, self).__init__(**kwargs)
        assert self.shuffle_frames_of_nseqs == 0

        self.num_inputs = input_dim
        output_dim = convert_data_dims(output_dim)
        if "data" not in output_dim:
            output_dim["data"] = [input_dim, 2]  # not sparse
        self.num_outputs = output_dim
        self.expected_load_seq_start = 0
        self._num_seqs = num_seqs
        self.random = numpy.random.RandomState(1)
        self.fixed_random_seed = fixed_random_seed  # useful when used as eval dataset
Example #12
0
 def num_inputs_outputs_from_config(cls, config):
   """
   :type config: Config.Config
   :returns (num_inputs, num_outputs),
      where num_inputs is like num_outputs["data"][0],
      and num_outputs is a dict of data_key -> (dim, ndim),
        where data_key is e.g. "classes" or "data",
        dim is the feature dimension or the number of classes,
        and ndim is the ndim counted without batch-dim,
        i.e. ndim=1 means usually sparse data and ndim=2 means dense data.
   :rtype: (int,dict[str,(int,int)])
   """
   from Util import BackendEngine
   num_inputs = config.int('num_inputs', 0)
   target = config.value('target', 'classes')
   if config.is_typed('num_outputs'):
     num_outputs = config.typed_value('num_outputs')
     if not isinstance(num_outputs, dict):
       num_outputs = {target: num_outputs}
     num_outputs = num_outputs.copy()
     from Dataset import convert_data_dims
     num_outputs = convert_data_dims(num_outputs, leave_dict_as_is=BackendEngine.is_tensorflow_selected())
     if "data" in num_outputs:
       num_inputs = num_outputs["data"]
       if isinstance(num_inputs, (list, tuple)):
         num_inputs = num_inputs[0]
       elif isinstance(num_inputs, dict):
         if "dim" in num_inputs:
           num_inputs = num_inputs["dim"]
         else:
           num_inputs = num_inputs["shape"][-1]
       else:
         raise TypeError("data key %r" % num_inputs)
   elif config.has('num_outputs'):
     num_outputs = {target: [config.int('num_outputs', 0), 1]}
   else:
     num_outputs = None
   dataset = None
   if config.list('train') and ":" not in config.value('train', ''):
     dataset = config.list('train')[0]
   if not config.is_typed('num_outputs') and dataset:
     # noinspection PyBroadException
     try:
       _num_inputs = hdf5_dimension(dataset, 'inputCodeSize') * config.int('window', 1)
     except Exception:
       _num_inputs = hdf5_dimension(dataset, 'inputPattSize') * config.int('window', 1)
     # noinspection PyBroadException
     try:
       _num_outputs = {target: [hdf5_dimension(dataset, 'numLabels'), 1]}
     except Exception:
       _num_outputs = hdf5_group(dataset, 'targets/size')
       for k in _num_outputs:
         _num_outputs[k] = [_num_outputs[k], len(hdf5_shape(dataset, 'targets/data/' + k))]
     if num_inputs:
       assert num_inputs == _num_inputs
     if num_outputs:
       assert num_outputs == _num_outputs
     num_inputs = _num_inputs
     num_outputs = _num_outputs
   if not num_inputs and not num_outputs and config.has("load") and BackendEngine.is_theano_selected():
     from Network import LayerNetwork
     import h5py
     model = h5py.File(config.value("load", ""), "r")
     # noinspection PyProtectedMember
     num_inputs, num_outputs = LayerNetwork._n_in_out_from_hdf_model(model)
   assert num_inputs and num_outputs, "provide num_inputs/num_outputs directly or via train"
   return num_inputs, num_outputs
Example #13
0
    def __init__(self,
                 datasets,
                 data_map,
                 data_dims,
                 seq_list_file,
                 seq_lens_file=None,
                 data_dtypes=None,
                 window=1,
                 **kwargs):
        """
    :param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files'
    :param dict[str,(str,str)] data_map: self-data-key -> (dataset-key, dataset-data-key).
      Should contain 'data' as key. Also defines the target-list, which is all except 'data'.
    :param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr).
    :param str seq_list_file: filename. pickle. dict[str,list[str]], dataset-key -> list of sequence tags. If tag is the same for all datasets a line-separated plain text file can be used.
    :param str seq_lens_file: filename. json. dict[str,dict[str,int]], seq-tag -> data-key -> len. Use if getting sequence length from loading data is too costly.
    :param dict[str,str] data_dtypes: self-data-key -> dtype. automatic if not specified
    """
        assert window == 1  # not implemented
        super(MetaDataset, self).__init__(**kwargs)
        assert self.shuffle_frames_of_nseqs == 0  # not implemented. anyway only for non-recurrent nets

        self.data_map = data_map
        self.dataset_keys = set([m[0] for m in self.data_map.values()])
        ":type: set[str]"
        self.data_keys = set(self.data_map.keys())
        ":type: set[str]"
        assert "data" in self.data_keys
        self.target_list = sorted(self.data_keys - {"data"})
        self.default_dataset_key = self.data_map["data"][0]

        if seq_list_file.endswith(".pkl"):
            import pickle
            seq_list = pickle.load(open(seq_list_file, 'rb'))
        else:
            seq_list = open(seq_list_file).read().splitlines()
        assert isinstance(seq_list, (list, dict))
        if isinstance(seq_list, list):
            seq_list = {key: seq_list for key in self.dataset_keys}
        self.seq_list_original = seq_list  # type: dict[str,list[str]]  # dataset key -> seq list
        self._num_seqs = len(self.seq_list_original[self.default_dataset_key])
        for key in self.dataset_keys:
            assert len(self.seq_list_original[key]) == self._num_seqs
        self.tag_idx = {
            tag: idx
            for (idx, tag) in enumerate(self.seq_list_original[
                self.default_dataset_key])
        }

        data_dims = convert_data_dims(data_dims)
        self.data_dims = data_dims
        assert "data" in data_dims
        for key in self.target_list:
            assert key in data_dims
        self.num_inputs = data_dims["data"][0]
        self.num_outputs = data_dims

        self.data_dtypes = {
            data_key: _select_dtype(data_key, data_dims, data_dtypes)
            for data_key in self.data_keys
        }

        if seq_lens_file:
            seq_lens = load_json(filename=seq_lens_file)
            assert isinstance(seq_lens, dict)
            # dict[str,NumbersDict], seq-tag -> data-key -> len
            self._seq_lens = {
                tag: NumbersDict(l)
                for (tag, l) in seq_lens.items()
            }
        else:
            self._seq_lens = None

        if self._seq_lens:
            self._num_timesteps = sum([
                self._seq_lens[s]
                for s in self.seq_list_original[self.default_dataset_key]
            ])
        else:
            self._num_timesteps = None

        # Will only init the needed datasets.
        self.datasets = {
            key:
            init_dataset(datasets[key],
                         extra_kwargs={"name": "%s_%s" % (self.name, key)})
            for key in self.dataset_keys
        }
        for data_key in self.data_keys:
            dataset_key, dataset_data_key = self.data_map[data_key]
            dataset = self.datasets[dataset_key]
            if dataset_data_key in dataset.labels:
                self.labels[data_key] = dataset.labels[dataset_data_key]