def model_fn(features, mode, params): '''The model_fn to be used with TPUEstimator. Args: features: `Tensor` of batched images. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model ''' def preprocess_image(image): # In most cases, the default data format NCHW instead of NHWC should be # used for a significant performance boost on GPU. NHWC should be used # only if the network needs to be run on CPU since the pooling operations # are only supported on NHWC. TPU uses XLA compiler to figure out best layout. if FLAGS.data_format == 'channels_first': assert not FLAGS.transpose_input # channels_first only for GPU image = tf.transpose(image, [0, 3, 1, 2]) if FLAGS.transpose_input and mode == tf.estimator.ModeKeys.TRAIN: image = tf.transpose(image, [3, 0, 1, 2]) # HWCN to NHWC return image def normalize_image(image): # Normalize the image to zero mean and unit variance. if FLAGS.data_format == 'channels_first': stats_shape = [3, 1, 1] else: stats_shape = [1, 1, 3] mean, std = task_info.get_mean_std(FLAGS.task_name) image -= tf.constant(mean, shape=stats_shape, dtype=image.dtype) image /= tf.constant(std, shape=stats_shape, dtype=image.dtype) return image image = features['image'] image = preprocess_image(image) image_shape = image.get_shape().as_list() tf.logging.info('image shape: {}'.format(image_shape)) is_training = (mode == tf.estimator.ModeKeys.TRAIN) if mode != tf.estimator.ModeKeys.PREDICT: labels = features['label'] else: labels = None # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable if FLAGS.unlabel_ratio and is_training: unl_bsz = features['unl_probs'].shape[0] else: unl_bsz = 0 lab_bsz = image.shape[0] - unl_bsz assert lab_bsz == batch_size metric_dict = {} global_step = tf.train.get_global_step() has_moving_average_decay = (FLAGS.moving_average_decay > 0) # This is essential, if using a keras-derived model. tf.keras.backend.set_learning_phase(is_training) tf.logging.info('Using open-source implementation.') override_params = {} if FLAGS.dropout_rate is not None: override_params['dropout_rate'] = FLAGS.dropout_rate if FLAGS.stochastic_depth_rate is not None: override_params['stochastic_depth_rate'] = FLAGS.stochastic_depth_rate if FLAGS.data_format: override_params['data_format'] = FLAGS.data_format if FLAGS.num_label_classes: override_params['num_classes'] = FLAGS.num_label_classes if FLAGS.depth_coefficient: override_params['depth_coefficient'] = FLAGS.depth_coefficient if FLAGS.width_coefficient: override_params['width_coefficient'] = FLAGS.width_coefficient def build_model(scope=None, reuse=tf.AUTO_REUSE, model_name=None, model_is_training=None, input_image=None, use_adv_bn=False, is_teacher=False): model_name = model_name or FLAGS.model_name if model_is_training is None: model_is_training = is_training if input_image is None: input_image = image input_image = normalize_image(input_image) scope_model_name = model_name if scope: scope = scope + '/' else: scope = '' with tf.variable_scope(scope + scope_model_name, reuse=reuse): if model_name.startswith('efficientnet'): logits, _ = efficientnet_builder.build_model( input_image, model_name=model_name, training=model_is_training, override_params=override_params, model_dir=FLAGS.model_dir, use_adv_bn=use_adv_bn, is_teacher=is_teacher) else: assert False, 'model {} not implemented'.format(model_name) return logits if params['use_bfloat16']: with tf.tpu.bfloat16_scope(): logits = tf.cast(build_model(), tf.float32) else: logits = build_model() if FLAGS.teacher_model_name: teacher_image = preprocess_image(features['teacher_image']) if params['use_bfloat16']: with tf.tpu.bfloat16_scope(): teacher_logits = tf.cast( build_model(scope='teacher_model', model_name=FLAGS.teacher_model_name, model_is_training=False, input_image=teacher_image, is_teacher=True), tf.float32) else: teacher_logits = build_model(scope='teacher_model', model_name=FLAGS.teacher_model_name, model_is_training=False, input_image=teacher_image, is_teacher=True) teacher_logits = tf.stop_gradient(teacher_logits) if FLAGS.teacher_softmax_temp != -1: teacher_prob = tf.nn.softmax(teacher_logits / FLAGS.teacher_softmax_temp) else: teacher_prob = None teacher_one_hot_pred = tf.argmax(teacher_logits, axis=1, output_type=labels.dtype) if mode == tf.estimator.ModeKeys.PREDICT: if has_moving_average_decay: ema = tf.train.ExponentialMovingAverage( decay=FLAGS.moving_average_decay) ema_vars = utils.get_all_variable() restore_vars_dict = ema.variables_to_restore(ema_vars) tf.logging.info( 'restored variables:\n%s', json.dumps(sorted(restore_vars_dict.keys()), indent=4)) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=functools.partial(_scaffold_fn, restore_vars_dict=restore_vars_dict) if has_moving_average_decay else None) if has_moving_average_decay: ema_step = global_step ema = tf.train.ExponentialMovingAverage( decay=FLAGS.moving_average_decay, num_updates=ema_step) ema_vars = utils.get_all_variable() lab_labels = labels[:lab_bsz] lab_logits = logits[:lab_bsz] lab_pred = tf.argmax(lab_logits, axis=-1, output_type=labels.dtype) lab_prob = tf.nn.softmax(lab_logits) lab_acc = tf.to_float(tf.equal(lab_pred, lab_labels)) metric_dict['lab/acc'] = tf.reduce_mean(lab_acc) metric_dict['lab/pred_prob'] = tf.reduce_mean( tf.reduce_max(lab_prob, axis=-1)) one_hot_labels = tf.one_hot(lab_labels, FLAGS.num_label_classes) if FLAGS.unlabel_ratio: unl_labels = labels[lab_bsz:] unl_logits = logits[lab_bsz:] unl_pred = tf.argmax(unl_logits, axis=-1, output_type=labels.dtype) unl_prob = tf.nn.softmax(unl_logits) unl_acc = tf.to_float(tf.equal(unl_pred, unl_labels)) metric_dict['unl/acc_to_dump'] = tf.reduce_mean(unl_acc) metric_dict['unl/pred_prob'] = tf.reduce_mean( tf.reduce_max(unl_prob, axis=-1)) # compute lab_loss one_hot_labels = tf.one_hot(lab_labels, FLAGS.num_label_classes) lab_loss = tf.losses.softmax_cross_entropy( logits=lab_logits, onehot_labels=one_hot_labels, label_smoothing=FLAGS.label_smoothing, reduction=tf.losses.Reduction.NONE) if FLAGS.label_data_sample_prob != 1: # mask out part of the labeled data random_mask = tf.floor( FLAGS.label_data_sample_prob + tf.random_uniform(tf.shape(lab_loss), dtype=lab_loss.dtype)) lab_loss = tf.reduce_mean(lab_loss * random_mask) else: lab_loss = tf.reduce_mean(lab_loss) metric_dict['lab/loss'] = lab_loss if FLAGS.unlabel_ratio: if FLAGS.teacher_softmax_temp == -1: # Hard labels # Get one-hot labels if FLAGS.teacher_model_name: ext_teacher_pred = teacher_one_hot_pred[lab_bsz:] one_hot_labels = tf.one_hot(ext_teacher_pred, FLAGS.num_label_classes) else: one_hot_labels = tf.one_hot(unl_labels, FLAGS.num_label_classes) # Compute cross entropy unl_loss = tf.losses.softmax_cross_entropy( logits=unl_logits, onehot_labels=one_hot_labels, label_smoothing=FLAGS.label_smoothing) else: # Soft labels # Get teacher prob if FLAGS.teacher_model_name: unl_teacher_prob = teacher_prob[lab_bsz:] else: scaled_prob = tf.pow(features['unl_probs'], 1 / FLAGS.teacher_softmax_temp) unl_teacher_prob = scaled_prob / tf.reduce_sum( scaled_prob, axis=-1, keepdims=True) metric_dict['unl/target_prob'] = tf.reduce_mean( tf.reduce_max(unl_teacher_prob, axis=-1)) unl_loss = cross_entropy(unl_teacher_prob, unl_logits, return_mean=True) metric_dict['ext/loss'] = unl_loss else: unl_loss = 0 real_lab_bsz = tf.to_float(lab_bsz) * FLAGS.label_data_sample_prob real_unl_bsz = batch_size * FLAGS.label_data_sample_prob * FLAGS.unlabel_ratio data_loss = lab_loss * real_lab_bsz + unl_loss * real_unl_bsz data_loss = data_loss / real_lab_bsz # Add weight decay to the loss for non-batch-normalization variables. loss = data_loss + FLAGS.weight_decay * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) metric_dict['train/data_loss'] = data_loss metric_dict['train/loss'] = loss host_call = None restore_vars_dict = None if is_training: # Compute the current epoch and associated learning rate from global_step. current_epoch = (tf.cast(global_step, tf.float32) / params['steps_per_epoch']) real_train_batch_size = FLAGS.train_batch_size real_train_batch_size *= FLAGS.label_data_sample_prob scaled_lr = FLAGS.base_learning_rate * (real_train_batch_size / 256.0) if FLAGS.final_base_lr: # total number of training epochs total_epochs = FLAGS.train_steps * FLAGS.train_batch_size * 1. / FLAGS.num_train_images - 5 decay_times = math.log(FLAGS.final_base_lr / FLAGS.base_learning_rate) / math.log(0.97) decay_epochs = total_epochs / decay_times tf.logging.info( 'setting decay_epochs to {:.2f}'.format(decay_epochs) + '\n' * 3) else: decay_epochs = 2.4 * FLAGS.train_ratio learning_rate = utils.build_learning_rate( scaled_lr, global_step, params['steps_per_epoch'], decay_epochs=decay_epochs, start_from_step=FLAGS.train_steps - FLAGS.train_last_step_num, warmup_epochs=5, ) metric_dict['train/lr'] = learning_rate metric_dict['train/epoch'] = current_epoch optimizer = utils.build_optimizer(learning_rate) if FLAGS.use_tpu: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = tf.tpu.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) tvars = tf.trainable_variables() g_vars = [] tvars = sorted(tvars, key=lambda var: var.name) for var in tvars: if 'teacher_model' not in var.name: g_vars += [var] with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step, var_list=g_vars) if has_moving_average_decay: with tf.control_dependencies([train_op]): train_op = ema.apply(ema_vars) if not FLAGS.skip_host_call: host_call = utils.construct_scalar_host_call(metric_dict) scaffold_fn = None if FLAGS.teacher_model_name or FLAGS.init_model: scaffold_fn = utils.init_from_ckpt(scaffold_fn) else: train_op = None if has_moving_average_decay: # Load moving average variables for eval. restore_vars_dict = ema.variables_to_restore(ema_vars) eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: scaffold_fn = functools.partial(_scaffold_fn, restore_vars_dict=restore_vars_dict ) if has_moving_average_decay else None def metric_fn(labels, logits): '''Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. ''' predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) result_dict = { 'top_1_accuracy': top_1_accuracy, 'top_5_accuracy': top_5_accuracy, } return result_dict eval_metrics = (metric_fn, [labels, logits]) num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info('number of trainable parameters: {}'.format(num_params)) return tf.estimator.tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)
def model_fn(features, labels, mode, params): sup_labels = tf.reshape(features["label"], [-1]) #### Configuring the optimizer global_step = tf.train.get_global_step() metric_dict = {} is_training = (mode == tf.estimator.ModeKeys.TRAIN) if FLAGS.unsup_ratio > 0 and is_training: all_images = tf.concat([features["image"], features["ori_image"], features["aug_image"]], 0) else: all_images = features["image"] with tf.variable_scope("model", reuse=tf.AUTO_REUSE): all_logits = build_model( inputs=all_images, num_classes=FLAGS.num_classes, is_training=is_training, update_bn=True and is_training, hparams=hparams, ) sup_bsz = tf.shape(features["image"])[0] sup_logits = all_logits[:sup_bsz] sup_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=sup_labels, logits=sup_logits) sup_prob = tf.nn.softmax(sup_logits, axis=-1) metric_dict["sup/pred_prob"] = tf.reduce_mean( tf.reduce_max(sup_prob, axis=-1)) if FLAGS.tsa: sup_loss, avg_sup_loss = anneal_sup_loss(sup_logits, sup_labels, sup_loss, global_step, metric_dict) else: avg_sup_loss = tf.reduce_mean(sup_loss) total_loss = avg_sup_loss if FLAGS.unsup_ratio > 0 and is_training: aug_bsz = tf.shape(features["ori_image"])[0] ori_logits = all_logits[sup_bsz : sup_bsz + aug_bsz] aug_logits = all_logits[sup_bsz + aug_bsz:] if FLAGS.uda_softmax_temp != -1: ori_logits_tgt = ori_logits / FLAGS.uda_softmax_temp else: ori_logits_tgt = ori_logits ori_prob = tf.nn.softmax(ori_logits, axis=-1) aug_prob = tf.nn.softmax(aug_logits, axis=-1) metric_dict["unsup/ori_prob"] = tf.reduce_mean( tf.reduce_max(ori_prob, axis=-1)) metric_dict["unsup/aug_prob"] = tf.reduce_mean( tf.reduce_max(aug_prob, axis=-1)) aug_loss = _kl_divergence_with_logits( p_logits=tf.stop_gradient(ori_logits_tgt), q_logits=aug_logits) if FLAGS.uda_confidence_thresh != -1: ori_prob = tf.nn.softmax(ori_logits, axis=-1) largest_prob = tf.reduce_max(ori_prob, axis=-1) loss_mask = tf.cast(tf.greater( largest_prob, FLAGS.uda_confidence_thresh), tf.float32) metric_dict["unsup/high_prob_ratio"] = tf.reduce_mean(loss_mask) loss_mask = tf.stop_gradient(loss_mask) aug_loss = aug_loss * loss_mask metric_dict["unsup/high_prob_loss"] = tf.reduce_mean(aug_loss) if FLAGS.ent_min_coeff > 0: ent_min_coeff = FLAGS.ent_min_coeff metric_dict["unsup/ent_min_coeff"] = ent_min_coeff per_example_ent = get_ent(ori_logits) ent_min_loss = tf.reduce_mean(per_example_ent) total_loss = total_loss + ent_min_coeff * ent_min_loss avg_unsup_loss = tf.reduce_mean(aug_loss) total_loss += FLAGS.unsup_coeff * avg_unsup_loss metric_dict["unsup/loss"] = avg_unsup_loss total_loss = utils.decay_weights( total_loss, FLAGS.weight_decay_rate) #### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info("#params: {}".format(num_params)) if FLAGS.verbose: format_str = "{{:<{0}s}}\t{{}}".format( max([len(v.name) for v in tf.trainable_variables()])) for v in tf.trainable_variables(): tf.logging.info(format_str.format(v.name, v.get_shape())) if FLAGS.moving_average_decay > 0.: ema = tf.train.ExponentialMovingAverage( decay=FLAGS.moving_average_decay) ema_vars = utils.get_all_variable() #### Evaluation mode if mode == tf.estimator.ModeKeys.EVAL: if FLAGS.moving_average_decay > 0: restore_vars_dict = ema.variables_to_restore(ema_vars) scaffold_fn = functools.partial( _scaffold_fn, restore_vars_dict=restore_vars_dict) if FLAGS.moving_average_decay > 0 else None else: scaffold_fn = None #### Metric function for classification def metric_fn(per_example_loss, label_ids, logits): # classification loss & accuracy loss = tf.metrics.mean(per_example_loss) predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) accuracy = tf.metrics.accuracy(label_ids, predictions) ret_dict = { "eval/classify_loss": loss, "eval/classify_accuracy": accuracy } return ret_dict eval_metrics = (metric_fn, [sup_loss, sup_labels, sup_logits]) #### Constucting evaluation TPUEstimatorSpec. eval_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn, ) return eval_spec # increase the learning rate linearly if FLAGS.warmup_steps > 0: warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \ * FLAGS.learning_rate else: warmup_lr = 0.0 # decay the learning rate using the cosine schedule lrate = tf.clip_by_value(tf.to_float(global_step-FLAGS.warmup_steps) / (FLAGS.train_steps-FLAGS.warmup_steps), 0, 1) decay_lr = FLAGS.learning_rate * tf.cos(lrate * (7. / 8) * np.pi / 2) learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr, decay_lr) optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=0.9, use_nesterov=True) if FLAGS.use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) grads_and_vars = optimizer.compute_gradients(total_loss) gradients, variables = zip(*grads_and_vars) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.apply_gradients( zip(gradients, variables), global_step=tf.train.get_global_step()) if FLAGS.moving_average_decay > 0: with tf.control_dependencies([train_op]): train_op = ema.apply(ema_vars) #### Creating training logging hook # compute accuracy sup_pred = tf.argmax(sup_logits, axis=-1, output_type=sup_labels.dtype) is_correct = tf.to_float(tf.equal(sup_pred, sup_labels)) acc = tf.reduce_mean(is_correct) metric_dict["sup/sup_loss"] = avg_sup_loss metric_dict["training/loss"] = total_loss metric_dict["sup/acc"] = acc metric_dict["training/lr"] = learning_rate metric_dict["training/step"] = global_step if not FLAGS.use_tpu: log_info = ("step [{training/step}] lr {training/lr:.6f} " "loss {training/loss:.4f} " "sup/acc {sup/acc:.4f} sup/loss {sup/sup_loss:.6f} ") if FLAGS.unsup_ratio > 0: log_info += "unsup/loss {unsup/loss:.6f} " formatter = lambda kwargs: log_info.format(**kwargs) logging_hook = tf.train.LoggingTensorHook( tensors=metric_dict, every_n_iter=FLAGS.iterations, formatter=formatter) training_hooks = [logging_hook] #### Constucting training TPUEstimatorSpec. train_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=training_hooks) else: #### Constucting training TPUEstimatorSpec. host_call = utils.construct_scalar_host_call( metric_dict=metric_dict, model_dir=params["model_dir"], prefix="", reduce_fn=tf.reduce_mean) train_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, host_call=host_call) return train_spec