Exemple #1
0
    def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
        """ Perform T steps of message passing """
        if in_layers is None:
            in_layers = self.in_layers
        in_layers = convert_to_layers(in_layers)

        # Extract atom_features
        atom_features = in_layers[0].out_tensor
        pair_features = in_layers[1].out_tensor
        atom_to_pair = in_layers[2].out_tensor
        n_atom_features = atom_features.get_shape().as_list()[-1]
        n_pair_features = pair_features.get_shape().as_list()[-1]
        # Add trainable weights
        self.build(pair_features, n_pair_features)

        if n_atom_features < self.n_hidden:
            pad_length = self.n_hidden - n_atom_features
            out = F.pad(atom_features, ((0, 0), (0, pad_length)),
                        mode='constant')
        elif n_atom_features > self.n_hidden:
            raise ValueError("Too large initial feature vector")
        else:
            out = atom_features

        for i in range(self.T):
            message = self.message_function.forward(out, atom_to_pair)
            out = self.update_function.forward(out, message)

        out_tensor = out

        if set_tensors:
            self.variables = self.trainable_weights
            self.out_tensor = out_tensor
        return out_tensor
Exemple #2
0
    def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
        """
    parent layers: atom_features, atom_split
    """
        if in_layers is None:
            in_layers = self.in_layers
        in_layers = convert_to_layers(in_layers)

        self.build()
        outputs = in_layers[0].out_tensor
        atom_split = in_layers[1].out_tensor

        if self.gaussian_expand:
            outputs = self.gaussian_histogram(outputs)

        output_molecules = torch.sum(outputs, atom_split)

        if self.gaussian_expand:
            output_molecules = torch.matmul(output_molecules, self.W) + self.b
            output_molecules = self.activation(output_molecules)

        out_tensor = output_molecules
        if set_tensors:
            self.variables = self.trainable_weights
            self.out_tensor = out_tensor
        return out_tensor
Exemple #3
0
 def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
     if in_layers is None:
         in_layers = self.in_layers
     in_layers = convert_to_layers(in_layers)
     output = in_layers[0].out_tensor
     out_tensor = output[:, self.task_id:self.task_id + 1]
     self.out_tensor = out_tensor
     return out_tensor
Exemple #4
0
    def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
        """Creates weave tensors.
    parent layers: [atom_features, pair_features], pair_split, atom_to_pair
    """
        activation = activations.get(self.activation)  # Get activations
        if in_layers is None:
            in_layers = self.in_layers
        in_layers = convert_to_layers(in_layers)

        self.build()

        atom_features = in_layers[0].out_tensor
        pair_features = in_layers[1].out_tensor

        pair_split = in_layers[2].out_tensor
        atom_to_pair = in_layers[3].out_tensor

        AA = torch.matmul(atom_features, self.W_AA) + self.b_AA
        AA = activation(AA)
        PA = torch.matmul(pair_features, self.W_PA) + self.b_PA
        PA = activation(PA)
        PA = torch.sum(PA, pair_split)

        A = torch.matmul(torch.cat([AA, PA], 1), self.W_A) + self.b_A
        A = activation(A)

        if self.update_pair:
            AP_ij = torch.matmul(
                torch.reshape(torch.gather(atom_features, atom_to_pair),
                              [-1, 2 * self.n_atom_input_feat]),
                self.W_AP) + self.b_AP
            AP_ij = activation(AP_ij)
            AP_ji = torch.matmul(
                torch.reshape(
                    torch.gather(atom_features,
                                 torch.transpose(atom_to_pair, [1])),
                    [-1, 2 * self.n_atom_input_feat]), self.W_AP) + self.b_AP
            AP_ji = activation(AP_ji)

            PP = torch.matmul(pair_features, self.W_PP) + self.b_PP
            PP = activation(PP)
            P = torch.matmul(torch.cat([AP_ij + AP_ji, PP], 1),
                             self.W_P) + self.b_P
            P = activation(P)
        else:
            P = pair_features

        self.out_tensors = [A, P]
        if set_tensors:
            self.variables = self.trainable_weights
            self.out_tensor = A
        return self.out_tensors
Exemple #5
0
    def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
        """
    parent layers: atom_number
    """
        if in_layers is None:
            in_layers = self.in_layers
        in_layers = convert_to_layers(in_layers)

        self.build()
        atom_number = in_layers[0].out_tensor
        atom_features = nn.Embedding(self.embedding_list, atom_number)
        if set_tensors:
            self.variables = self.trainable_weights
            self.out_tensor = atom_features
Exemple #6
0
    def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
        """
    parent layers: atom_features, atom_membership
    """
        if in_layers is None:
            in_layers = self.in_layers
        in_layers = convert_to_layers(in_layers)

        self.build()
        output = in_layers[0].out_tensor
        atom_membership = in_layers[1].out_tensor
        for i, W in enumerate(self.W_list[:-1]):
            output = torch.matmul(output, W) + self.b_list[i]
            output = self.activation(output)
        output = torch.matmul(output, self.W_list[-1]) + self.b_list[-1]
        if self.output_activation:
            output = self.activation(output)
        output = torch.sum(output, atom_membership)
        out_tensor = output
        if set_tensors:
            self.variables = self.trainable_weights
            self.out_tensor = out_tensor
        return out_tensor
Exemple #7
0
    def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
        """
    parent layers: atom_features, distance, distance_membership_i, distance_membership_j
    """
        if in_layers is None:
            in_layers = self.in_layers
        in_layers = convert_to_layers(in_layers)

        self.build()
        atom_features = in_layers[0].out_tensor
        distance = in_layers[1].out_tensor
        distance_membership_i = in_layers[2].out_tensor
        distance_membership_j = in_layers[3].out_tensor
        distance_hidden = torch.matmul(distance, self.W_df) + self.b_df
        atom_features_hidden = torch.matmul(atom_features,
                                            self.W_cf) + self.b_cf
        outputs = torch.multiply(
            distance_hidden,
            torch.gather(atom_features_hidden, distance_membership_j))

        # for atom i in a molecule m, this step multiplies together distance info of atom pair(i,j)
        # and embeddings of atom j(both gone through a hidden layer)
        outputs = torch.matmul(outputs, self.W_fc)
        outputs = self.activation(outputs)

        output_ii = torch.multiply(self.b_df, atom_features_hidden)
        output_ii = torch.matmul(output_ii, self.W_fc)
        output_ii = self.activation(output_ii)

        # for atom i, sum the influence from all other atom j in the molecule
        outputs = torch.sum(outputs,
                            distance_membership_i) - output_ii + atom_features
        out_tensor = outputs
        if set_tensors:
            self.variables = self.trainable_weights
            self.out_tensor = out_tensor
        return out_tensor
Exemple #8
0
    def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
        """
    parent layers: atom_features, membership
    """
        if in_layers is None:
            in_layers = self.in_layers
        in_layers = convert_to_layers(in_layers)

        # Add trainable weights
        self.build()

        # Extract atom_features
        atom_features = in_layers[0].out_tensor
        membership = in_layers[1].out_tensor
        # Extract atom_features
        graph_features = torch.sum(atom_features, membership)
        # sum all graph outputs
        outputs = self.DAGgraph_step(graph_features, self.W_list, self.b_list,
                                     **kwargs)
        out_tensor = outputs
        if set_tensors:
            self.variables = self.trainable_weights
            self.out_tensor = out_tensor
        return out_tensor
Exemple #9
0
    def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
        """ Perform M steps of set2set gather,
        detailed descriptions in: https://arxiv.org/abs/1511.06391 """
        if in_layers is None:
            in_layers = self.in_layers
        in_layers = convert_to_layers(in_layers)

        self.build()
        # Extract atom_features
        atom_features = in_layers[0].out_tensor
        atom_split = in_layers[1].out_tensor

        self.c = torch.zeros((self.batch_size, self.n_hidden))
        self.h = torch.zeros((self.batch_size, self.n_hidden))

        for i in range(self.M):
            q_expanded = torch.gather(self.h, atom_split)
            e = torch.sum(atom_features * q_expanded, 1)
            e_mols = torch.Tensor.unfold(e, atom_split, self.batch_size)
            # Add another value(~-Inf) to prevent error in softmax
            e_mols = [
                torch.cat([e_mol, nn.init.constant_([-1000.])], 0)
                for e_mol in e_mols
            ]
            a = torch.cat([F.softmax(e_mol)[:-1] for e_mol in e_mols], 0)
            r = torch.sum(
                torch.reshape(a, [-1, 1]) * atom_features, atom_split)
            # Model using this layer must set pad_batches=True
            q_star = torch.cat([self.h, r], dim=1)
            self.h, self.c = self.LSTMStep(q_star, self.c)

        out_tensor = q_star
        if set_tensors:
            self.variables = self.trainable_weights
            self.out_tensor = out_tensor
        return out_tensor
Exemple #10
0
    def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
        """
    parent layers: atom_features, parents, calculation_orders, calculation_masks, n_atoms
    """
        if in_layers is None:
            in_layers = self.in_layers
        in_layers = convert_to_layers(in_layers)

        # Add trainable weights
        self.build()

        atom_features = in_layers[0].out_tensor
        # each atom corresponds to a graph, which is represented by the `max_atoms*max_atoms` int32 matrix of index
        # each gragh include `max_atoms` of steps(corresponding to rows) of calculating graph features
        parents = in_layers[1].out_tensor
        # target atoms for each step: (batch_size*max_atoms) * max_atoms
        calculation_orders = in_layers[2].out_tensor
        calculation_masks = in_layers[3].out_tensor

        n_atoms = in_layers[4].out_tensor
        # initialize graph features for each graph
        graph_features_initial = torch.zeros(
            (self.max_atoms * self.batch_size, self.max_atoms + 1,
             self.n_graph_feat))
        # initialize graph features for each graph
        # another row of zeros is generated for padded dummy atoms
        graph_features = model_ops.Variable(graph_features_initial,
                                            trainable=False)

        for count in range(self.max_atoms):
            # `count`-th step
            # extracting atom features of target atoms: (batch_size*max_atoms) * n_atom_features
            mask = calculation_masks[:, count]
            current_round = torch.masked_select(calculation_orders[:, count],
                                                mask)
            batch_atom_features = torch.gather(atom_features, current_round)

            # generating index for graph features used in the inputs
            index = torch.stack([
                torch.reshape(
                    torch.stack(
                        [torch.masked_select(torch.arange(n_atoms), mask)] *
                        (self.max_atoms - 1),
                        dim=1), [-1]),
                torch.reshape(torch.masked_select(parents[:, count, 1:], mask),
                              [-1])
            ],
                                dim=1)
            # extracting graph features for parents of the target atoms, then flatten
            # shape: (batch_size*max_atoms) * [(max_atoms-1)*n_graph_features]
            batch_graph_features = torch.reshape(
                torch.gather(graph_features, index),
                [-1, (self.max_atoms - 1) * self.n_graph_feat])

            # concat into the input tensor: (batch_size*max_atoms) * n_inputs
            batch_inputs = torch.cat(
                dim=1, values=[batch_atom_features, batch_graph_features])
            # DAGgraph_step maps from batch_inputs to a batch of graph_features
            # of shape: (batch_size*max_atoms) * n_graph_features
            # representing the graph features of target atoms in each graph
            batch_outputs = self.DAGgraph_step(batch_inputs, self.W_list,
                                               self.b_list, **kwargs)

            # index for targe atoms
            target_index = torch.stack(
                [torch.arange(n_atoms), parents[:, count, 0]], axis=1)
            target_index = torch.masked_select(target_index, mask)
            # update the graph features for target atoms
            graph_features = torch.Tensor.scatter_(graph_features,
                                                   target_index, batch_outputs)

        out_tensor = batch_outputs
        if set_tensors:
            self.variables = self.trainable_weights
            self.out_tensor = out_tensor
        return out_tensor