Exemplo n.º 1
0
 def forward(self, inputs, lengths, hidden=None):
     lens, indices = torch.sort(inputs.data.new(lengths).long(), 0, True)
     inputs = inputs[indices] if self.batch_first else inputs[:, indices] 
     outputs, (h, c) = self.rnn(pack(inputs, lens.tolist(), 
         batch_first=self.batch_first), hidden)
     outputs = unpack(outputs, batch_first=self.batch_first)[0]
     _, _indices = torch.sort(indices, 0)
     outputs = outputs[_indices] if self.batch_first else outputs[:, _indices]
     h, c = h[:, _indices, :], h[:, _indices, :]
     return outputs, (h, c)
Exemplo n.º 2
0
 def forward(self, enc_input, hidden=None):
     if isinstance(enc_input, tuple):
         # Lengths data is wrapped inside a Variable.
         lengths = enc_input[1].data.view(-1).tolist()
         emb = pack(self.embedding(enc_input[0]), lengths)
     else:
         emb = self.embedding(enc_input)
     outputs, hidden_t = self.rnn(emb, hidden)
     if isinstance(enc_input, tuple):
         outputs = unpack(outputs)[0]
     return outputs, hidden_t
Exemplo n.º 3
0
    def forward(self, src, lengths=None):
        """See :func:`onmt.encoders.encoder.EncoderBase.forward()`"""
        batch_size, _, nfft, t = src.size()
        src = src.transpose(0, 1).transpose(0, 3).contiguous() \
                 .view(t, batch_size, nfft)
        orig_lengths = lengths
        lengths = lengths.view(-1).tolist()

        for l in range(self.enc_layers):
            rnn = getattr(self, 'rnn_%d' % l)
            pool = getattr(self, 'pool_%d' % l)
            batchnorm = getattr(self, 'batchnorm_%d' % l)
            stride = self.enc_pooling[l]
            packed_emb = pack(src, lengths)
            memory_bank, tmp = rnn(packed_emb)
            memory_bank = unpack(memory_bank)[0]
            t, _, _ = memory_bank.size()
            memory_bank = memory_bank.transpose(0, 2)
            memory_bank = pool(memory_bank)
            lengths = [int(math.floor((length - stride) / stride + 1))
                       for length in lengths]
            memory_bank = memory_bank.transpose(0, 2)
            src = memory_bank
            t, _, num_feat = src.size()
            src = batchnorm(src.contiguous().view(-1, num_feat))
            src = src.view(t, -1, num_feat)
            if self.dropout and l + 1 != self.enc_layers:
                src = self.dropout(src)

        memory_bank = memory_bank.contiguous().view(-1, memory_bank.size(2))
        memory_bank = self.W(memory_bank).view(-1, batch_size,
                                               self.dec_rnn_size)

        state = memory_bank.new_full((self.dec_layers * self.num_directions,
                                      batch_size, self.dec_rnn_size_real), 0)
        if self.rnn_type == 'LSTM':
            # The encoder hidden is  (layers*directions) x batch x dim.
            encoder_final = (state, state)
        else:
            encoder_final = state
        return encoder_final, memory_bank, orig_lengths.new_tensor(lengths)
Exemplo n.º 4
0
    def forward(self, src, lengths=None, encoder_state=None):
        "See :obj:`EncoderBase.forward()`"
        self._check_args(src, lengths, encoder_state)

        emb = self.embeddings(src)
        s_len, batch, emb_dim = emb.size()

        packed_emb = emb
        if lengths is not None and not self.no_pack_padded_seq:
            # Lengths data is wrapped inside a Variable.
            lengths = lengths.view(-1).tolist()
            packed_emb = pack(emb, lengths)

        memory_bank, encoder_final = self.rnn(packed_emb, encoder_state)

        if lengths is not None and not self.no_pack_padded_seq:
            memory_bank = unpack(memory_bank)[0]

        if self.use_bridge:
            encoder_final = self._bridge(encoder_final)
        return encoder_final, memory_bank
Exemplo n.º 5
0
    def forward(self, docs_input, titles_input, keywords=None, topics=None):
        docs_len, titles_len = get_seq_lenth(docs_input), get_seq_lenth(
            titles_input)
        docs_mask, titles_mask = creat_mask(docs_len), creat_mask(titles_len)

        s_docs, s_docs_len, reverse_docs_idx = sort_batch(docs_input, docs_len)
        s_titles, s_titles_len, reverse_titles_idx = sort_batch(
            titles_input, titles_len)

        # embedding
        docs_embedding = pack(self.embedding(s_docs),
                              list(s_docs_len.data),
                              batch_first=True)
        titles_embedding = pack(self.embedding(s_titles),
                                list(s_titles_len.data),
                                batch_first=True)

        # GRU encoder
        docs_outputs, _ = self.gru(docs_embedding, None)
        titles_outputs, _ = self.gru(titles_embedding, None)

        # unpack
        docs_outputs, _ = unpack(docs_outputs, batch_first=True)
        titles_outputs, _ = unpack(titles_outputs, batch_first=True)

        # unsort
        docs_outputs = docs_outputs[reverse_docs_idx]
        titles_outputs = titles_outputs[reverse_titles_idx]

        # calculate attention matrix
        dos = docs_outputs
        doc_mask = docs_mask.unsqueeze(2)
        tos = torch.transpose(titles_outputs, 1, 2)
        title_mask = titles_mask.unsqueeze(2)

        M = torch.bmm(dos, tos)
        M_mask = torch.bmm(doc_mask, title_mask.transpose(1, 2))
        alpha = softmax_mask(M, M_mask, axis=1)
        beta = softmax_mask(M, M_mask, axis=2)

        sum_beta = torch.sum(beta, dim=1, keepdim=True)
        docs_len = docs_len.unsqueeze(1).unsqueeze(2).expand_as(sum_beta)
        average_beta = sum_beta / docs_len.float()
        # doc-level attention
        s = torch.bmm(alpha, average_beta.transpose(1, 2))

        # predict keywords
        kws_probs = None
        if keywords is not None:
            kws_probs = []
            for i, kws in enumerate(keywords):
                document = docs_input[i].squeeze()
                cur_prob = 1.
                for j, kw in enumerate(kws):
                    if kw.data[0] == Constants.PAD: continue
                    kw = kws[j].squeeze()
                    pointer = document == kw.expand_as(document)
                    cur_prob *= torch.sum(
                        torch.masked_select(s[i].squeeze(), pointer))
                kws_probs.append(cur_prob + 1e-10)
            kws_probs = torch.cat(kws_probs, 0).squeeze()

        # predict prob of topic
        fc_feature = torch.sum(docs_outputs * s, dim=1)
        topic_probs = self.fc(fc_feature)

        return topic_probs, kws_probs, s
Exemplo n.º 6
0
    def forward(self, x1, x2, cate, seq_lens, _, attn_mode=False):
        batch_size = x1.size(0)
        max_seq_length = x1.size(1)
        embed_size = x1.size(2)

        outputs = torch.zeros(
            [max_seq_length, batch_size, embed_size],
            device=torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu'))

        #
        cate = pack(cate, seq_lens, batch_first=True).data

        # x2: [batch_size, seq_len, num_of_temporal, embed_dim]
        x2 = pack(x2, seq_lens, batch_first=True).data

        mha_input = torch.transpose(x2, 0, 1)
        _, x2_score = self._mha(mha_input, mha_input, mha_input)
        x2_score = torch.softmax(torch.mean(x2_score, 2, keepdim=False), dim=1)
        x2_score = torch.unsqueeze(x2_score, dim=1)

        x2 = torch.squeeze(torch.bmm(x2_score, x2), dim=1)

        #x2 = torch.mean(x2, 0, keepdim=False)
        #x2 = torch.mean(x2.data, 1, keepdim=False)
        x2 = self._mlp_mha(x2)

        # sequence embedding
        x1 = pack(x1, seq_lens, batch_first=True)
        x1, _ = self.rnn(x1)

        sequence_lenths = x1.batch_sizes.cpu().numpy()
        cursor = 0
        prev_x1s = []

        if attn_mode:
            attn_score_save = torch.zeros(
                [max_seq_length, batch_size, max_seq_length],
                device=torch.device(
                    'cuda' if torch.cuda.is_available() else 'cpu'))

        for step in range(sequence_lenths.shape[0]):
            sequence_lenth = sequence_lenths[step]

            x1_step = x1.data[cursor:cursor + sequence_lenth]
            x2_step = x2[cursor:cursor + sequence_lenth]

            prev_x1s.append(x1_step)

            prev_x1s = [prev_x1[:sequence_lenth] for prev_x1 in prev_x1s]

            prev_hs = torch.stack(prev_x1s, dim=1)
            attn_score = []
            for prev in range(prev_hs.size(1)):
                attn_inter = torch.sum(prev_hs[:, prev, :] * x2_step,
                                       dim=1,
                                       keepdim=True)
                attn_score.append(attn_inter + self._b_attn)
            attn_score = torch.softmax(torch.stack(attn_score, dim=1), dim=1)

            if attn_mode:
                attn_score_save[step, :sequence_lenth, :attn_score.
                                shape[1]] = torch.squeeze(attn_score, dim=2)

            x1_step = torch.squeeze(torch.bmm(
                torch.transpose(attn_score, 1, 2), prev_hs),
                                    dim=1)

            x_step = torch.cat((x1_step, x2_step), dim=1)
            x_step = self.mlp(x_step)

            outputs[step][:sequence_lenth] = x_step

            cursor += sequence_lenth

        if attn_mode:
            prev_cates = []
            attn_good_data = []

            attn_score_save = torch.transpose(attn_score_save, 0, 1)
            batch_total = 0
            cursor = 0
            for step in range(sequence_lenths.shape[0]):
                sequence_lenth = sequence_lenths[step]

                cate_step = cate[cursor:cursor + sequence_lenth]
                cursor += sequence_lenth

                prev_cates.append(cate_step)
                prev_cates = [prev[:sequence_lenth] for prev in prev_cates]

                batch_total += 1

                min_len = 14
                min_vari_cate_prev = 6
                min_vari_cate_pred = 4
                target_len = 10

                #                min_len = 5
                #                min_vari_cate_prev = 0
                #                min_vari_cate_pred = 0
                #                target_len = 2

                if step < min_len:
                    continue

                prev_cates_t = torch.stack(prev_cates, dim=0)
                prev_cates_t = torch.argmax(torch.transpose(
                    prev_cates_t, 0, 1),
                                            dim=2)

                for batch in range(prev_cates_t.shape[0]):
                    #attn_score_c = attn_score_save[batch, step-1, :step].cpu().numpy().tolist()
                    prev_cates_c = prev_cates_t[
                        batch, :].cpu().numpy().tolist()

                    prev_cate_set = set([])
                    pred_cate_set = set([])

                    for i, c in enumerate(prev_cates_c):
                        if i < target_len:
                            prev_cate_set.update([c])
                        else:
                            pred_cate_set.update([c])

                    if len(prev_cate_set) < min_vari_cate_prev or len(
                            pred_cate_set) < min_vari_cate_pred:
                        continue

                    ok = 0
                    for c in pred_cate_set:
                        if c not in prev_cate_set:
                            continue
                        ok += 1

                    if ok < 4:
                        continue

                    attn_good_data.append(
                        ['==================================='])
                    attn_good_data.append(prev_cates_c[:target_len])
                    attn_good_data.append(prev_cates_c[target_len:])
                    for s in range(target_len, step + 1):
                        attn_score_c = attn_score_save[
                            batch, s - 1, :s].cpu().numpy().tolist()
                        attn_good_data.append(attn_score_c[:target_len])

        outputs = torch.transpose(outputs, 0, 1)
        #outputs = self._dropout(outputs)

        if attn_mode:
            with open('attn_data_cate.txt', 'a') as f_att:
                attn_data_str = ''
                for data_line in attn_good_data:
                    data_line = [str(d) for d in data_line]
                    attn_data_str += ','.join(data_line) + '\n'
                f_att.write(attn_data_str)

        return outputs
Exemplo n.º 7
0
    def forward(self, x, src_len=None, ivec_feat=None):
        """
        x : (batch x seq x ndim) 
        mask : (batch x seq)
        """
        batchsize, seqlen, ndim = x.size()

        if src_len is None:
            src_len = [seqlen] * batchsize

        ### FNN ###
        # apply augmentation first layer #
        # TODO : dynamic layer #
        assert self.ivec_dim >= ivec_feat.size(1)
        ivec_feat = ivec_feat[:, 0:self.ivec_dim]
        if self.ivec_cfg['type'] == 'concat':
            # calculate h from ivec #
            res_ivec = self.aug_layer(ivec_feat)
            res_ivec = res_ivec.unsqueeze(1).expand(batchsize, seqlen,
                                                    res_ivec.size(1))
            res_ivec = res_ivec.contiguous().view(seqlen * batchsize, -1)

            res_main = x.contiguous().view(seqlen * batchsize, ndim)
            res_main = self.fnn[0](res_main)

            res = res_main + res_ivec
        else:
            _res_list = []
            for ii in range(batchsize):
                _aug_param = self.fn_gen_params(ivec_feat[ii:ii + 1])
                _main_aug_param = self.fn_aug_params(self.fnn[0].weight,
                                                     _aug_param)
                _main_bias = self.fnn[0].bias
                res_ii = F.linear(x[ii], _main_aug_param, _main_bias)
                _res_list.append(res_ii)
            res = torch.stack(_res_list)
            pass
        res = getattr(F, self.fnn_act)(res)
        res = F.dropout(res, self.do_fnn[0], training=self.training)
        prev_size = res.size(1)

        for ii in range(1, len(self.fnn_sizes)):
            res = getattr(F, self.fnn_act)(self.fnn[ii](res))
            res = F.dropout(res, self.do_fnn[ii], training=self.training)

        ### RNN ###
        # convert shape for RNN #
        res = res.view(batchsize, seqlen, -1)
        for ii in range(len(self.rnn_sizes)):
            # AVOID pack, slow !!! #
            if self.use_pack:
                res = pack(res, src_len, batch_first=True)
                res = self.rnn[ii](res)[0]  # get h only #
                res, _ = unpack(res, batch_first=True)
            else:
                res = self.rnn[ii](res)[0]  # get h only #

            if self.downsampling[ii] == True:
                res = res[:, 1::2]
                src_len = [x // 2 for x in src_len]
                pass

            res = F.dropout(res, self.do_rnn[ii], training=self.training)

        ### PRE SOFTMAX ###
        batchsize, seqlen_final, ndim_final = res.size()
        res = res.view(seqlen_final * batchsize, ndim_final)

        res = self.pre_softmax(res)
        res = res.view(batchsize, seqlen_final, -1)
        res = res.transpose(1, 0)
        return res, Variable(torch.IntTensor(src_len))
    def forward(self,
                confnets,
                scores,
                src_lengths=None,
                par_arc_lengths=None):
        """
        Based on the paper NEURAL CONFNET CLASSIFICATION (http://150.162.46.34:8080/icassp2018/ICASSP18_USB/pdfs/0006039.pdf)
        """
        #self._check_args(confnets, src_lengths)

        #print('confnet size', confnets.size())

        emb = self.embeddings(confnets.permute(
            1, 0, 2, 3))  #(slen, batch, max_par_arc, emb_dim)
        emb_trans = emb.squeeze(
            3)  #.permute(1,0,2,3) #(batch, slen, max_par_arc, emb_dim)
        output_list = torch.tensor([]).cuda()
        #### FEATS NOT SUPPORTED ###
        confnets_ = confnets.squeeze(-1).permute(
            1, 0, 2)  # (max_sent_len, batch, max_par_arc_len, emb_dim)
        scores_ = scores.squeeze(-1).permute(
            1, 0, 2)  # (max_sent_len, batch, max_par_arc_len)
        par_arc_lengths_ = par_arc_lengths.permute(1,
                                                   0)  # (max_sent_lens, batch)
        for em, score, lengths in zip(emb_trans, scores_, par_arc_lengths_):
            # word embedding
            # s_len, batch, emb_dim = emb.size()
            # output = self.dropout(output)
            sc = score.unsqueeze(-1).expand(
                em.size())  #(batch, max_par_arc, emb_sz)

            # confnet score weighted word embedding
            q = em.float() * sc.float()  #(batch, max_par_arc, emb_sz)
            batch_size, max_par_arcs, emb_sz = q.size()
            v = torch.tanh(self.thetav(q))  #(batch, max_par_arc, emb_sz)
            v_bar = self.v_bar(v).squeeze(-1)
            #### masking: Mask the padding ####
            mask = torch.arange(max_par_arcs)[None, :].to(
                "cuda") < lengths[:, None].to("cuda").type(torch.float)
            mask = mask.type(torch.float)
            masked_v_bar = torch.where(
                mask == False,
                torch.tensor([float("-inf") - 1e-10], device=q.device), v_bar)
            attention = torch.softmax(masked_v_bar, dim=1)
            final_attention = attention.masked_fill(torch.isnan(attention), 0)
            # apply attention weights
            output = q * final_attention.unsqueeze(-1).expand(q.size())

            # most attented words
            most_attentive_arc = torch.argmax(final_attention, dim=1)
            # highest attention weights
            most_attentive_arc_weights, _ = torch.max(final_attention, dim=1)
            # a = output

            a = torch.sum(output, dim=1)
            output_list = torch.cat((output_list, a.unsqueeze(0)), dim=0)
        # a = self.dropout(a)

        output_confnet = output_list.permute(
            1, 0, 2)  # (batch, max_sent_len, hid_dim)
        #return a, output_confnet, src_lengths #most_attentive_arc, attention, most_attentive_arc_weights  # output, h_output

        packed_emb_ = output_confnet.permute(1, 0, 2)
        if src_lengths is not None and not self.no_pack_padded_seq:
            # Lengths data is wrapped inside a Tensor.
            lengths_list = src_lengths.view(-1).tolist()
            packed_emb = pack(packed_emb_, lengths_list)

        memory_bank, encoder_final = self.rnn(packed_emb)

        if src_lengths is not None and not self.no_pack_padded_seq:
            memory_bank = unpack(memory_bank)[0]

        if self.use_bridge:
            encoder_final = self._bridge(encoder_final)
        return encoder_final, memory_bank, src_lengths
Exemplo n.º 9
0
    def beam_search(self, src: Dict[str,
                                    torch.Tensor], key: Dict[str,
                                                             torch.Tensor],
                    lpos: Dict[str, torch.Tensor], rpos: Dict[str,
                                                              torch.Tensor],
                    max_decoding_step: int) -> Dict[str, torch.Tensor]:

        beam_size = self.beam_size
        src = src['tokens']

        if self.use_feature:
            keys, lpos, rpos = key['keys'], lpos['lpos'], rpos['rpos']
        lengths = self._get_lengths(src)
        batch_size = src.size(0)
        lengths, indices = lengths.sort(dim=0, descending=True)
        rev_indices = indices.sort()[1]
        src = src.index_select(dim=0, index=indices)

        if self.use_feature:
            keys = keys.index_select(dim=0, index=indices)
            lpos = lpos.index_select(dim=0, index=indices)
            rpos = rpos.index_select(dim=0, index=indices)

            src_embs = torch.cat([
                self.src_embedding(src),
                self.key_embedding(keys),
                self.lpos_embedding(lpos),
                self.rpos_embedding(rpos)
            ],
                                 dim=-1)
        else:
            src_embs = self.src_embedding(src)

        src_embs = pack(src_embs, lengths, batch_first=True)
        encode_outputs = self.encoder(src_embs)
        contexts, encState = encode_outputs['hidden_outputs'], encode_outputs[
            'final_state']

        contexts = contexts.repeat(beam_size, 1, 1)
        decState = encState[0].repeat(1, beam_size,
                                      1), encState[1].repeat(1, beam_size, 1)
        beam = [
            modules.beam.Beam(beam_size,
                              bos=self._bos,
                              eos=self._eos,
                              n_best=1,
                              minimum_length=self.minimum_length)
            for _ in range(batch_size)
        ]

        for i in range(max_decoding_step):

            if all((b.done() for b in beam)):
                break

            inp = torch.stack([b.getCurrentState()
                               for b in beam]).t().contiguous().view(-1)

            outputs = self.decoder.decode_step(self.tgt_embedding(inp),
                                               decState, contexts)
            output, decState, attn = outputs['hidden_output'], outputs[
                'state'], outputs['attention_weights']
            logits = self.generator(output)

            output = torch.nn.functional.log_softmax(logits, dim=-1).view(
                beam_size, batch_size, -1)
            attn = attn.view(beam_size, batch_size, -1)

            for j, b in enumerate(beam):
                b.advance(output.data[:, j], attn.data[:, j])
                b.beam_update(decState, j)

        allHyps, allScores, allAttn = [], [], []

        for j in rev_indices:
            b = beam[j]
            n_best = 1
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, attn = [], []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = b.getHyp(times, k)
                hyps.append(hyp)
                attn.append(att.max(1)[1])
            allHyps.append(hyps[0])
            allScores.append(scores[0])
            allAttn.append(attn[0])

        outputs = {'output_ids': allHyps, 'alignments': allAttn}
        return outputs
Exemplo n.º 10
0
    def greedy_search(self, src: Dict[str,
                                      torch.Tensor], key: Dict[str,
                                                               torch.Tensor],
                      lpos: Dict[str, torch.Tensor], rpos: Dict[str,
                                                                torch.Tensor],
                      max_decoding_step: int) -> Dict[str, torch.Tensor]:

        src = src['tokens']

        if self.use_feature:
            keys, lpos, rpos = key['keys'], lpos['lpos'], rpos['rpos']

        lengths = self._get_lengths(src)
        lengths, indices = lengths.sort(dim=0, descending=True)
        rev_indices = indices.sort()[1]
        src = src.index_select(dim=0, index=indices)
        bos = torch.ones(src.size(0)).long().fill_(self._bos).cuda()

        if self.use_feature:
            keys = keys.index_select(dim=0, index=indices)
            lpos = lpos.index_select(dim=0, index=indices)
            rpos = rpos.index_select(dim=0, index=indices)

            src_embs = torch.cat([
                self.src_embedding(src),
                self.key_embedding(keys),
                self.lpos_embedding(lpos),
                self.rpos_embedding(rpos)
            ],
                                 dim=-1)
        else:
            src_embs = self.src_embedding(src)

        src_embs = pack(src_embs, lengths, batch_first=True)
        encode_outputs = self.encoder(src_embs)

        inputs, state, contexts = [
            bos
        ], encode_outputs['final_state'], encode_outputs['hidden_outputs']
        output_ids, attention_weights = [], []

        for i in range(max_decoding_step):
            outputs = self.decoder.decode_step(self.tgt_embedding(inputs[i]),
                                               state, contexts)
            hidden_output, state, attn_weight = outputs[
                'hidden_output'], outputs['state'], outputs[
                    'attention_weights']
            logits = self.generator(hidden_output)
            next_id = logits.max(1)[1]
            inputs += [next_id]
            output_ids += [next_id]
            attention_weights += [attn_weight]

        output_ids = torch.stack(output_ids, dim=1)
        attention_weights = torch.stack(attention_weights, dim=1)

        alignments = attention_weights.max(2)[1]
        output_ids = output_ids.index_select(dim=0, index=rev_indices)
        alignments = alignments.index_select(dim=0, index=rev_indices)
        outputs = {
            'output_ids': output_ids.tolist(),
            'alignments': alignments.tolist()
        }

        return outputs
Exemplo n.º 11
0
    def forward(self, label_seqs, location_seqs, lengths):
        # sort label sequences and location sequences in batch dimension according to length
        batch_idx = sorted(range(len(lengths)),
                           key=lambda k: lengths[k],
                           reverse=True)
        reverse_batch_idx = torch.LongTensor(
            [batch_idx.index(i) for i in range(len(batch_idx))])

        lens_sorted = sorted(lengths, reverse=True)
        label_seqs_sorted = torch.index_select(label_seqs, 0,
                                               torch.LongTensor(batch_idx))
        location_seqs_sorted = torch.index_select(location_seqs, 0,
                                                  torch.LongTensor(batch_idx))

        # assert torch.equal(torch.index_select(label_seqs_sorted, 0, reverse_batch_idx), label_seqs)
        # assert torch.equal(torch.index_select(location_seqs_sorted, 0, reverse_batch_idx), location_seqs)

        if torch.cuda.is_available():
            reverse_batch_idx = reverse_batch_idx.cuda()
            label_seqs_sorted = label_seqs_sorted.cuda()
            location_seqs_sorted = location_seqs_sorted.cuda()

        # create Variables
        label_seqs_sorted_var = Variable(label_seqs_sorted,
                                         requires_grad=False)
        location_seqs_sorted_var = Variable(location_seqs_sorted,
                                            requires_grad=False)

        # encode label sequences
        label_encoding = self.label_encoder(label_seqs_sorted_var)

        # encode location sequences
        location_seqs_sorted_var = location_seqs_sorted_var.view(-1, 4)
        location_encoding = self.location_encoder(location_seqs_sorted_var)
        location_encoding = location_encoding.view(label_encoding.size(0), -1,
                                                   location_encoding.size(1))

        # layout encoding - batch_size x max_seq_len x embed_size
        layout_encoding = label_encoding + location_encoding
        packed = pack(layout_encoding, lens_sorted, batch_first=True)
        hiddens, _ = self.lstm(packed)

        # unpack hiddens and get last hidden vector
        hiddens_unpack = unpack(
            hiddens,
            batch_first=True)[0]  # batch_size x max_seq_len x embed_size
        last_hidden_idx = torch.zeros(hiddens_unpack.size(0), 1,
                                      hiddens_unpack.size(2)).long()
        for i in range(hiddens_unpack.size(0)):
            last_hidden_idx[i, 0, :] = lens_sorted[i] - 1
        if torch.cuda.is_available():
            last_hidden_idx = last_hidden_idx.cuda()
        last_hidden = torch.gather(
            hiddens_unpack, 1,
            Variable(last_hidden_idx,
                     requires_grad=False))  # batch_size x 1 x embed_size
        last_hidden = torch.squeeze(last_hidden, 1)  # batch_size x embed_size

        # convert back to original batch order
        last_hidden = torch.index_select(
            last_hidden, 0, Variable(reverse_batch_idx, requires_grad=False))

        return last_hidden
Exemplo n.º 12
0
    def forward(self, x1, x2, __, seq_lens, ___, ____):
        """
            forwarding function of the model called by the pytorch
            :x: input tensor
            :x2: input tensor for global temporal preferences
            :seq_lens: list cataining the length of each seqeunce
            :return: output tensor
        """
        batch_size = x1.size(0)
        max_seq_length = x1.size(1)
        embed_size = x1.size(2)

        outputs = torch.zeros(
            [max_seq_length, batch_size, embed_size],
            device=torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu'))

        # x2: [batch_size, seq_len, num_of_temporal, embed_dim]
        x2 = pack(x2, seq_lens, batch_first=True).data

        mha_input = torch.transpose(x2, 0, 1)
        _, x2_score = self._mha(mha_input, mha_input, mha_input)
        x2_score = torch.softmax(torch.mean(x2_score, 2, keepdim=False), dim=1)
        x2_score = torch.unsqueeze(x2_score, dim=1)

        x2 = torch.squeeze(torch.bmm(x2_score, x2), dim=1)

        #x2 = torch.mean(x2, 0, keepdim=False)
        #x2 = torch.mean(x2.data, 1, keepdim=False)
        x2 = self._mlp_mha(x2)

        # sequence embedding
        x1 = pack(x1, seq_lens, batch_first=True)
        x1, _ = self.rnn(x1)

        sequence_lenths = x1.batch_sizes.cpu().numpy()
        cursor = 0
        prev_x1s = []
        for step in range(sequence_lenths.shape[0]):
            sequence_lenth = sequence_lenths[step]

            x1_step = x1.data[cursor:cursor + sequence_lenth]
            x2_step = x2[cursor:cursor + sequence_lenth]

            prev_x1s.append(x1_step)

            prev_x1s = [prev_x1[:sequence_lenth] for prev_x1 in prev_x1s]

            prev_hs = torch.stack(prev_x1s, dim=1)
            #            attn_score = []
            #            for prev in range(prev_hs.size(1)):
            #                attn_input = torch.cat((prev_hs[:,prev,:], x2_step), dim=1)
            #                attn_score.append(torch.matmul(attn_input, self._W_attn) + self._b_attn)
            #            attn_score = torch.softmax(torch.stack(attn_score, dim=1), dim=1)
            #            x1_step = torch.squeeze(torch.bmm(torch.transpose(attn_score, 1, 2), prev_hs), dim=1)

            x1_step = torch.mean(prev_hs, dim=1, keepdim=False)

            x_step = torch.cat((x1_step, x2_step), dim=1)
            x_step = self.mlp(x_step)

            outputs[step][:sequence_lenth] = x_step

            cursor += sequence_lenth

        outputs = torch.transpose(outputs, 0, 1)
        #outputs = self._dropout(outputs)

        return outputs
Exemplo n.º 13
0
 def forward(self, seq, length):
     emb = self.emb(seq)
     packed = pack(emb, length, batch_first=True, enforce_sorted=False)
     out, (h, c) = self.lstm(packed)
     return out, h, c
Exemplo n.º 14
0
    def forward(self, emb, lengths=None, init_states=None):
        "See :obj:`EncoderBase.forward()`"
        self._check_args(emb, lengths)

        packed_emb = emb
        if lengths is not None:
            # Lengths data is wrapped inside a Tensor.
            lengths, indices = torch.sort(lengths, 0,
                                          True)  # Sort by length (keep idx)
            packed_emb = pack(packed_emb[indices],
                              lengths.tolist(),
                              batch_first=True)
            _, _indices = torch.sort(indices, 0)  # Un-sort by length

        istates = []
        if init_states:
            if isinstance(init_states, tuple):
                hidden_states, cell_states = init_states
                hidden_states = hidden_states.split(self.nlayers, dim=0)
                cell_states = cell_states.split(self.nlayers, dim=0)
            else:
                hidden_states = init_states
                hidden_states = hidden_states.split(self.nlayers, dim=0)

            for i in range(self.nlayers):
                if isinstance(init_states, tuple):
                    istates.append((hidden_states[i], cell_states[i]))
                else:
                    istates.append(hidden_states[i])

        memory_bank, encoder_final = [], {'h_n': [], 'c_n': []}
        for i in range(self.nlayers):
            if i != 0:
                packed_emb = self.dropout(packed_emb)
                if lengths is not None:
                    packed_emb = pack(packed_emb,
                                      lengths.tolist(),
                                      batch_first=True)

            if init_states:
                packed_emb, states = self.rnns[i](packed_emb, istates[i])
            else:
                packed_emb, states = self.rnns[i](packed_emb)

            if isinstance(states, tuple):
                h_n, c_n = states
                encoder_final['c_n'].append(c_n)
            else:
                h_n = states
            encoder_final['h_n'].append(h_n)

            packed_emb = unpack(
                packed_emb,
                batch_first=True)[0] if lengths is not None else packed_emb
            if not self.use_last or i == self.nlayers - 1:
                memory_bank += [packed_emb[_indices]
                                ] if lengths is not None else [packed_emb]

        assert len(encoder_final['h_n']) != 0
        if self.use_last:
            memory_bank = memory_bank[-1]
            if len(encoder_final['c_n']) == 0:
                encoder_final = encoder_final['h_n'][-1]
            else:
                encoder_final = encoder_final['h_n'][-1], encoder_final['c_n'][
                    -1]
        else:
            memory_bank = torch.cat(memory_bank, dim=2)
            if len(encoder_final['c_n']) == 0:
                encoder_final = torch.cat(encoder_final['h_n'], dim=0)
            else:
                encoder_final = torch.cat(encoder_final['h_n'], dim=0), \
                                torch.cat(encoder_final['c_n'], dim=0)

        if self.use_bridge:
            encoder_final = self._bridge(encoder_final)

        # TODO: Temporary hack is adopted to compatible with DataParallel
        # reference: https://github.com/pytorch/pytorch/issues/1591
        if memory_bank.size(1) < emb.size(1):
            dummy_tensor = torch.zeros(
                memory_bank.size(0),
                emb.size(1) - memory_bank.size(1),
                memory_bank.size(2)).type_as(memory_bank)
            memory_bank = torch.cat([memory_bank, dummy_tensor], 1)

        return encoder_final, memory_bank
Exemplo n.º 15
0
    def forward(self, src, lengths=None, is_knowledge=False):
        """
        run transformer encoder
        :param src: source input
        :param lengths: sorted lengths
        :return: output and state (if with rnn)
        """
        if self.config.conditioned and not is_knowledge:
            # HACK: recover the original sentence without the condition
            conditions_1 = src[[length - 1 for length in lengths],
                               range(src.shape[1])]
            conditions_2 = src[[length - 2 for length in lengths],
                               range(src.shape[1])]
            src[[length - 1 for length in lengths],
                range(src.shape[1])] = self.padding_idx
            src[[length - 2 for length in lengths],
                range(src.shape[1])] = self.padding_idx
            lengths = [length - 2 for length in lengths]
            assert all([length > 0 for length in lengths])
            # print(conditions.shape) # batch_size
            # print(src.shape) # max_len X batch_size
            conditions_1 = conditions_1.unsqueeze(0)  # 1 X batch_size
            conditions_2 = conditions_2.unsqueeze(0)  # 1 X batch_size
        embed = self.embedding(src)

        if self.config.embed_only:
            return embed

        # RNN for positional information
        if self.config.positional:
            emb = self.position_embedding(embed)  # [len, batch, size]
        else:
            emb, state = self.rnn(pack(embed, lengths))
            emb = unpack(emb)[0]  # [len, batch, 2*size]
            emb = emb[:, :, :self.config.hidden_size] + \
                emb[:, :, self.config.hidden_size:]  # [len, batch, size]
            emb = emb + embed  # [len, batch, size]
            state = (state[0][0], state[1][0])  # LSTM states

        if self.config.conditioned and not is_knowledge:
            assert self.config.positional
            conditions_1_embed = self.embedding(conditions_1)
            conditions_1_embed = conditions_1_embed.expand_as(embed)
            conditions_2_embed = self.embedding(conditions_2)
            conditions_2_embed = conditions_2_embed.expand_as(embed)
            # Concat
            # emb = torch.cat([emb, conditions_embed], dim=-1)
            # emb = self.embed_transform(emb)
            # emb = torch.cat([emb, conditions_1_embed + conditions_2_embed], dim=-1)
            # emb = self.embed_transform(emb)
            # Add
            # emb = emb + conditions_embed
            emb = emb + conditions_1_embed + conditions_2_embed
            # Remove condition
            # emb = emb

        out = emb.transpose(0, 1).contiguous()  # [batch, len, size]
        src_words = src.transpose(0, 1)  # [batch, len]
        src_batch, src_len = src_words.size()
        padding_idx = self.padding_idx
        mask = src_words.data.eq(padding_idx).unsqueeze(1) \
            .expand(src_batch, src_len, src_len)    # [batch, len, len]

        for i in range(self.num_layers):
            out = self.transformer[i](out, mask)
        out = self.layer_norm(out)  # [batch, len, size]

        assert self.config.positional
        if self.config.positional:
            # out = self.condition_context_attn(out, conditions_embed)
            # out = self.bi_attn_control_exp(out)
            return out.transpose(0, 1)
        else:
            return out.transpose(0, 1), state  # [len, batch, size]
Exemplo n.º 16
0
    def forward(self,
                sents,
                lengths,
                fts=[],
                rel_idxs=[],
                lidx_start=[],
                lidx_end=[],
                ridx_start=[],
                ridx_end=[],
                pred_ind=True,
                flip=False,
                causal=False,
                token_type_ids=None,
                task='relation'):

        batch_size = sents.size(0)
        # dropout
        out = self.dropout(sents)
        # pack and lstm layer
        out, _ = self.lstm(pack(out, lengths, batch_first=True))
        # unpack
        out, _ = unpack(out, batch_first=True)

        ### entity prediction - predict each input token
        if task == 'entity':
            out_ent = self.linear1_ent(self.dropout(out))
            out_ent = self.act(out_ent)
            out_ent = self.linear2_ent(out_ent)
            prob_ent = self.softmax_ent(out_ent)
            return out_ent, prob_ent

        ### relaiton prediction - flatten hidden vars into a long vector
        if task == 'relation':

            ltar_f = torch.cat([
                out[b, lidx_start[b][r], :self.hid_size].unsqueeze(0)
                for b, r in rel_idxs
            ],
                               dim=0)
            ltar_b = torch.cat([
                out[b, lidx_end[b][r], self.hid_size:].unsqueeze(0)
                for b, r in rel_idxs
            ],
                               dim=0)
            rtar_f = torch.cat([
                out[b, ridx_start[b][r], :self.hid_size].unsqueeze(0)
                for b, r in rel_idxs
            ],
                               dim=0)
            rtar_b = torch.cat([
                out[b, ridx_end[b][r], self.hid_size:].unsqueeze(0)
                for b, r in rel_idxs
            ],
                               dim=0)

            out = self.dropout(
                torch.cat((ltar_f, ltar_b, rtar_f, rtar_b), dim=1))
            out = torch.cat((out, fts), dim=1)

            # linear prediction
            out = self.linear1(out)
            out = self.act(out)
            out = self.dropout(out)
            out = self.linear2(out)
            prob = self.softmax(out)
            return out, prob
Exemplo n.º 17
0
    def test(self):
        self.ds.set_split("test", self.args.num_samples)
        thresh = 1. / 50.
        prec = 0.
        reca = 0.
        acc = 0.
        num_batches = len(self.dl)
        num_labels = len(self.ds.labels_dict)
        infer_outputs = []
        counter_array = np.zeros((num_labels, 6)) # tgts, preds, tp, fp, tn, fn
        if any(x in self.model_name for x in ["resnet", "squeezenet"]):
            m = self.model_list[0]
            # set model(s) into eval mode
            m.eval()
            with tqdm(total=num_batches, leave=False, position=1,
                      postfix={"accuracy": acc, "precision": prec}) as t:
                for mb, tgts in self.dl:
                    mb = mb.to(self.device)
                    tgts = tgts.to(torch.device("cpu"))
                    # run inference
                    out = m(mb)
                    # move output to cpu for analysis / numpy
                    out = out.to(torch.device("cpu"))
                    infer_outputs.append((out.numpy().tolist(), tgts.numpy().tolist()))
                    if self.loss_criterion == "crossentropy":
                        out = F.softmax(out, dim = 1)
                    else:
                        out = F.sigmoid(out)
                        #out = F.softmax(out, dim = 1)
                    # out is either size (N, C) or (N, )
                    for tgt, o in zip(tgts, out):
                        o_mask = torch.zeros_like(o)
                        o_mask[torch.topk(o, tgt.sum().int().item())[1]] = 1.
                        o_mask = o_mask.numpy()
                        o_mask = o_mask.astype(np.bool)
                        tgt = tgt.numpy()
                        tgt_mask = tgt == 1.
                        counter_array[tgt_mask, 0] += 1
                        #print(o_mask); break;

                        counter_array[o_mask, 1] += 1
                        tp = np.logical_and(tgt_mask==True, o_mask==True)  # this will be deflated for cross entorpy
                        fp = np.logical_and(tgt_mask==False, o_mask==True)
                        tn = np.logical_and(tgt_mask==False, o_mask==False)
                        fn = np.logical_and(tgt_mask==True, o_mask==False)

                        counter_array[tp, 2] += 1
                        counter_array[fp, 3] += 1
                        counter_array[tn, 4] += 1
                        counter_array[fn, 5] += 1

                        k = int(np.sum(tgt_mask))
                        tmp1 = torch.topk(o, k)[1]  # get indicies
                        tmp2 = np.where(tgt == 1.)[0]
                    #acc = counter_array[:, 0].sum() / counter_array[:, 0].sum()
                    ttp = counter_array[:, 2].sum()
                    tfp = counter_array[:, 3].sum()
                    ttn = counter_array[:, 4].sum()
                    tfn = counter_array[:, 5].sum()
                    prec = ttp / (ttp + tfp)
                    reca = ttp / (ttp + tfn)
                    acc = (ttp + ttn) / (ttp + tfp + ttn + tfn)
                    t.set_postfix({"accuracy": "{0:.4f}".format(acc * 100.), "precision": "{0:.4f}".format(prec * 100.)})
                    t.update()
                    #correct += (out_valid.detach().max(1)[1] == tgts_valid.detach()).sum()
        elif "attn" in self.model_name:
            encoder = self.model_list[0]
            decoder = self.model_list[1]
            # set model(s) into eval mode
            encoder.eval()
            decoder.eval()
            with tqdm(total=num_batches, leave=True, position=1,
                      postfix={"accuracy": acc, "precision": prec}) as t:
                for i, ((mb, lengths), tgts) in enumerate(self.dl):
                    # set model into train mode and clear gradients

                    # move inputs to cuda if required
                    mb = mb.to(self.device)
                    tgts = tgts.to(torch.device("cpu"))
                    # init hidden before packing
                    encoder_hidden = encoder.initHidden(mb)

                    # set inputs and targets
                    mb = pack(mb, lengths, batch_first=True)
                    #print(mb.size(), tgts.size())
                    encoder_output, encoder_hidden = encoder(mb, encoder_hidden)

                    #print(encoder_output.detach().new(dec_size).size())
                    #enc_out_var, enc_out_len = unpack(encoder_output, batch_first=True)
                    #dec_i = enc_out_var.new_zeros((self.batch_size, 1, encoder.hidden_size))
                    dec_h = encoder_hidden # Use last (forward) hidden state from encoder
                    #print(decoder.n_layers, encoder_hidden.size(), dec_i.size(), dec_h.size())

                    # run through decoder in one shot
                    mb, _ = unpack(mb, batch_first=True)
                    out, dec_h, dec_attn = decoder(mb, dec_h, encoder_output)
                    # calculate loss
                    out = out.to(torch.device("cpu"))
                    out.squeeze_()
                    infer_outputs.append((out.numpy().tolist(), tgts.numpy().tolist()))
                    if self.loss_criterion == "crossentropy":
                        out = F.softmax(out, dim = 1)
                    else:
                        out = F.sigmoid(out)
                    # out is either size (N, C) or (N, )
                    for tgt, o in zip(tgts, out):
                        o_mask = torch.zeros_like(o)
                        o_mask[torch.topk(o, tgt.sum().int().item())[1]] = 1.
                        o_mask = o_mask.numpy()
                        o_mask = o_mask.astype(np.bool)
                        tgt = tgt.numpy()
                        tgt_mask = tgt == 1.
                        counter_array[tgt_mask, 0] += 1
                        #print(o_mask); break;

                        counter_array[o_mask, 1] += 1
                        tp = np.logical_and(tgt_mask==True, o_mask==True)  # this will be deflated for cross entorpy
                        fp = np.logical_and(tgt_mask==False, o_mask==True)
                        tn = np.logical_and(tgt_mask==False, o_mask==False)
                        fn = np.logical_and(tgt_mask==True, o_mask==False)

                        counter_array[tp, 2] += 1
                        counter_array[fp, 3] += 1
                        counter_array[tn, 4] += 1
                        counter_array[fn, 5] += 1
                    ttp = counter_array[:, 2].sum()
                    tfp = counter_array[:, 3].sum()
                    ttn = counter_array[:, 4].sum()
                    tfn = counter_array[:, 5].sum()
                    prec = ttp / (ttp + tfp)
                    reca = ttp / (ttp + tfn)
                    acc = (ttp + ttn) / (ttp + tfp + ttn + tfn)
                    t.set_postfix({"accuracy": "{0:.4f}".format(acc * 100.), "precision": "{0:.4f}".format(prec * 100.)})
                    t.update()
        elif "bytenet" in self.model_name:
            encoder = self.model_list[0]
            decoder = self.model_list[1]
            # set model(s) into eval mode
            encoder.eval()
            decoder.eval()
            with tqdm(total=num_batches, leave=True, position=1,
                      postfix={"accuracy": acc, "precision": prec}) as t:
                for i, (mb, tgts) in enumerate(self.dl):
                    # set inputs and targets
                    mb, tgts = mb.to(self.device), tgts.to(torch.device("cpu"))
                    mb = encoder(mb)
                    out = decoder(mb)
                    out = out.to(torch.device("cpu"))
                    infer_outputs.append((out.numpy().tolist(), tgts.numpy().tolist()))
                    if self.loss_criterion == "crossentropy":
                        out = F.softmax(out, dim = 1)
                    else:
                        out = F.sigmoid(out)
                    # out is either size (N, C) or (N, )
                    for tgt, o in zip(tgts, out):
                        o_mask = torch.zeros_like(o)
                        o_mask[torch.topk(o, tgt.sum().int().item())[1]] = 1.
                        o_mask = o_mask.numpy()
                        o_mask = o_mask.astype(np.bool)
                        tgt = tgt.numpy()
                        tgt_mask = tgt == 1.
                        counter_array[tgt_mask, 0] += 1
                        #print(o_mask); break;

                        counter_array[o_mask, 1] += 1
                        tp = np.logical_and(tgt_mask==True, o_mask==True)  # this will be deflated for cross entorpy
                        fp = np.logical_and(tgt_mask==False, o_mask==True)
                        tn = np.logical_and(tgt_mask==False, o_mask==False)
                        fn = np.logical_and(tgt_mask==True, o_mask==False)

                        counter_array[tp, 2] += 1
                        counter_array[fp, 3] += 1
                        counter_array[tn, 4] += 1
                        counter_array[fn, 5] += 1
                    ttp = counter_array[:, 2].sum()
                    tfp = counter_array[:, 3].sum()
                    ttn = counter_array[:, 4].sum()
                    tfn = counter_array[:, 5].sum()
                    prec = ttp / (ttp + tfp)
                    reca = ttp / (ttp + tfn)
                    acc = (ttp + ttn) / (ttp + tfp + ttn + tfn)
                    t.set_postfix({"accuracy": "{0:.4f}".format(acc * 100.), "precision": "{0:.4f}".format(prec * 100.)})
                    t.update()

        else:
            raise NotImplemented
        self.infer_stats = counter_array
        self.infer_outputs = infer_outputs
Exemplo n.º 18
0
    def validate(self, epoch):
        self.ds.set_split("valid", self.args.num_samples)
        running_validation_loss = []
        accuracies = []
        acc = 0
        threshold = 1 - (1. / 3.)
        num_batches = len(self.dl)
        if any(x in self.model_name for x in ["resnet", "squeezenet"]):
            m = self.model_list[0]
            # set model(s) into eval mode
            m.eval()
            with tqdm(total=num_batches, leave=True, position=2,
                      postfix={"acc": acc, "loss": "{0:.6f}".format(0.)}) as t:
                for mb_valid, tgts_valid in self.dl:
                    mb_valid = mb_valid.to(self.device)
                    tgts_valid = tgts_valid.to(torch.device("cpu"))
                    out_valid = m(mb_valid)
                    out_valid = out_valid.to(torch.device("cpu"))
                    if "margin" in self.loss_criterion:
                        out_valid = F.sigmoid(out_valid)
                    if self.loss_criterion == "margin":
                        tgts_valid = tgts_valid.long()
                    loss_valid = self.criterion(out_valid, tgts_valid)
                    running_validation_loss += [loss_valid.item()]
                    if "margin" not in self.loss_criterion:
                        out_valid = F.sigmoid(out_valid)
                    if self.loss_criterion == "crossentropy":
                        out_pred = out_valid.max(1)[1]
                        acc = (out_pred == tgts_valid).sum().item() / tgts_valid.size(0)
                    else:
                        acc = 0.
                        num_out = out_valid.size(0)
                        for ov, tgt in zip(out_valid, tgts_valid):
                            tgt = torch.LongTensor([i for i, x in enumerate(tgt) if x == 1])
                            num_tgt = tgt.size(0)
                            ov = torch.topk(ov, num_tgt)[1]
                            correct = len(np.intersect1d(tgt.numpy(), ov.numpy()))
                            acc += (correct / num_tgt) / num_out
                    accuracies.append(acc)
                    t.set_postfix({"acc": acc, "loss": "{0:.6f}".format(running_validation_loss[-1])})
                    t.update()
                    #correct += (out_valid.detach().max(1)[1] == tgts_valid.detach()).sum()
        elif "attn" in self.model_name:
            encoder = self.model_list[0]
            decoder = self.model_list[1]
            # set model(s) into eval mode
            encoder.eval()
            decoder.eval()
            with tqdm(total=num_batches, leave=True, position=2,
                      postfix={"acc": acc, "loss": "{0:.6f}".format(0.)}) as t:
                for i, ((mb_valid, lengths), tgts_valid) in enumerate(self.dl):
                    # set model into train mode and clear gradients

                    # move inputs to cuda if required
                    mb_valid = mb_valid.to(self.device)
                    tgts_valid = tgts_valid.to(torch.device("cpu"))
                    # init hidden before packing
                    encoder_hidden = encoder.initHidden(mb_valid)

                    # set inputs and targets
                    mb_valid = pack(mb_valid, lengths, batch_first=True)
                    #print(mb.size(), tgts.size())
                    encoder_output, encoder_hidden = encoder(mb_valid, encoder_hidden)

                    #print(encoder_output.detach().new(dec_size).size())
                    #enc_out_var, enc_out_len = unpack(encoder_output, batch_first=True)
                    #dec_i = enc_out_var.new_zeros((self.batch_size, 1, encoder.hidden_size))
                    dec_h = encoder_hidden # Use last (forward) hidden state from encoder
                    #print(decoder.n_layers, encoder_hidden.size(), dec_i.size(), dec_h.size())

                    # run through decoder in one shot
                    mb_valid, _ = unpack(mb_valid, batch_first=True)
                    out_valid, dec_h, dec_attn = decoder(mb_valid, dec_h, encoder_output)
                    # calculate loss
                    out_valid = out_valid.to(torch.device("cpu"))
                    out_valid.squeeze_()
                    if "margin" in self.loss_criterion:
                        out_valid = F.sigmoid(out_valid)
                    if self.loss_criterion == "margin":
                        tgts_valid = tgts_valid.long()
                    loss_valid = self.criterion(out_valid, tgts_valid)
                    running_validation_loss += [loss_valid.item()]
                    if "margin" not in self.loss_criterion:
                        out_valid = F.sigmoid(out_valid)
                    if self.loss_criterion == "crossentropy":
                        out_pred = out_valid.max(1)[1]
                        acc = (out_pred == tgts_valid).sum().item() / tgts_valid.size(0)
                    else:
                        acc = 0.
                        num_out = out_valid.size(0)
                        for ov, tgt in zip(out_valid, tgts_valid):
                            tgt = torch.LongTensor([i for i, x in enumerate(tgt) if x == 1])
                            num_tgt = tgt.size(0)
                            ov = torch.topk(ov, num_tgt)[1]
                            correct = len(np.intersect1d(tgt.numpy(), ov.numpy()))
                            acc += (correct / num_tgt) / num_out
                    accuracies.append(acc)
                    t.set_postfix({"acc": acc, "loss": "{0:.6f}".format(running_validation_loss[-1])})
                    t.update()
                    #correct += (dec_o.detach().max(1)[1] == tgts.detach()).sum()
        elif "bytenet" in self.model_name:
            encoder = self.model_list[0]
            decoder = self.model_list[1]
            # set model(s) into eval mode
            encoder.eval()
            decoder.eval()
            with tqdm(total=num_batches, leave=True, position=2,
                      postfix={"acc": acc, "loss": "{0:.6f}".format(0.)}) as t:
                for i, (mb_valid, tgts_valid) in enumerate(self.dl):
                    # set inputs and targets
                    mb_valid, tgts_valid = mb_valid.to(self.device), tgts_valid.to(torch.device("cpu"))
                    mb_valid = encoder(mb_valid)
                    out_valid = decoder(mb_valid)
                    if "margin" in self.loss_criterion:
                        out_valid = F.sigmoid(out_valid)
                    if self.loss_criterion == "margin":
                        tgts_valid = tgts_valid.long()
                    out_valid = out_valid.to(torch.device("cpu"))
                    loss_valid = self.criterion(out_valid, tgts_valid)
                    running_validation_loss += [loss_valid.item()]
                    if "margin" not in self.loss_criterion:
                        out_valid = F.sigmoid(out_valid)
                    if self.loss_criterion == "crossentropy":
                        out_pred = out_valid.max(1)[1]
                        acc = (out_pred == tgts_valid).sum().item() / tgts_valid.size(0)
                    else:
                        acc = 0.
                        num_out = out_valid.size(0)
                        for ov, tgt in zip(out_valid, tgts_valid):
                            tgt = torch.LongTensor([i for i, x in enumerate(tgt) if x == 1])
                            num_tgt = tgt.size(0)
                            ov = torch.topk(ov, num_tgt)[1]
                            correct = len(np.intersect1d(tgt.numpy(), ov.numpy()))
                            acc += (correct / num_tgt) / num_out
                    accuracies.append(acc)
                    t.set_postfix({"acc": acc, "loss": "{0:.6f}".format(running_validation_loss[-1])})
                    t.update()
                    #correct += (dec_o.detach().max(1)[1] == tgts.detach()).sum()

        self.valid_losses.append((running_validation_loss, accuracies))
Exemplo n.º 19
0
    def forward(self, word_inputs, feat_inputs, word_seq_length, char_inputs,
                char_seq_length, char_recover, dict_inputs, mask, batch_bert):
        """

             word_inputs: (batch_size,seq_len)
             word_seq_length:()
        """
        batch_size = word_inputs.size(0)
        seq_len = word_inputs.size(1)
        word_emb = self.word_embedding(word_inputs)
        if self.args.use_elmo:
            elmo_emb = self.elmo_embedding(word_inputs)
        # if self.args.use_bert:
        #     word_emb = torch.cat((word_emb,torch.squeeze(batch_bert,2)),2)
        #elmo_emb = self.drop(elmo_emb)

        # word_rep = word_emb
        if self.args.use_char:
            size = char_inputs.size(0)
            char_emb = self.char_embedding(char_inputs)
            char_emb = pack(char_emb,
                            char_seq_length.cpu().numpy(),
                            batch_first=True)
            char_lstm_out, char_hidden = self.char_feature(char_emb)
            char_lstm_out = pad(char_lstm_out, batch_first=True)
            char_hidden = char_hidden[0].transpose(1, 0).contiguous().view(
                size, -1)
            char_hidden = char_hidden[char_recover]
            char_hidden = char_hidden.view(batch_size, seq_len, -1)
            if self.args.attention:
                word_rep = F.tanh(
                    self.attn1(word_emb) + self.attn2(char_hidden))
                z = F.sigmoid(self.attn3(word_rep))
                x = 1 - z
                word_rep = F.mul(z, word_emb) + F.mul(x, char_hidden)
            else:
                word_rep = torch.cat((word_emb, char_hidden), 2)
                word_rep = self.word_drop(word_rep)  #word represent dropout
        #if self.args.use_elmo:
        #    word_rep = torch.cat((word_rep, elmo_emb), 2)
        if self.args.feature:
            for idx in range(self.feature_num):
                word_rep = torch.cat(
                    (word_rep, self.feature_embeddings[idx](feat_inputs[idx])),
                    2)
        # batch_bert = torch.split(batch_bert,1,dim=2)
        # normed_weights = F.softmax(self.scalar_parameters, dim=0)
        # y = self.gamma * sum(weight * tensor.squeeze(2) for weight, tensor in zip(normed_weights,batch_bert))

        # x = F.softmax(torch.mean(batch_bert,dim=2))
        x = F.softmax(torch.mean(batch_bert, dim=2))
        if self.args.use_bert:
            word_rep = torch.cat((word_rep, x), 2)
        word_rep = pack(word_rep,
                        word_seq_length.cpu().numpy(),
                        batch_first=True)
        out, hidden = self.word_feature(word_rep)
        out, _ = pad(out, batch_first=True)
        if self.args.use_elmo:
            out = torch.cat((out, elmo_emb), 2)
        if self.args.out_dict:
            dict_rep = pack(dict_inputs,
                            word_seq_length.cpu().numpy(),
                            batch_first=True)
            dict_out, hidden = self.dict_feature(dict_rep)
            dict_out, _ = pad(dict_out, batch_first=True)
            #dict_out = self.dict_fc(dict_inputs)
            out = torch.cat((out, dict_out), 2)
        if self.args.lstm_attention:

            out_list, weight_list = [], []
            for idx in range(seq_len):
                # slice_out = out[:,0:idx+1,:]
                if idx + 2 > seq_len:
                    slice_out = out
                else:
                    slice_out = out[:, 0:idx + 2, :]
                # slice_out = out
                slice_out, weights = self.attention(slice_out)
                # slice_out, weights = SelfAttention(self.args.hidden_dim*2).forward(slice_out)
                out_list.append(slice_out.unsqueeze(1))
                weight_list.append(weights)
            out = torch.cat(out_list, dim=1)
        out = self.drop(out)
        out = self.hidden2tag(out)
        return out
Exemplo n.º 20
0
    def forward(self, x, lens, k, kx):
        # model takes as input the text, aspect, and location
        # runs BLSTM over text using embedding(location, aspect) as
        # the initial hidden state, as opposed to a different lstm for every pair???
        # output sentiment

        # DBG
        words = x

        emb = self.drop(self.lut(x))
        p_emb = pack(emb, lens, True)

        l, a = k
        N = l.shape[0]
        T = x.shape[1]
        # factor this out, for sure. POSSIBLE BUGS
        y_idx = l * len(self.A) + a
        s = (self.lut_la(y_idx)
            .view(N, 2, 2 * self.nlayers, self.rnn_sz)
            .permute(1, 2, 0, 3)
            .contiguous())
        state = (s[0], s[1])
        x, (h, c) = self.rnn(p_emb, state)
        # h: L * D x N x H
        x = unpack(x, True)[0]
        # Get the last hidden states for both directions, POSSIBLE BUGS
        phi_s = self.proj_s(x)
        #"""
        idxs = torch.arange(0, max(lens)).to(lens.device)
        # mask: N x R x 1
        mask = (idxs.repeat(len(lens), 1) >= lens.unsqueeze(-1))
        phi_s[:,:,-1].masked_fill_(1-mask, float("-inf"))
        phi_s[:,:,:3].masked_fill_(mask.unsqueeze(-1), float("-inf"))
        #"""
        """
        h = (h
            .view(self.nlayers, 2, -1, self.rnn_sz)[-1]
            .permute(1, 0, 2)
            .contiguous()
            .view(-1, 2 * self.rnn_sz))
        phi_y = self.proj_y(h)
        """
        phi_y = torch.zeros(N, len(self.S)).to(self.psi_ys.device)
        psi_ys = torch.cat(
            [torch.diag(self.psi_ys), torch.zeros(len(self.S), 1).to(self.psi_ys)],
            dim=-1,
        ).expand(T, len(self.S), len(self.S)+1)
        #psi_ys = torch.diag(self.psi_ys).repeat(T, 1, 1)
        # Z is really weird here
        Z, hy = ubersum("nts,tys,ny->n,ny", phi_s, psi_ys, phi_y, batch_dims="t", modulo_total=True)
        #Z, hy = ubersum("nts,tys,ny->n,ny", phi_s, psi_ys, phi_y, batch_dims="t", modulo_total=True)
        def stuff(i):
            loc = self.L.itos[l[i]]
            asp = self.A.itos[a[i]]
            return self.tostr(words[i]), loc, asp, xp[i], yp[i]
        if self.training:
            self._N += 1
        if self._N > 100 and self.training:
            Zx, hx = ubersum("nts,tys->nt,nts", phi_s, psi_ys, batch_dims="t", modulo_total=True)
            xp = (hx - Zx.unsqueeze(-1)).exp()
            yp = (hy - Z.unsqueeze(-1)).exp()
            #Zx, hx = ubersum("nts,ys->nt,nts", phi_s, self.psi_ys, batch_dims="t")
            import pdb; pdb.set_trace()
            pass
            # text, loc, asp, xpi, ypi = stuff(10)
        #import pdb; pdb.set_trace()
        return hy# - Z.unsqueeze(-1)
Exemplo n.º 21
0
 def forward(self, inputs, hidden=None):
     emb = pack(self.word_lut(inputs[0]), inputs[1])
     outputs, hidden_t = self.rnn(emb, hidden)
     outputs = unpack(outputs)[0]
     return hidden_t, outputs
Exemplo n.º 22
0
    def fit(self, epoch, early_stop=None):
        epoch_losses = []
        self.ds.set_split("train")
        self.adjust_opt_params(epoch)
        self.scheduler.step()
        #self.optimizer = self.get_optimizer(epoch)
        num_batches = len(self.dl)
        if any(x in self.model_name for x in ["resnet", "squeezenet"]):
            if self.use_precompute:
                pass # TODO implement network precomputation
                #self.precompute(self.L["fc_layer"]["precompute"])
            m = self.model_list[0]
            with tqdm(total=num_batches, leave=False, position=1) as t:
                for i, (mb, tgts) in enumerate(self.dl):
                    if i == early_stop: break
                    m.train()
                    mb, tgts = mb.to(self.device), tgts.to(self.device)
                    m.zero_grad()
                    out = m(mb)
                    if "margin" in self.loss_criterion:
                        out = F.sigmoid(out)
                    if self.loss_criterion == "margin":
                        tgts = tgts.long()
                    #print(tgts)
                    loss = self.criterion(out, tgts)
                    loss.backward()
                    self.optimizer.step()
                    epoch_losses.append(loss.item())
                    if self.tqdmiter:
                        self.tqdmiter.set_postfix({"loss": "{0:.6f}".format(epoch_losses[-1])})
                        self.tqdmiter.refresh()
                    else:
                        print(epoch_losses[-1])
                    if i % self.log_interval == 0 and self.do_validate and i != 0:
                        with torch.no_grad():
                            self.validate(epoch)
                            self.ds.set_split("train")
                    t.update()
        elif "attn" in self.model_name:
            encoder = self.model_list[0]
            decoder = self.model_list[1]
            with tqdm(total=num_batches, leave=False, position=1) as t:
                for i, ((mb, lengths), tgts) in enumerate(self.dl):
                    # set model into train mode and clear gradients
                    encoder.train()
                    decoder.train()
                    encoder.zero_grad()
                    decoder.zero_grad()

                    # set inputs and targets
                    mb, tgts = mb.to(self.device), tgts.to(self.device)

                    # create the initial hidden input before packing sequence
                    encoder_hidden = encoder.initHidden(mb)

                    # pack sequence
                    mb = pack(mb, lengths, batch_first=True)
                    #print(mb.size(), tgts.size())
                    # encode sequence
                    encoder_output, encoder_hidden = encoder(mb, encoder_hidden)

                    # Prepare input and output variables for decoder
                    #dec_size = [[[0] * encoder.hidden_size]*1]*self.batch_size
                    #print(encoder_output.detach().new(dec_size).size())
                    #enc_out_var, enc_out_len = unpack(encoder_output, batch_first=True)
                    #dec_i = enc_out_var.new_zeros((self.batch_size, 1, encoder.hidden_size))
                    dec_h = encoder_hidden # Use last (forward) hidden state from encoder
                    #print(decoder.n_layers, encoder_hidden.size(), dec_i.size(), dec_h.size())

                    # run through decoder in one shot
                    mb, _ = unpack(mb, batch_first=True)
                    dec_o, dec_h, dec_attn = decoder(mb, dec_h, encoder_output)
                    dec_o.squeeze_()
                    #print(dec_o)
                    #print(dec_o.size(), dec_h.size(), dec_attn.size(), tgts.size())
                    #print(dec_o.view(-1, decoder.output_size).size(), tgts.view(-1).size())

                    # calculate loss and backprop
                    if "margin" in self.loss_criterion:
                        dec_o = F.sigmoid(dec_o)
                    if self.loss_criterion == "margin":
                        tgts = tgts.long()
                    loss = self.criterion(dec_o, tgts)
                    #nn.utils.clip_grad_norm(encoder.parameters(), 0.05)
                    #nn.utils.clip_grad_norm(decoder.parameters(), 0.05)
                    loss.backward()
                    self.optimizer.step()
                    epoch_losses.append(loss.item())
                    if self.tqdmiter:
                        self.tqdmiter.set_postfix({"loss": "{0:.6f}".format(epoch_losses[-1])})
                        self.tqdmiter.refresh()
                    else:
                        print(epoch_losses[-1])
                    if i % self.log_interval == 0 and self.do_validate and i != 0:
                        with torch.no_grad():
                            self.validate(epoch)
                            self.ds.set_split("train")
                    t.update()
        elif "bytenet" in self.model_name:
            encoder = self.model_list[0]
            decoder = self.model_list[1]
            with tqdm(total=num_batches, leave=False, position=1) as t:
                for i, (mb, tgts) in enumerate(self.dl):
                    # set model into train mode and clear gradients
                    encoder.train()
                    decoder.train()
                    encoder.zero_grad()
                    decoder.zero_grad()
                    # set inputs and targets
                    mb, tgts = mb.to(self.device), tgts.to(self.device)
                    mb = encoder(mb)
                    out = decoder(mb)
                    if "margin" in self.loss_criterion:
                        out = F.sigmoid(out)
                    if self.loss_criterion == "margin":
                        tgts = tgts.long()
                    loss = self.criterion(out, tgts)
                    loss.backward()
                    self.optimizer.step()
                    epoch_losses.append(loss.item())
                    if self.tqdmiter:
                        self.tqdmiter.set_postfix({"loss": "{0:.6f}".format(epoch_losses[-1])})
                        self.tqdmiter.refresh()
                    else:
                        print(epoch_losses[-1])
                    if i % self.log_interval == 0 and self.do_validate and i != 0:
                        with torch.no_grad():
                            self.validate(epoch)
                            self.ds.set_split("train")
                    t.update()
        self.train_losses.append(epoch_losses)
        if epoch % 10 == 0 and epoch != 0 and self.use_cache:
            self.ds.init_cache()