def validate(args): rng = jax.random.PRNGKey(0) model, variables = create_model(args.model, pretrained=True, rng=rng) print(f'Created {args.model} model. Validating...') if args.no_jit: eval_step = lambda images, labels: eval_forward( model, variables, images, labels) else: eval_step = jax.jit(lambda images, labels: eval_forward( model, variables, images, labels)) if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data) else: dataset = Dataset(args.data) data_config = resolve_data_config(vars(args), model=model) loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=False, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=8, crop_pct=data_config['crop_pct']) batch_time = AverageMeter() correct_top1, correct_top5 = 0, 0 total_examples = 0 start_time = prev_time = time.time() for batch_index, (images, labels) in enumerate(loader): images = images.numpy().transpose(0, 2, 3, 1) labels = labels.numpy() top1_count, top5_count = eval_step(images, labels) correct_top1 += top1_count correct_top5 += top5_count total_examples += images.shape[0] batch_time.update(time.time() - prev_time) if batch_index % 20 == 0 and batch_index > 0: print( f'Test: [{batch_index:>4d}/{len(loader)}] ' f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') prev_time = time.time() acc_1 = 100 * correct_top1 / total_examples acc_5 = 100 * correct_top5 / total_examples print( f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') return dict(top1=float(acc_1), top5=float(acc_5))
def validate(args): rng = jax.random.PRNGKey(0) platform = jax.local_devices()[0].platform if args.half_precision: if platform == 'tpu': model_dtype = jax.numpy.bfloat16 else: model_dtype = jax.numpy.float16 else: model_dtype = jax.numpy.float32 model, variables = create_model(args.model, pretrained=True, dtype=model_dtype, rng=rng) print(f'Created {args.model} model. Validating...') if args.no_jit: eval_step = lambda images, labels: eval_forward(model.apply, variables, images, labels) else: eval_step = jax.jit(lambda images, labels: eval_forward(model.apply, variables, images, labels)) """Runs evaluation and returns top-1 accuracy.""" image_size = model.default_cfg['input_size'][-1] eval_iter, num_batches = create_eval_iter( args.data, args.batch_size, image_size, half_precision=args.half_precision, mean=tuple([x * 255 for x in model.default_cfg['mean']]), std=tuple([x * 255 for x in model.default_cfg['std']]), interpolation=model.default_cfg['interpolation'], ) batch_time = AverageMeter() correct_top1, correct_top5 = 0, 0 total_examples = 0 start_time = prev_time = time.time() for batch_index, batch in enumerate(eval_iter): images, labels = batch['image'], batch['label'] top1_count, top5_count = eval_step(images, labels) correct_top1 += int(top1_count) correct_top5 += int(top5_count) total_examples += images.shape[0] batch_time.update(time.time() - prev_time) if batch_index % 20 == 0 and batch_index > 0: print( f'Test: [{batch_index:>4d}/{num_batches}] ' f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') prev_time = time.time() acc_1 = 100 * correct_top1 / total_examples acc_5 = 100 * correct_top5 / total_examples print(f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') return dict(top1=acc_1, top5=acc_5)
def validate(args): rng = jax.random.PRNGKey(0) model, variables = create_model(args.model, pretrained=True, rng=rng) print(f'Created {args.model} model. Validating...') if args.no_jit: eval_step = lambda images, labels: eval_forward( model, variables, images, labels) else: eval_step = jax.jit(lambda images, labels: eval_forward( model, variables, images, labels)) """Runs evaluation and returns top-1 accuracy.""" image_size = model.default_cfg['input_size'][-1] test_ds, num_batches = imagenet_data.load( imagenet_data.Split.TEST, is_training=False, image_size=image_size, batch_dims=[args.batch_size], mean=tuple([x * 255 for x in model.default_cfg['mean']]), std=tuple([x * 255 for x in model.default_cfg['std']]), tfds_data_dir=args.data) batch_time = AverageMeter() correct_top1, correct_top5 = 0, 0 total_examples = 0 start_time = prev_time = time.time() for batch_index, batch in enumerate(test_ds): images, labels = batch['images'], batch['labels'] top1_count, top5_count = eval_step(images, labels) correct_top1 += top1_count correct_top5 += top5_count total_examples += images.shape[0] batch_time.update(time.time() - prev_time) if batch_index % 20 == 0 and batch_index > 0: print( f'Test: [{batch_index:>4d}/{num_batches}] ' f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') prev_time = time.time() acc_1 = 100 * correct_top1 / total_examples acc_5 = 100 * correct_top5 / total_examples print( f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') return dict(top1=acc_1, top5=acc_5)
def train_and_evaluate(config: ml_collections.ConfigDict, resume: str): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. resume: Resume from checkpoints at specified dir if set (TDDO: support specific checkpoint file/step) """ rng = random.PRNGKey(42) if config.batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = config.batch_size // jax.host_count() config.eval_batch_size = config.eval_batch_size or config.batch_size if config.eval_batch_size % jax.device_count() > 0: raise ValueError( 'Validation batch size must be divisible by the number of devices') local_eval_batch_size = config.eval_batch_size // jax.host_count() platform = jax.local_devices()[0].platform half_prec = config.half_precision if half_prec: if platform == 'tpu': model_dtype = jnp.bfloat16 else: model_dtype = jnp.float16 else: model_dtype = jnp.float32 rng, model_create_rng = random.split(rng) model, variables = create_model(config.model, dtype=model_dtype, drop_rate=config.drop_rate, drop_path_rate=config.drop_path_rate, rng=model_create_rng) image_size = config.image_size or model.default_cfg['input_size'][-1] dataset_builder = tfds.builder(config.dataset, data_dir=config.data_dir) train_iter = create_input_iter( dataset_builder, local_batch_size, train=True, image_size=image_size, augment_name=config.autoaugment, randaug_magnitude=config.randaug_magnitude, randaug_num_layers=config.randaug_num_layers, half_precision=half_prec, cache=config.cache) eval_iter = create_input_iter(dataset_builder, local_eval_batch_size, train=False, image_size=image_size, half_precision=half_prec, cache=config.cache) steps_per_epoch = dataset_builder.info.splits[ 'train'].num_examples // config.batch_size if config.num_train_steps == -1: num_steps = steps_per_epoch * config.num_epochs else: num_steps = config.num_train_steps if config.steps_per_eval == -1: num_validation_examples = dataset_builder.info.splits[ 'validation'].num_examples steps_per_eval = num_validation_examples // config.eval_batch_size else: steps_per_eval = config.steps_per_eval steps_per_checkpoint = steps_per_epoch * 1 base_lr = config.lr * config.batch_size / 256. lr_fn = create_lr_schedule_epochs(base_lr, config.lr_schedule, steps_per_epoch=steps_per_epoch, total_epochs=config.num_epochs, decay_rate=config.lr_decay_rate, decay_epochs=config.lr_decay_epochs, warmup_epochs=config.lr_warmup_epochs, min_lr=config.lr_minimum) state = create_train_state(config, variables, lr_fn) if resume: state = restore_checkpoint(state, resume) # step_offset > 0 if restarting from checkpoint step_offset = int(state.step) state = flax.jax_utils.replicate(state) p_train_step = jax.pmap(functools.partial( train_step, model.apply, lr_fn=lr_fn, label_smoothing=config.label_smoothing, weight_decay=config.weight_decay), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, model.apply), axis_name='batch') p_eval_step_ema = None if config.ema_decay != 0.: p_eval_step_ema = jax.pmap(functools.partial(eval_step_ema, model.apply), axis_name='batch') if jax.host_id() == 0: if resume and step_offset > 0: output_dir = resume else: output_base = config.output_base_dir if config.output_base_dir else './output' exp_name = '-'.join( [datetime.now().strftime("%Y%m%d-%H%M%S"), config.model]) output_dir = get_outdir(output_base, exp_name) summary_writer = tensorboard.SummaryWriter(output_dir) summary_writer.hparams(dict(config)) epoch_metrics = [] t_loop_start = time.time() num_samples = 0 for step, batch in zip(range(step_offset, num_steps), train_iter): step_p1 = step + 1 rng, step_rng = random.split(rng) sharded_rng = common_utils.shard_prng_key(step_rng) num_samples += config.batch_size state, metrics = p_train_step(state, batch, dropout_rng=sharded_rng) epoch_metrics.append(metrics) if step_p1 % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) samples_per_sec = num_samples / (time.time() - t_loop_start) logging.info( 'train epoch: %d, loss: %.4f, img/sec %.2f, top1: %.2f, top5: %.3f', epoch, summary['loss'], samples_per_sec, summary['top1'], summary['top5']) if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step_p1 - len(vals) + i) summary_writer.scalar('samples per second', samples_per_sec, step) epoch_metrics = [] state = sync_batch_stats( state) # sync batch statistics across replicas eval_metrics = [] for step_eval in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, top1: %.2f, top5: %.3f', epoch, summary['loss'], summary['top1'], summary['top5']) if p_eval_step_ema is not None: # NOTE running both ema and non-ema eval while improving this script eval_metrics = [] for step_eval in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step_ema(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info( 'eval epoch ema: %d, loss: %.4f, top1: %.2f, top5: %.3f', epoch, summary['loss'], summary['top1'], summary['top5']) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() t_loop_start = time.time() num_samples = 0 elif step_p1 % 100 == 0: summary = jax.tree_map(lambda x: x.mean(), common_utils.get_metrics(epoch_metrics)) samples_per_sec = num_samples / (time.time() - t_loop_start) logging.info('train steps: %d, loss: %.4f, img/sec: %.2f', step_p1, summary['loss'], samples_per_sec) if step_p1 % steps_per_checkpoint == 0 or step_p1 == num_steps: state = sync_batch_stats(state) save_checkpoint(state, output_dir) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()