コード例 #1
0
ファイル: activations.py プロジェクト: garthtrickett/tframe
def get(identifier, **kwargs):
    # Sanity check
    assert len(kwargs) == 0
    # Return identifier directly if it is callable
    if callable(identifier): return identifier
    elif isinstance(identifier, six.string_types):
        # Parse identifier
        p = Parser.parse(identifier)
        identifier = p.name.lower()
        if identifier in ['relu']: return relu
        elif identifier in ['lrelu', 'leakyrelu', 'leaky-relu']:
            leak = p.get_arg(float, default=0.1)
            return lambda x: leaky_relu(x, leak=leak)
        elif identifier in ['softmax']:
            return softmax
        elif identifier in ['cumax']:
            return cumax
        elif identifier in ['sigmoid']:
            return lambda x: sigmoid(x, **kwargs)
        else:
            # Try to find activation in tf.nn
            activation = getattr(tf, identifier, None)
            if activation is None:
                raise ValueError('Can not resolve {}'.format(identifier))
            return activation
    else:
        raise TypeError('identifier must be callable or a string')
コード例 #2
0
 def get_global_constraint(self):
   if self.global_constraint in [None, '']: return None
   p = Parser.parse(self.global_constraint)
   if p.name in ['max_norm']:
     max_value = p.get_arg(float, default=2.0)
     axis = p.get_kwarg('axis', int, default=0)
     return tf.keras.constraints.max_norm(max_value=max_value, axis=axis)
   else: KeyError('Unknown constraint name `{}`'.format(p.name))
コード例 #3
0
ファイル: common.py プロジェクト: garthtrickett/tframe
 def __init__(self, identifier, set_logits=False):
     p = Parser.parse(identifier)
     self._id = p.name
     self.abbreviation = (p.name if isinstance(identifier, six.string_types)
                          else identifier.__name__)
     self.full_name = self.abbreviation
     self._activation = activations.get(identifier)
     self._set_logits = set_logits
コード例 #4
0
ファイル: regularizers.py プロジェクト: winkywow/tframe
def get(identifier):
    if identifier is None or callable(identifier): return identifier
    if not isinstance(identifier, six.string_types):
        raise TypeError('identifier must be a function or a string')

    p = Parser.parse(identifier)
    key = p.name.lower()
    if key in ['l1']: return L1(penalty=p.get_arg(float))
    elif key in ['l2']: return L2(penalty=p.get_arg(float))
    else: raise ValueError('Can not resolve "{}"'.format(key))
コード例 #5
0
def image_augmentation_processor(data_batch: DataSet, is_training: bool):
    # Get hub
    th = tfr.hub
    if not is_training or th.aug_config is None: return data_batch
    # Parse augmentation setting
    assert isinstance(th.aug_config, str)
    configs = [Parser.parse(s) for s in th.aug_config.split('|')]
    if len(configs) == 0: return data_batch

    # Apply each method according to configs
    for cfg in configs:
        # Find method
        if cfg.name == 'rotate': method = _rotate
        elif cfg.name == 'flip': method = _flip
        else:
            raise KeyError('!! Unknown augmentation option {}'.format(
                cfg.name))
        # Do augmentation
        data_batch.features = method(data_batch.features, *cfg.arg_list,
                                     **cfg.arg_dict)

    return data_batch
コード例 #6
0
ファイル: config_base.py プロジェクト: garthtrickett/tframe
 def developer_options(self):
   if not self.developer_args: return None
   from tframe.utils.arg_parser import Parser
   parser = Parser.parse(self.developer_args)
   return parser