Ejemplo n.º 1
0
Archivo: conv.py Proyecto: WenjinW/PGL
    def __init__(self, config):
        super(GNNVirt, self).__init__()
        log.info("gnn_type is %s" % self.__class__.__name__)
        self.config = config

        self.atom_encoder = getattr(ME, self.config.atom_enc_type, ME.AtomEncoder)(
                self.config.emb_dim)

        self.virtualnode_embedding = self.create_parameter(
            shape=[1, self.config.emb_dim],
            dtype='float32',
            default_initializer=nn.initializer.Constant(value=0.0))

        self.convs = paddle.nn.LayerList()
        self.batch_norms = paddle.nn.LayerList()
        self.mlp_virtualnode_list = paddle.nn.LayerList()

        for layer in range(self.config.num_layers):
            self.convs.append(getattr(L, self.config.layer_type)(self.config))
            self.batch_norms.append(L.batch_norm_1d(self.config.emb_dim))

        for layer in range(self.config.num_layers - 1):
            self.mlp_virtualnode_list.append(
                    nn.Sequential(L.Linear(self.config.emb_dim, self.config.emb_dim), 
                        L.batch_norm_1d(self.config.emb_dim), 
                        nn.Swish(),
                        L.Linear(self.config.emb_dim, self.config.emb_dim), 
                        L.batch_norm_1d(self.config.emb_dim), 
                        nn.Swish())
                    )

        self.pool = gnn.GraphPool(pool_type="sum")
Ejemplo n.º 2
0
Archivo: model.py Proyecto: WenjinW/PGL
    def __init__(self, config):
        super(GNN, self).__init__()
        log.info("model_type is %s" % self.__class__.__name__)

        self.config = config
        self.pretrain_tasks = config.pretrain_tasks.split(',')
        self.num_layers = config.num_layers
        self.drop_ratio = config.drop_ratio
        self.JK = config.JK
        self.block_num = config.block_num
        self.emb_dim = config.emb_dim
        self.num_tasks = config.num_tasks
        self.residual = config.residual
        self.graph_pooling = config.graph_pooling

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings
        self.gnn_blocks = paddle.nn.LayerList()
        for i in range(self.config.block_num):
            self.gnn_blocks.append(getattr(CONV, self.config.gnn_type)(config))

        hidden_size = self.emb_dim * self.block_num
        ### Pooling function to generate whole-graph embeddings
        if self.config.graph_pooling == "bisop":
            pass
        else:
            self.pool = MeanGlobalPool()

        if self.config.clf_layers == 3:
            log.info("clf_layers is 3")
            self.graph_pred_linear = nn.Sequential(
                L.Linear(hidden_size, hidden_size // 2),
                L.batch_norm_1d(hidden_size // 2), nn.Swish(),
                L.Linear(hidden_size // 2, hidden_size // 4),
                L.batch_norm_1d(hidden_size // 4), nn.Swish(),
                L.Linear(hidden_size // 4, self.num_tasks))
        elif self.config.clf_layers == 2:
            log.info("clf_layers is 2")
            self.graph_pred_linear = nn.Sequential(
                L.Linear(hidden_size, hidden_size // 2),
                L.batch_norm_1d(hidden_size // 2), nn.Swish(),
                L.Linear(hidden_size // 2, self.num_tasks))
        else:
            self.graph_pred_linear = L.Linear(hidden_size, self.num_tasks)

        if 'Con' in self.pretrain_tasks:
            self.context_loss = nn.CrossEntropyLoss()
            self.contextmlp = nn.Sequential(
                L.Linear(self.emb_dim, self.emb_dim // 2),
                L.batch_norm_1d(self.emb_dim // 2), nn.Swish(),
                L.Linear(self.emb_dim // 2, 5000))
        if 'Ba' in self.pretrain_tasks:
            self.pretrain_bond_angle = PretrainBondAngle(config)
        if 'Bl' in self.pretrain_tasks:
            self.pretrain_bond_length = PretrainBondLength(config)
Ejemplo n.º 3
0
Archivo: conv.py Proyecto: WenjinW/PGL
    def __init__(self, config):
        super(JuncGNNVirt, self).__init__()
        log.info("gnn_type is %s" % self.__class__.__name__)
        self.config = config
        self.num_layers = config.num_layers
        self.drop_ratio = config.drop_ratio
        self.JK = config.JK
        self.residual = config.residual
        self.emb_dim = config.emb_dim
        self.gnn_type = config.gnn_type
        self.layer_type = config.layer_type

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = getattr(ME, self.config.atom_enc_type, ME.AtomEncoder)(
                self.emb_dim)

        self.junc_embed = paddle.nn.Embedding(6000, self.emb_dim)

        ### set the initial virtual node embedding to 0.
        #  self.virtualnode_embedding = paddle.nn.Embedding(1, emb_dim)
        #  torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)
        self.virtualnode_embedding = self.create_parameter(
            shape=[1, self.emb_dim],
            dtype='float32',
            default_initializer=nn.initializer.Constant(value=0.0))

        ### List of GNNs
        self.convs = nn.LayerList()
        ### batch norms applied to node embeddings
        self.batch_norms = nn.LayerList()

        ### List of MLPs to transform virtual node at every layer
        self.mlp_virtualnode_list = nn.LayerList()

        self.junc_convs = nn.LayerList()

        for layer in range(self.num_layers):
            self.convs.append(getattr(L, self.layer_type)(self.config))
            self.junc_convs.append(gnn.GINConv(self.emb_dim, self.emb_dim))

            self.batch_norms.append(L.batch_norm_1d(self.emb_dim))

        for layer in range(self.num_layers - 1):
            self.mlp_virtualnode_list.append(
                    nn.Sequential(L.Linear(self.emb_dim, self.emb_dim), 
                        L.batch_norm_1d(self.emb_dim), 
                        nn.Swish(),
                        L.Linear(self.emb_dim, self.emb_dim), 
                        L.batch_norm_1d(self.emb_dim), 
                        nn.Swish())
                    )

        self.pool = gnn.GraphPool(pool_type="sum")
Ejemplo n.º 4
0
Archivo: model.py Proyecto: WenjinW/PGL
 def __init__(self, config):
     super(PretrainBondLength, self).__init__()
     log.info("Using pretrain bond length")
     hidden_size = config.emb_dim
     self.bond_length_pred_linear = nn.Sequential(
         L.Linear(hidden_size * 2, hidden_size // 2),
         L.batch_norm_1d(hidden_size // 2), nn.Swish(),
         L.Linear(hidden_size // 2, hidden_size // 4),
         L.batch_norm_1d(hidden_size // 4), nn.Swish(),
         L.Linear(hidden_size // 4, 1))
     self.loss = nn.SmoothL1Loss(reduction='none')
Ejemplo n.º 5
0
def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1):
    # activation layer
    act = act_type.lower()
    if act == 'relu':
        layer = nn.ReLU()
    elif act == 'leakyrelu':
        layer = nn.LeakyReLU(neg_slope, inplace)
    elif act == 'prelu':
        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
    elif act == 'swish':
        layer = nn.Swish()
    else:
        raise NotImplementedError('activation layer [%s] is not found' % act)
    return layer
Ejemplo n.º 6
0
    def __init__(self, config, with_efeat=True):
        super(CatGINConv, self).__init__()
        log.info("layer_type is %s" % self.__class__.__name__)
        self.config = config
        emb_dim = self.config.emb_dim

        self.with_efeat = with_efeat

        self.mlp = nn.Sequential(Linear(emb_dim, emb_dim),
                                 batch_norm_1d(emb_dim), nn.Swish(),
                                 Linear(emb_dim, emb_dim))

        self.send_mlp = nn.Sequential(nn.Linear(2 * emb_dim, 2 * emb_dim),
                                      nn.Swish(), Linear(2 * emb_dim, emb_dim))

        self.eps = self.create_parameter(
            shape=[1, 1],
            dtype='float32',
            default_initializer=nn.initializer.Constant(value=0))

        if self.with_efeat:
            self.bond_encoder = getattr(ME, self.config.bond_enc_type,
                                        ME.BondEncoder)(emb_dim=emb_dim)
Ejemplo n.º 7
0
    def __init__(self, emb_dim):
        super(CatAtomEncoder, self).__init__()
        log.info("atom encoder type is %s" % self.__class__.__name__)

        self.atom_embedding_list = nn.LayerList()

        for i, dim in enumerate(full_atom_feature_dims):
            weight_attr = nn.initializer.XavierUniform()
            emb = paddle.nn.Embedding(dim, emb_dim, weight_attr=weight_attr)
            self.atom_embedding_list.append(emb)

        self.mlp = nn.Sequential(
            nn.Linear(len(full_atom_feature_dims) * emb_dim, 2 * emb_dim),
            batch_norm_1d(2 * emb_dim), nn.Swish(),
            nn.Linear(2 * emb_dim, emb_dim))
Ejemplo n.º 8
0
def conv_bn_swish(out,
                  in_channels,
                  channels,
                  kernel=1,
                  stride=1,
                  pad=0,
                  num_group=1):
    out.append(
        nn.Conv2D(in_channels,
                  channels,
                  kernel,
                  stride,
                  pad,
                  groups=num_group,
                  bias_attr=False))
    out.append(nn.BatchNorm2D(channels))
    out.append(nn.Swish())
Ejemplo n.º 9
0
    def __init__(self,
                 in_channels,
                 out_channels=None,
                 kernel_size=3,
                 norm_type='bn',
                 norm_groups=32,
                 act='swish'):
        super(SeparableConvLayer, self).__init__()
        assert norm_type in ['bn', 'sync_bn', 'gn', None]
        assert act in ['swish', 'relu', None]

        self.in_channels = in_channels
        if out_channels is None:
            self.out_channels = self.in_channels
        self.norm_type = norm_type
        self.norm_groups = norm_groups
        self.depthwise_conv = nn.Conv2D(in_channels,
                                        in_channels,
                                        kernel_size,
                                        padding=kernel_size // 2,
                                        groups=in_channels,
                                        bias_attr=False)
        self.pointwise_conv = nn.Conv2D(in_channels, self.out_channels, 1)

        # norm type
        if self.norm_type == 'bn':
            self.norm = nn.BatchNorm2D(self.out_channels)
        elif self.norm_type == 'sync_bn':
            self.norm = nn.SyncBatchNorm(self.out_channels)
        elif self.norm_type == 'gn':
            self.norm = nn.GroupNorm(num_groups=self.norm_groups,
                                     num_channels=self.out_channels)

        # activation
        if act == 'swish':
            self.act = nn.Swish()
        elif act == 'relu':
            self.act = nn.ReLU()
Ejemplo n.º 10
0
def get_activation_layer(activation):
    """
    Create activation layer from string/function.

    Parameters:
    ----------
    activation : function, or str, or nn.Module
        Activation function or name of activation function.

    Returns:
    -------
    nn.Module
        Activation layer.
    """
    assert activation is not None
    if isfunction(activation):
        return activation()
    elif isinstance(activation, str):
        if activation == "relu":
            return nn.ReLU()
        elif activation == "relu6":
            return nn.ReLU6()
        elif activation == "swish":
            return nn.Swish()
        elif activation == "hswish":
            return nn.Hardswish()
        elif activation == "sigmoid":
            return nn.Sigmoid()
        elif activation == "hsigmoid":
            return nn.Hardsigmoid()
        elif activation == "identity":
            return Identity()
        else:
            raise NotImplementedError()
    else:
        assert isinstance(activation, nn.Layer)
        return activation
Ejemplo n.º 11
0
    def __init__(self, config, with_efeat=True):
        super(LiteGEM, self).__init__()
        log.info("gnn_type is %s" % self.__class__.__name__)

        self.config = config
        self.with_efeat = with_efeat
        self.num_layers = config["num_layers"]
        self.drop_ratio = config["dropout_rate"]
        self.virtual_node = config["virtual_node"]
        self.emb_dim = config["emb_dim"]
        self.norm = config["norm"]
        self.num_tasks = config["num_tasks"]

        self.atom_names = config["atom_names"]
        self.atom_float_names = config["atom_float_names"]
        self.bond_names = config["bond_names"]
        self.gnns = paddle.nn.LayerList()
        self.norms = paddle.nn.LayerList()

        if self.virtual_node:
            log.info("using virtual node in %s" % self.__class__.__name__)
            self.mlp_virtualnode_list = paddle.nn.LayerList()

            self.virtualnode_embedding = self.create_parameter(
                shape=[1, self.emb_dim],
                dtype='float32',
                default_initializer=nn.initializer.Constant(value=0.0))

            for layer in range(self.num_layers - 1):
                self.mlp_virtualnode_list.append(
                    MLP([self.emb_dim] * 3, norm=self.norm))

        for layer in range(self.num_layers):
            self.gnns.append(
                LiteGEMConv(config, with_efeat=not self.with_efeat))
            self.norms.append(norm_layer(self.norm, self.emb_dim))

        self.atom_embedding = AtomEmbedding(self.atom_names, self.emb_dim)
        self.atom_float_embedding = AtomFloatEmbedding(self.atom_float_names,
                                                       self.emb_dim)

        if self.with_efeat:
            self.init_bond_embedding = BondEmbedding(self.config["bond_names"],
                                                     self.emb_dim)

        self.pool = gnn.GraphPool(pool_type="sum")

        if not self.config["graphnorm"]:
            self.gn = gnn.GraphNorm()

        hidden_size = self.emb_dim

        if self.config["clf_layers"] == 3:
            log.info("clf_layers is 3")
            self.graph_pred_linear = nn.Sequential(
                Linear(hidden_size, hidden_size // 2),
                batch_norm_1d(hidden_size // 2), nn.Swish(),
                Linear(hidden_size // 2, hidden_size // 4),
                batch_norm_1d(hidden_size // 4), nn.Swish(),
                Linear(hidden_size // 4, self.num_tasks))
        elif self.config["clf_layers"] == 2:
            log.info("clf_layers is 2")
            self.graph_pred_linear = nn.Sequential(
                Linear(hidden_size, hidden_size // 2),
                batch_norm_1d(hidden_size // 2), nn.Swish(),
                Linear(hidden_size // 2, self.num_tasks))
        else:
            self.graph_pred_linear = Linear(hidden_size, self.num_tasks)
Ejemplo n.º 12
0
    def func_test_layer_str(self):
        module = nn.ELU(0.2)
        self.assertEqual(str(module), 'ELU(alpha=0.2)')

        module = nn.CELU(0.2)
        self.assertEqual(str(module), 'CELU(alpha=0.2)')

        module = nn.GELU(True)
        self.assertEqual(str(module), 'GELU(approximate=True)')

        module = nn.Hardshrink()
        self.assertEqual(str(module), 'Hardshrink(threshold=0.5)')

        module = nn.Hardswish(name="Hardswish")
        self.assertEqual(str(module), 'Hardswish(name=Hardswish)')

        module = nn.Tanh(name="Tanh")
        self.assertEqual(str(module), 'Tanh(name=Tanh)')

        module = nn.Hardtanh(name="Hardtanh")
        self.assertEqual(str(module),
                         'Hardtanh(min=-1.0, max=1.0, name=Hardtanh)')

        module = nn.PReLU(1, 0.25, name="PReLU", data_format="NCHW")
        self.assertEqual(
            str(module),
            'PReLU(num_parameters=1, data_format=NCHW, init=0.25, dtype=float32, name=PReLU)'
        )

        module = nn.ReLU()
        self.assertEqual(str(module), 'ReLU()')

        module = nn.ReLU6()
        self.assertEqual(str(module), 'ReLU6()')

        module = nn.SELU()
        self.assertEqual(
            str(module),
            'SELU(scale=1.0507009873554805, alpha=1.6732632423543772)')

        module = nn.LeakyReLU()
        self.assertEqual(str(module), 'LeakyReLU(negative_slope=0.01)')

        module = nn.Sigmoid()
        self.assertEqual(str(module), 'Sigmoid()')

        module = nn.Hardsigmoid()
        self.assertEqual(str(module), 'Hardsigmoid()')

        module = nn.Softplus()
        self.assertEqual(str(module), 'Softplus(beta=1, threshold=20)')

        module = nn.Softshrink()
        self.assertEqual(str(module), 'Softshrink(threshold=0.5)')

        module = nn.Softsign()
        self.assertEqual(str(module), 'Softsign()')

        module = nn.Swish()
        self.assertEqual(str(module), 'Swish()')

        module = nn.Tanhshrink()
        self.assertEqual(str(module), 'Tanhshrink()')

        module = nn.ThresholdedReLU()
        self.assertEqual(str(module), 'ThresholdedReLU(threshold=1.0)')

        module = nn.LogSigmoid()
        self.assertEqual(str(module), 'LogSigmoid()')

        module = nn.Softmax()
        self.assertEqual(str(module), 'Softmax(axis=-1)')

        module = nn.LogSoftmax()
        self.assertEqual(str(module), 'LogSoftmax(axis=-1)')

        module = nn.Maxout(groups=2)
        self.assertEqual(str(module), 'Maxout(groups=2, axis=1)')

        module = nn.Linear(2, 4, name='linear')
        self.assertEqual(
            str(module),
            'Linear(in_features=2, out_features=4, dtype=float32, name=linear)'
        )

        module = nn.Upsample(size=[12, 12])
        self.assertEqual(
            str(module),
            'Upsample(size=[12, 12], mode=nearest, align_corners=False, align_mode=0, data_format=NCHW)'
        )

        module = nn.UpsamplingNearest2D(size=[12, 12])
        self.assertEqual(
            str(module),
            'UpsamplingNearest2D(size=[12, 12], data_format=NCHW)')

        module = nn.UpsamplingBilinear2D(size=[12, 12])
        self.assertEqual(
            str(module),
            'UpsamplingBilinear2D(size=[12, 12], data_format=NCHW)')

        module = nn.Bilinear(in1_features=5, in2_features=4, out_features=1000)
        self.assertEqual(
            str(module),
            'Bilinear(in1_features=5, in2_features=4, out_features=1000, dtype=float32)'
        )

        module = nn.Dropout(p=0.5)
        self.assertEqual(str(module),
                         'Dropout(p=0.5, axis=None, mode=upscale_in_train)')

        module = nn.Dropout2D(p=0.5)
        self.assertEqual(str(module), 'Dropout2D(p=0.5, data_format=NCHW)')

        module = nn.Dropout3D(p=0.5)
        self.assertEqual(str(module), 'Dropout3D(p=0.5, data_format=NCDHW)')

        module = nn.AlphaDropout(p=0.5)
        self.assertEqual(str(module), 'AlphaDropout(p=0.5)')

        module = nn.Pad1D(padding=[1, 2], mode='constant')
        self.assertEqual(
            str(module),
            'Pad1D(padding=[1, 2], mode=constant, value=0.0, data_format=NCL)')

        module = nn.Pad2D(padding=[1, 0, 1, 2], mode='constant')
        self.assertEqual(
            str(module),
            'Pad2D(padding=[1, 0, 1, 2], mode=constant, value=0.0, data_format=NCHW)'
        )

        module = nn.ZeroPad2D(padding=[1, 0, 1, 2])
        self.assertEqual(str(module),
                         'ZeroPad2D(padding=[1, 0, 1, 2], data_format=NCHW)')

        module = nn.Pad3D(padding=[1, 0, 1, 2, 0, 0], mode='constant')
        self.assertEqual(
            str(module),
            'Pad3D(padding=[1, 0, 1, 2, 0, 0], mode=constant, value=0.0, data_format=NCDHW)'
        )

        module = nn.CosineSimilarity(axis=0)
        self.assertEqual(str(module), 'CosineSimilarity(axis=0, eps=1e-08)')

        module = nn.Embedding(10, 3, sparse=True)
        self.assertEqual(str(module), 'Embedding(10, 3, sparse=True)')

        module = nn.Conv1D(3, 2, 3)
        self.assertEqual(str(module),
                         'Conv1D(3, 2, kernel_size=[3], data_format=NCL)')

        module = nn.Conv1DTranspose(2, 1, 2)
        self.assertEqual(
            str(module),
            'Conv1DTranspose(2, 1, kernel_size=[2], data_format=NCL)')

        module = nn.Conv2D(4, 6, (3, 3))
        self.assertEqual(str(module),
                         'Conv2D(4, 6, kernel_size=[3, 3], data_format=NCHW)')

        module = nn.Conv2DTranspose(4, 6, (3, 3))
        self.assertEqual(
            str(module),
            'Conv2DTranspose(4, 6, kernel_size=[3, 3], data_format=NCHW)')

        module = nn.Conv3D(4, 6, (3, 3, 3))
        self.assertEqual(
            str(module),
            'Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)')

        module = nn.Conv3DTranspose(4, 6, (3, 3, 3))
        self.assertEqual(
            str(module),
            'Conv3DTranspose(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)')

        module = nn.PairwiseDistance()
        self.assertEqual(str(module), 'PairwiseDistance(p=2.0)')

        module = nn.InstanceNorm1D(2)
        self.assertEqual(str(module),
                         'InstanceNorm1D(num_features=2, epsilon=1e-05)')

        module = nn.InstanceNorm2D(2)
        self.assertEqual(str(module),
                         'InstanceNorm2D(num_features=2, epsilon=1e-05)')

        module = nn.InstanceNorm3D(2)
        self.assertEqual(str(module),
                         'InstanceNorm3D(num_features=2, epsilon=1e-05)')

        module = nn.GroupNorm(num_channels=6, num_groups=6)
        self.assertEqual(
            str(module),
            'GroupNorm(num_groups=6, num_channels=6, epsilon=1e-05)')

        module = nn.LayerNorm([2, 2, 3])
        self.assertEqual(
            str(module),
            'LayerNorm(normalized_shape=[2, 2, 3], epsilon=1e-05)')

        module = nn.BatchNorm1D(1)
        self.assertEqual(
            str(module),
            'BatchNorm1D(num_features=1, momentum=0.9, epsilon=1e-05, data_format=NCL)'
        )

        module = nn.BatchNorm2D(1)
        self.assertEqual(
            str(module),
            'BatchNorm2D(num_features=1, momentum=0.9, epsilon=1e-05)')

        module = nn.BatchNorm3D(1)
        self.assertEqual(
            str(module),
            'BatchNorm3D(num_features=1, momentum=0.9, epsilon=1e-05, data_format=NCDHW)'
        )

        module = nn.SyncBatchNorm(2)
        self.assertEqual(
            str(module),
            'SyncBatchNorm(num_features=2, momentum=0.9, epsilon=1e-05)')

        module = nn.LocalResponseNorm(size=5)
        self.assertEqual(
            str(module),
            'LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=1.0)')

        module = nn.AvgPool1D(kernel_size=2, stride=2, padding=0)
        self.assertEqual(str(module),
                         'AvgPool1D(kernel_size=2, stride=2, padding=0)')

        module = nn.AvgPool2D(kernel_size=2, stride=2, padding=0)
        self.assertEqual(str(module),
                         'AvgPool2D(kernel_size=2, stride=2, padding=0)')

        module = nn.AvgPool3D(kernel_size=2, stride=2, padding=0)
        self.assertEqual(str(module),
                         'AvgPool3D(kernel_size=2, stride=2, padding=0)')

        module = nn.MaxPool1D(kernel_size=2, stride=2, padding=0)
        self.assertEqual(str(module),
                         'MaxPool1D(kernel_size=2, stride=2, padding=0)')

        module = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
        self.assertEqual(str(module),
                         'MaxPool2D(kernel_size=2, stride=2, padding=0)')

        module = nn.MaxPool3D(kernel_size=2, stride=2, padding=0)
        self.assertEqual(str(module),
                         'MaxPool3D(kernel_size=2, stride=2, padding=0)')

        module = nn.AdaptiveAvgPool1D(output_size=16)
        self.assertEqual(str(module), 'AdaptiveAvgPool1D(output_size=16)')

        module = nn.AdaptiveAvgPool2D(output_size=3)
        self.assertEqual(str(module), 'AdaptiveAvgPool2D(output_size=3)')

        module = nn.AdaptiveAvgPool3D(output_size=3)
        self.assertEqual(str(module), 'AdaptiveAvgPool3D(output_size=3)')

        module = nn.AdaptiveMaxPool1D(output_size=16, return_mask=True)
        self.assertEqual(
            str(module), 'AdaptiveMaxPool1D(output_size=16, return_mask=True)')

        module = nn.AdaptiveMaxPool2D(output_size=3, return_mask=True)
        self.assertEqual(str(module),
                         'AdaptiveMaxPool2D(output_size=3, return_mask=True)')

        module = nn.AdaptiveMaxPool3D(output_size=3, return_mask=True)
        self.assertEqual(str(module),
                         'AdaptiveMaxPool3D(output_size=3, return_mask=True)')

        module = nn.SimpleRNNCell(16, 32)
        self.assertEqual(str(module), 'SimpleRNNCell(16, 32)')

        module = nn.LSTMCell(16, 32)
        self.assertEqual(str(module), 'LSTMCell(16, 32)')

        module = nn.GRUCell(16, 32)
        self.assertEqual(str(module), 'GRUCell(16, 32)')

        module = nn.PixelShuffle(3)
        self.assertEqual(str(module), 'PixelShuffle(upscale_factor=3)')

        module = nn.SimpleRNN(16, 32, 2)
        self.assertEqual(
            str(module),
            'SimpleRNN(16, 32, num_layers=2\n  (0): RNN(\n    (cell): SimpleRNNCell(16, 32)\n  )\n  (1): RNN(\n    (cell): SimpleRNNCell(32, 32)\n  )\n)'
        )

        module = nn.LSTM(16, 32, 2)
        self.assertEqual(
            str(module),
            'LSTM(16, 32, num_layers=2\n  (0): RNN(\n    (cell): LSTMCell(16, 32)\n  )\n  (1): RNN(\n    (cell): LSTMCell(32, 32)\n  )\n)'
        )

        module = nn.GRU(16, 32, 2)
        self.assertEqual(
            str(module),
            'GRU(16, 32, num_layers=2\n  (0): RNN(\n    (cell): GRUCell(16, 32)\n  )\n  (1): RNN(\n    (cell): GRUCell(32, 32)\n  )\n)'
        )

        module1 = nn.Sequential(
            ('conv1', nn.Conv2D(1, 20, 5)), ('relu1', nn.ReLU()),
            ('conv2', nn.Conv2D(20, 64, 5)), ('relu2', nn.ReLU()))
        self.assertEqual(
            str(module1),
            'Sequential(\n  '\
            '(conv1): Conv2D(1, 20, kernel_size=[5, 5], data_format=NCHW)\n  '\
            '(relu1): ReLU()\n  '\
            '(conv2): Conv2D(20, 64, kernel_size=[5, 5], data_format=NCHW)\n  '\
            '(relu2): ReLU()\n)'
        )

        module2 = nn.Sequential(
            nn.Conv3DTranspose(4, 6, (3, 3, 3)),
            nn.AvgPool3D(kernel_size=2, stride=2, padding=0),
            nn.Tanh(name="Tanh"), module1, nn.Conv3D(4, 6, (3, 3, 3)),
            nn.MaxPool3D(kernel_size=2, stride=2, padding=0), nn.GELU(True))
        self.assertEqual(
            str(module2),
            'Sequential(\n  '\
            '(0): Conv3DTranspose(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)\n  '\
            '(1): AvgPool3D(kernel_size=2, stride=2, padding=0)\n  '\
            '(2): Tanh(name=Tanh)\n  '\
            '(3): Sequential(\n    (conv1): Conv2D(1, 20, kernel_size=[5, 5], data_format=NCHW)\n    (relu1): ReLU()\n'\
            '    (conv2): Conv2D(20, 64, kernel_size=[5, 5], data_format=NCHW)\n    (relu2): ReLU()\n  )\n  '\
            '(4): Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)\n  '\
            '(5): MaxPool3D(kernel_size=2, stride=2, padding=0)\n  '\
            '(6): GELU(approximate=True)\n)'
        )
Ejemplo n.º 13
0
 def set_swish(self, memory_efficient=True):
     """Sets swish function as memory efficient (for training) or standard (for export)"""
     self._swish = nn.Hardswish() if memory_efficient else nn.Swish()
     for block in self._blocks:
         block.set_swish(memory_efficient)