def make_params(input_fn): params = {} params['hparams'] = hparams.HParams() params['hparams'].input_data.input_fn = input_fn.__name__ params['use_tpu'] = True params['data_dir'] = None return params
def make_trainer(self, mode): return launcher.ContrastiveTrainer( model_inputs=tf.random.normal( shape=[self.batch_size, 224, 224, 3 * self.num_views], mean=0, stddev=.1, dtype=tf.float32), labels=tf.random.uniform( shape=[self.batch_size], minval=0, maxval=self.num_classes, dtype=tf.int32), train_global_batch_size=self.batch_size, hparams=hparams.HParams( bs=self.batch_size, eval=hparams.Eval(batch_size=self.batch_size)), mode=mode, num_classes=self.num_classes, training_set_size=self.training_set_size, is_tpu=False)
def hparams_from_flags(): return hparams.HParams( bs=FLAGS.batch_size, architecture=hparams.Architecture( encoder_architecture=FLAGS.resnet_architecture, encoder_depth=FLAGS.resnet_depth, encoder_width=FLAGS.resnet_width, first_conv_kernel_size=FLAGS.first_conv_kernel_size, first_conv_stride=FLAGS.first_conv_stride, use_initial_max_pool=FLAGS.use_initial_max_pool, projection_head_layers=tuple(map(int, FLAGS.projection_head_layers)), projection_head_use_batch_norm=FLAGS.use_projection_batch_norm, projection_head_use_batch_norm_beta=( FLAGS.use_projection_batch_norm_beta), normalize_projection_head_inputs=FLAGS.normalize_embedding, normalize_classifier_inputs=FLAGS.normalize_embedding, zero_initialize_classifier=FLAGS.zero_initialize_classifier, stop_gradient_before_classification_head=( FLAGS.stop_gradient_before_classification_head), stop_gradient_before_projection_head=( FLAGS.stop_gradient_before_projection_head), use_global_batch_norm=FLAGS.use_global_batch_norm), loss_all_stages=hparams.LossAllStages( contrastive=hparams.ContrastiveLoss( use_labels=FLAGS.use_labels, temperature=FLAGS.temperature, contrast_mode=FLAGS.contrast_mode, summation_location=FLAGS.summation_location, denominator_mode=FLAGS.denominator_mode, positives_cap=FLAGS.positives_cap, scale_by_temperature=FLAGS.scale_by_temperature), cross_entropy=hparams.CrossEntropyLoss( label_smoothing=FLAGS.label_smoothing), include_bias_in_weight_decay=FLAGS.use_bias_weight_decay), stage_1=hparams.Stage( training=hparams.TrainingStage( train_epochs=FLAGS.stage_1_epochs, learning_rate_warmup_epochs=FLAGS.stage_1_warmup_epochs, base_learning_rate=FLAGS.stage_1_base_learning_rate, learning_rate_decay=FLAGS.stage_1_learning_rate_decay, decay_rate=FLAGS.stage_1_decay_rate, decay_boundary_epochs=tuple( map(int, FLAGS.stage_1_decay_boundary_epochs)), epochs_per_decay=FLAGS.stage_1_epochs_per_decay, optimizer=FLAGS.stage_1_optimizer, update_encoder_batch_norm=( FLAGS.stage_1_update_encoder_batch_norm), rmsprop_epsilon=FLAGS.stage_1_rmsprop_epsilon), loss=hparams.LossStage( contrastive_weight=FLAGS.stage_1_contrastive_loss_weight, cross_entropy_weight=FLAGS.stage_1_cross_entropy_loss_weight, weight_decay_coeff=FLAGS.stage_1_weight_decay, use_encoder_weight_decay=FLAGS. stage_1_use_encoder_weight_decay, use_projection_head_weight_decay=( FLAGS.stage_1_use_projection_head_weight_decay), use_classification_head_weight_decay=( FLAGS.stage_1_use_classification_head_weight_decay)), ), stage_2=hparams.Stage( training=hparams.TrainingStage( train_epochs=FLAGS.stage_2_epochs, learning_rate_warmup_epochs=FLAGS.stage_2_warmup_epochs, base_learning_rate=FLAGS.stage_2_base_learning_rate, learning_rate_decay=FLAGS.stage_2_learning_rate_decay, decay_rate=FLAGS.stage_2_decay_rate, decay_boundary_epochs=tuple( map(int, FLAGS.stage_2_decay_boundary_epochs)), epochs_per_decay=FLAGS.stage_2_epochs_per_decay, optimizer=FLAGS.stage_2_optimizer, update_encoder_batch_norm=( FLAGS.stage_2_update_encoder_batch_norm), rmsprop_epsilon=FLAGS.stage_2_rmsprop_epsilon), loss=hparams.LossStage( contrastive_weight=FLAGS.stage_2_contrastive_loss_weight, cross_entropy_weight=FLAGS.stage_2_cross_entropy_loss_weight, weight_decay_coeff=FLAGS.stage_2_weight_decay, use_encoder_weight_decay=FLAGS. stage_2_use_encoder_weight_decay, use_projection_head_weight_decay=( FLAGS.stage_2_use_projection_head_weight_decay), use_classification_head_weight_decay=( FLAGS.stage_2_use_classification_head_weight_decay))), eval=hparams.Eval(batch_size=FLAGS.eval_batch_size), input_data=hparams.InputData( input_fn=FLAGS.input_fn, preprocessing=hparams.ImagePreprocessing( allow_mixed_precision=FLAGS.allow_mixed_precision, image_size=FLAGS.image_size, augmentation_type=FLAGS.augmentation_type, augmentation_magnitude=FLAGS.augmentation_magnitude, blur_probability=FLAGS.blur_probability, defer_blurring=FLAGS.defer_blurring, use_pytorch_color_jitter=FLAGS.use_pytorch_color_jitter, apply_whitening=FLAGS.apply_whitening, crop_area_range=tuple(map(float, FLAGS.crop_area_range)), eval_crop_method=FLAGS.eval_crop_method, crop_padding=FLAGS.crop_padding, ), max_samples=FLAGS.num_images, label_noise_prob=FLAGS.label_noise_prob, shard_per_host=FLAGS.shard_per_host), warm_start=hparams.WarmStart( warm_start_classifier=FLAGS.warm_start_classifier, ignore_missing_checkpoint_vars=FLAGS. ignore_missing_checkpoint_vars, warm_start_projection_head=FLAGS.warm_start_projection_head, warm_start_encoder=FLAGS.warm_start_encoder, batch_norm_in_train_mode=FLAGS.batch_norm_in_train_mode, ), )
def main(_): tf.disable_v2_behavior() tf.enable_resource_variables() if FLAGS.hparams is None: hparams = hparams_flags.hparams_from_flags() else: hparams = hparams_lib.HParams(FLAGS.hparams) cluster = None if FLAGS.use_tpu and FLAGS.master is None: if FLAGS.tpu_name: cluster = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) else: cluster = tf.distribute.cluster_resolver.TPUClusterResolver() tf.config.experimental_connect_to_cluster(cluster) tf.tpu.experimental.initialize_tpu_system(cluster) session_config = tf.ConfigProto() # Workaround for https://github.com/tensorflow/tensorflow/issues/26411 where # convolutions (used in blurring) get confused about data-format when used # inside a tf.data pipeline that is run on GPU. if (tf.test.is_built_with_cuda() and not hparams.input_data.preprocessing.defer_blurring): # RewriterConfig.OFF = 2 session_config.graph_options.rewrite_options.layout_optimizer = 2 run_config = tf.estimator.tpu.RunConfig( master=FLAGS.master, cluster=cluster, model_dir=FLAGS.model_dir, save_checkpoints_steps=FLAGS.save_interval_steps, keep_checkpoint_max=FLAGS.max_checkpoints_to_keep, keep_checkpoint_every_n_hours=(FLAGS.keep_checkpoint_interval_secs / (60.0 * 60.0)), log_step_count_steps=100, session_config=session_config, tpu_config=tf.estimator.tpu.TPUConfig( iterations_per_loop=FLAGS.steps_per_loop, per_host_input_for_training=True, experimental_host_call_every_n_steps=FLAGS.summary_interval_steps, tpu_job_name='train_tpu_worker' if FLAGS.mode == 'train' else None, eval_training_input_configuration=( tf.estimator.tpu.InputPipelineConfig.SLICED if FLAGS.use_tpu else tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1))) params = { 'hparams': hparams, 'use_tpu': FLAGS.use_tpu, 'data_dir': FLAGS.data_dir, } estimator = tf.estimator.tpu.TPUEstimator( model_fn=model_fn, use_tpu=FLAGS.use_tpu, config=run_config, params=params, train_batch_size=hparams.bs, eval_batch_size=hparams.eval.batch_size) if hparams.input_data.input_fn not in dir(inputs): raise ValueError('Unknown input_fn: {hparams.input_data.input_fn}') input_fn = getattr(inputs, hparams.input_data.input_fn) training_set_size = inputs.get_num_train_images(hparams) steps_per_epoch = training_set_size / hparams.bs stage_1_epochs = hparams.stage_1.training.train_epochs stage_2_epochs = hparams.stage_2.training.train_epochs total_steps = int((stage_1_epochs + stage_2_epochs) * steps_per_epoch) num_eval_examples = inputs.get_num_eval_images(hparams) eval_steps = num_eval_examples // hparams.eval.batch_size if FLAGS.mode == 'eval': for ckpt_str in tf.train.checkpoints_iterator( FLAGS.model_dir, min_interval_secs=FLAGS.eval_interval_secs, timeout=60 * 60): result = estimator.evaluate(input_fn=input_fn, checkpoint_path=ckpt_str, steps=eval_steps) estimator.export_saved_model( os.path.join(FLAGS.model_dir, 'exports'), lambda: input_fn(tf.estimator.ModeKeys.PREDICT, params), checkpoint_path=ckpt_str) if result['global_step'] >= total_steps: return else: # 'train' or 'train_then_eval'. estimator.train(input_fn=input_fn, max_steps=total_steps) if FLAGS.mode == 'train_then_eval': result = estimator.evaluate(input_fn=input_fn, steps=eval_steps) estimator.export_saved_model( os.path.join(FLAGS.model_dir, 'exports'), lambda: input_fn(tf.estimator.ModeKeys.PREDICT, params))