예제 #1
0
def test_register():
    assert len(register.act_dict) == 8
    assert list(register.act_dict.keys()) == [
        'relu', 'selu', 'prelu', 'elu', 'lrelu_01', 'lrelu_025', 'lrelu_05',
        'identity'
    ]
    assert str(register.act_dict['relu']) == 'ReLU()'

    register.register_act('lrelu_03', torch.nn.LeakyReLU(0.3))
    assert len(register.act_dict) == 9
    assert 'lrelu_03' in register.act_dict
예제 #2
0
import torch
import torch.nn as nn

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_act


class SWISH(nn.Module):
    def __init__(self, inplace=False):
        super().__init__()
        self.inplace = inplace

    def forward(self, x):
        if self.inplace:
            x.mul_(torch.sigmoid(x))
            return x
        else:
            return x * torch.sigmoid(x)


register_act('swish', SWISH(inplace=cfg.mem.inplace))
register_act('lrelu_03', nn.LeakyReLU(0.3, inplace=cfg.mem.inplace))
예제 #3
0
import torch.nn as nn

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_act

register_act('relu', nn.ReLU(inplace=cfg.mem.inplace))
register_act('selu', nn.SELU(inplace=cfg.mem.inplace))
register_act('prelu', nn.PReLU())
register_act('elu', nn.ELU(inplace=cfg.mem.inplace))
register_act('lrelu_01', nn.LeakyReLU(0.1, inplace=cfg.mem.inplace))
register_act('lrelu_025', nn.LeakyReLU(0.25, inplace=cfg.mem.inplace))
register_act('lrelu_05', nn.LeakyReLU(0.5, inplace=cfg.mem.inplace))
예제 #4
0

def prelu():
    return nn.PReLU()


def elu():
    return nn.ELU(inplace=cfg.mem.inplace)


def lrelu_01():
    return nn.LeakyReLU(0.1, inplace=cfg.mem.inplace)


def lrelu_025():
    return nn.LeakyReLU(0.25, inplace=cfg.mem.inplace)


def lrelu_05():
    return nn.LeakyReLU(0.5, inplace=cfg.mem.inplace)


if cfg is not None:
    register_act('relu', relu)
    register_act('selu', selu)
    register_act('prelu', prelu)
    register_act('elu', elu)
    register_act('lrelu_01', lrelu_01)
    register_act('lrelu_025', lrelu_025)
    register_act('lrelu_05', lrelu_05)
예제 #5
0
import torch
import torch.nn as nn
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_act


class SWISH(nn.Module):
    def __init__(self, inplace=False):
        super().__init__()
        self.inplace = inplace

    def forward(self, x):
        if self.inplace:
            x.mul_(torch.sigmoid(x))
            return x
        else:
            return x * torch.sigmoid(x)


register_act('swish', SWISH(inplace=cfg.mem.inplace))

register_act('lrelu_03',
             nn.LeakyReLU(negative_slope=0.3, inplace=cfg.mem.inplace))
예제 #6
0
from functools import partial

import torch
import torch.nn as nn

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_act


class SWISH(nn.Module):
    def __init__(self, inplace=False):
        super().__init__()
        self.inplace = inplace

    def forward(self, x):
        if self.inplace:
            x.mul_(torch.sigmoid(x))
            return x
        else:
            return x * torch.sigmoid(x)


register_act('swish', partial(SWISH, inplace=cfg.mem.inplace))
register_act('lrelu_03', partial(nn.LeakyReLU, 0.3, inplace=cfg.mem.inplace))