示例#1
0
 def init_from_dataset(self, dataset):
   """
   :param Dataset.Dataset dataset:
   """
   target_keys = list(dataset.get_target_list())
   if target_keys:
     if "classes" in target_keys:
       self.default_target = "classes"
     else:
       self.default_target = target_keys[0]
   data_keys = list(dataset.get_data_keys())
   input_keys = [key for key in data_keys if key not in target_keys]
   if input_keys:
     if "data" in input_keys:
       self.default_input = "data"
     else:
       self.default_input = input_keys[0]
   for key in data_keys:
     if key in dataset.get_target_list():
       available_for_inference = False
     else:
       available_for_inference = True
     dim = dataset.get_data_dim(key)
     shape = [None] + list(dataset.get_data_shape(key))
     sparse = dataset.is_data_sparse(key)
     dtype = dataset.get_data_dtype(key)
     self.data[key] = Data(
       name=key, auto_create_placeholders=True, batch_dim_axis=0, time_dim_axis=1,
       shape=shape, dim=dim, sparse=sparse, dtype=dtype,
       available_for_inference=available_for_inference)
示例#2
0
 def register_data_from_dict(self, data):
     """
 :param dict[str,dict[str]] data: init kwargs for Data
 """
     for key, value in data.items():
         self.data[key] = Data(name=key,
                               auto_create_placeholders=True,
                               **value)
示例#3
0
 def get_extern_data(self, key, mark_data_key_as_used=True):
     """
 Returns Data and add the key to self.used_data_keys if mark_data_key_as_used.
 :param str key:
 :param bool mark_data_key_as_used:
 :rtype: Data
 """
     if mark_data_key_as_used:
         self.used_data_keys.add(key)
     if key == "seq_idx" and key not in self.extern_data.data:
         self.extern_data.data[key] = Data(name="seq_idx",
                                           shape=(),
                                           dtype="int32",
                                           sparse=False,
                                           auto_create_placeholders=True)
     if key == "seq_tag" and key not in self.extern_data.data:
         self.extern_data.data[key] = Data(name="seq_tag",
                                           shape=(),
                                           dtype="string",
                                           auto_create_placeholders=True)
     return self.extern_data.get_data(key)
示例#4
0
 def init_from_config(self, config):
   """
   :param Config.Config config:
   """
   from NetworkDescription import LayerNetworkDescription
   data_dims = LayerNetworkDescription.tf_extern_data_types_from_config(config)
   for key, init_args in data_dims.items():
     # In Returnn with Theano, we usually have the shape (time,batch,feature).
     # In TensorFlow, the default is (batch,time,feature).
     # This is also what we use here, i.e.:
     # batch_dim_axis=0, time_dim_axis=1. See TFEngine.DataProvider._get_next_batch().
     self.data[key] = Data(name=key, auto_create_placeholders=True, **init_args)
   self.default_target = config.value('target', 'classes')