def _normalize_sequence(length, inputs, layout, merge, in_layout=None): assert inputs is not None, \ "unroll(inputs=None) has been deprecated. " \ "Please create input variables outside unroll." axis = layout.find('T') in_axis = in_layout.find('T') if in_layout is not None else axis if isinstance(inputs, symbol.Symbol): if merge is False: assert len(inputs.list_outputs()) == 1, \ "unroll doesn't allow grouped symbol as input. Please convert " \ "to list with list(inputs) first or let unroll handle splitting." inputs = list( symbol.split(inputs, axis=in_axis, num_outputs=length, squeeze_axis=1)) else: assert length is None or len(inputs) == length if merge is True: inputs = [symbol.expand_dims(i, axis=axis) for i in inputs] inputs = symbol.Concat(*inputs, dim=axis) in_axis = axis if isinstance(inputs, symbol.Symbol) and axis != in_axis: inputs = symbol.swapaxes(inputs, dim0=axis, dim1=in_axis) return inputs, axis
def __call__(self, inputs, states): # inputs: (batch_size, decoder_num_hidden) # for dot attention decoder_num_hidden must equal encoder_num_hidden if len(states) > 1: states = [symbol.concat(*states, dim=1)] # source: (batch_size, seq_len, encoder_num_hidden) source = states[0] # (batch_size, decoder_num_hidden, 1) inputs = symbol.expand_dims(inputs, axis=2) # (batch_size, seq_len, 1) scores = symbol.batch_dot(source, inputs) # (batch_size, encoder_num_hidden) return _attention_pooling(source, scores), states
def __call__(self, inputs, states): return inputs, states + [symbol.expand_dims(inputs, axis=1)]