示例#1
0
def instantiate_model(model_name, vocab_size, embeddings):

    multi_layer_args = yaml.load(open('./configs/multi_layer.yml'), Loader=yaml.FullLoader)

    if model_name == "rcnn":
        model_args = yaml.load(open('./configs/rcnn.yml'), Loader=yaml.FullLoader)
        model = RCNN(vocab_size, embeddings, **{**model_args, **multi_layer_args})

    elif model_name == "textcnn":
        model_args = yaml.load(open('./configs/textcnn.yml'), Loader=yaml.FullLoader)
        model = TextCNN(vocab_size, embeddings, **{**model_args, **multi_layer_args})

    elif model_name == "textrnn":
        model_args = yaml.load(open('./configs/textrnn.yml'), Loader=yaml.FullLoader)
        model = TextRNN(vocab_size, embeddings, **{**model_args, **multi_layer_args})

    elif model_name == "attention_rnn":
        model_args = yaml.load(open('./configs/attention_rnn.yml'), Loader=yaml.FullLoader)
        model = AttentionRNN(vocab_size, embeddings, **{**model_args, **multi_layer_args})

    elif model_name == "transformer":
        model_args = yaml.load(open('./configs/transformer.yml'), Loader=yaml.FullLoader)
        model = Transformer(vocab_size, embeddings, **{**model_args, **multi_layer_args})

    else:
        model_args = yaml.load(open('./configs/fasttext.yml'), Loader=yaml.FullLoader)
        model = FastText(vocab_size, embeddings, **{**model_args, **multi_layer_args})

    logger = get_logger(__name__)
    logger.info("A model of {} is instantiated.".format(model.__class__.__name__))

    return model
def model_selector(config, model_id, use_element):
    model = None
    if use_element:
        print("use element")
        model = ModelWithElement(config, model_id)
    elif model_id == 0:
        model = FastText(config)
    elif model_id == 1:
        model = TextCNN(config)
    elif model_id == 2:
        model = TextRCNN(config)
    elif model_id == 3:
        model = TextRNN(config)
    elif model_id == 4:
        model = HAN(config)
    elif model_id == 5:
        model = CNNWithDoc2Vec(config)
    elif model_id == 6:
        model = RCNNWithDoc2Vec(config)
    elif model_id == 7:
        model = CNNwithInception(config)
    else:
        print("Input ERROR!")
        exit(2)
    return model
示例#3
0
def load_model(model_path, model_id, config):
    if model_id == 0:
        model = FastText(config)
    elif model_id == 1:
        model = TextCNN(config)
    elif model_id == 2:
        model = TextRCNN(config)
    elif model_id == 4:
        model = HAN(config)
#    print(model)
#    time_stamp = '1510844987' 
#    final_model_path = config.model_path+"."+time_stamp+"."+config.model_names[model_id]
    print("load model data:", model_path)
    model.load_state_dict(torch.load(model_path))
    if config.has_cuda:
        model = model.cuda()
    return model
示例#4
0
def _model_selector(config, model_id):
    model = None
    if model_id == 0:
        model = FastText(config)
    elif model_id == 1:
        model = TextCNN(config)
    elif model_id == 2:
        model = TextRCNN(config)
    elif model_id == 4:
        model = HAN(config)
    elif model_id == 5:
        model = CNNWithDoc2Vec(config)
    elif model_id == 6:
        model = RCNNWithDoc2Vec(config)
    else:
        print("Input ERROR!")
        exit(2)
    return model
示例#5
0
def get_models(
        vocab_size,  # 词典大小
        n_class=10,  # 类别个数
        seq_len=38,  # 句子长度
        device=None):  # 设备
    """ 获取所有需要训练的模型 """

    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    fast_text = FastText(vocab_size=vocab_size, n_class=n_class)
    text_cnn = TextCNN(vocab_size=vocab_size, n_class=n_class)
    text_rnn = TextRNN(vocab_size=vocab_size, n_class=n_class)
    text_rcnn = TextRCNN(vocab_size=vocab_size, n_class=n_class)
    transformer = Transformer(vocab_size=vocab_size,
                              seq_len=seq_len,
                              n_class=n_class,
                              device=device)
    return [fast_text, text_cnn, text_rnn, text_rcnn, transformer]
def main():
    logging.info("加载数据......")
    stopwords = []
    #     with open(config.stopWords_file, 'r') as f:
    #         for line in f.readlines():
    #             line = line.strip()
    #             stopwords.append(line)
    train_dataset = MedicalData(config.train_data_file,
                                config.max_seq_length,
                                char_level=config.char_level,
                                stopwords=stopwords)
    dev_dataset = MedicalData(config.eval_data_file,
                              config.max_seq_length,
                              char_level=config.char_level,
                              dictionary=train_dataset.tokenizer,
                              stopwords=stopwords)
    test_dataset = MedicalData(config.test_data_file,
                               config.max_seq_length,
                               char_level=config.char_level,
                               dictionary=train_dataset.tokenizer,
                               stopwords=stopwords)
    if config.model_name == "bilstm":
        tokenizer = train_dataset.tokenizer
        vocab_size = tokenizer.vocabulary_size
        embeddings = torch.from_numpy(train_dataset.embeddings).to(
            torch.float32)
        model = LSTMModel(vocab_size=vocab_size,
                          embedding_dim=config.embedding_dim,
                          hidden_size=config.hidden_size,
                          num_classes=config.num_classes,
                          num_layers=config.num_layers,
                          bidirectional=config.bidirectional,
                          embeddings=embeddings)
    elif "bert" in config.model_name:
        config_class, model_class, tokenizer_class = MODEL_CLASSES[
            config.model_name]
        bert_config = config_class.from_pretrained(
            config.bert_path,
            num_labels=config.num_classes,
            hidden_dropout_prob=config.dropout,
            cache_dir=config.cache_dir if config.cache_dir else None,
        )
        tokenizer = tokenizer_class.from_pretrained(
            config.bert_path,
            do_lower_case=config.do_lower_case,
            cache_dir=config.cache_dir if config.cache_dir else None,
        )
        model = model_class.from_pretrained(
            config.bert_path,
            config=bert_config,
            cache_dir=config.cache_dir if config.cache_dir else None,
        )
        train_dataset.tokenizer = tokenizer
        dev_dataset.tokenizer = tokenizer
        test_dataset.tokenizer = tokenizer
    elif config.model_name == "textcnn":
        tokenizer = train_dataset.tokenizer
        vocab_size = tokenizer.vocabulary_size
        embeddings = torch.from_numpy(train_dataset.embeddings).to(
            torch.float32)
        model = TextCNN(vocab_size=vocab_size,
                        embedding_dim=config.embedding_dim,
                        feature_dim=config.feature_dim,
                        window_size=config.window_size,
                        max_seq_length=config.max_seq_length,
                        num_classes=config.num_classes,
                        dropout=config.dropout,
                        embeddings=embeddings,
                        fine_tune=config.fine_tune)

    if config.model_name != 'transformer':
        init_network(model)
    # logging.info(model.parameters)
    train(model, train_dataset, dev_dataset, test_dataset)
示例#7
0
def create_model(args, num_classes, embedding_vector):
    nl_str = args.nonlin.lower()
    if nl_str == 'relu':
        nonlin = nn.ReLU
    elif nl_str == 'threshrelu':
        nonlin = ThresholdReLU
    elif nl_str == 'sign11':
        nonlin = partial(Sign11, targetprop_rule=args.tp_rule)
    elif nl_str == 'qrelu':
        nonlin = partial(qReLU, targetprop_rule=args.tp_rule, nsteps=3)
    else:
        raise NotImplementedError(
            'no other non-linearities currently supported')

    # input size
    if args.ds == 'sentiment140' or args.ds == 'tsad':
        input_shape, target_shape = (1, 60, 50), None
    elif args.ds == 'semeval':
        input_shape, target_shape = (1, 60, 100), (1, 6, 100)
    else:
        raise NotImplementedError('no other datasets currently supported')

    # create a model with the specified architecture
    if args.arch == 'cnn':
        model = CNN(input_shape, num_classes, embedding_vector, nonlin=nonlin)
    elif args.arch == 'lstm':
        model = LSTM(input_shape, num_classes, embedding_vector)
    elif args.arch == 'cnn-lstm':
        model = CNN_LSTM(input_shape,
                         num_classes,
                         embedding_vector,
                         nonlin=nonlin)
    elif args.arch == 'lstm-cnn':
        model = LSTM_CNN(input_shape,
                         num_classes,
                         embedding_vector,
                         nonlin=nonlin)
    elif args.arch == 'textcnn':
        model = TextCNN(input_shape,
                        num_classes,
                        embedding_vector,
                        nonlin=nonlin)
    elif args.arch == 'bilstm':
        model = BiLSTM(input_shape,
                       target_shape,
                       num_classes,
                       embedding_vector,
                       nonlin=nonlin)
    else:
        raise NotImplementedError('other models not yet supported')

    logging.info("{} model has {} parameters and non-linearity={} ({})".format(
        args.arch, sum([p.data.nelement() for p in model.parameters()]),
        nl_str, args.tp_rule.name))

    if len(args.gpus) > 1:
        model = nn.DataParallel(model)

    if args.cuda:
        model.cuda()

    return model
示例#8
0
    def __call__(self, num_labels: int):
        if self.choose_pretrain == "Bert":
            if self.resume_path:
                model_dir = self.resume_path
            else:
                model_dir = config.bert_model_dir

            if self.choose_model == "BertFC":
                model = BertFCForMultiLable.from_pretrained(
                    model_dir, num_labels=num_labels)
            elif self.choose_model == "BertCNN":
                model = BertCNNForMultiLabel.from_pretrained(
                    model_dir, num_labels=num_labels)
            elif self.choose_model == "BertRCNN":
                model = BertRCNNForMultiLabel.from_pretrained(
                    model_dir, num_labels=num_labels)
            elif self.choose_model == "BertDPCNN":
                model = BertDPCNNForMultiLabel.from_pretrained(
                    model_dir, num_labels=num_labels)
            else:
                raise ModelNotDefinedError

        elif self.choose_pretrain in ["Word2vec", "Nopretrain"]:
            if self.resume_path:
                model_dir = self.resume_path
            else:
                model_dir = None

            if self.choose_pretrain == "Word2vec":
                pretrain_model_dir = config.word2vec_model_dir
            else:
                pretrain_model_dir = None

            if self.choose_model == "TextCNN":
                cnn_config = config.cnn
                cnn_config.embedding_pretrained = pretrain_model_dir
                cnn_config.embedding_size = config.embedding_size
                cnn_config.vocab_size = config.vocab_size
                cnn_config.dropout = config.dropout
                cnn_config.num_labels = num_labels
                if self.resume_path:
                    model = TextCNN(cnn_config)
                    state_dict_file = os.path.join(model_dir,
                                                   "pytorch_model.bin")
                    model.load_state_dict(torch.load(state_dict_file))
                else:
                    model = TextCNN(cnn_config)

            elif self.choose_model == "TextRCNN":
                rcnn_config = config.rcnn
                rcnn_config.num_labels = num_labels
                if self.resume_path:
                    model = TextRCNN(rcnn_config)
                else:
                    model = TextRCNN(rcnn_config)
                    model.load_state_dict(torch.load(model_dir))
            else:
                raise ModelNotDefinedError

        else:
            raise ModelNotDefinedError

        return model