Exemplo n.º 1
0
def make_params(input_fn):
    params = {}
    params['hparams'] = hparams.HParams()
    params['hparams'].input_data.input_fn = input_fn.__name__
    params['use_tpu'] = True
    params['data_dir'] = None
    return params
 def make_trainer(self, mode):
   return launcher.ContrastiveTrainer(
       model_inputs=tf.random.normal(
           shape=[self.batch_size, 224, 224, 3 * self.num_views],
           mean=0,
           stddev=.1,
           dtype=tf.float32),
       labels=tf.random.uniform(
           shape=[self.batch_size],
           minval=0,
           maxval=self.num_classes,
           dtype=tf.int32),
       train_global_batch_size=self.batch_size,
       hparams=hparams.HParams(
           bs=self.batch_size, eval=hparams.Eval(batch_size=self.batch_size)),
       mode=mode,
       num_classes=self.num_classes,
       training_set_size=self.training_set_size,
       is_tpu=False)
def hparams_from_flags():
    return hparams.HParams(
        bs=FLAGS.batch_size,
        architecture=hparams.Architecture(
            encoder_architecture=FLAGS.resnet_architecture,
            encoder_depth=FLAGS.resnet_depth,
            encoder_width=FLAGS.resnet_width,
            first_conv_kernel_size=FLAGS.first_conv_kernel_size,
            first_conv_stride=FLAGS.first_conv_stride,
            use_initial_max_pool=FLAGS.use_initial_max_pool,
            projection_head_layers=tuple(map(int,
                                             FLAGS.projection_head_layers)),
            projection_head_use_batch_norm=FLAGS.use_projection_batch_norm,
            projection_head_use_batch_norm_beta=(
                FLAGS.use_projection_batch_norm_beta),
            normalize_projection_head_inputs=FLAGS.normalize_embedding,
            normalize_classifier_inputs=FLAGS.normalize_embedding,
            zero_initialize_classifier=FLAGS.zero_initialize_classifier,
            stop_gradient_before_classification_head=(
                FLAGS.stop_gradient_before_classification_head),
            stop_gradient_before_projection_head=(
                FLAGS.stop_gradient_before_projection_head),
            use_global_batch_norm=FLAGS.use_global_batch_norm),
        loss_all_stages=hparams.LossAllStages(
            contrastive=hparams.ContrastiveLoss(
                use_labels=FLAGS.use_labels,
                temperature=FLAGS.temperature,
                contrast_mode=FLAGS.contrast_mode,
                summation_location=FLAGS.summation_location,
                denominator_mode=FLAGS.denominator_mode,
                positives_cap=FLAGS.positives_cap,
                scale_by_temperature=FLAGS.scale_by_temperature),
            cross_entropy=hparams.CrossEntropyLoss(
                label_smoothing=FLAGS.label_smoothing),
            include_bias_in_weight_decay=FLAGS.use_bias_weight_decay),
        stage_1=hparams.Stage(
            training=hparams.TrainingStage(
                train_epochs=FLAGS.stage_1_epochs,
                learning_rate_warmup_epochs=FLAGS.stage_1_warmup_epochs,
                base_learning_rate=FLAGS.stage_1_base_learning_rate,
                learning_rate_decay=FLAGS.stage_1_learning_rate_decay,
                decay_rate=FLAGS.stage_1_decay_rate,
                decay_boundary_epochs=tuple(
                    map(int, FLAGS.stage_1_decay_boundary_epochs)),
                epochs_per_decay=FLAGS.stage_1_epochs_per_decay,
                optimizer=FLAGS.stage_1_optimizer,
                update_encoder_batch_norm=(
                    FLAGS.stage_1_update_encoder_batch_norm),
                rmsprop_epsilon=FLAGS.stage_1_rmsprop_epsilon),
            loss=hparams.LossStage(
                contrastive_weight=FLAGS.stage_1_contrastive_loss_weight,
                cross_entropy_weight=FLAGS.stage_1_cross_entropy_loss_weight,
                weight_decay_coeff=FLAGS.stage_1_weight_decay,
                use_encoder_weight_decay=FLAGS.
                stage_1_use_encoder_weight_decay,
                use_projection_head_weight_decay=(
                    FLAGS.stage_1_use_projection_head_weight_decay),
                use_classification_head_weight_decay=(
                    FLAGS.stage_1_use_classification_head_weight_decay)),
        ),
        stage_2=hparams.Stage(
            training=hparams.TrainingStage(
                train_epochs=FLAGS.stage_2_epochs,
                learning_rate_warmup_epochs=FLAGS.stage_2_warmup_epochs,
                base_learning_rate=FLAGS.stage_2_base_learning_rate,
                learning_rate_decay=FLAGS.stage_2_learning_rate_decay,
                decay_rate=FLAGS.stage_2_decay_rate,
                decay_boundary_epochs=tuple(
                    map(int, FLAGS.stage_2_decay_boundary_epochs)),
                epochs_per_decay=FLAGS.stage_2_epochs_per_decay,
                optimizer=FLAGS.stage_2_optimizer,
                update_encoder_batch_norm=(
                    FLAGS.stage_2_update_encoder_batch_norm),
                rmsprop_epsilon=FLAGS.stage_2_rmsprop_epsilon),
            loss=hparams.LossStage(
                contrastive_weight=FLAGS.stage_2_contrastive_loss_weight,
                cross_entropy_weight=FLAGS.stage_2_cross_entropy_loss_weight,
                weight_decay_coeff=FLAGS.stage_2_weight_decay,
                use_encoder_weight_decay=FLAGS.
                stage_2_use_encoder_weight_decay,
                use_projection_head_weight_decay=(
                    FLAGS.stage_2_use_projection_head_weight_decay),
                use_classification_head_weight_decay=(
                    FLAGS.stage_2_use_classification_head_weight_decay))),
        eval=hparams.Eval(batch_size=FLAGS.eval_batch_size),
        input_data=hparams.InputData(
            input_fn=FLAGS.input_fn,
            preprocessing=hparams.ImagePreprocessing(
                allow_mixed_precision=FLAGS.allow_mixed_precision,
                image_size=FLAGS.image_size,
                augmentation_type=FLAGS.augmentation_type,
                augmentation_magnitude=FLAGS.augmentation_magnitude,
                blur_probability=FLAGS.blur_probability,
                defer_blurring=FLAGS.defer_blurring,
                use_pytorch_color_jitter=FLAGS.use_pytorch_color_jitter,
                apply_whitening=FLAGS.apply_whitening,
                crop_area_range=tuple(map(float, FLAGS.crop_area_range)),
                eval_crop_method=FLAGS.eval_crop_method,
                crop_padding=FLAGS.crop_padding,
            ),
            max_samples=FLAGS.num_images,
            label_noise_prob=FLAGS.label_noise_prob,
            shard_per_host=FLAGS.shard_per_host),
        warm_start=hparams.WarmStart(
            warm_start_classifier=FLAGS.warm_start_classifier,
            ignore_missing_checkpoint_vars=FLAGS.
            ignore_missing_checkpoint_vars,
            warm_start_projection_head=FLAGS.warm_start_projection_head,
            warm_start_encoder=FLAGS.warm_start_encoder,
            batch_norm_in_train_mode=FLAGS.batch_norm_in_train_mode,
        ),
    )
Exemplo n.º 4
0
def main(_):
    tf.disable_v2_behavior()
    tf.enable_resource_variables()

    if FLAGS.hparams is None:
        hparams = hparams_flags.hparams_from_flags()
    else:
        hparams = hparams_lib.HParams(FLAGS.hparams)

    cluster = None
    if FLAGS.use_tpu and FLAGS.master is None:
        if FLAGS.tpu_name:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        else:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
            tf.config.experimental_connect_to_cluster(cluster)
            tf.tpu.experimental.initialize_tpu_system(cluster)

    session_config = tf.ConfigProto()
    # Workaround for https://github.com/tensorflow/tensorflow/issues/26411 where
    # convolutions (used in blurring) get confused about data-format when used
    # inside a tf.data pipeline that is run on GPU.
    if (tf.test.is_built_with_cuda()
            and not hparams.input_data.preprocessing.defer_blurring):
        # RewriterConfig.OFF = 2
        session_config.graph_options.rewrite_options.layout_optimizer = 2
    run_config = tf.estimator.tpu.RunConfig(
        master=FLAGS.master,
        cluster=cluster,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=FLAGS.save_interval_steps,
        keep_checkpoint_max=FLAGS.max_checkpoints_to_keep,
        keep_checkpoint_every_n_hours=(FLAGS.keep_checkpoint_interval_secs /
                                       (60.0 * 60.0)),
        log_step_count_steps=100,
        session_config=session_config,
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.steps_per_loop,
            per_host_input_for_training=True,
            experimental_host_call_every_n_steps=FLAGS.summary_interval_steps,
            tpu_job_name='train_tpu_worker' if FLAGS.mode == 'train' else None,
            eval_training_input_configuration=(
                tf.estimator.tpu.InputPipelineConfig.SLICED if FLAGS.use_tpu
                else tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1)))
    params = {
        'hparams': hparams,
        'use_tpu': FLAGS.use_tpu,
        'data_dir': FLAGS.data_dir,
    }
    estimator = tf.estimator.tpu.TPUEstimator(
        model_fn=model_fn,
        use_tpu=FLAGS.use_tpu,
        config=run_config,
        params=params,
        train_batch_size=hparams.bs,
        eval_batch_size=hparams.eval.batch_size)

    if hparams.input_data.input_fn not in dir(inputs):
        raise ValueError('Unknown input_fn: {hparams.input_data.input_fn}')
    input_fn = getattr(inputs, hparams.input_data.input_fn)

    training_set_size = inputs.get_num_train_images(hparams)
    steps_per_epoch = training_set_size / hparams.bs
    stage_1_epochs = hparams.stage_1.training.train_epochs
    stage_2_epochs = hparams.stage_2.training.train_epochs
    total_steps = int((stage_1_epochs + stage_2_epochs) * steps_per_epoch)

    num_eval_examples = inputs.get_num_eval_images(hparams)
    eval_steps = num_eval_examples // hparams.eval.batch_size

    if FLAGS.mode == 'eval':
        for ckpt_str in tf.train.checkpoints_iterator(
                FLAGS.model_dir,
                min_interval_secs=FLAGS.eval_interval_secs,
                timeout=60 * 60):
            result = estimator.evaluate(input_fn=input_fn,
                                        checkpoint_path=ckpt_str,
                                        steps=eval_steps)
            estimator.export_saved_model(
                os.path.join(FLAGS.model_dir, 'exports'),
                lambda: input_fn(tf.estimator.ModeKeys.PREDICT, params),
                checkpoint_path=ckpt_str)
            if result['global_step'] >= total_steps:
                return
    else:  # 'train' or 'train_then_eval'.
        estimator.train(input_fn=input_fn, max_steps=total_steps)
        if FLAGS.mode == 'train_then_eval':
            result = estimator.evaluate(input_fn=input_fn, steps=eval_steps)
            estimator.export_saved_model(
                os.path.join(FLAGS.model_dir, 'exports'),
                lambda: input_fn(tf.estimator.ModeKeys.PREDICT, params))