def forward(self, web_page, examples): e_logits = self._encoding_model(web_page, examples, logits_only=True) a_logits = self._alignment_model(web_page, examples, logits_only=True) # Normalize e_logprobs = F.log_softmax(e_logits, dim=1) a_logprobs = F.log_softmax(a_logits, dim=1) logits = e_logprobs * self._weight[0] + a_logprobs * self._weight[1] # Filter the candidates node_filter_mask = self.node_filter(web_page, examples[0].web_page_code) log_node_filter_mask = V(FT([0. if x else -999999. for x in node_filter_mask])) logits = logits + log_node_filter_mask # Losses and predictions targets = V(LT([web_page.xid_to_ref.get(x.target_xid, 0) for x in examples])) mask = V(FT([int( x.target_xid in web_page.xid_to_ref and node_filter_mask[web_page.xid_to_ref[x.target_xid]] ) for x in examples])) losses = self.loss(logits, targets) * mask #print '=' * 20, examples[0].web_page_code #print [node_filter_mask[web_page.xid_to_ref.get(x.target_xid, 0)] for x in examples] #print [logits.data[i, web_page.xid_to_ref.get(x.target_xid, 0)] for (i, x) in enumerate(examples)] #print logits, targets, mask, losses if not np.isfinite(losses.data.sum()): #raise ValueError('Losses has NaN') logging.warn('Losses has NaN') #print losses # num_phrases x top_k top_k = min(self.top_k, len(web_page.nodes)) predictions = torch.topk(logits, top_k, dim=1)[1] return logits, losses, predictions
def forward(self, memory_cells, query): """Performs sentinel attention with a sentinel of 0. Returns the AttentionOutput where the weights do not include the sentinel weight. Args: memory_cells (Variable[FloatTensor]): batch x num_cells x cell_dim query (Variable[FloatTensor]): batch x query_dim Returns: AttentionOutput: weights do not include sentinel weights """ batch_size, _, cell_dim = memory_cells.values.size() sentinel = self._sentinel_embed.expand( batch_size, 1, cell_dim) sentinel_mask = V(torch.ones(batch_size, 1)) cell_values_with_sentinel = torch.cat( [memory_cells.values, sentinel], 1) cell_masks_with_sentinel = torch.cat( [memory_cells.mask, sentinel_mask], 1) cells_with_sentinel = SequenceBatch( cell_values_with_sentinel, cell_masks_with_sentinel, left_justify=False) attention_output = super(SentinelAttention, self).forward( cells_with_sentinel, query) # weights_with_sentinel = attention_output.weights # TODO: Bring this line in after torch v0.2.0 # weights_without_sentinel = weights_with_sentinel[batch_size, :-1] # attention_output = AttentionOutput( # weights=weights_without_sentinel, context=attention_output.context) return attention_output
def _get_neighbors(self, web_page): """Get indices of at most |max_neighbors| neighbors for each relation Args: web_page (WebPage) Returns: neighbors: SequenceBatch of shape num_nodes x ??? containing the neighbor refs (??? is at most max_neighbors * len(neighbor_rels)) rels: SequenceBatch of shape num_nodes x ??? containing the relation indices """ g = web_page.graph batch_neighbors = [[] for _ in range(len(web_page.nodes))] batch_rels = [[] for _ in range(len(web_page.nodes))] for src, tgts in g.nodes.items(): # Group by relation rel_to_tgts = defaultdict(list) for tgt, rels in tgts.items(): for rel in rels: rel_to_tgts[rel].append(tgt) # Sample if needed for rel, index in self._neighbor_rels.items(): tgts = rel_to_tgts[rel] random.shuffle(tgts) if not tgts: continue if len(tgts) > self._max_neighbors: tgts = tgts[:self._max_neighbors] batch_neighbors[src].extend(tgts) batch_rels[src].extend([index] * len(tgts)) # Create SequenceBatches max_len = max(len(x) for x in batch_neighbors) batch_mask = [] for neighbors, rels in zip(batch_neighbors, batch_rels): assert len(neighbors) == len(rels) this_len = len(neighbors) batch_mask.append([1.] * this_len + [0.] * (max_len - this_len)) neighbors.extend([0] * (max_len - this_len)) rels.extend([0] * (max_len - this_len)) return (SequenceBatch( V(torch.tensor(batch_neighbors, dtype=torch.long)), V(torch.tensor(batch_mask, dtype=torch.float32))), SequenceBatch(V(torch.tensor(batch_rels, dtype=torch.long)), V(torch.tensor(batch_mask, dtype=torch.float32))))
def forward(self, nodes): """Embeds a batch of Nodes. Args: nodes (list[Node]) Returns: embeddings (Tensor): num_nodes x embed_dim """ texts = [] for node in nodes: if not self.ablate_text: if self._recursive_texts: text = ' '.join(node.all_texts(max_words=self._max_words)) else: text = node.text or '' texts.append(word_tokenize2(text)) else: texts.append([]) text_embeddings = self._utterance_embedder(texts) # num_nodes x attr_embed_dim tags = [node.tag for node in nodes] tag_embeddings = self._tag_embedder.embed_tokens(tags) # num_nodes x attr_embed_dim if not self.ablate_attrs: ids = [word_tokenize2(node.id_) for node in nodes] else: ids = [[] for node in nodes] id_embeddings = self._id_embedder(ids) # num_nodes x attr_embed_dim if not self.ablate_attrs: classes = [ word_tokenize2(' '.join(node.classes)) for node in nodes ] else: classes = [[] for node in nodes] class_embeddings = self._classes_embedder(classes) if not self.ablate_attrs: other = [ word_tokenize2(semantic_attrs(node.attributes)) for node in nodes ] else: other = [[] for node in nodes] other_embeddings = self._other_embedder(other) # num_nodes x 3 coords = V( FT([[node.x_ratio, node.y_ratio, float(node.visible)] for node in nodes])) # num_nodes x dom_embed_dim dom_embeddings = torch.cat( (text_embeddings, tag_embeddings, id_embeddings, class_embeddings, other_embeddings, coords), dim=1) #dom_embeddings = text_embeddings return self.fc(dom_embeddings)
def __init__(self, encoding_model, alignment_model, node_filter, top_k=5): super(EnsembleModel, self).__init__() self._encoding_model = encoding_model self._alignment_model = alignment_model self._weight = V(torch.tensor([1.0, 1.0], dtype=torch.float32)) self.node_filter = node_filter self.loss = nn.CrossEntropyLoss(reduction="none") self.top_k = top_k
def __init__(self, encoding_model, alignment_model, node_filter, top_k=5): super(EnsembleModel, self).__init__() self._encoding_model = encoding_model self._alignment_model = alignment_model self._weight = V(FT([1.0, 1.0])) self.node_filter = node_filter self.loss = nn.CrossEntropyLoss(reduce=False) self.top_k = top_k
def from_sequences(cls, sequences, vocab_or_vocabs, min_seq_length=0): """Convert a batch of sequences into a SequenceBatch. Args: sequences (list[list[unicode]]) vocab_or_vocabs (WordVocab|list[WordVocab]): either a single vocab, or a list of vocabs, one per sequence min_seq_length (int): enforce that the Tensor representing the SequenceBatch have at least this many columns. Returns: SequenceBatch """ # determine dimensions batch_size = len(sequences) if batch_size == 0: seq_length = 0 else: seq_length = max(len(seq) for seq in sequences) # max seq length in batch seq_length = max( seq_length, min_seq_length) # make sure it is at least min_seq_length shape = (batch_size, seq_length) # set up vocabs if isinstance(vocab_or_vocabs, list): vocabs = vocab_or_vocabs assert len(vocabs) == batch_size else: # duplicate a single vocab assert isinstance(vocab_or_vocabs, Vocab) vocabs = [vocab_or_vocabs] * batch_size # build arrays values = np.zeros(shape, dtype=np.int64) # pad with zeros mask = np.zeros(shape, dtype=np.float32) for i, (seq, vocab) in enumerate(zip(sequences, vocabs)): for j, word in enumerate(seq): values[i, j] = vocab.word2index(word) mask[i, j] = 1.0 return SequenceBatch(V(torch.from_numpy(values)), V(torch.from_numpy(mask)))
def forward(self, old_embeds, neighbors, rels): batch_size = len(old_embeds) projected = F.relu(self._proj(self._dropout(old_embeds))) neighbor_embeds = torch.index_select(projected, 0, neighbors.values.view(-1)) neighbor_embeds = neighbor_embeds.view(batch_size, neighbors.values.shape[1], -1) combined = torch.cat((projected.unsqueeze(1), neighbor_embeds), dim=1) mask = torch.cat((V(torch.ones(batch_size, 1)), neighbors.mask), dim=1) combined = SequenceBatch(combined, mask) return SequenceBatch.reduce_max(combined)
def tile_state(h, batch_size): """Tile a given hidden state batch_size times. Args: h (Variable): a single hidden state of shape (hidden_dim,) batch_size (int) Returns: a Variable of shape (batch_size, hidden_dim) """ tiler = V(torch.ones(batch_size, 1)) return torch.mm(tiler, h.unsqueeze(0)) # (batch_size, hidden_size)
def _mask_logits(cls, logits, mask): no_cells = cls._no_cells(mask) # (batch_size, num_cells) suppress = V(torch.zeros(*mask.size())) # send the logit of non-cells to -infinity suppress[mask == 0] = float('-inf') # but if an entire row has no cells, just leave the cells alone suppress[no_cells == 1] = 0.0 logits = logits + suppress # -inf + anything = -inf return logits
def embed_tokens(self, tokens): """Embed list of tokens. Args: tokens (list[unicode]) Returns: embeds (Variable[FloatTensor]): of shape (len(tokens), embed_dim) """ vocab = self.vocab indices = V( torch.tensor([vocab.word2index(t) for t in tokens], dtype=torch.long)) return self._embedding(indices)
def forward(self, nodes): """Embeds a batch of Nodes. Args: nodes (list[Node]) Returns: embeddings (Tensor): num_nodes x embed_dim """ texts = [] utterance_embedder = self._utterance_embedder for node in nodes: if self._recursive_texts: text = ' '.join(node.all_texts(max_words=self._max_words)) else: text = node.text or '' texts.append(utterance_embedder.tokenize(text.lower())) text_embeddings = self._utterance_embedder(texts) # num_nodes x attr_embed_dim tag_embeddings = self._tag_embedder.embed_tokens( [node.tag for node in nodes]) # num_nodes x attr_embed_dim id_embedder = self._id_embedder id_embeddings = self._id_embedder( [id_embedder.tokenize(node.id_) for node in nodes]) # num_nodes x attr_embed_dim classes_embedder = self._classes_embedder class_embeddings = self._classes_embedder([ classes_embedder.tokenize(' '.join(node.classes)) for node in nodes ]) # num_nodes x 3 coords = V( torch.tensor([[elem.x_ratio, elem.y_ratio, float(elem.visible)] for elem in nodes], dtype=torch.float32)) # num_nodes x dom_embed_dim dom_embeddings = torch.cat((text_embeddings, tag_embeddings, id_embeddings, class_embeddings, coords), dim=1) # dom_embeddings = text_embeddings return self.fc(dom_embeddings)
def __init__(self, num_embeddings, embedding_dim, initial_embeddings, **kwargs): """Constructs TrainFlagEmbedding with embeddings initialized with initial_embeddings. Args: num_embeddings (int) embedding_dim (int) initial_embeddings (np.array): (num_embeddings, embedding_dim) trainable (bool): if False, weights matrix will not change. (default True) kwargs: all other supported keywords in torch.nn.Embeddings. """ super(TrainFlagEmbedding, self).__init__() trainable = kwargs.pop("trainable", True) self._trainable = trainable if trainable: embedding = Embedding(num_embeddings, embedding_dim, **kwargs) embedding.weight.data.copy_(torch.from_numpy(initial_embeddings)) self._embedding = embedding else: self._fixed_weight = V(torch.from_numpy(initial_embeddings))
def reduce_sum(cls, seq_batch): weights = V(torch.ones(*seq_batch.mask.size())) return cls.weighted_sum(seq_batch, weights)
def forward(self, memory_cells, query): batch_size, num_cells = memory_cells.mask.size() logits = V(torch.zeros(batch_size, num_cells)) weights = V(torch.zeros(batch_size, num_cells)) context = V(torch.zeros(batch_size, self.memory_dim)) return AttentionOutput(weights=weights, context=context, logits=logits)
def __init__(self, input_dim, hidden_dim): super(AdditionCell, self).__init__() self.W = V(torch.eye(input_dim, hidden_dim)) # truncates input if input_dim > hidden_dim # pads with zeros if input_dim < hidden_dim self.hidden_size = hidden_dim
def _mask_weights(cls, weights, mask): # if a given row has no memory cells, weights should be all zeros no_cells = cls._no_cells(mask) all_zeros = V(torch.zeros(*mask.size())) weights = conditional(no_cells, all_zeros, weights) return weights
def forward(self, web_page, examples, logits_only=False): """Compute predictions and loss. Args: web_page (WebPage): The web page of the examples examples (list[PhraseNodeExample]): Must be from the same web page. logits_only (bool) Returns: logits (Tensor): num_phrases x num_nodes Each entry (i,j) is the logit for p(node_j | phrase_i) losses (Tensor): num_phrases predictions (Tensor): num_phrases """ phrase_embedder = self.phrase_embedder def max_scorer(pairwise_scores): """ Args: pairwise_scores: num_nodes x phrase_len x max_text_len """ scores = torch.max(pairwise_scores, dim=1)[0] return torch.max(scores, dim=1)[0] def cnn_scorer(pairwise_scores): """ Args: pairwise_scores: num_nodes x phrase_len x max_text_len """ scores = torch.unsqueeze(pairwise_scores, dim=1) scores = self.conv2d(scores) scores = self.conv2d_dilated(scores) scores = self.pooler(scores) scores = torch.squeeze(scores, dim=1) # dim = scores.shape[1]*scores.shape[2] scores = scores.view(-1,self.score_dim) if self.use_tags: tags = [node.tag for node in web_page.nodes] tag_embeddings = self._tag_embedder.embed_tokens(tags) scores = torch.cat((scores,tag_embeddings), dim=1) scores = self.project_tag(scores) scores = self.scorer(scores) scores = torch.squeeze(scores, dim=1) return scores def neighbor_cnn_scorer(pairwise_scores): """ Args: pairwise_scores: num_nodes x phrase_len x max_text_len """ scores = torch.unsqueeze(pairwise_scores, dim=1) scores = self.conv2d(scores) scores = self.conv2d_dilated(scores) scores = self.pooler(scores) scores = torch.squeeze(scores, dim=1) # dim = scores.shape[1]*scores.shape[2] scores = scores.view(-1,self.score_dim) if self.use_tags: tags = [node.tag for node in web_page.nodes] tag_embeddings = self._tag_embedder.embed_tokens(tags) scores = torch.cat((scores,tag_embeddings), dim=1) scores = self.project_tag(scores) return scores # Tokenize the nodes # num_nodes x text_length x embed_dim texts = [] for node in web_page.nodes: text = ' '.join(node.all_texts(max_words=self.max_words)) output = [] if not self.ablate_text: output += phrase_embedder.tokenize(text) if not self.ablate_attrs: # TODO better way to include attributes? output += phrase_embedder.tokenize(semantic_attrs(node.attributes)) texts.append(output) embedded_texts = embed_tokens(self.token_embedder, self.max_words, texts) embedded_texts_values = self.dropout(embedded_texts.values) embedded_texts = embedded_texts_values * embedded_texts.mask.unsqueeze(2) # Tokenize the phrases # num_phrases x phrase_length x embed_dim logits = [] if not self.use_neighbors: for example in examples: phrase = [phrase_embedder.tokenize(example.phrase)] embedded_phrase = embed_tokens(self.token_embedder, self.max_words, phrase) embedded_phrase_values = self.dropout(embedded_phrase.values) # expand: num_nodes x phrase_len x embed_dim batch_phrase = embedded_phrase_values.expand(len(texts), -1, -1) # permute embedded_texts: num_nodes x embed_dim x max_text_len pairwise_scores = torch.bmm(batch_phrase, embedded_texts.permute(0, 2, 1)) # compute scores scores = cnn_scorer(pairwise_scores) logits.append(torch.unsqueeze(scores, dim=0)) else: intermediate_scores = [] for example in examples: phrase = [phrase_embedder.tokenize(example.phrase)] embedded_phrase = embed_tokens(self.token_embedder, self.max_words, phrase) embedded_phrase_values = self.dropout(embedded_phrase.values) # expand: num_nodes x phrase_len x embed_dim batch_phrase = embedded_phrase_values.expand(len(texts), -1, -1) # permuted embedded_texts: num_nodes x embed_dim x max_text_len pairwise_scores = torch.bmm(batch_phrase, embedded_texts.permute(0, 2, 1)) node_score = neighbor_cnn_scorer(pairwise_scores) intermediate_scores.append(node_score) neighbors, masks = web_page.get_spatial_neighbors() neighbors, masks = V(torch.tensor(neighbors, dtype=torch.long)), V(torch.tensor(masks, dtype=torch.float32)) masks = masks.unsqueeze(dim=2) # each node_score tensor is parameterized by phrase for node_score in intermediate_scores: # get pairwise_scores for all neighbors... # neighbors, rels = self._get_neighbors(web_page) batch_size = len(node_score) neighbor_scores = torch.index_select(node_score, 0, neighbors.view(-1)) neighbor_scores = neighbor_scores.view(batch_size, neighbors.shape[1], -1) neighbor_scores = neighbor_scores * masks if neighbor_scores.shape[1] < self.num_rels: more = self.num_rels - neighbor_scores.shape[1] num_nodes, _, embed_dim = neighbor_scores.shape padding = V(torch.zeros(num_nodes, more, embed_dim)) neighbor_scores = torch.cat((neighbor_scores, padding), dim=1) # num_nodes x num_neighbors x intermediate_score_dim node_score = torch.unsqueeze(node_score, dim=1) scores = torch.cat((node_score, neighbor_scores), dim=1) scores = scores.view(node_score.shape[0], -1) scores = self._final_neighbor_linear(scores) scores = torch.squeeze(scores, dim=1) logits.append(torch.unsqueeze(scores, dim=0)) logits = torch.cat(logits, dim=0) # Filter the candidates node_filter_mask = self.node_filter(web_page, examples[0].web_page_code) # what does this do? log_node_filter_mask = V(torch.tensor([0. if x else -999999. for x in node_filter_mask], dtype=torch.float32)) logits = logits + log_node_filter_mask if logits_only: return logits # Losses and predictions targets = V(torch.tensor([web_page.xid_to_ref.get(x.target_xid, 0) for x in examples], dtype=torch.long)) mask = V(torch.tensor([int(x.target_xid in web_page.xid_to_ref and node_filter_mask[web_page.xid_to_ref[x.target_xid]]) for x in examples], dtype=torch.float32)) losses = self.loss(logits, targets) * mask # print '=' * 20, examples[0].web_page_code # print [node_filter_mask[web_page.xid_to_ref.get(x.target_xid, 0)] for x in examples] # print [logits.detach()[i, web_page.xid_to_ref.get(x.target_xid, 0)] for (i, x) in enumerate(examples)] # print logits, targets, mask, losses if not isfinite(losses.detach().sum()): # raise ValueError('Losses has NaN') logging.warning('Losses has NaN') # print losses # num_phrases x top_k top_k = min(self.top_k, len(web_page.nodes)) predictions = torch.topk(logits, top_k, dim=1)[1] return logits, losses, predictions
def forward(self, web_page, examples, logits_only=False): """Compute predictions and loss. Args: web_page (WebPage): The web page of the examples examples (list[PhraseNodeExample]): Must be from the same web page. Returns: logits (Tensor): num_phrases x num_nodes Each entry (i,j) is the logit for p(node_j | phrase_i) losses (Tensor): num_phrases predictions (Tensor): num_phrases """ # Embed the nodes + normalize # num_nodes x dim node_embeddings = self.node_embedder(web_page.nodes) node_embeddings = node_embeddings / torch.clamp( node_embeddings.norm(p=2, dim=1, keepdim=True), min=1e-8) # Embed the phrases + normalize phrases = [] for example in examples: phrases.append(word_tokenize(example.phrase.lower())) # num_phrases x dim phrase_embeddings = self.phrase_embedder(phrases) if self.proj is not None: phrase_embeddings = F.sigmoid(self.proj(phrase_embeddings)) else: pass phrase_embeddings = phrase_embeddings / torch.clamp( phrase_embeddings.norm(p=2, dim=1, keepdim=True), min=1e-8) ps = torch.split(phrase_embeddings, 1, dim=0) logits = [] # only loop on phrases if not self.use_neighbors: for p in ps: p = p.expand(node_embeddings.shape[0], -1) encoding = torch.cat((p, node_embeddings, p * node_embeddings), dim=1) encoding = self.dropout(encoding) scores = self.score(encoding) logits.append(scores) else: neighbors, masks = web_page.get_spatial_neighbors() neighbors, masks = V(LT(neighbors)), V(FT(masks)) masks = masks.unsqueeze(dim=2) for p in ps: batch_size = node_embeddings.shape[0] p = p.expand(batch_size, -1) neighbor_scores = torch.index_select(node_embeddings, 0, neighbors.view(-1)) neighbor_scores = neighbor_scores.view(batch_size, neighbors.shape[1], -1) neighbor_scores = neighbor_scores * masks encoding = [p] neighbor_scores = torch.split(neighbor_scores, 1, dim=1) neighbor_scores = [ torch.squeeze(x, dim=1) for x in neighbor_scores ] for n in [node_embeddings] + neighbor_scores: encoding += [n, p * n] encoding = torch.cat(encoding, dim=1) encoding = self.dropout(encoding) scores = self.score(encoding) logits.append(scores) logits = torch.cat(logits, dim=1) logits = logits.permute(1, 0) # Filter the candidates node_filter_mask = self.node_filter(web_page, examples[0].web_page_code) log_node_filter_mask = V( FT([0. if x else -999999. for x in node_filter_mask])) logits = logits + log_node_filter_mask if logits_only: return logits # Losses and predictions targets = V( LT([web_page.xid_to_ref.get(x.target_xid, 0) for x in examples])) mask = V( FT([ int(x.target_xid in web_page.xid_to_ref and node_filter_mask[web_page.xid_to_ref[x.target_xid]]) for x in examples ])) losses = self.loss(logits, targets) * mask #print '=' * 20, examples[0].web_page_code #print [node_filter_mask[web_page.xid_to_ref.get(x.target_xid, 0)] for x in examples] #print [logits.data[i, web_page.xid_to_ref.get(x.target_xid, 0)] for (i, x) in enumerate(examples)] #print logits, targets, mask, losses if not np.isfinite(losses.data.sum().item()): #raise ValueError('Losses has NaN') logging.warn('Losses has NaN') #print losses # num_phrases x top_k top_k = min(self.top_k, len(web_page.nodes)) predictions = torch.topk(logits, top_k, dim=1)[1] return logits, losses, predictions