Exemplo n.º 1
0
def main(option):
    glove = Glove(option.emb_file)
    print(option.emb_file + ' loaded')

    embedding = Embedding(option.vocab_size, option.emb_dim, padding_idx=1)
    if option.model_type == 'NTN':
        model = NeuralTensorNetwork(embedding, option.em_k)
    elif option.model_type == 'LowRankNTN':
        model = LowRankNeuralTensorNetwork(embedding, option.em_k, option.em_r)

    checkpoint = torch.load(option.model_file, map_location='cpu')
    if type(checkpoint) == dict:
        if 'event_model_state_dict' in checkpoint:
            state_dict = checkpoint['event_model_state_dict']
        else:
            state_dict = checkpoint['model_state_dict']
    else:
        state_dict = checkpoint
    model.load_state_dict(state_dict)
    print(option.model_file + ' loaded')
    model.eval()
    model.to(option.device)

    all_subj_id = []
    all_subj_w = []
    all_verb_id = []
    all_verb_w = []
    all_obj_id = []
    all_obj_w = []
    all_labels = []
    all_event_texts = []
    for label, filename in enumerate(option.input_files):
        print('loading ' + filename)
        lines = open(filename, 'r').readlines()
        for line in lines:
            subj, verb, obj = line.lower().strip().split(' | ')
            event_text = '(' + subj + ', ' + verb + ', ' + obj + ')'
            subj = subj.split(' ')
            verb = verb.split(' ')
            obj = obj.split(' ')
            subj_id, subj_w = glove.transform(subj, 10)
            verb_id, verb_w = glove.transform(verb, 10)
            obj_id, obj_w = glove.transform(obj, 10)
            if subj_id is not None and verb_id is not None and obj_id is not None and event_text not in all_event_texts:
                all_subj_id.append(subj_id)
                all_subj_w.append(subj_w)
                all_verb_id.append(verb_id)
                all_verb_w.append(verb_w)
                all_obj_id.append(obj_id)
                all_obj_w.append(obj_w)
                all_labels.append(label)
                all_event_texts.append(event_text)

    all_subj_id = torch.tensor(all_subj_id,
                               dtype=torch.long,
                               device=option.device)
    all_subj_w = torch.tensor(all_subj_w,
                              dtype=torch.float,
                              device=option.device)
    all_verb_id = torch.tensor(all_verb_id,
                               dtype=torch.long,
                               device=option.device)
    all_verb_w = torch.tensor(all_verb_w,
                              dtype=torch.float,
                              device=option.device)
    all_obj_id = torch.tensor(all_obj_id,
                              dtype=torch.long,
                              device=option.device)
    all_obj_w = torch.tensor(all_obj_w,
                             dtype=torch.float,
                             device=option.device)
    all_event_embeddings = model(all_subj_id, all_subj_w, all_verb_id,
                                 all_verb_w, all_obj_id,
                                 all_obj_w).detach().cpu()

    torch.save(
        {
            'embeddings': all_event_embeddings,
            'labels': torch.tensor(all_labels, dtype=torch.long),
            'event_texts': all_event_texts
        }, option.output_file)
    print('saved to ' + option.output_file)
Exemplo n.º 2
0
def main(option):
    random.seed(option.random_seed)
    torch.manual_seed(option.random_seed)

    LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
    logging.basicConfig(format=LOG_FORMAT, level='INFO', stream=sys.stdout)

    glove = Glove(option.emb_file)
    logging.info('loaded embeddings from ' + option.emb_file)

    src_vocab = Vocab.build_from_glove(glove)
    tgt_vocab = Vocab.load(option.intent_vocab)

    train_dataset = load_intent_prediction_dataset(option.train_dataset,
                                                   src_vocab,
                                                   tgt_vocab,
                                                   device=option.device)
    dev_dataset = load_intent_prediction_dataset(option.dev_dataset,
                                                 src_vocab,
                                                 tgt_vocab,
                                                 device=option.device)

    train_data_loader = DataLoader(train_dataset,
                                   batch_size=option.batch_size,
                                   shuffle=True)
    dev_data_loader = DataLoader(dev_dataset,
                                 batch_size=len(dev_dataset),
                                 shuffle=False)

    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)

    # Prepare loss
    weight = torch.ones(tgt_vocab_size)
    pad = tgt_vocab.stoi[tgt_vocab.pad_token]
    loss = Perplexity(weight, pad)
    loss.criterion.to(option.device)

    # Initialize model
    encoder = NeuralTensorNetwork(nn.Embedding(src_vocab_size, option.emb_dim),
                                  option.em_k)
    decoder = DecoderRNN(tgt_vocab_size,
                         option.im_max_len,
                         option.im_hidden_size,
                         use_attention=False,
                         bidirectional=False,
                         eos_id=tgt_vocab.stoi[tgt_vocab.eos_token],
                         sos_id=tgt_vocab.stoi[tgt_vocab.bos_token])
    encoder.to(option.device)
    decoder.to(option.device)

    init_model(encoder)
    init_model(decoder)

    encoder.embeddings.weight.data.copy_(torch.from_numpy(glove.embd).float())

    optimizer_params = [{
        'params': encoder.parameters()
    }, {
        'params': decoder.parameters()
    }]
    optimizer = Optimizer(optim.Adam(optimizer_params, lr=option.lr),
                          max_grad_norm=5)
    trainer = NTNTrainer(loss,
                         print_every=option.report_every,
                         device=option.device)
    encoder, decoder = trainer.train(
        encoder,
        decoder,
        optimizer,
        train_data_loader,
        num_epochs=option.epochs,
        dev_data_loader=dev_data_loader,
        teacher_forcing_ratio=option.im_teacher_forcing_ratio)

    predictor = NTNPredictor(encoder, decoder, src_vocab, tgt_vocab,
                             option.device)
    samples = [
        ("PersonX", "eventually told", "___"),
        ("PersonX", "tells", "PersonY 's tale"),
        ("PersonX", "always played", " ___"),
        ("PersonX", "would teach", "PersonY"),
        ("PersonX", "gets", "a ride"),
    ]
    for sample in samples:
        subj, verb, obj = sample
        subj = subj.lower().split(' ')
        verb = verb.lower().split(' ')
        obj = obj.lower().split(' ')
        print(sample, predictor.predict(subj, verb, obj))