def create_model(variant, pretrained=False, **kwargs): model_cfg = get_model_cfg(variant) model_args = model_cfg['arch_fn'](variant, **model_cfg['arch_cfg']) model_args.update(kwargs) # resolve some special layers and their arguments se_args = model_args.pop('se_cfg', {}) # not consumable by model if 'se_layer' not in model_args: if 'bound_act_fn' in se_args: se_args['bound_act_fn'] = get_act_fn(se_args['bound_act_fn']) if 'gate_fn' in se_args: se_args['gate_fn'] = get_act_fn(se_args['gate_fn']) model_args['se_layer'] = partial(SqueezeExcite, **se_args) bn_args = model_args.pop('bn_cfg') # not consumable by model if 'norm_layer' not in model_args: model_args['norm_layer'] = partial(BatchNorm2d, **bn_args) model_args['act_fn'] = get_act_fn(model_args.pop( 'act_fn', 'relu')) # convert str -> fn model = EfficientNet(**model_args) model.default_cfg = model_cfg['default_cfg'] if pretrained: load_pretrained(model, default_cfg=model.default_cfg) return model
def main(): args = parser.parse_args() logging.set_verbosity(logging.ERROR) print('JAX host: %d / %d' % (jax.host_id(), jax.host_count())) print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True) if get_model_cfg(args.model) is not None: validate(args) else: models = list_models(pretrained=True) if args.model != 'all': models = fnmatch.filter(models, args.model) if not models: print(f'ERROR: No models found to validate with pattern ({args.model}).') exit(1) print('Validating:', ', '.join(models)) results = [] for m in models: args.model = m res = validate(args) res.update(dict(model=m)) results.append(res) print('Results:') for r in results: print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}")
def create_model(variant, pretrained=False, rng=None, input_shape=None, dtype=jnp.float32, **kwargs): model_cfg = get_model_cfg(variant) model_args = model_cfg['arch_fn'](variant, **model_cfg['arch_cfg']) model_args.update(kwargs) # resolve some special layers and their arguments se_args = model_args.pop('se_cfg', {}) # not consumable by model if 'se_layer' not in model_args: if 'bound_act_fn' in se_args: se_args['bound_act_fn'] = get_act_fn(se_args['bound_act_fn']) if 'gate_fn' in se_args: se_args['gate_fn'] = get_act_fn(se_args['gate_fn']) model_args['se_layer'] = partial(SqueezeExcite, **se_args) bn_args = model_args.pop('bn_cfg') # not consumable by model if 'norm_layer' not in model_args: model_args['norm_layer'] = partial(batchnorm2d, **bn_args) model_args['act_fn'] = get_act_fn(model_args.pop( 'act_fn', 'relu')) # convert str -> fn model = EfficientNet(dtype=dtype, default_cfg=model_cfg['default_cfg'], **model_args) rng = jax.random.PRNGKey(0) if rng is None else rng params_rng, dropout_rng = jax.random.split(rng) input_shape = model_cfg['default_cfg'][ 'input_size'] if input_shape is None else input_shape input_shape = (1, input_shape[1], input_shape[2], input_shape[0] ) # CHW -> HWC by default # FIXME is jiting the init worthwhile for my usage? # @jax.jit # def init(*args): # return model.init(*args, training=True) variables = model.init({ 'params': params_rng, 'dropout': dropout_rng }, jnp.ones(input_shape, dtype=dtype), training=False) if pretrained: variables = load_pretrained(variables, default_cfg=model.default_cfg, filter_fn=_filter) return model, variables
def main(): args = parser.parse_args() print('JAX host: %d / %d' % (jax.host_id(), jax.host_count())) print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True) jax.config.enable_omnistaging() def _try_validate(args): res = None batch_size = args.batch_size while res is None: try: print(f'Setting validation batch size to {batch_size}') args.batch_size = batch_size res = validate(args) except RuntimeError as e: if batch_size <= 1: print( "Validation failed with no ability to reduce batch size. Exiting." ) raise e batch_size = max(batch_size // 2, 1) print("Validation failed, reducing batch size by 50%") return res if get_model_cfg(args.model) is not None: _try_validate(args) else: models = list_models(pretrained=True) if args.model != 'all': models = fnmatch.filter(models, args.model) if not models: print( f'ERROR: No models found to validate with pattern {args.model}.' ) exit(1) print('Validating:', ', '.join(models)) results = [] start_batch_size = args.batch_size for m in models: args.batch_size = start_batch_size # reset in case reduced for retry args.model = m res = _try_validate(args) res.update(dict(model=m)) results.append(res) print('Results:') for r in results: print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}")
def main(): args = parser.parse_args() print('JAX host: %d / %d' % (jax.host_id(), jax.host_count())) print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True) jax.config.enable_omnistaging() if get_model_cfg(args.model) is not None: validate(args) else: models = list_models(pretrained=True) if args.model != 'all': models = fnmatch.filter(models, args.model) print('Validating:', ', '.join(models)) results = [] for m in models: args.model = m res = validate(args) res.update(dict(model=m)) results.append(res) print('Results:') for r in results: print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}")