def test_execute_non_placeholder(): """ Expect a failure if a non-input (Variable) is used as an argument to executor. """ N = ng.make_axis(length=1) x = ng.temporary([N]) y = ng.variable([N]) with pytest.raises(ValueError): with executor(x + y, x, y) as ex: ex
def __call__(self, in_obj, init_state=None): """ Sets shape based parameters of this layer given an input tuple or int or input layer. Arguments: in_obj (int, tuple, Layer or Tensor): object that provides shape information for layer init_state (tuple of Tensor): object that provides initial state, and in LSTM, it includes hidden state, and cell states Returns: rnn_out (Tensor): output """ # try to understand the axes from the input if init_state is not None: assert len(init_state) == 2 and init_state[0].axes == init_state[1].axes self.interpret_axes(in_obj, init_state[0]) else: self.interpret_axes(in_obj, init_state) # initialize the hidden states if init_state is not None: self.h_init = init_state[0] self.c_init = init_state[1] else: if self.reset_cells: self.h_init = ng.temporary(initial_value=0, axes=self.out_axes).named('h_init') self.c_init = ng.temporary(initial_value=0, axes=self.out_axes).named('c_init') else: self.h_init = ng.variable(initial_value=0, axes=self.out_axes).named('h_init') self.c_init = ng.variable(initial_value=0, axes=self.out_axes).named('c_init') # params are dictionary for i, f, o, g self.W_input = {k: ng.variable(axes=self.w_in_axes, initial_value=self.init, scope=self.scope). named("W_in_{}".format(k)) for k in self.metadata['gates']} self.W_recur = {k: ng.variable(axes=self.w_re_axes, initial_value=self.init_inner, scope=self.scope). named("W_re_{}".format(k)) for k in self.metadata['gates']} self.b = {k: ng.variable(axes=self.out_feature_axes, initial_value=0, scope=self.scope). named("bias_{}".format(k)) for k in self.metadata['gates']} h = self.h_init c = self.c_init h_list = [] c_list = [] # Compute feed forward weighted inputs # Batch norm is computed only on the weighted inputs # as in https://arxiv.org/abs/1510.01378 h_ff = dict() for k in self.metadata["gates"]: h_ff[k] = ng.dot(self.W_input[k], in_obj) if self.batch_norm is not None: h_ff[k] = self.batch_norm[k](h_ff[k]) # slice the weighted inputs into time slices h_ff = get_steps(h_ff, self.recurrent_axis, self.backward) # recurrent computation for i in range(self.recurrent_axis.length): with ng.metadata(recurrent_step=str(i)): [h, c] = self._step(h_ff[i], [h, c]) h_list.append(h) c_list.append(c) if self.return_sequence is True: if self.backward: h_list = h_list[::-1] c_list = c_list[::-1] lstm_out = ng.stack(h_list, self.recurrent_axis, pos=self.recurrent_axis_idx) else: lstm_out = h_list[-1] if self.reset_cells is True: return lstm_out else: return ng.sequential([ ng.doall([ ng.assign(self.h_init, h_list[-1]), ng.assign(self.c_init, c_list[-1]) ]), lstm_out ])
def train_outputs(self, in_obj, init_state=None): """ Sets shape based parameters of this layer given an input tuple or int or input layer. Arguments: in_obj (int, tuple, Layer or Tensor): object that provides shape information for layer init_state (tuple of Tensor): object that provides initial state, and in LSTM, it includes hidden state, and cell states Returns: rnn_out (Tensor): output """ # try to understand the axes from the input if init_state is not None: assert len( init_state) == 2 and init_state[0].axes == init_state[1].axes self.interpret_axes(in_obj, init_state[0]) else: self.interpret_axes(in_obj, init_state) # initialize the hidden states if init_state is not None: self.h_init = init_state[0] self.c_init = init_state[1] else: if self.reset_cells: self.h_init = ng.temporary( initial_value=0, axes=self.hidden_state_axes).named('h_init') self.c_init = ng.temporary( initial_value=0, axes=self.hidden_state_axes).named('c_init') else: self.h_init = ng.variable( initial_value=0, axes=self.hidden_state_axes).named('h_init') self.c_init = ng.variable( initial_value=0, axes=self.hidden_state_axes).named('c_init') # params are dictionary for i, f, o, g self.W_input = { k: ng.variable(axes=self.w_in_axes, initial_value=self.init).named("W_in_{}".format(k)) for k in self.metadata['gates'] } self.W_recur = { k: ng.variable(axes=self.w_re_axes, initial_value=self.init_inner).named( "W_re_{}".format(k)) for k in self.metadata['gates'] } self.b = { k: ng.variable(axes=self.hidden_axes, initial_value=0).named("bias_{}".format(k)) for k in self.metadata['gates'] } h = self.h_init c = self.c_init h_list = [] c_list = [] # feedforward computation in_s = get_steps(in_obj, self.recurrent_axis, self.backward) # recurrent computation for i in range(self.recurrent_axis.length): with ng.metadata(recurrent_step=str(i)): [h, c] = self._step(in_s[i], [h, c]) h_list.append(h) c_list.append(c) if self.return_sequence is True: if self.backward: h_list = h_list[::-1] c_list = c_list[::-1] lstm_out = ng.stack(h_list, self.recurrent_axis, pos=self.recurrent_axis_idx) else: lstm_out = h_list[-1] if self.reset_cells is True: return lstm_out else: return ng.sequential([ ng.doall([ ng.assign(self.h_init, h_list[-1]), ng.assign(self.c_init, c_list[-1]) ]), lstm_out ])