def train(): conf = configuration.LstmConfig(vocab_size=68355) tokenizer = tokenization.FullTokenizer( vocab_file=conf.file_config.vocab_file) model = models.TextLstm(conf) model = model.to(device) if args.train: model.train() train_dataset = datasets.OnlineShopping( mode='train', config=conf, tokenizer=tokenizer, auto_padding=conf.train_config.auto_padding) logging.info("***** Running training *****") logging.info(" Num examples = %d", len(train_dataset)) logging.info(" Total training steps: {}".format( train_dataset.num_steps)) train_dataloader = DataLoader( train_dataset, batch_size=conf.train_config.train_batch_size, shuffle=True, collate_fn=collate_fn) run(config=conf, dataloader=train_dataloader, model=model, mode='train', total_steps=train_dataset.num_steps) if args.dev: model.eval() dev_dataset = datasets.OnlineShopping( mode='dev', config=conf, tokenizer=tokenizer, auto_padding=conf.train_config.auto_padding) logging.info("***** Running validating *****") logging.info(" Num examples = %d", len(dev_dataset)) logging.info(" Total validating steps: {}".format( dev_dataset.num_steps)) train_dataloader = DataLoader( dev_dataset, batch_size=conf.train_config.train_batch_size, shuffle=True, collate_fn=collate_fn) run(config=conf, dataloader=train_dataloader, model=model, mode='eval')
def train(): conf = configuration.Config() tokenizer = BertTokenizer.from_pretrained(conf.pretrained_model_name) # 加载bert的预训练模型。指定cache文件夹路径 pretrained_model = os.path.join(conf.pretrained_model_path, conf.pretrained_model_name) if not os.path.exists(pretrained_model): os.mkdir(pretrained_model) model = BertForSequenceClassification.from_pretrained(conf.pretrained_model_name, num_labels=conf.num_labels, cache_dir=os.path.join(pretrained_model, './cache')) model.save_pretrained(pretrained_model) else: model = BertForSequenceClassification.from_pretrained(pretrained_model, num_labels=conf.num_labels) model = model.to(device) if args.train: model.train() train_dataset = datasets.OnlineShopping(mode='train', config=conf, tokenizer=tokenizer, auto_padding=conf.auto_padding) logging.info("***** Running training *****") logging.info(" Num examples = %d", len(train_dataset)) logging.info(" Total training steps: {}".format(train_dataset.num_steps)) train_dataloader = DataLoader(train_dataset, batch_size=conf.train_batch_size, shuffle=True, collate_fn=collate_fn) run(config=conf, dataloader=train_dataloader, model=model, mode='train') if args.dev: model.eval() dev_dataset = datasets.OnlineShopping(mode='dev', config=conf, tokenizer=tokenizer, auto_padding=conf.auto_padding) logging.info("***** Running training *****") logging.info(" Num examples = %d", len(dev_dataset)) logging.info(" Total training steps: {}".format(dev_dataset.num_steps)) dev_dataloader = DataLoader(dev_dataset, batch_size=conf.dev_batch_size, shuffle=True, collate_fn=collate_fn) run(config=conf, dataloader=dev_dataloader, model=model, mode='eval')
def predict(texts): conf = configuration.Config() tokenizer = BertTokenizer.from_pretrained(conf.pretrained_model_name) model = BertForSequenceClassification.from_pretrained(conf.pretrained_model_name, num_labels=conf.num_labels) model = model.to(device) if os.path.exists(os.path.join(conf.model_dir, conf.model_name)): model.load_state_dict(torch.load(os.path.join(conf.model_dir, conf.model_name))) else: logging.info(' *** No model available. *** ') return predict_dataset = datasets.OnlineShopping(mode='single_predict', config=conf, tokenizer=tokenizer, auto_padding=True, texts=texts) predict_dataloader = DataLoader(predict_dataset, batch_size=len(texts), collate_fn=collate_fn) data = next(iter(predict_dataloader)) tokens_tensors, segments_tensors, masks_tensors, _ = [t.to(device) if t is not None else t for t in data] outputs = model(input_ids=tokens_tensors, token_type_ids=segments_tensors, attention_mask=masks_tensors) print(outputs) probs, predictions = get_predictions(outputs, compute_acc=False) return dict(zip(texts, [{'result': label, 'probability': prob} for label, prob in zip([predict_dataset.convert_label_id_to_value(prediction.item()) for prediction in predictions], [prob.item() for prob in probs])]))
def predict(texts): conf = configuration.LstmConfig(vocab_size=68355) tokenizer = tokenization.FullTokenizer( vocab_file=conf.file_config.vocab_file) model = models.TextLstm(conf) model = model.to(device) if os.path.exists( os.path.join(conf.train_config.model_dir, conf.train_config.model_name)): logging.info(' *** Loading model ***') model.load_state_dict( torch.load( os.path.join(conf.train_config.model_dir, conf.train_config.model_name))) else: logging.info(' *** No model available. *** ') return predict_dataset = datasets.OnlineShopping(mode='single_predict', config=conf, tokenizer=tokenizer, auto_padding=True, texts=texts) predict_dataloader = DataLoader(predict_dataset, batch_size=len(predict_dataset), collate_fn=collate_fn) data = next(iter(predict_dataloader)) text_ids, _ = [t.to(device) if t is not None else t for t in data] logits = model(text_ids) print(logits) probs, predictions = get_predictions(logits) return dict( zip(texts, [{ 'result': label, 'probability': prob } for label, prob in zip([ predict_dataset.convert_label_id_to_value(prediction.item()) for prediction in predictions ], [prob.item() for prob in probs])]))