Exemple #1
0
class Generator(nn.Module):
    # sents/trees/paths are specific to one sentence
    def forward(self, sents, trees, paths):
        hiddens, prediction = self.tree_model(sents, trees)
        return self.seqback_model(paths, hiddens[0])

    def __init__(self, vocab, embed, device):
        super().__init__()

        self.vocab = vocab
        self.device = device
        self.embed = embed

        # set seed for embedding metrics
        torch.manual_seed(args.seed)
        random.seed(args.seed)

        # initialize tree_model, criterion/loss_function, optimizer
        self.tree_model = TreeLSTM(
            self.vocab.size(),
            args.input_dim,
            args.mem_dim,
            args.hidden_dim,
            args.num_classes,
            args.sparse,
            args.freeze_embed,
            device=self.device)

        self.tree_criterion = nn.KLDivLoss()
        # todo: tree criterion might be useless
        self.tree_model.to(self.device), self.tree_criterion.to(self.device)
        # plug these into embedding matrix inside tree_model
        self.tree_model.emb.weight.data.copy_(self.embed)

        if args.optim == 'adam':
            self.optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                               self.tree_model.parameters()), lr=args.lr, weight_decay=args.wd)
        elif args.optim == 'adagrad':
            self.optimizer = optim.Adagrad(filter(lambda p: p.requires_grad,
                                                  self.tree_model.parameters()), lr=args.lr, weight_decay=args.wd)
        elif args.optim == 'sgd':
            self.optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                              self.tree_model.parameters()), lr=args.lr, weight_decay=args.wd)

        self.seqback_model = SeqbackLSTM(self.vocab, self.device)
        self.seqback_criterion = nn.CrossEntropyLoss()
        self.seqback_model.to(self.device), self.seqback_criterion.to(self.device)
        self.seqback_model.emb.weight.data.copy_(self.embed)
Exemple #2
0
    def __init__(self, path, vocab=None, embed=None, data_set=None):
        super().__init__()
        # GPU select
        args.cuda = args.cuda and torch.cuda.is_available()
        self.device = torch.device("cuda:0" if args.cuda else "cpu")

        self.data_set = data_set
        if embed is not None:
            self.embed = embed
        else:
            self.embed = self.data_set.build_embedding()
        # todo: torch save dataset

        if args.sparse and args.wd != 0:
            logger.error('Sparsity and weight decay are incompatible, pick one!')
            exit()

        # debugging args
        logger.debug(args)
        # set seed for embedding metrics
        torch.manual_seed(args.seed)
        random.seed(args.seed)

        self.vocab = vocab
        rel_vocab = "rel_vocab.txt"
        self.rel_emb_size = 50
        self.rel_vocab = Vocab(filename=rel_vocab, data=[util.UNK_WORD])
        self.rel_emb = torch.nn.Embedding(self.rel_vocab.size(), self.rel_emb_size).to(self.device)
        # initialize tree_model, criterion/loss_function, optimizer
        # assume mem_dim = hidden_dim
        if args.encode_rel:
            self.tree_model = TreeLSTM(
                self.vocab.size(),
                self.embed.shape[1],
                args.hidden_dim,
                args.hidden_dim,
                args.sparse,
                device=self.device,
                rel_dim=self.rel_emb_size,
                rel_emb=self.rel_emb)
        else:
            self.tree_model = TreeLSTM(
                self.vocab.size(),
                self.embed.shape[1],
                args.mem_dim,
                args.hidden_dim,
                args.num_classes,
                args.sparse,
                device=self.device)

        # self.tree_criterion = nn.KLDivLoss()
        self.tree_model.to(self.device)
        # plug these into embedding matrix inside tree_model
        self.tree_model.emb.weight.data.copy_(self.embed)

        if args.decode_word:
            self.seqback_model = TreeDecoder(
                self.vocab, self.device, self.rel_emb,
                embedding_dim=self.embed.shape[1], hidden_dim=args.hidden_dim)
        else:
            self.seqback_model = TreeDecoder(
                self.vocab, self.device, self.rel_emb, hidden_dim=args.hidden_dim,
                embedding_dim=0)
        self.seqback_criterion = nn.CrossEntropyLoss()
        self.seqback_model.to(self.device), self.seqback_criterion.to(self.device)

        self.teacher_forcing_ratio = args.tr
        self.null_token = self.tree_model.emb(torch.LongTensor([self.vocab.getIndex(util.SOS_WORD)]).to(self.device))

        if args.decode_word:
            self.seqback_model.emb.weight.data.copy_(self.embed)
        if args.optim == 'adam':
            self.optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                               self.parameters()), lr=args.lr, weight_decay=args.wd)
        elif args.optim == 'adagrad':
            self.optimizer = optim.Adagrad(filter(lambda p: p.requires_grad,
                                                  self.parameters()), lr=args.lr, weight_decay=args.wd)
        elif args.optim == 'sgd':
            self.optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                              self.parameters()), lr=args.lr, weight_decay=args.wd)
        torch.backends.cudnn.benchmark = True

    # get vocab object from vocab file previously written
    imdb_vocab_file = classificationConfig.vocab
    vocab = Vocab(filename=imdb_vocab_file,
                  data=[
                      Constants.PAD_WORD, Constants.UNK_WORD,
                      Constants.BOS_WORD, Constants.EOS_WORD
                  ])
    logger.debug('==> imdb vocabulary size : %d ' % vocab.size())
    emb_file = classificationConfig.embed
    emb = torch.load(emb_file)

    ## built treeLSTM model
    tree_model = TreeLSTM(vocab.size(), args.input_dim, args.mem_dim,
                          args.hidden_dim, args.num_classes, args.sparse,
                          args.freeze_embed, device)
    criterion = nn.CrossEntropyLoss()
    tree_model.to(device), criterion.to(device)
    tree_model.emb.weight.data.copy_(emb)
    with open('%s.pt' % os.path.join(args.save, args.expname), 'rb') as f:
        tree_model.load_state_dict(torch.load(f)['model'])
    tree_model.eval()

    # build dataset for seqbackLSTM
    # train_dir = classificationConfig.token_file_labels[0]
    # seqback_train_file = os.path.join(Global.external_tools, 'imdb_seqback_train.pth')
    # if os.path.isfile(seqback_train_file):
    # seqback_train_data = torch.load(seqback_train_file)
    # else:
    # seqback_train_data = seqbackDataset(train_dir, vocab, device, tree_model)
Exemple #4
0
class Generator(nn.Module):
    """
        Take a paragraph as input
    """

    def forward(self, sents, trees, masks=None):
        hiddens, prediction = self.tree_model(sents, trees, masks)
        outputs = []
        for i in range(len(trees)):
            output = {}
            trees[i].hidden = torch.cat([hiddens[i] for _ in range(self.seqback_model.nlayers)]).view(
                self.seqback_model.nlayers, 1, self.seqback_model.hidden_dim)
            trees[i].hidden = (trees[i].hidden, trees[i].hidden)
            if args.decode_word:
                sentence = self.seqback_model.emb(sents[i])
                self.traverse(trees[i], output, sentence)
            else:
                self.traverse(trees[i], output)

            try:
                output = torch.cat([output[i] for i in range(len(output))])
            except KeyError:
                print(self.seqback_model.sentences)
                self.debug(trees[i])
            outputs.append(output)
        return outputs

    def debug(self, trees):
        print(trees.idx)
        for i in range(trees.num_children):
            child = trees.children[i]
            self.debug(child)

    def traverse(self, node: treeNode, output, sentence=None):
        # print(node.relation)
        # print(node.num_children)
        relations = torch.tensor(node.relation, device=self.device, dtype=torch.long)
        prev = node
        for idx in range(node.num_children):
            child = node.children[idx]
            # print("idx: {}".format(idx))
            # print("child {}".format(child.idx))
            if sentence is not None:
                is_teacher = random.random() < self.teacher_forcing_ratio
                if prev.idx == -1:
                    word_emb = self.null_token.clone().detach()
                else:
                    predicted_word = self.tree_model.emb(torch.argmax(output[prev.idx]))
                    assert predicted_word.shape == sentence[prev.idx].shape
                    word_emb = sentence[prev.idx] if is_teacher else predicted_word
                output[child.idx], child.hidden = self.seqback_model(relations[idx],
                                                                     node.hidden, word=word_emb)
            else:
                output[child.idx], child.hidden = self.seqback_model(relations[idx], node.hidden)
            prev = child
            self.traverse(child, output, sentence)

    def __init__(self, path, vocab=None, embed=None, data_set=None):
        super().__init__()
        # GPU select
        args.cuda = args.cuda and torch.cuda.is_available()
        self.device = torch.device("cuda:0" if args.cuda else "cpu")

        self.data_set = data_set
        if embed is not None:
            self.embed = embed
        else:
            self.embed = self.data_set.build_embedding()
        # todo: torch save dataset

        if args.sparse and args.wd != 0:
            logger.error('Sparsity and weight decay are incompatible, pick one!')
            exit()

        # debugging args
        logger.debug(args)
        # set seed for embedding metrics
        torch.manual_seed(args.seed)
        random.seed(args.seed)

        self.vocab = vocab
        rel_vocab = "rel_vocab.txt"
        self.rel_emb_size = 50
        self.rel_vocab = Vocab(filename=rel_vocab, data=[util.UNK_WORD])
        self.rel_emb = torch.nn.Embedding(self.rel_vocab.size(), self.rel_emb_size).to(self.device)
        # initialize tree_model, criterion/loss_function, optimizer
        # assume mem_dim = hidden_dim
        if args.encode_rel:
            self.tree_model = TreeLSTM(
                self.vocab.size(),
                self.embed.shape[1],
                args.hidden_dim,
                args.hidden_dim,
                args.sparse,
                device=self.device,
                rel_dim=self.rel_emb_size,
                rel_emb=self.rel_emb)
        else:
            self.tree_model = TreeLSTM(
                self.vocab.size(),
                self.embed.shape[1],
                args.mem_dim,
                args.hidden_dim,
                args.num_classes,
                args.sparse,
                device=self.device)

        # self.tree_criterion = nn.KLDivLoss()
        self.tree_model.to(self.device)
        # plug these into embedding matrix inside tree_model
        self.tree_model.emb.weight.data.copy_(self.embed)

        if args.decode_word:
            self.seqback_model = TreeDecoder(
                self.vocab, self.device, self.rel_emb,
                embedding_dim=self.embed.shape[1], hidden_dim=args.hidden_dim)
        else:
            self.seqback_model = TreeDecoder(
                self.vocab, self.device, self.rel_emb, hidden_dim=args.hidden_dim,
                embedding_dim=0)
        self.seqback_criterion = nn.CrossEntropyLoss()
        self.seqback_model.to(self.device), self.seqback_criterion.to(self.device)

        self.teacher_forcing_ratio = args.tr
        self.null_token = self.tree_model.emb(torch.LongTensor([self.vocab.getIndex(util.SOS_WORD)]).to(self.device))

        if args.decode_word:
            self.seqback_model.emb.weight.data.copy_(self.embed)
        if args.optim == 'adam':
            self.optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                               self.parameters()), lr=args.lr, weight_decay=args.wd)
        elif args.optim == 'adagrad':
            self.optimizer = optim.Adagrad(filter(lambda p: p.requires_grad,
                                                  self.parameters()), lr=args.lr, weight_decay=args.wd)
        elif args.optim == 'sgd':
            self.optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                              self.parameters()), lr=args.lr, weight_decay=args.wd)

    def get_tree(self, tri_case):
        """
        :param tri_case: one sentence tri case
        :return: root node
        """
        tri_case.sort(key=lambda x: x[1][1])
        Nodes = dict()
        root = None
        for i in range(len(tri_case)):
            # if i not in Nodes.keys() and tri_case[i][0][1] != -1:
            if i not in Nodes.keys():
                idx = i
                prev = None
                rel = None
                while True:
                    tree = TreeNode()
                    Nodes[idx] = tree
                    tree.idx = idx
                    if prev is not None:
                        tree.add_child(
                            prev, self.rel_vocab.getIndex(rel, util.UNK_WORD))
                    parent = tri_case[idx][0][1]
                    parent_rel = tri_case[idx][2]
                    if parent in Nodes.keys():
                        Nodes[parent].add_child(
                            tree, self.rel_vocab.getIndex(parent_rel, util.UNK_WORD))
                        break
                    elif parent == -1:
                        root = TreeNode()
                        root.idx = -1
                        Nodes[-1] = root
                        root.add_child(
                            tree, self.rel_vocab.getIndex(parent_rel, util.UNK_WORD))
                        break
                    else:
                        prev = tree
                        rel = tri_case[idx][2]
                        idx = parent
        if root is None:
            print(tri_case)
        return root