示例#1
0
def instantiate_model(model_name, vocab_size, embeddings):

    multi_layer_args = yaml.load(open('./configs/multi_layer.yml'), Loader=yaml.FullLoader)

    if model_name == "rcnn":
        model_args = yaml.load(open('./configs/rcnn.yml'), Loader=yaml.FullLoader)
        model = RCNN(vocab_size, embeddings, **{**model_args, **multi_layer_args})

    elif model_name == "textcnn":
        model_args = yaml.load(open('./configs/textcnn.yml'), Loader=yaml.FullLoader)
        model = TextCNN(vocab_size, embeddings, **{**model_args, **multi_layer_args})

    elif model_name == "textrnn":
        model_args = yaml.load(open('./configs/textrnn.yml'), Loader=yaml.FullLoader)
        model = TextRNN(vocab_size, embeddings, **{**model_args, **multi_layer_args})

    elif model_name == "attention_rnn":
        model_args = yaml.load(open('./configs/attention_rnn.yml'), Loader=yaml.FullLoader)
        model = AttentionRNN(vocab_size, embeddings, **{**model_args, **multi_layer_args})

    elif model_name == "transformer":
        model_args = yaml.load(open('./configs/transformer.yml'), Loader=yaml.FullLoader)
        model = Transformer(vocab_size, embeddings, **{**model_args, **multi_layer_args})

    else:
        model_args = yaml.load(open('./configs/fasttext.yml'), Loader=yaml.FullLoader)
        model = FastText(vocab_size, embeddings, **{**model_args, **multi_layer_args})

    logger = get_logger(__name__)
    logger.info("A model of {} is instantiated.".format(model.__class__.__name__))

    return model
def model_selector(config, model_id, use_element):
    model = None
    if use_element:
        print("use element")
        model = ModelWithElement(config, model_id)
    elif model_id == 0:
        model = FastText(config)
    elif model_id == 1:
        model = TextCNN(config)
    elif model_id == 2:
        model = TextRCNN(config)
    elif model_id == 3:
        model = TextRNN(config)
    elif model_id == 4:
        model = HAN(config)
    elif model_id == 5:
        model = CNNWithDoc2Vec(config)
    elif model_id == 6:
        model = RCNNWithDoc2Vec(config)
    elif model_id == 7:
        model = CNNwithInception(config)
    else:
        print("Input ERROR!")
        exit(2)
    return model
def load_multi_model(model_path, model_id, config):
    if model_id == 0:
        model = FastText(config)
    elif model_id == 1:
        model = TextCNN(config)
    elif model_id == 2:
        model = TextRCNN(config)
    elif model_id == 3:
        model = TextRNN(config)
    elif model_id == 4:
        model = HAN(config)
#    print(model)
#    time_stamp = '1510844987' 
#    final_model_path = config.model_path+"."+time_stamp+".multi."+config.model_names[model_id]
    print("load model data:", model_path)
    model.load_state_dict(torch.load(model_path))
    if config.has_cuda:
        model = model.cuda()
    
    return model
示例#4
0
def _model_selector(config, model_id):
    model = None
    if model_id == 0:
        model = FastText(config)
    elif model_id == 1:
        model = TextCNN(config)
    elif model_id == 2:
        model = TextRCNN(config)
    elif model_id == 4:
        model = HAN(config)
    elif model_id == 5:
        model = CNNWithDoc2Vec(config)
    elif model_id == 6:
        model = RCNNWithDoc2Vec(config)
    else:
        print("Input ERROR!")
        exit(2)
    return model
示例#5
0
def get_models(
        vocab_size,  # 词典大小
        n_class=10,  # 类别个数
        seq_len=38,  # 句子长度
        device=None):  # 设备
    """ 获取所有需要训练的模型 """

    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    fast_text = FastText(vocab_size=vocab_size, n_class=n_class)
    text_cnn = TextCNN(vocab_size=vocab_size, n_class=n_class)
    text_rnn = TextRNN(vocab_size=vocab_size, n_class=n_class)
    text_rcnn = TextRCNN(vocab_size=vocab_size, n_class=n_class)
    transformer = Transformer(vocab_size=vocab_size,
                              seq_len=seq_len,
                              n_class=n_class,
                              device=device)
    return [fast_text, text_cnn, text_rnn, text_rcnn, transformer]
示例#6
0
        sh = logging.StreamHandler()
        sh.setFormatter(fmt)
        sh.setLevel(logging.DEBUG)
        logger.addHandler(sh)
    # 设置文件日志
    fh = logging.FileHandler(path, encoding='utf-8')
    fh.setFormatter(fmt)
    fh.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    return logger


if __name__ == '__main__':
    from dataprocess import get_vocab, load_dataset, DataIter
    from models.fasttext import FastText

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    c2i = get_vocab()
    model = FastText(vocab_size=len(c2i), n_class=10)

    # data, labels = load_dataset('train')
    dev_samples = load_dataset('dev')

    trian_iter = DataIter(dev_samples)
    test_iter = DataIter(dev_samples)

    train(model, trian_iter, test_iter, device=device)

    pass