示例#1
0
    def __init__(self, opt, shared=None):
        self.opt = opt
        self.is_train = 'train' in opt['datatype'] and 'evalmode' not in opt[
            'datatype']
        self.is_valid = 'valid' in opt['datatype']
        self.add_unknown_classes = opt['add_unknown_classes'] and self.is_train
        self.label_candidates = gend_utils.ALL_CANDS

        if shared is None:
            # set map
            self.data = self._setup_data(opt)
            if (self.is_train
                    and opt['balance']) or (self.is_valid
                                            and opt['balance_valid']):
                # don't want to balance the unknown data
                to_exclude = [
                    f'SELF:{gend_utils.UNKNOWN}',
                    f'PARTNER:{gend_utils.UNKNOWN}',
                ]
                self.data = gend_utils.balance_data(self.data,
                                                    key='label',
                                                    exclude_labels=to_exclude)
        else:
            self.data = shared['data']
        super().__init__(opt, shared)
        self.reset()
示例#2
0
    def __init__(self, opt, shared=None):
        self.opt = opt
        self.fixed_random = random.Random(42)
        self.label_candidates = [
            gend_utils.MASC, gend_utils.FEM, gend_utils.NEUTRAL
        ]
        self.use_probably = opt['convai2_use_probably']
        self.labels_to_use = opt['labels_to_use']
        self.is_train = 'train' in opt['datatype'] and 'evalmode' not in opt[
            'datatype']
        self.is_valid = 'valid' in opt['datatype']
        self.add_unknown_classes = opt['add_unknown_classes'] and self.is_train

        if shared and 'data' in shared:
            self.data = shared['data']
            self.persona_map = shared['persona_map']
        else:
            self.missing_cnt = 0
            self._load_persona_map(opt)
            self._setup_data(opt)
            if (self.is_train
                    and opt['balance']) or (self.is_valid
                                            and opt['balance_valid']):
                exclude_lst = gend_utils.ABOUT_CANDS
                self.data = gend_utils.balance_data(self.data,
                                                    exclude_labels=exclude_lst)

        self.label_candidates = gend_utils.ALL_CANDS

        opt = deepcopy(opt)
        super().__init__(opt, shared)
        self.reset()
示例#3
0
    def __init__(self, opt, shared=None):
        self.opt = opt
        self.labels_to_use = opt['labels_to_use']
        self.is_train = 'train' in opt['datatype'] and 'evalmode' not in opt['datatype']
        self.is_valid = 'valid' in opt['datatype']
        self.label_candidates = gend_utils.ALL_CANDS

        if shared is None:
            # set map
            self.data = self._setup_data(opt)
            if (self.is_train and opt['balance']) or (
                self.is_valid and opt['balance_valid']
            ):
                to_exclude = (
                    ['ABOUT:non-binary'],
                )  # not enough non-binary examples to balance
                self.data = gend_utils.balance_data(
                    self.data, key='labels', exclude_labels=to_exclude
                )
        else:
            self.data = shared['data']
        super().__init__(opt, shared)
        self.reset()
示例#4
0
    def __init__(self, opt, shared=None):
        self.opt = opt
        self.is_train = 'train' in opt['datatype'] and 'evalmode' not in opt[
            'datatype']
        self.is_valid = 'valid' in opt['datatype']
        self.add_unknown_classes = opt['add_unknown_classes'] and self.is_train
        if shared is None:
            # set map
            self.data = self._setup_data(opt)
            if (self.is_train
                    and opt['balance']) or (self.is_valid
                                            and opt['balance_valid']):
                to_exclude = gend_utils.PARTNER_CANDS + gend_utils.SELF_CANDS
                self.data = gend_utils.balance_data(self.data,
                                                    key=1,
                                                    exclude_labels=to_exclude)
        else:
            self.data = shared['data']
        super().__init__(opt, shared)

        self.label_candidates = gend_utils.ALL_CANDS

        self.reset()
示例#5
0
    def __init__(self, opt, shared=None):
        self.opt = opt
        self.fixed_random = random.Random(42)
        self.label_candidates = gend_utils.ALL_CANDS
        self.labels_to_use = opt['labels_to_use']
        self.is_train = 'train' in opt['datatype'] and 'evalmode' not in opt[
            'datatype']
        self.is_valid = 'valid' in opt['datatype']
        self.add_unknown_classes = opt['add_unknown_classes'] and self.is_train

        if shared and 'data' in shared:
            self.data = shared['data']
        else:
            self._setup_data(opt)
            if (self.is_train
                    and opt['balance']) or (self.is_valid
                                            and opt['balance_valid']):
                to_exclude = gend_utils.ABOUT_CANDS
                self.data = gend_utils.balance_data(self.data,
                                                    exclude_labels=to_exclude)

        opt = deepcopy(opt)
        super().__init__(opt, shared)
        self.reset()
示例#6
0
    def load_from_chunk(self, chunk_idx: int):
        """
        [Abstract] Given the chunk index, load examples from that chunk.

        Return a list of tuples. The function `_create_message` will take these tuples
        to form the Message object that is returned by the teacher.
        """
        output = []
        chunk_path = self.chunk_idx_to_file[chunk_idx]

        extra_data = []
        with open(chunk_path) as wf:
            for article_json in wf:
                article = json.loads(article_json)
                title = article['title']
                text = article['text']

                title = title.split(' (')[0]
                is_person = check_if_person(title)
                if not is_person:
                    continue

                gender = get_gender(text)

                label = f'ABOUT:{gender}'
                for par in text.split('\n'):
                    if par:
                        output.append((par, title, label, gender, 'about'))
                        self.counts[gender] += 1

                        if self.add_unknown_classes:
                            extra_data.append((
                                par,
                                title,
                                f'SELF:{gend_utils.UNKNOWN}',
                                gender,
                                'self',
                            ))
                            extra_data.append((
                                par,
                                title,
                                f'PARTNER:{gend_utils.NEUTRAL}',
                                gender,
                                'partner',
                            ))

        if len(extra_data) > 0:
            # possibly sample unknown classes
            sample_rate = self.opt['unknown_temp']
            if sample_rate < 1.0:
                to_samp = int(sample_rate * len(extra_data))
                sampled = random.sample(extra_data, to_samp)
                output += sampled
            else:
                output += extra_data

        if DEBUG:
            print('\n\nGender count update:')
            for k, v in self.counts.items():
                print(f'{k}: {v}')

        if (self.is_train
                and self.opt['balance']) or (self.is_valid
                                             and self.opt['balance_valid']):
            exclude_lst = [
                f'ABOUT:{gend_utils.NONBINARY}',
                f'SELF:{gend_utils.UNKNOWN}',
                f'PARTNER:{gend_utils.NEUTRAL}',
            ]  # not enough of each of these examples to balance
            output = gend_utils.balance_data(output,
                                             key=2,
                                             exclude_labels=exclude_lst)

        if len(output) == 0:
            warn_once(f'CHUNK {chunk_idx} is empty')

        return output