def __init__(self, config): super(GNNVirt, self).__init__() log.info("gnn_type is %s" % self.__class__.__name__) self.config = config self.atom_encoder = getattr(ME, self.config.atom_enc_type, ME.AtomEncoder)( self.config.emb_dim) self.virtualnode_embedding = self.create_parameter( shape=[1, self.config.emb_dim], dtype='float32', default_initializer=nn.initializer.Constant(value=0.0)) self.convs = paddle.nn.LayerList() self.batch_norms = paddle.nn.LayerList() self.mlp_virtualnode_list = paddle.nn.LayerList() for layer in range(self.config.num_layers): self.convs.append(getattr(L, self.config.layer_type)(self.config)) self.batch_norms.append(L.batch_norm_1d(self.config.emb_dim)) for layer in range(self.config.num_layers - 1): self.mlp_virtualnode_list.append( nn.Sequential(L.Linear(self.config.emb_dim, self.config.emb_dim), L.batch_norm_1d(self.config.emb_dim), nn.Swish(), L.Linear(self.config.emb_dim, self.config.emb_dim), L.batch_norm_1d(self.config.emb_dim), nn.Swish()) ) self.pool = gnn.GraphPool(pool_type="sum")
def __init__(self, config): super(GNN, self).__init__() log.info("model_type is %s" % self.__class__.__name__) self.config = config self.pretrain_tasks = config.pretrain_tasks.split(',') self.num_layers = config.num_layers self.drop_ratio = config.drop_ratio self.JK = config.JK self.block_num = config.block_num self.emb_dim = config.emb_dim self.num_tasks = config.num_tasks self.residual = config.residual self.graph_pooling = config.graph_pooling if self.num_layers < 2: raise ValueError("Number of GNN layers must be greater than 1.") ### GNN to generate node embeddings self.gnn_blocks = paddle.nn.LayerList() for i in range(self.config.block_num): self.gnn_blocks.append(getattr(CONV, self.config.gnn_type)(config)) hidden_size = self.emb_dim * self.block_num ### Pooling function to generate whole-graph embeddings if self.config.graph_pooling == "bisop": pass else: self.pool = MeanGlobalPool() if self.config.clf_layers == 3: log.info("clf_layers is 3") self.graph_pred_linear = nn.Sequential( L.Linear(hidden_size, hidden_size // 2), L.batch_norm_1d(hidden_size // 2), nn.Swish(), L.Linear(hidden_size // 2, hidden_size // 4), L.batch_norm_1d(hidden_size // 4), nn.Swish(), L.Linear(hidden_size // 4, self.num_tasks)) elif self.config.clf_layers == 2: log.info("clf_layers is 2") self.graph_pred_linear = nn.Sequential( L.Linear(hidden_size, hidden_size // 2), L.batch_norm_1d(hidden_size // 2), nn.Swish(), L.Linear(hidden_size // 2, self.num_tasks)) else: self.graph_pred_linear = L.Linear(hidden_size, self.num_tasks) if 'Con' in self.pretrain_tasks: self.context_loss = nn.CrossEntropyLoss() self.contextmlp = nn.Sequential( L.Linear(self.emb_dim, self.emb_dim // 2), L.batch_norm_1d(self.emb_dim // 2), nn.Swish(), L.Linear(self.emb_dim // 2, 5000)) if 'Ba' in self.pretrain_tasks: self.pretrain_bond_angle = PretrainBondAngle(config) if 'Bl' in self.pretrain_tasks: self.pretrain_bond_length = PretrainBondLength(config)
def __init__(self, config): super(JuncGNNVirt, self).__init__() log.info("gnn_type is %s" % self.__class__.__name__) self.config = config self.num_layers = config.num_layers self.drop_ratio = config.drop_ratio self.JK = config.JK self.residual = config.residual self.emb_dim = config.emb_dim self.gnn_type = config.gnn_type self.layer_type = config.layer_type if self.num_layers < 2: raise ValueError("Number of GNN layers must be greater than 1.") self.atom_encoder = getattr(ME, self.config.atom_enc_type, ME.AtomEncoder)( self.emb_dim) self.junc_embed = paddle.nn.Embedding(6000, self.emb_dim) ### set the initial virtual node embedding to 0. # self.virtualnode_embedding = paddle.nn.Embedding(1, emb_dim) # torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) self.virtualnode_embedding = self.create_parameter( shape=[1, self.emb_dim], dtype='float32', default_initializer=nn.initializer.Constant(value=0.0)) ### List of GNNs self.convs = nn.LayerList() ### batch norms applied to node embeddings self.batch_norms = nn.LayerList() ### List of MLPs to transform virtual node at every layer self.mlp_virtualnode_list = nn.LayerList() self.junc_convs = nn.LayerList() for layer in range(self.num_layers): self.convs.append(getattr(L, self.layer_type)(self.config)) self.junc_convs.append(gnn.GINConv(self.emb_dim, self.emb_dim)) self.batch_norms.append(L.batch_norm_1d(self.emb_dim)) for layer in range(self.num_layers - 1): self.mlp_virtualnode_list.append( nn.Sequential(L.Linear(self.emb_dim, self.emb_dim), L.batch_norm_1d(self.emb_dim), nn.Swish(), L.Linear(self.emb_dim, self.emb_dim), L.batch_norm_1d(self.emb_dim), nn.Swish()) ) self.pool = gnn.GraphPool(pool_type="sum")
def __init__(self, config): super(PretrainBondLength, self).__init__() log.info("Using pretrain bond length") hidden_size = config.emb_dim self.bond_length_pred_linear = nn.Sequential( L.Linear(hidden_size * 2, hidden_size // 2), L.batch_norm_1d(hidden_size // 2), nn.Swish(), L.Linear(hidden_size // 2, hidden_size // 4), L.batch_norm_1d(hidden_size // 4), nn.Swish(), L.Linear(hidden_size // 4, 1)) self.loss = nn.SmoothL1Loss(reduction='none')
def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1): # activation layer act = act_type.lower() if act == 'relu': layer = nn.ReLU() elif act == 'leakyrelu': layer = nn.LeakyReLU(neg_slope, inplace) elif act == 'prelu': layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) elif act == 'swish': layer = nn.Swish() else: raise NotImplementedError('activation layer [%s] is not found' % act) return layer
def __init__(self, config, with_efeat=True): super(CatGINConv, self).__init__() log.info("layer_type is %s" % self.__class__.__name__) self.config = config emb_dim = self.config.emb_dim self.with_efeat = with_efeat self.mlp = nn.Sequential(Linear(emb_dim, emb_dim), batch_norm_1d(emb_dim), nn.Swish(), Linear(emb_dim, emb_dim)) self.send_mlp = nn.Sequential(nn.Linear(2 * emb_dim, 2 * emb_dim), nn.Swish(), Linear(2 * emb_dim, emb_dim)) self.eps = self.create_parameter( shape=[1, 1], dtype='float32', default_initializer=nn.initializer.Constant(value=0)) if self.with_efeat: self.bond_encoder = getattr(ME, self.config.bond_enc_type, ME.BondEncoder)(emb_dim=emb_dim)
def __init__(self, emb_dim): super(CatAtomEncoder, self).__init__() log.info("atom encoder type is %s" % self.__class__.__name__) self.atom_embedding_list = nn.LayerList() for i, dim in enumerate(full_atom_feature_dims): weight_attr = nn.initializer.XavierUniform() emb = paddle.nn.Embedding(dim, emb_dim, weight_attr=weight_attr) self.atom_embedding_list.append(emb) self.mlp = nn.Sequential( nn.Linear(len(full_atom_feature_dims) * emb_dim, 2 * emb_dim), batch_norm_1d(2 * emb_dim), nn.Swish(), nn.Linear(2 * emb_dim, emb_dim))
def conv_bn_swish(out, in_channels, channels, kernel=1, stride=1, pad=0, num_group=1): out.append( nn.Conv2D(in_channels, channels, kernel, stride, pad, groups=num_group, bias_attr=False)) out.append(nn.BatchNorm2D(channels)) out.append(nn.Swish())
def __init__(self, in_channels, out_channels=None, kernel_size=3, norm_type='bn', norm_groups=32, act='swish'): super(SeparableConvLayer, self).__init__() assert norm_type in ['bn', 'sync_bn', 'gn', None] assert act in ['swish', 'relu', None] self.in_channels = in_channels if out_channels is None: self.out_channels = self.in_channels self.norm_type = norm_type self.norm_groups = norm_groups self.depthwise_conv = nn.Conv2D(in_channels, in_channels, kernel_size, padding=kernel_size // 2, groups=in_channels, bias_attr=False) self.pointwise_conv = nn.Conv2D(in_channels, self.out_channels, 1) # norm type if self.norm_type == 'bn': self.norm = nn.BatchNorm2D(self.out_channels) elif self.norm_type == 'sync_bn': self.norm = nn.SyncBatchNorm(self.out_channels) elif self.norm_type == 'gn': self.norm = nn.GroupNorm(num_groups=self.norm_groups, num_channels=self.out_channels) # activation if act == 'swish': self.act = nn.Swish() elif act == 'relu': self.act = nn.ReLU()
def get_activation_layer(activation): """ Create activation layer from string/function. Parameters: ---------- activation : function, or str, or nn.Module Activation function or name of activation function. Returns: ------- nn.Module Activation layer. """ assert activation is not None if isfunction(activation): return activation() elif isinstance(activation, str): if activation == "relu": return nn.ReLU() elif activation == "relu6": return nn.ReLU6() elif activation == "swish": return nn.Swish() elif activation == "hswish": return nn.Hardswish() elif activation == "sigmoid": return nn.Sigmoid() elif activation == "hsigmoid": return nn.Hardsigmoid() elif activation == "identity": return Identity() else: raise NotImplementedError() else: assert isinstance(activation, nn.Layer) return activation
def __init__(self, config, with_efeat=True): super(LiteGEM, self).__init__() log.info("gnn_type is %s" % self.__class__.__name__) self.config = config self.with_efeat = with_efeat self.num_layers = config["num_layers"] self.drop_ratio = config["dropout_rate"] self.virtual_node = config["virtual_node"] self.emb_dim = config["emb_dim"] self.norm = config["norm"] self.num_tasks = config["num_tasks"] self.atom_names = config["atom_names"] self.atom_float_names = config["atom_float_names"] self.bond_names = config["bond_names"] self.gnns = paddle.nn.LayerList() self.norms = paddle.nn.LayerList() if self.virtual_node: log.info("using virtual node in %s" % self.__class__.__name__) self.mlp_virtualnode_list = paddle.nn.LayerList() self.virtualnode_embedding = self.create_parameter( shape=[1, self.emb_dim], dtype='float32', default_initializer=nn.initializer.Constant(value=0.0)) for layer in range(self.num_layers - 1): self.mlp_virtualnode_list.append( MLP([self.emb_dim] * 3, norm=self.norm)) for layer in range(self.num_layers): self.gnns.append( LiteGEMConv(config, with_efeat=not self.with_efeat)) self.norms.append(norm_layer(self.norm, self.emb_dim)) self.atom_embedding = AtomEmbedding(self.atom_names, self.emb_dim) self.atom_float_embedding = AtomFloatEmbedding(self.atom_float_names, self.emb_dim) if self.with_efeat: self.init_bond_embedding = BondEmbedding(self.config["bond_names"], self.emb_dim) self.pool = gnn.GraphPool(pool_type="sum") if not self.config["graphnorm"]: self.gn = gnn.GraphNorm() hidden_size = self.emb_dim if self.config["clf_layers"] == 3: log.info("clf_layers is 3") self.graph_pred_linear = nn.Sequential( Linear(hidden_size, hidden_size // 2), batch_norm_1d(hidden_size // 2), nn.Swish(), Linear(hidden_size // 2, hidden_size // 4), batch_norm_1d(hidden_size // 4), nn.Swish(), Linear(hidden_size // 4, self.num_tasks)) elif self.config["clf_layers"] == 2: log.info("clf_layers is 2") self.graph_pred_linear = nn.Sequential( Linear(hidden_size, hidden_size // 2), batch_norm_1d(hidden_size // 2), nn.Swish(), Linear(hidden_size // 2, self.num_tasks)) else: self.graph_pred_linear = Linear(hidden_size, self.num_tasks)
def func_test_layer_str(self): module = nn.ELU(0.2) self.assertEqual(str(module), 'ELU(alpha=0.2)') module = nn.CELU(0.2) self.assertEqual(str(module), 'CELU(alpha=0.2)') module = nn.GELU(True) self.assertEqual(str(module), 'GELU(approximate=True)') module = nn.Hardshrink() self.assertEqual(str(module), 'Hardshrink(threshold=0.5)') module = nn.Hardswish(name="Hardswish") self.assertEqual(str(module), 'Hardswish(name=Hardswish)') module = nn.Tanh(name="Tanh") self.assertEqual(str(module), 'Tanh(name=Tanh)') module = nn.Hardtanh(name="Hardtanh") self.assertEqual(str(module), 'Hardtanh(min=-1.0, max=1.0, name=Hardtanh)') module = nn.PReLU(1, 0.25, name="PReLU", data_format="NCHW") self.assertEqual( str(module), 'PReLU(num_parameters=1, data_format=NCHW, init=0.25, dtype=float32, name=PReLU)' ) module = nn.ReLU() self.assertEqual(str(module), 'ReLU()') module = nn.ReLU6() self.assertEqual(str(module), 'ReLU6()') module = nn.SELU() self.assertEqual( str(module), 'SELU(scale=1.0507009873554805, alpha=1.6732632423543772)') module = nn.LeakyReLU() self.assertEqual(str(module), 'LeakyReLU(negative_slope=0.01)') module = nn.Sigmoid() self.assertEqual(str(module), 'Sigmoid()') module = nn.Hardsigmoid() self.assertEqual(str(module), 'Hardsigmoid()') module = nn.Softplus() self.assertEqual(str(module), 'Softplus(beta=1, threshold=20)') module = nn.Softshrink() self.assertEqual(str(module), 'Softshrink(threshold=0.5)') module = nn.Softsign() self.assertEqual(str(module), 'Softsign()') module = nn.Swish() self.assertEqual(str(module), 'Swish()') module = nn.Tanhshrink() self.assertEqual(str(module), 'Tanhshrink()') module = nn.ThresholdedReLU() self.assertEqual(str(module), 'ThresholdedReLU(threshold=1.0)') module = nn.LogSigmoid() self.assertEqual(str(module), 'LogSigmoid()') module = nn.Softmax() self.assertEqual(str(module), 'Softmax(axis=-1)') module = nn.LogSoftmax() self.assertEqual(str(module), 'LogSoftmax(axis=-1)') module = nn.Maxout(groups=2) self.assertEqual(str(module), 'Maxout(groups=2, axis=1)') module = nn.Linear(2, 4, name='linear') self.assertEqual( str(module), 'Linear(in_features=2, out_features=4, dtype=float32, name=linear)' ) module = nn.Upsample(size=[12, 12]) self.assertEqual( str(module), 'Upsample(size=[12, 12], mode=nearest, align_corners=False, align_mode=0, data_format=NCHW)' ) module = nn.UpsamplingNearest2D(size=[12, 12]) self.assertEqual( str(module), 'UpsamplingNearest2D(size=[12, 12], data_format=NCHW)') module = nn.UpsamplingBilinear2D(size=[12, 12]) self.assertEqual( str(module), 'UpsamplingBilinear2D(size=[12, 12], data_format=NCHW)') module = nn.Bilinear(in1_features=5, in2_features=4, out_features=1000) self.assertEqual( str(module), 'Bilinear(in1_features=5, in2_features=4, out_features=1000, dtype=float32)' ) module = nn.Dropout(p=0.5) self.assertEqual(str(module), 'Dropout(p=0.5, axis=None, mode=upscale_in_train)') module = nn.Dropout2D(p=0.5) self.assertEqual(str(module), 'Dropout2D(p=0.5, data_format=NCHW)') module = nn.Dropout3D(p=0.5) self.assertEqual(str(module), 'Dropout3D(p=0.5, data_format=NCDHW)') module = nn.AlphaDropout(p=0.5) self.assertEqual(str(module), 'AlphaDropout(p=0.5)') module = nn.Pad1D(padding=[1, 2], mode='constant') self.assertEqual( str(module), 'Pad1D(padding=[1, 2], mode=constant, value=0.0, data_format=NCL)') module = nn.Pad2D(padding=[1, 0, 1, 2], mode='constant') self.assertEqual( str(module), 'Pad2D(padding=[1, 0, 1, 2], mode=constant, value=0.0, data_format=NCHW)' ) module = nn.ZeroPad2D(padding=[1, 0, 1, 2]) self.assertEqual(str(module), 'ZeroPad2D(padding=[1, 0, 1, 2], data_format=NCHW)') module = nn.Pad3D(padding=[1, 0, 1, 2, 0, 0], mode='constant') self.assertEqual( str(module), 'Pad3D(padding=[1, 0, 1, 2, 0, 0], mode=constant, value=0.0, data_format=NCDHW)' ) module = nn.CosineSimilarity(axis=0) self.assertEqual(str(module), 'CosineSimilarity(axis=0, eps=1e-08)') module = nn.Embedding(10, 3, sparse=True) self.assertEqual(str(module), 'Embedding(10, 3, sparse=True)') module = nn.Conv1D(3, 2, 3) self.assertEqual(str(module), 'Conv1D(3, 2, kernel_size=[3], data_format=NCL)') module = nn.Conv1DTranspose(2, 1, 2) self.assertEqual( str(module), 'Conv1DTranspose(2, 1, kernel_size=[2], data_format=NCL)') module = nn.Conv2D(4, 6, (3, 3)) self.assertEqual(str(module), 'Conv2D(4, 6, kernel_size=[3, 3], data_format=NCHW)') module = nn.Conv2DTranspose(4, 6, (3, 3)) self.assertEqual( str(module), 'Conv2DTranspose(4, 6, kernel_size=[3, 3], data_format=NCHW)') module = nn.Conv3D(4, 6, (3, 3, 3)) self.assertEqual( str(module), 'Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)') module = nn.Conv3DTranspose(4, 6, (3, 3, 3)) self.assertEqual( str(module), 'Conv3DTranspose(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)') module = nn.PairwiseDistance() self.assertEqual(str(module), 'PairwiseDistance(p=2.0)') module = nn.InstanceNorm1D(2) self.assertEqual(str(module), 'InstanceNorm1D(num_features=2, epsilon=1e-05)') module = nn.InstanceNorm2D(2) self.assertEqual(str(module), 'InstanceNorm2D(num_features=2, epsilon=1e-05)') module = nn.InstanceNorm3D(2) self.assertEqual(str(module), 'InstanceNorm3D(num_features=2, epsilon=1e-05)') module = nn.GroupNorm(num_channels=6, num_groups=6) self.assertEqual( str(module), 'GroupNorm(num_groups=6, num_channels=6, epsilon=1e-05)') module = nn.LayerNorm([2, 2, 3]) self.assertEqual( str(module), 'LayerNorm(normalized_shape=[2, 2, 3], epsilon=1e-05)') module = nn.BatchNorm1D(1) self.assertEqual( str(module), 'BatchNorm1D(num_features=1, momentum=0.9, epsilon=1e-05, data_format=NCL)' ) module = nn.BatchNorm2D(1) self.assertEqual( str(module), 'BatchNorm2D(num_features=1, momentum=0.9, epsilon=1e-05)') module = nn.BatchNorm3D(1) self.assertEqual( str(module), 'BatchNorm3D(num_features=1, momentum=0.9, epsilon=1e-05, data_format=NCDHW)' ) module = nn.SyncBatchNorm(2) self.assertEqual( str(module), 'SyncBatchNorm(num_features=2, momentum=0.9, epsilon=1e-05)') module = nn.LocalResponseNorm(size=5) self.assertEqual( str(module), 'LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=1.0)') module = nn.AvgPool1D(kernel_size=2, stride=2, padding=0) self.assertEqual(str(module), 'AvgPool1D(kernel_size=2, stride=2, padding=0)') module = nn.AvgPool2D(kernel_size=2, stride=2, padding=0) self.assertEqual(str(module), 'AvgPool2D(kernel_size=2, stride=2, padding=0)') module = nn.AvgPool3D(kernel_size=2, stride=2, padding=0) self.assertEqual(str(module), 'AvgPool3D(kernel_size=2, stride=2, padding=0)') module = nn.MaxPool1D(kernel_size=2, stride=2, padding=0) self.assertEqual(str(module), 'MaxPool1D(kernel_size=2, stride=2, padding=0)') module = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) self.assertEqual(str(module), 'MaxPool2D(kernel_size=2, stride=2, padding=0)') module = nn.MaxPool3D(kernel_size=2, stride=2, padding=0) self.assertEqual(str(module), 'MaxPool3D(kernel_size=2, stride=2, padding=0)') module = nn.AdaptiveAvgPool1D(output_size=16) self.assertEqual(str(module), 'AdaptiveAvgPool1D(output_size=16)') module = nn.AdaptiveAvgPool2D(output_size=3) self.assertEqual(str(module), 'AdaptiveAvgPool2D(output_size=3)') module = nn.AdaptiveAvgPool3D(output_size=3) self.assertEqual(str(module), 'AdaptiveAvgPool3D(output_size=3)') module = nn.AdaptiveMaxPool1D(output_size=16, return_mask=True) self.assertEqual( str(module), 'AdaptiveMaxPool1D(output_size=16, return_mask=True)') module = nn.AdaptiveMaxPool2D(output_size=3, return_mask=True) self.assertEqual(str(module), 'AdaptiveMaxPool2D(output_size=3, return_mask=True)') module = nn.AdaptiveMaxPool3D(output_size=3, return_mask=True) self.assertEqual(str(module), 'AdaptiveMaxPool3D(output_size=3, return_mask=True)') module = nn.SimpleRNNCell(16, 32) self.assertEqual(str(module), 'SimpleRNNCell(16, 32)') module = nn.LSTMCell(16, 32) self.assertEqual(str(module), 'LSTMCell(16, 32)') module = nn.GRUCell(16, 32) self.assertEqual(str(module), 'GRUCell(16, 32)') module = nn.PixelShuffle(3) self.assertEqual(str(module), 'PixelShuffle(upscale_factor=3)') module = nn.SimpleRNN(16, 32, 2) self.assertEqual( str(module), 'SimpleRNN(16, 32, num_layers=2\n (0): RNN(\n (cell): SimpleRNNCell(16, 32)\n )\n (1): RNN(\n (cell): SimpleRNNCell(32, 32)\n )\n)' ) module = nn.LSTM(16, 32, 2) self.assertEqual( str(module), 'LSTM(16, 32, num_layers=2\n (0): RNN(\n (cell): LSTMCell(16, 32)\n )\n (1): RNN(\n (cell): LSTMCell(32, 32)\n )\n)' ) module = nn.GRU(16, 32, 2) self.assertEqual( str(module), 'GRU(16, 32, num_layers=2\n (0): RNN(\n (cell): GRUCell(16, 32)\n )\n (1): RNN(\n (cell): GRUCell(32, 32)\n )\n)' ) module1 = nn.Sequential( ('conv1', nn.Conv2D(1, 20, 5)), ('relu1', nn.ReLU()), ('conv2', nn.Conv2D(20, 64, 5)), ('relu2', nn.ReLU())) self.assertEqual( str(module1), 'Sequential(\n '\ '(conv1): Conv2D(1, 20, kernel_size=[5, 5], data_format=NCHW)\n '\ '(relu1): ReLU()\n '\ '(conv2): Conv2D(20, 64, kernel_size=[5, 5], data_format=NCHW)\n '\ '(relu2): ReLU()\n)' ) module2 = nn.Sequential( nn.Conv3DTranspose(4, 6, (3, 3, 3)), nn.AvgPool3D(kernel_size=2, stride=2, padding=0), nn.Tanh(name="Tanh"), module1, nn.Conv3D(4, 6, (3, 3, 3)), nn.MaxPool3D(kernel_size=2, stride=2, padding=0), nn.GELU(True)) self.assertEqual( str(module2), 'Sequential(\n '\ '(0): Conv3DTranspose(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)\n '\ '(1): AvgPool3D(kernel_size=2, stride=2, padding=0)\n '\ '(2): Tanh(name=Tanh)\n '\ '(3): Sequential(\n (conv1): Conv2D(1, 20, kernel_size=[5, 5], data_format=NCHW)\n (relu1): ReLU()\n'\ ' (conv2): Conv2D(20, 64, kernel_size=[5, 5], data_format=NCHW)\n (relu2): ReLU()\n )\n '\ '(4): Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)\n '\ '(5): MaxPool3D(kernel_size=2, stride=2, padding=0)\n '\ '(6): GELU(approximate=True)\n)' )
def set_swish(self, memory_efficient=True): """Sets swish function as memory efficient (for training) or standard (for export)""" self._swish = nn.Hardswish() if memory_efficient else nn.Swish() for block in self._blocks: block.set_swish(memory_efficient)