def __init__(self, dim_rbf, dim_sbf, dim_msg, n_output, n_res_interaction, n_res_msg, n_dense_output, dim_bi_linear, activation, uncertainty_modify): super().__init__() self.uncertainty_modify = uncertainty_modify self.activation = activation_getter(activation) msg_gate = torch.zeros(1, dim_msg).fill_(1.).type(floating_type) self.register_parameter('gate', nn.Parameter(msg_gate, requires_grad=True)) self.message_pass_layer = DimeNetMPN(dim_bi_linear, dim_msg, dim_rbf, dim_sbf, activation) self.n_res_interaction = n_res_interaction for i in range(n_res_interaction): self.add_module('res_interaction{}'.format(i), ResidualLayer(dim_msg, activation)) self.lin_interacted_msg = nn.Linear(dim_msg, dim_msg) self.n_res_msg = n_res_msg for i in range(n_res_msg): self.add_module('res_msg{}'.format(i), ResidualLayer(dim_msg, activation)) self.output_layer = OutputLayer( dim_msg, dim_rbf, n_output, n_dense_output, activation, concrete_dropout=(uncertainty_modify == 'concreteDropoutOutput'))
def __init__(self, embedding_dim, rbf_dim, n_output, n_dense, activation, concrete_dropout): super().__init__() self.concrete_dropout = concrete_dropout self.embedding_dim = embedding_dim self.rbf_dim = rbf_dim self.activation = activation_getter(activation) self.n_dense = n_dense for i in range(n_dense): if self.concrete_dropout: self.add_module( 'dense{}'.format(i), ConcreteDropout(nn.Linear(embedding_dim, embedding_dim), module_type='Linear')) else: self.add_module('dense{}'.format(i), nn.Linear(embedding_dim, embedding_dim)) self.lin_rbf = nn.Linear(rbf_dim, embedding_dim, bias=False) self.scatter_fn = _MPNScatter() self.out_dense = nn.Linear(embedding_dim, n_output, bias=False) self.out_dense.weight.data.zero_() if self.concrete_dropout: self.out_dense = ConcreteDropout(self.out_dense, module_type='Linear')
def __init__(self, F, activation, concrete_dropout=False, batch_norm=False, dropout=False): super().__init__() self.batch_norm = batch_norm self.concrete_dropout = concrete_dropout self.activation = activation_getter(activation) self.lin1 = nn.Linear(F, F) self.lin1.weight.data = semi_orthogonal_glorot_weights(F, F) self.lin1.bias.data.zero_() if self.batch_norm: self.bn1 = nn.BatchNorm1d(F, momentum=1.) self.lin2 = nn.Linear(F, F) self.lin2.weight.data = semi_orthogonal_glorot_weights(F, F) self.lin2.bias.data.zero_() if self.batch_norm: self.bn2 = nn.BatchNorm1d(F, momentum=1.) if self.concrete_dropout: self.lin1 = ConcreteDropout(self.lin1, module_type='Linear') self.lin2 = ConcreteDropout(self.lin2, module_type='Linear')
def __init__(self, n_tensor, dim_msg, dim_rbf, dim_sbf, activation): super().__init__() self.n_tensor = n_tensor self.dim_msg = dim_msg self.lin_source = nn.Linear(dim_msg, dim_msg) self.lin_target = nn.Linear(dim_msg, dim_msg) self.lin_rbf = nn.Linear(dim_rbf, dim_msg, bias=False) self.lin_sbf = nn.Linear(dim_sbf, n_tensor, bias=False) ''' registering bi-linear layer weight (without bias) ''' W_bi_linear = torch.zeros(dim_msg, dim_msg, n_tensor).type(floating_type).uniform_( -1 / dim_msg, 1 / dim_msg) self.register_parameter( 'W_bi_linear', torch.nn.Parameter(W_bi_linear, requires_grad=True)) self.activation = activation_getter(activation)
def __init__(self, dim_rbf, dim_edge, activation): super().__init__() self.lin_rbf = nn.Linear(dim_rbf, dim_edge) self.lin_concat = nn.Linear(dim_edge * 3, dim_edge) self.activation = activation_getter(activation)
def __init__(self, F, n_output, n_res_output, activation, uncertainty_modify, n_read_out=0, batch_norm=False, dropout=False): self.batch_norm = batch_norm super().__init__() self.concrete_dropout = ( uncertainty_modify.split('[')[0] == "concreteDropoutOutput") self.dropout_options = option_solver(uncertainty_modify) # convert string into correct types: if 'train_p' in self.dropout_options: self.dropout_options['train_p'] = ( self.dropout_options['train_p'].lower() == 'true') if 'normal_dropout' in self.dropout_options: self.dropout_options['normal_dropout'] = ( self.dropout_options['normal_dropout'].lower() == 'true') if 'init_min' in self.dropout_options: self.dropout_options['init_min'] = float( self.dropout_options['init_min']) if 'init_max' in self.dropout_options: self.dropout_options['init_max'] = float( self.dropout_options['init_max']) self.n_res_output = n_res_output self.n_read_out = n_read_out for i in range(n_res_output): self.add_module( 'res_layer' + str(i), ResidualLayer(F, activation, concrete_dropout=False, batch_norm=batch_norm, dropout=dropout)) # Readout layers dim_decay = True # this is for compatibility issues, always set to True otherwise if not dim_decay: print('WARNING, dim decay is not enabled!') last_dim = F for i in range(n_read_out): if dim_decay: this_dim = ceil(last_dim / 2) read_out_i = torch.nn.Linear(last_dim, this_dim) last_dim = this_dim else: read_out_i = torch.nn.Linear(last_dim, last_dim) this_dim = last_dim if self.concrete_dropout: read_out_i = ConcreteDropout(read_out_i, module_type='Linear', **self.dropout_options) self.add_module('read_out{}'.format(i), read_out_i) if self.batch_norm: self.add_module("bn_{}".format(i), torch.nn.BatchNorm1d(last_dim, momentum=1.)) self.lin = torch.nn.Linear(last_dim, n_output, bias=False) self.lin.weight.data.zero_() if self.concrete_dropout: self.lin = ConcreteDropout(self.lin, module_type='Linear', **self.dropout_options) if self.batch_norm: self.bn_last = torch.nn.BatchNorm1d(last_dim, momentum=1.) self.activation = activation_getter(activation)