Ejemplo n.º 1
0
    def forward(self, web_page, examples):
        e_logits = self._encoding_model(web_page, examples, logits_only=True)
        a_logits = self._alignment_model(web_page, examples, logits_only=True)

        # Normalize
        e_logprobs = F.log_softmax(e_logits, dim=1)
        a_logprobs = F.log_softmax(a_logits, dim=1)

        logits = e_logprobs * self._weight[0] + a_logprobs * self._weight[1]

        # Filter the candidates
        node_filter_mask = self.node_filter(web_page, examples[0].web_page_code)
        log_node_filter_mask = V(FT([0. if x else -999999. for x in node_filter_mask]))
        logits = logits + log_node_filter_mask
        # Losses and predictions
        targets = V(LT([web_page.xid_to_ref.get(x.target_xid, 0) for x in examples]))
        mask = V(FT([int(
                x.target_xid in web_page.xid_to_ref
                and node_filter_mask[web_page.xid_to_ref[x.target_xid]]
            ) for x in examples]))
        losses = self.loss(logits, targets) * mask
        #print '=' * 20, examples[0].web_page_code
        #print [node_filter_mask[web_page.xid_to_ref.get(x.target_xid, 0)] for x in examples]
        #print [logits.data[i, web_page.xid_to_ref.get(x.target_xid, 0)] for (i, x) in enumerate(examples)]
        #print logits, targets, mask, losses
        if not np.isfinite(losses.data.sum()):
            #raise ValueError('Losses has NaN')
            logging.warn('Losses has NaN')
            #print losses
        # num_phrases x top_k
        top_k = min(self.top_k, len(web_page.nodes))
        predictions = torch.topk(logits, top_k, dim=1)[1]
        return logits, losses, predictions
Ejemplo n.º 2
0
    def forward(self, memory_cells, query):
        """Performs sentinel attention with a sentinel of 0. Returns the
        AttentionOutput where the weights do not include the sentinel weight.

        Args:
            memory_cells (Variable[FloatTensor]): batch x num_cells x cell_dim
            query (Variable[FloatTensor]): batch x query_dim

        Returns:
            AttentionOutput: weights do not include sentinel weights
        """
        batch_size, _, cell_dim = memory_cells.values.size()
        sentinel = self._sentinel_embed.expand(
                batch_size, 1, cell_dim)
        sentinel_mask = V(torch.ones(batch_size, 1))

        cell_values_with_sentinel = torch.cat(
            [memory_cells.values, sentinel], 1)
        cell_masks_with_sentinel = torch.cat(
            [memory_cells.mask, sentinel_mask], 1)
        cells_with_sentinel = SequenceBatch(
            cell_values_with_sentinel, cell_masks_with_sentinel,
            left_justify=False)

        attention_output = super(SentinelAttention, self).forward(
                cells_with_sentinel, query)
        # weights_with_sentinel = attention_output.weights

        # TODO: Bring this line in after torch v0.2.0
        # weights_without_sentinel = weights_with_sentinel[batch_size, :-1]
        # attention_output = AttentionOutput(
        #   weights=weights_without_sentinel, context=attention_output.context)
        return attention_output
Ejemplo n.º 3
0
    def _get_neighbors(self, web_page):
        """Get indices of at most |max_neighbors| neighbors for each relation

        Args:
            web_page (WebPage)
        Returns:
            neighbors: SequenceBatch of shape num_nodes x ???
                containing the neighbor refs
                (??? is at most max_neighbors * len(neighbor_rels))
            rels: SequenceBatch of shape num_nodes x ???
                containing the relation indices
        """
        g = web_page.graph
        batch_neighbors = [[] for _ in range(len(web_page.nodes))]
        batch_rels = [[] for _ in range(len(web_page.nodes))]
        for src, tgts in g.nodes.items():
            # Group by relation
            rel_to_tgts = defaultdict(list)
            for tgt, rels in tgts.items():
                for rel in rels:
                    rel_to_tgts[rel].append(tgt)
            # Sample if needed
            for rel, index in self._neighbor_rels.items():
                tgts = rel_to_tgts[rel]
                random.shuffle(tgts)
                if not tgts:
                    continue
                if len(tgts) > self._max_neighbors:
                    tgts = tgts[:self._max_neighbors]
                batch_neighbors[src].extend(tgts)
                batch_rels[src].extend([index] * len(tgts))
        # Create SequenceBatches
        max_len = max(len(x) for x in batch_neighbors)
        batch_mask = []
        for neighbors, rels in zip(batch_neighbors, batch_rels):
            assert len(neighbors) == len(rels)
            this_len = len(neighbors)
            batch_mask.append([1.] * this_len + [0.] * (max_len - this_len))
            neighbors.extend([0] * (max_len - this_len))
            rels.extend([0] * (max_len - this_len))
        return (SequenceBatch(
            V(torch.tensor(batch_neighbors, dtype=torch.long)),
            V(torch.tensor(batch_mask, dtype=torch.float32))),
                SequenceBatch(V(torch.tensor(batch_rels, dtype=torch.long)),
                              V(torch.tensor(batch_mask,
                                             dtype=torch.float32))))
Ejemplo n.º 4
0
    def forward(self, nodes):
        """Embeds a batch of Nodes.

        Args:
            nodes (list[Node])
        Returns:
            embeddings (Tensor): num_nodes x embed_dim
        """
        texts = []
        for node in nodes:
            if not self.ablate_text:
                if self._recursive_texts:
                    text = ' '.join(node.all_texts(max_words=self._max_words))
                else:
                    text = node.text or ''
                texts.append(word_tokenize2(text))
            else:
                texts.append([])
        text_embeddings = self._utterance_embedder(texts)

        # num_nodes x attr_embed_dim
        tags = [node.tag for node in nodes]
        tag_embeddings = self._tag_embedder.embed_tokens(tags)
        # num_nodes x attr_embed_dim
        if not self.ablate_attrs:
            ids = [word_tokenize2(node.id_) for node in nodes]
        else:
            ids = [[] for node in nodes]
        id_embeddings = self._id_embedder(ids)
        # num_nodes x attr_embed_dim
        if not self.ablate_attrs:
            classes = [
                word_tokenize2(' '.join(node.classes)) for node in nodes
            ]
        else:
            classes = [[] for node in nodes]
        class_embeddings = self._classes_embedder(classes)

        if not self.ablate_attrs:
            other = [
                word_tokenize2(semantic_attrs(node.attributes))
                for node in nodes
            ]
        else:
            other = [[] for node in nodes]
        other_embeddings = self._other_embedder(other)
        # num_nodes x 3
        coords = V(
            FT([[node.x_ratio, node.y_ratio,
                 float(node.visible)] for node in nodes]))

        # num_nodes x dom_embed_dim
        dom_embeddings = torch.cat(
            (text_embeddings, tag_embeddings, id_embeddings, class_embeddings,
             other_embeddings, coords),
            dim=1)
        #dom_embeddings = text_embeddings
        return self.fc(dom_embeddings)
Ejemplo n.º 5
0
    def __init__(self, encoding_model, alignment_model, node_filter, top_k=5):
        super(EnsembleModel, self).__init__()
        self._encoding_model = encoding_model
        self._alignment_model = alignment_model
        self._weight = V(torch.tensor([1.0, 1.0], dtype=torch.float32))

        self.node_filter = node_filter
        self.loss = nn.CrossEntropyLoss(reduction="none")
        self.top_k = top_k
Ejemplo n.º 6
0
    def __init__(self, encoding_model, alignment_model, node_filter, top_k=5):
        super(EnsembleModel, self).__init__()
        self._encoding_model = encoding_model
        self._alignment_model = alignment_model
        self._weight = V(FT([1.0, 1.0]))

        self.node_filter = node_filter
        self.loss = nn.CrossEntropyLoss(reduce=False)
        self.top_k = top_k
Ejemplo n.º 7
0
    def from_sequences(cls, sequences, vocab_or_vocabs, min_seq_length=0):
        """Convert a batch of sequences into a SequenceBatch.

        Args:
            sequences (list[list[unicode]])
            vocab_or_vocabs (WordVocab|list[WordVocab]): either a single vocab, or a list of vocabs, one per sequence
            min_seq_length (int): enforce that the Tensor representing the SequenceBatch have at least
                this many columns.

        Returns:
            SequenceBatch
        """
        # determine dimensions
        batch_size = len(sequences)
        if batch_size == 0:
            seq_length = 0
        else:
            seq_length = max(len(seq)
                             for seq in sequences)  # max seq length in batch
        seq_length = max(
            seq_length,
            min_seq_length)  # make sure it is at least min_seq_length
        shape = (batch_size, seq_length)

        # set up vocabs
        if isinstance(vocab_or_vocabs, list):
            vocabs = vocab_or_vocabs
            assert len(vocabs) == batch_size
        else:
            # duplicate a single vocab
            assert isinstance(vocab_or_vocabs, Vocab)
            vocabs = [vocab_or_vocabs] * batch_size

        # build arrays
        values = np.zeros(shape, dtype=np.int64)  # pad with zeros
        mask = np.zeros(shape, dtype=np.float32)
        for i, (seq, vocab) in enumerate(zip(sequences, vocabs)):
            for j, word in enumerate(seq):
                values[i, j] = vocab.word2index(word)
                mask[i, j] = 1.0

        return SequenceBatch(V(torch.from_numpy(values)),
                             V(torch.from_numpy(mask)))
Ejemplo n.º 8
0
 def forward(self, old_embeds, neighbors, rels):
     batch_size = len(old_embeds)
     projected = F.relu(self._proj(self._dropout(old_embeds)))
     neighbor_embeds = torch.index_select(projected, 0,
                                          neighbors.values.view(-1))
     neighbor_embeds = neighbor_embeds.view(batch_size,
                                            neighbors.values.shape[1], -1)
     combined = torch.cat((projected.unsqueeze(1), neighbor_embeds), dim=1)
     mask = torch.cat((V(torch.ones(batch_size, 1)), neighbors.mask), dim=1)
     combined = SequenceBatch(combined, mask)
     return SequenceBatch.reduce_max(combined)
Ejemplo n.º 9
0
def tile_state(h, batch_size):
    """Tile a given hidden state batch_size times.

    Args:
        h (Variable): a single hidden state of shape (hidden_dim,)
        batch_size (int)

    Returns:
        a Variable of shape (batch_size, hidden_dim)
    """
    tiler = V(torch.ones(batch_size, 1))
    return torch.mm(tiler, h.unsqueeze(0))  # (batch_size, hidden_size)
Ejemplo n.º 10
0
    def _mask_logits(cls, logits, mask):
        no_cells = cls._no_cells(mask)  # (batch_size, num_cells)
        suppress = V(torch.zeros(*mask.size()))

        # send the logit of non-cells to -infinity
        suppress[mask == 0] = float('-inf')
        # but if an entire row has no cells, just leave the cells alone
        suppress[no_cells == 1] = 0.0

        logits = logits + suppress  # -inf + anything = -inf

        return logits
Ejemplo n.º 11
0
    def embed_tokens(self, tokens):
        """Embed list of tokens.

        Args:
            tokens (list[unicode])

        Returns:
            embeds (Variable[FloatTensor]): of shape (len(tokens), embed_dim)
        """
        vocab = self.vocab
        indices = V(
            torch.tensor([vocab.word2index(t) for t in tokens],
                         dtype=torch.long))
        return self._embedding(indices)
Ejemplo n.º 12
0
    def forward(self, nodes):
        """Embeds a batch of Nodes.

        Args:
            nodes (list[Node])
        Returns:
            embeddings (Tensor): num_nodes x embed_dim
        """
        texts = []
        utterance_embedder = self._utterance_embedder
        for node in nodes:
            if self._recursive_texts:
                text = ' '.join(node.all_texts(max_words=self._max_words))
            else:
                text = node.text or ''
            texts.append(utterance_embedder.tokenize(text.lower()))
        text_embeddings = self._utterance_embedder(texts)

        # num_nodes x attr_embed_dim
        tag_embeddings = self._tag_embedder.embed_tokens(
            [node.tag for node in nodes])

        # num_nodes x attr_embed_dim
        id_embedder = self._id_embedder
        id_embeddings = self._id_embedder(
            [id_embedder.tokenize(node.id_) for node in nodes])

        # num_nodes x attr_embed_dim
        classes_embedder = self._classes_embedder
        class_embeddings = self._classes_embedder([
            classes_embedder.tokenize(' '.join(node.classes)) for node in nodes
        ])

        # num_nodes x 3
        coords = V(
            torch.tensor([[elem.x_ratio, elem.y_ratio,
                           float(elem.visible)] for elem in nodes],
                         dtype=torch.float32))

        # num_nodes x dom_embed_dim
        dom_embeddings = torch.cat((text_embeddings, tag_embeddings,
                                    id_embeddings, class_embeddings, coords),
                                   dim=1)
        # dom_embeddings = text_embeddings
        return self.fc(dom_embeddings)
Ejemplo n.º 13
0
    def __init__(self, num_embeddings, embedding_dim, initial_embeddings,
                 **kwargs):
        """Constructs TrainFlagEmbedding with embeddings initialized with
        initial_embeddings.

        Args:
            num_embeddings (int)
            embedding_dim (int)
            initial_embeddings (np.array): (num_embeddings, embedding_dim)
            trainable (bool): if False, weights matrix will not change.
                (default True)
            kwargs: all other supported keywords in torch.nn.Embeddings.
        """
        super(TrainFlagEmbedding, self).__init__()
        trainable = kwargs.pop("trainable", True)
        self._trainable = trainable
        if trainable:
            embedding = Embedding(num_embeddings, embedding_dim, **kwargs)
            embedding.weight.data.copy_(torch.from_numpy(initial_embeddings))
            self._embedding = embedding
        else:
            self._fixed_weight = V(torch.from_numpy(initial_embeddings))
Ejemplo n.º 14
0
 def reduce_sum(cls, seq_batch):
     weights = V(torch.ones(*seq_batch.mask.size()))
     return cls.weighted_sum(seq_batch, weights)
Ejemplo n.º 15
0
 def forward(self, memory_cells, query):
     batch_size, num_cells = memory_cells.mask.size()
     logits = V(torch.zeros(batch_size, num_cells))
     weights = V(torch.zeros(batch_size, num_cells))
     context = V(torch.zeros(batch_size, self.memory_dim))
     return AttentionOutput(weights=weights, context=context, logits=logits)
Ejemplo n.º 16
0
 def __init__(self, input_dim, hidden_dim):
     super(AdditionCell, self).__init__()
     self.W = V(torch.eye(input_dim, hidden_dim))
     # truncates input if input_dim > hidden_dim
     # pads with zeros if input_dim < hidden_dim
     self.hidden_size = hidden_dim
Ejemplo n.º 17
0
 def _mask_weights(cls, weights, mask):
     # if a given row has no memory cells, weights should be all zeros
     no_cells = cls._no_cells(mask)
     all_zeros = V(torch.zeros(*mask.size()))
     weights = conditional(no_cells, all_zeros, weights)
     return weights
Ejemplo n.º 18
0
    def forward(self, web_page, examples, logits_only=False):
        """Compute predictions and loss.

        Args:
            web_page (WebPage): The web page of the examples
            examples (list[PhraseNodeExample]): Must be from the same web page.
            logits_only (bool)
        Returns:
            logits (Tensor): num_phrases x num_nodes
                Each entry (i,j) is the logit for p(node_j | phrase_i)
            losses (Tensor): num_phrases
            predictions (Tensor): num_phrases
        """
        phrase_embedder = self.phrase_embedder

        def max_scorer(pairwise_scores):
            """
            Args:
                pairwise_scores: num_nodes x phrase_len x max_text_len
            """
            scores = torch.max(pairwise_scores, dim=1)[0]
            return torch.max(scores, dim=1)[0]

        def cnn_scorer(pairwise_scores):
            """
            Args:
                pairwise_scores: num_nodes x phrase_len x max_text_len
            """
            scores = torch.unsqueeze(pairwise_scores, dim=1)
            scores = self.conv2d(scores)
            scores = self.conv2d_dilated(scores)
            scores = self.pooler(scores)
            scores = torch.squeeze(scores, dim=1)
            # dim = scores.shape[1]*scores.shape[2]
            scores = scores.view(-1,self.score_dim)
            if self.use_tags:
                tags = [node.tag for node in web_page.nodes]
                tag_embeddings = self._tag_embedder.embed_tokens(tags)
                scores = torch.cat((scores,tag_embeddings), dim=1)
                scores = self.project_tag(scores)
            scores = self.scorer(scores)
            scores = torch.squeeze(scores, dim=1)
            return scores

        def neighbor_cnn_scorer(pairwise_scores):
            """
            Args:
                pairwise_scores: num_nodes x phrase_len x max_text_len
            """
            scores = torch.unsqueeze(pairwise_scores, dim=1)
            scores = self.conv2d(scores)
            scores = self.conv2d_dilated(scores)
            scores = self.pooler(scores)
            scores = torch.squeeze(scores, dim=1)
            # dim = scores.shape[1]*scores.shape[2]
            scores = scores.view(-1,self.score_dim)
            if self.use_tags:
                tags = [node.tag for node in web_page.nodes]
                tag_embeddings = self._tag_embedder.embed_tokens(tags)
                scores = torch.cat((scores,tag_embeddings), dim=1)
                scores = self.project_tag(scores)
            return scores

        # Tokenize the nodes
        # num_nodes x text_length x embed_dim
        texts = []
        for node in web_page.nodes:
            text = ' '.join(node.all_texts(max_words=self.max_words))
            output = []
            if not self.ablate_text:
                output += phrase_embedder.tokenize(text)
            if not self.ablate_attrs:
                # TODO better way to include attributes?
                output += phrase_embedder.tokenize(semantic_attrs(node.attributes))
            texts.append(output)

        embedded_texts = embed_tokens(self.token_embedder, self.max_words, texts)
        embedded_texts_values = self.dropout(embedded_texts.values)

        embedded_texts = embedded_texts_values * embedded_texts.mask.unsqueeze(2)

        # Tokenize the phrases
        # num_phrases x phrase_length x embed_dim
        logits = []

        if not self.use_neighbors:
            for example in examples:
                phrase = [phrase_embedder.tokenize(example.phrase)]
                embedded_phrase = embed_tokens(self.token_embedder, self.max_words, phrase)

                embedded_phrase_values = self.dropout(embedded_phrase.values)

                # expand: num_nodes x phrase_len x embed_dim
                batch_phrase = embedded_phrase_values.expand(len(texts), -1, -1)
                # permute embedded_texts: num_nodes x embed_dim x max_text_len
                pairwise_scores = torch.bmm(batch_phrase, embedded_texts.permute(0, 2, 1))

                # compute scores
                scores = cnn_scorer(pairwise_scores)
                logits.append(torch.unsqueeze(scores, dim=0))
        else:
            intermediate_scores = []
            for example in examples:
                phrase = [phrase_embedder.tokenize(example.phrase)]
                embedded_phrase = embed_tokens(self.token_embedder, self.max_words, phrase)

                embedded_phrase_values = self.dropout(embedded_phrase.values)

                # expand: num_nodes x phrase_len x embed_dim
                batch_phrase = embedded_phrase_values.expand(len(texts), -1, -1)
                # permuted embedded_texts: num_nodes x embed_dim x max_text_len
                pairwise_scores = torch.bmm(batch_phrase, embedded_texts.permute(0, 2, 1))
                node_score = neighbor_cnn_scorer(pairwise_scores)
                intermediate_scores.append(node_score)

            neighbors, masks = web_page.get_spatial_neighbors()
            neighbors, masks = V(torch.tensor(neighbors, dtype=torch.long)), V(torch.tensor(masks, dtype=torch.float32))
            masks = masks.unsqueeze(dim=2)

            # each node_score tensor is parameterized by phrase
            for node_score in intermediate_scores:
                # get pairwise_scores for all neighbors...
                # neighbors, rels = self._get_neighbors(web_page)
                batch_size = len(node_score)
                neighbor_scores = torch.index_select(node_score, 0, neighbors.view(-1))
                neighbor_scores = neighbor_scores.view(batch_size, neighbors.shape[1], -1)
                neighbor_scores = neighbor_scores * masks

                if neighbor_scores.shape[1] < self.num_rels:
                    more = self.num_rels - neighbor_scores.shape[1]
                    num_nodes, _, embed_dim = neighbor_scores.shape
                    padding = V(torch.zeros(num_nodes, more, embed_dim))
                    neighbor_scores = torch.cat((neighbor_scores, padding), dim=1)
                # num_nodes x num_neighbors x intermediate_score_dim

                node_score = torch.unsqueeze(node_score, dim=1)
                scores = torch.cat((node_score, neighbor_scores), dim=1)

                scores = scores.view(node_score.shape[0], -1)
                scores = self._final_neighbor_linear(scores)
                scores = torch.squeeze(scores, dim=1)

                logits.append(torch.unsqueeze(scores, dim=0))

        logits = torch.cat(logits, dim=0)

        # Filter the candidates
        node_filter_mask = self.node_filter(web_page, examples[0].web_page_code)  # what does this do?
        log_node_filter_mask = V(torch.tensor([0. if x else -999999. for x in node_filter_mask], dtype=torch.float32))
        logits = logits + log_node_filter_mask
        if logits_only:
            return logits

        # Losses and predictions
        targets = V(torch.tensor([web_page.xid_to_ref.get(x.target_xid, 0) for x in examples], dtype=torch.long))
        mask = V(torch.tensor([int(x.target_xid in web_page.xid_to_ref and node_filter_mask[web_page.xid_to_ref[x.target_xid]])
                     for x in examples], dtype=torch.float32))
        losses = self.loss(logits, targets) * mask
        # print '=' * 20, examples[0].web_page_code
        # print [node_filter_mask[web_page.xid_to_ref.get(x.target_xid, 0)] for x in examples]
        # print [logits.detach()[i, web_page.xid_to_ref.get(x.target_xid, 0)] for (i, x) in enumerate(examples)]
        # print logits, targets, mask, losses
        if not isfinite(losses.detach().sum()):
            # raise ValueError('Losses has NaN')
            logging.warning('Losses has NaN')
            # print losses
        # num_phrases x top_k
        top_k = min(self.top_k, len(web_page.nodes))
        predictions = torch.topk(logits, top_k, dim=1)[1]
        return logits, losses, predictions
Ejemplo n.º 19
0
    def forward(self, web_page, examples, logits_only=False):
        """Compute predictions and loss.

        Args:
            web_page (WebPage): The web page of the examples
            examples (list[PhraseNodeExample]): Must be from the same web page.
        Returns:
            logits (Tensor): num_phrases x num_nodes
                Each entry (i,j) is the logit for p(node_j | phrase_i)
            losses (Tensor): num_phrases
            predictions (Tensor): num_phrases
        """
        # Embed the nodes + normalize
        # num_nodes x dim
        node_embeddings = self.node_embedder(web_page.nodes)
        node_embeddings = node_embeddings / torch.clamp(
            node_embeddings.norm(p=2, dim=1, keepdim=True), min=1e-8)
        # Embed the phrases + normalize
        phrases = []
        for example in examples:
            phrases.append(word_tokenize(example.phrase.lower()))
        # num_phrases x dim
        phrase_embeddings = self.phrase_embedder(phrases)
        if self.proj is not None:
            phrase_embeddings = F.sigmoid(self.proj(phrase_embeddings))
        else:
            pass
        phrase_embeddings = phrase_embeddings / torch.clamp(
            phrase_embeddings.norm(p=2, dim=1, keepdim=True), min=1e-8)

        ps = torch.split(phrase_embeddings, 1, dim=0)
        logits = []

        # only loop on phrases
        if not self.use_neighbors:
            for p in ps:
                p = p.expand(node_embeddings.shape[0], -1)
                encoding = torch.cat((p, node_embeddings, p * node_embeddings),
                                     dim=1)
                encoding = self.dropout(encoding)
                scores = self.score(encoding)
                logits.append(scores)
        else:
            neighbors, masks = web_page.get_spatial_neighbors()
            neighbors, masks = V(LT(neighbors)), V(FT(masks))
            masks = masks.unsqueeze(dim=2)

            for p in ps:
                batch_size = node_embeddings.shape[0]
                p = p.expand(batch_size, -1)

                neighbor_scores = torch.index_select(node_embeddings, 0,
                                                     neighbors.view(-1))
                neighbor_scores = neighbor_scores.view(batch_size,
                                                       neighbors.shape[1], -1)
                neighbor_scores = neighbor_scores * masks

                encoding = [p]
                neighbor_scores = torch.split(neighbor_scores, 1, dim=1)
                neighbor_scores = [
                    torch.squeeze(x, dim=1) for x in neighbor_scores
                ]
                for n in [node_embeddings] + neighbor_scores:
                    encoding += [n, p * n]
                encoding = torch.cat(encoding, dim=1)
                encoding = self.dropout(encoding)
                scores = self.score(encoding)
                logits.append(scores)

        logits = torch.cat(logits, dim=1)
        logits = logits.permute(1, 0)

        # Filter the candidates
        node_filter_mask = self.node_filter(web_page,
                                            examples[0].web_page_code)
        log_node_filter_mask = V(
            FT([0. if x else -999999. for x in node_filter_mask]))
        logits = logits + log_node_filter_mask
        if logits_only:
            return logits
        # Losses and predictions
        targets = V(
            LT([web_page.xid_to_ref.get(x.target_xid, 0) for x in examples]))
        mask = V(
            FT([
                int(x.target_xid in web_page.xid_to_ref
                    and node_filter_mask[web_page.xid_to_ref[x.target_xid]])
                for x in examples
            ]))
        losses = self.loss(logits, targets) * mask
        #print '=' * 20, examples[0].web_page_code
        #print [node_filter_mask[web_page.xid_to_ref.get(x.target_xid, 0)] for x in examples]
        #print [logits.data[i, web_page.xid_to_ref.get(x.target_xid, 0)] for (i, x) in enumerate(examples)]
        #print logits, targets, mask, losses
        if not np.isfinite(losses.data.sum().item()):
            #raise ValueError('Losses has NaN')
            logging.warn('Losses has NaN')
            #print losses
        # num_phrases x top_k
        top_k = min(self.top_k, len(web_page.nodes))
        predictions = torch.topk(logits, top_k, dim=1)[1]
        return logits, losses, predictions