예제 #1
0
def validate(args):
    rng = jax.random.PRNGKey(0)
    model, variables = create_model(args.model, pretrained=True, rng=rng)
    print(f'Created {args.model} model. Validating...')

    if args.no_jit:
        eval_step = lambda images, labels: eval_forward(
            model, variables, images, labels)
    else:
        eval_step = jax.jit(lambda images, labels: eval_forward(
            model, variables, images, labels))

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data)
    else:
        dataset = Dataset(args.data)

    data_config = resolve_data_config(vars(args), model=model)
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=False,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=8,
                           crop_pct=data_config['crop_pct'])

    batch_time = AverageMeter()
    correct_top1, correct_top5 = 0, 0
    total_examples = 0
    start_time = prev_time = time.time()
    for batch_index, (images, labels) in enumerate(loader):
        images = images.numpy().transpose(0, 2, 3, 1)
        labels = labels.numpy()

        top1_count, top5_count = eval_step(images, labels)
        correct_top1 += top1_count
        correct_top5 += top5_count
        total_examples += images.shape[0]

        batch_time.update(time.time() - prev_time)
        if batch_index % 20 == 0 and batch_index > 0:
            print(
                f'Test: [{batch_index:>4d}/{len(loader)}]  '
                f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) '
                f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} '
                f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}')
        prev_time = time.time()

    acc_1 = 100 * correct_top1 / total_examples
    acc_5 = 100 * correct_top5 / total_examples
    print(
        f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. '
        f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}')
    return dict(top1=float(acc_1), top5=float(acc_5))
def validate(args):
    rng = jax.random.PRNGKey(0)
    platform = jax.local_devices()[0].platform
    if args.half_precision:
        if platform == 'tpu':
            model_dtype = jax.numpy.bfloat16
        else:
            model_dtype = jax.numpy.float16
    else:
        model_dtype = jax.numpy.float32

    model, variables = create_model(args.model, pretrained=True, dtype=model_dtype, rng=rng)
    print(f'Created {args.model} model. Validating...')

    if args.no_jit:
        eval_step = lambda images, labels: eval_forward(model.apply, variables, images, labels)
    else:
        eval_step = jax.jit(lambda images, labels: eval_forward(model.apply, variables, images, labels))

    """Runs evaluation and returns top-1 accuracy."""
    image_size = model.default_cfg['input_size'][-1]

    eval_iter, num_batches = create_eval_iter(
        args.data, args.batch_size, image_size,
        half_precision=args.half_precision,
        mean=tuple([x * 255 for x in model.default_cfg['mean']]),
        std=tuple([x * 255 for x in model.default_cfg['std']]),
        interpolation=model.default_cfg['interpolation'],
    )

    batch_time = AverageMeter()
    correct_top1, correct_top5 = 0, 0
    total_examples = 0
    start_time = prev_time = time.time()
    for batch_index, batch in enumerate(eval_iter):
        images, labels = batch['image'], batch['label']
        top1_count, top5_count = eval_step(images, labels)
        correct_top1 += int(top1_count)
        correct_top5 += int(top5_count)
        total_examples += images.shape[0]

        batch_time.update(time.time() - prev_time)
        if batch_index % 20 == 0 and batch_index > 0:
            print(
                f'Test: [{batch_index:>4d}/{num_batches}]  '
                f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) '
                f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} '
                f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}')
        prev_time = time.time()

    acc_1 = 100 * correct_top1 / total_examples
    acc_5 = 100 * correct_top5 / total_examples
    print(f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. '
          f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}')
    return dict(top1=acc_1, top5=acc_5)
예제 #3
0
def validate(args):
    rng = jax.random.PRNGKey(0)

    model, variables = create_model(args.model, pretrained=True, rng=rng)
    print(f'Created {args.model} model. Validating...')

    if args.no_jit:
        eval_step = lambda images, labels: eval_forward(
            model, variables, images, labels)
    else:
        eval_step = jax.jit(lambda images, labels: eval_forward(
            model, variables, images, labels))
    """Runs evaluation and returns top-1 accuracy."""
    image_size = model.default_cfg['input_size'][-1]
    test_ds, num_batches = imagenet_data.load(
        imagenet_data.Split.TEST,
        is_training=False,
        image_size=image_size,
        batch_dims=[args.batch_size],
        mean=tuple([x * 255 for x in model.default_cfg['mean']]),
        std=tuple([x * 255 for x in model.default_cfg['std']]),
        tfds_data_dir=args.data)

    batch_time = AverageMeter()
    correct_top1, correct_top5 = 0, 0
    total_examples = 0
    start_time = prev_time = time.time()
    for batch_index, batch in enumerate(test_ds):
        images, labels = batch['images'], batch['labels']
        top1_count, top5_count = eval_step(images, labels)
        correct_top1 += top1_count
        correct_top5 += top5_count
        total_examples += images.shape[0]

        batch_time.update(time.time() - prev_time)
        if batch_index % 20 == 0 and batch_index > 0:
            print(
                f'Test: [{batch_index:>4d}/{num_batches}]  '
                f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) '
                f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} '
                f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}')
        prev_time = time.time()

    acc_1 = 100 * correct_top1 / total_examples
    acc_5 = 100 * correct_top5 / total_examples
    print(
        f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. '
        f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}')
    return dict(top1=acc_1, top5=acc_5)
예제 #4
0
def train_and_evaluate(config: ml_collections.ConfigDict, resume: str):
    """Execute model training and evaluation loop.

    Args:
      config: Hyperparameter configuration for training and evaluation.
      resume: Resume from checkpoints at specified dir if set (TDDO: support specific checkpoint file/step)
    """
    rng = random.PRNGKey(42)

    if config.batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = config.batch_size // jax.host_count()
    config.eval_batch_size = config.eval_batch_size or config.batch_size
    if config.eval_batch_size % jax.device_count() > 0:
        raise ValueError(
            'Validation batch size must be divisible by the number of devices')
    local_eval_batch_size = config.eval_batch_size // jax.host_count()

    platform = jax.local_devices()[0].platform
    half_prec = config.half_precision
    if half_prec:
        if platform == 'tpu':
            model_dtype = jnp.bfloat16
        else:
            model_dtype = jnp.float16
    else:
        model_dtype = jnp.float32

    rng, model_create_rng = random.split(rng)
    model, variables = create_model(config.model,
                                    dtype=model_dtype,
                                    drop_rate=config.drop_rate,
                                    drop_path_rate=config.drop_path_rate,
                                    rng=model_create_rng)
    image_size = config.image_size or model.default_cfg['input_size'][-1]

    dataset_builder = tfds.builder(config.dataset, data_dir=config.data_dir)

    train_iter = create_input_iter(
        dataset_builder,
        local_batch_size,
        train=True,
        image_size=image_size,
        augment_name=config.autoaugment,
        randaug_magnitude=config.randaug_magnitude,
        randaug_num_layers=config.randaug_num_layers,
        half_precision=half_prec,
        cache=config.cache)

    eval_iter = create_input_iter(dataset_builder,
                                  local_eval_batch_size,
                                  train=False,
                                  image_size=image_size,
                                  half_precision=half_prec,
                                  cache=config.cache)

    steps_per_epoch = dataset_builder.info.splits[
        'train'].num_examples // config.batch_size

    if config.num_train_steps == -1:
        num_steps = steps_per_epoch * config.num_epochs
    else:
        num_steps = config.num_train_steps

    if config.steps_per_eval == -1:
        num_validation_examples = dataset_builder.info.splits[
            'validation'].num_examples
        steps_per_eval = num_validation_examples // config.eval_batch_size
    else:
        steps_per_eval = config.steps_per_eval

    steps_per_checkpoint = steps_per_epoch * 1

    base_lr = config.lr * config.batch_size / 256.
    lr_fn = create_lr_schedule_epochs(base_lr,
                                      config.lr_schedule,
                                      steps_per_epoch=steps_per_epoch,
                                      total_epochs=config.num_epochs,
                                      decay_rate=config.lr_decay_rate,
                                      decay_epochs=config.lr_decay_epochs,
                                      warmup_epochs=config.lr_warmup_epochs,
                                      min_lr=config.lr_minimum)

    state = create_train_state(config, variables, lr_fn)
    if resume:
        state = restore_checkpoint(state, resume)
    # step_offset > 0 if restarting from checkpoint
    step_offset = int(state.step)
    state = flax.jax_utils.replicate(state)

    p_train_step = jax.pmap(functools.partial(
        train_step,
        model.apply,
        lr_fn=lr_fn,
        label_smoothing=config.label_smoothing,
        weight_decay=config.weight_decay),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step, model.apply),
                           axis_name='batch')
    p_eval_step_ema = None
    if config.ema_decay != 0.:
        p_eval_step_ema = jax.pmap(functools.partial(eval_step_ema,
                                                     model.apply),
                                   axis_name='batch')

    if jax.host_id() == 0:
        if resume and step_offset > 0:
            output_dir = resume
        else:
            output_base = config.output_base_dir if config.output_base_dir else './output'
            exp_name = '-'.join(
                [datetime.now().strftime("%Y%m%d-%H%M%S"), config.model])
            output_dir = get_outdir(output_base, exp_name)
        summary_writer = tensorboard.SummaryWriter(output_dir)
        summary_writer.hparams(dict(config))

    epoch_metrics = []
    t_loop_start = time.time()
    num_samples = 0
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        step_p1 = step + 1
        rng, step_rng = random.split(rng)
        sharded_rng = common_utils.shard_prng_key(step_rng)

        num_samples += config.batch_size
        state, metrics = p_train_step(state, batch, dropout_rng=sharded_rng)
        epoch_metrics.append(metrics)

        if step_p1 % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            epoch_metrics = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
            samples_per_sec = num_samples / (time.time() - t_loop_start)
            logging.info(
                'train epoch: %d, loss: %.4f, img/sec %.2f, top1: %.2f, top5: %.3f',
                epoch, summary['loss'], samples_per_sec, summary['top1'],
                summary['top5'])

            if jax.host_id() == 0:
                for key, vals in epoch_metrics.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step_p1 - len(vals) + i)
                summary_writer.scalar('samples per second', samples_per_sec,
                                      step)
            epoch_metrics = []
            state = sync_batch_stats(
                state)  # sync batch statistics across replicas

            eval_metrics = []
            for step_eval in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)

            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, top1: %.2f, top5: %.3f',
                         epoch, summary['loss'], summary['top1'],
                         summary['top5'])

            if p_eval_step_ema is not None:
                # NOTE running both ema and non-ema eval while improving this script
                eval_metrics = []
                for step_eval in range(steps_per_eval):
                    eval_batch = next(eval_iter)
                    metrics = p_eval_step_ema(state, eval_batch)
                    eval_metrics.append(metrics)

                eval_metrics = common_utils.get_metrics(eval_metrics)
                summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
                logging.info(
                    'eval epoch ema: %d, loss: %.4f, top1: %.2f, top5: %.3f',
                    epoch, summary['loss'], summary['top1'], summary['top5'])

            if jax.host_id() == 0:
                for key, val in eval_metrics.items():
                    tag = 'eval_%s' % key
                    summary_writer.scalar(tag, val.mean(), step)
                summary_writer.flush()
            t_loop_start = time.time()
            num_samples = 0

        elif step_p1 % 100 == 0:
            summary = jax.tree_map(lambda x: x.mean(),
                                   common_utils.get_metrics(epoch_metrics))
            samples_per_sec = num_samples / (time.time() - t_loop_start)
            logging.info('train steps: %d, loss: %.4f, img/sec: %.2f', step_p1,
                         summary['loss'], samples_per_sec)

        if step_p1 % steps_per_checkpoint == 0 or step_p1 == num_steps:
            state = sync_batch_stats(state)
            save_checkpoint(state, output_dir)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()