Exemplo n.º 1
0
 def test_shuffle(self):
     """Simple test to ensure that dataloader is initialized with correct
         data sampler
     """
     dts = ['train', 'valid', 'test']
     exts = ['', ':stream', ':ordered', ':stream:ordered']
     shuffle_opts = [False, True]
     task = 'babi:task1k:1'
     for dt in dts:
         for ext in exts:
             datatype = dt + ext
             for shuffle in shuffle_opts:
                 opt_defaults = {
                     'pytorch_teacher_task': task,
                     'datatype': datatype,
                     'shuffle': shuffle
                 }
                 with testing_utils.capture_output() as _:
                     parser = display_setup_args()
                     parser.set_defaults(**opt_defaults)
                     opt = parser.parse_args()
                     teacher = create_task_agent_from_taskname(opt)[0]
                     if ('ordered' in datatype or
                         ('stream' in datatype and not opt.get('shuffle'))
                             or 'train' not in datatype):
                         self.assertIsInstance(
                             teacher.pytorch_dataloader.sampler, Sequential,
                             'PytorchDataTeacher failed with args: {}'.
                             format(opt))
                     else:
                         self.assertIsInstance(
                             teacher.pytorch_dataloader.sampler,
                             RandomSampler,
                             'PytorchDataTeacher failed with args: {}'.
                             format(opt))
Exemplo n.º 2
0
    def test_pytd_teacher(self):
        """
            Test that the pytorch teacher works with given Pytorch Datasets
            as well

            I'll be using the Flickr30k dataset to ensure that the observations
            are the same.
        """
        defaults = parser_defaults.copy()
        defaults['datatype'] = 'train:stream'
        defaults['image_mode'] = 'ascii'

        f = io.StringIO()

        with redirect_stdout(f):
            # Get processed act from agent
            parser = display_setup_args()
            defaults['pytorch_teacher_dataset'] = 'flickr30k'
            del defaults['pytorch_teacher_task']
            parser.set_defaults(**defaults)
            opt = parser.parse_args()
            teacher = create_task_agent_from_taskname(opt)[0]
            pytorch_teacher_act = teacher.act()

            parser = display_setup_args()
            defaults['task'] = 'flickr30k'
            del defaults['pytorch_teacher_dataset']
            parser.set_defaults(**defaults)
            opt = parser.parse_args()
            teacher = create_task_agent_from_taskname(opt)[0]
            regular_teacher_act = teacher.act()

        keys = set(pytorch_teacher_act.keys()).intersection(
            set(regular_teacher_act.keys()))
        self.assertTrue(len(keys) != 0)
        for key in keys:
            self.assertTrue(
                pytorch_teacher_act[key] == regular_teacher_act[key],
                'PytorchDataTeacher does not have the same value '
                'as regular teacher for act key: {}'.format(key))
        print('\n------Passed `test_pytd_teacher`------\n')
 def run_display_test(defaults, ep_and_ex_counts):
     with testing_utils.capture_output() as f:
         parser = display_setup_args()
         parser.set_defaults(**defaults)
         opt = parser.parse_args()
         display_data(opt)
     str_output = f.getvalue()
     self.assertTrue(
         '[ loaded {} episodes with a total of {} examples ]'.format(
             ep_and_ex_counts[0], ep_and_ex_counts[1]) in str_output,
         'PytorchDataTeacher multitasking failed with '
         'following args: {}'.format(opt))
    def test_pytd_teacher(self):
        """
        Test that the pytorch teacher works with given Pytorch Datasets
        as well
        """
        defaults = integration_test_parser_defaults.copy()
        defaults['datatype'] = 'train:stream'
        defaults['image_mode'] = 'ascii'

        with testing_utils.capture_output():
            # Get processed act from agent
            parser = display_setup_args()
            defaults['pytorch_teacher_dataset'] = 'integration_tests'
            del defaults['pytorch_teacher_task']
            parser.set_defaults(**defaults)
            opt = parser.parse_args()
            teacher = create_task_agent_from_taskname(opt)[0]
            pytorch_teacher_act = teacher.act()

            parser = display_setup_args()
            defaults['task'] = 'integration_tests'
            del defaults['pytorch_teacher_dataset']
            parser.set_defaults(**defaults)
            opt = parser.parse_args()
            teacher = create_task_agent_from_taskname(opt)[0]
            regular_teacher_act = teacher.act()

        keys = set(pytorch_teacher_act.keys()).intersection(
            set(regular_teacher_act.keys()))
        self.assertTrue(len(keys) != 0)
        for key in keys:
            self.assertTrue(
                pytorch_teacher_act[key] == regular_teacher_act[key],
                'PytorchDataTeacher does not have the same value '
                'as regular teacher for act key: {}. '
                'Values: {}; {}'.format(key, pytorch_teacher_act[key],
                                        regular_teacher_act[key]),
            )
Exemplo n.º 5
0
 def test_shuffle(self):
     """Simple test to ensure that dataloader is initialized with correct
         data sampler
     """
     dts = ['train', 'valid', 'test']
     exts = ['', ':stream', ':ordered', ':stream:ordered']
     shuffle_opts = [False, True]
     task = 'babi:task1k:1'
     for dt in dts:
         for ext in exts:
             datatype = dt + ext
             for shuffle in shuffle_opts:
                 opt_defaults = {
                     'pytorch_teacher_task': task,
                     'datatype': datatype,
                     'shuffle': shuffle
                 }
                 print('Testing test_shuffle with args {}'.format(
                     opt_defaults))
                 f = io.StringIO()
                 with redirect_stdout(f):
                     parser = display_setup_args()
                     parser.set_defaults(**opt_defaults)
                     opt = parser.parse_args()
                     teacher = create_task_agent_from_taskname(opt)[0]
                 if ('ordered' in datatype or
                     ('stream' in datatype and not opt.get('shuffle'))
                         or 'train' not in datatype):
                     self.assertTrue(
                         type(teacher.pytorch_dataloader.sampler) is
                         Sequential,
                         'PytorchDataTeacher failed with args: {}'.format(
                             opt))
                 else:
                     self.assertTrue(
                         type(teacher.pytorch_dataloader.sampler) is
                         RandomSampler,
                         'PytorchDataTeacher failed with args: {}'.format(
                             opt))
     print('\n------Passed `test_shuffle`------\n')