def get_activation(fn_name,only_layer=False): """ Args: fn_name (): Returns: Examples: >>> print(get_activation('swish')) """ if fn_name is None: return None fn_modules = ['trident.layers.pytorch_activations', 'trident.backend.pytorch_ops', 'torch.nn.functional'] trident_fn_modules = ['trident.layers.pytorch_activations', 'trident.backend.pytorch_ops'] if only_layer: fn_modules = ['trident.layers.pytorch_activations'] trident_fn_modules = ['trident.layers.pytorch_activations'] try: if isinstance(fn_name, str): if not only_layer and (camel2snake(fn_name)== fn_name or fn_name.lower()== fn_name): if fn_name == 'p_relu' or fn_name == 'prelu': return PRelu() activation_fn = get_function(fn_name, trident_fn_modules) return activation_fn else: try: activation_fn = get_class(snake2camel(fn_name), fn_modules) return activation_fn() except Exception: activation_fn = get_class(fn_name, fn_modules) return activation_fn() elif getattr(fn_name, '__module__', None) == 'trident.layers.pytorch_activations': if inspect.isfunction(fn_name): return partial(fn_name) elif inspect.isclass(fn_name) and fn_name.__class__.__name__=="type": return fn_name() elif isinstance(fn_name, Layer): return fn_name elif inspect.isfunction(fn_name) and getattr(fn_name, '__module__', None) == 'trident.backend.pytorch_ops': if only_layer: activation_layer = get_class(snake2camel(fn_name.__name__), trident_fn_modules) return activation_layer() else: return fn_name else: if callable(fn_name): result = inspect.getfullargspec(fn_name) if 1 <= len(result.args) <= 2: return fn_name if inspect.isfunction(fn_name) else fn_name() else: raise ValueError('Unknown activation function/ class') except Exception as e: print(e) return None
def get_loss(loss_name): if loss_name is None: return None loss_modules = ['trident.optims.tensorflow_losses'] if loss_name in __all__: loss_fn = get_class(loss_name, loss_modules) else: try: loss_fn = get_class(camel2snake(loss_name), loss_modules) except Exception: loss_fn = get_class(loss_name, loss_modules) return loss_fn
def get_normalization(fn_name): if fn_name is None: return None elif isinstance(fn_name, Layer) and 'Norm' in fn_name.__class__.__name__: return fn_name elif inspect.isclass(fn_name) and fn_name.__class__.__name__ == type: return fn_name() elif inspect.isclass(fn_name): return fn_name elif isinstance(fn_name, str): if fn_name.lower().strip() in ['instance_norm', 'instance', 'in', 'i']: return InstanceNorm() elif fn_name.lower().strip() in ['batch_norm', 'batch', 'bn', 'b']: return BatchNorm() elif fn_name.lower().strip() in ['layer_norm', 'layer', 'ln', 'l']: return BatchNorm() elif fn_name.lower().strip() in ['group_norm', 'group', 'gn', 'g']: return GroupNorm(num_groups=16) elif fn_name.lower().strip() in ['evo_normb0', 'evo-b0', 'evob0']: return EvoNormB0() elif fn_name.lower().strip() in ['evo_norms0', 'evo-s0', 'evos0']: return EvoNormS0() elif fn_name.lower().strip() in [ 'spectral_norm', 'spectral', 'spec', 'sp', 's' ]: return SpectralNorm elif fn_name.lower().strip() in ['l2_norm', 'l2']: return L2Norm() fn_modules = ['trident.layers.pytorch_normalizations'] normalization_fn_ = get_class(fn_name, fn_modules) normalization_fn = normalization_fn_ return normalization_fn
def get_normalization(fn_name): if fn_name is None: return None elif isinstance(fn_name, Layer) and 'Norm' in fn_name.__class__.__name__: return fn_name elif inspect.isclass(fn_name): return fn_name elif isinstance(fn_name, str): if fn_name.lower().strip() in ['instance', 'in', 'i']: return None #return InstanceNorm() elif fn_name.lower().strip() in ['batch_norm', 'batch', 'bn', 'b']: return BatchNorm2d() elif fn_name.lower().strip() in ['group', 'g']: return None #return GroupNorm(num_groups=16) elif inspect.isclass(fn_name): return fn_name fn_modules = ['trident.layers.tensorflow_normalizations'] normalization_fn_ = get_class(fn_name, fn_modules) normalization_fn = normalization_fn_ return normalization_fn