示例#1
0
文件: main.py 项目: buvata/ZaloQA
def train_model_bert(args):
    # need remake config with device option for train with another cuda device
    config = BertConfig.from_pretrained(args.folder_model)

    config = config.to_dict()
    config.update({"device": args.device})
    config.update({"use_pooler": args.use_pooler})
    config.update({"weight_class": args.weight_class})
    config.update({"output_hidden_states": args.output_hidden_states})
    config = BertConfig.from_dict(config)

    tokenizer = BertTokenizer.from_pretrained(args.folder_model)
    model = BERTQa.from_pretrained(args.folder_model, config=config)
    model = model.to(args.device)
    train_squad(args, tokenizer, model)
示例#2
0
文件: ZaloBert.py 项目: buvata/ZaloQA
    def loss(self, input_ids, attention_mask, token_type_ids, label):
        target = label

        final_output = self.compute(input_ids, attention_mask, token_type_ids)
        if self.use_pooler:
            logits = self.qa_outputs(final_output)
        else:
            logits = self.qa_outputs_cat(final_output)

        class_weights = torch.FloatTensor(self.weight_class).to(self.device)
        loss = F.cross_entropy(logits, target, weight=class_weights)

        predict_value = torch.max(logits, 1)[1]
        list_predict = predict_value.cpu().numpy().tolist()
        list_target = target.cpu().numpy().tolist()

        return loss, list_predict, list_target


if __name__ == '__main__':
    from transformers.configuration_bert import BertConfig

    config = BertConfig.from_pretrained("bert-base-multilingual-uncased",
                                        cache_dir="../resources/cache_model")
    config = config.to_dict()
    config.update({"weight_class": [1, 1]})
    config = BertConfig.from_dict(config)
    # model = BERTQa.from_pretrained("bert-base-multilingual-uncased",
    #                                cache_dir="../resources/cache_model", config=config)