def instantiate_model(model_name, vocab_size, embeddings): multi_layer_args = yaml.load(open('./configs/multi_layer.yml'), Loader=yaml.FullLoader) if model_name == "rcnn": model_args = yaml.load(open('./configs/rcnn.yml'), Loader=yaml.FullLoader) model = RCNN(vocab_size, embeddings, **{**model_args, **multi_layer_args}) elif model_name == "textcnn": model_args = yaml.load(open('./configs/textcnn.yml'), Loader=yaml.FullLoader) model = TextCNN(vocab_size, embeddings, **{**model_args, **multi_layer_args}) elif model_name == "textrnn": model_args = yaml.load(open('./configs/textrnn.yml'), Loader=yaml.FullLoader) model = TextRNN(vocab_size, embeddings, **{**model_args, **multi_layer_args}) elif model_name == "attention_rnn": model_args = yaml.load(open('./configs/attention_rnn.yml'), Loader=yaml.FullLoader) model = AttentionRNN(vocab_size, embeddings, **{**model_args, **multi_layer_args}) elif model_name == "transformer": model_args = yaml.load(open('./configs/transformer.yml'), Loader=yaml.FullLoader) model = Transformer(vocab_size, embeddings, **{**model_args, **multi_layer_args}) else: model_args = yaml.load(open('./configs/fasttext.yml'), Loader=yaml.FullLoader) model = FastText(vocab_size, embeddings, **{**model_args, **multi_layer_args}) logger = get_logger(__name__) logger.info("A model of {} is instantiated.".format(model.__class__.__name__)) return model
def model_selector(config, model_id, use_element): model = None if use_element: print("use element") model = ModelWithElement(config, model_id) elif model_id == 0: model = FastText(config) elif model_id == 1: model = TextCNN(config) elif model_id == 2: model = TextRCNN(config) elif model_id == 3: model = TextRNN(config) elif model_id == 4: model = HAN(config) elif model_id == 5: model = CNNWithDoc2Vec(config) elif model_id == 6: model = RCNNWithDoc2Vec(config) elif model_id == 7: model = CNNwithInception(config) else: print("Input ERROR!") exit(2) return model
def load_multi_model(model_path, model_id, config): if model_id == 0: model = FastText(config) elif model_id == 1: model = TextCNN(config) elif model_id == 2: model = TextRCNN(config) elif model_id == 3: model = TextRNN(config) elif model_id == 4: model = HAN(config) # print(model) # time_stamp = '1510844987' # final_model_path = config.model_path+"."+time_stamp+".multi."+config.model_names[model_id] print("load model data:", model_path) model.load_state_dict(torch.load(model_path)) if config.has_cuda: model = model.cuda() return model
def _model_selector(config, model_id): model = None if model_id == 0: model = FastText(config) elif model_id == 1: model = TextCNN(config) elif model_id == 2: model = TextRCNN(config) elif model_id == 4: model = HAN(config) elif model_id == 5: model = CNNWithDoc2Vec(config) elif model_id == 6: model = RCNNWithDoc2Vec(config) else: print("Input ERROR!") exit(2) return model
def get_models( vocab_size, # 词典大小 n_class=10, # 类别个数 seq_len=38, # 句子长度 device=None): # 设备 """ 获取所有需要训练的模型 """ if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') fast_text = FastText(vocab_size=vocab_size, n_class=n_class) text_cnn = TextCNN(vocab_size=vocab_size, n_class=n_class) text_rnn = TextRNN(vocab_size=vocab_size, n_class=n_class) text_rcnn = TextRCNN(vocab_size=vocab_size, n_class=n_class) transformer = Transformer(vocab_size=vocab_size, seq_len=seq_len, n_class=n_class, device=device) return [fast_text, text_cnn, text_rnn, text_rcnn, transformer]
sh = logging.StreamHandler() sh.setFormatter(fmt) sh.setLevel(logging.DEBUG) logger.addHandler(sh) # 设置文件日志 fh = logging.FileHandler(path, encoding='utf-8') fh.setFormatter(fmt) fh.setLevel(logging.DEBUG) logger.addHandler(fh) return logger if __name__ == '__main__': from dataprocess import get_vocab, load_dataset, DataIter from models.fasttext import FastText device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') c2i = get_vocab() model = FastText(vocab_size=len(c2i), n_class=10) # data, labels = load_dataset('train') dev_samples = load_dataset('dev') trian_iter = DataIter(dev_samples) test_iter = DataIter(dev_samples) train(model, trian_iter, test_iter, device=device) pass