if not option.update_embeddings: event_model.embeddings.weight.requires_grad = False if option.use_gpu: event_model.cuda() intent_model.cuda() sentiment_classifier.cuda() event_scorer.cuda() embeddings_param_id = [id(param) for param in embeddings.parameters()] params = [{ 'params': embeddings.parameters() }, { 'params': [ param for param in event_model.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }, { 'params': [ param for param in event_scorer.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }, { 'params': [ param for param in intent_model.parameters() if id(param) not in embeddings_param_id
def main(option): logging.basicConfig( stream=sys.stdout, level=logging.DEBUG, format= '%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s' ) torch.manual_seed(option.random_seed) glove = Glove(option.emb_file) logging.info('Embeddings loaded') train_dataset = EventIntentSentimentDataset() logging.info('Loading train dataset: ' + option.train_dataset) train_dataset.load(option.train_dataset, glove) logging.info('Loaded train dataset: ' + option.train_dataset) train_data_loader = torch.utils.data.DataLoader( train_dataset, collate_fn=EventIntentSentimentDataset_collate_fn, batch_size=option.batch_size, shuffle=True) if option.dev_dataset is not None: dev_dataset = EventIntentSentimentDataset() logging.info('Loading dev dataset: ' + option.dev_dataset) dev_dataset.load(option.dev_dataset, glove) logging.info('Loaded dev dataset: ' + option.dev_dataset) dev_data_loader = torch.utils.data.DataLoader( dev_dataset, collate_fn=EventIntentSentimentDataset_collate_fn, batch_size=len(dev_dataset), shuffle=False) yago_train_dataset = YagoDataset() logging.info('Loading YAGO train dataset: ' + option.yago_train_dataset) yago_train_dataset.load(option.yago_train_dataset, glove) yago_train_data_loader = torch.utils.data.DataLoader( yago_train_dataset, collate_fn=YagoDataset_collate_fn, batch_size=option.batch_size, shuffle=True) if option.yago_dev_dataset is not None: yago_dev_dataset = YagoDataset() logging.info('Loading YAGO dev dataset: ' + option.yago_dev_dataset) yago_dev_dataset.load(option.yago_dev_dataset, glove) yago_dev_data_loader = torch.utils.data.DataLoader( yago_dev_dataset, collate_fn=YagoDataset_collate_fn, batch_size=len(yago_dev_dataset), shuffle=False) embeddings = nn.Embedding(option.vocab_size, option.emb_dim, padding_idx=1) if option.model_type == 'NTN': event_model = NeuralTensorNetwork(embeddings, option.em_k) elif option.model_type == 'RoleFactor': event_model = RoleFactoredTensorModel(embeddings, option.em_k) elif option.model_type == 'LowRankNTN': event_model = LowRankNeuralTensorNetwork(embeddings, option.em_k, option.em_r) intent_model = BiLSTMEncoder(embeddings, option.im_hidden_size, option.im_num_layers) relation_models = {} for relation in yago_train_dataset.relations: relation_model = NeuralTensorNetwork_yago(embeddings, option.em_k) relation_models[relation] = relation_model if option.scorer_actv_func == 'sigmoid': scorer_actv_func = nn.Sigmoid elif option.scorer_actv_func == 'relu': scorer_actv_func = nn.ReLU elif option.scorer_actv_func == 'tanh': scorer_actv_func = nn.Tanh event_scorer = nn.Sequential(nn.Linear(option.em_k, 1), scorer_actv_func()) relation_scorer = nn.Sequential(nn.Linear(option.em_k, 1), scorer_actv_func()) intent_scorer = nn.CosineSimilarity(dim=1) sentiment_classifier = nn.Linear(option.em_k, 1) criterion = MarginLoss(option.margin) sentiment_criterion = nn.BCEWithLogitsLoss() yago_criterion = MarginLoss(option.yago_margin) # load pretrained embeddings embeddings.weight.data.copy_(torch.from_numpy(glove.embd).float()) if not option.update_embeddings: event_model.embeddings.weight.requires_grad = False if option.use_gpu: event_model.cuda() intent_model.cuda() sentiment_classifier.cuda() event_scorer.cuda() relation_scorer.cuda() for relation_model in relation_models.values(): relation_model.cuda() embeddings_param_id = [id(param) for param in embeddings.parameters()] params = [{ 'params': embeddings.parameters() }, { 'params': [ param for param in event_model.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }, { 'params': [ param for param in event_scorer.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }, { 'params': [ param for param in intent_model.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }, { 'params': [ param for param in sentiment_classifier.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }] for relation in relation_models: params.append({ 'params': [ param for param in relation_models[relation].parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }) optimizer = torch.optim.Adagrad(params, lr=option.lr) # load checkpoint if provided: if option.load_checkpoint != '': checkpoint = torch.load(option.load_checkpoint) event_model.load_state_dict(checkpoint['event_model_state_dict']) intent_model.load_state_dict(checkpoint['intent_model_state_dict']) event_scorer.load_state_dict(checkpoint['event_scorer_state_dict']) sentiment_classifier.load_state_dict( checkpoint['sentiment_classifier_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) for relation in relation_models: relation_models[relation].load_state_dict( checkpoint['relation_model_state_dict'][relation]) logging.info('Loaded checkpoint: ' + option.load_checkpoint) # load pretrained event model instead: elif option.pretrained_event_model != '': checkpoint = torch.load(option.pretrained_event_model) event_model.load_state_dict(checkpoint['model_state_dict']) logging.info('Loaded pretrained event model: ' + option.pretrained_event_model) for epoch in range(option.epochs): epoch += 1 logging.info('Epoch ' + str(epoch)) # train set # train set avg_loss_e = 0 avg_loss_i = 0 avg_loss_s = 0 avg_loss = 0 avg_loss_sum = 0 avg_loss_event = 0 avg_loss_attr = 0 k = 0 # assume yago dataset is larger than atomic atomic_iterator = itertools.cycle(iter(train_data_loader)) yago_iterator = iter(yago_train_data_loader) # iterate over yago dataset (atomic dataset is cycled) for i, yago_batch in enumerate(yago_iterator): atomic_batch = next(atomic_iterator) optimizer.zero_grad() loss, loss_e, loss_i, loss_s = run_batch( option, atomic_batch, event_model, intent_model, event_scorer, intent_scorer, sentiment_classifier, criterion, sentiment_criterion) loss.backward() optimizer.step() avg_loss_e += loss_e.item() / option.report_every avg_loss_i += loss_i.item() / option.report_every avg_loss_s += loss_s.item() / option.report_every avg_loss += loss.item() / option.report_every if i % option.report_every == 0: logging.info( 'Atomic batch %d, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f' % (i, avg_loss_e, avg_loss_i, avg_loss_s, avg_loss)) avg_loss_e = 0 avg_loss_i = 0 avg_loss_s = 0 avg_loss = 0 optimizer.zero_grad() loss_sum, loss_event, loss_attr = run_yago_batch( option, yago_batch, event_model, relation_models, yago_criterion, relation_scorer, event_scorer) loss_sum.backward() optimizer.step() avg_loss_sum += loss_sum.item() / option.report_every avg_loss_event += loss_event.item() / option.report_every avg_loss_attr += loss_attr.item() / option.report_every if i % option.report_every == 0: logging.info( 'YAGO batch %d, loss_event=%.4f, loss_attr=%.4f, loss=%.4f' % (i, avg_loss_event, avg_loss_attr, avg_loss_sum)) avg_loss_sum = 0 avg_loss_event = 0 avg_loss_attr = 0 # dev set if option.dev_dataset is not None: event_model.eval() intent_model.eval() event_scorer.eval() sentiment_classifier.eval() batch = next(iter(dev_data_loader)) with torch.no_grad(): loss, loss_e, loss_i, loss_s = run_batch( option, batch, event_model, intent_model, event_scorer, intent_scorer, sentiment_classifier, criterion, sentiment_criterion) logging.info( 'Eval on dev set, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f' % (loss_e.item(), loss_i.item(), loss_s.item(), loss.item())) event_model.train() intent_model.train() event_scorer.train() sentiment_classifier.train() # dev set (yago) if option.yago_dev_dataset is not None: for key in relation_models.keys(): relation_models[key].eval() relation_scorer.eval() event_model.eval() yago_dev_batch = next(iter(yago_dev_data_loader)) with torch.no_grad(): loss_sum, loss_event, loss_attr = run_yago_batch( option, yago_dev_batch, event_model, relation_models, criterion_yago, relation_scorer, event_scorer) logging.info( 'Eval on yago dev set, loss_sum=%.4f, loss_event=%.4f, loss_attr=%.4f, ' % (loss_sum.item(), loss_event.item(), loss_attr.item())) for key in relation_models.keys(): relation_models[key].train() relation_scorer.train() if option.save_checkpoint != '': checkpoint = { 'event_model_state_dict': event_model.state_dict(), 'intent_model_state_dict': intent_model.state_dict(), 'event_scorer_state_dict': event_scorer.state_dict(), 'sentiment_classifier_state_dict': sentiment_classifier.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'relation_model_state_dict': { relation: relation_models[relation].state_dict() for relation in relation_models } } torch.save(checkpoint, option.save_checkpoint + '_' + str(epoch)) logging.info('Saved checkpoint: ' + option.save_checkpoint + '_' + str(epoch))
def main(option): random.seed(option.random_seed) torch.manual_seed(option.random_seed) LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s' logging.basicConfig(format=LOG_FORMAT, level='INFO', stream=sys.stdout) glove = Glove(option.emb_file) logging.info('loaded embeddings from ' + option.emb_file) src_vocab = Vocab.build_from_glove(glove) tgt_vocab = Vocab.load(option.intent_vocab) train_dataset = load_intent_prediction_dataset(option.train_dataset, src_vocab, tgt_vocab, device=option.device) dev_dataset = load_intent_prediction_dataset(option.dev_dataset, src_vocab, tgt_vocab, device=option.device) train_data_loader = DataLoader(train_dataset, batch_size=option.batch_size, shuffle=True) dev_data_loader = DataLoader(dev_dataset, batch_size=len(dev_dataset), shuffle=False) src_vocab_size = len(src_vocab) tgt_vocab_size = len(tgt_vocab) # Prepare loss weight = torch.ones(tgt_vocab_size) pad = tgt_vocab.stoi[tgt_vocab.pad_token] loss = Perplexity(weight, pad) loss.criterion.to(option.device) # Initialize model encoder = NeuralTensorNetwork(nn.Embedding(src_vocab_size, option.emb_dim), option.em_k) decoder = DecoderRNN(tgt_vocab_size, option.im_max_len, option.im_hidden_size, use_attention=False, bidirectional=False, eos_id=tgt_vocab.stoi[tgt_vocab.eos_token], sos_id=tgt_vocab.stoi[tgt_vocab.bos_token]) encoder.to(option.device) decoder.to(option.device) init_model(encoder) init_model(decoder) encoder.embeddings.weight.data.copy_(torch.from_numpy(glove.embd).float()) optimizer_params = [{ 'params': encoder.parameters() }, { 'params': decoder.parameters() }] optimizer = Optimizer(optim.Adam(optimizer_params, lr=option.lr), max_grad_norm=5) trainer = NTNTrainer(loss, print_every=option.report_every, device=option.device) encoder, decoder = trainer.train( encoder, decoder, optimizer, train_data_loader, num_epochs=option.epochs, dev_data_loader=dev_data_loader, teacher_forcing_ratio=option.im_teacher_forcing_ratio) predictor = NTNPredictor(encoder, decoder, src_vocab, tgt_vocab, option.device) samples = [ ("PersonX", "eventually told", "___"), ("PersonX", "tells", "PersonY 's tale"), ("PersonX", "always played", " ___"), ("PersonX", "would teach", "PersonY"), ("PersonX", "gets", "a ride"), ] for sample in samples: subj, verb, obj = sample subj = subj.lower().split(' ') verb = verb.lower().split(' ') obj = obj.lower().split(' ') print(sample, predictor.predict(subj, verb, obj))
def main(option): logging.basicConfig( stream=sys.stdout, level=logging.DEBUG, format= '%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s' ) torch.manual_seed(option.random_seed) glove = Glove(option.emb_file) logging.info('Embeddings loaded') embeddings = nn.Embedding(option.vocab_size, option.emb_dim, padding_idx=1) if option.model_type == 'NTN': event_model = NeuralTensorNetwork(embeddings, option.em_k) elif option.model_type == 'RoleFactor': event_model = RoleFactoredTensorModel(embeddings, option.em_k) elif option.model_type == 'LowRankNTN': event_model = LowRankNeuralTensorNetwork(embeddings, option.em_k, option.em_r) intent_model = BiLSTMEncoder(embeddings, option.im_hidden_size, option.im_num_layers) if option.em_actv_func == 'sigmoid': em_actv_func = nn.Sigmoid() elif option.em_actv_func == 'relu': em_actv_func = nn.ReLU() elif option.em_actv_func == 'tanh': em_actv_func = nn.Tanh() else: logging.info('Unknown event activation func: ' + option.em_actv_func) exit(1) event_scorer = nn.Sequential(nn.Linear(option.em_k, 1), em_actv_func) intent_scorer = nn.CosineSimilarity(dim=1) sentiment_classifier = nn.Linear(option.em_k, 1) criterion = MarginLoss(option.margin) sentiment_criterion = nn.BCEWithLogitsLoss() # load pretrained embeddings embeddings.weight.data.copy_(torch.from_numpy(glove.embd).float()) if not option.update_embeddings: event_model.embeddings.weight.requires_grad = False if option.use_gpu: event_model.cuda() intent_model.cuda() sentiment_classifier.cuda() event_scorer.cuda() embeddings_param_id = [id(param) for param in embeddings.parameters()] params = [{ 'params': embeddings.parameters() }, { 'params': [ param for param in event_model.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }, { 'params': [ param for param in event_scorer.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }, { 'params': [ param for param in intent_model.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }, { 'params': [ param for param in sentiment_classifier.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }] optimizer = torch.optim.Adagrad(params, lr=option.lr) # load checkpoint if provided: if option.load_checkpoint != '': checkpoint = torch.load(option.load_checkpoint) event_model.load_state_dict(checkpoint['event_model_state_dict']) intent_model.load_state_dict(checkpoint['intent_model_state_dict']) event_scorer.load_state_dict(checkpoint['event_scorer_state_dict']) sentiment_classifier.load_state_dict( checkpoint['sentiment_classifier_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) logging.info('Loaded checkpoint: ' + option.load_checkpoint) # load pretrained event model instead: elif option.pretrained_event_model != '': checkpoint = torch.load(option.pretrained_event_model) event_model.load_state_dict(checkpoint['model_state_dict']) logging.info('Loaded pretrained event model: ' + option.pretrained_event_model) train_dataset = EventIntentSentimentDataset() logging.info('Loading train dataset: ' + option.train_dataset) train_dataset.load(option.train_dataset, glove) logging.info('Loaded train dataset: ' + option.train_dataset) train_data_loader = torch.utils.data.DataLoader( train_dataset, collate_fn=EventIntentSentimentDataset_collate_fn, batch_size=option.batch_size, shuffle=True) if option.dev_dataset is not None: dev_dataset = EventIntentSentimentDataset() logging.info('Loading dev dataset: ' + option.dev_dataset) dev_dataset.load(option.dev_dataset, glove) logging.info('Loaded dev dataset: ' + option.dev_dataset) dev_data_loader = torch.utils.data.DataLoader( dev_dataset, collate_fn=EventIntentSentimentDataset_collate_fn, batch_size=len(dev_dataset), shuffle=False) for epoch in range(option.epochs): epoch += 1 logging.info('Epoch ' + str(epoch)) # train set avg_loss_e = 0 avg_loss_i = 0 avg_loss_s = 0 avg_loss = 0 for i, batch in enumerate(train_data_loader): i += 1 optimizer.zero_grad() loss, loss_e, loss_i, loss_s = run_batch( option, batch, event_model, intent_model, event_scorer, intent_scorer, sentiment_classifier, criterion, sentiment_criterion) loss.backward() optimizer.step() avg_loss_e += loss_e.item() / option.report_every avg_loss_i += loss_i.item() / option.report_every avg_loss_s += loss_s.item() / option.report_every avg_loss += loss.item() / option.report_every if i % option.report_every == 0: logging.info( 'Batch %d, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f' % (i, avg_loss_e, avg_loss_i, avg_loss_s, avg_loss)) avg_loss_e = 0 avg_loss_i = 0 avg_loss_s = 0 avg_loss = 0 # dev set if option.dev_dataset is not None: event_model.eval() intent_model.eval() event_scorer.eval() sentiment_classifier.eval() batch = next(iter(dev_data_loader)) loss, loss_e, loss_i, loss_s = run_batch( option, batch, event_model, intent_model, event_scorer, intent_scorer, sentiment_classifier, criterion, sentiment_criterion) logging.info( 'Eval on dev set, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f' % (loss_e.item(), loss_i.item(), loss_s.item(), loss.item())) event_model.train() intent_model.train() event_scorer.train() sentiment_classifier.train() if option.save_checkpoint != '': checkpoint = { 'event_model_state_dict': event_model.state_dict(), 'intent_model_state_dict': intent_model.state_dict(), 'event_scorer_state_dict': event_scorer.state_dict(), 'sentiment_classifier_state_dict': sentiment_classifier.state_dict(), 'optimizer_state_dict': optimizer.state_dict() } torch.save(checkpoint, option.save_checkpoint + '_' + str(epoch)) logging.info('Saved checkpoint: ' + option.save_checkpoint + '_' + str(epoch))