コード例 #1
0
def _assert_same_size(outputs: TensorStruct,
                      output_size: OutputSize):
    r"""Check if outputs match output_size

    Args:
        outputs: A ``Tensor`` or a (nested) ``tuple`` of tensors
        output_size: Can be an Integer, a ``torch.Size``, or a (nested)
            ``tuple`` of Integers or ``torch.Size``.
    """
    flat_output_size = nest.flatten(output_size)
    flat_output = nest.flatten(outputs)

    for (output, size) in zip(flat_output, flat_output_size):

        if isinstance(size, torch.Size):
            if output[0].size() != size:
                raise ValueError("The output size does not match"
                                 "the the required output_size")
        elif output[0].size()[-1] != size:
            raise ValueError(
                "The output size does not match the the required output_size")
コード例 #2
0
def _sum_output_size(output_size: OutputSize) -> int:
    r"""Return sum of all dim values in :attr:`output_size`

    Args:
        output_size: Can be an ``Integer``, a ``torch.Size``, or a (nested)
            ``tuple`` of ``Integers`` or ``torch.Size``.
    """
    flat_output_size = nest.flatten(output_size)

    if isinstance(flat_output_size[0], torch.Size):
        size_list = [0] * len(flat_output_size)
        for (i, shape) in enumerate(flat_output_size):
            size_list[i] = np.prod([dim for dim in shape])
    else:
        size_list = flat_output_size
    ret = sum(size_list)
    return ret
コード例 #3
0
    def forward(self,  # type: ignore
                inputs: TensorStruct) -> Any:
        r"""Transforms inputs to have the same structure as with
        :attr:`output_size`. Values of the inputs are not changed.
        :attr:`inputs` must either have the same structure, or have the same
        number of elements with :attr:`output_size`.

        Args:
            inputs: The input (structure of) tensor to pass forward.

        :returns:
            A (structure of) tensors that re-packs :attr:`inputs` to have
            the specified structure of :attr:`output_size`.
        """
        output = inputs
        flat_input = nest.flatten(inputs)
        output = nest.pack_sequence_as(
            self._output_size, flat_input)

        return output
コード例 #4
0
    def test_constant_connector(self):
        r"""Tests the logic of
        :class:`~texar.torch.modules.connectors.ConstantConnector`.
        """

        state_size = namedtuple('LSTMStateTuple', ['h', 'c'])(256, 256)
        connector_0 = ConstantConnector(state_size)
        decoder_initial_state_0 = connector_0(self._batch_size)
        connector_1 = ConstantConnector(
            state_size, hparams={"value": 1.})
        decoder_initial_state_1 = connector_1(self._batch_size)

        s_0 = decoder_initial_state_0
        s_1 = decoder_initial_state_1
        self.assertEqual(nest.flatten(s_0)[0][0, 0], 0.)
        self.assertEqual(nest.flatten(s_1)[0][0, 0], 1.)

        size = torch.Size([1, 2, 3])
        connector_size_0 = ConstantConnector(
            size, hparams={"value": 2.})
        size_tensor = connector_size_0(self._batch_size)
        self.assertEqual(
            torch.Size([self._batch_size]) + size, size_tensor.size())
        self.assertEqual(size_tensor[0][0, 0, 0], 2.)

        tuple_size_1 = (torch.Size([1, 2, 3]), torch.Size([4, 5, 6]))
        connector_size_1 = ConstantConnector(
            tuple_size_1, hparams={"value": 3.})
        tuple_size_tensor = connector_size_1(self._batch_size)
        tuple_size_tensor_0 = tuple_size_tensor[0]
        tuple_size_tensor_1 = tuple_size_tensor[1]
        self.assertEqual(
            torch.Size([self._batch_size]) + torch.Size([1, 2, 3]),
            tuple_size_tensor_0.size())
        self.assertEqual(tuple_size_tensor_0[0][0, 0, 0], 3.)
        self.assertEqual(
            torch.Size([self._batch_size]) + torch.Size([4, 5, 6]),
            tuple_size_tensor_1.size())
        self.assertEqual(tuple_size_tensor_1[0][0, 0, 0], 3.)

        tuple_size_2 = (5, 10)
        connector_size_2 = ConstantConnector(
            tuple_size_2, hparams={"value": 4.})
        tuple_size_tensor = connector_size_2(self._batch_size)
        tuple_size_tensor_0 = tuple_size_tensor[0]
        tuple_size_tensor_1 = tuple_size_tensor[1]
        self.assertEqual(
            torch.Size([self._batch_size]) + torch.Size([5]),
            tuple_size_tensor_0.size())
        self.assertEqual(tuple_size_tensor_0[0][0], 4.)
        self.assertEqual(
            torch.Size([self._batch_size]) + torch.Size([10]),
            tuple_size_tensor_1.size())
        self.assertEqual(tuple_size_tensor_1[0][0], 4.)

        tuple_size_3 = (torch.Size([1, 2, 3]), 10)
        connector_size_3 = ConstantConnector(
            tuple_size_3, hparams={"value": 4.})
        tuple_size_tensor = connector_size_3(self._batch_size)
        tuple_size_tensor_0 = tuple_size_tensor[0]
        tuple_size_tensor_1 = tuple_size_tensor[1]
        self.assertEqual(
            torch.Size([self._batch_size]) + torch.Size([1, 2, 3]),
            tuple_size_tensor_0.size())
        self.assertEqual(tuple_size_tensor_0[0][0, 0, 0], 4.)
        self.assertEqual(
            torch.Size([self._batch_size]) + torch.Size([10]),
            tuple_size_tensor_1.size())
        self.assertEqual(tuple_size_tensor_1[0][0], 4.)
コード例 #5
0
def _mlp_transform(inputs: TensorStruct,
                   output_size: OutputSize,
                   linear_layer: Optional[LinearLayer] = None,
                   activation_fn: Optional[ActivationFn] = None,
                   ) -> Any:
    r"""Transforms inputs through a fully-connected layer that creates
    the output with specified size.

    Args:
        inputs: A ``Tensor`` of shape ``[batch_size, ..., finale_state]``
            (i.e., batch-major), or a (nested) tuple of such elements.
            A Tensor or a (nested) tuple of Tensors with shape
            ``[max_time, batch_size, ...]`` (i.e., time-major) can
            be transposed to batch-major using
            :func:`~texar.torch.utils.transpose_batch_time` prior to this
            function.
        output_size: Can be an ``Integer``, a ``torch.Size``, or a (nested)
            ``tuple`` of ``Integers`` or ``torch.Size``.
        activation_fn: Activation function applied to the output.

    :returns:
        If :attr:`output_size` is an ``Integer`` or a ``torch.Size``,
        returns a ``Tensor`` of shape ``[batch_size, *, output_size]``.
        If :attr:`output_size` is a ``tuple`` of Integers or torch.Size,
        returns a ``tuple`` having the same structure as:attr:`output_size`,
        where each element ``Tensor`` has the same size as
        defined in :attr:`output_size`.
    """
    # Flatten inputs
    flat_input = nest.flatten(inputs)
    if len(flat_input[0].size()) == 1:
        batch_size = 1
    else:
        batch_size = flat_input[0].size(0)
    flat_input = [x.view(-1, x.size(-1)) for x in flat_input]
    concat_input = torch.cat(flat_input, 1)
    # Get output dimension
    flat_output_size = nest.flatten(output_size)

    if isinstance(flat_output_size[0], torch.Size):
        size_list = [0] * len(flat_output_size)
        for (i, shape) in enumerate(flat_output_size):
            size_list[i] = np.prod([dim for dim in shape])
    else:
        size_list = flat_output_size

    fc_output = concat_input
    if linear_layer is not None:
        fc_output = linear_layer(fc_output)
    if activation_fn is not None:
        fc_output = activation_fn(fc_output)

    flat_output = torch.split(fc_output, size_list, dim=1)
    flat_output = list(flat_output)
    for i, _ in enumerate(flat_output):
        final_state = flat_output[i].size(-1)
        flat_output[i] = flat_output[i].view(batch_size, -1, final_state)
        flat_output[i] = torch.squeeze(flat_output[i], 1)

    if isinstance(flat_output_size[0], torch.Size):
        for (i, shape) in enumerate(flat_output_size):
            flat_output[i] = torch.reshape(
                flat_output[i], (-1,) + shape)

    output = nest.pack_sequence_as(structure=output_size,
                                   flat_sequence=flat_output)
    return output