예제 #1
0
파일: wizard.py 프로젝트: simplecoka/cortx
    def _setup_data(self, opt):
        # Load map from image ID to gender
        data = self._load_gender_data(opt)

        extra_data = []
        if self.add_unknown_classes:
            for ex in data:
                self_ex = deepcopy(ex)
                self_ex['label'] = f'SELF:{gend_utils.UNKNOWN}'
                self_ex['class_type'] = 'self'
                extra_data.append(self_ex)

                partner_ex = deepcopy(ex)
                partner_ex['label'] = f'PARTNER:{gend_utils.UNKNOWN}'
                partner_ex['class_type'] = 'partner'
                extra_data.append(partner_ex)

        if len(extra_data) > 0:
            # possibly sample unknown classes
            sample_rate = self.opt['unknown_temp']
            if sample_rate < 1.0:
                to_samp = int(sample_rate * len(extra_data))
                sampled = random.sample(extra_data, to_samp)
                data += sampled
            else:
                data += extra_data

        if self.is_train:
            random.shuffle(data)

        gend_utils.get_data_stats(data, key='label', lst=False)

        return data
예제 #2
0
    def _setup_data(self, opt):
        # Load map from image ID to gender
        data = self._load_gender_data(opt)

        # Possibly add extra examples
        extra_data = []
        if self.add_unknown_classes:
            for ex in data:
                # add self examples
                self_ex = deepcopy(ex)
                self_ex['class_type'] = 'self'
                # not True neutral, so we flip between
                self_ex['labels'] = gend_utils.UNKNOWN_LABELS['self']
                extra_data.append(self_ex)
                # add partner examples
                partner_ex = deepcopy(ex)
                partner_ex['labels'] = gend_utils.UNKNOWN_LABELS['partner']
                partner_ex['class_type'] = 'partner'
                extra_data.append(partner_ex)

            # now sample the data
            sample_rate = self.opt['unknown_temp']
            if sample_rate < 1.0:
                to_samp = int(sample_rate * len(extra_data))
                sampled = random.sample(extra_data, to_samp)
                data += sampled
            else:
                data += extra_data

        if self.is_train:
            random.shuffle(data)

        gend_utils.get_data_stats(data, key='labels')

        return data
예제 #3
0
    def _setup_data(self, opt):
        # Load map from image ID to gender
        # TODO: proper train/test/valid splits
        data = []
        convos = self._get_convos(opt)
        for convo in convos:
            class_type = convo['class_type']
            if self.labels_to_use in ['all', class_type]:
                data.append(convo)

        if self.is_train:
            random.shuffle(data)

        gend_utils.get_data_stats(data, key='labels')

        return data