def validate(args): rng = jax.random.PRNGKey(0) model, variables = create_model(args.model, pretrained=True, rng=rng) print(f'Created {args.model} model. Validating...') if args.no_jit: eval_step = lambda images, labels: eval_forward( model, variables, images, labels) else: eval_step = jax.jit(lambda images, labels: eval_forward( model, variables, images, labels)) if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data) else: dataset = Dataset(args.data) data_config = resolve_data_config(vars(args), model=model) loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=False, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=8, crop_pct=data_config['crop_pct']) batch_time = AverageMeter() correct_top1, correct_top5 = 0, 0 total_examples = 0 start_time = prev_time = time.time() for batch_index, (images, labels) in enumerate(loader): images = images.numpy().transpose(0, 2, 3, 1) labels = labels.numpy() top1_count, top5_count = eval_step(images, labels) correct_top1 += top1_count correct_top5 += top5_count total_examples += images.shape[0] batch_time.update(time.time() - prev_time) if batch_index % 20 == 0 and batch_index > 0: print( f'Test: [{batch_index:>4d}/{len(loader)}] ' f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') prev_time = time.time() acc_1 = 100 * correct_top1 / total_examples acc_5 = 100 * correct_top5 / total_examples print( f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') return dict(top1=float(acc_1), top5=float(acc_5))
def validate(args): rng = jax.random.PRNGKey(0) platform = jax.local_devices()[0].platform if args.half_precision: if platform == 'tpu': model_dtype = jax.numpy.bfloat16 else: model_dtype = jax.numpy.float16 else: model_dtype = jax.numpy.float32 model, variables = create_model(args.model, pretrained=True, dtype=model_dtype, rng=rng) print(f'Created {args.model} model. Validating...') if args.no_jit: eval_step = lambda images, labels: eval_forward(model.apply, variables, images, labels) else: eval_step = jax.jit(lambda images, labels: eval_forward(model.apply, variables, images, labels)) """Runs evaluation and returns top-1 accuracy.""" image_size = model.default_cfg['input_size'][-1] eval_iter, num_batches = create_eval_iter( args.data, args.batch_size, image_size, half_precision=args.half_precision, mean=tuple([x * 255 for x in model.default_cfg['mean']]), std=tuple([x * 255 for x in model.default_cfg['std']]), interpolation=model.default_cfg['interpolation'], ) batch_time = AverageMeter() correct_top1, correct_top5 = 0, 0 total_examples = 0 start_time = prev_time = time.time() for batch_index, batch in enumerate(eval_iter): images, labels = batch['image'], batch['label'] top1_count, top5_count = eval_step(images, labels) correct_top1 += int(top1_count) correct_top5 += int(top5_count) total_examples += images.shape[0] batch_time.update(time.time() - prev_time) if batch_index % 20 == 0 and batch_index > 0: print( f'Test: [{batch_index:>4d}/{num_batches}] ' f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') prev_time = time.time() acc_1 = 100 * correct_top1 / total_examples acc_5 = 100 * correct_top5 / total_examples print(f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') return dict(top1=acc_1, top5=acc_5)
def validate(args): model = create_model(args.model, pretrained=True) print(f'Created {args.model} model. Validating...') eval_step = objax.Jit( lambda images, labels: eval_forward(model, images, labels), model.vars()) dataset = create_dataset('imagenet', args.data) data_config = resolve_data_config(vars(args), model=model) loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=False, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=8, crop_pct=data_config['crop_pct']) batch_time = AverageMeter() correct_top1, correct_top5 = 0, 0 total_examples = 0 start_time = prev_time = time.time() for batch_index, (images, labels) in enumerate(loader): images = images.numpy() labels = labels.numpy() top1_count, top5_count = eval_step(images, labels) correct_top1 += int(top1_count) correct_top5 += int(top5_count) total_examples += images.shape[0] batch_time.update(time.time() - prev_time) if batch_index % 20 == 0 and batch_index > 0: print( f'Test: [{batch_index:>4d}/{len(loader)}] ' f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') prev_time = time.time() acc_1 = 100 * correct_top1 / total_examples acc_5 = 100 * correct_top5 / total_examples print(f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') return dict(top1=float(acc_1), top5=float(acc_5))
def validate(args): rng = jax.random.PRNGKey(0) model, variables = create_model(args.model, pretrained=True, rng=rng) print(f'Created {args.model} model. Validating...') if args.no_jit: eval_step = lambda images, labels: eval_forward( model, variables, images, labels) else: eval_step = jax.jit(lambda images, labels: eval_forward( model, variables, images, labels)) """Runs evaluation and returns top-1 accuracy.""" image_size = model.default_cfg['input_size'][-1] test_ds, num_batches = imagenet_data.load( imagenet_data.Split.TEST, is_training=False, image_size=image_size, batch_dims=[args.batch_size], mean=tuple([x * 255 for x in model.default_cfg['mean']]), std=tuple([x * 255 for x in model.default_cfg['std']]), tfds_data_dir=args.data) batch_time = AverageMeter() correct_top1, correct_top5 = 0, 0 total_examples = 0 start_time = prev_time = time.time() for batch_index, batch in enumerate(test_ds): images, labels = batch['images'], batch['labels'] top1_count, top5_count = eval_step(images, labels) correct_top1 += top1_count correct_top5 += top5_count total_examples += images.shape[0] batch_time.update(time.time() - prev_time) if batch_index % 20 == 0 and batch_index > 0: print( f'Test: [{batch_index:>4d}/{num_batches}] ' f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') prev_time = time.time() acc_1 = 100 * correct_top1 / total_examples acc_5 = 100 * correct_top5 / total_examples print( f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') return dict(top1=acc_1, top5=acc_5)