def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None): self.reset() inputs, axis = _normalize_sequence(length, inputs, layout, True) if axis == 1: warnings.warn("NTC layout detected. Consider using " "TNC for FusedRNNCell for faster speed") inputs = symbol.swapaxes(inputs, dim1=0, dim2=1) else: assert axis == 0, "Unsupported layout %s" % layout if begin_state is None: begin_state = self.begin_state() states = begin_state if self._mode == 'lstm': states = {'state': states[0], 'state_cell': states[1]} # pylint: disable=redefined-variable-type else: states = {'state': states[0]} rnn = symbol.RNN(data=inputs, parameters=self._parameter, state_size=self._num_hidden, num_layers=self._num_layers, bidirectional=self._bidirectional, p=self._dropout, state_outputs=self._get_next_state, mode=self._mode, name=self._prefix + 'rnn', **states) if not self._get_next_state: outputs, states = rnn, [] elif self._mode == 'lstm': outputs, states = rnn[0], [rnn[1], rnn[2]] else: outputs, states = rnn[0], [rnn[1]] if axis == 1: outputs = symbol.swapaxes(outputs, dim1=0, dim2=1) outputs, _ = _normalize_sequence(length, outputs, layout, merge_outputs) return outputs, states
def _normalize_sequence(length, inputs, layout, merge, in_layout=None): assert inputs is not None, \ "unroll(inputs=None) has been deprecated. " \ "Please create input variables outside unroll." axis = layout.find('T') in_axis = in_layout.find('T') if in_layout is not None else axis if isinstance(inputs, symbol.Symbol): if merge is False: assert len(inputs.list_outputs()) == 1, \ "unroll doesn't allow grouped symbol as input. Please convert " \ "to list with list(inputs) first or let unroll handle splitting." inputs = list( symbol.split(inputs, axis=in_axis, num_outputs=length, squeeze_axis=1)) else: assert length is None or len(inputs) == length if merge is True: inputs = [symbol.expand_dims(i, axis=axis) for i in inputs] inputs = symbol.Concat(*inputs, dim=axis) in_axis = axis if isinstance(inputs, symbol.Symbol) and axis != in_axis: inputs = symbol.swapaxes(inputs, dim0=axis, dim1=in_axis) return inputs, axis