Beispiel #1
0
 def encode(self, s, z, temp, add_noise):
     s = torch.flatten(s, start_dim=1).unsqueeze(1)
     z = torch.flatten(z, start_dim=1).unsqueeze(1)
     h1 = bn_and_dpt(torch.relu(self.fc1(torch.cat([s, z], dim=2))), self.bn1, self.dpt1)
     h2 = bn_and_dpt(torch.relu(self.fc2(torch.cat([s, h1], dim=2))), self.bn2, self.dpt2)
     h3 = self.fc3(torch.cat([s, h2], dim=2))
     return gumbel_softmax(h3.view(-1, 1, self.AAE_N_ACTION), temp, add_noise)
 def forward(self, input, temp):
     h1 = self.dpt1(
         self.bn1(self.conv1(input.view(-1, U, N * N_OBJ_FEATURE)))).view(
             -1, U, LAYER_SIZE)
     logits = self.conv2(h1).view(-1, U, A, N)
     prob = gumbel_softmax(logits, temp).unsqueeze(-1).expand(
         -1, -1, -1, -1, N_OBJ_FEATURE)
     dot = torch.mul(prob,
                     input.unsqueeze(2).expand(-1, -1, A, -1,
                                               -1)).sum(dim=3)
     return dot
Beispiel #3
0
    def decode(self, s, a, temp, add_noise):
        h4 = bn_and_dpt(torch.relu(self.fc4(a)), self.bn4, self.dpt4)
        h5 = bn_and_dpt(torch.relu(self.fc5(h4)), self.bn5, self.dpt5)
        h6 = self.fc6(h5)

        if self.back_to_logit:
            s = self.bn_input(s)
            h6 = h6.view(-1, self.AAE_LATENT_DIM, 1)
            h6 = self.bn_effect(h6)
            s = gumbel_softmax(h6+s, temp, add_noise)
            add = None
            delete = None
        else:
            h6 = h6.view(-1, self.AAE_LATENT_DIM, 3)
            h6 = gumbel_softmax(h6, temp, add_noise)
            add = h6[:,:,[0]]
            delete = h6[:,:,[1]]
            s = torch.min(s, 1-delete)
            s = torch.max(s, add)

        return s, add, delete
 def forward(self, input, temp):
     h1 = self.dpt1(self.bn1(self.fc1(input.view(-1, U,
                                                 A * N_OBJ_FEATURE))))
     logits = self.fc2(h1).view(-1, U, P, 2)
     prob = gumbel_softmax(logits, temp)
     return prob
Beispiel #5
0
 def forward(self, x, temp, add_noise):
     q_y = self.encode(x)
     z_y = gumbel_softmax(q_y, temp, add_noise)
     return self.decode(z_y), z_y