Ejemplo n.º 1
0
    def test_setnamedtupledefaults(self):
        from collections import namedtuple

        NT = namedtuple("NT", ("a", "b", "c"))

        # Shouldn't be able to construct a namedtuple without providing info
        try:
            NT()
            self.fail("Shouldn't be able to construct namedtuple")
        except TypeError:
            pass

        # Test setting default value
        set_namedtuple_defaults(NT)
        nt = NT()
        assert nt.a is None
        assert nt.b is None
        assert nt.c is None

        # Test setting it with something else
        set_namedtuple_defaults(NT, default=1)
        nt = NT()
        assert nt.a == 1
        assert nt.b == 1
        assert nt.c == 1
Ejemplo n.º 2
0
    set_namedtuple_defaults, argsort, padded_tensor, warn_once, round_sigfigs
)
from parlai.core.distributed_utils import is_primary_worker

try:
    import torch
except ImportError:
    raise ImportError('Need to install Pytorch: go to pytorch.org')


Batch = namedtuple('Batch', [
    'text_vec', 'text_lengths', 'label_vec', 'label_lengths', 'labels',
    'valid_indices', 'candidates', 'candidate_vecs', 'image',
    'memory_vecs', 'observations'
])
set_namedtuple_defaults(Batch, default=None)
Batch.__doc__ = """
Batch is a namedtuple containing data being sent to an agent.

This is the input type of the train_step and eval_step functions.
Agents can override the batchify function to return an extended namedtuple
with additional fields if they would like, though we recommend calling the
parent function to set up these fields as a base.

.. py:attribute:: text_vec

    bsz x seqlen tensor containing the parsed text data.

.. py:attribute:: text_lengths

    list of length bsz containing the lengths of the text in same order as