示例#1
0
    def __init__(self, opt: Opt, world: World):
        self.world = world
        super().__init__(opt, agents=world.agents, shared=None)

        import torch.multiprocessing as mp

        self._num_workers = self.opt['num_workers']
        # 4 per worker is somewhat arbitrary. 1 is potentially too few:
        # every worker is prevented from queuing up multiple batches.
        # Unbounded could fill up our memory too much. So 4 per worker.
        self._process_queue = mp.Queue(maxsize=4 * self._num_workers)
        self._process_pool = self._start_processes()

        self._batch_buffer = []
        self.metrics = TeacherMetrics()
    def __init__(self, opt, shared=None):
        super().__init__(opt)

        dt = opt['datatype'].split(':')[0]
        if dt not in ('train', 'test'):
            raise RuntimeError('Not valid datatype (only train/test).')

        task = opt.get('task', 'fvqa:split:0')
        task_num = 0  # default to train/split 0
        split = task.split(':')
        if len(split) > 2:
            task_num = split[2]
            if task_num not in [str(i) for i in range(5)]:
                raise RuntimeError(
                    'Invalid train/test split ID (0-4 inclusive)')

        if not hasattr(self, 'factmetrics'):
            if shared and shared.get('factmetrics'):
                self.factmetrics = shared['factmetrics']
            else:
                self.factmetrics = TeacherMetrics(
                    opt.get('numthreads', 1) > 1,
                    opt.get('metrics', 'default'))
            self.datatype = opt['datatype']
        questions_path, trainset_path, self.image_path = _path(opt)

        if shared and 'ques' in shared:
            self.ques = shared['ques']
        else:
            self._setup_data(questions_path, trainset_path, dt, task_num)
        self.len = len(self.ques)

        self.asked_question = False
        # for ordered data in batch mode (especially, for validation and
        # testing), each teacher in the batch gets a start index and a step
        # size so they all process disparate sets of the data
        self.step_size = opt.get('batchsize', 1)
        self.data_offset = opt.get('batchindex', 0)
        self.image_loader = ImageLoader(opt)

        self.reset()
示例#3
0
class SplitTeacher(Teacher):
    """
    FVQA Teacher, which loads the json VQA data and implements its own `act` method for
    interacting with student agent.

    Use "fvqa:split:X" to choose between splits 0-4 (inclusive), or just "fvqa" to use
    the default split (0).
    """
    def __init__(self, opt, shared=None):
        super().__init__(opt)

        dt = opt['datatype'].split(':')[0]
        if dt not in ('train', 'test'):
            raise RuntimeError('Not valid datatype (only train/test).')

        task = opt.get('task', 'fvqa:split:0')
        task_num = 0  # default to train/split 0
        split = task.split(':')
        if len(split) > 2:
            task_num = split[2]
            if task_num not in [str(i) for i in range(5)]:
                raise RuntimeError(
                    'Invalid train/test split ID (0-4 inclusive)')

        if not hasattr(self, 'factmetrics'):
            if shared and shared.get('factmetrics'):
                self.factmetrics = shared['factmetrics']
            else:
                self.factmetrics = TeacherMetrics(opt.get(
                    'metrics', 'default'))
            self.datatype = opt['datatype']
        questions_path, trainset_path, self.image_path = _path(opt)

        if shared and 'ques' in shared:
            self.ques = shared['ques']
        else:
            self._setup_data(questions_path, trainset_path, dt, task_num)
        self.len = len(self.ques)

        self.asked_question = False
        # for ordered data in batch mode (especially, for validation and
        # testing), each teacher in the batch gets a start index and a step
        # size so they all process disparate sets of the data
        self.step_size = opt.get('batchsize', 1)
        self.data_offset = opt.get('batchindex', 0)
        self.image_loader = ImageLoader(opt)

        self.reset()

    def num_examples(self):
        return self.len

    def num_episodes(self):
        return self.len

    def report(self):
        r = super().report()
        for k, v in self.factmetrics.report().items():
            r[f'factmetrics_{k}'] = v
        return r

    def reset(self):
        # Reset the dialog so that it is at the start of the epoch,
        # and all metrics are reset.
        super().reset()
        self.lastY = None
        self.episode_idx = self.data_offset - self.step_size
        self.epochDone = False

    def reset_metrics(self):
        super().reset_metrics()
        self.factmetrics.clear()

    def observe(self, observation):
        """
        Process observation for metrics.
        """
        if self.lastY is not None:
            if self.asked_question:
                self.metrics.evaluate_response(observation, self.lastY[0])
            else:
                self.factmetrics.evaluate_response(observation, self.lastY[1])
                self.lastY = None
        return observation

    def act(self):
        if self.asked_question:
            self.asked_question = False
            action = {
                'text': 'Which fact supports this answer?',
                'episode_done': True
            }
            if self.datatype.startswith('train'):
                action['labels'] = self.lastY[1]
            if (self.datatype != 'train' and
                    self.episode_idx + self.step_size >= self.num_episodes()):
                self.epochDone = True
            return action

        if self.datatype == 'train':
            self.episode_idx = random.randrange(self.len)
        else:
            self.episode_idx = (self.episode_idx +
                                self.step_size) % self.num_episodes()

        self.asked_question = True
        qa = self.ques[self.episode_idx]
        question = qa['question']
        img_path = self.image_path + qa['img_file']

        action = {
            'image': self.image_loader.load(img_path),
            'text': question,
            'episode_done': False,
        }

        human_readable = qa['fact_surface'].replace('[', '').replace(']', '')
        self.lastY = [[qa['answer']], [human_readable]]

        if self.datatype.startswith('train'):
            action['labels'] = self.lastY[0]

        return action

    def share(self):
        shared = super().share()
        shared['factmetrics'] = self.factmetrics
        shared['ques'] = self.ques
        if hasattr(self, 'facts'):
            shared['facts'] = self.facts
        return shared

    def _setup_data(self, questions_path, trainset_path, datatype, task_num):
        print('loading: ' + questions_path)
        with PathManager.open(questions_path) as questions_file:
            questions = json.load(questions_file)
        train_test_images = set()
        fn = os.path.join(trainset_path,
                          '{}_list_{}.txt'.format(datatype, task_num))
        with PathManager.open(fn) as imageset:
            for line in imageset:
                train_test_images.add(line.strip())
        self.ques = [
            questions[k] for k in sorted(questions.keys())
            if questions[k]['img_file'] in train_test_images
        ]
示例#4
0
class BackgroundDriverWorld(World):
    def __init__(self, opt: Opt, world: World):
        self.world = world
        super().__init__(opt, agents=world.agents, shared=None)

        import torch.multiprocessing as mp

        self._num_workers = self.opt['num_workers']
        # 4 per worker is somewhat arbitrary. 1 is potentially too few:
        # every worker is prevented from queuing up multiple batches.
        # Unbounded could fill up our memory too much. So 4 per worker.
        self._process_queue = mp.Queue(maxsize=4 * self._num_workers)
        self._process_pool = self._start_processes()

        self._batch_buffer = []
        self.metrics = TeacherMetrics()

    def _start_processes(self):
        import torch.multiprocessing as mp

        return mp.start_processes(
            fn=BackgroundWorkerDynamicBatchWorld.launch_process,
            nprocs=self._num_workers,
            # note that index is an an implied argument added by start_processes
            args=(self.opt, self.get_model_agent(), self._process_queue),
            join=False,
            # launch in fork mode so that we can share the model agent easily
            # note that this prevents us from using ANY threads in ANY of the
            # subprocesses! (See ChunkTeacher for one example). Fortunately, we
            # CAN use threads in the MAIN process, and we exploit this at
            # times.
            start_method='fork',
        )

    def reset(self):
        """
        Reset all subworlds.
        """
        self.world.reset()

    def reset_metrics(self):
        """
        Reset metrics in all subworlds.
        """
        self.world.reset_metrics()
        self.metrics.clear()

    def get_task_agent(self):
        return self.world.get_task_agent()

    def get_model_agent(self):
        return self.world.get_model_agent()

    def num_examples(self):
        return self.world.num_examples()

    def num_episodes(self):
        return self.world.num_episodes()

    def _queue_get(self):
        import queue

        while True:
            try:
                return self._process_queue.get(timeout=10)
            except queue.Empty:
                # not getting anything, let's check for exceptions on the
                self._process_pool.join(timeout=0.1)

    def parley(self):
        index, batch = self._queue_get()
        response_object = self.get_model_agent().batch_act(batch)
        # compute metrics
        for response in response_object:
            self.metrics._consume_user_metrics(response)
        self.total_parleys += 1
        self.total_exs += batch.batchsize

    def get_total_exs(self):
        return self.total_exs

    def get_total_epochs(self):
        return self.total_exs / self.num_examples()

    def report(self):
        return aggregate_unnamed_reports(
            [self.world.report(), self.metrics.report()])

    def shutdown(self):
        logging.debug("Killing all the worker processes")
        for p in self._process_pool.processes:
            p.kill()
        super().shutdown()