def __init__(self, max_val=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): super(MSSSIM, self).__init__() validator.check_value_type('max_val', max_val, [int, float], self.cls_name) validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name) self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) self.filter_sigma = validator.check_float_positive( 'filter_sigma', filter_sigma, self.cls_name) self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) window = _create_window(filter_size, filter_sigma) self.level = len(power_factors) self.conv = [] for i in range(self.level): self.conv.append(_conv2d(1, 1, filter_size, Tensor(window))) self.conv[i].weight.requires_grad = False self.multi_convs_list = CellList(self.conv) self.weight_tensor = Tensor(power_factors, mstype.float32) self.avg_pool = AvgPool2d(kernel_size=2, stride=2, pad_mode='valid') self.relu = ReLU() self.reduce_mean = P.ReduceMean() self.prod = P.ReduceProd() self.pow = P.Pow() self.pack = P.Pack(axis=-1) self.concat = P.Concat(axis=1)
def construct(self, x, query_hidden_state, attention_mask, layer_past=None): original_shape = F.shape(x) x = F.reshape(x, (-1, original_shape[-1])) query_hidden_state = F.reshape(query_hidden_state, (-1, original_shape[-1])) query = self.dense1(query_hidden_state) key = self.dense2(x) value = self.dense3(x) query = self.transpose( F.reshape( query, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3)) key = self.transpose( F.reshape( key, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 3, 1)) value = self.transpose( F.reshape( value, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3)) if self.use_past: past_value = layer_past[1] past_key = self.transpose(layer_past[0], (0, 1, 3, 2)) key = self.concat_k((past_key, key)) value = self.concat_v(past_value, value) layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value]) attention = self._attn(query, key, value, attention_mask) attention_merge = self.merge_heads(attention) output = self.projection(attention_merge) output = self.dropout(output) return output, layer_present
def construct(self, x, attention_mask, layer_past=None): """ self-attention Inputs: x: output of previous layer attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length) layer_past: the previous feature map Returns: output: Tensor, the output logit of this layer layer_present: Tensor, the feature map of current layer """ original_shape = F.shape(x) x = F.reshape(x, (-1, original_shape[-1])) query = self.dense1(x) key = self.dense2(x) value = self.dense3(x) query = self.transpose(F.reshape(query, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3)) key = self.transpose(F.reshape(key, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 3, 1)) value = self.transpose(F.reshape(value, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3)) if self.use_past: past_value = layer_past[1] past_key = self.transpose(layer_past[0], (0, 1, 3, 2)) key = self.concat_k((past_key, key)) value = self.concat_v(past_value, value) layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value]) attention = self._attn(query, key, value, attention_mask) attention_merge = self.merge_heads(attention) output = self.projection(attention_merge) output = self.dropout(output) return output, layer_present
def __init__(self, weight1, weight2, axis=0, strategy1=None, strategy2=None, is_parameter=True): super(Net, self).__init__() self.pack = P.Pack(axis=axis).shard(strategy1) self.mul = P.Mul().shard(strategy2) if is_parameter: self.weight1 = Parameter(weight1, "w1") else: self.weight1 = weight1 self.weight2 = Parameter(weight2, "w2")
def __init__(self, weight1, weight2, axis=0, strategy1=None, strategy2=None): super(Net1, self).__init__() self.pack = P.Pack(axis=axis).shard(strategy1) self.mul = P.Mul().shard(strategy2) self.weight1 = Parameter(weight1, "w1") self.weight2 = Parameter(weight2, "w2")
def __init__(self, dense_in_channel, dense_out_channel, axis=0, shape=None, strategy=None): super().__init__() weight_np = np.full((dense_out_channel, dense_in_channel), 0.01, dtype=np.float32) bias_np = np.full((dense_out_channel), 0.01, dtype=np.float32) self.pack_con = Tensor(np.full(shape, 0.01, dtype=np.float32)) self.flat = Flatten() self.dense = Dense(in_channels=dense_in_channel, out_channels=dense_out_channel, weight_init=Tensor(weight_np), bias_init=Tensor(bias_np), has_bias=True) self.mul = P.Mul() self.pack = P.Pack(axis) if strategy is not None: self.pack.shard(strategy)
def __init__(self, x, axis): super(Net, self).__init__() self.pack = P.Pack(axis) self.x = x
def __init__( self, dim_atom_embed, num_rbf, n_heads=8, activation=Swish(), max_cycles=10, time_embedding=0, use_pondering=True, fixed_cycles=False, use_filter=True, inside_filter=None, act_threshold=0.9, fixed_neigh=False, ): super().__init__(gather_dim=dim_atom_embed, fixed_neigh=fixed_neigh) if dim_atom_embed % n_heads != 0: raise ValueError('The term "dim_atom_embed" cannot be divisible ' + 'by the term "n_heads" in AirNetIneteraction! ') self.n_heads = n_heads self.max_cycles = max_cycles self.dim_atom_embed = dim_atom_embed self.num_rbf = num_rbf self.time_embedding = time_embedding if fixed_cycles: self.flexable_cycels = False else: self.flexable_cycels = True self.use_filter = use_filter if self.use_filter: # self.filter = Filter(num_rbf,dim_atom_embed,activation) self.filter = Dense(num_rbf, dim_atom_embed, has_bias=True, activation=None) self.positional_embedding = PositionalEmbedding(dim_atom_embed) self.multi_head_attention = MultiheadAttention(dim_atom_embed, n_heads) self.act_threshold = act_threshold self.act_epsilon = 1.0 - act_threshold self.use_pondering = use_pondering self.pondering = None self.act_weight = None if self.max_cycles > 1: if self.use_pondering: self.pondering = Pondering(dim_atom_embed * 3, bias_const=3) self.act_weight = ACTWeight(self.act_threshold) else: if self.flexable_cycels: raise ValueError( 'The term "fixed_cycles" must be True ' + 'when the pondering network is None in AirNetIneteraction! ' ) self.fixed_weight = Tensor(1.0 / max_cycles, ms.float32) self.max = P.Maximum() self.min = P.Minimum() self.concat = P.Concat(-1) self.pack = P.Pack() self.reducesum = P.ReduceSum() self.squeeze = P.Squeeze(-1) self.ones_like = P.OnesLike() self.zeros_like = P.ZerosLike() self.zeros = P.Zeros()
'desc_bprop': [[4, 2]]}), ('ConcatV2_4', { 'block': P.Concat(axis=0), 'desc_inputs': [ (Tensor(np.ones((3, 2, 3), np.float32)), Tensor(np.ones((5, 2, 3), np.float32)), Tensor(np.ones((6, 2, 3), np.float32)))], 'desc_bprop': [[14, 2, 3]]}), ('ConcatV2_5', { 'block': P.Concat(axis=-1), 'desc_inputs': [(Tensor(np.array([1], np.float32)), Tensor(np.array([1], np.float32)), Tensor(np.array([1], np.float32)))], 'desc_bprop': [[3,]]}), ('Pack_0', { 'block': NetForPackInput(P.Pack()), 'desc_inputs':[[2, 2], [2, 2], [2, 2]], 'desc_bprop':[[3, 2, 2]], }), ('Pack_1', { 'block': NetForPackInput(P.Pack(axis=-2)), 'desc_inputs':[[3, 2, 3], [3, 2, 3], [3, 2, 3]], 'desc_bprop':[[3, 2, 3, 3]], }), ('Pack_2', { 'block': NetForPackInput(P.Pack()), 'desc_inputs':[[2, 2]], 'desc_bprop':[[2, 2, 2]], }), ('Pack_3', { 'block': NetForPackInput(P.Pack()),
def __init__(self): super(PackNet, self).__init__() self.pack = P.Pack()
def stack(inputs: List[Tensor], axis: int) -> Tensor: """Packs a list of tensors in specified axis.""" pack_op = op.Pack(axis) outputs = pack_op(inputs) return outputs
# you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import operations as P from mindspore.ops import Primitive pack = P.Pack() concat = P.Concat() make_tuple = Primitive('make_tuple') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
def __init__( self, num_atomtypes=100, dim_atomembedding=64, min_rbf_dis=0.05, max_rbf_dis=1, num_rbf=32, n_interactions=3, n_heads=8, max_cycles=10, activation=Swish(), output_dim=1, self_dis=None, rbf_sigma=None, distance_expansion=LogGaussianDistribution, cutoff=None, cutoff_network=SmoothCutoff, public_filter=True, coupled_interactions=False, trainable_gaussians=False, use_pondering=True, fixed_cycles=False, rescale_rbf=True, use_time_embedding=True, use_all_interactions=True, use_mcr=False, debug=False, ): super().__init__( num_atomtypes=num_atomtypes, dim_atomembedding=dim_atomembedding, min_rbf_dis=min_rbf_dis, max_rbf_dis=max_rbf_dis, num_rbf=num_rbf, output_dim=output_dim, rbf_sigma=rbf_sigma, distance_expansion=distance_expansion, cutoff=cutoff, cutoff_network=cutoff_network, rescale_rbf=rescale_rbf, use_all_interactions=use_all_interactions, ) self.network_name = 'AirNet' self.max_distance = max_rbf_dis self.min_distance = min_rbf_dis if self_dis is None: self.self_dis = self.min_distance else: self.self_dis = self_dis self.self_dis_tensor = Tensor([self.self_dis], ms.float32) self.n_heads = n_heads if use_time_embedding: time_embedding = self._get_time_signal(max_cycles, dim_atomembedding) else: time_embedding = [0 for _ in range(max_cycles)] if public_filter: inter_filter = False self.filter = Filter(num_rbf, dim_atomembedding, None) else: inter_filter = True self.filter = None self.n_interactions = n_interactions # block for computing interaction if coupled_interactions: # use the same SchNetInteraction instance (hence the same weights) self.interactions = nn.CellList([ AirNetInteraction( dim_atom_embed=dim_atomembedding, num_rbf=num_rbf, n_heads=n_heads, activation=activation, max_cycles=max_cycles, time_embedding=time_embedding, use_filter=inter_filter, use_pondering=use_pondering, fixed_cycles=fixed_cycles, ) ] * n_interactions) else: # use one SchNetInteraction instance for each interaction self.interactions = nn.CellList([ AirNetInteraction( dim_atom_embed=dim_atomembedding, num_rbf=num_rbf, n_heads=n_heads, activation=activation, max_cycles=max_cycles, time_embedding=time_embedding, use_filter=inter_filter, use_pondering=use_pondering, fixed_cycles=fixed_cycles, ) for i in range(n_interactions) ]) # readout layer if self.use_all_interactions and n_interactions > 1: if use_mcr: self.gather_interactions = MultipleChannelRepresentation( n_interactions, dim_atomembedding, 1, activation) else: self.gather_interactions = TensorSum() else: self.gather_interactions = None readoutdim = int(dim_atomembedding / 2) self.readout = AtomwiseReadout(dim_atomembedding, self.output_dim, [ readoutdim, ], activation) if debug: self.debug_fun = self._debug_fun self.lmax_label = [] for i in range(n_interactions): self.lmax_label.append('l' + str(i) + '_cycles') self.fill = P.Fill() self.concat = P.Concat(-1) self.pack = P.Pack(-1) self.reducesum = P.ReduceSum() self.reducemax = P.ReduceMax() self.tensor_summary = P.TensorSummary() self.scalar_summary = P.ScalarSummary()
'desc_bprop': [[4, 2]]}), ('ConcatV2_4', { 'block': P.Concat(axis=0), 'desc_inputs': [ (Tensor(np.ones((3, 2, 3), np.float32)), Tensor(np.ones((5, 2, 3), np.float32)), Tensor(np.ones((6, 2, 3), np.float32)))], 'desc_bprop': [[14, 2, 3]]}), ('ConcatV2_5', { 'block': P.Concat(axis=-1), 'desc_inputs': [(Tensor(np.array([1], np.float32)), Tensor(np.array([1], np.float32)), Tensor(np.array([1], np.float32)))], 'desc_bprop': [[3, ]]}), ('Pack_0', { 'block': NetForPackInput(P.Pack()), 'desc_inputs': [[2, 2], [2, 2], [2, 2]], 'desc_bprop': [[3, 2, 2]], }), ('Pack_1', { 'block': NetForPackInput(P.Pack(axis=-2)), 'desc_inputs': [[3, 2, 3], [3, 2, 3], [3, 2, 3]], 'desc_bprop': [[3, 2, 3, 3]], }), ('Pack_2', { 'block': NetForPackInput(P.Pack()), 'desc_inputs': [[128, 128], [128, 128]], 'desc_bprop': [[2, 128, 128]], }), ('Unpack_0', { 'block': NetForUnpackInput(P.Unpack(axis=0)),