def main(option): glove = Glove(option.emb_file) print(option.emb_file + ' loaded') embedding = Embedding(option.vocab_size, option.emb_dim, padding_idx=1) if option.model_type == 'NTN': model = NeuralTensorNetwork(embedding, option.em_k) elif option.model_type == 'LowRankNTN': model = LowRankNeuralTensorNetwork(embedding, option.em_k, option.em_r) checkpoint = torch.load(option.model_file, map_location='cpu') if type(checkpoint) == dict: if 'event_model_state_dict' in checkpoint: state_dict = checkpoint['event_model_state_dict'] else: state_dict = checkpoint['model_state_dict'] else: state_dict = checkpoint model.load_state_dict(state_dict) print(option.model_file + ' loaded') model.eval() model.to(option.device) all_subj_id = [] all_subj_w = [] all_verb_id = [] all_verb_w = [] all_obj_id = [] all_obj_w = [] all_labels = [] all_event_texts = [] for label, filename in enumerate(option.input_files): print('loading ' + filename) lines = open(filename, 'r').readlines() for line in lines: subj, verb, obj = line.lower().strip().split(' | ') event_text = '(' + subj + ', ' + verb + ', ' + obj + ')' subj = subj.split(' ') verb = verb.split(' ') obj = obj.split(' ') subj_id, subj_w = glove.transform(subj, 10) verb_id, verb_w = glove.transform(verb, 10) obj_id, obj_w = glove.transform(obj, 10) if subj_id is not None and verb_id is not None and obj_id is not None and event_text not in all_event_texts: all_subj_id.append(subj_id) all_subj_w.append(subj_w) all_verb_id.append(verb_id) all_verb_w.append(verb_w) all_obj_id.append(obj_id) all_obj_w.append(obj_w) all_labels.append(label) all_event_texts.append(event_text) all_subj_id = torch.tensor(all_subj_id, dtype=torch.long, device=option.device) all_subj_w = torch.tensor(all_subj_w, dtype=torch.float, device=option.device) all_verb_id = torch.tensor(all_verb_id, dtype=torch.long, device=option.device) all_verb_w = torch.tensor(all_verb_w, dtype=torch.float, device=option.device) all_obj_id = torch.tensor(all_obj_id, dtype=torch.long, device=option.device) all_obj_w = torch.tensor(all_obj_w, dtype=torch.float, device=option.device) all_event_embeddings = model(all_subj_id, all_subj_w, all_verb_id, all_verb_w, all_obj_id, all_obj_w).detach().cpu() torch.save( { 'embeddings': all_event_embeddings, 'labels': torch.tensor(all_labels, dtype=torch.long), 'event_texts': all_event_texts }, option.output_file) print('saved to ' + option.output_file)
def main(option): random.seed(option.random_seed) torch.manual_seed(option.random_seed) LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s' logging.basicConfig(format=LOG_FORMAT, level='INFO', stream=sys.stdout) glove = Glove(option.emb_file) logging.info('loaded embeddings from ' + option.emb_file) src_vocab = Vocab.build_from_glove(glove) tgt_vocab = Vocab.load(option.intent_vocab) train_dataset = load_intent_prediction_dataset(option.train_dataset, src_vocab, tgt_vocab, device=option.device) dev_dataset = load_intent_prediction_dataset(option.dev_dataset, src_vocab, tgt_vocab, device=option.device) train_data_loader = DataLoader(train_dataset, batch_size=option.batch_size, shuffle=True) dev_data_loader = DataLoader(dev_dataset, batch_size=len(dev_dataset), shuffle=False) src_vocab_size = len(src_vocab) tgt_vocab_size = len(tgt_vocab) # Prepare loss weight = torch.ones(tgt_vocab_size) pad = tgt_vocab.stoi[tgt_vocab.pad_token] loss = Perplexity(weight, pad) loss.criterion.to(option.device) # Initialize model encoder = NeuralTensorNetwork(nn.Embedding(src_vocab_size, option.emb_dim), option.em_k) decoder = DecoderRNN(tgt_vocab_size, option.im_max_len, option.im_hidden_size, use_attention=False, bidirectional=False, eos_id=tgt_vocab.stoi[tgt_vocab.eos_token], sos_id=tgt_vocab.stoi[tgt_vocab.bos_token]) encoder.to(option.device) decoder.to(option.device) init_model(encoder) init_model(decoder) encoder.embeddings.weight.data.copy_(torch.from_numpy(glove.embd).float()) optimizer_params = [{ 'params': encoder.parameters() }, { 'params': decoder.parameters() }] optimizer = Optimizer(optim.Adam(optimizer_params, lr=option.lr), max_grad_norm=5) trainer = NTNTrainer(loss, print_every=option.report_every, device=option.device) encoder, decoder = trainer.train( encoder, decoder, optimizer, train_data_loader, num_epochs=option.epochs, dev_data_loader=dev_data_loader, teacher_forcing_ratio=option.im_teacher_forcing_ratio) predictor = NTNPredictor(encoder, decoder, src_vocab, tgt_vocab, option.device) samples = [ ("PersonX", "eventually told", "___"), ("PersonX", "tells", "PersonY 's tale"), ("PersonX", "always played", " ___"), ("PersonX", "would teach", "PersonY"), ("PersonX", "gets", "a ride"), ] for sample in samples: subj, verb, obj = sample subj = subj.lower().split(' ') verb = verb.lower().split(' ') obj = obj.lower().split(' ') print(sample, predictor.predict(subj, verb, obj))