class Generator(nn.Module): # sents/trees/paths are specific to one sentence def forward(self, sents, trees, paths): hiddens, prediction = self.tree_model(sents, trees) return self.seqback_model(paths, hiddens[0]) def __init__(self, vocab, embed, device): super().__init__() self.vocab = vocab self.device = device self.embed = embed # set seed for embedding metrics torch.manual_seed(args.seed) random.seed(args.seed) # initialize tree_model, criterion/loss_function, optimizer self.tree_model = TreeLSTM( self.vocab.size(), args.input_dim, args.mem_dim, args.hidden_dim, args.num_classes, args.sparse, args.freeze_embed, device=self.device) self.tree_criterion = nn.KLDivLoss() # todo: tree criterion might be useless self.tree_model.to(self.device), self.tree_criterion.to(self.device) # plug these into embedding matrix inside tree_model self.tree_model.emb.weight.data.copy_(self.embed) if args.optim == 'adam': self.optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.tree_model.parameters()), lr=args.lr, weight_decay=args.wd) elif args.optim == 'adagrad': self.optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, self.tree_model.parameters()), lr=args.lr, weight_decay=args.wd) elif args.optim == 'sgd': self.optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.tree_model.parameters()), lr=args.lr, weight_decay=args.wd) self.seqback_model = SeqbackLSTM(self.vocab, self.device) self.seqback_criterion = nn.CrossEntropyLoss() self.seqback_model.to(self.device), self.seqback_criterion.to(self.device) self.seqback_model.emb.weight.data.copy_(self.embed)
def __init__(self, path, vocab=None, embed=None, data_set=None): super().__init__() # GPU select args.cuda = args.cuda and torch.cuda.is_available() self.device = torch.device("cuda:0" if args.cuda else "cpu") self.data_set = data_set if embed is not None: self.embed = embed else: self.embed = self.data_set.build_embedding() # todo: torch save dataset if args.sparse and args.wd != 0: logger.error('Sparsity and weight decay are incompatible, pick one!') exit() # debugging args logger.debug(args) # set seed for embedding metrics torch.manual_seed(args.seed) random.seed(args.seed) self.vocab = vocab rel_vocab = "rel_vocab.txt" self.rel_emb_size = 50 self.rel_vocab = Vocab(filename=rel_vocab, data=[util.UNK_WORD]) self.rel_emb = torch.nn.Embedding(self.rel_vocab.size(), self.rel_emb_size).to(self.device) # initialize tree_model, criterion/loss_function, optimizer # assume mem_dim = hidden_dim if args.encode_rel: self.tree_model = TreeLSTM( self.vocab.size(), self.embed.shape[1], args.hidden_dim, args.hidden_dim, args.sparse, device=self.device, rel_dim=self.rel_emb_size, rel_emb=self.rel_emb) else: self.tree_model = TreeLSTM( self.vocab.size(), self.embed.shape[1], args.mem_dim, args.hidden_dim, args.num_classes, args.sparse, device=self.device) # self.tree_criterion = nn.KLDivLoss() self.tree_model.to(self.device) # plug these into embedding matrix inside tree_model self.tree_model.emb.weight.data.copy_(self.embed) if args.decode_word: self.seqback_model = TreeDecoder( self.vocab, self.device, self.rel_emb, embedding_dim=self.embed.shape[1], hidden_dim=args.hidden_dim) else: self.seqback_model = TreeDecoder( self.vocab, self.device, self.rel_emb, hidden_dim=args.hidden_dim, embedding_dim=0) self.seqback_criterion = nn.CrossEntropyLoss() self.seqback_model.to(self.device), self.seqback_criterion.to(self.device) self.teacher_forcing_ratio = args.tr self.null_token = self.tree_model.emb(torch.LongTensor([self.vocab.getIndex(util.SOS_WORD)]).to(self.device)) if args.decode_word: self.seqback_model.emb.weight.data.copy_(self.embed) if args.optim == 'adam': self.optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=args.lr, weight_decay=args.wd) elif args.optim == 'adagrad': self.optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, self.parameters()), lr=args.lr, weight_decay=args.wd) elif args.optim == 'sgd': self.optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.parameters()), lr=args.lr, weight_decay=args.wd)
torch.backends.cudnn.benchmark = True # get vocab object from vocab file previously written imdb_vocab_file = classificationConfig.vocab vocab = Vocab(filename=imdb_vocab_file, data=[ Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD ]) logger.debug('==> imdb vocabulary size : %d ' % vocab.size()) emb_file = classificationConfig.embed emb = torch.load(emb_file) ## built treeLSTM model tree_model = TreeLSTM(vocab.size(), args.input_dim, args.mem_dim, args.hidden_dim, args.num_classes, args.sparse, args.freeze_embed, device) criterion = nn.CrossEntropyLoss() tree_model.to(device), criterion.to(device) tree_model.emb.weight.data.copy_(emb) with open('%s.pt' % os.path.join(args.save, args.expname), 'rb') as f: tree_model.load_state_dict(torch.load(f)['model']) tree_model.eval() # build dataset for seqbackLSTM # train_dir = classificationConfig.token_file_labels[0] # seqback_train_file = os.path.join(Global.external_tools, 'imdb_seqback_train.pth') # if os.path.isfile(seqback_train_file): # seqback_train_data = torch.load(seqback_train_file) # else: # seqback_train_data = seqbackDataset(train_dir, vocab, device, tree_model)
class Generator(nn.Module): """ Take a paragraph as input """ def forward(self, sents, trees, masks=None): hiddens, prediction = self.tree_model(sents, trees, masks) outputs = [] for i in range(len(trees)): output = {} trees[i].hidden = torch.cat([hiddens[i] for _ in range(self.seqback_model.nlayers)]).view( self.seqback_model.nlayers, 1, self.seqback_model.hidden_dim) trees[i].hidden = (trees[i].hidden, trees[i].hidden) if args.decode_word: sentence = self.seqback_model.emb(sents[i]) self.traverse(trees[i], output, sentence) else: self.traverse(trees[i], output) try: output = torch.cat([output[i] for i in range(len(output))]) except KeyError: print(self.seqback_model.sentences) self.debug(trees[i]) outputs.append(output) return outputs def debug(self, trees): print(trees.idx) for i in range(trees.num_children): child = trees.children[i] self.debug(child) def traverse(self, node: treeNode, output, sentence=None): # print(node.relation) # print(node.num_children) relations = torch.tensor(node.relation, device=self.device, dtype=torch.long) prev = node for idx in range(node.num_children): child = node.children[idx] # print("idx: {}".format(idx)) # print("child {}".format(child.idx)) if sentence is not None: is_teacher = random.random() < self.teacher_forcing_ratio if prev.idx == -1: word_emb = self.null_token.clone().detach() else: predicted_word = self.tree_model.emb(torch.argmax(output[prev.idx])) assert predicted_word.shape == sentence[prev.idx].shape word_emb = sentence[prev.idx] if is_teacher else predicted_word output[child.idx], child.hidden = self.seqback_model(relations[idx], node.hidden, word=word_emb) else: output[child.idx], child.hidden = self.seqback_model(relations[idx], node.hidden) prev = child self.traverse(child, output, sentence) def __init__(self, path, vocab=None, embed=None, data_set=None): super().__init__() # GPU select args.cuda = args.cuda and torch.cuda.is_available() self.device = torch.device("cuda:0" if args.cuda else "cpu") self.data_set = data_set if embed is not None: self.embed = embed else: self.embed = self.data_set.build_embedding() # todo: torch save dataset if args.sparse and args.wd != 0: logger.error('Sparsity and weight decay are incompatible, pick one!') exit() # debugging args logger.debug(args) # set seed for embedding metrics torch.manual_seed(args.seed) random.seed(args.seed) self.vocab = vocab rel_vocab = "rel_vocab.txt" self.rel_emb_size = 50 self.rel_vocab = Vocab(filename=rel_vocab, data=[util.UNK_WORD]) self.rel_emb = torch.nn.Embedding(self.rel_vocab.size(), self.rel_emb_size).to(self.device) # initialize tree_model, criterion/loss_function, optimizer # assume mem_dim = hidden_dim if args.encode_rel: self.tree_model = TreeLSTM( self.vocab.size(), self.embed.shape[1], args.hidden_dim, args.hidden_dim, args.sparse, device=self.device, rel_dim=self.rel_emb_size, rel_emb=self.rel_emb) else: self.tree_model = TreeLSTM( self.vocab.size(), self.embed.shape[1], args.mem_dim, args.hidden_dim, args.num_classes, args.sparse, device=self.device) # self.tree_criterion = nn.KLDivLoss() self.tree_model.to(self.device) # plug these into embedding matrix inside tree_model self.tree_model.emb.weight.data.copy_(self.embed) if args.decode_word: self.seqback_model = TreeDecoder( self.vocab, self.device, self.rel_emb, embedding_dim=self.embed.shape[1], hidden_dim=args.hidden_dim) else: self.seqback_model = TreeDecoder( self.vocab, self.device, self.rel_emb, hidden_dim=args.hidden_dim, embedding_dim=0) self.seqback_criterion = nn.CrossEntropyLoss() self.seqback_model.to(self.device), self.seqback_criterion.to(self.device) self.teacher_forcing_ratio = args.tr self.null_token = self.tree_model.emb(torch.LongTensor([self.vocab.getIndex(util.SOS_WORD)]).to(self.device)) if args.decode_word: self.seqback_model.emb.weight.data.copy_(self.embed) if args.optim == 'adam': self.optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=args.lr, weight_decay=args.wd) elif args.optim == 'adagrad': self.optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, self.parameters()), lr=args.lr, weight_decay=args.wd) elif args.optim == 'sgd': self.optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.parameters()), lr=args.lr, weight_decay=args.wd) def get_tree(self, tri_case): """ :param tri_case: one sentence tri case :return: root node """ tri_case.sort(key=lambda x: x[1][1]) Nodes = dict() root = None for i in range(len(tri_case)): # if i not in Nodes.keys() and tri_case[i][0][1] != -1: if i not in Nodes.keys(): idx = i prev = None rel = None while True: tree = TreeNode() Nodes[idx] = tree tree.idx = idx if prev is not None: tree.add_child( prev, self.rel_vocab.getIndex(rel, util.UNK_WORD)) parent = tri_case[idx][0][1] parent_rel = tri_case[idx][2] if parent in Nodes.keys(): Nodes[parent].add_child( tree, self.rel_vocab.getIndex(parent_rel, util.UNK_WORD)) break elif parent == -1: root = TreeNode() root.idx = -1 Nodes[-1] = root root.add_child( tree, self.rel_vocab.getIndex(parent_rel, util.UNK_WORD)) break else: prev = tree rel = tri_case[idx][2] idx = parent if root is None: print(tri_case) return root