Beispiel #1
0
    def __getitem__(self, idx):
        """
        Item getter for batch number idx.

        Returns
        -------

        - src: torch.LongTensor of maximum size self.bptt x self.batch_size
        - trg: torch.LongTensor of maximum size self.bptt x self.batch_size,
            corresponding to a shifted batch
        """
        src, trg = None, None
        # multi-input
        if isinstance(self.data, tuple):
            src, trg = tuple(zip(*(self._get_batch(d, idx)
                                   for d in self.data)))
            # decompress from table
            if self.table is not None:
                # source
                src_pre, src_target, src_post = destruct(src, self.table_idx)
                src_target = tuple(
                    wrap_variables(t, self.evaluation, self.gpu)
                    for t in self.table.expand(src_target.data))
                src = tuple(src_pre + src_target + src_post)
                # target
                trg_pre, trg_target, trg_post = destruct(trg, self.table_idx)
                trg_target = tuple(
                    wrap_variables(t, self.evaluation, self.gpu)
                    for t in self.table.expand(trg_target.data))
                trg = tuple(trg_pre + trg_target + trg_post)
        # single-input
        else:
            src, trg = self._get_batch(self.data, idx)
        return src, trg
Beispiel #2
0
 def _get_batch(self, data, idx):
     """
     General function to get the source data to compute the batch. This
     should be overwritten by subclasses in which the source data isn't
     always stored in self.data, e.g. the case of cyclical subset access.
     """
     idx *= self.bptt
     seq_len = min(self.bptt, len(data) - 1 - idx)
     src_data, trg_data = data[idx:idx+seq_len], data[idx+1:idx+seq_len+1]
     src = wrap_variables(src_data, self.evaluation, self.gpu)
     trg = wrap_variables(trg_data, self.evaluation, self.gpu)
     return src, trg
Beispiel #3
0
def embed_single(model, target):
    model.eval()
    src_dict = model.encoder.embeddings.d
    inp = torch.LongTensor(list(src_dict.transform([target]))).transpose(0, 1)
    length = torch.LongTensor([len(target)]) + 2
    inp, length = u.wrap_variables((inp, length), volatile=True, gpu=False)
    _, embedding = model.encoder.forward(inp, lengths=None)
    return embedding.data.numpy()[0].flatten()
Beispiel #4
0
    def _pack(self, batch, dicts):
        # multi-input dataset
        if isinstance(batch[0], tuple):
            batches = list(zip(*batch))  # unpack batches
            if isinstance(dicts, MultiDict):
                dicts = dicts.dicts.values()
            out = tuple(d.pack(b) for (d, b) in zip(dicts, batches))
        else:
            out = dicts.pack(batch)

        return wrap_variables(out, volatile=self.evaluation, gpu=self.gpu)
Beispiel #5
0
def translate(model, target, gpu, beam=True, max_len=4):
    model.eval()
    src_dict = model.encoder.embeddings.d
    inp = torch.LongTensor(list(src_dict.transform([target]))).transpose(0, 1)
    length = torch.LongTensor([len(target)]) + 2
    inp, length = u.wrap_variables((inp, length), volatile=True, gpu=False)
    if beam:
        scores, hyps, _ = model.translate_beam(
            inp, length, beam_width=5, max_decode_len=max_len)
    else:
        scores, hyps, _ = model.translate(inp, length, max_decode_len=max_len)

    return scores, hyps