def model_fn(features, mode, params): '''The model_fn to be used with TPUEstimator. Args: features: `Tensor` of batched images. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model ''' def preprocess_image(image): # In most cases, the default data format NCHW instead of NHWC should be # used for a significant performance boost on GPU. NHWC should be used # only if the network needs to be run on CPU since the pooling operations # are only supported on NHWC. TPU uses XLA compiler to figure out best layout. if FLAGS.data_format == 'channels_first': assert not FLAGS.transpose_input # channels_first only for GPU image = tf.transpose(image, [0, 3, 1, 2]) if FLAGS.transpose_input and mode == tf.estimator.ModeKeys.TRAIN: image = tf.transpose(image, [3, 0, 1, 2]) # HWCN to NHWC return image def normalize_image(image): # Normalize the image to zero mean and unit variance. if FLAGS.data_format == 'channels_first': stats_shape = [3, 1, 1] else: stats_shape = [1, 1, 3] mean, std = task_info.get_mean_std(FLAGS.task_name) image -= tf.constant(mean, shape=stats_shape, dtype=image.dtype) image /= tf.constant(std, shape=stats_shape, dtype=image.dtype) return image image = features['image'] image = preprocess_image(image) image_shape = image.get_shape().as_list() tf.logging.info('image shape: {}'.format(image_shape)) is_training = (mode == tf.estimator.ModeKeys.TRAIN) if mode != tf.estimator.ModeKeys.PREDICT: labels = features['label'] else: labels = None # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable if FLAGS.unlabel_ratio and is_training: unl_bsz = features['unl_probs'].shape[0] else: unl_bsz = 0 lab_bsz = image.shape[0] - unl_bsz assert lab_bsz == batch_size metric_dict = {} global_step = tf.train.get_global_step() has_moving_average_decay = (FLAGS.moving_average_decay > 0) # This is essential, if using a keras-derived model. tf.keras.backend.set_learning_phase(is_training) tf.logging.info('Using open-source implementation.') override_params = {} if FLAGS.dropout_rate is not None: override_params['dropout_rate'] = FLAGS.dropout_rate if FLAGS.stochastic_depth_rate is not None: override_params['stochastic_depth_rate'] = FLAGS.stochastic_depth_rate if FLAGS.data_format: override_params['data_format'] = FLAGS.data_format if FLAGS.num_label_classes: override_params['num_classes'] = FLAGS.num_label_classes if FLAGS.depth_coefficient: override_params['depth_coefficient'] = FLAGS.depth_coefficient if FLAGS.width_coefficient: override_params['width_coefficient'] = FLAGS.width_coefficient def build_model(scope=None, reuse=tf.AUTO_REUSE, model_name=None, model_is_training=None, input_image=None, use_adv_bn=False, is_teacher=False): model_name = model_name or FLAGS.model_name if model_is_training is None: model_is_training = is_training if input_image is None: input_image = image input_image = normalize_image(input_image) scope_model_name = model_name if scope: scope = scope + '/' else: scope = '' with tf.variable_scope(scope + scope_model_name, reuse=reuse): if model_name.startswith('efficientnet'): logits, _ = efficientnet_builder.build_model( input_image, model_name=model_name, training=model_is_training, override_params=override_params, model_dir=FLAGS.model_dir, use_adv_bn=use_adv_bn, is_teacher=is_teacher) else: assert False, 'model {} not implemented'.format(model_name) return logits if params['use_bfloat16']: with tf.tpu.bfloat16_scope(): logits = tf.cast(build_model(), tf.float32) else: logits = build_model() if FLAGS.teacher_model_name: teacher_image = preprocess_image(features['teacher_image']) if params['use_bfloat16']: with tf.tpu.bfloat16_scope(): teacher_logits = tf.cast( build_model(scope='teacher_model', model_name=FLAGS.teacher_model_name, model_is_training=False, input_image=teacher_image, is_teacher=True), tf.float32) else: teacher_logits = build_model(scope='teacher_model', model_name=FLAGS.teacher_model_name, model_is_training=False, input_image=teacher_image, is_teacher=True) teacher_logits = tf.stop_gradient(teacher_logits) if FLAGS.teacher_softmax_temp != -1: teacher_prob = tf.nn.softmax(teacher_logits / FLAGS.teacher_softmax_temp) else: teacher_prob = None teacher_one_hot_pred = tf.argmax(teacher_logits, axis=1, output_type=labels.dtype) if mode == tf.estimator.ModeKeys.PREDICT: if has_moving_average_decay: ema = tf.train.ExponentialMovingAverage( decay=FLAGS.moving_average_decay) ema_vars = utils.get_all_variable() restore_vars_dict = ema.variables_to_restore(ema_vars) tf.logging.info( 'restored variables:\n%s', json.dumps(sorted(restore_vars_dict.keys()), indent=4)) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=functools.partial(_scaffold_fn, restore_vars_dict=restore_vars_dict) if has_moving_average_decay else None) if has_moving_average_decay: ema_step = global_step ema = tf.train.ExponentialMovingAverage( decay=FLAGS.moving_average_decay, num_updates=ema_step) ema_vars = utils.get_all_variable() lab_labels = labels[:lab_bsz] lab_logits = logits[:lab_bsz] lab_pred = tf.argmax(lab_logits, axis=-1, output_type=labels.dtype) lab_prob = tf.nn.softmax(lab_logits) lab_acc = tf.to_float(tf.equal(lab_pred, lab_labels)) metric_dict['lab/acc'] = tf.reduce_mean(lab_acc) metric_dict['lab/pred_prob'] = tf.reduce_mean( tf.reduce_max(lab_prob, axis=-1)) one_hot_labels = tf.one_hot(lab_labels, FLAGS.num_label_classes) if FLAGS.unlabel_ratio: unl_labels = labels[lab_bsz:] unl_logits = logits[lab_bsz:] unl_pred = tf.argmax(unl_logits, axis=-1, output_type=labels.dtype) unl_prob = tf.nn.softmax(unl_logits) unl_acc = tf.to_float(tf.equal(unl_pred, unl_labels)) metric_dict['unl/acc_to_dump'] = tf.reduce_mean(unl_acc) metric_dict['unl/pred_prob'] = tf.reduce_mean( tf.reduce_max(unl_prob, axis=-1)) # compute lab_loss one_hot_labels = tf.one_hot(lab_labels, FLAGS.num_label_classes) lab_loss = tf.losses.softmax_cross_entropy( logits=lab_logits, onehot_labels=one_hot_labels, label_smoothing=FLAGS.label_smoothing, reduction=tf.losses.Reduction.NONE) if FLAGS.label_data_sample_prob != 1: # mask out part of the labeled data random_mask = tf.floor( FLAGS.label_data_sample_prob + tf.random_uniform(tf.shape(lab_loss), dtype=lab_loss.dtype)) lab_loss = tf.reduce_mean(lab_loss * random_mask) else: lab_loss = tf.reduce_mean(lab_loss) metric_dict['lab/loss'] = lab_loss if FLAGS.unlabel_ratio: if FLAGS.teacher_softmax_temp == -1: # Hard labels # Get one-hot labels if FLAGS.teacher_model_name: ext_teacher_pred = teacher_one_hot_pred[lab_bsz:] one_hot_labels = tf.one_hot(ext_teacher_pred, FLAGS.num_label_classes) else: one_hot_labels = tf.one_hot(unl_labels, FLAGS.num_label_classes) # Compute cross entropy unl_loss = tf.losses.softmax_cross_entropy( logits=unl_logits, onehot_labels=one_hot_labels, label_smoothing=FLAGS.label_smoothing) else: # Soft labels # Get teacher prob if FLAGS.teacher_model_name: unl_teacher_prob = teacher_prob[lab_bsz:] else: scaled_prob = tf.pow(features['unl_probs'], 1 / FLAGS.teacher_softmax_temp) unl_teacher_prob = scaled_prob / tf.reduce_sum( scaled_prob, axis=-1, keepdims=True) metric_dict['unl/target_prob'] = tf.reduce_mean( tf.reduce_max(unl_teacher_prob, axis=-1)) unl_loss = cross_entropy(unl_teacher_prob, unl_logits, return_mean=True) metric_dict['ext/loss'] = unl_loss else: unl_loss = 0 real_lab_bsz = tf.to_float(lab_bsz) * FLAGS.label_data_sample_prob real_unl_bsz = batch_size * FLAGS.label_data_sample_prob * FLAGS.unlabel_ratio data_loss = lab_loss * real_lab_bsz + unl_loss * real_unl_bsz data_loss = data_loss / real_lab_bsz # Add weight decay to the loss for non-batch-normalization variables. loss = data_loss + FLAGS.weight_decay * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) metric_dict['train/data_loss'] = data_loss metric_dict['train/loss'] = loss host_call = None restore_vars_dict = None if is_training: # Compute the current epoch and associated learning rate from global_step. current_epoch = (tf.cast(global_step, tf.float32) / params['steps_per_epoch']) real_train_batch_size = FLAGS.train_batch_size real_train_batch_size *= FLAGS.label_data_sample_prob scaled_lr = FLAGS.base_learning_rate * (real_train_batch_size / 256.0) if FLAGS.final_base_lr: # total number of training epochs total_epochs = FLAGS.train_steps * FLAGS.train_batch_size * 1. / FLAGS.num_train_images - 5 decay_times = math.log(FLAGS.final_base_lr / FLAGS.base_learning_rate) / math.log(0.97) decay_epochs = total_epochs / decay_times tf.logging.info( 'setting decay_epochs to {:.2f}'.format(decay_epochs) + '\n' * 3) else: decay_epochs = 2.4 * FLAGS.train_ratio learning_rate = utils.build_learning_rate( scaled_lr, global_step, params['steps_per_epoch'], decay_epochs=decay_epochs, start_from_step=FLAGS.train_steps - FLAGS.train_last_step_num, warmup_epochs=5, ) metric_dict['train/lr'] = learning_rate metric_dict['train/epoch'] = current_epoch optimizer = utils.build_optimizer(learning_rate) if FLAGS.use_tpu: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = tf.tpu.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) tvars = tf.trainable_variables() g_vars = [] tvars = sorted(tvars, key=lambda var: var.name) for var in tvars: if 'teacher_model' not in var.name: g_vars += [var] with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step, var_list=g_vars) if has_moving_average_decay: with tf.control_dependencies([train_op]): train_op = ema.apply(ema_vars) if not FLAGS.skip_host_call: host_call = utils.construct_scalar_host_call(metric_dict) scaffold_fn = None if FLAGS.teacher_model_name or FLAGS.init_model: scaffold_fn = utils.init_from_ckpt(scaffold_fn) else: train_op = None if has_moving_average_decay: # Load moving average variables for eval. restore_vars_dict = ema.variables_to_restore(ema_vars) eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: scaffold_fn = functools.partial(_scaffold_fn, restore_vars_dict=restore_vars_dict ) if has_moving_average_decay else None def metric_fn(labels, logits): '''Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. ''' predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) result_dict = { 'top_1_accuracy': top_1_accuracy, 'top_5_accuracy': top_5_accuracy, } return result_dict eval_metrics = (metric_fn, [labels, logits]) num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info('number of trainable parameters: {}'.format(num_params)) return tf.estimator.tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)
def model_fn(features, labels, mode, params): sup_labels = tf.reshape(features["label"], [-1]) #### Configuring the optimizer global_step = tf.train.get_global_step() metric_dict = {} is_training = (mode == tf.estimator.ModeKeys.TRAIN) if FLAGS.unsup_ratio > 0 and is_training: all_images = tf.concat([ features["image"], features["ori_image"], features["aug_image"] ], 0) else: all_images = features["image"] with tf.variable_scope("model", reuse=tf.AUTO_REUSE): all_logits = build_model( inputs=all_images, num_classes=FLAGS.num_classes, is_training=is_training, update_bn=True and is_training, hparams=hparams, ) sup_bsz = tf.shape(features["image"])[0] sup_logits = all_logits[:sup_bsz] sup_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=sup_labels, logits=sup_logits) sup_prob = tf.nn.softmax(sup_logits, axis=-1) metric_dict["sup/pred_prob"] = tf.reduce_mean( tf.reduce_max(sup_prob, axis=-1)) if FLAGS.tsa: sup_loss, avg_sup_loss = anneal_sup_loss(sup_logits, sup_labels, sup_loss, global_step, metric_dict) else: avg_sup_loss = tf.reduce_mean(sup_loss) total_loss = avg_sup_loss if FLAGS.unsup_ratio > 0 and is_training: aug_bsz = tf.shape(features["ori_image"])[0] ori_logits = all_logits[sup_bsz:sup_bsz + aug_bsz] aug_logits = all_logits[sup_bsz + aug_bsz:] if FLAGS.uda_softmax_temp != -1: ori_logits_tgt = ori_logits / FLAGS.uda_softmax_temp else: ori_logits_tgt = ori_logits ori_prob = tf.nn.softmax(ori_logits, axis=-1) aug_prob = tf.nn.softmax(aug_logits, axis=-1) metric_dict["unsup/ori_prob"] = tf.reduce_mean( tf.reduce_max(ori_prob, axis=-1)) metric_dict["unsup/aug_prob"] = tf.reduce_mean( tf.reduce_max(aug_prob, axis=-1)) aug_loss = _kl_divergence_with_logits( p_logits=tf.stop_gradient(ori_logits_tgt), q_logits=aug_logits) if FLAGS.uda_confidence_thresh != -1: ori_prob = tf.nn.softmax(ori_logits, axis=-1) largest_prob = tf.reduce_max(ori_prob, axis=-1) loss_mask = tf.cast( tf.greater(largest_prob, FLAGS.uda_confidence_thresh), tf.float32) metric_dict["unsup/high_prob_ratio"] = tf.reduce_mean( loss_mask) loss_mask = tf.stop_gradient(loss_mask) aug_loss = aug_loss * loss_mask metric_dict["unsup/high_prob_loss"] = tf.reduce_mean(aug_loss) if FLAGS.ent_min_coeff > 0: ent_min_coeff = FLAGS.ent_min_coeff metric_dict["unsup/ent_min_coeff"] = ent_min_coeff per_example_ent = get_ent(ori_logits) ent_min_loss = tf.reduce_mean(per_example_ent) total_loss = total_loss + ent_min_coeff * ent_min_loss avg_unsup_loss = tf.reduce_mean(aug_loss) total_loss += FLAGS.unsup_coeff * avg_unsup_loss metric_dict["unsup/loss"] = avg_unsup_loss total_loss = utils.decay_weights(total_loss, FLAGS.weight_decay_rate) #### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info("#params: {}".format(num_params)) if FLAGS.verbose: format_str = "{{:<{0}s}}\t{{}}".format( max([len(v.name) for v in tf.trainable_variables()])) for v in tf.trainable_variables(): tf.logging.info(format_str.format(v.name, v.get_shape())) #### Evaluation mode if mode == tf.estimator.ModeKeys.EVAL: #### Metric function for classification def metric_fn(per_example_loss, label_ids, logits): # classification loss & accuracy loss = tf.metrics.mean(per_example_loss) predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) accuracy = tf.metrics.accuracy(label_ids, predictions) ret_dict = { "eval/classify_loss": loss, "eval/classify_accuracy": accuracy } return ret_dict eval_metrics = (metric_fn, [sup_loss, sup_labels, sup_logits]) #### Constucting evaluation TPUEstimatorSpec. eval_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics) return eval_spec # increase the learning rate linearly if FLAGS.warmup_steps > 0: warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \ * FLAGS.learning_rate else: warmup_lr = 0.0 # decay the learning rate using the cosine schedule decay_lr = tf.train.cosine_decay( FLAGS.learning_rate, global_step=global_step - FLAGS.warmup_steps, decay_steps=FLAGS.train_steps - FLAGS.warmup_steps, alpha=FLAGS.min_lr_ratio) learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr, decay_lr) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9, use_nesterov=True) if FLAGS.use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) grads_and_vars = optimizer.compute_gradients(total_loss) gradients, variables = zip(*grads_and_vars) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.apply_gradients( zip(gradients, variables), global_step=tf.train.get_global_step()) #### Creating training logging hook # compute accuracy sup_pred = tf.argmax(sup_logits, axis=-1, output_type=sup_labels.dtype) is_correct = tf.to_float(tf.equal(sup_pred, sup_labels)) acc = tf.reduce_mean(is_correct) metric_dict["sup/sup_loss"] = avg_sup_loss metric_dict["training/loss"] = total_loss metric_dict["sup/acc"] = acc metric_dict["training/lr"] = learning_rate metric_dict["training/step"] = global_step if not FLAGS.use_tpu: log_info = ("step [{training/step}] lr {training/lr:.6f} " "loss {training/loss:.4f} " "sup/acc {sup/acc:.4f} sup/loss {sup/sup_loss:.6f} ") if FLAGS.unsup_ratio > 0: log_info += "unsup/loss {unsup/loss:.6f} " formatter = lambda kwargs: log_info.format(**kwargs) logging_hook = tf.train.LoggingTensorHook( tensors=metric_dict, every_n_iter=FLAGS.iterations, formatter=formatter) training_hooks = [logging_hook] #### Constucting training TPUEstimatorSpec. train_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=training_hooks) else: #### Constucting training TPUEstimatorSpec. host_call = utils.construct_scalar_host_call( metric_dict=metric_dict, model_dir=params["model_dir"], prefix="", reduce_fn=tf.reduce_mean) train_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=total_loss, train_op=train_op, host_call=host_call) return train_spec
def model_fn(features, labels, mode, params): sup_labels = tf.reshape(features["label"], [-1]) #### Configuring the optimizer global_step = tf.train.get_global_step() metric_dict = {} is_training = (mode == tf.estimator.ModeKeys.TRAIN) if FLAGS.unsup_ratio > 0 and is_training: all_images = tf.concat([ features["image"], features["ori_image"], features["aug_image"] ], 0) else: all_images = features["image"] with tf.variable_scope("model", reuse=tf.AUTO_REUSE): all_logits = build_model( inputs=all_images, num_classes=FLAGS.num_classes, feature_dim=128, is_training=is_training, update_bn=True and is_training, hparams=hparams, ) sup_bsz = tf.shape(features["image"])[0] sup_logits = all_logits[0][:sup_bsz] print('sup_buz') print(sup_bsz) sup_features = all_logits[1][:sup_bsz] map_dict = read_pkl() tmp_list = [x.numpy() for x in map_dict.values()] pedcc_features_all = np.concatenate(tmp_list) def f0(): return tmp_list[0] def f1(): return tmp_list[1] def f2(): return tmp_list[2] def f3(): return tmp_list[3] def f4(): return tmp_list[4] def f5(): return tmp_list[5] def f6(): return tmp_list[6] def f7(): return tmp_list[7] def f8(): return tmp_list[8] def f9(): return tmp_list[9] def f10(): pass for i in range(FLAGS.train_batch_size): tmp = sup_labels[i] test = tf.case( { tf.equal(tmp, 0): f0, tf.equal(tmp, 1): f1, tf.equal(tmp, 2): f2, tf.equal(tmp, 3): f3, tf.equal(tmp, 4): f4, tf.equal(tmp, 5): f5, tf.equal(tmp, 6): f6, tf.equal(tmp, 7): f7, tf.equal(tmp, 8): f8, tf.equal(tmp, 9): f9 }, exclusive=True) if i == 0: feature_label = test else: feature_label = tf.concat([feature_label, test], axis=0) pedcc_features = tf.cast(feature_label, dtype=tf.float32) mse_loss = tf.reduce_mean(tf.square(sup_features - pedcc_features)) loss_2 = AM_loss(sup_logits, sup_labels) sup_loss = mse_loss + loss_2 sup_prob = tf.nn.softmax(sup_logits, axis=-1) metric_dict["sup/pred_prob"] = tf.reduce_mean( tf.reduce_max(sup_prob, axis=-1)) avg_sup_loss = tf.reduce_mean(sup_loss) total_loss = avg_sup_loss if FLAGS.unsup_ratio > 0 and is_training: aug_bsz = tf.shape(features["ori_image"])[0] ori_logits = all_logits[0][sup_bsz:sup_bsz + aug_bsz] ori_features = all_logits[1][sup_bsz:sup_bsz + aug_bsz] aug_logits = all_logits[0][sup_bsz + aug_bsz:] ori_logits_tgt = ori_logits ori_prob = tf.nn.softmax(ori_logits, axis=-1) aug_prob = tf.nn.softmax(aug_logits, axis=-1) metric_dict["unsup/ori_prob"] = tf.reduce_mean( tf.reduce_max(ori_prob, axis=-1)) metric_dict["unsup/aug_prob"] = tf.reduce_mean( tf.reduce_max(aug_prob, axis=-1)) for i in range( 0, int(FLAGS.train_batch_size * FLAGS.unsup_ratio / 10 - 1)): ## # print(i) if i == 0: pedcc_features_sum = tf.concat( [pedcc_features_all, pedcc_features_all], axis=0) else: pedcc_features_sum = tf.concat( [pedcc_features_sum, pedcc_features_all], axis=0) pedcc_features_sum = tf.cast(pedcc_features_sum, dtype=tf.float32) mmd_loss = mmd_rbf(ori_features, pedcc_features_sum) mmd_loss = mmd_loss * 0.2 aug_loss = _kl_divergence_with_logits( p_logits=tf.stop_gradient(ori_logits_tgt), q_logits=aug_logits) avg_unsup_loss = tf.reduce_mean(aug_loss) avg_unsup_loss = avg_unsup_loss * 400 total_loss += FLAGS.unsup_coeff * avg_unsup_loss total_loss += mmd_loss metric_dict["unsup/mmd_loss"] = mmd_loss metric_dict["unsup/loss"] = avg_unsup_loss total_loss = utils.decay_weights(total_loss, FLAGS.weight_decay_rate) #### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info("#params: {}".format(num_params)) #### Evaluation mode if mode == tf.estimator.ModeKeys.EVAL: #### Metric function for classification def metric_fn(per_example_loss, label_ids, logits): # classification loss & accuracy loss = tf.metrics.mean(per_example_loss) predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) accuracy = tf.metrics.accuracy(label_ids, predictions) ret_dict = { "eval/classify_loss": loss, "eval/classify_accuracy": accuracy } return ret_dict eval_metrics = (metric_fn, [sup_loss, sup_labels, sup_logits]) #### Constucting evaluation TPUEstimatorSpec. eval_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics) return eval_spec # increase the learning rate linearly if FLAGS.warmup_steps > 0: warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \ * FLAGS.learning_rate else: warmup_lr = 0.0 # decay the learning rate using the cosine schedule decay_lr = tf.train.cosine_decay( FLAGS.learning_rate, global_step=global_step - FLAGS.warmup_steps, decay_steps=FLAGS.train_steps - FLAGS.warmup_steps, alpha=FLAGS.min_lr_ratio) learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr, decay_lr) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9, use_nesterov=True) #### use_tpu =false ### if FLAGS.use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) grads_and_vars = optimizer.compute_gradients(total_loss) gradients, variables = zip(*grads_and_vars) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.apply_gradients( zip(gradients, variables), global_step=tf.train.get_global_step()) #### Creating training logging hook # compute accuracy sup_pred = tf.argmax(sup_logits, axis=-1, output_type=sup_labels.dtype) is_correct = tf.to_float(tf.equal(sup_pred, sup_labels)) acc = tf.reduce_mean(is_correct) metric_dict["sup/sup_loss"] = avg_sup_loss metric_dict["training/loss"] = total_loss metric_dict["sup/acc"] = acc metric_dict["training/lr"] = learning_rate metric_dict["training/step"] = global_step if not FLAGS.use_tpu: log_info = ("step [{training/step}] lr {training/lr:.6f} " "loss {training/loss:.4f} " "sup/acc {sup/acc:.4f} sup/loss {sup/sup_loss:.6f} ") if FLAGS.unsup_ratio > 0: log_info += "unsup/loss {unsup/loss:.6f} " log_info += "unsup/mmd_loss {unsup/mmd_loss:.6f} " formatter = lambda kwargs: log_info.format(**kwargs) logging_hook = tf.train.LoggingTensorHook( tensors=metric_dict, every_n_iter=FLAGS.iterations, formatter=formatter) training_hooks = [logging_hook] #### Constucting training TPUEstimatorSpec. train_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=training_hooks) else: #### Constucting training TPUEstimatorSpec. host_call = utils.construct_scalar_host_call( metric_dict=metric_dict, model_dir=params["model_dir"], prefix="", reduce_fn=tf.reduce_mean) train_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=total_loss, train_op=train_op, host_call=host_call) return train_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = GroverModel( config=config, is_training=is_training, input_ids=input_ids, pad_token_id=config.pad_token_id, chop_off_last_token=True, ) total_loss = model.lm_loss() print(model.logits_flat) print(total_loss) if is_training: train_op, train_metrics = create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) tvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) else: train_op = None train_metrics = {} tvars = tf.trainable_variables() params_sum = np.sum([ np.prod(v.get_shape().as_list()) for v in tf.trainable_variables() ]) tf.logging.info("**** Trainable params_sum ****") tf.logging.info(params_sum) initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names) = get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, host_call=construct_scalar_host_call( metric_dict=train_metrics, model_dir=params['model_dir'], prefix='training/'), scaffold_fn=scaffold_fn) else: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=[ tf.train.LoggingTensorHook( { "train_loss": total_loss, "global_step": tf.train.global_step }, every_n_iter=10) ], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(total_loss): loss = tf.metrics.mean(values=total_loss) return { "eval_loss": loss, } eval_metrics = (metric_fn, [total_loss]) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: gt_logprobs = tf.squeeze(tf.batch_gather( model.log_probs, model.target_ids[:, :, None]), axis=2) # Need top-p required under topp sampling! better_than_gt = model.log_probs > gt_logprobs[:, :, None] top_p_required = tf.reduce_sum( tf.cast(better_than_gt, tf.float32) * tf.exp(model.log_probs), axis=2) # No top-p sampling for now, since this seems to be too slow on TPUs if use_tpu: predictions = tf.reshape( tf.random.categorical(logits=model.logits_flat, num_samples=1), get_shape_list(model.target_ids), ) else: # Argmax # predictions = tf.math.argmax(model.log_probs, axis=-1, output_type=tf.int32) predictions = tf.reshape( _top_p_sample(model.logits_flat, num_samples=1, p=0.99)['sample'], get_shape_list(model.target_ids), ) pred_logprobs = tf.squeeze(tf.batch_gather(model.log_probs, predictions[:, :, None]), axis=2) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions={ 'gt_logprobs': gt_logprobs, 'top_p_required': top_p_required, 'predictions': predictions, 'pred_logprobs': pred_logprobs, 'labels': input_ids }, scaffold_fn=scaffold_fn) return output_spec