Ejemplo n.º 1
0
def prepare_tokenizer(args):
    config = BertConfig.from_json_file(args.config_file)
    # tokenizer = BertTokenizerFast(args.vocab_file, model_max_length=512)
    # print('config', type(config), config,)
    tokenizer = BertTokenizerFast(
        args.vocab_file, model_max_length=config.max_position_embeddings)
    return tokenizer
    def build_model(self, args):
        if args.task == 'BertForELClassification':
            # obtain num_label from dataset before assign model
            config = BertConfig.from_json_file(args.config_file)
            # **YD** mention detection, num_label is by default 3
            assert hasattr(args, 'num_labels')
            assert hasattr(args, 'num_entity_labels')
            assert hasattr(args, 'dim_entity_emb')
            assert hasattr(args, 'EntityEmbedding')

            model = BertForELClassification(config, args)

            # **YD** add load state_dict from pre-trained model
            # could make only master model to load from state_dict, not quite sure whether this works for single GPU
            # if distributed_utils.is_master(args) and args.hetseq_state_dict is not None:
            if args.hetseq_state_dict is not None:
                state_dict = torch.load(args.hetseq_state_dict,
                                        map_location='cpu')['model']
                if args.load_state_dict_strict:
                    model.load_state_dict(state_dict, strict=True)
                else:
                    model.load_state_dict(state_dict, strict=False)

            elif args.transformers_state_dict is not None:
                state_dict = torch.load(args.transformers_state_dict,
                                        map_location='cpu')
                if args.load_state_dict_strict:
                    model.load_state_dict(state_dict, strict=True)
                else:
                    model.load_state_dict(state_dict, strict=False)
        else:
            raise ValueError('Unknown fine_tunning task!')
        return model
Ejemplo n.º 3
0
def prepare_model(args):
    config = BertConfig.from_json_file(args.config_file)
    model = BertForTokenClassification(config, args.num_labels)
    if args.hetseq_state_dict != '':
        # load hetseq state_dictionary
        model.load_state_dict(torch.load(args.hetseq_state_dict, map_location='cpu')['model'], strict=False)
    elif args.transformers_state_dict != '':
        model.load_state_dict(torch.load(args.transformers_state_dict, map_location='cpu'), strict=False)

    return model
Ejemplo n.º 4
0
    def build_model(self, args):
        if args.task == 'bert':
            from hetseq.bert_modeling import BertForPreTraining, BertConfig
            config = BertConfig.from_json_file(args.config_file)
            model = BertForPreTraining(config)

        else:
            raise ValueError("Unsupported language modeling task: {}".format(
                args.task))

        return model