Example #1
0
    def __init__(self, config):
        super(TokenEmbeder, self).__init__()
        self.conf = config
        self.margin = config['margin']
        self.n_token_words = config['n_token_words']
        self.n_desc_words = config['n_desc_words']
        self.emb_size = config['emb_size']
        self.n_hidden = config['n_hidden']
        self.dropout = config['dropout']

        self.tok_encoder = SeqEncoder(self.n_token_words, self.emb_size,
                                      self.n_hidden)
        self.desc_encoder = SeqEncoder(self.n_desc_words, self.emb_size,
                                       self.n_hidden)

        #self.w_tok = nn.Linear(config['n_hidden'], config['n_hidden'])
        #self.w_desc = nn.Linear(config['n_hidden'], config['n_hidden'])
        #self.fuse = nn.Linear(config['n_hidden'], config['n_hidden'])

        self.linear_attn_out = nn.Sequential(
            nn.Linear(self.n_hidden, self.n_hidden), nn.Tanh(),
            nn.Linear(self.n_hidden, self.n_hidden))

        if self.conf['transform_every_modal']:
            self.linear_single_modal = nn.Sequential(
                nn.Linear(self.n_hidden, self.n_hidden), nn.Tanh(),
                nn.Linear(self.n_hidden, self.n_hidden))

        if self.conf['save_attn_weight']:
            self.attn_weight_torch = []
            self.node_mask_torch = []

        self.self_atten = nn.Linear(self.n_hidden, self.n_hidden)
        self.self_atten_scalar = nn.Linear(self.n_hidden, 1)
    def __init__(self, config):
        super(JointEmbeder, self).__init__()
        self.conf = config
        self.margin = config['margin']

        self.name_encoder = SeqEncoder(config['n_words'], config['emb_size'],
                                       config['lstm_dims'])
        self.api_encoder = SeqEncoder(config['n_words'], config['emb_size'],
                                      config['lstm_dims'])
        self.tok_encoder = BOWEncoder(config['n_words'], config['emb_size'],
                                      config['n_hidden'])
        self.desc_encoder = SeqEncoder(config['n_words'], config['emb_size'],
                                       config['lstm_dims'])
        #self.fuse1=nn.Linear(config['emb_size']+4*config['lstm_dims'], config['n_hidden'])
        #self.fuse2 = nn.Sequential(
        #    nn.Linear(config['emb_size']+4*config['lstm_dims'], config['n_hidden']),
        #    nn.BatchNorm1d(config['n_hidden'], eps=1e-05, momentum=0.1),
        #    nn.ReLU(),
        #    nn.Linear(config['n_hidden'], config['n_hidden']),
        #)
        self.w_name = nn.Linear(2 * config['lstm_dims'], config['n_hidden'])
        self.w_api = nn.Linear(2 * config['lstm_dims'], config['n_hidden'])
        self.w_tok = nn.Linear(config['emb_size'], config['n_hidden'])
        self.fuse3 = nn.Linear(config['n_hidden'], config['n_hidden'])

        self.init_weights()
Example #3
0
    def __init__(self, config):
        super(MultiEmbeder, self).__init__()
        self.conf = config

        self.margin = config['margin']
        self.emb_size = config['emb_size']
        self.n_hidden = config['n_hidden']
        self.dropout = config['dropout']

        self.n_desc_words = config['n_desc_words']
        self.n_token_words = config['n_token_words']

        self.ast_encoder = TreeLSTM(self.conf)
        self.cfg_encoder = GGNN(self.conf)
        self.tok_encoder = SeqEncoder(self.n_token_words, self.emb_size,
                                      self.n_hidden)
        self.desc_encoder = SeqEncoder(self.n_desc_words, self.emb_size,
                                       self.n_hidden)

        self.tok_attn = nn.Linear(self.n_hidden, self.n_hidden)
        self.tok_attn_scalar = nn.Linear(self.n_hidden, 1)
        self.ast_attn = nn.Linear(self.n_hidden, self.n_hidden)
        self.ast_attn_scalar = nn.Linear(self.n_hidden, 1)
        self.cfg_attn = nn.Linear(self.n_hidden, self.n_hidden)
        self.cfg_attn_scalar = nn.Linear(self.n_hidden, 1)

        self.attn_modal_fusion = nn.Linear(self.n_hidden * 3, self.n_hidden)
    def __init__(self, config):
        super(JointEmbeder, self).__init__()
        self.conf = config
        self.margin = config['margin']

        self.name_encoder = SeqEncoder(config['n_words'], config['emb_size'],
                                       config['lstm_dims'])
        self.api_encoder = SeqEncoder(config['n_words'], config['emb_size'],
                                      config['lstm_dims'])
        self.tok_encoder = BOWEncoder(config['n_words'], config['emb_size'],
                                      config['n_hidden'])
        self.desc_encoder = SeqEncoder(config['n_words'], config['emb_size'],
                                       config['lstm_dims'])
        self.fuse = nn.Linear(config['emb_size'] + 4 * config['lstm_dims'],
                              config['n_hidden'])

        # create a model path to store model info
        if not os.path.exists(config['workdir'] + 'models/'):
            os.makedirs(config['workdir'] + 'models/')
Example #5
0
    def __init__(self, config):
        super(JointEmbeder, self).__init__()
        self.conf = config
        self.margin = config['margin']
        self.dropout = config['dropout']
        self.n_hidden = config['n_hidden']

        self.name_encoder = SeqEncoder(config['n_words'], config['emb_size'],
                                       config['lstm_dims'])
        self.tok_encoder = BOWEncoder(config['n_words'], config['emb_size'],
                                      config['n_hidden'])
        self.desc_encoder = SeqEncoder2(config['n_words'], config['emb_size'],
                                        config['n_hidden'])

        self.w_name = nn.Linear(2 * config['lstm_dims'], config['n_hidden'])
        self.w_tok = nn.Linear(config['emb_size'], config['n_hidden'])
        #self.w_desc = nn.Linear(2*config['lstm_dims'], config['n_hidden'])
        self.fuse3 = nn.Linear(config['n_hidden'], config['n_hidden'])

        self.self_attn2 = nn.Linear(self.n_hidden, self.n_hidden)
        self.self_attn_scalar2 = nn.Linear(self.n_hidden, 1)

        self.init_weights()
Example #6
0
class TokenEmbeder(nn.Module):
    def __init__(self, config):
        super(TokenEmbeder, self).__init__()
        self.conf = config
        self.margin = config['margin']
        self.n_token_words = config['n_token_words']
        self.n_desc_words = config['n_desc_words']
        self.emb_size = config['emb_size']
        self.n_hidden = config['n_hidden']
        self.dropout = config['dropout']

        self.tok_encoder = SeqEncoder(self.n_token_words, self.emb_size,
                                      self.n_hidden)
        self.desc_encoder = SeqEncoder(self.n_desc_words, self.emb_size,
                                       self.n_hidden)

        #self.w_tok = nn.Linear(config['n_hidden'], config['n_hidden'])
        #self.w_desc = nn.Linear(config['n_hidden'], config['n_hidden'])
        #self.fuse = nn.Linear(config['n_hidden'], config['n_hidden'])

        self.linear_attn_out = nn.Sequential(
            nn.Linear(self.n_hidden, self.n_hidden), nn.Tanh(),
            nn.Linear(self.n_hidden, self.n_hidden))

        if self.conf['transform_every_modal']:
            self.linear_single_modal = nn.Sequential(
                nn.Linear(self.n_hidden, self.n_hidden), nn.Tanh(),
                nn.Linear(self.n_hidden, self.n_hidden))

        if self.conf['save_attn_weight']:
            self.attn_weight_torch = []
            self.node_mask_torch = []

        self.self_atten = nn.Linear(self.n_hidden, self.n_hidden)
        self.self_atten_scalar = nn.Linear(self.n_hidden, 1)

        #self.init_weights()

    '''    
    def init_weights(self):# Initialize Linear Weight 
        for m in [self.w_tok, self.fuse]:        
            m.weight.data.uniform_(-0.1, 0.1) #nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias, 0.) 
    '''

    def code_encoding(self, tokens, tok_len):
        batch_size = tokens.size(0)
        code_enc_hidden = self.tok_encoder.init_hidden(
            batch_size)  # initialize h_0, c_0
        # code_feat: [batch_sz x seq_len x lstm_sz]
        code_feat, code_enc_hidden = self.tok_encoder(tokens, tok_len,
                                                      code_enc_hidden)
        code_enc_hidden = code_enc_hidden[0]

        if self.conf['use_tanh']:
            code_feat = torch.tanh(code_feat)

        seq_len = code_feat.size()[1]

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        unpack_len_list = tok_len.long().to(device)
        range_tensor = torch.arange(seq_len).to(device)
        mask_1forgt0 = range_tensor[None, :] < unpack_len_list[:, None]
        '''
        if self.conf['gpu']:
            unpack_len_list = tok_len.long().cuda()
            range_tensor = torch.arange(seq_len).cuda()
            mask_1forgt0 = range_tensor[None, :] < unpack_len_list[:, None]
        else:
            unpack_len_list = tok_len.long()
            range_tensor = torch.arange(seq_len)
            mask_1forgt0 = range_tensor[None, :] < unpack_len_list[:, None]
        '''
        # for attention
        mask_1forgt0 = mask_1forgt0.reshape(-1, seq_len)

        if self.conf['transform_every_modal']:
            code_feat = torch.tanh(
                self.linear_single_modal(
                    F.dropout(code_feat, self.dropout,
                              training=self.training)))

        if self.conf['use_attn']:

            code_sa_tanh = torch.tanh(
                self.self_atten(code_feat.reshape(
                    -1, self.n_hidden)))  # [(batch_sz * seq_len) x n_hidden]
            code_sa_tanh = F.dropout(code_sa_tanh,
                                     self.dropout,
                                     training=self.training)
            code_sa_tanh = self.self_atten_scalar(code_sa_tanh).reshape(
                -1, seq_len)  # [batch_sz x seq_len]
            code_feat = code_feat.reshape(-1, seq_len, self.n_hidden)

            self_attn_code_feat = None
            for _i in range(batch_size):
                code_sa_tanh_one = torch.masked_select(
                    code_sa_tanh[_i, :], mask_1forgt0[_i, :]).reshape(1, -1)
                # attn_w_one: [1 x 1 x seq_len]
                attn_w_one = F.softmax(code_sa_tanh_one,
                                       dim=1).reshape(1, 1, -1)

                if self.conf['save_attn_weight']:
                    self.attn_weight_torch.append(attn_w_one.detach().reshape(
                        1, -1).cpu())
                    self.node_mask_torch.append(
                        mask_1forgt0[_i, :].detach().reshape(1, -1).cpu())
                # attn_feat_one: [1 x seq_len x n_hidden]
                attn_feat_one = torch.masked_select(
                    code_feat[_i, :, :].reshape(1, seq_len, self.n_hidden),
                    mask_1forgt0[_i, :].reshape(1, seq_len, 1)).reshape(
                        1, -1, self.n_hidden)
                # out_to_cat: [1 x n_hidden]
                out_to_cat = torch.bmm(attn_w_one, attn_feat_one).reshape(
                    1, self.n_hidden)
                # self_attn_code_feat: [batch_sz x n_hidden]
                self_attn_code_feat = out_to_cat if self_attn_code_feat is None else torch.cat(
                    (self_attn_code_feat, out_to_cat), 0)

            self_attn_code_feat = torch.tanh(
                self.linear_attn_out(
                    F.dropout(self_attn_code_feat,
                              self.dropout,
                              training=self.training)))

            return self_attn_code_feat

        return code_enc_hidden

    def desc_encoding(self, desc, desc_len):
        batch_size = desc.size(0)
        desc_enc_hidden = self.desc_encoder.init_hidden(batch_size)
        # desc_enc_hidden: [2 x batch_sz x n_hidden]
        _, desc_enc_hidden = self.desc_encoder(desc, desc_len)
        # desc_feat: [batch_size x n_hidden]
        desc_feat = desc_enc_hidden[0].reshape(batch_size, self.n_hidden)

        if self.conf['transform_every_modal']:
            desc_feat = torch.tanh(
                self.linear_single_modal(
                    F.dropout(desc_feat, self.dropout,
                              training=self.training)))
        elif self.conf['use_tanh']:
            desc_feat = torch.tanh(desc_feat)

        return desc_feat

    def similarity(self, code_vec, desc_vec):
        assert self.conf['sim_measure'] in [
            'cos', 'poly', 'euc', 'sigmoid', 'gesd', 'aesd'
        ], "invalid similarity measure"
        if self.conf['sim_measure'] == 'cos':
            return F.cosine_similarity(code_vec, desc_vec)
        elif self.conf['sim_measure'] == 'poly':
            return (0.5 * torch.matmul(code_vec, desc_vec.t()).diag() + 1)**2
        elif self.conf['sim_measure'] == 'sigmoid':
            return torch.tanh(torch.matmul(code_vec, desc_vec.t()).diag() + 1)
        elif self.conf['sim_measure'] in ['euc', 'gesd', 'aesd']:
            euc_dist = torch.dist(code_vec, desc_vec,
                                  2)  # or torch.norm(code_vec-desc_vec,2)
            euc_sim = 1 / (1 + euc_dist)
            if self.conf['sim_measure'] == 'euc': return euc_sim
            sigmoid_sim = torch.sigmoid(
                torch.matmul(code_vec, desc_vec.t()).diag() + 1)
            if self.conf['sim_measure'] == 'gesd':
                return euc_sim * sigmoid_sim
            elif self.conf['sim_measure'] == 'aesd':
                return 0.5 * (euc_sim + sigmoid_sim)

    def forward(self, tokens, tok_len, desc_anchor, desc_anchor_len, desc_neg,
                desc_neg_len):
        # code_repr: [batch_sz x n_hidden]
        code_repr = self.code_encoding(tokens, tok_len)
        # desc_repr: [batch_sz x n_hidden]
        desc_anchor_repr = self.desc_encoding(desc_anchor, desc_anchor_len)
        desc_neg_repr = self.desc_encoding(desc_neg, desc_neg_len)

        # sim: [batch_sz]
        anchor_sim = self.similarity(code_repr, desc_anchor_repr)
        neg_sim = self.similarity(code_repr, desc_neg_repr)

        loss = (self.margin - anchor_sim + neg_sim).clamp(min=1e-6).mean()

        return loss
Example #7
0
class MultiEmbeder(nn.Module):
    def __init__(self, config):
        super(MultiEmbeder, self).__init__()
        self.conf = config

        self.margin = config['margin']
        self.emb_size = config['emb_size']
        self.n_hidden = config['n_hidden']
        self.dropout = config['dropout']

        self.n_desc_words = config['n_desc_words']
        self.n_token_words = config['n_token_words']

        self.ast_encoder = TreeLSTM(self.conf)
        self.cfg_encoder = GGNN(self.conf)
        self.tok_encoder = SeqEncoder(self.n_token_words, self.emb_size,
                                      self.n_hidden)
        self.desc_encoder = SeqEncoder(self.n_desc_words, self.emb_size,
                                       self.n_hidden)

        self.tok_attn = nn.Linear(self.n_hidden, self.n_hidden)
        self.tok_attn_scalar = nn.Linear(self.n_hidden, 1)
        self.ast_attn = nn.Linear(self.n_hidden, self.n_hidden)
        self.ast_attn_scalar = nn.Linear(self.n_hidden, 1)
        self.cfg_attn = nn.Linear(self.n_hidden, self.n_hidden)
        self.cfg_attn_scalar = nn.Linear(self.n_hidden, 1)

        self.attn_modal_fusion = nn.Linear(self.n_hidden * 3, self.n_hidden)

    def code_encoding(self, tokens, tok_len, tree, tree_node_num,
                      cfg_init_input, cfg_adjmat, cfg_node_mask):

        batch_size = cfg_node_mask.size()[0]
        ''' Token Embedding w.Attention '''
        tok_enc_hidden = self.tok_encoder.init_hidden(batch_size)
        # tok_feat: [batch_size x seq_len x hidden_size]
        tok_feat, _ = self.tok_encoder(tokens, tok_len, tok_enc_hidden)

        seq_len = tok_feat.size()[1]

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        tok_unpack_len_list = tok_len.long().to(device)
        range_tensor = torch.arange(seq_len).to(device)
        tok_mask_1forgt0 = range_tensor[None, :] < tok_unpack_len_list[:, None]
        tok_mask_1forgt0 = tok_mask_1forgt0.reshape(-1, seq_len)

        tok_sa_tanh = torch.tanh(
            self.tok_attn(tok_feat.reshape(-1, self.n_hidden)))
        tok_sa_tanh = F.dropout(tok_sa_tanh,
                                self.dropout,
                                training=self.training)
        # tok_sa_tanh: [batch_size x seq_len]
        tok_sa_tanh = self.tok_attn_scalar(tok_sa_tanh).reshape(-1, seq_len)
        tok_feat = tok_feat.reshape(-1, seq_len, self.n_hidden)

        tok_feat_attn = None
        for _i in range(batch_size):
            tok_sa_tanh_one = torch.masked_select(
                tok_sa_tanh[_i, :], tok_mask_1forgt0[_i, :]).reshape(1, -1)
            # attn_w_one: [1 x 1 x seq_real_len]
            attn_w_one = F.softmax(tok_sa_tanh_one, dim=1).reshape(1, 1, -1)
            # attn_feat_one: [1 x seq_real_len x n_hidden]
            attn_feat_one = torch.masked_select(
                tok_feat[_i, :, :].reshape(1, seq_len, self.n_hidden),
                tok_mask_1forgt0[_i, :].reshape(1, seq_len, 1)).reshape(
                    1, -1, self.n_hidden)
            # out_to_cat: [1 x n_hidden]
            out_to_cat = torch.bmm(attn_w_one,
                                   attn_feat_one).reshape(1, self.n_hidden)
            # tok_feat_attn: [batch_sz x n_hidden]
            tok_feat_attn = out_to_cat if tok_feat_attn is None else torch.cat(
                (tok_feat_attn, out_to_cat), 0)
        ''' AST Embedding w.Attention '''
        # tree: contain ['graph', 'mask', 'wordid', 'label']
        ast_enc_hidden = self.ast_encoder.init_hidden(
            tree.graph.number_of_nodes(
            ))  # use all_node_num to initialize h_0, c_0
        # all_node_h/c_in_batch: [all_node_num_in_batch x hidden_size]
        all_node_h_in_batch, all_node_c_in_batch = self.ast_encoder(
            tree, ast_enc_hidden)

        ast_feat_attn = None
        add_up_node_num = 0
        for _i in range(batch_size):
            # this_sample_h: [this_sample_node_num x hidden_size]
            this_sample_h = all_node_h_in_batch[
                add_up_node_num:add_up_node_num + tree_node_num[_i]]
            add_up_node_num += tree_node_num[_i]

            node_num = tree_node_num[_i]  # this_sample_node_num
            ast_sa_tanh = torch.tanh(
                self.ast_attn(this_sample_h.reshape(-1, self.n_hidden)))
            ast_sa_tanh = F.dropout(ast_sa_tanh,
                                    self.dropout,
                                    training=self.training)
            ast_sa_before_softmax = self.ast_attn_scalar(ast_sa_tanh).reshape(
                1, node_num)
            # ast_attn_weight: [1 x this_sample_node_num]
            ast_attn_weight = F.softmax(ast_sa_before_softmax, dim=1)
            # ast_attn_this_sample_h: [1 x n_hidden]
            ast_attn_this_sample_h = torch.bmm(
                ast_attn_weight.reshape(1, 1, node_num),
                this_sample_h.reshape(1, node_num, self.n_hidden)).reshape(
                    1, self.n_hidden)
            # ast_feat_attn: [batch_size x n_hidden]
            ast_feat_attn = ast_attn_this_sample_h if ast_feat_attn is None else torch.cat(
                (ast_feat_attn, ast_attn_this_sample_h), 0)
        ''' CFG Embedding w.Attention '''
        # cfg_feat: [batch_size x n_node x state_dim]
        cfg_feat = self.cfg_encoder(
            cfg_init_input, cfg_adjmat,
            cfg_node_mask)  # forward(prop_state, A, node_mask)

        node_num = cfg_feat.size()[1]  # n_node
        cfg_feat = cfg_feat.reshape(-1, node_num, self.n_hidden)
        # cfg_mask_1forgt0: [batch_size x n_node]
        cfg_mask_1forgt0 = cfg_node_mask.bool().reshape(-1, node_num)

        cfg_sa_tanh = F.tanh(self.cfg_attn(cfg_feat.reshape(
            -1, self.n_hidden)))  # [(batch_size * n_node) x n_hidden]
        cfg_sa_tanh = F.dropout(cfg_sa_tanh,
                                self.dropout,
                                training=self.training)
        # cfg_sa_tanh: [batch_size x n_node]
        cfg_sa_tanh = self.cfg_attn_scalar(cfg_sa_tanh).reshape(-1, node_num)
        cfg_feat = cfg_feat.reshape(-1, node_num, self.n_hidden)

        cfg_feat_attn = None
        for _i in range(batch_size):
            # cfg_sa_tanh_one: [1 x real_node_num]
            cfg_sa_tanh_one = torch.masked_select(
                cfg_sa_tanh[_i, :], cfg_mask_1forgt0[_i, :]).reshape(1, -1)
            # attn_w_one: [1 x 1 x real_node_num]
            attn_w_one = torch.sigmoid(cfg_sa_tanh_one).reshape(1, 1, -1)
            # attn_feat_one: [1 x real_node_num x n_hidden]
            attn_feat_one = torch.masked_select(
                cfg_feat[_i, :, :].reshape(1, node_num, self.n_hidden),
                cfg_mask_1forgt0[_i, :].reshape(1, node_num, 1)).reshape(
                    1, -1, self.n_hidden)
            # out_to_cat: [1 x n_hidden]
            out_to_cat = torch.bmm(attn_w_one,
                                   attn_feat_one).reshape(1, self.n_hidden)
            # cfg_feat_attn: [batch_size x n_hidden]
            cfg_feat_attn = out_to_cat if cfg_feat_attn is None else torch.cat(
                (cfg_feat_attn, out_to_cat), 0)

        # concat_feat: [batch_size x (n_hidden * 3)]
        concat_feat = torch.cat((tok_feat_attn, ast_feat_attn, cfg_feat_attn),
                                1)
        # code_feat: [batch_size x n_hidden]
        code_feat = torch.tanh(
            self.attn_modal_fusion(
                F.dropout(concat_feat, self.dropout,
                          training=self.training))).reshape(-1, self.n_hidden)

        return code_feat

    def desc_encoding(self, desc, desc_len):
        batch_size = desc.size(0)

        desc_enc_hidden = self.desc_encoder.init_hidden(batch_size)
        # desc_enc_hidden: [2 x batch_size x n_hidden]
        _, desc_enc_hidden = self.desc_encoder(desc, desc_len)
        # desc_feat: [batch_size x n_hidden]
        desc_feat = desc_enc_hidden[0].reshape(batch_size, self.n_hidden)
        # desc_feat: [batch_size x n_hidden]
        desc_feat = torch.tanh(desc_feat)

        return desc_feat

    def forward(self, tokens, tok_len, tree, tree_node_num, cfg_init_input,
                cfg_adjmat, cfg_node_mask, desc_anchor, desc_anchor_len,
                desc_neg, desc_neg_len):
        # code_repr: [batch_size x n_hidden]
        code_repr = self.code_encoding(tokens, tok_len, tree, tree_node_num,
                                       cfg_init_input, cfg_adjmat,
                                       cfg_node_mask)
        # desc_repr: [batch_size x n_hidden]
        desc_anchor_repr = self.desc_encoding(desc_anchor, desc_anchor_len)
        desc_neg_repr = self.desc_encoding(desc_neg, desc_neg_len)

        # sim: [batch_size]
        anchor_sim = F.cosine_similarity(code_repr, desc_anchor_repr)
        neg_sim = F.cosine_similarity(code_repr, desc_neg_repr)

        loss = (self.margin - anchor_sim + neg_sim).clamp(min=1e-6).mean()

        return loss
Example #8
0
class CFGEmbeder(nn.Module):
    def __init__(self, config):
        super(CFGEmbeder, self).__init__()
        self.conf = config

        self.margin = config['margin']
        self.emb_size = config['emb_size']
        self.n_hidden = config['n_hidden']
        self.dropout = config['dropout']

        self.n_desc_words = config['n_desc_words']
        self.n_token_words = config['n_token_words']

        self.dfg_encoder = GGNN(self.conf)
        self.cfg_encoder = GGNN(self.conf)
        self.tok_encoder = SeqEncoder(self.n_token_words, self.emb_size,
                                      self.n_hidden)
        self.desc_encoder = SeqEncoder(self.n_desc_words, self.emb_size,
                                       self.n_hidden)

        self.tok_attn = nn.Linear(self.n_hidden, self.n_hidden)
        self.tok_attn_scalar = nn.Linear(self.n_hidden, 1)
        self.dfg_attn = nn.Linear(self.n_hidden, self.n_hidden)
        self.dfg_attn_scalar = nn.Linear(self.n_hidden, 1)
        self.cfg_attn = nn.Linear(self.n_hidden, self.n_hidden)
        self.cfg_attn_scalar = nn.Linear(self.n_hidden, 1)

        self.self_attn2 = nn.Linear(self.n_hidden, self.n_hidden)
        self.self_attn_scalar2 = nn.Linear(self.n_hidden, 1)

        self.attn_modal_fusion = nn.Linear(self.n_hidden * 3, self.n_hidden)

    def code_encoding(self, tokens, tok_len, dfg_init_input, dfg_adjmat,
                      dfg_node_mask, cfg_init_input, cfg_adjmat,
                      cfg_node_mask):
        batch_size = cfg_node_mask.size()[0]
        ''' Token Embedding w.Attention '''
        tok_enc_hidden = self.tok_encoder.init_hidden(batch_size)
        # tok_feat: [batch_size x seq_len x hidden_size]
        tok_feat, _ = self.tok_encoder(tokens, tok_len, tok_enc_hidden)

        seq_len = tok_feat.size()[1]

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        tok_unpack_len_list = tok_len.long().to(device)
        range_tensor = torch.arange(seq_len).to(device)
        tok_mask_1forgt0 = range_tensor[None, :] < tok_unpack_len_list[:, None]
        tok_mask_1forgt0 = tok_mask_1forgt0.reshape(-1, seq_len)

        tok_sa_tanh = torch.tanh(
            self.tok_attn(tok_feat.reshape(-1, self.n_hidden)))
        tok_sa_tanh = F.dropout(tok_sa_tanh,
                                self.dropout,
                                training=self.training)
        # tok_sa_tanh: [batch_size x seq_len]
        tok_sa_tanh = self.tok_attn_scalar(tok_sa_tanh).reshape(-1, seq_len)
        tok_feat = tok_feat.reshape(-1, seq_len, self.n_hidden)

        tok_feat_attn = None
        for _i in range(batch_size):
            tok_sa_tanh_one = torch.masked_select(
                tok_sa_tanh[_i, :], tok_mask_1forgt0[_i, :]).reshape(1, -1)
            # attn_w_one: [1 x 1 x seq_real_len]
            attn_w_one = F.softmax(tok_sa_tanh_one, dim=1).reshape(1, 1, -1)
            # attn_feat_one: [1 x seq_real_len x n_hidden]
            attn_feat_one = torch.masked_select(
                tok_feat[_i, :, :].reshape(1, seq_len, self.n_hidden),
                tok_mask_1forgt0[_i, :].reshape(1, seq_len, 1)).reshape(
                    1, -1, self.n_hidden)
            # out_to_cat: [1 x n_hidden]
            out_to_cat = torch.bmm(attn_w_one,
                                   attn_feat_one).reshape(1, self.n_hidden)
            # tok_feat_attn: [batch_sz x n_hidden]
            tok_feat_attn = out_to_cat if tok_feat_attn is None else torch.cat(
                (tok_feat_attn, out_to_cat), 0)
        ''' DFG Embedding w.Attention '''
        # dfg_feat: [batch_size x n_node x state_dim]
        dfg_feat = self.dfg_encoder(
            dfg_init_input, dfg_adjmat,
            dfg_node_mask)  # forward(prop_state, A, node_mask)

        node_num = dfg_feat.size()[1]  # n_node
        dfg_feat = dfg_feat.reshape(-1, node_num, self.n_hidden)
        # dfg_mask_1forgt0: [batch_size x n_node]
        dfg_mask_1forgt0 = dfg_node_mask.bool().reshape(-1, node_num)

        dfg_sa_tanh = F.tanh(self.dfg_attn(dfg_feat.reshape(
            -1, self.n_hidden)))  # [(batch_size * n_node) x n_hidden]
        dfg_sa_tanh = F.dropout(dfg_sa_tanh,
                                self.dropout,
                                training=self.training)
        # dfg_sa_tanh: [batch_size x n_node]
        dfg_sa_tanh = self.dfg_attn_scalar(dfg_sa_tanh).reshape(-1, node_num)
        dfg_feat = dfg_feat.reshape(-1, node_num, self.n_hidden)

        dfg_feat_attn = None
        for _i in range(batch_size):
            # dfg_sa_tanh_one: [1 x real_node_num]
            dfg_sa_tanh_one = torch.masked_select(
                dfg_sa_tanh[_i, :], dfg_mask_1forgt0[_i, :]).reshape(1, -1)
            # attn_w_one: [1 x 1 x real_node_num]
            attn_w_one = torch.sigmoid(dfg_sa_tanh_one).reshape(1, 1, -1)
            # attn_feat_one: [1 x real_node_num x n_hidden]
            attn_feat_one = torch.masked_select(
                dfg_feat[_i, :, :].reshape(1, node_num, self.n_hidden),
                dfg_mask_1forgt0[_i, :].reshape(1, node_num, 1)).reshape(
                    1, -1, self.n_hidden)
            # out_to_cat: [1 x n_hidden]
            out_to_cat = torch.bmm(attn_w_one,
                                   attn_feat_one).reshape(1, self.n_hidden)
            # dfg_feat_attn: [batch_size x n_hidden]
            dfg_feat_attn = out_to_cat if dfg_feat_attn is None else torch.cat(
                (dfg_feat_attn, out_to_cat), 0)
        ''' CFG Embedding w.Attention '''
        # cfg_feat: [batch_size x n_node x state_dim]
        cfg_feat = self.cfg_encoder(
            cfg_init_input, cfg_adjmat,
            cfg_node_mask)  # forward(prop_state, A, node_mask)

        node_num = cfg_feat.size()[1]  # n_node
        cfg_feat = cfg_feat.reshape(-1, node_num, self.n_hidden)
        # cfg_mask_1forgt0: [batch_size x n_node]
        cfg_mask_1forgt0 = cfg_node_mask.bool().reshape(-1, node_num)

        cfg_sa_tanh = F.tanh(self.cfg_attn(cfg_feat.reshape(
            -1, self.n_hidden)))  # [(batch_size * n_node) x n_hidden]
        cfg_sa_tanh = F.dropout(cfg_sa_tanh,
                                self.dropout,
                                training=self.training)
        # cfg_sa_tanh: [batch_size x n_node]
        cfg_sa_tanh = self.cfg_attn_scalar(cfg_sa_tanh).reshape(-1, node_num)
        cfg_feat = cfg_feat.reshape(-1, node_num, self.n_hidden)

        cfg_feat_attn = None
        for _i in range(batch_size):
            # cfg_sa_tanh_one: [1 x real_node_num]
            cfg_sa_tanh_one = torch.masked_select(
                cfg_sa_tanh[_i, :], cfg_mask_1forgt0[_i, :]).reshape(1, -1)
            # attn_w_one: [1 x 1 x real_node_num]
            attn_w_one = torch.sigmoid(cfg_sa_tanh_one).reshape(1, 1, -1)
            # attn_feat_one: [1 x real_node_num x n_hidden]
            attn_feat_one = torch.masked_select(
                cfg_feat[_i, :, :].reshape(1, node_num, self.n_hidden),
                cfg_mask_1forgt0[_i, :].reshape(1, node_num, 1)).reshape(
                    1, -1, self.n_hidden)
            # out_to_cat: [1 x n_hidden]
            out_to_cat = torch.bmm(attn_w_one,
                                   attn_feat_one).reshape(1, self.n_hidden)
            # cfg_feat_attn: [batch_size x n_hidden]
            cfg_feat_attn = out_to_cat if cfg_feat_attn is None else torch.cat(
                (cfg_feat_attn, out_to_cat), 0)

        # concat_feat: [batch_size x (n_hidden * 3)]
        concat_feat = torch.cat((tok_feat_attn, dfg_feat_attn, cfg_feat_attn),
                                1)
        # code_feat: [batch_size x n_hidden]
        code_feat = torch.tanh(
            self.attn_modal_fusion(
                F.dropout(concat_feat, self.dropout,
                          training=self.training))).reshape(-1, self.n_hidden)

        return code_feat

    def desc_encoding(self, desc, desc_len):
        batch_size = desc.size()[0]
        desc_enc_hidden = self.desc_encoder.init_hidden(batch_size)
        # desc_enc_hidden: [2 x batch_size x n_hidden]
        desc_feat, desc_enc_hidden = self.desc_encoder(desc, desc_len,
                                                       desc_enc_hidden)
        desc_enc_hidden = desc_enc_hidden[0]

        if self.conf['use_desc_attn']:
            seq_len = desc_feat.size()[1]

            device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
            unpack_len_list = desc_len.long().to(device)
            range_tensor = torch.arange(seq_len).to(device)
            mask_1forgt0 = range_tensor[None, :] < unpack_len_list[:, None]
            mask_1forgt0 = mask_1forgt0.reshape(-1, seq_len)

            desc_sa_tanh = torch.tanh(
                self.self_attn2(desc_feat.reshape(
                    -1, self.n_hidden)))  # [(batch_sz * seq_len) x n_hidden]
            desc_sa_tanh = F.dropout(desc_sa_tanh,
                                     self.dropout,
                                     training=self.training)
            desc_sa_tanh = self.self_attn_scalar2(desc_sa_tanh).reshape(
                -1, seq_len)  # [batch_sz x seq_len]
            desc_feat = desc_feat.reshape(-1, seq_len, self.n_hidden)

            self_attn_desc_feat = None
            for _i in range(batch_size):
                desc_sa_tanh_one = torch.masked_select(
                    desc_sa_tanh[_i, :], mask_1forgt0[_i, :]).reshape(1, -1)
                # attn_w_one: [1 x 1 x seq_len]
                attn_w_one = F.softmax(desc_sa_tanh_one,
                                       dim=1).reshape(1, 1, -1)

                # attn_feat_one: [1 x seq_len x n_hidden]
                attn_feat_one = torch.masked_select(
                    desc_feat[_i, :, :].reshape(1, seq_len, self.n_hidden),
                    mask_1forgt0[_i, :].reshape(1, seq_len, 1)).reshape(
                        1, -1, self.n_hidden)
                # out_to_cat: [1 x n_hidden]
                out_to_cat = torch.bmm(attn_w_one, attn_feat_one).reshape(
                    1, self.n_hidden)
                # self_attn_cfg_feat: [batch_sz x n_hidden]
                self_attn_desc_feat = out_to_cat if self_attn_desc_feat is None else torch.cat(
                    (self_attn_desc_feat, out_to_cat), 0)

        else:
            self_attn_desc_feat = desc_enc_hidden.reshape(
                batch_size, self.n_hidden)

        if self.conf['use_tanh']:
            self_attn_desc_feat = torch.tanh(self_attn_desc_feat)

        # desc_feat: [batch_size x n_hidden]
        return self_attn_desc_feat

    def similarity(self, code_vec, desc_vec):
        assert self.conf['sim_measure'] in [
            'cos', 'poly', 'euc', 'sigmoid', 'gesd', 'aesd'
        ], "invalid similarity measure"
        if self.conf['sim_measure'] == 'cos':
            return F.cosine_similarity(code_vec, desc_vec)
        elif self.conf['sim_measure'] == 'poly':
            return (0.5 * torch.matmul(code_vec, desc_vec.t()).diag() + 1)**2
        elif self.conf['sim_measure'] == 'sigmoid':
            return torch.tanh(torch.matmul(code_vec, desc_vec.t()).diag() + 1)
        elif self.conf['sim_measure'] in ['euc', 'gesd', 'aesd']:
            euc_dist = torch.dist(code_vec, desc_vec,
                                  2)  # or torch.norm(code_vec-desc_vec,2)
            euc_sim = 1 / (1 + euc_dist)
            if self.conf['sim_measure'] == 'euc': return euc_sim
            sigmoid_sim = torch.sigmoid(
                torch.matmul(code_vec, desc_vec.t()).diag() + 1)
            if self.conf['sim_measure'] == 'gesd':
                return euc_sim * sigmoid_sim
            elif self.conf['sim_measure'] == 'aesd':
                return 0.5 * (euc_sim + sigmoid_sim)

    def forward(self, tokens, tok_len, dfg_init_input, dfg_adjmat,
                dfg_node_mask, cfg_init_input, cfg_adjmat, cfg_node_mask,
                desc_anchor, desc_anchor_len, desc_neg, desc_neg_len):
        # code_repr: [batch_sz x n_hidden]
        code_repr = self.code_encoding(tokens, tok_len, dfg_init_input,
                                       dfg_adjmat, dfg_node_mask,
                                       cfg_init_input, cfg_adjmat,
                                       cfg_node_mask)
        # desc_repr: [batch_sz x n_hidden]
        desc_anchor_repr = self.desc_encoding(desc_anchor, desc_anchor_len)
        desc_neg_repr = self.desc_encoding(desc_neg, desc_neg_len)

        # sim: [batch_sz]
        anchor_sim = self.similarity(code_repr, desc_anchor_repr)
        neg_sim = self.similarity(code_repr, desc_neg_repr)

        loss = (self.margin - anchor_sim + neg_sim).clamp(min=1e-6).mean()

        return loss
class ASTEmbeder(nn.Module):
    def __init__(self, config):
        super(ASTEmbeder, self).__init__()
        self.conf = config

        self.margin = config['margin']
        self.n_desc_words = config['n_desc_words']
        self.emb_size = config['emb_size']
        self.n_hidden = config['n_hidden']
        self.dropout = config['dropout']

        self.ast_encoder = TreeLSTM(self.conf)
        self.desc_encoder = SeqEncoder(self.n_desc_words, self.emb_size,
                                       self.n_hidden)

        if self.conf['transform_attn_out']:
            self.linear_attn_out = nn.Sequential(
                nn.Linear(self.n_hidden, self.n_hidden), nn.Tanh(),
                nn.Linear(self.n_hidden, self.n_hidden))

        if self.conf['transform_every_modal']:
            self.linear_single_modal = nn.Sequential(
                nn.Linear(self.n_hidden, self.n_hidden), nn.Tanh(),
                nn.Linear(self.n_hidden, self.n_hidden))

        if self.conf['save_attn_weight']:
            self.attn_weight_torch = []
            self.node_mask_torch = []

        self.self_atten = nn.Linear(self.n_hidden, self.n_hidden)
        self.self_atten_scalar = nn.Linear(self.n_hidden, 1)
        self.self_attn2 = nn.Linear(self.n_hidden, self.n_hidden)
        self.self_attn_scalar2 = nn.Linear(self.n_hidden, 1)

        #self.init_weights()

    def code_encoding(self, tree_batch, tree_node_num):
        batch_size = tree_node_num.size(0)  # tree_num
        # tree_batch: contain ['graph', 'mask', 'wordid', 'label']
        code_enc_hidden = self.ast_encoder.init_hidden(
            tree_batch.graph.number_of_nodes(
            ))  # use all_node_num to initialize h_0, c_0
        # all_node_h_in_batch: [all_node_num_in_batch x hidden_size]
        all_node_h_in_batch, all_node_c_in_batch = self.ast_encoder(
            tree_batch, code_enc_hidden)

        if self.conf['transform_every_modal']:
            all_node_h_in_batch = torch.tanh(
                self.linear_single_modal(
                    F.dropout(all_node_h_in_batch,
                              self.dropout,
                              training=self.training)))

        elif self.conf['use_tanh']:
            all_node_h_in_batch = torch.tanh(all_node_h_in_batch)

        self_attn_code_feat = None
        add_up_node_num = 0
        for _i in range(batch_size):
            # this_sample_h: [this_sample_node_num x hidden_size]
            this_sample_h = all_node_h_in_batch[
                add_up_node_num:add_up_node_num + tree_node_num[_i]]
            add_up_node_num += tree_node_num[_i]

            node_num = tree_node_num[_i]  # this_sample_node_num
            code_sa_tanh = torch.tanh(
                self.self_atten(this_sample_h.reshape(-1, self.n_hidden)))
            code_sa_tanh = F.dropout(code_sa_tanh,
                                     self.dropout,
                                     training=self.training)

            code_sa_before_softmax = self.self_atten_scalar(
                code_sa_tanh).reshape(1, node_num)
            self_atten_weight = F.softmax(code_sa_before_softmax, dim=1)

            if self.conf['save_attn_weight']:
                self.attn_weight_torch.append(
                    self_atten_weight.detach().reshape(1, -1).cpu())

            # self_attn_this_sample_h: [1 x hidden_size]
            self_attn_this_sample_h = torch.bmm(
                self_atten_weight.reshape(1, 1, node_num),
                this_sample_h.reshape(1, node_num, self.n_hidden)).reshape(
                    1, self.n_hidden)

            if self_attn_code_feat is None:
                self_attn_code_feat = self_attn_this_sample_h
            else:
                self_attn_code_feat = torch.cat(
                    (self_attn_code_feat, self_attn_this_sample_h), 0)

        if self.conf['transform_attn_out']:
            # self_attn_code_feat: [batch_size x hidden_size]
            self_attn_code_feat = torch.tanh(
                self.linear_attn_out(
                    F.dropout(self_attn_code_feat,
                              self.dropout,
                              training=self.training)))

        elif self.conf['use_tanh']:
            self_attn_code_feat = torch.tanh(self_attn_code_feat)

        return self_attn_code_feat

    def desc_encoding(self, desc, desc_len):
        batch_size = desc.size()[0]
        desc_enc_hidden = self.desc_encoder.init_hidden(batch_size)
        # desc_enc_hidden: [2 x batch_size x n_hidden]
        desc_feat, desc_enc_hidden = self.desc_encoder(desc, desc_len,
                                                       desc_enc_hidden)
        desc_enc_hidden = desc_enc_hidden[0]

        if self.conf['use_desc_attn']:
            seq_len = desc_feat.size()[1]

            device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
            unpack_len_list = desc_len.long().to(device)
            range_tensor = torch.arange(seq_len).to(device)
            mask_1forgt0 = range_tensor[None, :] < unpack_len_list[:, None]
            mask_1forgt0 = mask_1forgt0.reshape(-1, seq_len)

            desc_sa_tanh = torch.tanh(
                self.self_attn2(desc_feat.reshape(
                    -1, self.n_hidden)))  # [(batch_sz * seq_len) x n_hidden]
            desc_sa_tanh = F.dropout(desc_sa_tanh,
                                     self.dropout,
                                     training=self.training)
            desc_sa_tanh = self.self_attn_scalar2(desc_sa_tanh).reshape(
                -1, seq_len)  # [batch_sz x seq_len]
            desc_feat = desc_feat.reshape(-1, seq_len, self.n_hidden)

            self_attn_desc_feat = None
            for _i in range(batch_size):
                desc_sa_tanh_one = torch.masked_select(
                    desc_sa_tanh[_i, :], mask_1forgt0[_i, :]).reshape(1, -1)
                # attn_w_one: [1 x 1 x seq_len]
                attn_w_one = F.softmax(desc_sa_tanh_one,
                                       dim=1).reshape(1, 1, -1)

                # attn_feat_one: [1 x seq_len x n_hidden]
                attn_feat_one = torch.masked_select(
                    desc_feat[_i, :, :].reshape(1, seq_len, self.n_hidden),
                    mask_1forgt0[_i, :].reshape(1, seq_len, 1)).reshape(
                        1, -1, self.n_hidden)
                # out_to_cat: [1 x n_hidden]
                out_to_cat = torch.bmm(attn_w_one, attn_feat_one).reshape(
                    1, self.n_hidden)
                # self_attn_code_feat: [batch_sz x n_hidden]
                self_attn_desc_feat = out_to_cat if self_attn_desc_feat is None else torch.cat(
                    (self_attn_desc_feat, out_to_cat), 0)

        else:
            self_attn_desc_feat = desc_enc_hidden.reshape(
                batch_size, self.n_hidden)

        if self.conf['use_tanh']:
            self_attn_desc_feat = torch.tanh(self_attn_desc_feat)

        # desc_feat: [batch_size x n_hidden]
        return self_attn_desc_feat

    def similarity(self, code_vec, desc_vec):
        assert self.conf['sim_measure'] in [
            'cos', 'poly', 'euc', 'sigmoid', 'gesd', 'aesd'
        ], "invalid similarity measure"
        if self.conf['sim_measure'] == 'cos':
            return F.cosine_similarity(code_vec, desc_vec)
        elif self.conf['sim_measure'] == 'poly':
            return (0.5 * torch.matmul(code_vec, desc_vec.t()).diag() + 1)**2
        elif self.conf['sim_measure'] == 'sigmoid':
            return torch.tanh(torch.matmul(code_vec, desc_vec.t()).diag() + 1)
        elif self.conf['sim_measure'] in ['euc', 'gesd', 'aesd']:
            euc_dist = torch.dist(code_vec, desc_vec,
                                  2)  # or torch.norm(code_vec-desc_vec,2)
            euc_sim = 1 / (1 + euc_dist)
            if self.conf['sim_measure'] == 'euc': return euc_sim
            sigmoid_sim = torch.sigmoid(
                torch.matmul(code_vec, desc_vec.t()).diag() + 1)
            if self.conf['sim_measure'] == 'gesd':
                return euc_sim * sigmoid_sim
            elif self.conf['sim_measure'] == 'aesd':
                return 0.5 * (euc_sim + sigmoid_sim)

    def forward(self, tree_batch, tree_node_num, desc_anchor, desc_anchor_len,
                desc_neg, desc_neg_len):
        # code_repr: [batch_sz x n_hidden]
        code_repr = self.code_encoding(tree_batch, tree_node_num)
        # desc_repr: [batch_sz x n_hidden]
        desc_anchor_repr = self.desc_encoding(desc_anchor, desc_anchor_len)
        desc_neg_repr = self.desc_encoding(desc_neg, desc_neg_len)

        # sim: [batch_sz]
        anchor_sim = self.similarity(code_repr, desc_anchor_repr)
        neg_sim = self.similarity(code_repr, desc_neg_repr)

        loss = (self.margin - anchor_sim + neg_sim).clamp(min=1e-6).mean()

        return loss
Example #10
0
class CFGEmbeder(nn.Module):
    def __init__(self, config):
        super(CFGEmbeder, self).__init__()
        self.conf = config

        self.margin = config['margin']
        self.n_desc_words = config['n_desc_words']
        self.emb_size = config['emb_size']
        self.n_hidden = config['n_hidden']
        self.dropout = config['dropout']
        self.cfg_attn_mode = config['cfg_attn_mode']

        self.cfg_encoder = GGNN(self.conf)
        self.desc_encoder = SeqEncoder(self.n_desc_words, self.emb_size,
                                       self.n_hidden)

        self.linear_attn_out = nn.Sequential(
            nn.Linear(self.n_hidden, self.n_hidden), nn.Tanh(),
            nn.Linear(self.n_hidden, self.n_hidden))

        if self.conf['transform_every_modal']:
            self.linear_single_modal = nn.Sequential(
                nn.Linear(self.n_hidden, self.n_hidden), nn.Tanh(),
                nn.Linear(self.n_hidden, self.n_hidden))

        if self.conf['save_attn_weight']:
            self.attn_weight_torch = []
            self.node_mask_torch = []

        self.self_atten = nn.Linear(self.n_hidden, self.n_hidden)
        self.self_atten_scalar = nn.Linear(self.n_hidden, 1)

    def code_encoding(self, cfg_init_input_batch, cfg_adjmat_batch,
                      cfg_node_mask):

        batch_size = cfg_node_mask.size()[0]

        # code_feat: [batch_size x n_node x state_dim]
        code_feat = self.cfg_encoder(
            cfg_init_input_batch, cfg_adjmat_batch,
            cfg_node_mask)  # forward(prop_state, A, node_mask)

        node_num = code_feat.size()[1]  # n_node
        code_feat = code_feat.reshape(-1, node_num,
                                      self.n_hidden)  # useless...
        # mask_1forgt0: [batch_size x n_node]
        mask_1forgt0 = cfg_node_mask.bool().reshape(-1, node_num)

        if self.conf['transform_every_modal']:
            code_feat = torch.tanh(
                self.linear_single_modal(
                    F.dropout(code_feat, self.dropout,
                              training=self.training)))

        code_sa_tanh = F.tanh(
            self.self_atten(code_feat.reshape(
                -1, self.n_hidden)))  # [(batch_size * n_node) x n_hidden]
        code_sa_tanh = F.dropout(code_sa_tanh,
                                 self.dropout,
                                 training=self.training)
        # code_sa_tanh: [batch_size x n_node]
        code_sa_tanh = self.self_atten_scalar(code_sa_tanh).reshape(
            -1, node_num)

        code_feat = code_feat.reshape(-1, node_num, self.n_hidden)
        batch_size = code_feat.size()[0]

        self_attn_code_feat = None
        for _i in range(batch_size):
            # code_sa_tanh_one: [1 x real_node_num]
            code_sa_tanh_one = torch.masked_select(
                code_sa_tanh[_i, :], mask_1forgt0[_i, :]).reshape(1, -1)

            if self.cfg_attn_mode == 'sigmoid_scalar':
                # attn_w_one: [1 x 1 x real_node_num]
                attn_w_one = torch.sigmoid(code_sa_tanh_one).reshape(1, 1, -1)
            else:
                attn_w_one = F.softmax(code_sa_tanh_one,
                                       dim=1).reshape(1, 1, -1)

            if self.conf['save_attn_weight']:
                self.attn_weight_torch.append(attn_w_one.detach().reshape(
                    1, -1).cpu())
                self.node_mask_torch.append(
                    mask_1forgt0[_i, :].detach().reshape(1, -1).cpu())

            # attn_feat_one: [1 x real_node_num x n_hidden]
            attn_feat_one = torch.masked_select(
                code_feat[_i, :, :].reshape(1, node_num, self.n_hidden),
                mask_1forgt0[_i, :].reshape(1, node_num,
                                            1)).reshape(1, -1, self.n_hidden)
            # out_to_cat: [1 x n_hidden]
            out_to_cat = torch.bmm(attn_w_one,
                                   attn_feat_one).reshape(1, self.n_hidden)
            # self_attn_code_feat: [batch_size x n_hidden]
            self_attn_code_feat = out_to_cat if self_attn_code_feat is None else torch.cat(
                (self_attn_code_feat, out_to_cat), 0)

        self_attn_code_feat = torch.tanh(self_attn_code_feat)
        '''
        self_attn_code_feat = torch.tanh(
            self.linear_attn_out(
                F.dropout(self_attn_code_feat, self.dropout, training=self.training))
        )
        '''
        # self_attn_code_feat: [batch_size x n_hidden]
        return self_attn_code_feat

    def desc_encoding(self, desc, desc_len):
        batch_size = desc.size(0)
        desc_enc_hidden = self.desc_encoder.init_hidden(batch_size)
        # desc_enc_hidden: [2 x batch_size x n_hidden]
        _, desc_enc_hidden = self.desc_encoder(desc, desc_len)
        # desc_feat: [batch_size x n_hidden]
        desc_feat = desc_enc_hidden[0].reshape(batch_size, self.n_hidden)

        if self.conf['transform_every_modal']:
            desc_feat = torch.tanh(
                self.linear_single_modal(
                    F.dropout(desc_feat, self.dropout,
                              training=self.training)))
        elif self.conf['use_tanh']:
            desc_feat = torch.tanh(desc_feat)

        # desc_feat: [batch_size x n_hidden]
        return desc_feat

    def similarity(self, code_vec, desc_vec):
        assert self.conf['sim_measure'] in [
            'cos', 'poly', 'euc', 'sigmoid', 'gesd', 'aesd'
        ], "invalid similarity measure"
        if self.conf['sim_measure'] == 'cos':
            return F.cosine_similarity(code_vec, desc_vec)
        elif self.conf['sim_measure'] == 'poly':
            return (0.5 * torch.matmul(code_vec, desc_vec.t()).diag() + 1)**2
        elif self.conf['sim_measure'] == 'sigmoid':
            return torch.tanh(torch.matmul(code_vec, desc_vec.t()).diag() + 1)
        elif self.conf['sim_measure'] in ['euc', 'gesd', 'aesd']:
            euc_dist = torch.dist(code_vec, desc_vec,
                                  2)  # or torch.norm(code_vec-desc_vec,2)
            euc_sim = 1 / (1 + euc_dist)
            if self.conf['sim_measure'] == 'euc': return euc_sim
            sigmoid_sim = torch.sigmoid(
                torch.matmul(code_vec, desc_vec.t()).diag() + 1)
            if self.conf['sim_measure'] == 'gesd':
                return euc_sim * sigmoid_sim
            elif self.conf['sim_measure'] == 'aesd':
                return 0.5 * (euc_sim + sigmoid_sim)

    def forward(self, cfg_init_input_batch, cfg_adjmat_batch, cfg_node_mask,
                desc_anchor, desc_anchor_len, desc_neg, desc_neg_len):
        # code_repr: [batch_sz x n_hidden]
        cfg_repr = self.code_encoding(cfg_init_input_batch, cfg_adjmat_batch,
                                      cfg_node_mask)
        # desc_repr: [batch_sz x n_hidden]
        desc_anchor_repr = self.desc_encoding(desc_anchor, desc_anchor_len)
        desc_neg_repr = self.desc_encoding(desc_neg, desc_neg_len)

        # sim: [batch_sz]
        anchor_sim = self.similarity(cfg_repr, desc_anchor_repr)
        neg_sim = self.similarity(cfg_repr, desc_neg_repr)

        loss = (self.margin - anchor_sim + neg_sim).clamp(min=1e-6).mean()

        return loss