コード例 #1
0
ファイル: agents.py プロジェクト: jojonki/ParlAI
    def __init__(self, opt, shared=None):
        super().__init__(opt)

        dt = opt['datatype'].split(':')[0]
        if dt not in ('train', 'test'):
            raise RuntimeError('Not valid datatype (only train/test).')

        task = opt.get('task', 'fvqa:split:0')
        task_num = 0  # default to train/split 0
        split = task.split(':')
        if len(split) > 2:
            task_num = split[2]
            if task_num not in [str(i) for i in range(5)]:
                raise RuntimeError('Invalid train/test split ID (0-4 inclusive)')

        if not hasattr(self, 'factmetrics'):
            if shared and shared.get('factmetrics'):
                self.factmetrics = shared['factmetrics']
            else:
                self.factmetrics = Metrics(opt)
            self.datatype = opt['datatype']
        questions_path, trainset_path, self.image_path = _path(opt)

        if shared and 'ques' in shared:
            self.ques = shared['ques']
        else:
            self._setup_data(questions_path, trainset_path, dt, task_num)
        self.len = len(self.ques)

        self.asked_question = False
        # for ordered data in batch mode (especially, for validation and
        # testing), each teacher in the batch gets a start index and a step
        # size so they all process disparate sets of the data
        self.step_size = opt.get('batchsize', 1)
        self.data_offset = opt.get('batchindex', 0)
        self.image_loader = ImageLoader(opt)

        self.reset()
コード例 #2
0
class TodMetrics(Metrics):
    """
    Helper container which encapsulates TOD metrics and does some basic prepocessing to
    handlers to calculate said metrics.

    This class should generally not need to be changed; add new metrics handlers to
    `WORLD_METRIC_HANDLERS` (or otherwise override `self.handlers` of this class) to
    change metrics actively being used.
    """
    def __init__(self, shared: Dict[str, Any] = None) -> None:
        super().__init__(shared=shared)
        self.handlers = [x() for x in WORLD_METRIC_HANDLERS]
        self.convo_started = False
        self.last_episode_metrics = Metrics()

    def handle_message(self, message: Message, agent_type: TodAgentType):
        if "text" not in message:
            return
        if agent_type == TodAgentType.GOAL_GROUNDING_AGENT and len(
                message["text"]) > len(STANDARD_GOAL):
            # Only count a conversation as started if there is a goal.
            self.convo_started = True
        for handler in self.handlers:
            metrics = self._handle_message_impl(message, agent_type, handler)
            if metrics is not None:
                for name, metric in metrics.items():
                    if metric is not None:
                        self.add(name, metric)

    def _handle_message_impl(
        self,
        message: Message,
        agent_type: TodAgentType,
        handler: world_metrics_handlers.TodMetricsHandler,
    ):
        prefix_stripped_text = message["text"].replace(
            TOD_AGENT_TYPE_TO_PREFIX[agent_type], "")
        if agent_type is TodAgentType.API_SCHEMA_GROUNDING_AGENT:
            return handler.handle_api_schemas(
                message,
                SerializationHelpers.str_to_api_schemas(prefix_stripped_text))
        if agent_type is TodAgentType.GOAL_GROUNDING_AGENT:
            return handler.handle_goals(
                message,
                SerializationHelpers.str_to_goals(prefix_stripped_text))
        if agent_type is TodAgentType.USER_UTT_AGENT:
            return handler.handle_user_utt(message, prefix_stripped_text)
        if agent_type is TodAgentType.API_CALL_AGENT:
            return handler.handle_api_call(
                message,
                SerializationHelpers.str_to_api_dict(prefix_stripped_text))
        if agent_type is TodAgentType.API_RESP_AGENT:
            return handler.handle_api_resp(
                message,
                SerializationHelpers.str_to_api_dict(prefix_stripped_text))
        if agent_type is TodAgentType.SYSTEM_UTT_AGENT:
            return handler.handle_sys_utt(message, prefix_stripped_text)

    def get_last_episode_metrics(self):
        """
        This is a bit of a hack so that we can  report whether or not a convo has
        successfully hit all goals and associate this with each episode for the purposes
        of doing filtering.
        """
        return self.last_episode_metrics

    def episode_reset(self):
        self.last_episode_metrics = None
        if self.convo_started:
            self.last_episode_metrics = Metrics()
            for handler in self.handlers:
                metrics = handler.get_episode_metrics()
                handler.episode_reset()
                if metrics is not None:
                    for name, metric in metrics.items():
                        if metric is not None:
                            self.add(name, metric)
                            self.last_episode_metrics.add(name, metric)
            self.convo_started = False
コード例 #3
0
    def test_multithreaded(self):
        m = Metrics(threadsafe=True)
        m2 = Metrics(threadsafe=True, shared=m.share())
        m3 = Metrics(threadsafe=True, shared=m.share())

        m2.add('key', SumMetric(1))
        m2.flush()
        m3.add('key', SumMetric(2))
        m3.flush()
        m.add('key', SumMetric(3))
        m.flush()
        m.report()['key'] == 6
コード例 #4
0
    def test_shared(self):
        m = Metrics(threadsafe=False)
        m2 = Metrics(threadsafe=False, shared=m.share())
        m3 = Metrics(threadsafe=False, shared=m.share())

        m2.add('key', SumMetric(1))
        m3.add('key', SumMetric(2))
        m2.flush()  # just make sure this doesn't throw exception, it's a no-op
        m.add('key', SumMetric(3))

        assert m.report()['key'] == 6

        # shouldn't throw exception
        m.flush()
        m2.flush()
        m3.flush()
コード例 #5
0
    def test_simpleadd(self):
        m = Metrics(threadsafe=False)
        m.add('key', SumMetric(1))
        m.add('key', SumMetric(2))
        assert m.report()['key'] == 3

        m.clear()
        assert 'key' not in m.report()

        m.add('key', SumMetric(1.5))
        m.add('key', SumMetric(2.5))
        assert m.report()['key'] == 4.0

        # shouldn't throw exception
        m.flush()
コード例 #6
0
 def test_recent(self):
     m = Metrics()
     m2 = Metrics(shared=m.share())
     m.add('test', SumMetric(1))
     assert m.report() == {'test': 1}
     assert m.report_recent() == {'test': 1}
     m.clear_recent()
     m.add('test', SumMetric(2))
     assert m.report() == {'test': 3}
     assert m.report_recent() == {'test': 2}
     assert m2.report() == {'test': 3}
     assert m2.report_recent() == {}
     m2.add('test', SumMetric(3))
     assert m2.report() == {'test': 6}
     assert m.report() == {'test': 6}
     assert m2.report_recent() == {'test': 3}
     assert m.report_recent() == {'test': 2}
     m2.clear_recent()
     assert m2.report() == {'test': 6}
     assert m.report() == {'test': 6}
     assert m2.report_recent() == {}
     assert m.report_recent() == {'test': 2}
     m.clear_recent()
     assert m2.report() == {'test': 6}
     assert m.report() == {'test': 6}
     assert m.report_recent() == {}
コード例 #7
0
    def test_multithreaded(self):
        # legacy test, but left because it's just another test
        m = Metrics()
        m2 = Metrics(shared=m.share())
        m3 = Metrics(shared=m.share())

        m2.add('key', SumMetric(1))
        m3.add('key', SumMetric(2))
        m.add('key', SumMetric(3))
        assert m.report()['key'] == 6
コード例 #8
0
    def test_shared(self):
        m = Metrics()
        m2 = Metrics(shared=m.share())
        m3 = Metrics(shared=m.share())

        m2.add('key', SumMetric(1))
        m3.add('key', SumMetric(2))
        m.add('key', SumMetric(3))

        assert m.report()['key'] == 6
コード例 #9
0
    def test_simpleadd(self):
        m = Metrics()
        m.add('key', SumMetric(1))
        m.add('key', SumMetric(2))
        assert m.report()['key'] == 3

        m.clear()
        assert 'key' not in m.report()

        m.add('key', SumMetric(1.5))
        m.add('key', SumMetric(2.5))
        assert m.report()['key'] == 4.0
コード例 #10
0
ファイル: agents.py プロジェクト: xlrshop/Parl
class SplitTeacher(Teacher):
    """FVQA Teacher, which loads the json VQA data and implements its own
    `act` method for interacting with student agent.

    Use "fvqa:split:X" to choose between splits 0-4 (inclusive), or just
    "fvqa" to use the default split (0).
    """
    def __init__(self, opt, shared=None):
        super().__init__(opt)

        dt = opt['datatype'].split(':')[0]
        if dt not in ('train', 'test'):
            raise RuntimeError('Not valid datatype (only train/test).')

        task = opt.get('task', 'fvqa:split:0')
        task_num = 0  # default to train/split 0
        split = task.split(':')
        if len(split) > 2:
            task_num = split[2]
            if task_num not in [str(i) for i in range(5)]:
                raise RuntimeError(
                    'Invalid train/test split ID (0-4 inclusive)')

        if not hasattr(self, 'factmetrics'):
            if shared and shared.get('factmetrics'):
                self.factmetrics = shared['factmetrics']
            else:
                self.factmetrics = Metrics(opt)
            self.datatype = opt['datatype']
        questions_path, trainset_path, self.image_path = _path(opt)

        if shared and 'ques' in shared:
            self.ques = shared['ques']
        else:
            self._setup_data(questions_path, trainset_path, dt, task_num)
        self.len = len(self.ques)

        self.asked_question = False
        # for ordered data in batch mode (especially, for validation and
        # testing), each teacher in the batch gets a start index and a step
        # size so they all process disparate sets of the data
        self.step_size = opt.get('batchsize', 1)
        self.data_offset = opt.get('batchindex', 0)
        self.image_loader = ImageLoader(opt)

        self.reset()

    def num_examples(self):
        return self.len

    def num_episodes(self):
        return self.len

    def report(self):
        r = super().report()
        r['factmetrics'] = self.factmetrics.report()
        return r

    def reset(self):
        # Reset the dialog so that it is at the start of the epoch,
        # and all metrics are reset.
        super().reset()
        self.lastY = None
        self.episode_idx = self.data_offset - self.step_size
        self.epochDone = False

    def reset_metrics(self):
        super().reset_metrics()
        self.factmetrics.clear()

    def observe(self, observation):
        """Process observation for metrics."""
        if self.lastY is not None:
            if self.asked_question:
                self.metrics.update(observation, self.lastY[0])
            else:
                self.factmetrics.update(observation, self.lastY[1])
                self.lastY = None
        return observation

    def act(self):
        if self.asked_question:
            self.asked_question = False
            action = {
                'text': 'Which fact supports this answer?',
                'episode_done': True
            }
            if self.datatype.startswith('train'):
                action['labels'] = self.lastY[1]
            if self.datatype != 'train' and self.episode_idx + self.step_size >= self.num_episodes(
            ):
                self.epochDone = True
            return action

        if self.datatype == 'train':
            self.episode_idx = random.randrange(self.len)
        else:
            self.episode_idx = (self.episode_idx +
                                self.step_size) % self.num_episodes()

        self.asked_question = True
        qa = self.ques[self.episode_idx]
        question = qa['question']
        img_path = self.image_path + qa['img_file']

        action = {
            'image': self.image_loader.load(img_path),
            'text': question,
            'episode_done': False
        }

        human_readable = qa['fact_surface'].replace('[', '').replace(']', '')
        self.lastY = [[qa['answer']], [human_readable]]

        if self.datatype.startswith('train'):
            action['labels'] = self.lastY[0]

        return action

    def share(self):
        shared = super().share()
        shared['factmetrics'] = self.factmetrics
        shared['ques'] = self.ques
        if hasattr(self, 'facts'):
            shared['facts'] = self.facts
        return shared

    def _setup_data(self, questions_path, trainset_path, datatype, task_num):
        print('loading: ' + questions_path)
        with open(questions_path) as questions_file:
            questions = json.load(questions_file)
        train_test_images = set()
        with open(
                os.path.join(trainset_path,
                             '{}_list_{}.txt'.format(datatype,
                                                     task_num))) as imageset:
            for line in imageset:
                train_test_images.add(line.strip())
        self.ques = [
            questions[k] for k in sorted(questions.keys())
            if questions[k]['img_file'] in train_test_images
        ]
コード例 #11
0
ファイル: agents.py プロジェクト: jojonki/ParlAI
class SplitTeacher(Teacher):
    """FVQA Teacher, which loads the json VQA data and implements its own
    `act` method for interacting with student agent.

    Use "fvqa:split:X" to choose between splits 0-4 (inclusive), or just
    "fvqa" to use the default split (0).
    """

    def __init__(self, opt, shared=None):
        super().__init__(opt)

        dt = opt['datatype'].split(':')[0]
        if dt not in ('train', 'test'):
            raise RuntimeError('Not valid datatype (only train/test).')

        task = opt.get('task', 'fvqa:split:0')
        task_num = 0  # default to train/split 0
        split = task.split(':')
        if len(split) > 2:
            task_num = split[2]
            if task_num not in [str(i) for i in range(5)]:
                raise RuntimeError('Invalid train/test split ID (0-4 inclusive)')

        if not hasattr(self, 'factmetrics'):
            if shared and shared.get('factmetrics'):
                self.factmetrics = shared['factmetrics']
            else:
                self.factmetrics = Metrics(opt)
            self.datatype = opt['datatype']
        questions_path, trainset_path, self.image_path = _path(opt)

        if shared and 'ques' in shared:
            self.ques = shared['ques']
        else:
            self._setup_data(questions_path, trainset_path, dt, task_num)
        self.len = len(self.ques)

        self.asked_question = False
        # for ordered data in batch mode (especially, for validation and
        # testing), each teacher in the batch gets a start index and a step
        # size so they all process disparate sets of the data
        self.step_size = opt.get('batchsize', 1)
        self.data_offset = opt.get('batchindex', 0)
        self.image_loader = ImageLoader(opt)

        self.reset()

    def __len__(self):
        return self.len

    def report(self):
        r = super().report()
        r['factmetrics'] = self.factmetrics.report()
        return r

    def reset(self):
        # Reset the dialog so that it is at the start of the epoch,
        # and all metrics are reset.
        super().reset()
        self.lastY = None
        self.episode_idx = self.data_offset - self.step_size
        self.epochDone = False

    def reset_metrics(self):
        super().reset_metrics()
        self.factmetrics.clear()

    def observe(self, observation):
        """Process observation for metrics."""
        if self.lastY is not None:
            if self.asked_question:
                self.metrics.update(observation, self.lastY[0])
            else:
                self.factmetrics.update(observation, self.lastY[1])
                self.lastY = None
        return observation

    def act(self):
        if self.asked_question:
            self.asked_question = False
            action = {'text': 'Which fact supports this answer?', 'episode_done': True}
            if self.datatype.startswith('train'):
                action['labels'] = self.lastY[1]
            if self.datatype != 'train' and self.episode_idx + self.step_size >= len(self):
                self.epochDone = True
            return action

        if self.datatype == 'train':
            self.episode_idx = random.randrange(self.len)
        else:
            self.episode_idx = (self.episode_idx + self.step_size) % len(self)

        self.asked_question = True
        qa = self.ques[self.episode_idx]
        question = qa['question']
        img_path = self.image_path + qa['img_file']

        action = {
            'image': self.image_loader.load(img_path),
            'text': question,
            'episode_done': False
        }

        human_readable = qa['fact_surface'].replace('[', '').replace(']', '')
        self.lastY = [[qa['answer']], [human_readable]]

        if self.datatype.startswith('train'):
            action['labels'] = self.lastY[0]

        return action

    def share(self):
        shared = super().share()
        shared['factmetrics'] = self.factmetrics
        shared['ques'] = self.ques
        if hasattr(self, 'facts'):
            shared['facts'] = self.facts
        return shared

    def _setup_data(self, questions_path, trainset_path, datatype, task_num):
        print('loading: ' + questions_path)
        with open(questions_path) as questions_file:
            questions = json.load(questions_file)
        train_test_images = set()
        with open(os.path.join(trainset_path, '{}_list_{}.txt'.format(datatype, task_num))) as imageset:
            for line in imageset:
                train_test_images.add(line.strip())
        self.ques = [questions[k] for k in sorted(questions.keys()) if questions[k]['img_file'] in train_test_images]
コード例 #12
0
class Teacher(Agent):
    """
    Basic Teacher agent that keeps track of how many times it's received messages.

    Teachers provide the ``report()`` method to get back metrics.
    """
    def __init__(self, opt, shared=None):
        if not hasattr(self, 'opt'):
            self.opt = copy.deepcopy(opt)
        if not hasattr(self, 'id'):
            self.id = opt.get('task', 'teacher')
        if not hasattr(self, 'metrics'):
            if shared and shared.get('metrics'):
                self.metrics = shared['metrics']
            else:
                self.metrics = Metrics(opt)
        self.epochDone = False

    # return state/action dict based upon passed state
    def act(self):
        """Act upon the previous observation."""
        if self.observation is not None and 'text' in self.observation:
            t = {'text': 'Hello agent!'}
        return t

    def epoch_done(self):
        """Return whether the epoch is done."""
        return self.epochDone

    # Default unknown length
    def num_examples(self):
        """
        Return the number of examples (e.g. individual utterances) in the dataset.

        Default implementation returns `None`, indicating an unknown number.
        """
        return None

    def num_episodes(self):
        """
        Return the number of episodes (e.g. conversations) in the dataset.

        Default implementation returns `None`, indicating an unknown number.
        """
        return None

    def report(self):
        """Return metrics showing total examples and accuracy if available."""
        return self.metrics.report()

    def reset(self):
        """Reset the teacher."""
        super().reset()
        self.reset_metrics()
        self.epochDone = False

    def reset_metrics(self):
        """Reset metrics."""
        self.metrics.clear()

    def share(self):
        """In addition to default Agent shared parameters, share metrics."""
        shared = super().share()
        shared['metrics'] = self.metrics
        return shared