def _assert_same_size(outputs: TensorStruct, output_size: OutputSize): r"""Check if outputs match output_size Args: outputs: A ``Tensor`` or a (nested) ``tuple`` of tensors output_size: Can be an Integer, a ``torch.Size``, or a (nested) ``tuple`` of Integers or ``torch.Size``. """ flat_output_size = nest.flatten(output_size) flat_output = nest.flatten(outputs) for (output, size) in zip(flat_output, flat_output_size): if isinstance(size, torch.Size): if output[0].size() != size: raise ValueError("The output size does not match" "the the required output_size") elif output[0].size()[-1] != size: raise ValueError( "The output size does not match the the required output_size")
def _sum_output_size(output_size: OutputSize) -> int: r"""Return sum of all dim values in :attr:`output_size` Args: output_size: Can be an ``Integer``, a ``torch.Size``, or a (nested) ``tuple`` of ``Integers`` or ``torch.Size``. """ flat_output_size = nest.flatten(output_size) if isinstance(flat_output_size[0], torch.Size): size_list = [0] * len(flat_output_size) for (i, shape) in enumerate(flat_output_size): size_list[i] = np.prod([dim for dim in shape]) else: size_list = flat_output_size ret = sum(size_list) return ret
def forward(self, # type: ignore inputs: TensorStruct) -> Any: r"""Transforms inputs to have the same structure as with :attr:`output_size`. Values of the inputs are not changed. :attr:`inputs` must either have the same structure, or have the same number of elements with :attr:`output_size`. Args: inputs: The input (structure of) tensor to pass forward. :returns: A (structure of) tensors that re-packs :attr:`inputs` to have the specified structure of :attr:`output_size`. """ output = inputs flat_input = nest.flatten(inputs) output = nest.pack_sequence_as( self._output_size, flat_input) return output
def test_constant_connector(self): r"""Tests the logic of :class:`~texar.torch.modules.connectors.ConstantConnector`. """ state_size = namedtuple('LSTMStateTuple', ['h', 'c'])(256, 256) connector_0 = ConstantConnector(state_size) decoder_initial_state_0 = connector_0(self._batch_size) connector_1 = ConstantConnector( state_size, hparams={"value": 1.}) decoder_initial_state_1 = connector_1(self._batch_size) s_0 = decoder_initial_state_0 s_1 = decoder_initial_state_1 self.assertEqual(nest.flatten(s_0)[0][0, 0], 0.) self.assertEqual(nest.flatten(s_1)[0][0, 0], 1.) size = torch.Size([1, 2, 3]) connector_size_0 = ConstantConnector( size, hparams={"value": 2.}) size_tensor = connector_size_0(self._batch_size) self.assertEqual( torch.Size([self._batch_size]) + size, size_tensor.size()) self.assertEqual(size_tensor[0][0, 0, 0], 2.) tuple_size_1 = (torch.Size([1, 2, 3]), torch.Size([4, 5, 6])) connector_size_1 = ConstantConnector( tuple_size_1, hparams={"value": 3.}) tuple_size_tensor = connector_size_1(self._batch_size) tuple_size_tensor_0 = tuple_size_tensor[0] tuple_size_tensor_1 = tuple_size_tensor[1] self.assertEqual( torch.Size([self._batch_size]) + torch.Size([1, 2, 3]), tuple_size_tensor_0.size()) self.assertEqual(tuple_size_tensor_0[0][0, 0, 0], 3.) self.assertEqual( torch.Size([self._batch_size]) + torch.Size([4, 5, 6]), tuple_size_tensor_1.size()) self.assertEqual(tuple_size_tensor_1[0][0, 0, 0], 3.) tuple_size_2 = (5, 10) connector_size_2 = ConstantConnector( tuple_size_2, hparams={"value": 4.}) tuple_size_tensor = connector_size_2(self._batch_size) tuple_size_tensor_0 = tuple_size_tensor[0] tuple_size_tensor_1 = tuple_size_tensor[1] self.assertEqual( torch.Size([self._batch_size]) + torch.Size([5]), tuple_size_tensor_0.size()) self.assertEqual(tuple_size_tensor_0[0][0], 4.) self.assertEqual( torch.Size([self._batch_size]) + torch.Size([10]), tuple_size_tensor_1.size()) self.assertEqual(tuple_size_tensor_1[0][0], 4.) tuple_size_3 = (torch.Size([1, 2, 3]), 10) connector_size_3 = ConstantConnector( tuple_size_3, hparams={"value": 4.}) tuple_size_tensor = connector_size_3(self._batch_size) tuple_size_tensor_0 = tuple_size_tensor[0] tuple_size_tensor_1 = tuple_size_tensor[1] self.assertEqual( torch.Size([self._batch_size]) + torch.Size([1, 2, 3]), tuple_size_tensor_0.size()) self.assertEqual(tuple_size_tensor_0[0][0, 0, 0], 4.) self.assertEqual( torch.Size([self._batch_size]) + torch.Size([10]), tuple_size_tensor_1.size()) self.assertEqual(tuple_size_tensor_1[0][0], 4.)
def _mlp_transform(inputs: TensorStruct, output_size: OutputSize, linear_layer: Optional[LinearLayer] = None, activation_fn: Optional[ActivationFn] = None, ) -> Any: r"""Transforms inputs through a fully-connected layer that creates the output with specified size. Args: inputs: A ``Tensor`` of shape ``[batch_size, ..., finale_state]`` (i.e., batch-major), or a (nested) tuple of such elements. A Tensor or a (nested) tuple of Tensors with shape ``[max_time, batch_size, ...]`` (i.e., time-major) can be transposed to batch-major using :func:`~texar.torch.utils.transpose_batch_time` prior to this function. output_size: Can be an ``Integer``, a ``torch.Size``, or a (nested) ``tuple`` of ``Integers`` or ``torch.Size``. activation_fn: Activation function applied to the output. :returns: If :attr:`output_size` is an ``Integer`` or a ``torch.Size``, returns a ``Tensor`` of shape ``[batch_size, *, output_size]``. If :attr:`output_size` is a ``tuple`` of Integers or torch.Size, returns a ``tuple`` having the same structure as:attr:`output_size`, where each element ``Tensor`` has the same size as defined in :attr:`output_size`. """ # Flatten inputs flat_input = nest.flatten(inputs) if len(flat_input[0].size()) == 1: batch_size = 1 else: batch_size = flat_input[0].size(0) flat_input = [x.view(-1, x.size(-1)) for x in flat_input] concat_input = torch.cat(flat_input, 1) # Get output dimension flat_output_size = nest.flatten(output_size) if isinstance(flat_output_size[0], torch.Size): size_list = [0] * len(flat_output_size) for (i, shape) in enumerate(flat_output_size): size_list[i] = np.prod([dim for dim in shape]) else: size_list = flat_output_size fc_output = concat_input if linear_layer is not None: fc_output = linear_layer(fc_output) if activation_fn is not None: fc_output = activation_fn(fc_output) flat_output = torch.split(fc_output, size_list, dim=1) flat_output = list(flat_output) for i, _ in enumerate(flat_output): final_state = flat_output[i].size(-1) flat_output[i] = flat_output[i].view(batch_size, -1, final_state) flat_output[i] = torch.squeeze(flat_output[i], 1) if isinstance(flat_output_size[0], torch.Size): for (i, shape) in enumerate(flat_output_size): flat_output[i] = torch.reshape( flat_output[i], (-1,) + shape) output = nest.pack_sequence_as(structure=output_size, flat_sequence=flat_output) return output