Esempio n. 1
0
def main(option):
    logging.basicConfig(
        stream=sys.stdout,
        level=logging.DEBUG,
        format=
        '%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
    )
    torch.manual_seed(option.random_seed)

    glove = Glove(option.emb_file)
    logging.info('Embeddings loaded')

    train_dataset = EventIntentSentimentDataset()
    logging.info('Loading train dataset: ' + option.train_dataset)
    train_dataset.load(option.train_dataset, glove)
    logging.info('Loaded train dataset: ' + option.train_dataset)
    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        collate_fn=EventIntentSentimentDataset_collate_fn,
        batch_size=option.batch_size,
        shuffle=True)

    if option.dev_dataset is not None:
        dev_dataset = EventIntentSentimentDataset()
        logging.info('Loading dev dataset: ' + option.dev_dataset)
        dev_dataset.load(option.dev_dataset, glove)
        logging.info('Loaded dev dataset: ' + option.dev_dataset)
        dev_data_loader = torch.utils.data.DataLoader(
            dev_dataset,
            collate_fn=EventIntentSentimentDataset_collate_fn,
            batch_size=len(dev_dataset),
            shuffle=False)

    yago_train_dataset = YagoDataset()
    logging.info('Loading YAGO train dataset: ' + option.yago_train_dataset)
    yago_train_dataset.load(option.yago_train_dataset, glove)
    yago_train_data_loader = torch.utils.data.DataLoader(
        yago_train_dataset,
        collate_fn=YagoDataset_collate_fn,
        batch_size=option.batch_size,
        shuffle=True)

    if option.yago_dev_dataset is not None:
        yago_dev_dataset = YagoDataset()
        logging.info('Loading YAGO dev dataset: ' + option.yago_dev_dataset)
        yago_dev_dataset.load(option.yago_dev_dataset, glove)
        yago_dev_data_loader = torch.utils.data.DataLoader(
            yago_dev_dataset,
            collate_fn=YagoDataset_collate_fn,
            batch_size=len(yago_dev_dataset),
            shuffle=False)

    embeddings = nn.Embedding(option.vocab_size, option.emb_dim, padding_idx=1)
    if option.model_type == 'NTN':
        event_model = NeuralTensorNetwork(embeddings, option.em_k)
    elif option.model_type == 'RoleFactor':
        event_model = RoleFactoredTensorModel(embeddings, option.em_k)
    elif option.model_type == 'LowRankNTN':
        event_model = LowRankNeuralTensorNetwork(embeddings, option.em_k,
                                                 option.em_r)

    intent_model = BiLSTMEncoder(embeddings, option.im_hidden_size,
                                 option.im_num_layers)

    relation_models = {}
    for relation in yago_train_dataset.relations:
        relation_model = NeuralTensorNetwork_yago(embeddings, option.em_k)
        relation_models[relation] = relation_model

    if option.scorer_actv_func == 'sigmoid':
        scorer_actv_func = nn.Sigmoid
    elif option.scorer_actv_func == 'relu':
        scorer_actv_func = nn.ReLU
    elif option.scorer_actv_func == 'tanh':
        scorer_actv_func = nn.Tanh
    event_scorer = nn.Sequential(nn.Linear(option.em_k, 1), scorer_actv_func())
    relation_scorer = nn.Sequential(nn.Linear(option.em_k, 1),
                                    scorer_actv_func())

    intent_scorer = nn.CosineSimilarity(dim=1)
    sentiment_classifier = nn.Linear(option.em_k, 1)
    criterion = MarginLoss(option.margin)
    sentiment_criterion = nn.BCEWithLogitsLoss()
    yago_criterion = MarginLoss(option.yago_margin)

    # load pretrained embeddings
    embeddings.weight.data.copy_(torch.from_numpy(glove.embd).float())

    if not option.update_embeddings:
        event_model.embeddings.weight.requires_grad = False

    if option.use_gpu:
        event_model.cuda()
        intent_model.cuda()
        sentiment_classifier.cuda()
        event_scorer.cuda()
        relation_scorer.cuda()
        for relation_model in relation_models.values():
            relation_model.cuda()

    embeddings_param_id = [id(param) for param in embeddings.parameters()]
    params = [{
        'params': embeddings.parameters()
    }, {
        'params': [
            param for param in event_model.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }, {
        'params': [
            param for param in event_scorer.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }, {
        'params': [
            param for param in intent_model.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }, {
        'params': [
            param for param in sentiment_classifier.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }]
    for relation in relation_models:
        params.append({
            'params': [
                param for param in relation_models[relation].parameters()
                if id(param) not in embeddings_param_id
            ],
            'weight_decay':
            option.weight_decay
        })
    optimizer = torch.optim.Adagrad(params, lr=option.lr)

    # load checkpoint if provided:
    if option.load_checkpoint != '':
        checkpoint = torch.load(option.load_checkpoint)
        event_model.load_state_dict(checkpoint['event_model_state_dict'])
        intent_model.load_state_dict(checkpoint['intent_model_state_dict'])
        event_scorer.load_state_dict(checkpoint['event_scorer_state_dict'])
        sentiment_classifier.load_state_dict(
            checkpoint['sentiment_classifier_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        for relation in relation_models:
            relation_models[relation].load_state_dict(
                checkpoint['relation_model_state_dict'][relation])
        logging.info('Loaded checkpoint: ' + option.load_checkpoint)
    # load pretrained event model instead:
    elif option.pretrained_event_model != '':
        checkpoint = torch.load(option.pretrained_event_model)
        event_model.load_state_dict(checkpoint['model_state_dict'])
        logging.info('Loaded pretrained event model: ' +
                     option.pretrained_event_model)

    for epoch in range(option.epochs):
        epoch += 1
        logging.info('Epoch ' + str(epoch))

        # train set
        # train set
        avg_loss_e = 0
        avg_loss_i = 0
        avg_loss_s = 0
        avg_loss = 0
        avg_loss_sum = 0
        avg_loss_event = 0
        avg_loss_attr = 0
        k = 0
        # assume yago dataset is larger than atomic
        atomic_iterator = itertools.cycle(iter(train_data_loader))
        yago_iterator = iter(yago_train_data_loader)
        # iterate over yago dataset (atomic dataset is cycled)
        for i, yago_batch in enumerate(yago_iterator):
            atomic_batch = next(atomic_iterator)

            optimizer.zero_grad()
            loss, loss_e, loss_i, loss_s = run_batch(
                option, atomic_batch, event_model, intent_model, event_scorer,
                intent_scorer, sentiment_classifier, criterion,
                sentiment_criterion)
            loss.backward()
            optimizer.step()
            avg_loss_e += loss_e.item() / option.report_every
            avg_loss_i += loss_i.item() / option.report_every
            avg_loss_s += loss_s.item() / option.report_every
            avg_loss += loss.item() / option.report_every
            if i % option.report_every == 0:
                logging.info(
                    'Atomic batch %d, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f'
                    % (i, avg_loss_e, avg_loss_i, avg_loss_s, avg_loss))
                avg_loss_e = 0
                avg_loss_i = 0
                avg_loss_s = 0
                avg_loss = 0

            optimizer.zero_grad()
            loss_sum, loss_event, loss_attr = run_yago_batch(
                option, yago_batch, event_model, relation_models,
                yago_criterion, relation_scorer, event_scorer)
            loss_sum.backward()
            optimizer.step()
            avg_loss_sum += loss_sum.item() / option.report_every
            avg_loss_event += loss_event.item() / option.report_every
            avg_loss_attr += loss_attr.item() / option.report_every
            if i % option.report_every == 0:
                logging.info(
                    'YAGO batch %d, loss_event=%.4f, loss_attr=%.4f, loss=%.4f'
                    % (i, avg_loss_event, avg_loss_attr, avg_loss_sum))
                avg_loss_sum = 0
                avg_loss_event = 0
                avg_loss_attr = 0

        # dev set
        if option.dev_dataset is not None:
            event_model.eval()
            intent_model.eval()
            event_scorer.eval()
            sentiment_classifier.eval()
            batch = next(iter(dev_data_loader))
            with torch.no_grad():
                loss, loss_e, loss_i, loss_s = run_batch(
                    option, batch, event_model, intent_model, event_scorer,
                    intent_scorer, sentiment_classifier, criterion,
                    sentiment_criterion)
            logging.info(
                'Eval on dev set, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f'
                % (loss_e.item(), loss_i.item(), loss_s.item(), loss.item()))
            event_model.train()
            intent_model.train()
            event_scorer.train()
            sentiment_classifier.train()

        # dev set (yago)
        if option.yago_dev_dataset is not None:
            for key in relation_models.keys():
                relation_models[key].eval()
            relation_scorer.eval()
            event_model.eval()
            yago_dev_batch = next(iter(yago_dev_data_loader))
            with torch.no_grad():
                loss_sum, loss_event, loss_attr = run_yago_batch(
                    option, yago_dev_batch, event_model, relation_models,
                    criterion_yago, relation_scorer, event_scorer)
            logging.info(
                'Eval on yago dev set, loss_sum=%.4f, loss_event=%.4f, loss_attr=%.4f, '
                % (loss_sum.item(), loss_event.item(), loss_attr.item()))
            for key in relation_models.keys():
                relation_models[key].train()
            relation_scorer.train()

        if option.save_checkpoint != '':
            checkpoint = {
                'event_model_state_dict':
                event_model.state_dict(),
                'intent_model_state_dict':
                intent_model.state_dict(),
                'event_scorer_state_dict':
                event_scorer.state_dict(),
                'sentiment_classifier_state_dict':
                sentiment_classifier.state_dict(),
                'optimizer_state_dict':
                optimizer.state_dict(),
                'relation_model_state_dict': {
                    relation: relation_models[relation].state_dict()
                    for relation in relation_models
                }
            }
            torch.save(checkpoint, option.save_checkpoint + '_' + str(epoch))
            logging.info('Saved checkpoint: ' + option.save_checkpoint + '_' +
                         str(epoch))
        # dev set
        event_model.eval()
        intent_model.eval()
        event_scorer.eval()
        sentiment_classifier.eval()
        batch = next(iter(dev_data_loader))
        loss, loss_e, loss_i, loss_s = run_batch(option, batch, event_model,
                                                 intent_model, event_scorer,
                                                 intent_scorer,
                                                 sentiment_classifier,
                                                 criterion,
                                                 sentiment_criterion)
        logging.info(
            'Eval on dev set, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f'
            % (loss_e.item(), loss_i.item(), loss_s.item(), loss.item()))
        event_model.train()
        intent_model.train()
        event_scorer.train()
        sentiment_classifier.train()

        if option.save_checkpoint != '':
            checkpoint = {
                'event_model_state_dict':
                event_model.state_dict(),
                'intent_model_state_dict':
                intent_model.state_dict(),
                'event_scorer_state_dict':
                event_scorer.state_dict(),
                'sentiment_classifier_state_dict':
                sentiment_classifier.state_dict(),
                'optimizer_state_dict':
Esempio n. 3
0
def main(option):
    logging.basicConfig(
        stream=sys.stdout,
        level=logging.DEBUG,
        format=
        '%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
    )
    torch.manual_seed(option.random_seed)

    glove = Glove(option.emb_file)
    logging.info('Embeddings loaded')

    embeddings = nn.Embedding(option.vocab_size, option.emb_dim, padding_idx=1)
    if option.model_type == 'NTN':
        event_model = NeuralTensorNetwork(embeddings, option.em_k)
    elif option.model_type == 'RoleFactor':
        event_model = RoleFactoredTensorModel(embeddings, option.em_k)
    elif option.model_type == 'LowRankNTN':
        event_model = LowRankNeuralTensorNetwork(embeddings, option.em_k,
                                                 option.em_r)

    intent_model = BiLSTMEncoder(embeddings, option.im_hidden_size,
                                 option.im_num_layers)

    if option.em_actv_func == 'sigmoid':
        em_actv_func = nn.Sigmoid()
    elif option.em_actv_func == 'relu':
        em_actv_func = nn.ReLU()
    elif option.em_actv_func == 'tanh':
        em_actv_func = nn.Tanh()
    else:
        logging.info('Unknown event activation func: ' + option.em_actv_func)
        exit(1)
    event_scorer = nn.Sequential(nn.Linear(option.em_k, 1), em_actv_func)

    intent_scorer = nn.CosineSimilarity(dim=1)
    sentiment_classifier = nn.Linear(option.em_k, 1)
    criterion = MarginLoss(option.margin)
    sentiment_criterion = nn.BCEWithLogitsLoss()

    # load pretrained embeddings
    embeddings.weight.data.copy_(torch.from_numpy(glove.embd).float())

    if not option.update_embeddings:
        event_model.embeddings.weight.requires_grad = False

    if option.use_gpu:
        event_model.cuda()
        intent_model.cuda()
        sentiment_classifier.cuda()
        event_scorer.cuda()

    embeddings_param_id = [id(param) for param in embeddings.parameters()]
    params = [{
        'params': embeddings.parameters()
    }, {
        'params': [
            param for param in event_model.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }, {
        'params': [
            param for param in event_scorer.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }, {
        'params': [
            param for param in intent_model.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }, {
        'params': [
            param for param in sentiment_classifier.parameters()
            if id(param) not in embeddings_param_id
        ],
        'weight_decay':
        option.weight_decay
    }]
    optimizer = torch.optim.Adagrad(params, lr=option.lr)

    # load checkpoint if provided:
    if option.load_checkpoint != '':
        checkpoint = torch.load(option.load_checkpoint)
        event_model.load_state_dict(checkpoint['event_model_state_dict'])
        intent_model.load_state_dict(checkpoint['intent_model_state_dict'])
        event_scorer.load_state_dict(checkpoint['event_scorer_state_dict'])
        sentiment_classifier.load_state_dict(
            checkpoint['sentiment_classifier_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        logging.info('Loaded checkpoint: ' + option.load_checkpoint)
    # load pretrained event model instead:
    elif option.pretrained_event_model != '':
        checkpoint = torch.load(option.pretrained_event_model)
        event_model.load_state_dict(checkpoint['model_state_dict'])
        logging.info('Loaded pretrained event model: ' +
                     option.pretrained_event_model)

    train_dataset = EventIntentSentimentDataset()
    logging.info('Loading train dataset: ' + option.train_dataset)
    train_dataset.load(option.train_dataset, glove)
    logging.info('Loaded train dataset: ' + option.train_dataset)
    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        collate_fn=EventIntentSentimentDataset_collate_fn,
        batch_size=option.batch_size,
        shuffle=True)

    if option.dev_dataset is not None:
        dev_dataset = EventIntentSentimentDataset()
        logging.info('Loading dev dataset: ' + option.dev_dataset)
        dev_dataset.load(option.dev_dataset, glove)
        logging.info('Loaded dev dataset: ' + option.dev_dataset)
        dev_data_loader = torch.utils.data.DataLoader(
            dev_dataset,
            collate_fn=EventIntentSentimentDataset_collate_fn,
            batch_size=len(dev_dataset),
            shuffle=False)

    for epoch in range(option.epochs):
        epoch += 1
        logging.info('Epoch ' + str(epoch))

        # train set
        avg_loss_e = 0
        avg_loss_i = 0
        avg_loss_s = 0
        avg_loss = 0
        for i, batch in enumerate(train_data_loader):
            i += 1
            optimizer.zero_grad()
            loss, loss_e, loss_i, loss_s = run_batch(
                option, batch, event_model, intent_model, event_scorer,
                intent_scorer, sentiment_classifier, criterion,
                sentiment_criterion)
            loss.backward()
            optimizer.step()

            avg_loss_e += loss_e.item() / option.report_every
            avg_loss_i += loss_i.item() / option.report_every
            avg_loss_s += loss_s.item() / option.report_every
            avg_loss += loss.item() / option.report_every
            if i % option.report_every == 0:
                logging.info(
                    'Batch %d, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f'
                    % (i, avg_loss_e, avg_loss_i, avg_loss_s, avg_loss))
                avg_loss_e = 0
                avg_loss_i = 0
                avg_loss_s = 0
                avg_loss = 0

        # dev set
        if option.dev_dataset is not None:
            event_model.eval()
            intent_model.eval()
            event_scorer.eval()
            sentiment_classifier.eval()
            batch = next(iter(dev_data_loader))
            loss, loss_e, loss_i, loss_s = run_batch(
                option, batch, event_model, intent_model, event_scorer,
                intent_scorer, sentiment_classifier, criterion,
                sentiment_criterion)
            logging.info(
                'Eval on dev set, loss_e=%.4f, loss_i=%.4f, loss_s=%.4f, loss=%.4f'
                % (loss_e.item(), loss_i.item(), loss_s.item(), loss.item()))
            event_model.train()
            intent_model.train()
            event_scorer.train()
            sentiment_classifier.train()

        if option.save_checkpoint != '':
            checkpoint = {
                'event_model_state_dict':
                event_model.state_dict(),
                'intent_model_state_dict':
                intent_model.state_dict(),
                'event_scorer_state_dict':
                event_scorer.state_dict(),
                'sentiment_classifier_state_dict':
                sentiment_classifier.state_dict(),
                'optimizer_state_dict':
                optimizer.state_dict()
            }
            torch.save(checkpoint, option.save_checkpoint + '_' + str(epoch))
            logging.info('Saved checkpoint: ' + option.save_checkpoint + '_' +
                         str(epoch))