def test_setnamedtupledefaults(self): from collections import namedtuple NT = namedtuple("NT", ("a", "b", "c")) # Shouldn't be able to construct a namedtuple without providing info try: NT() self.fail("Shouldn't be able to construct namedtuple") except TypeError: pass # Test setting default value set_namedtuple_defaults(NT) nt = NT() assert nt.a is None assert nt.b is None assert nt.c is None # Test setting it with something else set_namedtuple_defaults(NT, default=1) nt = NT() assert nt.a == 1 assert nt.b == 1 assert nt.c == 1
set_namedtuple_defaults, argsort, padded_tensor, warn_once, round_sigfigs ) from parlai.core.distributed_utils import is_primary_worker try: import torch except ImportError: raise ImportError('Need to install Pytorch: go to pytorch.org') Batch = namedtuple('Batch', [ 'text_vec', 'text_lengths', 'label_vec', 'label_lengths', 'labels', 'valid_indices', 'candidates', 'candidate_vecs', 'image', 'memory_vecs', 'observations' ]) set_namedtuple_defaults(Batch, default=None) Batch.__doc__ = """ Batch is a namedtuple containing data being sent to an agent. This is the input type of the train_step and eval_step functions. Agents can override the batchify function to return an extended namedtuple with additional fields if they would like, though we recommend calling the parent function to set up these fields as a base. .. py:attribute:: text_vec bsz x seqlen tensor containing the parsed text data. .. py:attribute:: text_lengths list of length bsz containing the lengths of the text in same order as