def sort_and_run_forward(self, module, inputs, mask, hidden_state=None): """ Parameters ---------- module : ``Callable[[PackedSequence, Optional[RnnState]], Tuple[Union[PackedSequence, torch.Tensor], RnnState]]``, required. A function to run on the inputs. In most cases, this is a ``torch.nn.Module``. inputs : ``torch.Tensor``, required. A tensor of shape ``(batch_size, sequence_length, embedding_size)`` representing the inputs to the Encoder. mask : ``torch.Tensor``, required. A tensor of shape ``(batch_size, sequence_length)``, representing masked and non-masked elements of the sequence for each element in the batch. hidden_state : ``Optional[RnnState]``, (default = None). A single tensor of shape (num_layers, batch_size, hidden_size) representing the state of an RNN with or a tuple of tensors of shapes (num_layers, batch_size, hidden_size) and (num_layers, batch_size, memory_size), representing the hidden state and memory state of an LSTM-like RNN. Returns ------- module_output : ``Union[torch.Tensor, PackedSequence]``. A Tensor or PackedSequence representing the output of the Pytorch Module. The batch size dimension will be equal to ``num_valid``, as sequences of zero length are clipped off before the module is called, as Pytorch cannot handle zero length sequences. final_states : ``Optional[RnnState]`` A Tensor representing the hidden state of the Pytorch Module. This can either be a single tensor of shape (num_layers, num_valid, hidden_size), for instance in the case of a GRU, or a tuple of tensors, such as those required for an LSTM. restoration_indices : ``torch.LongTensor`` A tensor of shape ``(batch_size,)``, describing the re-indexing required to transform the outputs back to their original batch order. """ xp = self.xp xs = inputs batch_lengths = [m.sum() for m in mask] indices = argsort_list_descent(batch_lengths) indices_array = xp.asarray(indices) xs = F.permutate(xs, indices_array, axis=0, inv=False) mask = mask[indices_array] if hidden_state: h, c = hidden_state h = F.permutate(h, indices_array, axis=1, inv=False) c = F.permutate(c, indices_array, axis=1, inv=False) initial_state = (h, c) # TODO: test else: initial_state = None batch_lengths = [m.sum() for m in mask] module_output, final_states = module(xs, batch_lengths=batch_lengths, initial_state=initial_state) restoration_indices = indices_array return module_output, final_states, restoration_indices
def argmax_crf1d(cost, xs): indices = argsort_list_descent(xs) xs = permutate_list(xs, indices, inv=False) xs = F.transpose_sequence(xs) score, path = F.argmax_crf1d(cost, xs) path = F.transpose_sequence(path) path = permutate_list(path, indices, inv=True) score = F.permutate(score, indices, inv=True) return score, path
def check_forward(self, x_data, ind_data): x = chainer.Variable(x_data) indices = chainer.Variable(ind_data) y = functions.permutate(x, indices, axis=self.axis, inv=self.inv) y_cpu = cuda.to_cpu(y.data) y_cpu = numpy.rollaxis(y_cpu, axis=self.axis) x_data = numpy.rollaxis(self.x, axis=self.axis) for i, ind in enumerate(self.indices): if self.inv: numpy.testing.assert_array_equal(y_cpu[ind], x_data[i]) else: numpy.testing.assert_array_equal(y_cpu[i], x_data[ind])
def __call__(self, c, xs, train=True): """ The API is (almost) equivalent to NStepLSTM's. Just pass the list of variables, and they are encoded. """ inds = np.argsort([-len(x.data) for x in xs]).astype('i') xs_ = [xs[i] for i in inds] pool_in = self.convolution(xs_, train) c, hs = self.pooling(c, pool_in, train) # permutate the list back ret = [None] * len(inds) for i, idx in enumerate(inds): ret[idx] = hs[i] # permutate the cell state, too c = F.permutate(c, indices=inds, axis=0) return c, ret
def forward(self, ws, ss, ps, ls, dep_ts=None): batchsize, slen = ws.shape xp = chainer.cuda.get_array_module(ws[0]) wss = self.emb_word(ws) sss = F.reshape(self.emb_suf(ss), (batchsize, slen, 4 * self.afix_dim)) pss = F.reshape(self.emb_prf(ps), (batchsize, slen, 4 * self.afix_dim)) ins = F.dropout(F.concat([wss, sss, pss], 2), self.dropout_ratio, train=self.train) xs_f = F.transpose(ins, (1, 0, 2)) xs_b = xs_f[::-1] cx_f, hx_f, cx_b, hx_b = self._init_state(xp, batchsize) _, _, hs_f = self.lstm_f(hx_f, cx_f, xs_f, train=self.train) _, _, hs_b = self.lstm_b(hx_b, cx_b, xs_b, train=self.train) # (batch, length, hidden_dim) hs = F.transpose(F.concat([hs_f, hs_b[::-1]], 2), (1, 0, 2)) dep_ys = self.biaffine_arc( F.elu(F.dropout(self.arc_dep(hs), 0.32, train=self.train)), F.elu(F.dropout(self.arc_head(hs), 0.32, train=self.train))) if dep_ts is not None and random.random >= 0.5: heads = dep_ts else: heads = F.flatten(F.argmax(dep_ys, axis=2)) + \ xp.repeat(xp.arange(0, batchsize * slen, slen), slen) hs = F.reshape(hs, (batchsize * slen, -1)) heads = F.permutate( F.elu(F.dropout( self.rel_head(hs), 0.32, train=self.train)), heads) childs = F.elu(F.dropout(self.rel_dep(hs), 0.32, train=self.train)) cat_ys = self.biaffine_tag(childs, heads) dep_ys = F.split_axis(dep_ys, batchsize, 0) if batchsize > 1 else [dep_ys] dep_ys = [F.reshape(v, v.shape[1:])[:l, :l] for v, l in zip(dep_ys, ls)] cat_ys = F.split_axis(cat_ys, batchsize, 0) if batchsize > 1 else [cat_ys] cat_ys = [v[:l] for v, l in zip(cat_ys, ls)] return cat_ys, dep_ys
def forward(self, x, indices): return F.permutate(x, indices, **self.kwargs)
def check_invalid(self, x_data, ind_data): x = chainer.Variable(x_data) ind = chainer.Variable(ind_data) with self.assertRaises(ValueError): functions.permutate(x, ind)
def fun(x, ind): return functions.permutate(x, ind, self.axis, self.inv)
def forward(self, inputs, device): x, indices = inputs y = functions.permutate(x, indices, axis=self.axis, inv=self.inv) return y,
def permutate(self, order): for link in [self.enc1, self.enc2]: link.c = F.permutate(link.c, order) link.h = F.permutate(link.h, order)