def test_shuffle(self): """Simple test to ensure that dataloader is initialized with correct data sampler """ dts = ['train', 'valid', 'test'] exts = ['', ':stream', ':ordered', ':stream:ordered'] shuffle_opts = [False, True] task = 'babi:task1k:1' for dt in dts: for ext in exts: datatype = dt + ext for shuffle in shuffle_opts: opt_defaults = { 'pytorch_teacher_task': task, 'datatype': datatype, 'shuffle': shuffle } with testing_utils.capture_output() as _: parser = display_setup_args() parser.set_defaults(**opt_defaults) opt = parser.parse_args() teacher = create_task_agent_from_taskname(opt)[0] if ('ordered' in datatype or ('stream' in datatype and not opt.get('shuffle')) or 'train' not in datatype): self.assertIsInstance( teacher.pytorch_dataloader.sampler, Sequential, 'PytorchDataTeacher failed with args: {}'. format(opt)) else: self.assertIsInstance( teacher.pytorch_dataloader.sampler, RandomSampler, 'PytorchDataTeacher failed with args: {}'. format(opt))
def test_pytd_teacher(self): """ Test that the pytorch teacher works with given Pytorch Datasets as well I'll be using the Flickr30k dataset to ensure that the observations are the same. """ defaults = parser_defaults.copy() defaults['datatype'] = 'train:stream' defaults['image_mode'] = 'ascii' f = io.StringIO() with redirect_stdout(f): # Get processed act from agent parser = display_setup_args() defaults['pytorch_teacher_dataset'] = 'flickr30k' del defaults['pytorch_teacher_task'] parser.set_defaults(**defaults) opt = parser.parse_args() teacher = create_task_agent_from_taskname(opt)[0] pytorch_teacher_act = teacher.act() parser = display_setup_args() defaults['task'] = 'flickr30k' del defaults['pytorch_teacher_dataset'] parser.set_defaults(**defaults) opt = parser.parse_args() teacher = create_task_agent_from_taskname(opt)[0] regular_teacher_act = teacher.act() keys = set(pytorch_teacher_act.keys()).intersection( set(regular_teacher_act.keys())) self.assertTrue(len(keys) != 0) for key in keys: self.assertTrue( pytorch_teacher_act[key] == regular_teacher_act[key], 'PytorchDataTeacher does not have the same value ' 'as regular teacher for act key: {}'.format(key)) print('\n------Passed `test_pytd_teacher`------\n')
def run_display_test(defaults, ep_and_ex_counts): with testing_utils.capture_output() as f: parser = display_setup_args() parser.set_defaults(**defaults) opt = parser.parse_args() display_data(opt) str_output = f.getvalue() self.assertTrue( '[ loaded {} episodes with a total of {} examples ]'.format( ep_and_ex_counts[0], ep_and_ex_counts[1]) in str_output, 'PytorchDataTeacher multitasking failed with ' 'following args: {}'.format(opt))
def test_pytd_teacher(self): """ Test that the pytorch teacher works with given Pytorch Datasets as well """ defaults = integration_test_parser_defaults.copy() defaults['datatype'] = 'train:stream' defaults['image_mode'] = 'ascii' with testing_utils.capture_output(): # Get processed act from agent parser = display_setup_args() defaults['pytorch_teacher_dataset'] = 'integration_tests' del defaults['pytorch_teacher_task'] parser.set_defaults(**defaults) opt = parser.parse_args() teacher = create_task_agent_from_taskname(opt)[0] pytorch_teacher_act = teacher.act() parser = display_setup_args() defaults['task'] = 'integration_tests' del defaults['pytorch_teacher_dataset'] parser.set_defaults(**defaults) opt = parser.parse_args() teacher = create_task_agent_from_taskname(opt)[0] regular_teacher_act = teacher.act() keys = set(pytorch_teacher_act.keys()).intersection( set(regular_teacher_act.keys())) self.assertTrue(len(keys) != 0) for key in keys: self.assertTrue( pytorch_teacher_act[key] == regular_teacher_act[key], 'PytorchDataTeacher does not have the same value ' 'as regular teacher for act key: {}. ' 'Values: {}; {}'.format(key, pytorch_teacher_act[key], regular_teacher_act[key]), )
def test_shuffle(self): """Simple test to ensure that dataloader is initialized with correct data sampler """ dts = ['train', 'valid', 'test'] exts = ['', ':stream', ':ordered', ':stream:ordered'] shuffle_opts = [False, True] task = 'babi:task1k:1' for dt in dts: for ext in exts: datatype = dt + ext for shuffle in shuffle_opts: opt_defaults = { 'pytorch_teacher_task': task, 'datatype': datatype, 'shuffle': shuffle } print('Testing test_shuffle with args {}'.format( opt_defaults)) f = io.StringIO() with redirect_stdout(f): parser = display_setup_args() parser.set_defaults(**opt_defaults) opt = parser.parse_args() teacher = create_task_agent_from_taskname(opt)[0] if ('ordered' in datatype or ('stream' in datatype and not opt.get('shuffle')) or 'train' not in datatype): self.assertTrue( type(teacher.pytorch_dataloader.sampler) is Sequential, 'PytorchDataTeacher failed with args: {}'.format( opt)) else: self.assertTrue( type(teacher.pytorch_dataloader.sampler) is RandomSampler, 'PytorchDataTeacher failed with args: {}'.format( opt)) print('\n------Passed `test_shuffle`------\n')