Example #1
0
def create_model(variant, pretrained=False, **kwargs):
    model_cfg = get_model_cfg(variant)
    model_args = model_cfg['arch_fn'](variant, **model_cfg['arch_cfg'])
    model_args.update(kwargs)

    # resolve some special layers and their arguments
    se_args = model_args.pop('se_cfg', {})  # not consumable by model
    if 'se_layer' not in model_args:
        if 'bound_act_fn' in se_args:
            se_args['bound_act_fn'] = get_act_fn(se_args['bound_act_fn'])
        if 'gate_fn' in se_args:
            se_args['gate_fn'] = get_act_fn(se_args['gate_fn'])
        model_args['se_layer'] = partial(SqueezeExcite, **se_args)

    bn_args = model_args.pop('bn_cfg')  # not consumable by model
    if 'norm_layer' not in model_args:
        model_args['norm_layer'] = partial(BatchNorm2d, **bn_args)

    model_args['act_fn'] = get_act_fn(model_args.pop(
        'act_fn', 'relu'))  # convert str -> fn

    model = EfficientNet(**model_args)
    model.default_cfg = model_cfg['default_cfg']

    if pretrained:
        load_pretrained(model, default_cfg=model.default_cfg)

    return model
def main():
    args = parser.parse_args()
    logging.set_verbosity(logging.ERROR)
    print('JAX host: %d / %d' % (jax.host_id(), jax.host_count()))
    print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True)

    if get_model_cfg(args.model) is not None:
        validate(args)
    else:
        models = list_models(pretrained=True)
        if args.model != 'all':
            models = fnmatch.filter(models, args.model)
        if not models:
            print(f'ERROR: No models found to validate with pattern ({args.model}).')
            exit(1)

        print('Validating:', ', '.join(models))
        results = []
        for m in models:
            args.model = m
            res = validate(args)
            res.update(dict(model=m))
            results.append(res)
        print('Results:')
        for r in results:
            print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}")
def create_model(variant,
                 pretrained=False,
                 rng=None,
                 input_shape=None,
                 dtype=jnp.float32,
                 **kwargs):
    model_cfg = get_model_cfg(variant)
    model_args = model_cfg['arch_fn'](variant, **model_cfg['arch_cfg'])
    model_args.update(kwargs)

    # resolve some special layers and their arguments
    se_args = model_args.pop('se_cfg', {})  # not consumable by model
    if 'se_layer' not in model_args:
        if 'bound_act_fn' in se_args:
            se_args['bound_act_fn'] = get_act_fn(se_args['bound_act_fn'])
        if 'gate_fn' in se_args:
            se_args['gate_fn'] = get_act_fn(se_args['gate_fn'])
        model_args['se_layer'] = partial(SqueezeExcite, **se_args)

    bn_args = model_args.pop('bn_cfg')  # not consumable by model
    if 'norm_layer' not in model_args:
        model_args['norm_layer'] = partial(batchnorm2d, **bn_args)

    model_args['act_fn'] = get_act_fn(model_args.pop(
        'act_fn', 'relu'))  # convert str -> fn

    model = EfficientNet(dtype=dtype,
                         default_cfg=model_cfg['default_cfg'],
                         **model_args)

    rng = jax.random.PRNGKey(0) if rng is None else rng
    params_rng, dropout_rng = jax.random.split(rng)
    input_shape = model_cfg['default_cfg'][
        'input_size'] if input_shape is None else input_shape
    input_shape = (1, input_shape[1], input_shape[2], input_shape[0]
                   )  # CHW -> HWC by default

    # FIXME is jiting the init worthwhile for my usage?
    #     @jax.jit
    #     def init(*args):
    #         return model.init(*args, training=True)

    variables = model.init({
        'params': params_rng,
        'dropout': dropout_rng
    },
                           jnp.ones(input_shape, dtype=dtype),
                           training=False)

    if pretrained:
        variables = load_pretrained(variables,
                                    default_cfg=model.default_cfg,
                                    filter_fn=_filter)

    return model, variables
def main():
    args = parser.parse_args()
    print('JAX host: %d / %d' % (jax.host_id(), jax.host_count()))
    print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()),
          flush=True)
    jax.config.enable_omnistaging()

    def _try_validate(args):
        res = None
        batch_size = args.batch_size
        while res is None:
            try:
                print(f'Setting validation batch size to {batch_size}')
                args.batch_size = batch_size
                res = validate(args)
            except RuntimeError as e:
                if batch_size <= 1:
                    print(
                        "Validation failed with no ability to reduce batch size. Exiting."
                    )
                    raise e
                batch_size = max(batch_size // 2, 1)
                print("Validation failed, reducing batch size by 50%")
        return res

    if get_model_cfg(args.model) is not None:
        _try_validate(args)
    else:
        models = list_models(pretrained=True)
        if args.model != 'all':
            models = fnmatch.filter(models, args.model)
        if not models:
            print(
                f'ERROR: No models found to validate with pattern {args.model}.'
            )
            exit(1)

        print('Validating:', ', '.join(models))
        results = []
        start_batch_size = args.batch_size
        for m in models:
            args.batch_size = start_batch_size  # reset in case reduced for retry
            args.model = m
            res = _try_validate(args)
            res.update(dict(model=m))
            results.append(res)
        print('Results:')
        for r in results:
            print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}")
Example #5
0
def main():
    args = parser.parse_args()
    print('JAX host: %d / %d' % (jax.host_id(), jax.host_count()))
    print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()),
          flush=True)
    jax.config.enable_omnistaging()

    if get_model_cfg(args.model) is not None:
        validate(args)
    else:
        models = list_models(pretrained=True)
        if args.model != 'all':
            models = fnmatch.filter(models, args.model)
        print('Validating:', ', '.join(models))
        results = []
        for m in models:
            args.model = m
            res = validate(args)
            res.update(dict(model=m))
            results.append(res)
        print('Results:')
        for r in results:
            print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}")