Ejemplo n.º 1
0
 def __init__(self, tag_tuple=None, scalar=1):
     super(SummaryNet, self).__init__()
     self.summary_s = P.ScalarSummary()
     self.summary_i = P.ImageSummary()
     self.summary_t = P.TensorSummary()
     self.add = P.TensorAdd()
     self.tag_tuple = tag_tuple
     self.scalar = scalar
Ejemplo n.º 2
0
 def __init__(self, num_class=10, channel=1):
     super(LeNet5, self).__init__()
     self.num_class = num_class
     self.conv1 = conv(channel, 6, 5)
     self.conv2 = conv(6, 16, 5)
     self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
     self.fc2 = fc_with_initialize(120, 84)
     self.fc3 = fc_with_initialize(84, self.num_class)
     self.relu = nn.ReLU()
     self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
     self.flatten = nn.Flatten()
     self.scalar_summary = P.ScalarSummary()
     self.image_summary = P.ImageSummary()
     self.histogram_summary = P.HistogramSummary()
     self.tensor_summary = P.TensorSummary()
     self.channel = Tensor(channel)
Ejemplo n.º 3
0
    def __init__(self, num_class=10, num_channel=1, include_top=True):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.include_top = include_top
        if self.include_top:
            self.flatten = nn.Flatten()
            self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
            self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
            self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))

        self.scalar_summary = P.ScalarSummary()
        self.image_summary = P.ImageSummary()
        self.tensor_summary = P.TensorSummary()
        self.channel = Tensor(num_channel)
Ejemplo n.º 4
0
 def __init__(self, ):
     super(SummaryDemo, self).__init__()
     self.s = P.TensorSummary()
     self.add = P.Add()
Ejemplo n.º 5
0
 def __init__(self):
     super().__init__()
     self.scalar_summary = P.ScalarSummary()
     self.image_summary = P.ImageSummary()
     self.tensor_summary = P.TensorSummary()
     self.histogram_summary = P.HistogramSummary()
Ejemplo n.º 6
0
    def __init__(
        self,
        num_atomtypes=100,
        dim_atomembedding=64,
        min_rbf_dis=0.05,
        max_rbf_dis=1,
        num_rbf=32,
        n_interactions=3,
        n_heads=8,
        max_cycles=10,
        activation=Swish(),
        output_dim=1,
        self_dis=None,
        rbf_sigma=None,
        distance_expansion=LogGaussianDistribution,
        cutoff=None,
        cutoff_network=SmoothCutoff,
        public_filter=True,
        coupled_interactions=False,
        trainable_gaussians=False,
        use_pondering=True,
        fixed_cycles=False,
        rescale_rbf=True,
        use_time_embedding=True,
        use_all_interactions=True,
        use_mcr=False,
        debug=False,
    ):
        super().__init__(
            num_atomtypes=num_atomtypes,
            dim_atomembedding=dim_atomembedding,
            min_rbf_dis=min_rbf_dis,
            max_rbf_dis=max_rbf_dis,
            num_rbf=num_rbf,
            output_dim=output_dim,
            rbf_sigma=rbf_sigma,
            distance_expansion=distance_expansion,
            cutoff=cutoff,
            cutoff_network=cutoff_network,
            rescale_rbf=rescale_rbf,
            use_all_interactions=use_all_interactions,
        )
        self.network_name = 'AirNet'
        self.max_distance = max_rbf_dis
        self.min_distance = min_rbf_dis

        if self_dis is None:
            self.self_dis = self.min_distance
        else:
            self.self_dis = self_dis

        self.self_dis_tensor = Tensor([self.self_dis], ms.float32)

        self.n_heads = n_heads

        if use_time_embedding:
            time_embedding = self._get_time_signal(max_cycles,
                                                   dim_atomembedding)
        else:
            time_embedding = [0 for _ in range(max_cycles)]

        if public_filter:
            inter_filter = False
            self.filter = Filter(num_rbf, dim_atomembedding, None)
        else:
            inter_filter = True
            self.filter = None

        self.n_interactions = n_interactions

        # block for computing interaction
        if coupled_interactions:
            # use the same SchNetInteraction instance (hence the same weights)
            self.interactions = nn.CellList([
                AirNetInteraction(
                    dim_atom_embed=dim_atomembedding,
                    num_rbf=num_rbf,
                    n_heads=n_heads,
                    activation=activation,
                    max_cycles=max_cycles,
                    time_embedding=time_embedding,
                    use_filter=inter_filter,
                    use_pondering=use_pondering,
                    fixed_cycles=fixed_cycles,
                )
            ] * n_interactions)
        else:
            # use one SchNetInteraction instance for each interaction
            self.interactions = nn.CellList([
                AirNetInteraction(
                    dim_atom_embed=dim_atomembedding,
                    num_rbf=num_rbf,
                    n_heads=n_heads,
                    activation=activation,
                    max_cycles=max_cycles,
                    time_embedding=time_embedding,
                    use_filter=inter_filter,
                    use_pondering=use_pondering,
                    fixed_cycles=fixed_cycles,
                ) for i in range(n_interactions)
            ])

        # readout layer
        if self.use_all_interactions and n_interactions > 1:
            if use_mcr:
                self.gather_interactions = MultipleChannelRepresentation(
                    n_interactions, dim_atomembedding, 1, activation)
            else:
                self.gather_interactions = TensorSum()
        else:
            self.gather_interactions = None

        readoutdim = int(dim_atomembedding / 2)
        self.readout = AtomwiseReadout(dim_atomembedding, self.output_dim, [
            readoutdim,
        ], activation)

        if debug:
            self.debug_fun = self._debug_fun

        self.lmax_label = []
        for i in range(n_interactions):
            self.lmax_label.append('l' + str(i) + '_cycles')

        self.fill = P.Fill()
        self.concat = P.Concat(-1)
        self.pack = P.Pack(-1)
        self.reducesum = P.ReduceSum()
        self.reducemax = P.ReduceMax()
        self.tensor_summary = P.TensorSummary()
        self.scalar_summary = P.ScalarSummary()