Ejemplo n.º 1
0
    def __init__(self,
                 feature_embed_dim,
                 f_act="sigmoid",
                 dropout=0.5,
                 join_type="cat"):
        '''
        
        Args:
            
            feature_embed_dim: the feature embedding dimention
            f_act: the final activation function applied to get the final result
        '''
        super(JointRelativeGlobalDecoder, self).__init__()
        self.feature_embed_dim = feature_embed_dim
        self.join_type = join_type

        if self.join_type == "cat":
            self.post_linear = nn.Linear(self.feature_embed_dim * 2,
                                         self.feature_embed_dim)
        else:
            self.post_linear = nn.Linear(self.feature_embed_dim,
                                         self.feature_embed_dim)

        self.dropout = nn.Dropout(p=dropout)

        self.f_act = get_activation_function(f_act,
                                             "JointRelativeGlobalDecoder")
Ejemplo n.º 2
0
    def __init__(self,
                 pointset,
                 enc,
                 spa_enc,
                 g_spa_enc,
                 g_spa_dec,
                 init_dec,
                 dec,
                 joint_dec,
                 activation="sigmoid",
                 num_context_sample=10,
                 num_neg_resample=10):
        super(JointRelativeGlobalEncoderDecoder, self).__init__()
        self.pointset = pointset
        self.enc = enc
        self.init_dec = init_dec
        self.dec = dec
        self.joint_dec = joint_dec
        self.spa_enc = spa_enc
        self.g_spa_enc = g_spa_enc
        self.g_spa_dec = g_spa_dec
        self.num_context_sample = num_context_sample
        self.num_neg_resample = num_neg_resample  # given 100 negative sample, we sample 10

        self.activation = get_activation_function(
            activation, "JointRelativeGlobalEncoderDecoder")
Ejemplo n.º 3
0
    def __init__(self,
                 pointset,
                 enc,
                 g_spa_enc,
                 g_spa_dec,
                 activation="sigmoid",
                 num_neg_resample=10):
        super(GlobalPositionEncoderDecoder, self).__init__()
        self.pointset = pointset
        self.enc = enc  # point feature embedding encoder
        self.g_spa_enc = g_spa_enc  # one of the SpatialRelationEncoder
        self.g_spa_dec = g_spa_dec  # DirectPositionEmbeddingDecoder()

        self.activation = get_activation_function(
            activation, "GlobalPositionEncoderDecoder")

        self.num_neg_resample = num_neg_resample  # given 100 negative sample, we sample 10
Ejemplo n.º 4
0
    def __init__(self,
                 g_spa_embed_dim,
                 feature_embed_dim,
                 f_act="sigmoid",
                 dropout=0.5):
        '''
        
        Args:
            g_spa_embed_dim: the global position embedding dimention
            feature_embed_dim: the feature embedding dimention
            f_act: the final activation function applied to get the final result
        '''
        super(DirectPositionEmbeddingDecoder, self).__init__()
        self.g_spa_embed_dim = g_spa_embed_dim
        self.feature_embed_dim = feature_embed_dim

        self.post_linear = nn.Linear(self.g_spa_embed_dim,
                                     self.feature_embed_dim)
        self.dropout = nn.Dropout(p=dropout)

        self.f_act = get_activation_function(f_act,
                                             "DirectPositionEmbeddingDecoder")
Ejemplo n.º 5
0
    def __init__(self,
                 query_dim,
                 key_dim,
                 spa_embed_dim,
                 g_spa_embed_dim,
                 have_query_embed=True,
                 num_attn=1,
                 activation="leakyrelu",
                 f_activation="sigmoid",
                 layernorm=False,
                 use_post_mat=False,
                 dropout=0.5):
        '''
        The attention method used by Graph Attention network (LeakyReLU)
        Args:
            query_dim: the center point feature embedding dimention
            key_dim: the N context point feature embedding dimention
            spa_embed_dim: the spatial relation embedding dimention
            have_query_embed: Trua/False, do we use query embedding in the attention
            num_attn: number of attention head
            activation: the activation function to atten_vecs * torch.cat(query_embed, key_embed), see GAT paper Equ 3
            f_activation: the final activation function applied to get the final result, see GAT paper Equ 6
        '''
        super(GolbalPositionIntersectConcatAttention, self).__init__()

        self.query_dim = query_dim
        self.key_dim = key_dim
        self.spa_embed_dim = spa_embed_dim

        self.g_spa_embed_dim = g_spa_embed_dim

        self.num_attn = num_attn
        self.have_query_embed = have_query_embed

        self.activation = get_activation_function(
            activation, "GolbalPositionIntersectConcatAttention middle")

        self.f_activation = get_activation_function(
            f_activation, "GolbalPositionIntersectConcatAttention final")

        self.softmax = nn.Softmax(dim=1)

        self.layernorm = layernorm
        self.use_post_mat = use_post_mat
        if self.have_query_embed:
            assert key_dim == query_dim

            # define the layer normalization
            if self.layernorm:
                self.pre_ln = LayerNorm(query_dim)
                self.add_module("attn_preln", self.pre_ln)

            if self.use_post_mat:
                self.post_linear = nn.Linear(query_dim, query_dim)
                self.dropout = nn.Dropout(p=dropout)
                # self.register_parameter("attn_PostLinear", self.post_linear)

                # self.post_W = nn.Parameter(torch.FloatTensor(query_dim, query_dim))
                # init.xavier_uniform_(self.post_W)
                # self.register_parameter("attn_PostW", self.post_W)

                # self.post_B = nn.Parameter(torch.FloatTensor(1,query_dim))
                # init.xavier_uniform_(self.post_B)
                # self.register_parameter("attn_PostB",self.post_B)
                if self.layernorm:
                    self.post_ln = LayerNorm(query_dim)
                    self.add_module("attn_Postln", self.post_ln)

            # each column represent an attention vector for one attention head: [embed_dim*2, num_attn]
            self.atten_vecs = nn.Parameter(
                torch.FloatTensor(
                    query_dim + key_dim + spa_embed_dim + g_spa_embed_dim,
                    self.num_attn))
            init.xavier_uniform_(self.atten_vecs)
            self.register_parameter("attn_attenvecs", self.atten_vecs)
        else:
            # if we do not use query embedding in the attention, this means
            # We just compute the initial query embedding
            # define the layer normalization
            if self.layernorm:
                self.pre_ln = LayerNorm(key_dim)
                self.add_module("attn_nq_preln", self.pre_ln)

            if self.use_post_mat:
                self.post_linear = nn.Linear(key_dim, key_dim)
                self.dropout = nn.Dropout(p=dropout)
                # self.register_parameter("attn_PostLinear", self.post_linear)

                # self.post_W = nn.Parameter(torch.FloatTensor(key_dim, key_dim))
                # init.xavier_uniform_(self.post_W)
                # self.register_parameter("attn_nq_PostW", self.post_W)

                # self.post_B = nn.Parameter(torch.FloatTensor(1,key_dim))
                # init.xavier_uniform_(self.post_B)
                # self.register_parameter("attn_nq_PostB",self.post_B)
                if self.layernorm:
                    self.post_ln = LayerNorm(key_dim)
                    self.add_module("attn_nq_Postln", self.post_ln)

            # In the initial query embedding computing, we just use key embeddings and spatial relation embeddings
            # each column represent an attention vector for one attention head: [embed_dim*2, num_attn]
            self.atten_vecs = nn.Parameter(
                torch.FloatTensor(key_dim + spa_embed_dim + g_spa_embed_dim,
                                  self.num_attn))
            init.xavier_uniform_(self.atten_vecs)
            self.register_parameter("attn_nq_attenvecs", self.atten_vecs)