示例#1
0
    def forward_batch(self,
                      data_base,
                      edge_index_base,
                      batch_base,
                      t_sne=False):
        # Base predictor inference
        x1_base = F.relu(self.conv1(data_base, edge_index_base))
        x1_base = self.bn1(x1_base)

        x2_base = F.relu(self.conv2_base(x1_base, edge_index_base))
        x2_base = self.bn2_base(x2_base)

        x3_base = F.relu(self.conv3_base(x2_base, edge_index_base))
        x3_base = self.bn3_base(x3_base)
        x_embedding_base = gmp(x3_base, batch_base)

        # # branch1 output
        outputs = self.fc(x_embedding_base)

        outputs = self.output_layer(outputs)

        if t_sne:
            return outputs, x_embedding_base
        else:
            return outputs
    def forward_batch(self, data_base, edge_index_base, batch_base,
                      data_residual, edge_index_residual, batch_residual):
        # Base predictor inference
        x1_base = F.relu(self.conv1(data_base, edge_index_base))
        x1_base = self.bn1(x1_base)

        x2_base = F.relu(self.conv2_base(x1_base, edge_index_base))
        x2_base = self.bn2_base(x2_base)

        x3_base = F.relu(self.conv3_base(x2_base, edge_index_base))
        x3_base = self.bn3_base(x3_base)
        x_embedding_base = gmp(x3_base, batch_base)
        x_embedding_base = F.relu(self.linear_branch1(x_embedding_base))

        # Residual predictor inference
        x1_residual = F.relu(
            self.conv1_residual(data_residual, edge_index_residual))
        x1_residual = self.bn1_residual(x1_residual)

        x2_residual = F.relu(
            self.conv2_residual(x1_residual, edge_index_residual))
        x2_residual = self.bn2_residual(x2_residual)

        x3_residual = F.relu(
            self.conv3_residual(x2_residual, edge_index_residual))
        x3_residual = self.bn3_residual(x3_residual)
        x_embedding_residual = gmp(x3_residual, batch_residual)
        x_embedding_residual = F.relu(
            self.linear_branch2(x_embedding_residual))

        x_embedding_residual = torch.cat(
            [x_embedding_base, x_embedding_residual], dim=-1)
        x_embedding_residual = F.relu(
            self.linear_before_residual(x_embedding_residual))
        outputs = self.linear_mean_residual(x_embedding_residual)

        outputs = self.output(outputs)

        return outputs
示例#3
0
    def forward_batch(self, data, edge_index, batch):
        x1 = F.relu(self.conv1(data, edge_index))
        x1 = self.bn1(x1)

        x2 = F.relu(self.conv2(x1, edge_index))
        x2 = self.bn2(x2)

        x_embedding = torch.cat([gmp(x2, batch), gap(x2, batch)], dim=1)
        x_embedding = F.relu(self.linear_before(x_embedding))
        x_embedding = F.dropout(x_embedding, p=0.1, training=self.training)

        pred = self.linear_mean(x_embedding)
        pred = self.out_layer(pred)

        return pred
示例#4
0
    def forward_batch(self, data, edge_index, batch, alpha=None):
        x1 = F.relu(self.conv1(data, edge_index))
        x1 = self.bn1(x1)

        x2 = F.relu(self.conv2(x1, edge_index))
        x2 = self.bn2(x2)

        x3 = F.relu(self.conv3(x2, edge_index))
        x3 = self.bn3(x3)

        x_embedding = gmp(x3, batch)
        x_embedding_mean = F.relu(self.linear_before(x_embedding))
        x_embedding_drop = F.dropout(x_embedding_mean,
                                     p=0.1,
                                     training=self.training)
        mean = self.linear_mean(x_embedding_drop)
        mean = self.out_layer(mean)
        return mean
示例#5
0
    def forward_batch(self, data, edge_index, batch, data_reverse,
                      edge_index_reverse):
        x1 = (F.relu(self.conv1(data, edge_index)) + F.relu(
            self.conv1_reverse(data_reverse, edge_index_reverse))) * 0.5
        x1 = self.bn1(x1)
        x1 = F.dropout(x1, training=self.training, p=0.1)
        x2 = (F.relu(self.conv2(x1, edge_index)) +
              F.relu(self.conv2_reverse(x1, edge_index_reverse))) * 0.5
        x2 = self.bn2(x2)
        x2 = F.dropout(x2, training=self.training, p=0.1)
        x3 = (F.relu(self.conv3(x2, edge_index)) +
              F.relu(self.conv3_reverse(x2, edge_index_reverse))) * 0.5
        x3 = self.bn3(x3)
        x3 = F.dropout(x3, training=self.training, p=0.1)

        x_embedding = gmp(x3, batch)
        x_embedding = F.relu(self.linear(x_embedding))
        x_embedding = F.dropout(x_embedding, p=0.1, training=self.training)
        output = self.liner2(x_embedding)
        output = self.out_layer(output)
        return output
示例#6
0
    def forward_batch(self, data, edge_index, batch, alpha=None):
        x1 = F.relu(self.conv1(data, edge_index))
        x1 = self.bn1(x1)

        x2 = F.relu(self.conv2(x1, edge_index))
        x2 = self.bn2(x2)

        x3 = F.relu(self.conv3(x2, edge_index))
        x3 = self.bn3(x3)

        x_embedding = gmp(x3, batch)
        x_embedding_mean = F.relu(self.linear_before(x_embedding))
        x_embedding_drop = F.dropout(x_embedding_mean, p=0.1, training=self.training)
        mean = self.linear_mean(x_embedding_drop)

        x_embedding_std = F.relu(self.linear_before_std(x_embedding))
        std = F.relu(self.linear_std(x_embedding_std))

        std = torch.exp(std / 2)
        eps = torch.randn_like(std)
        x_sample = gaussian_layer(mean, std, eps)

        return x_sample, mean, std
示例#7
0
    def forward_batch(self, data, edge_index, batch, t_sne=None):
        x1 = F.relu(self.conv1(data, edge_index))
        x1 = self.bn1(x1)

        x2 = F.relu(self.conv2(x1, edge_index))
        x2 = self.bn2(x2)

        x3 = F.relu(self.conv3(x2, edge_index))
        x3 = self.bn3(x3)

        x_embedding = gmp(x3, batch)
        x_embedding_mean = F.relu(self.linear_before(x_embedding))
        x_embedding_drop = F.dropout(x_embedding_mean,
                                     p=0.1,
                                     training=self.training)

        pred = self.linear_mean(x_embedding_drop)
        pred = self.out_layer(pred)

        if t_sne:
            return pred, x_embedding
        else:
            return pred