Exemple #1
0
def main(option):
    glove = Glove(option.emb_file)
    print(option.emb_file + ' loaded')

    embedding = Embedding(option.vocab_size, option.emb_dim, padding_idx=1)
    if option.model_type == 'NTN':
        model = NeuralTensorNetwork(embedding, option.em_k)
    elif option.model_type == 'LowRankNTN':
        model = LowRankNeuralTensorNetwork(embedding, option.em_k, option.em_r)

    checkpoint = torch.load(option.model_file, map_location='cpu')
    if type(checkpoint) == dict:
        if 'event_model_state_dict' in checkpoint:
            state_dict = checkpoint['event_model_state_dict']
        else:
            state_dict = checkpoint['model_state_dict']
    else:
        state_dict = checkpoint
    model.load_state_dict(state_dict)
    print(option.model_file + ' loaded')
    model.eval()
    model.to(option.device)

    all_subj_id = []
    all_subj_w = []
    all_verb_id = []
    all_verb_w = []
    all_obj_id = []
    all_obj_w = []
    all_labels = []
    all_event_texts = []
    for label, filename in enumerate(option.input_files):
        print('loading ' + filename)
        lines = open(filename, 'r').readlines()
        for line in lines:
            subj, verb, obj = line.lower().strip().split(' | ')
            event_text = '(' + subj + ', ' + verb + ', ' + obj + ')'
            subj = subj.split(' ')
            verb = verb.split(' ')
            obj = obj.split(' ')
            subj_id, subj_w = glove.transform(subj, 10)
            verb_id, verb_w = glove.transform(verb, 10)
            obj_id, obj_w = glove.transform(obj, 10)
            if subj_id is not None and verb_id is not None and obj_id is not None and event_text not in all_event_texts:
                all_subj_id.append(subj_id)
                all_subj_w.append(subj_w)
                all_verb_id.append(verb_id)
                all_verb_w.append(verb_w)
                all_obj_id.append(obj_id)
                all_obj_w.append(obj_w)
                all_labels.append(label)
                all_event_texts.append(event_text)

    all_subj_id = torch.tensor(all_subj_id,
                               dtype=torch.long,
                               device=option.device)
    all_subj_w = torch.tensor(all_subj_w,
                              dtype=torch.float,
                              device=option.device)
    all_verb_id = torch.tensor(all_verb_id,
                               dtype=torch.long,
                               device=option.device)
    all_verb_w = torch.tensor(all_verb_w,
                              dtype=torch.float,
                              device=option.device)
    all_obj_id = torch.tensor(all_obj_id,
                              dtype=torch.long,
                              device=option.device)
    all_obj_w = torch.tensor(all_obj_w,
                             dtype=torch.float,
                             device=option.device)
    all_event_embeddings = model(all_subj_id, all_subj_w, all_verb_id,
                                 all_verb_w, all_obj_id,
                                 all_obj_w).detach().cpu()

    torch.save(
        {
            'embeddings': all_event_embeddings,
            'labels': torch.tensor(all_labels, dtype=torch.long),
            'event_texts': all_event_texts
        }, option.output_file)
    print('saved to ' + option.output_file)
        'weight_decay':
        option.weight_decay
    }, {
        'params': [
            param for param in sentiment_classifier.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }]
    optimizer = torch.optim.Adagrad(params, lr=option.lr)

    # load checkpoint if provided:
    if option.load_checkpoint != '':
        checkpoint = torch.load(option.load_checkpoint)
        event_model.load_state_dict(checkpoint['event_model_state_dict'])
        intent_model.load_state_dict(checkpoint['intent_model_state_dict'])
        event_scorer.load_state_dict(checkpoint['event_scorer_state_dict'])
        sentiment_classifier.load_state_dict(
            checkpoint['sentiment_classifier_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        logging.info('Loaded checkpoint: ' + option.load_checkpoint)
    # load pretrained event model instead:
    elif option.pretrained_event_model != '':
        checkpoint = torch.load(option.pretrained_event_model)
        event_model.load_state_dict(checkpoint['model_state_dict'])
        logging.info('Loaded pretrained event model: ' +
                     option.pretrained_event_model)

    train_dataset = EventIntentSentimentDataset()
    logging.info('Loading train dataset: ' + option.train_dataset)
Exemple #3
0
def main(option):
    logging.basicConfig(
        stream=sys.stdout,
        level=logging.DEBUG,
        format=
        '%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
    )
    torch.manual_seed(option.random_seed)

    glove = Glove(option.emb_file)
    logging.info('Embeddings loaded')

    train_dataset = EventIntentSentimentDataset()
    logging.info('Loading train dataset: ' + option.train_dataset)
    train_dataset.load(option.train_dataset, glove)
    logging.info('Loaded train dataset: ' + option.train_dataset)
    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        collate_fn=EventIntentSentimentDataset_collate_fn,
        batch_size=option.batch_size,
        shuffle=True)

    if option.dev_dataset is not None:
        dev_dataset = EventIntentSentimentDataset()
        logging.info('Loading dev dataset: ' + option.dev_dataset)
        dev_dataset.load(option.dev_dataset, glove)
        logging.info('Loaded dev dataset: ' + option.dev_dataset)
        dev_data_loader = torch.utils.data.DataLoader(
            dev_dataset,
            collate_fn=EventIntentSentimentDataset_collate_fn,
            batch_size=len(dev_dataset),
            shuffle=False)

    yago_train_dataset = YagoDataset()
    logging.info('Loading YAGO train dataset: ' + option.yago_train_dataset)
    yago_train_dataset.load(option.yago_train_dataset, glove)
    yago_train_data_loader = torch.utils.data.DataLoader(
        yago_train_dataset,
        collate_fn=YagoDataset_collate_fn,
        batch_size=option.batch_size,
        shuffle=True)

    if option.yago_dev_dataset is not None:
        yago_dev_dataset = YagoDataset()
        logging.info('Loading YAGO dev dataset: ' + option.yago_dev_dataset)
        yago_dev_dataset.load(option.yago_dev_dataset, glove)
        yago_dev_data_loader = torch.utils.data.DataLoader(
            yago_dev_dataset,
            collate_fn=YagoDataset_collate_fn,
            batch_size=len(yago_dev_dataset),
            shuffle=False)

    embeddings = nn.Embedding(option.vocab_size, option.emb_dim, padding_idx=1)
    if option.model_type == 'NTN':
        event_model = NeuralTensorNetwork(embeddings, option.em_k)
    elif option.model_type == 'RoleFactor':
        event_model = RoleFactoredTensorModel(embeddings, option.em_k)
    elif option.model_type == 'LowRankNTN':
        event_model = LowRankNeuralTensorNetwork(embeddings, option.em_k,
                                                 option.em_r)

    intent_model = BiLSTMEncoder(embeddings, option.im_hidden_size,
                                 option.im_num_layers)

    relation_models = {}
    for relation in yago_train_dataset.relations:
        relation_model = NeuralTensorNetwork_yago(embeddings, option.em_k)
        relation_models[relation] = relation_model

    if option.scorer_actv_func == 'sigmoid':
        scorer_actv_func = nn.Sigmoid
    elif option.scorer_actv_func == 'relu':
        scorer_actv_func = nn.ReLU
    elif option.scorer_actv_func == 'tanh':
        scorer_actv_func = nn.Tanh
    event_scorer = nn.Sequential(nn.Linear(option.em_k, 1), scorer_actv_func())
    relation_scorer = nn.Sequential(nn.Linear(option.em_k, 1),
                                    scorer_actv_func())

    intent_scorer = nn.CosineSimilarity(dim=1)
    sentiment_classifier = nn.Linear(option.em_k, 1)
    criterion = MarginLoss(option.margin)
    sentiment_criterion = nn.BCEWithLogitsLoss()
    yago_criterion = MarginLoss(option.yago_margin)

    # load pretrained embeddings
    embeddings.weight.data.copy_(torch.from_numpy(glove.embd).float())

    if not option.update_embeddings:
        event_model.embeddings.weight.requires_grad = False

    if option.use_gpu:
        event_model.cuda()
        intent_model.cuda()
        sentiment_classifier.cuda()
        event_scorer.cuda()
        relation_scorer.cuda()
        for relation_model in relation_models.values():
            relation_model.cuda()

    embeddings_param_id = [id(param) for param in embeddings.parameters()]
    params = [{
        'params': embeddings.parameters()
    }, {
        'params': [
            param for param in event_model.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }, {
        'params': [
            param for param in event_scorer.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }, {
        'params': [
            param for param in intent_model.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }, {
        'params': [
            param for param in sentiment_classifier.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }]
    for relation in relation_models:
        params.append({
            'params': [
                param for param in relation_models[relation].parameters()
                if id(param) not in embeddings_param_id
            ],
            'weight_decay':
            option.weight_decay
        })
    optimizer = torch.optim.Adagrad(params, lr=option.lr)

    # load checkpoint if provided:
    if option.load_checkpoint != '':
        checkpoint = torch.load(option.load_checkpoint)
        event_model.load_state_dict(checkpoint['event_model_state_dict'])
        intent_model.load_state_dict(checkpoint['intent_model_state_dict'])
        event_scorer.load_state_dict(checkpoint['event_scorer_state_dict'])
        sentiment_classifier.load_state_dict(
            checkpoint['sentiment_classifier_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        for relation in relation_models:
            relation_models[relation].load_state_dict(
                checkpoint['relation_model_state_dict'][relation])
        logging.info('Loaded checkpoint: ' + option.load_checkpoint)
    # load pretrained event model instead:
    elif option.pretrained_event_model != '':
        checkpoint = torch.load(option.pretrained_event_model)
        event_model.load_state_dict(checkpoint['model_state_dict'])
        logging.info('Loaded pretrained event model: ' +
                     option.pretrained_event_model)

    for epoch in range(option.epochs):
        epoch += 1
        logging.info('Epoch ' + str(epoch))

        # train set
        # train set
        avg_loss_e = 0
        avg_loss_i = 0
        avg_loss_s = 0
        avg_loss = 0
        avg_loss_sum = 0
        avg_loss_event = 0
        avg_loss_attr = 0
        k = 0
        # assume yago dataset is larger than atomic
        atomic_iterator = itertools.cycle(iter(train_data_loader))
        yago_iterator = iter(yago_train_data_loader)
        # iterate over yago dataset (atomic dataset is cycled)
        for i, yago_batch in enumerate(yago_iterator):
            atomic_batch = next(atomic_iterator)

            optimizer.zero_grad()
            loss, loss_e, loss_i, loss_s = run_batch(
                option, atomic_batch, event_model, intent_model, event_scorer,
                intent_scorer, sentiment_classifier, criterion,
                sentiment_criterion)
            loss.backward()
            optimizer.step()
            avg_loss_e += loss_e.item() / option.report_every
            avg_loss_i += loss_i.item() / option.report_every
            avg_loss_s += loss_s.item() / option.report_every
            avg_loss += loss.item() / option.report_every
            if i % option.report_every == 0:
                logging.info(
                    'Atomic batch %d, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f'
                    % (i, avg_loss_e, avg_loss_i, avg_loss_s, avg_loss))
                avg_loss_e = 0
                avg_loss_i = 0
                avg_loss_s = 0
                avg_loss = 0

            optimizer.zero_grad()
            loss_sum, loss_event, loss_attr = run_yago_batch(
                option, yago_batch, event_model, relation_models,
                yago_criterion, relation_scorer, event_scorer)
            loss_sum.backward()
            optimizer.step()
            avg_loss_sum += loss_sum.item() / option.report_every
            avg_loss_event += loss_event.item() / option.report_every
            avg_loss_attr += loss_attr.item() / option.report_every
            if i % option.report_every == 0:
                logging.info(
                    'YAGO batch %d, loss_event=%.4f, loss_attr=%.4f, loss=%.4f'
                    % (i, avg_loss_event, avg_loss_attr, avg_loss_sum))
                avg_loss_sum = 0
                avg_loss_event = 0
                avg_loss_attr = 0

        # dev set
        if option.dev_dataset is not None:
            event_model.eval()
            intent_model.eval()
            event_scorer.eval()
            sentiment_classifier.eval()
            batch = next(iter(dev_data_loader))
            with torch.no_grad():
                loss, loss_e, loss_i, loss_s = run_batch(
                    option, batch, event_model, intent_model, event_scorer,
                    intent_scorer, sentiment_classifier, criterion,
                    sentiment_criterion)
            logging.info(
                'Eval on dev set, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f'
                % (loss_e.item(), loss_i.item(), loss_s.item(), loss.item()))
            event_model.train()
            intent_model.train()
            event_scorer.train()
            sentiment_classifier.train()

        # dev set (yago)
        if option.yago_dev_dataset is not None:
            for key in relation_models.keys():
                relation_models[key].eval()
            relation_scorer.eval()
            event_model.eval()
            yago_dev_batch = next(iter(yago_dev_data_loader))
            with torch.no_grad():
                loss_sum, loss_event, loss_attr = run_yago_batch(
                    option, yago_dev_batch, event_model, relation_models,
                    criterion_yago, relation_scorer, event_scorer)
            logging.info(
                'Eval on yago dev set, loss_sum=%.4f, loss_event=%.4f, loss_attr=%.4f, '
                % (loss_sum.item(), loss_event.item(), loss_attr.item()))
            for key in relation_models.keys():
                relation_models[key].train()
            relation_scorer.train()

        if option.save_checkpoint != '':
            checkpoint = {
                'event_model_state_dict':
                event_model.state_dict(),
                'intent_model_state_dict':
                intent_model.state_dict(),
                'event_scorer_state_dict':
                event_scorer.state_dict(),
                'sentiment_classifier_state_dict':
                sentiment_classifier.state_dict(),
                'optimizer_state_dict':
                optimizer.state_dict(),
                'relation_model_state_dict': {
                    relation: relation_models[relation].state_dict()
                    for relation in relation_models
                }
            }
            torch.save(checkpoint, option.save_checkpoint + '_' + str(epoch))
            logging.info('Saved checkpoint: ' + option.save_checkpoint + '_' +
                         str(epoch))
        model = NN(embeddings, 2 * option.em_k, option.em_k)
    elif option.model == 'EMC':
        model = EMC(embeddings, 2 * option.em_k, option.em_k)
    else:
        logging.info('Unknown model type: ' + option.model)
        exit(1)

    checkpoint = torch.load(option.model_file)
    if type(checkpoint) == dict:
        if 'event_model_state_dict' in checkpoint:
            state_dict = checkpoint['event_model_state_dict']
        else:
            state_dict = checkpoint['model_state_dict']
    else:
        state_dict = checkpoint
    model.load_state_dict(state_dict)
    logging.info(option.model_file + ' loaded')

    # embeddings = nn.Embedding(option.vocab_size, option.emb_dim, padding_idx=1)
    # embeddings.weight.data = torch.from_numpy(glove.embd).float()
    # model = Averaging(embeddings)

    if option.use_gpu:
        model.cuda()
    model.eval()

    data_loader = torch.utils.data.DataLoader(
        dataset,
        collate_fn=TransitiveSentenceSimilarityDataset_collate_fn,
        shuffle=False,
        batch_size=len(dataset))
Exemple #5
0
def main(option):
    logging.basicConfig(
        stream=sys.stdout,
        level=logging.DEBUG,
        format=
        '%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
    )

    dataset = HardSimilarityDataset()
    if option.dataset_cache is None:
        glove = Glove(option.emb_file)
        logging.info('Embeddings loaded')
        dataset.load(option.dataset_file, glove)
    else:
        dataset.load_cache(option.dataset_cache)
    logging.info('Dataset loaded')

    embeddings = nn.Embedding(option.vocab_size, option.emb_dim, padding_idx=1)
    if option.model == 'NTN':
        model = NeuralTensorNetwork(embeddings, option.em_k)
    elif option.model == 'LowRankNTN':
        model = LowRankNeuralTensorNetwork(embeddings, option.em_k,
                                           option.em_r)
    elif option.model == 'RoleFactor':
        model = RoleFactoredTensorModel(embeddings, option.em_k)
    elif option.model == 'Predicate':
        model = PredicateTensorModel(embeddings)
    elif option.model == 'NN':
        model = NN(embeddings, 2 * option.em_k, option.em_k)
    elif option.model == 'EMC':
        model = EMC(embeddings, 2 * option.em_k, option.em_k)
    else:
        logging.info('Unknown model type: ' + option.model)
        exit(1)

    checkpoint = torch.load(option.model_file, map_location='cpu')
    if type(checkpoint) == dict:
        if 'event_model_state_dict' in checkpoint:
            state_dict = checkpoint['event_model_state_dict']
        else:
            state_dict = checkpoint['model_state_dict']
    else:
        state_dict = checkpoint
    model.load_state_dict(state_dict)
    logging.info(option.model_file + ' loaded')

    # embeddings = nn.Embedding(option.vocab_size, option.emb_dim, padding_idx=1)
    # embeddings.weight.data = torch.from_numpy(glove.embd).float()
    # model = Averaging(embeddings)

    if option.use_gpu:
        model.cuda()
    model.eval()

    data_loader = torch.utils.data.DataLoader(
        dataset,
        collate_fn=HardSimilarityDataset_collate_fn,
        shuffle=False,
        batch_size=len(dataset))
    batch = next(iter(data_loader))
    pos_e1_subj_id, pos_e1_subj_w, pos_e1_verb_id, pos_e1_verb_w, pos_e1_obj_id, pos_e1_obj_w, \
    pos_e2_subj_id, pos_e2_subj_w, pos_e2_verb_id, pos_e2_verb_w, pos_e2_obj_id, pos_e2_obj_w, \
    neg_e1_subj_id, neg_e1_subj_w, neg_e1_verb_id, neg_e1_verb_w, neg_e1_obj_id, neg_e1_obj_w, \
    neg_e2_subj_id, neg_e2_subj_w, neg_e2_verb_id, neg_e2_verb_w, neg_e2_obj_id, neg_e2_obj_w = batch

    if option.use_gpu:
        pos_e1_subj_id = pos_e1_subj_id.cuda()
        pos_e1_subj_w = pos_e1_subj_w.cuda()
        pos_e1_verb_id = pos_e1_verb_id.cuda()
        pos_e1_verb_w = pos_e1_verb_w.cuda()
        pos_e1_obj_id = pos_e1_obj_id.cuda()
        pos_e1_obj_w = pos_e1_obj_w.cuda()
        pos_e2_subj_id = pos_e2_subj_id.cuda()
        pos_e2_subj_w = pos_e2_subj_w.cuda()
        pos_e2_verb_id = pos_e2_verb_id.cuda()
        pos_e2_verb_w = pos_e2_verb_w.cuda()
        pos_e2_obj_id = pos_e2_obj_id.cuda()
        pos_e2_obj_w = pos_e2_obj_w.cuda()
        neg_e1_subj_id = neg_e1_subj_id.cuda()
        neg_e1_subj_w = neg_e1_subj_w.cuda()
        neg_e1_verb_id = neg_e1_verb_id.cuda()
        neg_e1_verb_w = neg_e1_verb_w.cuda()
        neg_e1_obj_id = neg_e1_obj_id.cuda()
        neg_e1_obj_w = neg_e1_obj_w.cuda()
        neg_e2_subj_id = neg_e2_subj_id.cuda()
        neg_e2_subj_w = neg_e2_subj_w.cuda()
        neg_e2_verb_id = neg_e2_verb_id.cuda()
        neg_e2_verb_w = neg_e2_verb_w.cuda()
        neg_e2_obj_id = neg_e2_obj_id.cuda()
        neg_e2_obj_w = neg_e2_obj_w.cuda()

    pos_e1_emb = model(pos_e1_subj_id, pos_e1_subj_w, pos_e1_verb_id,
                       pos_e1_verb_w, pos_e1_obj_id, pos_e1_obj_w)
    pos_e2_emb = model(pos_e2_subj_id, pos_e2_subj_w, pos_e2_verb_id,
                       pos_e2_verb_w, pos_e2_obj_id, pos_e2_obj_w)
    neg_e1_emb = model(neg_e1_subj_id, neg_e1_subj_w, neg_e1_verb_id,
                       neg_e1_verb_w, neg_e1_obj_id, neg_e1_obj_w)
    neg_e2_emb = model(neg_e2_subj_id, neg_e2_subj_w, neg_e2_verb_id,
                       neg_e2_verb_w, neg_e2_obj_id, neg_e2_obj_w)

    if option.distance_metric == 'cosine':
        distance_func = cosine_distance
    elif option.distance_metric == 'euclid':
        distance_func = euclid_distance
    pos_dist = distance_func(pos_e1_emb, pos_e2_emb)
    neg_dist = distance_func(neg_e1_emb, neg_e2_emb)
    num_correct = (pos_dist < neg_dist).sum().item()
    accuracy = num_correct / len(dataset)

    if option.output_file.strip() != '':
        output_file = open(option.output_file, 'w')
        for i, j, k in zip(pos_dist, neg_dist, (pos_dist < neg_dist)):
            output_file.write(
                ' '.join([str(i.item(
                )), str(j.item()), str(k.item())]) + '\n')
        output_file.close()
        logging.info('Output saved to ' + option.output_file)

    logging.info('Num correct: %d' % (num_correct, ))
    logging.info('Num total: %d' % (len(dataset), ))
    logging.info('Accuracy: %.4f' % (accuracy, ))

    if option.output_vectors is not None:
        vectors = torch.stack([pos_e1_emb, pos_e2_emb, neg_e1_emb, neg_e2_emb],
                              dim=1).cpu()
        torch.save(vectors, option.output_vectors)
        print('Vectors saved to %s' % (option.output_vectors, ))
Exemple #6
0
def main(option):
    logging.basicConfig(
        stream=sys.stdout,
        level=logging.DEBUG,
        format=
        '%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
    )
    torch.manual_seed(option.random_seed)

    glove = Glove(option.emb_file)
    logging.info('Embeddings loaded')

    embeddings = nn.Embedding(option.vocab_size, option.emb_dim, padding_idx=1)
    if option.model_type == 'NTN':
        event_model = NeuralTensorNetwork(embeddings, option.em_k)
    elif option.model_type == 'RoleFactor':
        event_model = RoleFactoredTensorModel(embeddings, option.em_k)
    elif option.model_type == 'LowRankNTN':
        event_model = LowRankNeuralTensorNetwork(embeddings, option.em_k,
                                                 option.em_r)

    intent_model = BiLSTMEncoder(embeddings, option.im_hidden_size,
                                 option.im_num_layers)

    if option.em_actv_func == 'sigmoid':
        em_actv_func = nn.Sigmoid()
    elif option.em_actv_func == 'relu':
        em_actv_func = nn.ReLU()
    elif option.em_actv_func == 'tanh':
        em_actv_func = nn.Tanh()
    else:
        logging.info('Unknown event activation func: ' + option.em_actv_func)
        exit(1)
    event_scorer = nn.Sequential(nn.Linear(option.em_k, 1), em_actv_func)

    intent_scorer = nn.CosineSimilarity(dim=1)
    sentiment_classifier = nn.Linear(option.em_k, 1)
    criterion = MarginLoss(option.margin)
    sentiment_criterion = nn.BCEWithLogitsLoss()

    # load pretrained embeddings
    embeddings.weight.data.copy_(torch.from_numpy(glove.embd).float())

    if not option.update_embeddings:
        event_model.embeddings.weight.requires_grad = False

    if option.use_gpu:
        event_model.cuda()
        intent_model.cuda()
        sentiment_classifier.cuda()
        event_scorer.cuda()

    embeddings_param_id = [id(param) for param in embeddings.parameters()]
    params = [{
        'params': embeddings.parameters()
    }, {
        'params': [
            param for param in event_model.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }, {
        'params': [
            param for param in event_scorer.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }, {
        'params': [
            param for param in intent_model.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }, {
        'params': [
            param for param in sentiment_classifier.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }]
    optimizer = torch.optim.Adagrad(params, lr=option.lr)

    # load checkpoint if provided:
    if option.load_checkpoint != '':
        checkpoint = torch.load(option.load_checkpoint)
        event_model.load_state_dict(checkpoint['event_model_state_dict'])
        intent_model.load_state_dict(checkpoint['intent_model_state_dict'])
        event_scorer.load_state_dict(checkpoint['event_scorer_state_dict'])
        sentiment_classifier.load_state_dict(
            checkpoint['sentiment_classifier_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        logging.info('Loaded checkpoint: ' + option.load_checkpoint)
    # load pretrained event model instead:
    elif option.pretrained_event_model != '':
        checkpoint = torch.load(option.pretrained_event_model)
        event_model.load_state_dict(checkpoint['model_state_dict'])
        logging.info('Loaded pretrained event model: ' +
                     option.pretrained_event_model)

    train_dataset = EventIntentSentimentDataset()
    logging.info('Loading train dataset: ' + option.train_dataset)
    train_dataset.load(option.train_dataset, glove)
    logging.info('Loaded train dataset: ' + option.train_dataset)
    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        collate_fn=EventIntentSentimentDataset_collate_fn,
        batch_size=option.batch_size,
        shuffle=True)

    if option.dev_dataset is not None:
        dev_dataset = EventIntentSentimentDataset()
        logging.info('Loading dev dataset: ' + option.dev_dataset)
        dev_dataset.load(option.dev_dataset, glove)
        logging.info('Loaded dev dataset: ' + option.dev_dataset)
        dev_data_loader = torch.utils.data.DataLoader(
            dev_dataset,
            collate_fn=EventIntentSentimentDataset_collate_fn,
            batch_size=len(dev_dataset),
            shuffle=False)

    for epoch in range(option.epochs):
        epoch += 1
        logging.info('Epoch ' + str(epoch))

        # train set
        avg_loss_e = 0
        avg_loss_i = 0
        avg_loss_s = 0
        avg_loss = 0
        for i, batch in enumerate(train_data_loader):
            i += 1
            optimizer.zero_grad()
            loss, loss_e, loss_i, loss_s = run_batch(
                option, batch, event_model, intent_model, event_scorer,
                intent_scorer, sentiment_classifier, criterion,
                sentiment_criterion)
            loss.backward()
            optimizer.step()

            avg_loss_e += loss_e.item() / option.report_every
            avg_loss_i += loss_i.item() / option.report_every
            avg_loss_s += loss_s.item() / option.report_every
            avg_loss += loss.item() / option.report_every
            if i % option.report_every == 0:
                logging.info(
                    'Batch %d, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f'
                    % (i, avg_loss_e, avg_loss_i, avg_loss_s, avg_loss))
                avg_loss_e = 0
                avg_loss_i = 0
                avg_loss_s = 0
                avg_loss = 0

        # dev set
        if option.dev_dataset is not None:
            event_model.eval()
            intent_model.eval()
            event_scorer.eval()
            sentiment_classifier.eval()
            batch = next(iter(dev_data_loader))
            loss, loss_e, loss_i, loss_s = run_batch(
                option, batch, event_model, intent_model, event_scorer,
                intent_scorer, sentiment_classifier, criterion,
                sentiment_criterion)
            logging.info(
                'Eval on dev set, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f'
                % (loss_e.item(), loss_i.item(), loss_s.item(), loss.item()))
            event_model.train()
            intent_model.train()
            event_scorer.train()
            sentiment_classifier.train()

        if option.save_checkpoint != '':
            checkpoint = {
                'event_model_state_dict':
                event_model.state_dict(),
                'intent_model_state_dict':
                intent_model.state_dict(),
                'event_scorer_state_dict':
                event_scorer.state_dict(),
                'sentiment_classifier_state_dict':
                sentiment_classifier.state_dict(),
                'optimizer_state_dict':
                optimizer.state_dict()
            }
            torch.save(checkpoint, option.save_checkpoint + '_' + str(epoch))
            logging.info('Saved checkpoint: ' + option.save_checkpoint + '_' +
                         str(epoch))
Exemple #7
0
def main(option):
    logging.basicConfig(
        stream=sys.stdout,
        level=logging.DEBUG,
        format=
        '%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
    )

    dataset = TransitiveSentenceSimilarityDataset()
    if option.dataset_cache is None:
        glove = Glove(option.emb_file)
        logging.info('Embeddings loaded')
        dataset.load(option.dataset_file, glove)
    else:
        dataset.load_cache(option.dataset_cache)
    logging.info('Dataset loaded')

    embeddings = nn.Embedding(option.vocab_size, option.emb_dim, padding_idx=1)
    if option.model == 'NTN':
        model = NeuralTensorNetwork(embeddings, option.em_k)
    elif option.model == 'LowRankNTN':
        model = LowRankNeuralTensorNetwork(embeddings, option.em_k,
                                           option.em_r)
    elif option.model == 'RoleFactor':
        model = RoleFactoredTensorModel(embeddings, option.em_k)
    elif option.model == 'Predicate':
        model = PredicateTensorModel(embeddings)
    elif option.model == 'NN':
        model = NN(embeddings, 2 * option.em_k, option.em_k)
    elif option.model == 'EMC':
        model = EMC(embeddings, 2 * option.em_k, option.em_k)
    else:
        logging.info('Unknown model type: ' + option.model)
        exit(1)

    checkpoint = torch.load(option.model_file, map_location='cpu')
    if type(checkpoint) == dict:
        if 'event_model_state_dict' in checkpoint:
            state_dict = checkpoint['event_model_state_dict']
        else:
            state_dict = checkpoint['model_state_dict']
    else:
        state_dict = checkpoint
    model.load_state_dict(state_dict)
    logging.info(option.model_file + ' loaded')

    # embeddings = nn.Embedding(option.vocab_size, option.emb_dim, padding_idx=1)
    # embeddings.weight.data = torch.from_numpy(glove.embd).float()
    # model = Averaging(embeddings)

    if option.use_gpu:
        model.cuda()
    model.eval()

    data_loader = torch.utils.data.DataLoader(
        dataset,
        collate_fn=TransitiveSentenceSimilarityDataset_collate_fn,
        shuffle=False,
        batch_size=len(dataset))
    batch = next(iter(data_loader))
    e1_subj_id, e1_subj_w, e1_verb_id, e1_verb_w, e1_obj_id, e1_obj_w, \
    e2_subj_id, e2_subj_w, e2_verb_id, e2_verb_w, e2_obj_id, e2_obj_w, \
    gold = batch

    if option.use_gpu:
        e1_subj_id = e1_subj_id.cuda()
        e1_subj_w = e1_subj_w.cuda()
        e1_verb_id = e1_verb_id.cuda()
        e1_verb_w = e1_verb_w.cuda()
        e1_obj_id = e1_obj_id.cuda()
        e1_obj_w = e1_obj_w.cuda()
        e2_subj_id = e2_subj_id.cuda()
        e2_subj_w = e2_subj_w.cuda()
        e2_verb_id = e2_verb_id.cuda()
        e2_verb_w = e2_verb_w.cuda()
        e2_obj_id = e2_obj_id.cuda()
        e2_obj_w = e2_obj_w.cuda()

    e1_emb = model(e1_subj_id, e1_subj_w, e1_verb_id, e1_verb_w, e1_obj_id,
                   e1_obj_w)
    e2_emb = model(e2_subj_id, e2_subj_w, e2_verb_id, e2_verb_w, e2_obj_id,
                   e2_obj_w)

    if option.distance_metric == 'cosine':
        distance_func = cosine_distance
    elif option.distance_metric == 'euclid':
        distance_func = euclid_distance

    pred = -distance_func(e1_emb, e2_emb)

    if option.use_gpu:
        pred = pred.cpu()
    pred = pred.detach().numpy()
    gold = gold.numpy()
    spearman_correlation, spearman_p = scipy.stats.spearmanr(pred, gold)

    if option.output_file.strip() != '':
        output_file = open(option.output_file, 'w')
        for score in pred:
            output_file.write(str(score) + '\n')
        output_file.close()
        logging.info('Output saved to ' + option.output_file)

    logging.info('Spearman correlation: %.4f' % (spearman_correlation, ))