def main():
    args = parser.parse_args()
    logging.set_verbosity(logging.ERROR)
    print('JAX host: %d / %d' % (jax.host_id(), jax.host_count()))
    print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True)

    if get_model_cfg(args.model) is not None:
        validate(args)
    else:
        models = list_models(pretrained=True)
        if args.model != 'all':
            models = fnmatch.filter(models, args.model)
        if not models:
            print(f'ERROR: No models found to validate with pattern ({args.model}).')
            exit(1)

        print('Validating:', ', '.join(models))
        results = []
        for m in models:
            args.model = m
            res = validate(args)
            res.update(dict(model=m))
            results.append(res)
        print('Results:')
        for r in results:
            print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}")
예제 #2
0
def main():
    args = parser.parse_args()

    all_models = list_models(pretrained=True)
    if args.model == 'all':
        for model_name in all_models:
            export_model(model_name, args.output)
    else:
        export_model(args.model, args.output)
def main():
    args = parser.parse_args()
    print('JAX host: %d / %d' % (jax.host_id(), jax.host_count()))
    print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()),
          flush=True)
    jax.config.enable_omnistaging()

    def _try_validate(args):
        res = None
        batch_size = args.batch_size
        while res is None:
            try:
                print(f'Setting validation batch size to {batch_size}')
                args.batch_size = batch_size
                res = validate(args)
            except RuntimeError as e:
                if batch_size <= 1:
                    print(
                        "Validation failed with no ability to reduce batch size. Exiting."
                    )
                    raise e
                batch_size = max(batch_size // 2, 1)
                print("Validation failed, reducing batch size by 50%")
        return res

    if get_model_cfg(args.model) is not None:
        _try_validate(args)
    else:
        models = list_models(pretrained=True)
        if args.model != 'all':
            models = fnmatch.filter(models, args.model)
        if not models:
            print(
                f'ERROR: No models found to validate with pattern {args.model}.'
            )
            exit(1)

        print('Validating:', ', '.join(models))
        results = []
        start_batch_size = args.batch_size
        for m in models:
            args.batch_size = start_batch_size  # reset in case reduced for retry
            args.model = m
            res = _try_validate(args)
            res.update(dict(model=m))
            results.append(res)
        print('Results:')
        for r in results:
            print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}")
예제 #4
0
def main():
    args = parser.parse_args()
    print('JAX host: %d / %d' % (jax.host_id(), jax.host_count()))
    print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()),
          flush=True)
    jax.config.enable_omnistaging()

    if get_model_cfg(args.model) is not None:
        validate(args)
    else:
        models = list_models(pretrained=True)
        if args.model != 'all':
            models = fnmatch.filter(models, args.model)
        print('Validating:', ', '.join(models))
        results = []
        for m in models:
            args.model = m
            res = validate(args)
            res.update(dict(model=m))
            results.append(res)
        print('Results:')
        for r in results:
            print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}")