Example #1
0
 def __init__(self,
              max_val=1.0,
              power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
              filter_size=11,
              filter_sigma=1.5,
              k1=0.01,
              k2=0.03):
     super(MSSSIM, self).__init__()
     validator.check_value_type('max_val', max_val, [int, float],
                                self.cls_name)
     validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
     self.max_val = max_val
     validator.check_value_type('power_factors', power_factors,
                                [tuple, list], self.cls_name)
     self.filter_size = validator.check_integer('filter_size', filter_size,
                                                1, Rel.GE, self.cls_name)
     self.filter_sigma = validator.check_float_positive(
         'filter_sigma', filter_sigma, self.cls_name)
     self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
     self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
     window = _create_window(filter_size, filter_sigma)
     self.level = len(power_factors)
     self.conv = []
     for i in range(self.level):
         self.conv.append(_conv2d(1, 1, filter_size, Tensor(window)))
         self.conv[i].weight.requires_grad = False
     self.multi_convs_list = CellList(self.conv)
     self.weight_tensor = Tensor(power_factors, mstype.float32)
     self.avg_pool = AvgPool2d(kernel_size=2, stride=2, pad_mode='valid')
     self.relu = ReLU()
     self.reduce_mean = P.ReduceMean()
     self.prod = P.ReduceProd()
     self.pow = P.Pow()
     self.pack = P.Pack(axis=-1)
     self.concat = P.Concat(axis=1)
 def construct(self, x, query_hidden_state, attention_mask, layer_past=None):
     original_shape = F.shape(x)
     x = F.reshape(x, (-1, original_shape[-1]))
     query_hidden_state = F.reshape(query_hidden_state, (-1, original_shape[-1]))
     query = self.dense1(query_hidden_state)
     key = self.dense2(x)
     value = self.dense3(x)
     query = self.transpose(
         F.reshape(
             query,
             (-1, original_shape[1], self.n_head, self.size_per_head)),
         (0, 2, 1, 3))
     key = self.transpose(
         F.reshape(
             key, (-1, original_shape[1], self.n_head, self.size_per_head)),
         (0, 2, 3, 1))
     value = self.transpose(
         F.reshape(
             value,
             (-1, original_shape[1], self.n_head, self.size_per_head)),
         (0, 2, 1, 3))
     if self.use_past:
         past_value = layer_past[1]
         past_key = self.transpose(layer_past[0], (0, 1, 3, 2))
         key = self.concat_k((past_key, key))
         value = self.concat_v(past_value, value)
     layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value])
     attention = self._attn(query, key, value, attention_mask)
     attention_merge = self.merge_heads(attention)
     output = self.projection(attention_merge)
     output = self.dropout(output)
     return output, layer_present
Example #3
0
    def construct(self, x, attention_mask, layer_past=None):
        """
        self-attention

        Inputs:
            x: output of previous layer
            attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
            layer_past: the previous feature map

        Returns:
            output: Tensor, the output logit of this layer
            layer_present: Tensor, the feature map of current layer
        """

        original_shape = F.shape(x)
        x = F.reshape(x, (-1, original_shape[-1]))
        query = self.dense1(x)
        key = self.dense2(x)
        value = self.dense3(x)
        query = self.transpose(F.reshape(query, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3))
        key = self.transpose(F.reshape(key, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 3, 1))
        value = self.transpose(F.reshape(value, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3))
        if self.use_past:
            past_value = layer_past[1]
            past_key = self.transpose(layer_past[0], (0, 1, 3, 2))
            key = self.concat_k((past_key, key))
            value = self.concat_v(past_value, value)
        layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value])
        attention = self._attn(query, key, value, attention_mask)
        attention_merge = self.merge_heads(attention)
        output = self.projection(attention_merge)
        output = self.dropout(output)
        return output, layer_present
Example #4
0
 def __init__(self, weight1, weight2, axis=0, strategy1=None, strategy2=None, is_parameter=True):
     super(Net, self).__init__()
     self.pack = P.Pack(axis=axis).shard(strategy1)
     self.mul = P.Mul().shard(strategy2)
     if is_parameter:
         self.weight1 = Parameter(weight1, "w1")
     else:
         self.weight1 = weight1
     self.weight2 = Parameter(weight2, "w2")
Example #5
0
 def __init__(self,
              weight1,
              weight2,
              axis=0,
              strategy1=None,
              strategy2=None):
     super(Net1, self).__init__()
     self.pack = P.Pack(axis=axis).shard(strategy1)
     self.mul = P.Mul().shard(strategy2)
     self.weight1 = Parameter(weight1, "w1")
     self.weight2 = Parameter(weight2, "w2")
Example #6
0
 def __init__(self,
              dense_in_channel,
              dense_out_channel,
              axis=0,
              shape=None,
              strategy=None):
     super().__init__()
     weight_np = np.full((dense_out_channel, dense_in_channel),
                         0.01,
                         dtype=np.float32)
     bias_np = np.full((dense_out_channel), 0.01, dtype=np.float32)
     self.pack_con = Tensor(np.full(shape, 0.01, dtype=np.float32))
     self.flat = Flatten()
     self.dense = Dense(in_channels=dense_in_channel,
                        out_channels=dense_out_channel,
                        weight_init=Tensor(weight_np),
                        bias_init=Tensor(bias_np),
                        has_bias=True)
     self.mul = P.Mul()
     self.pack = P.Pack(axis)
     if strategy is not None:
         self.pack.shard(strategy)
Example #7
0
 def __init__(self, x, axis):
     super(Net, self).__init__()
     self.pack = P.Pack(axis)
     self.x = x
Example #8
0
    def __init__(
        self,
        dim_atom_embed,
        num_rbf,
        n_heads=8,
        activation=Swish(),
        max_cycles=10,
        time_embedding=0,
        use_pondering=True,
        fixed_cycles=False,
        use_filter=True,
        inside_filter=None,
        act_threshold=0.9,
        fixed_neigh=False,
    ):
        super().__init__(gather_dim=dim_atom_embed, fixed_neigh=fixed_neigh)
        if dim_atom_embed % n_heads != 0:
            raise ValueError('The term "dim_atom_embed" cannot be divisible ' +
                             'by the term "n_heads" in AirNetIneteraction! ')

        self.n_heads = n_heads
        self.max_cycles = max_cycles
        self.dim_atom_embed = dim_atom_embed
        self.num_rbf = num_rbf
        self.time_embedding = time_embedding

        if fixed_cycles:
            self.flexable_cycels = False
        else:
            self.flexable_cycels = True

        self.use_filter = use_filter
        if self.use_filter:
            # self.filter = Filter(num_rbf,dim_atom_embed,activation)
            self.filter = Dense(num_rbf,
                                dim_atom_embed,
                                has_bias=True,
                                activation=None)

        self.positional_embedding = PositionalEmbedding(dim_atom_embed)
        self.multi_head_attention = MultiheadAttention(dim_atom_embed, n_heads)

        self.act_threshold = act_threshold
        self.act_epsilon = 1.0 - act_threshold

        self.use_pondering = use_pondering
        self.pondering = None
        self.act_weight = None
        if self.max_cycles > 1:
            if self.use_pondering:
                self.pondering = Pondering(dim_atom_embed * 3, bias_const=3)
                self.act_weight = ACTWeight(self.act_threshold)
            else:
                if self.flexable_cycels:
                    raise ValueError(
                        'The term "fixed_cycles" must be True ' +
                        'when the pondering network is None in AirNetIneteraction! '
                    )
        self.fixed_weight = Tensor(1.0 / max_cycles, ms.float32)

        self.max = P.Maximum()
        self.min = P.Minimum()
        self.concat = P.Concat(-1)
        self.pack = P.Pack()
        self.reducesum = P.ReduceSum()
        self.squeeze = P.Squeeze(-1)
        self.ones_like = P.OnesLike()
        self.zeros_like = P.ZerosLike()
        self.zeros = P.Zeros()
Example #9
0
     'desc_bprop': [[4, 2]]}),
 ('ConcatV2_4', {
     'block': P.Concat(axis=0),
     'desc_inputs': [
         (Tensor(np.ones((3, 2, 3), np.float32)),
          Tensor(np.ones((5, 2, 3), np.float32)),
          Tensor(np.ones((6, 2, 3), np.float32)))],
     'desc_bprop': [[14, 2, 3]]}),
 ('ConcatV2_5', {
     'block': P.Concat(axis=-1),
     'desc_inputs': [(Tensor(np.array([1], np.float32)),
                      Tensor(np.array([1], np.float32)),
                      Tensor(np.array([1], np.float32)))],
     'desc_bprop': [[3,]]}),
 ('Pack_0', {
     'block': NetForPackInput(P.Pack()),
     'desc_inputs':[[2, 2], [2, 2], [2, 2]],
     'desc_bprop':[[3, 2, 2]],
 }),
 ('Pack_1', {
     'block': NetForPackInput(P.Pack(axis=-2)),
     'desc_inputs':[[3, 2, 3], [3, 2, 3], [3, 2, 3]],
     'desc_bprop':[[3, 2, 3, 3]],
 }),
 ('Pack_2', {
     'block': NetForPackInput(P.Pack()),
     'desc_inputs':[[2, 2]],
     'desc_bprop':[[2, 2, 2]],
 }),
 ('Pack_3', {
     'block': NetForPackInput(P.Pack()),
Example #10
0
 def __init__(self):
     super(PackNet, self).__init__()
     self.pack = P.Pack()
Example #11
0
def stack(inputs: List[Tensor], axis: int) -> Tensor:
    """Packs a list of tensors in specified axis."""
    pack_op = op.Pack(axis)
    outputs = pack_op(inputs)
    return outputs
Example #12
0
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from mindspore.ops import operations as P
from mindspore.ops import Primitive

pack = P.Pack()
concat = P.Concat()
make_tuple = Primitive('make_tuple')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]

Example #13
0
    def __init__(
        self,
        num_atomtypes=100,
        dim_atomembedding=64,
        min_rbf_dis=0.05,
        max_rbf_dis=1,
        num_rbf=32,
        n_interactions=3,
        n_heads=8,
        max_cycles=10,
        activation=Swish(),
        output_dim=1,
        self_dis=None,
        rbf_sigma=None,
        distance_expansion=LogGaussianDistribution,
        cutoff=None,
        cutoff_network=SmoothCutoff,
        public_filter=True,
        coupled_interactions=False,
        trainable_gaussians=False,
        use_pondering=True,
        fixed_cycles=False,
        rescale_rbf=True,
        use_time_embedding=True,
        use_all_interactions=True,
        use_mcr=False,
        debug=False,
    ):
        super().__init__(
            num_atomtypes=num_atomtypes,
            dim_atomembedding=dim_atomembedding,
            min_rbf_dis=min_rbf_dis,
            max_rbf_dis=max_rbf_dis,
            num_rbf=num_rbf,
            output_dim=output_dim,
            rbf_sigma=rbf_sigma,
            distance_expansion=distance_expansion,
            cutoff=cutoff,
            cutoff_network=cutoff_network,
            rescale_rbf=rescale_rbf,
            use_all_interactions=use_all_interactions,
        )
        self.network_name = 'AirNet'
        self.max_distance = max_rbf_dis
        self.min_distance = min_rbf_dis

        if self_dis is None:
            self.self_dis = self.min_distance
        else:
            self.self_dis = self_dis

        self.self_dis_tensor = Tensor([self.self_dis], ms.float32)

        self.n_heads = n_heads

        if use_time_embedding:
            time_embedding = self._get_time_signal(max_cycles,
                                                   dim_atomembedding)
        else:
            time_embedding = [0 for _ in range(max_cycles)]

        if public_filter:
            inter_filter = False
            self.filter = Filter(num_rbf, dim_atomembedding, None)
        else:
            inter_filter = True
            self.filter = None

        self.n_interactions = n_interactions

        # block for computing interaction
        if coupled_interactions:
            # use the same SchNetInteraction instance (hence the same weights)
            self.interactions = nn.CellList([
                AirNetInteraction(
                    dim_atom_embed=dim_atomembedding,
                    num_rbf=num_rbf,
                    n_heads=n_heads,
                    activation=activation,
                    max_cycles=max_cycles,
                    time_embedding=time_embedding,
                    use_filter=inter_filter,
                    use_pondering=use_pondering,
                    fixed_cycles=fixed_cycles,
                )
            ] * n_interactions)
        else:
            # use one SchNetInteraction instance for each interaction
            self.interactions = nn.CellList([
                AirNetInteraction(
                    dim_atom_embed=dim_atomembedding,
                    num_rbf=num_rbf,
                    n_heads=n_heads,
                    activation=activation,
                    max_cycles=max_cycles,
                    time_embedding=time_embedding,
                    use_filter=inter_filter,
                    use_pondering=use_pondering,
                    fixed_cycles=fixed_cycles,
                ) for i in range(n_interactions)
            ])

        # readout layer
        if self.use_all_interactions and n_interactions > 1:
            if use_mcr:
                self.gather_interactions = MultipleChannelRepresentation(
                    n_interactions, dim_atomembedding, 1, activation)
            else:
                self.gather_interactions = TensorSum()
        else:
            self.gather_interactions = None

        readoutdim = int(dim_atomembedding / 2)
        self.readout = AtomwiseReadout(dim_atomembedding, self.output_dim, [
            readoutdim,
        ], activation)

        if debug:
            self.debug_fun = self._debug_fun

        self.lmax_label = []
        for i in range(n_interactions):
            self.lmax_label.append('l' + str(i) + '_cycles')

        self.fill = P.Fill()
        self.concat = P.Concat(-1)
        self.pack = P.Pack(-1)
        self.reducesum = P.ReduceSum()
        self.reducemax = P.ReduceMax()
        self.tensor_summary = P.TensorSummary()
        self.scalar_summary = P.ScalarSummary()
Example #14
0
     'desc_bprop': [[4, 2]]}),
 ('ConcatV2_4', {
     'block': P.Concat(axis=0),
     'desc_inputs': [
         (Tensor(np.ones((3, 2, 3), np.float32)),
          Tensor(np.ones((5, 2, 3), np.float32)),
          Tensor(np.ones((6, 2, 3), np.float32)))],
     'desc_bprop': [[14, 2, 3]]}),
 ('ConcatV2_5', {
     'block': P.Concat(axis=-1),
     'desc_inputs': [(Tensor(np.array([1], np.float32)),
                      Tensor(np.array([1], np.float32)),
                      Tensor(np.array([1], np.float32)))],
     'desc_bprop': [[3, ]]}),
 ('Pack_0', {
     'block': NetForPackInput(P.Pack()),
     'desc_inputs': [[2, 2], [2, 2], [2, 2]],
     'desc_bprop': [[3, 2, 2]],
 }),
 ('Pack_1', {
     'block': NetForPackInput(P.Pack(axis=-2)),
     'desc_inputs': [[3, 2, 3], [3, 2, 3], [3, 2, 3]],
     'desc_bprop': [[3, 2, 3, 3]],
 }),
 ('Pack_2', {
     'block': NetForPackInput(P.Pack()),
     'desc_inputs': [[128, 128], [128, 128]],
     'desc_bprop': [[2, 128, 128]],
 }),
 ('Unpack_0', {
     'block': NetForUnpackInput(P.Unpack(axis=0)),