def get(identifier, **kwargs): # Sanity check assert len(kwargs) == 0 # Return identifier directly if it is callable if callable(identifier): return identifier elif isinstance(identifier, six.string_types): # Parse identifier p = Parser.parse(identifier) identifier = p.name.lower() if identifier in ['relu']: return relu elif identifier in ['lrelu', 'leakyrelu', 'leaky-relu']: leak = p.get_arg(float, default=0.1) return lambda x: leaky_relu(x, leak=leak) elif identifier in ['softmax']: return softmax elif identifier in ['cumax']: return cumax elif identifier in ['sigmoid']: return lambda x: sigmoid(x, **kwargs) else: # Try to find activation in tf.nn activation = getattr(tf, identifier, None) if activation is None: raise ValueError('Can not resolve {}'.format(identifier)) return activation else: raise TypeError('identifier must be callable or a string')
def get_global_constraint(self): if self.global_constraint in [None, '']: return None p = Parser.parse(self.global_constraint) if p.name in ['max_norm']: max_value = p.get_arg(float, default=2.0) axis = p.get_kwarg('axis', int, default=0) return tf.keras.constraints.max_norm(max_value=max_value, axis=axis) else: KeyError('Unknown constraint name `{}`'.format(p.name))
def __init__(self, identifier, set_logits=False): p = Parser.parse(identifier) self._id = p.name self.abbreviation = (p.name if isinstance(identifier, six.string_types) else identifier.__name__) self.full_name = self.abbreviation self._activation = activations.get(identifier) self._set_logits = set_logits
def get(identifier): if identifier is None or callable(identifier): return identifier if not isinstance(identifier, six.string_types): raise TypeError('identifier must be a function or a string') p = Parser.parse(identifier) key = p.name.lower() if key in ['l1']: return L1(penalty=p.get_arg(float)) elif key in ['l2']: return L2(penalty=p.get_arg(float)) else: raise ValueError('Can not resolve "{}"'.format(key))
def image_augmentation_processor(data_batch: DataSet, is_training: bool): # Get hub th = tfr.hub if not is_training or th.aug_config is None: return data_batch # Parse augmentation setting assert isinstance(th.aug_config, str) configs = [Parser.parse(s) for s in th.aug_config.split('|')] if len(configs) == 0: return data_batch # Apply each method according to configs for cfg in configs: # Find method if cfg.name == 'rotate': method = _rotate elif cfg.name == 'flip': method = _flip else: raise KeyError('!! Unknown augmentation option {}'.format( cfg.name)) # Do augmentation data_batch.features = method(data_batch.features, *cfg.arg_list, **cfg.arg_dict) return data_batch
def developer_options(self): if not self.developer_args: return None from tframe.utils.arg_parser import Parser parser = Parser.parse(self.developer_args) return parser