Beispiel #1
0
    def decode(self, data, items):
        """Decodes the data to return the tensors specified by the list of
        items.

        Args:
            data: The TFRecords data(serialized example) to decode.
            items: A list of strings, each of which is the name of the resulting
                tensors to retrieve.

        Returns:
            A list of tensors, each of which corresponds to each item.
        """
        # pylint: disable=too-many-branches
        feature_description = dict()
        for key, value in self._feature_original_types.items():
            shape = []
            if len(value) == 3:
                if isinstance(value[-1], int):
                    shape = [value[-1]]
                elif isinstance(value[-1], list):
                    shape = value
            if len(value) < 2 or value[1] == 'FixedLenFeature':
                feature_description.update({
                    key:
                    tf.FixedLenFeature(shape, dtypes.get_tf_dtype(value[0]))
                })
            elif value[1] == 'VarLenFeature':
                feature_description.update(
                    {key: tf.VarLenFeature(dtypes.get_tf_dtype(value[0]))})
        decoded_data = tf.parse_single_example(data, feature_description)

        # Handle TFRecords containing images
        if isinstance(self._image_options, dict):
            self._decode_image_str_byte(self._image_options, decoded_data)
        elif isinstance(self._image_options, HParams):
            self._decode_image_str_byte(self._image_options.todict(),
                                        decoded_data)
        elif isinstance(self._image_options, list):
            _ = list(
                map(lambda x: self._decode_image_str_byte(x, decoded_data),
                    self._image_options))

        # Convert Dtypes
        for key, value in self._feature_convert_types.items():
            from_type = decoded_data[key].dtype
            to_type = dtypes.get_tf_dtype(value)
            if from_type is to_type:
                continue
            elif to_type is tf.string:
                decoded_data[key] = tf.dtypes.as_string(decoded_data[key])
            elif from_type is tf.string:
                decoded_data[key] = tf.string_to_number(
                    decoded_data[key], to_type)
            else:
                decoded_data[key] = tf.cast(decoded_data[key], to_type)
        outputs = decoded_data
        return [outputs[item] for item in items]
Beispiel #2
0
    def _decode_numpy_ndarray_str_byte(self, numpy_option_feature,
                                       decoded_data):
        numpy_key = numpy_option_feature.get('numpy_ndarray_name')
        if numpy_key is None:
            return

        shape = numpy_option_feature.get('shape')
        dtype = numpy_option_feature.get('dtype')
        dtype = dtypes.get_tf_dtype(dtype)

        numpy_byte = decoded_data.get(numpy_key)
        numpy_ndarray = tf.decode_raw(numpy_byte, dtype)
        numpy_ndarray = tf.reshape(numpy_ndarray, shape)

        decoded_data[numpy_key] = numpy_ndarray
    def _build(self, inputs, mode=None):
        """Takes in states and outputs actions.

        Args:
            inputs: Inputs to the policy network with the first dimension
                the batch dimension.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
                `TRAIN`, `EVAL`, and `PREDICT`. If `None`,
                :func:`texar.global_mode` is used.

        Returns
            A `dict` including fields `"logits"`, `"action"`, and `"dist"`,
            where

            - **"logits"**: A Tensor of shape \
            `[batch_size] + action_space size` used for categorical \
            distribution sampling.
            - **"action"**: A Tensor of shape \
            `[batch_size] + action_space.shape`.
            - **"dist"**: The \
            :tf_main:`Categorical <distributions/Categorical>` based on the \
            logits.
        """
        logits = self._network(inputs, mode=mode)

        dkwargs = self._hparams.distribution_kwargs.todict()
        dkwargs['dtype'] = get_tf_dtype(dkwargs['dtype'])
        dist = tf.distributions.Categorical(logits=logits, **dkwargs)

        action = dist.sample()
        to_shape = [-1]  # for batch dimension
        to_shape.extend(list(self._action_space.shape))
        action = tf.reshape(action, to_shape)

        outputs = {"logits": logits, "action": action, "dist": dist}

        if not self._built:
            self._add_internal_trainable_variables()
            self._add_trainable_variable(self._network.trainable_variables)
            self._built = True

        return outputs
Beispiel #4
0
    def _build(self, inputs, mode=None):
        logits = self._network(inputs, mode=mode)

        dkwargs = self._hparams.distribution_kwargs.todict()
        dkwargs['dtype'] = get_tf_dtype(dkwargs['dtype'])
        dist = tf.distributions.Categorical(logits=logits, **dkwargs)

        action = dist.sample()
        action = tf.reshape(action, self._action_space.shape)

        outputs = dict(
            logits=logits,
            action=action,
            dist=dist
        )

        if not self._built:
            self._add_internal_trainable_variables()
            self._add_trainable_variable(self._network.trainable_variables)
            self._built = True

        return outputs
    def _decode_numpy_ndarray_str_byte(self, numpy_option_feature,
                                       decoded_data):
        numpy_key = numpy_option_feature.get('numpy_ndarray_name')
        if numpy_key is None:
            return

        shape = numpy_option_feature.get('shape')
        dtype = numpy_option_feature.get('dtype')
        dtype = dtypes.get_tf_dtype(dtype)

        numpy_byte = decoded_data.get(numpy_key)
        numpy_ndarray = tf.decode_raw(numpy_byte, dtype)

        raw_shape = tf.cast(
            tf.sqrt(tf.cast(tf.shape(numpy_ndarray), tf.float32)),
            tf.int32)  #[]
        minus_shape = shape[0] - raw_shape[0]  #int
        numpy_ndarray = tf.reshape(numpy_ndarray, [raw_shape[0], raw_shape[0]])
        numpy_ndarray = tf.pad(numpy_ndarray,
                               [[0, minus_shape], [0, minus_shape]],
                               "CONSTANT")
        #numpy_ndarray = tf.reshape(numpy_ndarray, shape) ###debug

        decoded_data[numpy_key] = numpy_ndarray