コード例 #1
0
def get_activation(fn_name,only_layer=False):
    """

    Args:
        fn_name ():

    Returns:

    Examples:
        >>> print(get_activation('swish'))



    """
    if fn_name is None:
        return None
    fn_modules = ['trident.layers.pytorch_activations', 'trident.backend.pytorch_ops', 'torch.nn.functional']
    trident_fn_modules = ['trident.layers.pytorch_activations', 'trident.backend.pytorch_ops']
    if only_layer:
        fn_modules = ['trident.layers.pytorch_activations']
        trident_fn_modules = ['trident.layers.pytorch_activations']
    try:
        if isinstance(fn_name, str):
            if not only_layer and (camel2snake(fn_name)== fn_name or fn_name.lower()== fn_name):
                if fn_name == 'p_relu' or fn_name == 'prelu':
                    return PRelu()
                activation_fn = get_function(fn_name, trident_fn_modules)
                return activation_fn
            else:
                try:
                    activation_fn = get_class(snake2camel(fn_name), fn_modules)
                    return activation_fn()
                except Exception:
                    activation_fn = get_class(fn_name, fn_modules)
                    return activation_fn()
        elif getattr(fn_name, '__module__', None) == 'trident.layers.pytorch_activations':
            if inspect.isfunction(fn_name):
                return partial(fn_name)
            elif inspect.isclass(fn_name) and  fn_name.__class__.__name__=="type":
                return fn_name()
            elif isinstance(fn_name, Layer):
                return fn_name
        elif inspect.isfunction(fn_name) and getattr(fn_name, '__module__', None) == 'trident.backend.pytorch_ops':
            if only_layer:
                activation_layer = get_class(snake2camel(fn_name.__name__), trident_fn_modules)
                return activation_layer()
            else:
                return fn_name

        else:
            if callable(fn_name):
                result = inspect.getfullargspec(fn_name)
                if 1 <= len(result.args) <= 2:
                    return fn_name if inspect.isfunction(fn_name) else fn_name()
                else:
                    raise ValueError('Unknown activation function/ class')
    except Exception as e:
        print(e)
        return None
コード例 #2
0
def get_loss(loss_name):
    if loss_name is None:
        return None
    loss_modules = ['trident.optims.tensorflow_losses']
    if loss_name in __all__:
        loss_fn = get_class(loss_name, loss_modules)
    else:
        try:
            loss_fn = get_class(camel2snake(loss_name), loss_modules)
        except Exception:
            loss_fn = get_class(loss_name, loss_modules)
    return loss_fn
コード例 #3
0
def get_normalization(fn_name):
    if fn_name is None:
        return None
    elif isinstance(fn_name, Layer) and 'Norm' in fn_name.__class__.__name__:
        return fn_name
    elif inspect.isclass(fn_name) and fn_name.__class__.__name__ == type:
        return fn_name()
    elif inspect.isclass(fn_name):
        return fn_name

    elif isinstance(fn_name, str):
        if fn_name.lower().strip() in ['instance_norm', 'instance', 'in', 'i']:
            return InstanceNorm()
        elif fn_name.lower().strip() in ['batch_norm', 'batch', 'bn', 'b']:
            return BatchNorm()
        elif fn_name.lower().strip() in ['layer_norm', 'layer', 'ln', 'l']:
            return BatchNorm()
        elif fn_name.lower().strip() in ['group_norm', 'group', 'gn', 'g']:
            return GroupNorm(num_groups=16)
        elif fn_name.lower().strip() in ['evo_normb0', 'evo-b0', 'evob0']:
            return EvoNormB0()
        elif fn_name.lower().strip() in ['evo_norms0', 'evo-s0', 'evos0']:
            return EvoNormS0()
        elif fn_name.lower().strip() in [
                'spectral_norm', 'spectral', 'spec', 'sp', 's'
        ]:
            return SpectralNorm
        elif fn_name.lower().strip() in ['l2_norm', 'l2']:
            return L2Norm()

    fn_modules = ['trident.layers.pytorch_normalizations']
    normalization_fn_ = get_class(fn_name, fn_modules)
    normalization_fn = normalization_fn_
    return normalization_fn
コード例 #4
0
def get_normalization(fn_name):
    if fn_name is None:
        return None
    elif isinstance(fn_name, Layer) and 'Norm' in fn_name.__class__.__name__:
        return fn_name
    elif inspect.isclass(fn_name):
        return fn_name
    elif isinstance(fn_name, str):
        if fn_name.lower().strip() in ['instance', 'in', 'i']:
            return None
            #return InstanceNorm()
        elif fn_name.lower().strip() in ['batch_norm', 'batch', 'bn', 'b']:
            return BatchNorm2d()
        elif fn_name.lower().strip() in ['group', 'g']:
            return None
            #return GroupNorm(num_groups=16)
    elif inspect.isclass(fn_name):
        return fn_name
    fn_modules = ['trident.layers.tensorflow_normalizations']
    normalization_fn_ = get_class(fn_name, fn_modules)
    normalization_fn = normalization_fn_
    return normalization_fn