def main(option): logging.basicConfig( stream=sys.stdout, level=logging.DEBUG, format= '%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s' ) torch.manual_seed(option.random_seed) glove = Glove(option.emb_file) logging.info('Embeddings loaded') train_dataset = EventIntentSentimentDataset() logging.info('Loading train dataset: ' + option.train_dataset) train_dataset.load(option.train_dataset, glove) logging.info('Loaded train dataset: ' + option.train_dataset) train_data_loader = torch.utils.data.DataLoader( train_dataset, collate_fn=EventIntentSentimentDataset_collate_fn, batch_size=option.batch_size, shuffle=True) if option.dev_dataset is not None: dev_dataset = EventIntentSentimentDataset() logging.info('Loading dev dataset: ' + option.dev_dataset) dev_dataset.load(option.dev_dataset, glove) logging.info('Loaded dev dataset: ' + option.dev_dataset) dev_data_loader = torch.utils.data.DataLoader( dev_dataset, collate_fn=EventIntentSentimentDataset_collate_fn, batch_size=len(dev_dataset), shuffle=False) yago_train_dataset = YagoDataset() logging.info('Loading YAGO train dataset: ' + option.yago_train_dataset) yago_train_dataset.load(option.yago_train_dataset, glove) yago_train_data_loader = torch.utils.data.DataLoader( yago_train_dataset, collate_fn=YagoDataset_collate_fn, batch_size=option.batch_size, shuffle=True) if option.yago_dev_dataset is not None: yago_dev_dataset = YagoDataset() logging.info('Loading YAGO dev dataset: ' + option.yago_dev_dataset) yago_dev_dataset.load(option.yago_dev_dataset, glove) yago_dev_data_loader = torch.utils.data.DataLoader( yago_dev_dataset, collate_fn=YagoDataset_collate_fn, batch_size=len(yago_dev_dataset), shuffle=False) embeddings = nn.Embedding(option.vocab_size, option.emb_dim, padding_idx=1) if option.model_type == 'NTN': event_model = NeuralTensorNetwork(embeddings, option.em_k) elif option.model_type == 'RoleFactor': event_model = RoleFactoredTensorModel(embeddings, option.em_k) elif option.model_type == 'LowRankNTN': event_model = LowRankNeuralTensorNetwork(embeddings, option.em_k, option.em_r) intent_model = BiLSTMEncoder(embeddings, option.im_hidden_size, option.im_num_layers) relation_models = {} for relation in yago_train_dataset.relations: relation_model = NeuralTensorNetwork_yago(embeddings, option.em_k) relation_models[relation] = relation_model if option.scorer_actv_func == 'sigmoid': scorer_actv_func = nn.Sigmoid elif option.scorer_actv_func == 'relu': scorer_actv_func = nn.ReLU elif option.scorer_actv_func == 'tanh': scorer_actv_func = nn.Tanh event_scorer = nn.Sequential(nn.Linear(option.em_k, 1), scorer_actv_func()) relation_scorer = nn.Sequential(nn.Linear(option.em_k, 1), scorer_actv_func()) intent_scorer = nn.CosineSimilarity(dim=1) sentiment_classifier = nn.Linear(option.em_k, 1) criterion = MarginLoss(option.margin) sentiment_criterion = nn.BCEWithLogitsLoss() yago_criterion = MarginLoss(option.yago_margin) # load pretrained embeddings embeddings.weight.data.copy_(torch.from_numpy(glove.embd).float()) if not option.update_embeddings: event_model.embeddings.weight.requires_grad = False if option.use_gpu: event_model.cuda() intent_model.cuda() sentiment_classifier.cuda() event_scorer.cuda() relation_scorer.cuda() for relation_model in relation_models.values(): relation_model.cuda() embeddings_param_id = [id(param) for param in embeddings.parameters()] params = [{ 'params': embeddings.parameters() }, { 'params': [ param for param in event_model.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }, { 'params': [ param for param in event_scorer.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }, { 'params': [ param for param in intent_model.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }, { 'params': [ param for param in sentiment_classifier.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }] for relation in relation_models: params.append({ 'params': [ param for param in relation_models[relation].parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }) optimizer = torch.optim.Adagrad(params, lr=option.lr) # load checkpoint if provided: if option.load_checkpoint != '': checkpoint = torch.load(option.load_checkpoint) event_model.load_state_dict(checkpoint['event_model_state_dict']) intent_model.load_state_dict(checkpoint['intent_model_state_dict']) event_scorer.load_state_dict(checkpoint['event_scorer_state_dict']) sentiment_classifier.load_state_dict( checkpoint['sentiment_classifier_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) for relation in relation_models: relation_models[relation].load_state_dict( checkpoint['relation_model_state_dict'][relation]) logging.info('Loaded checkpoint: ' + option.load_checkpoint) # load pretrained event model instead: elif option.pretrained_event_model != '': checkpoint = torch.load(option.pretrained_event_model) event_model.load_state_dict(checkpoint['model_state_dict']) logging.info('Loaded pretrained event model: ' + option.pretrained_event_model) for epoch in range(option.epochs): epoch += 1 logging.info('Epoch ' + str(epoch)) # train set # train set avg_loss_e = 0 avg_loss_i = 0 avg_loss_s = 0 avg_loss = 0 avg_loss_sum = 0 avg_loss_event = 0 avg_loss_attr = 0 k = 0 # assume yago dataset is larger than atomic atomic_iterator = itertools.cycle(iter(train_data_loader)) yago_iterator = iter(yago_train_data_loader) # iterate over yago dataset (atomic dataset is cycled) for i, yago_batch in enumerate(yago_iterator): atomic_batch = next(atomic_iterator) optimizer.zero_grad() loss, loss_e, loss_i, loss_s = run_batch( option, atomic_batch, event_model, intent_model, event_scorer, intent_scorer, sentiment_classifier, criterion, sentiment_criterion) loss.backward() optimizer.step() avg_loss_e += loss_e.item() / option.report_every avg_loss_i += loss_i.item() / option.report_every avg_loss_s += loss_s.item() / option.report_every avg_loss += loss.item() / option.report_every if i % option.report_every == 0: logging.info( 'Atomic batch %d, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f' % (i, avg_loss_e, avg_loss_i, avg_loss_s, avg_loss)) avg_loss_e = 0 avg_loss_i = 0 avg_loss_s = 0 avg_loss = 0 optimizer.zero_grad() loss_sum, loss_event, loss_attr = run_yago_batch( option, yago_batch, event_model, relation_models, yago_criterion, relation_scorer, event_scorer) loss_sum.backward() optimizer.step() avg_loss_sum += loss_sum.item() / option.report_every avg_loss_event += loss_event.item() / option.report_every avg_loss_attr += loss_attr.item() / option.report_every if i % option.report_every == 0: logging.info( 'YAGO batch %d, loss_event=%.4f, loss_attr=%.4f, loss=%.4f' % (i, avg_loss_event, avg_loss_attr, avg_loss_sum)) avg_loss_sum = 0 avg_loss_event = 0 avg_loss_attr = 0 # dev set if option.dev_dataset is not None: event_model.eval() intent_model.eval() event_scorer.eval() sentiment_classifier.eval() batch = next(iter(dev_data_loader)) with torch.no_grad(): loss, loss_e, loss_i, loss_s = run_batch( option, batch, event_model, intent_model, event_scorer, intent_scorer, sentiment_classifier, criterion, sentiment_criterion) logging.info( 'Eval on dev set, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f' % (loss_e.item(), loss_i.item(), loss_s.item(), loss.item())) event_model.train() intent_model.train() event_scorer.train() sentiment_classifier.train() # dev set (yago) if option.yago_dev_dataset is not None: for key in relation_models.keys(): relation_models[key].eval() relation_scorer.eval() event_model.eval() yago_dev_batch = next(iter(yago_dev_data_loader)) with torch.no_grad(): loss_sum, loss_event, loss_attr = run_yago_batch( option, yago_dev_batch, event_model, relation_models, criterion_yago, relation_scorer, event_scorer) logging.info( 'Eval on yago dev set, loss_sum=%.4f, loss_event=%.4f, loss_attr=%.4f, ' % (loss_sum.item(), loss_event.item(), loss_attr.item())) for key in relation_models.keys(): relation_models[key].train() relation_scorer.train() if option.save_checkpoint != '': checkpoint = { 'event_model_state_dict': event_model.state_dict(), 'intent_model_state_dict': intent_model.state_dict(), 'event_scorer_state_dict': event_scorer.state_dict(), 'sentiment_classifier_state_dict': sentiment_classifier.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'relation_model_state_dict': { relation: relation_models[relation].state_dict() for relation in relation_models } } torch.save(checkpoint, option.save_checkpoint + '_' + str(epoch)) logging.info('Saved checkpoint: ' + option.save_checkpoint + '_' + str(epoch))
# dev set event_model.eval() intent_model.eval() event_scorer.eval() sentiment_classifier.eval() batch = next(iter(dev_data_loader)) loss, loss_e, loss_i, loss_s = run_batch(option, batch, event_model, intent_model, event_scorer, intent_scorer, sentiment_classifier, criterion, sentiment_criterion) logging.info( 'Eval on dev set, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f' % (loss_e.item(), loss_i.item(), loss_s.item(), loss.item())) event_model.train() intent_model.train() event_scorer.train() sentiment_classifier.train() if option.save_checkpoint != '': checkpoint = { 'event_model_state_dict': event_model.state_dict(), 'intent_model_state_dict': intent_model.state_dict(), 'event_scorer_state_dict': event_scorer.state_dict(), 'sentiment_classifier_state_dict': sentiment_classifier.state_dict(), 'optimizer_state_dict':
def main(option): logging.basicConfig( stream=sys.stdout, level=logging.DEBUG, format= '%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s' ) torch.manual_seed(option.random_seed) glove = Glove(option.emb_file) logging.info('Embeddings loaded') embeddings = nn.Embedding(option.vocab_size, option.emb_dim, padding_idx=1) if option.model_type == 'NTN': event_model = NeuralTensorNetwork(embeddings, option.em_k) elif option.model_type == 'RoleFactor': event_model = RoleFactoredTensorModel(embeddings, option.em_k) elif option.model_type == 'LowRankNTN': event_model = LowRankNeuralTensorNetwork(embeddings, option.em_k, option.em_r) intent_model = BiLSTMEncoder(embeddings, option.im_hidden_size, option.im_num_layers) if option.em_actv_func == 'sigmoid': em_actv_func = nn.Sigmoid() elif option.em_actv_func == 'relu': em_actv_func = nn.ReLU() elif option.em_actv_func == 'tanh': em_actv_func = nn.Tanh() else: logging.info('Unknown event activation func: ' + option.em_actv_func) exit(1) event_scorer = nn.Sequential(nn.Linear(option.em_k, 1), em_actv_func) intent_scorer = nn.CosineSimilarity(dim=1) sentiment_classifier = nn.Linear(option.em_k, 1) criterion = MarginLoss(option.margin) sentiment_criterion = nn.BCEWithLogitsLoss() # load pretrained embeddings embeddings.weight.data.copy_(torch.from_numpy(glove.embd).float()) if not option.update_embeddings: event_model.embeddings.weight.requires_grad = False if option.use_gpu: event_model.cuda() intent_model.cuda() sentiment_classifier.cuda() event_scorer.cuda() embeddings_param_id = [id(param) for param in embeddings.parameters()] params = [{ 'params': embeddings.parameters() }, { 'params': [ param for param in event_model.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }, { 'params': [ param for param in event_scorer.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }, { 'params': [ param for param in intent_model.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }, { 'params': [ param for param in sentiment_classifier.parameters() if id(param) not in embeddings_param_id ], 'weight_decay': option.weight_decay }] optimizer = torch.optim.Adagrad(params, lr=option.lr) # load checkpoint if provided: if option.load_checkpoint != '': checkpoint = torch.load(option.load_checkpoint) event_model.load_state_dict(checkpoint['event_model_state_dict']) intent_model.load_state_dict(checkpoint['intent_model_state_dict']) event_scorer.load_state_dict(checkpoint['event_scorer_state_dict']) sentiment_classifier.load_state_dict( checkpoint['sentiment_classifier_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) logging.info('Loaded checkpoint: ' + option.load_checkpoint) # load pretrained event model instead: elif option.pretrained_event_model != '': checkpoint = torch.load(option.pretrained_event_model) event_model.load_state_dict(checkpoint['model_state_dict']) logging.info('Loaded pretrained event model: ' + option.pretrained_event_model) train_dataset = EventIntentSentimentDataset() logging.info('Loading train dataset: ' + option.train_dataset) train_dataset.load(option.train_dataset, glove) logging.info('Loaded train dataset: ' + option.train_dataset) train_data_loader = torch.utils.data.DataLoader( train_dataset, collate_fn=EventIntentSentimentDataset_collate_fn, batch_size=option.batch_size, shuffle=True) if option.dev_dataset is not None: dev_dataset = EventIntentSentimentDataset() logging.info('Loading dev dataset: ' + option.dev_dataset) dev_dataset.load(option.dev_dataset, glove) logging.info('Loaded dev dataset: ' + option.dev_dataset) dev_data_loader = torch.utils.data.DataLoader( dev_dataset, collate_fn=EventIntentSentimentDataset_collate_fn, batch_size=len(dev_dataset), shuffle=False) for epoch in range(option.epochs): epoch += 1 logging.info('Epoch ' + str(epoch)) # train set avg_loss_e = 0 avg_loss_i = 0 avg_loss_s = 0 avg_loss = 0 for i, batch in enumerate(train_data_loader): i += 1 optimizer.zero_grad() loss, loss_e, loss_i, loss_s = run_batch( option, batch, event_model, intent_model, event_scorer, intent_scorer, sentiment_classifier, criterion, sentiment_criterion) loss.backward() optimizer.step() avg_loss_e += loss_e.item() / option.report_every avg_loss_i += loss_i.item() / option.report_every avg_loss_s += loss_s.item() / option.report_every avg_loss += loss.item() / option.report_every if i % option.report_every == 0: logging.info( 'Batch %d, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f' % (i, avg_loss_e, avg_loss_i, avg_loss_s, avg_loss)) avg_loss_e = 0 avg_loss_i = 0 avg_loss_s = 0 avg_loss = 0 # dev set if option.dev_dataset is not None: event_model.eval() intent_model.eval() event_scorer.eval() sentiment_classifier.eval() batch = next(iter(dev_data_loader)) loss, loss_e, loss_i, loss_s = run_batch( option, batch, event_model, intent_model, event_scorer, intent_scorer, sentiment_classifier, criterion, sentiment_criterion) logging.info( 'Eval on dev set, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f' % (loss_e.item(), loss_i.item(), loss_s.item(), loss.item())) event_model.train() intent_model.train() event_scorer.train() sentiment_classifier.train() if option.save_checkpoint != '': checkpoint = { 'event_model_state_dict': event_model.state_dict(), 'intent_model_state_dict': intent_model.state_dict(), 'event_scorer_state_dict': event_scorer.state_dict(), 'sentiment_classifier_state_dict': sentiment_classifier.state_dict(), 'optimizer_state_dict': optimizer.state_dict() } torch.save(checkpoint, option.save_checkpoint + '_' + str(epoch)) logging.info('Saved checkpoint: ' + option.save_checkpoint + '_' + str(epoch))