예제 #1
0
    def _get_predict_outputs(self, batch_outputs):
        n_inputs = len(list(self.data.values())[0])
        output_arrays = list(zip(*batch_outputs))

        # preds
        all_preds = utils.transform(output_arrays[0], n_inputs).tolist()
        preds = []
        for _pred_ids in all_preds:
            _pred_ids = [
                _pred_id for _pred_id in _pred_ids[1:] if _pred_id != 0
            ]
            _pred_tokens = self.tokenizer.convert_ids_to_tokens(_pred_ids)
            _pred_text = utils.convert_tokens_to_text(_pred_tokens)
            preds.append(_pred_text)

        outputs = {}
        outputs['preds'] = preds

        return outputs
예제 #2
0
    def _get_predict_outputs(self, batch_outputs):
        n_inputs = len(list(self.data.values())[0])
        output_arrays = list(zip(*batch_outputs))

        # MLM preds
        mlm_preds = []
        mlm_positions = self.data['masked_lm_positions']
        all_preds = utils.transform(output_arrays[0], n_inputs)
        for ex_id, _preds in enumerate(all_preds):
            _ids = []
            for p_id, _id in enumerate(_preds):
                if mlm_positions[ex_id][p_id] == 0:
                    break
                _ids.append(_id)
            mlm_preds.append(self.tokenizer.convert_ids_to_tokens(_ids))

        outputs = {}
        outputs['mlm_preds'] = mlm_preds

        return outputs
예제 #3
0
    def _get_predict_outputs(self, batch_outputs):
        n_inputs = len(list(self.data.values())[0])
        output_arrays = list(zip(*batch_outputs))

        # preds
        all_preds = utils.transform(output_arrays[0], n_inputs).tolist()
        preds = []
        for _pred_ids in all_preds:
            _pred_tokens = self.tokenizer.convert_ids_to_tokens(_pred_ids)
            for i in range(self.max_seq_length):
                if _pred_tokens[i] == '<eos>':
                    _pred_tokens = _pred_tokens[:i]
                    break
            _pred_text = utils.convert_tokens_to_text(_pred_tokens)
            preds.append(_pred_text)

        outputs = {}
        outputs['preds'] = preds

        return outputs
예제 #4
0
    def _get_score_outputs(self, batch_outputs):
        n_inputs = len(list(self.data.values())[0])
        output_arrays = list(zip(*batch_outputs))

        # accuracy
        probs = utils.transform(output_arrays[0], n_inputs)
        preds = np.argmax(probs, axis=-1)
        labels = self.data['label_ids']
        accuracy = np.mean(preds == labels)

        # loss
        losses = [-np.log(probs[i][label]) for i, label in enumerate(labels)]
        sample_weight = self.data['sample_weight']
        losses = np.array(losses) * sample_weight
        loss = np.mean(losses)

        outputs = {}
        outputs['accuracy'] = accuracy
        outputs['loss'] = loss

        return outputs
예제 #5
0
    def _get_predict_outputs(self, batch_outputs):
        n_inputs = len(list(self.data.values())[0])
        output_arrays = list(zip(*batch_outputs))

        def _uncertainty(prob):
            if prob < 1e-20 or 1 - prob < 1e-20:
                prob = 1e-20
            return (prob * np.log(prob) + (1 - prob) * np.log(1 - prob)) / \
                np.log(1 / self.label_size)

        def _permutate(batch_probs):
            n_device = max(len(self._gpu_ids), 1)
            d_batch_size = self.batch_size // n_device
            probs = np.zeros((self.batch_size, self.label_size))
            sources = np.zeros((self.batch_size), dtype=np.int32)
            max_loop = \
                self.bert_config.num_hidden_layers + 1 - len(self._ignore_cls)
            keep_cls = [
                cls_idx for cls_idx \
                in list(range(self.bert_config.num_hidden_layers + 1)) \
                if cls_idx not in self._ignore_cls]
            i = 0

            for d in range(n_device):
                unfinished = [k + i for k in range(d_batch_size)]

                for loop in range(max_loop):
                    source = keep_cls[loop]
                    next_unfinished = []

                    for k in range(len(unfinished)):
                        if _uncertainty(batch_probs[i][0]) < self._speed or \
                                loop == max_loop - 1:
                            probs[unfinished[k]] = batch_probs[i]
                            sources[unfinished[k]] = source
                        else:
                            next_unfinished.append(unfinished[k])
                        i += 1
                    unfinished = next_unfinished
            assert i == len(batch_probs)
            return probs, sources

        # probs
        probs_arrays = []
        sources_arrays = []
        for batch_probs in output_arrays[0]:
            probs_array, sources_array = _permutate(batch_probs)
            probs_arrays.append(probs_array)
            sources_arrays.append(sources_array)
        probs = utils.transform(probs_arrays, n_inputs)
        sources = utils.transform(sources_arrays, n_inputs).tolist()

        # preds
        preds = np.argmax(probs, axis=-1).tolist()
        if self._id_to_label:
            preds = [self._id_to_label[idx] for idx in preds]

        outputs = {}
        outputs['preds'] = preds
        outputs['probs'] = probs
        outputs['sources'] = sources

        return outputs