Ejemplo n.º 1
0
 def verify_batch_lengths(defaults):
     with testing_utils.capture_output() as _, testing_utils.tempdir() as tmpdir:
         # Get processed act from agent
         parser = train_setup_args()
         defaults['model_file'] = os.path.join(tmpdir, 'model')
         defaults['dict_file'] = os.path.join(tmpdir, 'model.dict')
         parser.set_defaults(**defaults)
         opt = parser.parse_args()
         build_dict(opt)
         agent = create_agent(opt)
         world_data = create_task(opt, agent)
         batch_sort_acts = []
         # first epoch
         while len(batch_sort_acts) < 900/50:
             world_data.parley()
             batch_sort_acts.append(world_data.acts[0])
         teacher = world_data.world.get_agents()[0]
         teacher.reset_data()
         # second epoch
         while len(batch_sort_acts) < 1800/50:
             world_data.parley()
             batch_sort_acts.append(world_data.acts[0])
         world_data.shutdown()
     field = defaults['batch_sort_field']
     lengths = [[ep_length(b[field]) for b in bb if field in b]
                for bb in batch_sort_acts[:-2]]  # exclude last batch
     # verify batch lengths
     for batch_lens in lengths:
         self.assertLessEqual(max(batch_lens) - min(batch_lens), max_range,
                              'PytorchDataTeacher batching does not give '
                              'batches with similar sized examples, when '
                              'sorting by `{}` field.'.format(
                                 defaults['batch_sort_field']))
Ejemplo n.º 2
0
        def verify_batch_lengths(defaults):
            f = io.StringIO()

            with redirect_stdout(f):
                # Get processed act from agent
                parser = train_setup_args()
                set_model_file(defaults)
                parser.set_defaults(**defaults)
                opt = parser.parse_args()
                build_dict(opt)
                agent = create_agent(opt)
                world_data = create_task(opt, agent)
                batch_sort_acts = []
                # first epoch
                while len(batch_sort_acts) < 900 / 50:
                    world_data.parley()
                    batch_sort_acts.append(world_data.acts[0])
                teacher = world_data.world.get_agents()[0]
                teacher.reset_data()
                # second epoch
                while len(batch_sort_acts) < 1800 / 50:
                    world_data.parley()
                    batch_sort_acts.append(world_data.acts[0])
            field = defaults['batch_sort_field']
            lengths = [[ep_length(b[field]) for b in bb if field in b]
                       for bb in batch_sort_acts[:-2]]  # exclude last batch
            # verify batch lengths
            for batch_lens in lengths:
                self.assertLessEqual(
                    max(batch_lens) - min(batch_lens), max_range,
                    'PytorchDataTeacher batching does not give '
                    'batches with similar sized examples, when '
                    'sorting by `{}` field.'.format(
                        defaults['batch_sort_field']))