def decode(self, data, items): """Decodes the data to return the tensors specified by the list of items. Args: data: The TFRecords data(serialized example) to decode. items: A list of strings, each of which is the name of the resulting tensors to retrieve. Returns: A list of tensors, each of which corresponds to each item. """ # pylint: disable=too-many-branches feature_description = dict() for key, value in self._feature_original_types.items(): shape = [] if len(value) == 3: if isinstance(value[-1], int): shape = [value[-1]] elif isinstance(value[-1], list): shape = value if len(value) < 2 or value[1] == 'FixedLenFeature': feature_description.update({ key: tf.FixedLenFeature(shape, dtypes.get_tf_dtype(value[0])) }) elif value[1] == 'VarLenFeature': feature_description.update( {key: tf.VarLenFeature(dtypes.get_tf_dtype(value[0]))}) decoded_data = tf.parse_single_example(data, feature_description) # Handle TFRecords containing images if isinstance(self._image_options, dict): self._decode_image_str_byte(self._image_options, decoded_data) elif isinstance(self._image_options, HParams): self._decode_image_str_byte(self._image_options.todict(), decoded_data) elif isinstance(self._image_options, list): _ = list( map(lambda x: self._decode_image_str_byte(x, decoded_data), self._image_options)) # Convert Dtypes for key, value in self._feature_convert_types.items(): from_type = decoded_data[key].dtype to_type = dtypes.get_tf_dtype(value) if from_type is to_type: continue elif to_type is tf.string: decoded_data[key] = tf.dtypes.as_string(decoded_data[key]) elif from_type is tf.string: decoded_data[key] = tf.string_to_number( decoded_data[key], to_type) else: decoded_data[key] = tf.cast(decoded_data[key], to_type) outputs = decoded_data return [outputs[item] for item in items]
def _decode_numpy_ndarray_str_byte(self, numpy_option_feature, decoded_data): numpy_key = numpy_option_feature.get('numpy_ndarray_name') if numpy_key is None: return shape = numpy_option_feature.get('shape') dtype = numpy_option_feature.get('dtype') dtype = dtypes.get_tf_dtype(dtype) numpy_byte = decoded_data.get(numpy_key) numpy_ndarray = tf.decode_raw(numpy_byte, dtype) numpy_ndarray = tf.reshape(numpy_ndarray, shape) decoded_data[numpy_key] = numpy_ndarray
def _build(self, inputs, mode=None): """Takes in states and outputs actions. Args: inputs: Inputs to the policy network with the first dimension the batch dimension. mode (optional): A tensor taking value in :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including `TRAIN`, `EVAL`, and `PREDICT`. If `None`, :func:`texar.global_mode` is used. Returns A `dict` including fields `"logits"`, `"action"`, and `"dist"`, where - **"logits"**: A Tensor of shape \ `[batch_size] + action_space size` used for categorical \ distribution sampling. - **"action"**: A Tensor of shape \ `[batch_size] + action_space.shape`. - **"dist"**: The \ :tf_main:`Categorical <distributions/Categorical>` based on the \ logits. """ logits = self._network(inputs, mode=mode) dkwargs = self._hparams.distribution_kwargs.todict() dkwargs['dtype'] = get_tf_dtype(dkwargs['dtype']) dist = tf.distributions.Categorical(logits=logits, **dkwargs) action = dist.sample() to_shape = [-1] # for batch dimension to_shape.extend(list(self._action_space.shape)) action = tf.reshape(action, to_shape) outputs = {"logits": logits, "action": action, "dist": dist} if not self._built: self._add_internal_trainable_variables() self._add_trainable_variable(self._network.trainable_variables) self._built = True return outputs
def _build(self, inputs, mode=None): logits = self._network(inputs, mode=mode) dkwargs = self._hparams.distribution_kwargs.todict() dkwargs['dtype'] = get_tf_dtype(dkwargs['dtype']) dist = tf.distributions.Categorical(logits=logits, **dkwargs) action = dist.sample() action = tf.reshape(action, self._action_space.shape) outputs = dict( logits=logits, action=action, dist=dist ) if not self._built: self._add_internal_trainable_variables() self._add_trainable_variable(self._network.trainable_variables) self._built = True return outputs
def _decode_numpy_ndarray_str_byte(self, numpy_option_feature, decoded_data): numpy_key = numpy_option_feature.get('numpy_ndarray_name') if numpy_key is None: return shape = numpy_option_feature.get('shape') dtype = numpy_option_feature.get('dtype') dtype = dtypes.get_tf_dtype(dtype) numpy_byte = decoded_data.get(numpy_key) numpy_ndarray = tf.decode_raw(numpy_byte, dtype) raw_shape = tf.cast( tf.sqrt(tf.cast(tf.shape(numpy_ndarray), tf.float32)), tf.int32) #[] minus_shape = shape[0] - raw_shape[0] #int numpy_ndarray = tf.reshape(numpy_ndarray, [raw_shape[0], raw_shape[0]]) numpy_ndarray = tf.pad(numpy_ndarray, [[0, minus_shape], [0, minus_shape]], "CONSTANT") #numpy_ndarray = tf.reshape(numpy_ndarray, shape) ###debug decoded_data[numpy_key] = numpy_ndarray