Esempio n. 1
0
def get_config(model_name, dataset_cfg, hparam_str=''):
    """Create a keras model for EffNetV2."""
    config = copy.deepcopy(hparams.base_config)
    config.override(effnetv2_configs.get_model_config(model_name))
    config.override(datasets.get_dataset_config(dataset_cfg))
    config.override(hparam_str)
    config.model.num_classes = config.data.num_classes
    return config
Esempio n. 2
0
def main(unused_argv):
    config = copy.deepcopy(hparams.base_config)
    config.override(effnetv2_configs.get_model_config(FLAGS.model_name))
    config.override(datasets.get_dataset_config(FLAGS.dataset_cfg))
    config.override(FLAGS.hparam_str)
    config.override(FLAGS.sweeps)

    train_size = config.train.isize
    eval_size = config.eval.isize
    if train_size <= 16.:
        train_size = int(eval_size * train_size) // 16 * 16
    input_image_size = eval_size if FLAGS.mode == 'eval' else train_size

    if FLAGS.mode == 'train':
        if not tf.io.gfile.exists(FLAGS.model_dir):
            tf.io.gfile.makedirs(FLAGS.model_dir)
        config.save_to_yaml(os.path.join(FLAGS.model_dir, 'config.yaml'))

    train_split = config.train.split or 'train'
    eval_split = config.eval.split or 'eval'
    num_train_images = config.data.splits[train_split].num_images
    num_eval_images = config.data.splits[eval_split].num_images

    if FLAGS.tpu or FLAGS.use_tpu:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
        tpu_cluster_resolver = None

    run_config = tf.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=max(100, config.runtime.iterations_per_loop),
        keep_checkpoint_max=config.runtime.keep_checkpoint_max,
        keep_checkpoint_every_n_hours=(
            config.runtime.keep_checkpoint_every_n_hours),
        log_step_count_steps=config.runtime.log_step_count_steps,
        session_config=tf.ConfigProto(isolate_session_state=True,
                                      log_device_placement=False),
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=config.runtime.iterations_per_loop,
            tpu_job_name=FLAGS.tpu_job_name,
            per_host_input_for_training=(
                tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2)))
    # Initializes model parameters.
    params = dict(steps_per_epoch=num_train_images / config.train.batch_size,
                  image_size=input_image_size,
                  config=config)

    est = tf.estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=config.train.batch_size,
        eval_batch_size=config.eval.batch_size,
        export_to_tpu=FLAGS.export_to_tpu,
        params=params)

    image_dtype = None
    if config.runtime.mixed_precision:
        image_dtype = 'bfloat16' if FLAGS.use_tpu else 'float16'

    train_steps = max(
        config.train.min_steps,
        config.train.epochs * num_train_images // config.train.batch_size)
    dataset_eval = datasets.build_dataset_input(False, input_image_size,
                                                image_dtype, FLAGS.data_dir,
                                                eval_split, config.data)

    if FLAGS.mode == 'eval':
        eval_steps = num_eval_images // config.eval.batch_size
        # Run evaluation when there's a new checkpoint
        for ckpt in tf.train.checkpoints_iterator(FLAGS.model_dir,
                                                  timeout=60 * 60 * 24):
            logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = est.evaluate(input_fn=dataset_eval.input_fn,
                                            steps=eval_steps,
                                            checkpoint_path=ckpt,
                                            name=FLAGS.eval_name)
                elapsed_time = int(time.time() - start_timestamp)
                logging.info('Eval results: %s. Elapsed seconds: %d',
                             eval_results, elapsed_time)
                if FLAGS.archive_ckpt:
                    utils.archive_ckpt(eval_results,
                                       eval_results['eval/acc_top1'], ckpt)

                # Terminate eval job when final checkpoint is reached
                try:
                    current_step = int(os.path.basename(ckpt).split('-')[1])
                except IndexError:
                    logging.info('%s has no global step info: stop!', ckpt)
                    break

                logging.info('Finished step: %d, total %d', current_step,
                             train_steps)
                if current_step >= train_steps:
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                logging.info('Checkpoint %s no longer exists, skip saving.',
                             ckpt)
    else:  # FLAGS.mode == 'train'
        try:
            checkpoint_reader = tf.compat.v1.train.NewCheckpointReader(
                tf.train.latest_checkpoint(FLAGS.model_dir))
            current_step = checkpoint_reader.get_tensor(
                tf.compat.v1.GraphKeys.GLOBAL_STEP)
        except:  # pylint: disable=bare-except
            current_step = 0

        logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', train_steps, config.train.epochs, current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []  # add hooks if needed.
            if not config.train.stages:
                dataset_train = datasets.build_dataset_input(
                    True, input_image_size, image_dtype, FLAGS.data_dir,
                    train_split, config.data)
                est.train(input_fn=dataset_train.input_fn,
                          max_steps=train_steps,
                          hooks=hooks)
            else:
                curr_step = 0
                try:
                    ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
                    curr_step = int(os.path.basename(ckpt).split('-')[1])
                except (IndexError, TypeError):
                    logging.info('%s has no ckpt with valid step.',
                                 FLAGS.model_dir)

                total_stages = config.train.stages
                if config.train.sched:
                    if config.model.dropout_rate:
                        dp_list = np.linspace(0, config.model.dropout_rate,
                                              total_stages)
                    else:
                        dp_list = [None] * total_stages

                    del dp_list
                    ram_list = np.linspace(5, config.data.ram, total_stages)
                    mixup_list = np.linspace(0, config.data.mixup_alpha,
                                             total_stages)
                    cutmix_list = np.linspace(0, config.data.cutmix_alpha,
                                              total_stages)

                ibase = config.data.ibase or (input_image_size / 2)
                # isize_list = np.linspace(ibase, input_image_size, total_stages)
                for stage in range(curr_step // train_steps, total_stages):
                    tf.compat.v1.reset_default_graph()
                    ratio = float(stage + 1) / float(total_stages)
                    max_steps = int(ratio * train_steps)
                    image_size = int(ibase +
                                     (input_image_size - ibase) * ratio)
                    params['image_size'] = image_size

                    if config.train.sched:
                        config.data.ram = ram_list[stage]
                        config.data.mixup_alpha = mixup_list[stage]
                        config.data.cutmix_alpha = cutmix_list[stage]
                        # config.model.dropout_rate = dp_list[stage]

                    ds_lab_cls = datasets.build_dataset_input(
                        True, image_size, image_dtype, FLAGS.data_dir,
                        train_split, config.data)

                    est = tf.estimator.tpu.TPUEstimator(
                        use_tpu=FLAGS.use_tpu,
                        model_fn=model_fn,
                        config=run_config,
                        train_batch_size=config.train.batch_size,
                        eval_batch_size=config.eval.batch_size,
                        export_to_tpu=FLAGS.export_to_tpu,
                        params=params)
                    est.train(input_fn=ds_lab_cls.input_fn,
                              max_steps=max_steps,
                              hooks=hooks)
        else:
            raise ValueError('Unknown mode %s' % FLAGS.mode)
Esempio n. 3
0
def main(_) -> None:
    config = copy.deepcopy(hparams.base_config)
    config.override(effnetv2_configs.get_model_config(FLAGS.model_name))
    config.override(datasets.get_dataset_config(FLAGS.dataset_cfg))
    config.override(FLAGS.hparam_str)
    config.model.num_classes = config.data.num_classes
    strategy = config.runtime.strategy
    if strategy == 'tpu' and not config.model.bn_type:
        config.model.bn_type = 'tpu_bn'

    # log and save config.
    logging.info('config=%s', str(config))
    if 'train' in FLAGS.mode:
        if not tf.io.gfile.exists(FLAGS.model_dir):
            tf.io.gfile.makedirs(FLAGS.model_dir)
        config.save_to_yaml(os.path.join(FLAGS.model_dir, 'config.yaml'))

    if strategy == 'tpu':
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
        tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
        ds_strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)
        logging.info('All devices: %s', tf.config.list_logical_devices('TPU'))
    elif strategy == 'gpus':
        ds_strategy = tf.distribute.MirroredStrategy()
        logging.info('All devices: %s', tf.config.list_physical_devices('GPU'))
    else:
        if tf.config.list_physical_devices('GPU'):
            ds_strategy = tf.distribute.MirroredStrategy(['GPU:0'])
        else:
            ds_strategy = tf.distribute.MirroredStrategy(['CPU:0'])

    with ds_strategy.scope():
        train_split = config.train.split or 'train'
        eval_split = config.eval.split or 'eval'
        num_train_images = config.data.splits[train_split].num_images
        num_eval_images = config.data.splits[eval_split].num_images

        train_size = config.train.isize
        eval_size = config.eval.isize
        if train_size <= 16.:
            train_size = int(eval_size * train_size) // 16 * 16

        image_dtype = None
        if config.runtime.mixed_precision:
            image_dtype = 'bfloat16' if strategy == 'tpu' else 'float16'
            precision = 'mixed_bfloat16' if strategy == 'tpu' else 'mixed_float16'
            policy = tf.keras.mixed_precision.Policy(precision)
            tf.keras.mixed_precision.set_global_policy(policy)

        model = TrainableModel(config.model.model_name,
                               config.model,
                               weight_decay=config.train.weight_decay)

        if config.train.ft_init_ckpt:  # load pretrained ckpt for finetuning.
            model(tf.ones([1, 224, 224, 3]))
            ckpt = config.train.ft_init_ckpt
            utils.restore_tf2_ckpt(model,
                                   ckpt,
                                   exclude_layers=('_head', 'optimizer'))

        steps_per_epoch = num_train_images // config.train.batch_size
        total_steps = steps_per_epoch * config.train.epochs

        scaled_lr = config.train.lr_base * (config.train.batch_size / 256.0)
        scaled_lr_min = config.train.lr_min * (config.train.batch_size / 256.0)
        learning_rate = utils.WarmupLearningRateSchedule(
            scaled_lr,
            steps_per_epoch=steps_per_epoch,
            decay_epochs=config.train.lr_decay_epoch,
            warmup_epochs=config.train.lr_warmup_epoch,
            decay_factor=config.train.lr_decay_factor,
            lr_decay_type=config.train.lr_sched,
            total_steps=total_steps,
            minimal_lr=scaled_lr_min)

        optimizer = build_tf2_optimizer(learning_rate,
                                        optimizer_name=config.train.optimizer)

        model.compile(
            optimizer=optimizer,
            loss=tf.keras.losses.CategoricalCrossentropy(
                label_smoothing=config.train.label_smoothing,
                from_logits=True),
            metrics=[
                tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='acc_top1'),
                tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='acc_top5')
            ],
        )

        ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
            os.path.join(FLAGS.model_dir, 'ckpt-{epoch:d}'),
            verbose=1,
            save_weights_only=True)
        tb_callback = tf.keras.callbacks.TensorBoard(log_dir=FLAGS.model_dir,
                                                     update_freq=100)
        rstr_callback = tf.keras.callbacks.experimental.BackupAndRestore(
            backup_dir=FLAGS.model_dir)

        def get_dataset(training):
            """A shared utility to get input dataset."""
            if training:
                return ds_strategy.distribute_datasets_from_function(
                    datasets.build_dataset_input(
                        True, train_size, image_dtype, FLAGS.data_dir,
                        train_split, config.data).distribute_dataset_fn(
                            config.train.batch_size))
            else:
                return ds_strategy.distribute_datasets_from_function(
                    datasets.build_dataset_input(
                        False, eval_size, image_dtype, FLAGS.data_dir,
                        eval_split, config.data).distribute_dataset_fn(
                            config.eval.batch_size))

        if FLAGS.mode == 'traineval':
            model.fit(
                get_dataset(training=True),
                epochs=config.train.epochs,
                steps_per_epoch=steps_per_epoch,
                validation_data=get_dataset(training=False),
                validation_steps=num_eval_images // config.eval.batch_size,
                callbacks=[ckpt_callback, tb_callback, rstr_callback],
                # don't log spam if running on tpus
                verbose=2 if strategy == 'tpu' else 1,
            )
        elif FLAGS.mode == 'train':
            model.fit(
                get_dataset(training=True),
                epochs=config.train.epochs,
                steps_per_epoch=steps_per_epoch,
                callbacks=[ckpt_callback, tb_callback, rstr_callback],
                verbose=2 if strategy == 'tpu' else 1,
            )
        elif FLAGS.mode == 'eval':
            for ckpt in tf.train.checkpoints_iterator(FLAGS.model_dir,
                                                      timeout=60 * 60 * 24):
                model.load_weights(ckpt)
                eval_results = model.evaluate(
                    get_dataset(training=False),
                    batch_size=config.eval.batch_size,
                    steps=num_eval_images // config.eval.batch_size,
                    callbacks=[tb_callback, rstr_callback],
                    verbose=2 if strategy == 'tpu' else 1,
                )

                try:
                    current_epoch = int(os.path.basename(ckpt).split('-')[1])
                except IndexError:
                    logging.info('%s has no epoch info: stop!', ckpt)
                    break

                logging.info('Epoch: %d, total %d', current_epoch,
                             config.train.epochs)
                if current_epoch >= config.train.epochs:
                    break
        else:
            raise ValueError(f'Invalid mode {FLAGS.mode}')
Esempio n. 4
0
            if dataset == "imagenet21k":
                classes, dropout, survival_prob, load_weights, save_model_suffix = 21843, 1e-6, 1.0, "imagenet21k", "-21k"
            elif dataset == "imagenetft":
                classes, dropout, survival_prob, load_weights, save_model_suffix = 1000, 0.2, 0.8, "imagenet21k-ft1k", "-21k-ft1k"
            else:  # "imagenet"
                classes, dropout, survival_prob, load_weights, save_model_suffix = 1000, 0.2, 0.8, "imagenet", "-imagenet"

            print(">>>> classes = {}, dropout = {}, load_weights = {}, save_model_suffix = {}".format(classes, dropout, load_weights, save_model_suffix))

            """ Define Keras model first just to keep the names start from `0` """
            keras_model = keras_efficientnet_v2.EfficientNetV2(
                model_type=model_type, drop_connect_rate=0, dropout=dropout, num_classes=classes, classifier_activation=None, pretrained=None
            )

            """ Load checkpoints using official defination """
            cc = orign_datasets.get_dataset_config(dataset)
            if cc.get("model", None):
                cc.model.num_classes = cc.data.num_classes
            else:
                cc["model"] = None
            model = orign_effnetv2_model.get_model("efficientnetv2-{}".format(model_type), model_config=cc.model, weights=load_weights)

            """ Save h5 weights if no error happens """
            model.save_weights("convert_temp_aa.h5")

            """ Reload weights with this modified version """
            mm = EffNetV2Model("efficientnetv2-{}".format(model_type), num_classes=classes)
            len(mm(tf.ones([1, 224, 224, 3]), False))
            mm.load_weights("convert_temp_aa.h5")

            """ Define a new model using `mm.call`, as mm is a subclassed model, cannot be saved as h5 """