def init_from_dataset(self, dataset): """ :param Dataset.Dataset dataset: """ target_keys = list(dataset.get_target_list()) if target_keys: if "classes" in target_keys: self.default_target = "classes" else: self.default_target = target_keys[0] data_keys = list(dataset.get_data_keys()) input_keys = [key for key in data_keys if key not in target_keys] if input_keys: if "data" in input_keys: self.default_input = "data" else: self.default_input = input_keys[0] for key in data_keys: if key in dataset.get_target_list(): available_for_inference = False else: available_for_inference = True dim = dataset.get_data_dim(key) shape = [None] + list(dataset.get_data_shape(key)) sparse = dataset.is_data_sparse(key) dtype = dataset.get_data_dtype(key) self.data[key] = Data( name=key, auto_create_placeholders=True, batch_dim_axis=0, time_dim_axis=1, shape=shape, dim=dim, sparse=sparse, dtype=dtype, available_for_inference=available_for_inference)
def register_data_from_dict(self, data): """ :param dict[str,dict[str]] data: init kwargs for Data """ for key, value in data.items(): self.data[key] = Data(name=key, auto_create_placeholders=True, **value)
def get_extern_data(self, key, mark_data_key_as_used=True): """ Returns Data and add the key to self.used_data_keys if mark_data_key_as_used. :param str key: :param bool mark_data_key_as_used: :rtype: Data """ if mark_data_key_as_used: self.used_data_keys.add(key) if key == "seq_idx" and key not in self.extern_data.data: self.extern_data.data[key] = Data(name="seq_idx", shape=(), dtype="int32", sparse=False, auto_create_placeholders=True) if key == "seq_tag" and key not in self.extern_data.data: self.extern_data.data[key] = Data(name="seq_tag", shape=(), dtype="string", auto_create_placeholders=True) return self.extern_data.get_data(key)
def init_from_config(self, config): """ :param Config.Config config: """ from NetworkDescription import LayerNetworkDescription data_dims = LayerNetworkDescription.tf_extern_data_types_from_config(config) for key, init_args in data_dims.items(): # In Returnn with Theano, we usually have the shape (time,batch,feature). # In TensorFlow, the default is (batch,time,feature). # This is also what we use here, i.e.: # batch_dim_axis=0, time_dim_axis=1. See TFEngine.DataProvider._get_next_batch(). self.data[key] = Data(name=key, auto_create_placeholders=True, **init_args) self.default_target = config.value('target', 'classes')