Esempio n. 1
0
def tf_distribute_config(base_tf_server_port: int):
    """
    Generates a TensorFlow cluster information and sets it to TF_CONFIG environment variable.
    TF_CONFIG won't be altered if it was externally set.
    """
    hls_addresses = str(os.environ.get(
        'MULTI_HLS_IPS', '127.0.0.1')).split(',')
    rank = comm_rank()
    size = comm_size()

    worker_hosts = ",".join([",".join([address + ':' + str(base_tf_server_port + r)
                                       for r in range(size//len(hls_addresses))])
                            for address in hls_addresses])

    configure_cluster(worker_hosts, rank)
    print(os.environ['TF_CONFIG'])
Esempio n. 2
0
def load_pretrained(size, weights, pretrained_top, model):
    """Load model weights for a known configuration."""
    fname = f"ViT-{size}_{weights}.npz"
    origin = f"{BASE_URL}/{fname}"
    from config import config
    # If more than 1 process exist - rank 0 download weights, other ranks use same file.
    if comm_size() > 1:
        from mpi4py import MPI
        if comm_rank() == 0:
            tf.keras.utils.get_file(fname,
                                    origin,
                                    cache_subdir="weights",
                                    cache_dir=config.WEIGHTS_DIR)
        MPI.COMM_WORLD.Barrier()
    local_filepath = tf.keras.utils.get_file(fname,
                                             origin,
                                             cache_subdir="weights",
                                             cache_dir=config.WEIGHTS_DIR)
    utils.load_weights_numpy(model, local_filepath, pretrained_top)
Esempio n. 3
0
def get_dataset(tfrecords_dir, subset, batch_size, is_training, distributed):
    """Read TFRecords files and turn them into a TFRecordDataset.

    Args:
        tfrecords_dir: dataset directory
        subset: pattern to detect subset in dataset directory
        batch_size: Global batch size
        is_training (bool): use True if dataset will be used for training
        distributed (bool): use True if used in distributed environment

    Returns:
        TFRecordDataset: Dataset.
    """
    filenames = tf.io.matching_files(
        os.path.join(tfrecords_dir, '%s-*' % subset))
    ds = tf.data.Dataset.from_tensor_slices(filenames)

    # Sharding should be used only for training and in distributed environments.
    if distributed and is_training:
        from habana_frameworks.tensorflow.multinode_helpers import comm_rank, comm_size
        ds = ds.shard(comm_size(), comm_rank())

    if is_training:
        num_files = tf.cast(tf.shape(input=filenames)[0], tf.int64)
        ds = ds.shuffle(buffer_size=num_files)

    ds = ds.interleave(tf.data.TFRecordDataset, cycle_length=10)

    if is_training:
        ds = ds.shuffle(buffer_size=10000)
    ds = ds.repeat()

    parser = partial(_parse_fn, is_training=is_training)
    ds = ds.map(map_func=parser,
                num_parallel_calls=config.NUM_DATA_WORKERS, deterministic=False)
    ds = ds.batch(batch_size=batch_size, drop_remainder=True)

    # Sharding is already done, so disable autosharding.
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF

    return ds.with_options(options)
Esempio n. 4
0
def main():
    parser = DenseNetArgumentParser(
        description=(
            "train.py is the main training/evaluation script for DenseNet. "
            "In order to run training on multiple Gaudi cards, use demo_densenet.py or run "
            "train.py with mpirun."))
    args, _ = parser.parse_known_args()

    strategy = None
    verbose = 1

    os.environ['ENABLE_EXPERIMENTAL_FLAGS'] = 'true'
    os.environ['RUN_TPC_FUSER'] = '******'

    if args.deterministic:
        if args.inputs is None:
            raise ValueError("Must provide inputs for deterministic mode")
        if args.resume_from_checkpoint_path is None:
            raise ValueError("Must provide checkpoint for deterministic mode")

    if args.dtype == 'bf16':
        os.environ['TF_BF16_CONVERSION'] = '1'

    if args.run_on_hpu:
        load_habana_module()
        if args.use_hpu_strategy:
            hls_addresses = str(os.environ.get(
                "MULTI_HLS_IPS", "127.0.0.1")).split(",")
            TF_BASE_PORT = 2410
            mpi_rank = comm_rank()
            mpi_size = comm_size()
            if mpi_rank > 0:
                verbose = 0
            worker_hosts = ""
            for address in hls_addresses:
                # worker_hosts: comma-separated list of worker ip:port pairs.
                worker_hosts = worker_hosts + ",".join(
                    [address + ':' + str(TF_BASE_PORT + rank)
                     for rank in range(mpi_size//len(hls_addresses))])
            task_index = mpi_rank

            # Configures cluster spec for distribution strategy.
            _ = distribution_utils.configure_cluster(worker_hosts, task_index)
            strategy = HPUStrategy()
            print('Number of devices: {}'.format(
                strategy.num_replicas_in_sync))
    else:
        strategy = tf.distribute.MultiWorkerMirroredStrategy()
        print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

    if args.seed is not None:
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
        random.seed(args.seed)
        np.random.seed(args.seed)
        tf.random.set_seed(args.seed)

    img_rows, img_cols = 224, 224  # Resolution of inputs
    channel = 3
    num_classes = 1000
    batch_size = args.batch_size
    nb_epoch = args.epochs
    dataset_dir = args.dataset_dir
    resume_from_checkpoint_path = args.resume_from_checkpoint_path
    resume_from_epoch = args.resume_from_epoch
    dropout_rate = args.dropout_rate
    weight_decay = args.weight_decay
    optim_name = args.optimizer
    initial_lr = args.initial_lr
    model_name = args.model
    save_summary_steps = args.save_summary_steps

    if model_name == "densenet121":
        growth_rate = 32
        nb_filter = 64
        nb_layers = [6, 12, 24, 16]

    elif model_name == "densenet161":
        growth_rate = 48
        nb_filter = 96
        nb_layers = [6, 12, 36, 24]

    elif model_name == "densenet169":
        growth_rate = 32
        nb_filter = 64
        nb_layers = [6, 12, 32, 32]

    else:
        print("model is not supported")
        exit(1)

    # Load our model
    if strategy:
        with strategy.scope():
            model = densenet_model(img_rows=img_rows, img_cols=img_cols, color_type=channel,
                                   dropout_rate=dropout_rate, weight_decay=weight_decay, num_classes=num_classes,
                                   growth_rate=growth_rate, nb_filter=nb_filter, nb_layers=nb_layers)
            optimizer = get_optimizer(
                model_name, optim_name, initial_lr, epsilon=1e-2)
            model.compile(optimizer=optimizer,
                          loss='categorical_crossentropy', metrics=['accuracy'])
    else:
        model = densenet_model(img_rows=img_rows, img_cols=img_cols, color_type=channel,
                               dropout_rate=dropout_rate, weight_decay=weight_decay, num_classes=num_classes,
                               growth_rate=growth_rate, nb_filter=nb_filter, nb_layers=nb_layers)
        optimizer = get_optimizer(
            model_name, optim_name, initial_lr, epsilon=1e-2)
        model.compile(optimizer=optimizer,
                      loss='categorical_crossentropy', metrics=['accuracy'])

    # Start training
    steps_per_epoch = 1281167 // batch_size
    if args.steps_per_epoch is not None:
        steps_per_epoch = args.steps_per_epoch
    validation_steps = 50000 // batch_size
    if args.validation_steps is not None:
        validation_steps = args.validation_steps
    warmup_steps = args.warmup_epochs * steps_per_epoch
    lr_sched = {0: 1, 30: 0.1, 60: 0.01, 80: 0.001}
    lr_sched_steps = {
        epoch * steps_per_epoch: multiplier for (epoch, multiplier) in lr_sched.items()}

    lrate = StepLearningRateScheduleWithWarmup(initial_lr=initial_lr,
                                               initial_global_step=0,
                                               warmup_steps=warmup_steps,
                                               decay_schedule=lr_sched_steps,
                                               verbose=0)

    save_name = model_name if not model_name.endswith('.h5') else \
        os.path.split(model_name)[-1].split('.')[0].split('-')[0]

    model_ckpt = tf.keras.callbacks.ModelCheckpoint(
        os.path.join(args.model_dir, config.SAVE_DIR,
                     save_name) + '-ckpt-{epoch:03d}.h5',
        monitor='train_loss')

    callbacks = [lrate, model_ckpt]

    if save_summary_steps is not None and save_summary_steps > 0:
        log_dir = os.path.join(args.model_dir, config.LOG_DIR)
        local_batch_size = batch_size
        
        if args.use_hpu_strategy:
            log_dir = os.path.join(log_dir, 'worker_' + str(comm_rank()))
            local_batch_size = batch_size // strategy.num_replicas_in_sync

        callbacks += [
            TensorBoardWithHParamsV2(
                args.__dict__, log_dir=log_dir,
                update_freq=save_summary_steps, profile_batch=0),
            ExamplesPerSecondKerasHookV2(
                save_summary_steps, output_dir=log_dir,
                batch_size=local_batch_size),
        ]

    if (args.evaluate_checkpoint_path is not None):
        model.load_weights(args.evaluate_checkpoint_path)
        results = model.evaluate(x=ds_valid, steps=validation_steps)
        print("Test loss, Test acc:", results)
        exit()

    if ((resume_from_epoch is not None) and (resume_from_checkpoint_path is not None)):
        model.load_weights(resume_from_checkpoint_path)

    if args.deterministic:
        set_deterministic()
        if not os.path.isfile(args.dump_config):
            raise FileNotFoundError("wrong dump config path")

        import pickle
        x_path = os.path.join(args.inputs, "input")
        y_path = os.path.join(args.inputs, "target")
        x = pickle.load(open(x_path, 'rb'))
        y = pickle.load(open(y_path, 'rb'))

        with dump_callback(args.dump_config):
          model.fit(x=x, y=y,
                  steps_per_epoch=steps_per_epoch,
                  callbacks=callbacks,
                  initial_epoch=resume_from_epoch,
                  epochs=nb_epoch,
                  shuffle=False,
                  verbose=verbose,
                  validation_data=None,
                  validation_steps=0,
                  )
    else:
      ds_train = get_dataset(dataset_dir, args.train_subset, batch_size)
      ds_valid = get_dataset(dataset_dir, args.val_subset, batch_size)

      model.fit(x=ds_train, y=None,
                steps_per_epoch=steps_per_epoch,
                callbacks=callbacks,
                initial_epoch=resume_from_epoch,
                epochs=nb_epoch,
                shuffle=True,
                verbose=verbose,
                validation_data=(ds_valid, None),
                validation_steps=validation_steps,
                validation_freq=1,
                )
Esempio n. 5
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

    Args:
      flags_obj: An object containing parsed flag values.

    Raises:
      ValueError: If fp16 is passed as it is not currently supported.

    Returns:
      Dictionary of training and eval stats.
    """

    keras_utils.set_session_config(
        enable_eager=flags_obj.enable_eager,
        enable_xla=flags_obj.enable_xla,
        enable_scoped_allocator=flags_obj.enable_scoped_allocator)
    # Enable habana bf16 conversion pass only if native keras mixed precision is disabled
    if flags.FLAGS.dtype == 'bf16' and flags.FLAGS.use_keras_mixed_precision == False:
        performance.set_mixed_precision_policy(tf.float32)
        os.environ['TF_BF16_CONVERSION'] = flags.FLAGS.bf16_config_path
    else:
        performance.set_mixed_precision_policy(
            flags_core.get_tf_dtype(flags_obj))

    os.environ.setdefault("TF_DISABLE_MKL", "1")
    os.environ.setdefault("TF_ALLOW_CONTROL_EDGES_IN_HABANA_OPS", "1")

    # This only affects GPU.
    common.set_cudnn_batchnorm_mode()

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)
    batch_size = adjust_batch_size(flags_obj.batch_size)

    if horovod_enabled():
        model_dir = os.path.join(flags_obj.model_dir,
                                 "worker_" + str(hvd.rank()))
    else:
        model_dir = flags_obj.model_dir

    hls_addresses = str(os.environ.get("MULTI_HLS_IPS",
                                       "127.0.0.1")).split(",")
    TF_BASE_PORT = 2410
    mpi_rank = comm_rank()
    mpi_size = comm_size()

    worker_hosts = ",".join([
        ",".join([
            address + ':' + str(TF_BASE_PORT + rank)
            for rank in range(mpi_size // len(hls_addresses))
        ]) for address in hls_addresses
    ])
    task_index = mpi_rank

    # Configures cluster spec for distribution strategy.
    _ = distribution_utils.configure_cluster(worker_hosts, task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    train_writer, eval_writer = None, None
    if flags_obj.enable_tensorboard:
        train_writer = tf.summary.create_file_writer(model_dir)
        eval_writer = tf.summary.create_file_writer(
            os.path.join(model_dir, 'eval'))
        hparams = flags_obj.flag_values_dict()
        hparams.setdefault('precision', flags_obj.dtype)
        write_hparams_v2(train_writer, hparams)

    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)
    train_steps = train_epochs * per_epoch_steps

    logging.info(
        'Training %d epochs, each epoch has %d steps, '
        'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
        train_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(
        batch_size,
        flags_obj.log_steps,
        summary_writer=train_writer,
        batch_size_per_node=flags_obj.batch_size)
    profiler_callback = None
    if flags_obj.profile_steps is not None:
        profiler_callback = keras_utils.get_profiler_callback(
            model_dir, flags_obj.profile_steps, flags_obj.enable_tensorboard,
            per_epoch_steps)
    with distribution_utils.get_strategy_scope(strategy):
        runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                                  train_steps, per_epoch_steps,
                                                  profiler_callback)

    eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
    checkpoint_interval = (per_epoch_steps
                           if flags_obj.enable_checkpoint_and_export else None)
    summary_interval = flags_obj.log_steps if flags_obj.enable_tensorboard else None

    checkpoint_manager = tf.train.CheckpointManager(
        runnable.checkpoint,
        directory=model_dir,
        max_to_keep=10,
        step_counter=runnable.global_step,
        checkpoint_interval=checkpoint_interval)

    train_steps = per_epoch_steps * train_epochs

    resnet_controller = controller.Controller(
        strategy,
        runnable.train,
        runnable.evaluate,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        train_steps=train_steps,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        eval_steps=eval_steps,
        eval_interval=eval_interval,
        train_summary_writer=train_writer,
        eval_summary_writer=eval_writer)

    time_callback.on_train_begin()
    resnet_controller.train(evaluate=not flags_obj.skip_eval)
    time_callback.on_train_end()

    stats = build_stats(runnable, time_callback)
    return stats
Esempio n. 6
0
def main(_):
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
    params = train_utils.parse_configuration(FLAGS)

    if params.runtime.num_hpus > 0:
        import os
        #TODO: remove when SW-49334 is fixed [SW-49404]
        os.environ["TF_DISABLE_EAGER_TO_FUNC_REWRITER"] = "1"
        from habana_frameworks.tensorflow import load_habana_module
        load_habana_module()

    if params.task.train_data.deterministic or params.task.validation_data.deterministic:
        import os
        os.environ['PYTHONHASHSEED'] = '0'
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
        import numpy
        numpy.random.seed(0)
        import tensorflow as tf
        tf.random.set_seed(0)
        tf.compat.v1.set_random_seed(0)
        import random
        random.seed(0)

    if FLAGS.dtype == "bf16":
        print("Using bf16 config list {}".format(FLAGS.bf16_config_path))
        os.environ['TF_BF16_CONVERSION'] = FLAGS.bf16_config_path

    hls_addresses = str(os.environ.get("MULTI_HLS_IPS",
                                       "127.0.0.1")).split(",")
    TF_BASE_PORT = 2410
    mpi_rank = comm_rank()
    mpi_size = comm_size()

    if params.runtime.num_hpus > 1:
        model_dir = os.path.join(FLAGS.model_dir, "worker_" + str(mpi_rank))
    else:
        model_dir = FLAGS.model_dir

    #prepare a comma-seperated list of device addreses
    worker_list = []
    for address in hls_addresses:
        for rank in range(mpi_size // len(hls_addresses)):
            worker_list.append(address + ':' + str(TF_BASE_PORT + rank))
    worker_hosts = ",".join(worker_list)
    task_index = mpi_rank

    # Configures cluster spec for distribution strategy.
    distribution_utils.configure_cluster(worker_hosts, task_index)
    if 'train' in FLAGS.mode:
        # Pure eval modes do not output yaml files. Otherwise continuous eval job
        # may race against the train job for writing the same file.
        train_utils.serialize_config(params, model_dir)

    # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
    # can have significant impact on model speeds by utilizing float16 in case of
    # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
    # dtype is float16
    if params.runtime.mixed_precision_dtype:
        performance.set_mixed_precision_policy(
            params.runtime.mixed_precision_dtype)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        num_hpus=params.runtime.num_hpus,
        tpu_address=params.runtime.tpu)

    with distribution_strategy.scope():
        task = task_factory.get_task(params.task, logging_dir=model_dir)

    train_lib.run_experiment(distribution_strategy=distribution_strategy,
                             task=task,
                             mode=FLAGS.mode,
                             params=params,
                             model_dir=model_dir)

    train_utils.save_gin_config(FLAGS.mode, model_dir)
Esempio n. 7
0
def main():
    parser = argparse.ArgumentParser(description=DESCRIPTION)
    parser.add_argument('--dataset', '--dataset_dir', metavar='PATH',
                        default=config.DEFAULT_DATASET_DIR, help='Dataset directory.')
    parser.add_argument('--optimizer', default='sgd',
                        choices=['sgd', 'adam', 'rmsprop'], help='Optimizer.')
    parser.add_argument('-d', '--dtype', default='fp32',
                        choices=['fp32', 'bf16'], help='Data type.')
    parser.add_argument('--batch_size', type=int,
                        default=32, help='Global batch size.')
    parser.add_argument('--lr_sched', default='WarmupCosine', choices=[
                        'linear', 'exp', 'steps', 'constant', 'WarmupCosine'], help='Learning rate scheduler.')
    parser.add_argument('--initial_lr', type=float,
                        default=6e-2, help='Initial learning rate.')
    parser.add_argument('--final_lr', type=float,
                        default=1e-5, help='Final learning rate.')
    parser.add_argument('--warmup_steps', type=int,
                        default=4000, help='Warmup steps.')
    parser.add_argument('--epochs', type=int, default=10,
                        help='Total number of epochs for training.')
    parser.add_argument('--steps_per_epoch', type=int,
                        help='Number of steps for training per epoch, overrides default value.')
    parser.add_argument('--validation_steps', type=int,
                        help='Number of steps for validation, overrides default value.')
    parser.add_argument('--model', default='ViT-B_16',
                        choices=['ViT-B_16', 'ViT-L_16', 'ViT-B_32', 'ViT-L_32'], help='Model.')
    parser.add_argument('--train_subset', default='train',
                        help='Pattern to detect train subset in dataset directory.')
    parser.add_argument('--val_subset', default='validation',
                        help='Pattern to detect validation subset in dataset directory.')
    parser.add_argument('--grad_accum_steps', type=int,
                        default=8, help='Gradient accumulation steps.')
    parser.add_argument('--resume_from_checkpoint_path',
                        metavar='PATH', help='Path to checkpoint to start from.')
    parser.add_argument('--resume_from_epoch', metavar='EPOCH_INDEX',
                        type=int, default=0, help='Initial epoch index.')
    parser.add_argument('--evaluate_checkpoint_path', metavar='PATH',
                        help='Checkpoint path for evaluating the model on --val_subset')
    parser.add_argument('--weights_path', metavar='PATH',
                        help='Path to weights cache directory. ~/.keras is used if not set.')
    parser.add_argument('--deterministic', action='store_true', default=False,
                        help='Enable deterministic behavior, this will also disable data augmentation. --seed must be set.')
    parser.add_argument('--seed', type=int,
                        help='Seed to be used by random functions.')
    parser.add_argument('--device', default='HPU',
                        choices=['CPU', 'HPU'], help='Device type.')
    parser.add_argument('--distributed', action='store_true',
                        default=False, help='Enable distributed training.')
    parser.add_argument('--base_tf_server_port', type=int,
                        default=7850, help='Rank 0 port used by tf.distribute.')
    parser.add_argument('--save_summary_steps', type=int, default=0,
                        help='Steps between saving summaries to TensorBoard.')
    parser.add_argument('--recipe_cache', default='/tmp/vit_recipe_cache',
                        help='Path to recipe cache directory. Set to empty to disable recipe cache. Externally set \'TF_RECIPE_CACHE_PATH\' will override this setting.')
    parser.add_argument(
        '--dump_config', help='Side-by-side config file. Internal, do not use.')
    args = parser.parse_args()

    if args.weights_path is not None:
        config.WEIGHTS_DIR = args.weights_path

    if args.dtype == 'bf16':
        tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')

    if args.device == 'HPU':
        if args.distributed:
            os.environ['TF_HCCL_MEMORY_ALLOWANCE_MB'] = '500'
        from habana_frameworks.tensorflow import load_habana_module
        from habana_frameworks.tensorflow.ops.layer_norm import HabanaLayerNormalization
        load_habana_module()
        tf.keras.layers.LayerNormalization = HabanaLayerNormalization

        # Handle recipe caching.
        recipe_cache = args.recipe_cache
        if 'TF_RECIPE_CACHE_PATH' not in os.environ.keys() and recipe_cache:
            os.environ['TF_RECIPE_CACHE_PATH'] = recipe_cache

        # Clear previous recipe cache.
        if not args.distributed or comm_rank() == 0:
            if os.path.exists(recipe_cache) and os.path.isdir(recipe_cache):
                import shutil
                shutil.rmtree(recipe_cache)
        # Wait for rank 0 to remove cache.
        if args.distributed:
            from mpi4py import MPI
            MPI.COMM_WORLD.Barrier()

    # Handle determinism.
    config.DETERMINISTIC = args.deterministic
    config.SEED = args.seed
    if args.deterministic:
        assert args.seed is not None, "Deterministic behavior require seed to be set."
        tf.config.threading.set_inter_op_parallelism_threads(1)
        tf.config.threading.set_intra_op_parallelism_threads(1)
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
        config.DATA_AUGMENTATION = False
    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        tf.random.set_seed(args.seed)

    # Handle distribution strategy.
    if args.distributed:
        tf_distribute_config(args.base_tf_server_port)
        if args.device == 'HPU':
            os.environ['HBN_TF_REGISTER_DATASETOPS'] = '1'
            from habana_frameworks.tensorflow.distribute import HPUStrategy
            strategy = HPUStrategy()
        else:
            strategy = tf.distribute.MultiWorkerMirroredStrategy()
    else:
        strategy = tf.distribute.OneDeviceStrategy(f'device:{args.device}:0')

    if not args.distributed or comm_rank() == 0:
        print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

    num_classes = 1000
    batch_size = args.batch_size
    nb_epoch = args.epochs
    dataset = args.dataset
    resume_from_checkpoint_path = args.resume_from_checkpoint_path
    resume_from_epoch = args.resume_from_epoch
    optim_name = args.optimizer
    initial_lr = args.initial_lr
    final_lr = args.final_lr
    lr_sched = args.lr_sched
    warmup_steps = args.warmup_steps
    model_name = args.model
    grad_accum_steps = args.grad_accum_steps

    ds_train = get_dataset(dataset, args.train_subset, batch_size,
                           is_training=True, distributed=args.distributed)
    ds_valid = get_dataset(dataset, args.val_subset,
                           batch_size, False, distributed=args.distributed)

    if args.dump_config is not None:
        vit.CONFIG_B['dropout'] = 0.0
        vit.CONFIG_L['dropout'] = 0.0

    # Load our model
    with strategy.scope():
        image_size = 384
        if model_name == 'ViT-B_16':
            model = vit.vit_b16(
                image_size=image_size,
                activation='softmax',
                pretrained=True,
                include_top=True,
                pretrained_top=False,
                classes=num_classes,
                weights="imagenet21k")
        elif model_name == 'ViT-L_16':
            model = vit.vit_l16(
                image_size=image_size,
                activation='softmax',
                pretrained=True,
                include_top=True,
                pretrained_top=False,
                classes=num_classes,
                weights="imagenet21k")
        elif model_name == 'ViT-B_32':
            model = vit.vit_b32(
                image_size=image_size,
                activation='softmax',
                pretrained=True,
                include_top=True,
                pretrained_top=False,
                classes=num_classes,
                weights="imagenet21k")
        elif model_name == 'ViT-L_32':
            model = vit.vit_l32(
                image_size=image_size,
                activation='softmax',
                pretrained=True,
                include_top=True,
                pretrained_top=False,
                classes=num_classes,
                weights="imagenet21k")
        else:
            print(
                "Model is not supported, please use either ViT-B_16 or ViT-L_16 or ViT-B_32 or ViT-L_32")
            exit(0)

        optimizer = get_optimizer(
            optim_name, initial_lr, accumulation_steps=grad_accum_steps, epsilon=1e-2)
        model.compile(optimizer=optimizer, loss='categorical_crossentropy',
                      metrics=['accuracy'], run_eagerly=False)

        # Start training

        steps_per_epoch = 1281167 // batch_size
        if args.steps_per_epoch is not None:
            steps_per_epoch = args.steps_per_epoch
        validation_steps = 50000 // batch_size
        if args.validation_steps is not None:
            validation_steps = args.validation_steps

        total_steps = nb_epoch * steps_per_epoch
        resume_step = resume_from_epoch * steps_per_epoch

        lrate = get_lr_func(nb_epoch, lr_sched, initial_lr,
                            final_lr, warmup_steps, resume_step, total_steps)

        save_name = model_name if not model_name.endswith('.h5') else \
            os.path.split(model_name)[-1].split('.')[0].split('-')[0]
        model_ckpt = tf.keras.callbacks.ModelCheckpoint(
            os.path.join(config.SAVE_DIR, save_name) + '-ckpt-{epoch:03d}.h5',
            monitor='train_loss')

        callbacks = [lrate, model_ckpt]
        if args.save_summary_steps > 0:
            callbacks += [TensorBoardWithHParamsV2(
                vars(args), log_dir=config.LOG_DIR, update_freq=args.save_summary_steps)]
            callbacks += [ExamplesPerSecondKerasHookV2(
                output_dir=config.LOG_DIR, every_n_steps=args.save_summary_steps, batch_size=args.batch_size)]

        if (args.evaluate_checkpoint_path is not None):
            model.load_weights(args.evaluate_checkpoint_path)
            results = model.evaluate(x=ds_valid, steps=validation_steps)
            print("Test loss, Test acc:", results)
            exit()

        if ((resume_from_epoch is not None) and (resume_from_checkpoint_path is not None)):
            model.load_weights(resume_from_checkpoint_path)

        with dump_callback(args.dump_config):
            model.fit(x=ds_train, y=None,
                      steps_per_epoch=steps_per_epoch,
                      callbacks=callbacks,
                      initial_epoch=resume_from_epoch,
                      epochs=nb_epoch,
                      shuffle=not args.deterministic,
                      verbose=1 if not args.distributed else comm_rank() == 0,
                      validation_data=(ds_valid, None),
                      validation_steps=validation_steps,
                      )

        if not args.distributed or comm_rank() == 0:
            model.save(f'{config.SAVE_DIR}/{save_name}-model-final.h5')