Beispiel #1
0
    def __init__(
        self,
        n_in=128,
        n_out=1,
        aggregation_mode="sum",
        n_layers=2,
        n_neurons=None,
        activation=shifted_softplus,
        return_contributions=False,
        requires_dr=False,
        create_graph=False,
        mean=None,
        stddev=None,
        atomref=None,
        max_z=100,
        outnet=None,
        train_embeddings=False,
    ):
        super(Atomwise, self).__init__(requires_dr)

        self.n_layers = n_layers
        self.create_graph = create_graph
        self.return_contributions = return_contributions

        if atomref is not None:
            self.atomref = nn.Embedding.from_pretrained(
                torch.from_numpy(atomref.astype(np.float32)),
                freeze=not train_embeddings,
            )
        elif train_embeddings:
            self.atomref = nn.Embedding.from_pretrained(
                torch.from_numpy(np.zeros((max_z, 1), dtype=np.float32)),
                freeze=not train_embeddings,
            )
        else:
            self.atomref = None
        # network block
        if outnet is None:
            # assign fully connected feed-forward block as network block
            self.out_net = nn.Sequential(
                GetItem("representation"),
                MLP(n_in, n_out, n_neurons, n_layers, activation=activation),
            )
        else:
            self.out_net = outnet
        # assign mean and standard deviation of property
        mean = torch.FloatTensor([0.0]) if mean is None else mean
        stddev = torch.FloatTensor([1.0]) if stddev is None else stddev
        # standardization layer
        self.standardize = ScaleShift(mean, stddev)
        # pooling layer
        if aggregation_mode == "sum":
            self.atom_pool = Aggregate(axis=1, mean=False)
        elif aggregation_mode == "avg":
            self.atom_pool = Aggregate(axis=1, mean=True)
        else:
            raise ValueError("Invalid aggregation_mode={0}!".format(aggregation_mode))
Beispiel #2
0
 def __init__(self, n_in, n_filters, n_out, filter_network,
              cutoff_network=None,
              activation=None, normalize_filter=False, axis=2):
     super(CFConv, self).__init__()
     self.in2f = Dense(n_in, n_filters, bias=False)
     self.f2out = Dense(n_filters, n_out, activation=activation)
     self.filter_network = filter_network
     self.cutoff_network = cutoff_network
     self.agg = Aggregate(axis=axis, mean=normalize_filter)
Beispiel #3
0
    def __init__(self,
                 n_in,
                 n_filters,
                 n_out,
                 filter_network,
                 cutoff_network=None,
                 activation=None,
                 normalize_filter=False,
                 axis=2,
                 n_heads_weights=0,
                 n_heads_conv=0,
                 device=torch.device("cpu"),
                 hyperparams=[0, 0],
                 dropout=0,
                 exp=False):
        super(CFConv, self).__init__()
        self.device = device
        self.n_heads_weights = n_heads_weights
        self.n_heads_conv = n_heads_conv
        self.atomic_embedding_dim = n_out
        self.in2f = Dense(n_in, n_filters, bias=False, activation=None)
        self.f2out = Dense(n_filters, n_out, bias=True, activation=activation)
        self.filter_network = filter_network
        self.cutoff_network = cutoff_network
        #sum over indices
        self.agg = Aggregate(axis=axis, mean=normalize_filter)
        #added multiheaded attention to weights
        self.attention_dim = int(n_out / 4)  #arbitrary -> could modify at will
        if n_heads_weights > 0:
            self.Attention = AttentionHeads(n_in, self.attention_dim,n_heads=self.n_heads_weights,EXP = exp,\
                atomic_embedding_dim=n_out ,device=self.device,SM=False,hyperparams = hyperparams,dropout = dropout)
        #added multiheaded attention to convolution
        if n_heads_conv > 0:
            self.AttentionConv = AttentionHeads(n_in,self.attention_dim,n_heads=self.n_heads_conv,EXP=exp,\
                atomic_embedding_dim=n_out,device=self.device,SM=False,hyperparams = hyperparams,dropout = dropout)#for now should be single head
        #NOTE: the EXP determines if the scalar attention value should be exp(A) or just (A).
        #NOTE: exp(A) can be unstable, as can softmax below

        #add possibility to use softmax over weights
        self.softmax = nn.Softmax(dim=3)

        #not currently used, but could add if deemed beneficial
        self.dropout = nn.Dropout(dropout)
Beispiel #4
0
 def __init__(self,
              n_in,
              n_filters,
              n_out,
              filter_network,
              cutoff_network=None,
              activation=None,
              normalize_filter=False,
              axis=2,
              weight_init=xavier_uniform_):
     super(CFConv, self).__init__()
     self.in2f = Dense(n_in,
                       n_filters,
                       bias=False,
                       activation=None,
                       weight_init=weight_init)
     self.f2out = Dense(n_filters,
                        n_out,
                        bias=True,
                        activation=activation,
                        weight_init=weight_init)
     self.filter_network = filter_network
     self.cutoff_network = cutoff_network
     self.agg = Aggregate(axis=axis, mean=normalize_filter)
Beispiel #5
0
def test_shape_aggregate():
    model = Aggregate(axis=1)
    input_data = torch.rand((3, 4, 5))
    inputs=[input_data]
    out_shape = [3, 5]
    assert_equal_shape(model, inputs, out_shape)
Beispiel #6
0
def test_nn_aggregate_axis():
    data = torch.ones((1, 5, 4, 3), dtype=torch.float)
    agg = Aggregate(axis=0, mean=False)
    assert torch.allclose(torch.ones((5, 4, 3)),
                          agg(data),
                          atol=0.0,
                          rtol=1.0e-7)
    assert list(agg.parameters()) == []
    agg = Aggregate(axis=1, mean=False)
    assert torch.allclose(5 * torch.ones((1, 4, 3)),
                          agg(data),
                          atol=0.0,
                          rtol=1.0e-7)
    assert list(agg.parameters()) == []
    agg = Aggregate(axis=2, mean=False)
    assert torch.allclose(4 * torch.ones((1, 5, 3)),
                          agg(data),
                          atol=0.0,
                          rtol=1.0e-7)
    assert list(agg.parameters()) == []
    agg = Aggregate(axis=3, mean=False)
    assert torch.allclose(3 * torch.ones((1, 5, 4)),
                          agg(data),
                          atol=0.0,
                          rtol=1.0e-7)
    assert list(agg.parameters()) == []