Example #1
0
    def forward(self, tgt_seq, src_seq, enc_output):

        dec_slf_attn_list, dec_enc_attn_list = [], []

        # -- Prepare masks
        non_pad_mask = get_non_pad_mask(tgt_seq)

        slf_attn_mask_subseq = get_subsequent_mask(tgt_seq)
        slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq)
        slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)

        dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq)
            
        tgt_pos = torch.arange(1, tgt_seq.size(-1) + 1).unsqueeze(0).repeat(tgt_seq.size(0), 1).to(self.device)
        # -- Forward
        dec_output = self.embedding(tgt_seq) + self.position_enc(tgt_pos)
        
        for dec_layer in self.layer_stack:
            dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
                dec_output, enc_output,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask,
                dec_enc_attn_mask=dec_enc_attn_mask)
        
        return self.last_linear(dec_output)
Example #2
0
def log_likelihood(model, data, time, types):
    """ Log-likelihood of sequence. """

    non_pad_mask = get_non_pad_mask(types).squeeze(2)

    type_mask = torch.zeros([*types.size(), model.num_types],
                            device=data.device)
    for i in range(model.num_types):
        type_mask[:, :, i] = (types == i + 1).bool().to(data.device)

    all_hid = model.linear(data)
    all_lambda = F.softplus(all_hid, threshold=10)
    type_lambda = torch.sum(all_lambda * type_mask, dim=2)

    # event log-likelihood
    event_ll = compute_event(type_lambda, non_pad_mask)
    event_ll = torch.sum(event_ll, dim=-1)

    # non-event log-likelihood, either numerical integration or MC integration
    # non_event_ll = compute_integral_biased(type_lambda, time, non_pad_mask)
    non_event_ll = compute_integral_unbiased(model, data, time, non_pad_mask,
                                             type_mask)
    non_event_ll = torch.sum(non_event_ll, dim=-1)

    return event_ll, non_event_ll
Example #3
0
    def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False):

        dec_slf_attn_list, dec_enc_attn_list = [], []

        # -- Prepare masks
        non_pad_mask = get_non_pad_mask(tgt_seq)

        slf_attn_mask_subseq = get_subsequent_mask(tgt_seq).type(torch.cuda.LongTensor)
        slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq).type(torch.cuda.LongTensor)
        slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)

        dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq)

        # -- Forward
        dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos)

        for dec_layer in self.layer_stack:
            dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
                dec_output, enc_output,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask,
                dec_enc_attn_mask=dec_enc_attn_mask)

            if return_attns:
                dec_slf_attn_list += [dec_slf_attn]
                dec_enc_attn_list += [dec_enc_attn]

        if return_attns:
            return dec_output, dec_slf_attn_list, dec_enc_attn_list
        return dec_output,
def time_loss(prediction, event_time):
    """ Time prediction loss. """
    non_pad_mask = get_non_pad_mask(event_time).squeeze(2)[:, 1:]
    prediction.squeeze_(-1)

    true = event_time[:, 1:] - event_time[:, :-1]
    prediction = prediction[:, :-1]

    # event time gap prediction
    diff = prediction - true
    diff *= non_pad_mask
    se = torch.sum(diff * diff)
    return se
Example #5
0
    def forward(self, src_seq, src_pos, return_attns=False):

        enc_slf_attn_list = []

        # -- Prepare masks
        slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
        non_pad_mask = get_non_pad_mask(src_seq)

        # -- Forward
        src_word_emb = self.src_enc_dropout(self.src_word_emb(src_seq))
        enc_output = F.tanh(
            self.src_word_enc(src_word_emb)) + self.position_enc(src_pos)

        k = enc_output  #.clone()

        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn = enc_layer(enc_output,
                                                 non_pad_mask=non_pad_mask,
                                                 slf_attn_mask=slf_attn_mask)
            if return_attns:
                enc_slf_attn_list += [enc_slf_attn]

        x = enc_output  #.clone()

        # k shape torch.Size([64, 72, 512])
        # x shape torch.Size([64, 72, 512])
        # gcl forward(self, k=None, x=None, bidirectional=False, save_attn=False)
        k = F.max_pool1d(k.permute(0, 2, 1), k.shape[-2]).squeeze(-1)
        x = F.max_pool1d(x.permute(0, 2, 1), x.shape[-2]).squeeze(-1)

        gcl_output = self.gcl(k.unsqueeze(0).detach(),
                              x.unsqueeze(0),
                              bidirectional=False,
                              save_attn=False)
        batch_size = x.shape[0]
        # gcl_output = self.gcl(k.unsqueeze(0).detach(), x.view(batch_size, -1).unsqueeze(0), bidirectional=False, save_attn=False)
        gcl_output = gcl_output.view(batch_size, 1, -1)
        # print(enc_output.shape, gcl_output.shape)

        # enc_output = enc_output + gcl_output
        n_positions = enc_output.shape[1]
        gcl_output = torch.cat(n_positions * [gcl_output], 1)
        enc_output = self.adapter(torch.cat([enc_output, gcl_output], -1))

        if return_attns:
            return enc_output, enc_slf_attn_list
        return enc_output,
Example #6
0
        def forward_seq(src_seq, src_pos, return_attns=return_attns):
            enc_slf_attn_list = []

            # -- Prepare masks
            slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
            non_pad_mask = get_non_pad_mask(src_seq)

            # -- Forward
            src_word_emb = self.src_enc_dropout(self.src_word_emb(src_seq))
            enc_output = F.tanh(self.src_word_enc(src_word_emb)) + self.position_enc(src_pos)
            enc_input = enc_output.clone()

            for enc_layer in self.layer_stack:
                enc_output, enc_slf_attn = enc_layer(
                    enc_output,
                    non_pad_mask=non_pad_mask,
                    slf_attn_mask=slf_attn_mask)
                if return_attns:
                    enc_slf_attn_list += [enc_slf_attn]

            if return_attns:
                return enc_output, enc_slf_attn_list, enc_input
            return enc_output, enc_input