def test_register(): assert len(register.act_dict) == 8 assert list(register.act_dict.keys()) == [ 'relu', 'selu', 'prelu', 'elu', 'lrelu_01', 'lrelu_025', 'lrelu_05', 'identity' ] assert str(register.act_dict['relu']) == 'ReLU()' register.register_act('lrelu_03', torch.nn.LeakyReLU(0.3)) assert len(register.act_dict) == 9 assert 'lrelu_03' in register.act_dict
import torch import torch.nn as nn from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.register import register_act class SWISH(nn.Module): def __init__(self, inplace=False): super().__init__() self.inplace = inplace def forward(self, x): if self.inplace: x.mul_(torch.sigmoid(x)) return x else: return x * torch.sigmoid(x) register_act('swish', SWISH(inplace=cfg.mem.inplace)) register_act('lrelu_03', nn.LeakyReLU(0.3, inplace=cfg.mem.inplace))
import torch.nn as nn from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.register import register_act register_act('relu', nn.ReLU(inplace=cfg.mem.inplace)) register_act('selu', nn.SELU(inplace=cfg.mem.inplace)) register_act('prelu', nn.PReLU()) register_act('elu', nn.ELU(inplace=cfg.mem.inplace)) register_act('lrelu_01', nn.LeakyReLU(0.1, inplace=cfg.mem.inplace)) register_act('lrelu_025', nn.LeakyReLU(0.25, inplace=cfg.mem.inplace)) register_act('lrelu_05', nn.LeakyReLU(0.5, inplace=cfg.mem.inplace))
def prelu(): return nn.PReLU() def elu(): return nn.ELU(inplace=cfg.mem.inplace) def lrelu_01(): return nn.LeakyReLU(0.1, inplace=cfg.mem.inplace) def lrelu_025(): return nn.LeakyReLU(0.25, inplace=cfg.mem.inplace) def lrelu_05(): return nn.LeakyReLU(0.5, inplace=cfg.mem.inplace) if cfg is not None: register_act('relu', relu) register_act('selu', selu) register_act('prelu', prelu) register_act('elu', elu) register_act('lrelu_01', lrelu_01) register_act('lrelu_025', lrelu_025) register_act('lrelu_05', lrelu_05)
import torch import torch.nn as nn from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.register import register_act class SWISH(nn.Module): def __init__(self, inplace=False): super().__init__() self.inplace = inplace def forward(self, x): if self.inplace: x.mul_(torch.sigmoid(x)) return x else: return x * torch.sigmoid(x) register_act('swish', SWISH(inplace=cfg.mem.inplace)) register_act('lrelu_03', nn.LeakyReLU(negative_slope=0.3, inplace=cfg.mem.inplace))
from functools import partial import torch import torch.nn as nn from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.register import register_act class SWISH(nn.Module): def __init__(self, inplace=False): super().__init__() self.inplace = inplace def forward(self, x): if self.inplace: x.mul_(torch.sigmoid(x)) return x else: return x * torch.sigmoid(x) register_act('swish', partial(SWISH, inplace=cfg.mem.inplace)) register_act('lrelu_03', partial(nn.LeakyReLU, 0.3, inplace=cfg.mem.inplace))