示例#1
0
    def sort_and_run_forward(self, module, inputs, mask, hidden_state=None):
        """
        Parameters
        ----------
        module : ``Callable[[PackedSequence, Optional[RnnState]],
                            Tuple[Union[PackedSequence, torch.Tensor], RnnState]]``, required.
            A function to run on the inputs. In most cases, this is a ``torch.nn.Module``.
        inputs : ``torch.Tensor``, required.
            A tensor of shape ``(batch_size, sequence_length, embedding_size)`` representing
            the inputs to the Encoder.
        mask : ``torch.Tensor``, required.
            A tensor of shape ``(batch_size, sequence_length)``, representing masked and
            non-masked elements of the sequence for each element in the batch.
        hidden_state : ``Optional[RnnState]``, (default = None).
            A single tensor of shape (num_layers, batch_size, hidden_size) representing the
            state of an RNN with or a tuple of
            tensors of shapes (num_layers, batch_size, hidden_size) and
            (num_layers, batch_size, memory_size), representing the hidden state and memory
            state of an LSTM-like RNN.

        Returns
        -------
        module_output : ``Union[torch.Tensor, PackedSequence]``.
            A Tensor or PackedSequence representing the output of the Pytorch Module.
            The batch size dimension will be equal to ``num_valid``, as sequences of zero
            length are clipped off before the module is called, as Pytorch cannot handle
            zero length sequences.
        final_states : ``Optional[RnnState]``
            A Tensor representing the hidden state of the Pytorch Module. This can either
            be a single tensor of shape (num_layers, num_valid, hidden_size), for instance in
            the case of a GRU, or a tuple of tensors, such as those required for an LSTM.
        restoration_indices : ``torch.LongTensor``
            A tensor of shape ``(batch_size,)``, describing the re-indexing required to transform
            the outputs back to their original batch order.
        """
        xp = self.xp
        xs = inputs

        batch_lengths = [m.sum() for m in mask]
        indices = argsort_list_descent(batch_lengths)
        indices_array = xp.asarray(indices)

        xs = F.permutate(xs, indices_array, axis=0, inv=False)
        mask = mask[indices_array]
        if hidden_state:
            h, c = hidden_state
            h = F.permutate(h, indices_array, axis=1, inv=False)
            c = F.permutate(c, indices_array, axis=1, inv=False)
            initial_state = (h, c)
            # TODO: test
        else:
            initial_state = None

        batch_lengths = [m.sum() for m in mask]
        module_output, final_states = module(xs,
                                             batch_lengths=batch_lengths,
                                             initial_state=initial_state)

        restoration_indices = indices_array
        return module_output, final_states, restoration_indices
示例#2
0
def argmax_crf1d(cost, xs):
    indices = argsort_list_descent(xs)
    xs = permutate_list(xs, indices, inv=False)
    xs = F.transpose_sequence(xs)
    score, path = F.argmax_crf1d(cost, xs)
    path = F.transpose_sequence(path)
    path = permutate_list(path, indices, inv=True)
    score = F.permutate(score, indices, inv=True)
    return score, path
示例#3
0
    def check_forward(self, x_data, ind_data):
        x = chainer.Variable(x_data)
        indices = chainer.Variable(ind_data)
        y = functions.permutate(x, indices, axis=self.axis, inv=self.inv)

        y_cpu = cuda.to_cpu(y.data)
        y_cpu = numpy.rollaxis(y_cpu, axis=self.axis)
        x_data = numpy.rollaxis(self.x, axis=self.axis)
        for i, ind in enumerate(self.indices):
            if self.inv:
                numpy.testing.assert_array_equal(y_cpu[ind], x_data[i])
            else:
                numpy.testing.assert_array_equal(y_cpu[i], x_data[ind])
示例#4
0
    def check_forward(self, x_data, ind_data):
        x = chainer.Variable(x_data)
        indices = chainer.Variable(ind_data)
        y = functions.permutate(x, indices, axis=self.axis, inv=self.inv)

        y_cpu = cuda.to_cpu(y.data)
        y_cpu = numpy.rollaxis(y_cpu, axis=self.axis)
        x_data = numpy.rollaxis(self.x, axis=self.axis)
        for i, ind in enumerate(self.indices):
            if self.inv:
                numpy.testing.assert_array_equal(y_cpu[ind], x_data[i])
            else:
                numpy.testing.assert_array_equal(y_cpu[i], x_data[ind])
示例#5
0
    def __call__(self, c, xs, train=True):
        """
        The API is (almost) equivalent to NStepLSTM's.
        Just pass the list of variables, and they are encoded.
        """
        inds = np.argsort([-len(x.data) for x in xs]).astype('i')
        xs_ = [xs[i] for i in inds]
        pool_in = self.convolution(xs_, train)
        c, hs = self.pooling(c, pool_in, train)

        # permutate the list back
        ret = [None] * len(inds)
        for i, idx in enumerate(inds):
            ret[idx] = hs[i]
        # permutate the cell state, too
        c = F.permutate(c, indices=inds, axis=0)
        return c, ret
示例#6
0
    def forward(self, ws, ss, ps, ls, dep_ts=None):
        batchsize, slen = ws.shape
        xp = chainer.cuda.get_array_module(ws[0])

        wss = self.emb_word(ws)
        sss = F.reshape(self.emb_suf(ss), (batchsize, slen, 4 * self.afix_dim))
        pss = F.reshape(self.emb_prf(ps), (batchsize, slen, 4 * self.afix_dim))
        ins = F.dropout(F.concat([wss, sss, pss], 2), self.dropout_ratio, train=self.train)
        xs_f = F.transpose(ins, (1, 0, 2))
        xs_b = xs_f[::-1]

        cx_f, hx_f, cx_b, hx_b = self._init_state(xp, batchsize)
        _, _, hs_f = self.lstm_f(hx_f, cx_f, xs_f, train=self.train)
        _, _, hs_b = self.lstm_b(hx_b, cx_b, xs_b, train=self.train)

        # (batch, length, hidden_dim)
        hs = F.transpose(F.concat([hs_f, hs_b[::-1]], 2), (1, 0, 2))

        dep_ys = self.biaffine_arc(
            F.elu(F.dropout(self.arc_dep(hs), 0.32, train=self.train)),
            F.elu(F.dropout(self.arc_head(hs), 0.32, train=self.train)))

        if dep_ts is not None and random.random >= 0.5:
            heads = dep_ts
        else:
            heads = F.flatten(F.argmax(dep_ys, axis=2)) + \
                    xp.repeat(xp.arange(0, batchsize * slen, slen), slen)

        hs = F.reshape(hs, (batchsize * slen, -1))
        heads = F.permutate(
                    F.elu(F.dropout(
                        self.rel_head(hs), 0.32, train=self.train)), heads)

        childs = F.elu(F.dropout(self.rel_dep(hs), 0.32, train=self.train))
        cat_ys = self.biaffine_tag(childs, heads)

        dep_ys = F.split_axis(dep_ys, batchsize, 0) if batchsize > 1 else [dep_ys]
        dep_ys = [F.reshape(v, v.shape[1:])[:l, :l] for v, l in zip(dep_ys, ls)]

        cat_ys = F.split_axis(cat_ys, batchsize, 0) if batchsize > 1 else [cat_ys]
        cat_ys = [v[:l] for v, l in zip(cat_ys, ls)]

        return cat_ys, dep_ys
示例#7
0
 def forward(self, x, indices):
     return F.permutate(x, indices, **self.kwargs)
示例#8
0
 def check_invalid(self, x_data, ind_data):
     x = chainer.Variable(x_data)
     ind = chainer.Variable(ind_data)
     with self.assertRaises(ValueError):
         functions.permutate(x, ind)
示例#9
0
 def fun(x, ind):
     return functions.permutate(x, ind, self.axis, self.inv)
示例#10
0
 def check_invalid(self, x_data, ind_data):
     x = chainer.Variable(x_data)
     ind = chainer.Variable(ind_data)
     with self.assertRaises(ValueError):
         functions.permutate(x, ind)
示例#11
0
 def fun(x, ind):
     return functions.permutate(x, ind, self.axis, self.inv)
示例#12
0
 def forward(self, inputs, device):
     x, indices = inputs
     y = functions.permutate(x, indices, axis=self.axis, inv=self.inv)
     return y,
示例#13
0
 def forward(self, inputs, device):
     x, indices = inputs
     y = functions.permutate(x, indices, axis=self.axis, inv=self.inv)
     return y,
示例#14
0
 def permutate(self, order):
     for link in [self.enc1, self.enc2]:
         link.c = F.permutate(link.c, order)
         link.h = F.permutate(link.h, order)