def model_fn(features, labels, mode, params): output_size = params['output_size'] net = features if FLAGS.data_type == 'float32': network = resnet_model.resnet_v1(resnet_layers, block_fn, num_classes=output_size, data_format='channels_last', filters=filters) net = network(inputs=features, is_training=True) else: with tf.variable_scope('cg', custom_getter=get_custom_getter()): network = resnet_model.resnet_v1(resnet_layers, block_fn, num_classes=output_size, data_format='channels_last', filters=filters) net = network(inputs=features, is_training=True) net = tf.cast(net, tf.float32) onehot_labels = tf.one_hot(labels, output_size) loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=net) learning_rate = tf.train.exponential_decay(0.1, tf.train.get_global_step(), 25000, 0.97) if opt == 'sgd': tf.logging.info('Using SGD optimizer') optimizer = tf.train.GradientDescentOptimizer( learning_rate=learning_rate) elif opt == 'momentum': tf.logging.info('Using Momentum optimizer') optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) elif opt == 'rms': tf.logging.info('Using RMS optimizer') optimizer = tf.train.RMSPropOptimizer(learning_rate, RMSPROP_DECAY, momentum=RMSPROP_MOMENTUM, epsilon=RMSPROP_EPSILON) if FLAGS.use_tpu: optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) param_stats = tf.profiler.profile( tf.get_default_graph(), options=ProfileOptionBuilder.trainable_variables_parameter()) fl_stats = tf.profiler.profile( tf.get_default_graph(), options=tf.profiler.ProfileOptionBuilder.float_operation()) return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op)
def build_network(): network = resnet_model.resnet_v1(resnet_depth=50, num_classes=1000, dropblock_size=None, dropblock_keep_probs=[None] * 4, data_format='channels_last') return network(inputs=images, is_training=False)
def build_network(): network = resnet_model.resnet_v1( resnet_depth=FLAGS.resnet_depth, num_classes=FLAGS.num_label_classes, data_format=FLAGS.data_format) return network( inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
def build_network(l_features): network = resnet_model.resnet_v1( resnet_depth=FLAGS.resnet_depth, num_classes=FLAGS.num_label_classes, dropblock_size=FLAGS.dropblock_size, dropblock_keep_probs=dropblock_keep_probs, data_format=FLAGS.data_format) return network(inputs=l_features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
def test_load_resnet18_v1(self): network = resnet_model.resnet_v1(resnet_depth=18, num_classes=10, data_format='channels_last') input_bhw3 = tf.placeholder(tf.float32, [1, 28, 28, 3]) resnet_output = network(inputs=input_bhw3, train=True) sess = tf.Session() sess.run(tf.global_variables_initializer()) _ = sess.run(resnet_output, feed_dict={input_bhw3: np.random.randn(1, 28, 28, 3)})
def create_model(): """Create the model and compute the logits.""" if FLAGS.use_keras_model: model = tf.keras.applications.resnet50.ResNet50( include_top=True, weights=None, input_tensor=None, input_shape=None, pooling=None, classes=_NUM_CLASSES) return model(features, training=is_training) else: model = resnet_model.resnet_v1(resnet_depth=_RESNET_DEPTH, num_classes=_NUM_CLASSES, data_format='channels_last') return model(inputs=features, is_training=is_training)
def build_network(features, mode, params): """ Build ResNet50 Model Args: features: mode: params: Returns: Model function """ network = resnet_v1( resnet_depth=50, num_classes=params["classes"], data_format=params["data_format"], ) return network(inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
def resnet_model_fn(features, labels, mode, params): """The model_fn for ResNet to be used with TPUEstimator. Args: features: `Tensor` of batched images. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model """ if isinstance(features, dict): features = features['feature'] # In most cases, the default data format NCHW instead of NHWC should be # used for a significant performance boost on GPU/TPU. NHWC should be used # only if the network needs to be run on CPU since the pooling operations # are only supported on NHWC. if FLAGS.data_format == 'channels_first': features = tf.transpose(features, [0, 3, 1, 2]) network = resnet_model.resnet_v1(resnet_depth=FLAGS.resnet_depth, num_classes=LABEL_CLASSES, data_format=FLAGS.data_format) logits = network(inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN)) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, LABEL_CLASSES) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels) # Add weight decay to the loss for non-batch-normalization variables. loss = cross_entropy + WEIGHT_DECAY * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) host_call = None if mode == tf.estimator.ModeKeys.TRAIN: # Compute the current epoch and associated learning rate from global_step. global_step = tf.train.get_global_step() batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size current_epoch = (tf.cast(global_step, tf.float32) / batches_per_epoch) learning_rate = learning_rate_schedule(current_epoch) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=MOMENTUM, use_nesterov=True) if FLAGS.use_tpu: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) if not FLAGS.skip_host_call: def host_call_fn(gs, loss, lr, ce): """Training host call. Creates scalar summaries for training metrics. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `host_call`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `host_call`. Args: gs: `Tensor with shape `[batch]` for the global_step loss: `Tensor` with shape `[batch]` for the training loss. lr: `Tensor` with shape `[batch]` for the learning_rate. ce: `Tensor` with shape `[batch]` for the current_epoch. Returns: List of summary ops to run on the CPU host. """ gs = gs[0] with summary.create_file_writer(FLAGS.model_dir).as_default(): with summary.always_record_summaries(): summary.scalar('loss', loss[0], step=gs) summary.scalar('learning_rate', lr[0], step=gs) summary.scalar('current_epoch', ce[0], step=gs) return summary.all_summary_ops() # To log the loss, current learning rate, and epoch for Tensorboard, the # summary op needs to be run on the host CPU via host_call. host_call # expects [batch_size, ...] Tensors, thus reshape to introduce a batch # dimension. These Tensors are implicitly concatenated to # [params['batch_size']]. gs_t = tf.reshape(global_step, [1]) loss_t = tf.reshape(loss, [1]) lr_t = tf.reshape(learning_rate, [1]) ce_t = tf.reshape(current_epoch, [1]) host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t]) else: train_op = None eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(labels, logits): """Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. """ predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) return { 'top_1_accuracy': top_1_accuracy, 'top_5_accuracy': top_5_accuracy, } eval_metrics = (metric_fn, [labels, logits]) return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics)
def resnet_model_fn(features, labels, mode, params): """The model_fn for ResNet to be used with TPUEstimator. Args: features: `Tensor` of batched images. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model """ if isinstance(features, dict): features = features['feature'] if FLAGS.data_format == 'channels_first': assert not FLAGS.transpose_input # channels_first only for GPU features = tf.transpose(features, [0, 3, 1, 2]) if FLAGS.transpose_input: features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC if FLAGS.use_tpu: import bfloat16 scope_fn = lambda: bfloat16.bfloat16_scope() else: scope_fn = lambda: tf.variable_scope("") with scope_fn(): resnet_size = int(FLAGS.resnet_depth.split("_")[-1]) if FLAGS.resnet_depth.startswith("v1_"): print("\n\n\n\n\nUSING RESNET V1 {}\n\n\n\n\n".format( FLAGS.resnet_depth)) network = resnet_model.resnet_v1(resnet_depth=int(resnet_size), num_classes=LABEL_CLASSES, attention=None, apply_to="outputs", use_tpu=FLAGS.use_tpu, data_format=FLAGS.data_format) elif FLAGS.resnet_depth.startswith("paper-v1_"): print("\n\n\n\n\nUSING RESNET V1 (Paper) {}\n\n\n\n\n".format( resnet_size)) network = resnet_model.resnet_v1(resnet_depth=int(resnet_size), num_classes=LABEL_CLASSES, attention="paper", apply_to="outputs", use_tpu=FLAGS.use_tpu, data_format=FLAGS.data_format) elif FLAGS.resnet_depth.startswith("fc-v1_"): print("\n\n\n\n\nUSING RESNET V1 (fc) {}\n\n\n\n\n".format( resnet_size)) network = resnet_model.resnet_v1(resnet_depth=int(resnet_size), num_classes=LABEL_CLASSES, attention="fc", apply_to="outputs", use_tpu=FLAGS.use_tpu, data_format=FLAGS.data_format) elif FLAGS.resnet_depth.startswith("v2_"): print("\n\n\n\n\nUSING RESNET V2 {}\n\n\n\n\n".format(resnet_size)) network = resnet_v2_model.resnet_v2(resnet_size=resnet_size, num_classes=LABEL_CLASSES, feature_attention=False, extra_convs=0, data_format=FLAGS.data_format, use_tpu=FLAGS.use_tpu) elif FLAGS.resnet_depth.startswith("paper-v2_"): print("\n\n\n\n\nUSING RESNET V2 (Paper) {}\n\n\n\n\n".format( resnet_size)) network = resnet_v2_model.resnet_v2(resnet_size=resnet_size, num_classes=LABEL_CLASSES, feature_attention="paper", extra_convs=0, apply_to="output", data_format=FLAGS.data_format, use_tpu=FLAGS.use_tpu) elif FLAGS.resnet_depth.startswith("fc-v2_"): print("\n\n\n\n\nUSING RESNET V2 (fc) {}\n\n\n\n\n".format( resnet_size)) network = resnet_v2_model.resnet_v2(resnet_size=resnet_size, num_classes=LABEL_CLASSES, feature_attention="fc", extra_convs=1, data_format=FLAGS.data_format, use_tpu=FLAGS.use_tpu) else: assert False logits = network(inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN)) logits = tf.cast(logits, tf.float32) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, LABEL_CLASSES) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels) # Add weight decay to the loss for non-batch-normalization variables. loss = cross_entropy + WEIGHT_DECAY * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) # with tf.device("/cpu:0"): # loss = tf.Print(loss, [loss], "loss", summarize=20) host_call = None if mode == tf.estimator.ModeKeys.TRAIN: # Compute the current epoch and associated learning rate from global_step. global_step = tf.train.get_global_step() batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size current_epoch = (tf.cast(global_step, tf.float32) / batches_per_epoch) learning_rate = learning_rate_schedule(current_epoch) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=MOMENTUM, use_nesterov=True) if FLAGS.use_tpu: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): if FLAGS.clip_gradients == 0: print("\nnot clipping gradients\n") train_op = optimizer.minimize(loss, global_step) else: print("\nclipping gradients\n") gradients, variables = zip(*optimizer.compute_gradients(loss)) gradients, _ = tf.clip_by_global_norm(gradients, FLAGS.clip_gradients) train_op = optimizer.apply_gradients(zip(gradients, variables), global_step=global_step) # gvs = optimizer.compute_gradients(loss) # gradients, _ = tf.clip_by_global_norm(gradients, 5.0) # capped_gvs = [(tf.clip_by_value(grad, -10., 10.), var) for grad, var in gvs] # train_op = optimizer.apply_gradients(capped_gvs, global_step=global_step) if not FLAGS.skip_host_call: def host_call_fn(gs, loss, lr, ce): """Training host call. Creates scalar summaries for training metrics. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `host_call`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `host_call`. Args: gs: `Tensor with shape `[batch]` for the global_step loss: `Tensor` with shape `[batch]` for the training loss. lr: `Tensor` with shape `[batch]` for the learning_rate. ce: `Tensor` with shape `[batch]` for the current_epoch. Returns: List of summary ops to run on the CPU host. """ gs = gs[0] with summary.create_file_writer(FLAGS.model_dir).as_default(): with summary.always_record_summaries(): summary.scalar('loss', loss[0], step=gs) summary.scalar('learning_rate', lr[0], step=gs) summary.scalar('current_epoch', ce[0], step=gs) return summary.all_summary_ops() # To log the loss, current learning rate, and epoch for Tensorboard, the # summary op needs to be run on the host CPU via host_call. host_call # expects [batch_size, ...] Tensors, thus reshape to introduce a batch # dimension. These Tensors are implicitly concatenated to # [params['batch_size']]. gs_t = tf.reshape(global_step, [1]) loss_t = tf.reshape(loss, [1]) lr_t = tf.reshape(learning_rate, [1]) ce_t = tf.reshape(current_epoch, [1]) host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t]) else: train_op = None eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(labels, logits): """Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. """ predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) return { 'top_1_accuracy': top_1_accuracy, 'top_5_accuracy': top_5_accuracy, } eval_metrics = (metric_fn, [labels, logits]) # logging_hook = tf.train.LoggingTensorHook( # {"logging_hook_loss": loss}, every_n_iter=1) return tpu_estimator.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, # training_hooks=[logging_hook] )
def model_fn(features, labels, mode): """Definition for ResNet model.""" is_training = mode == tf.estimator.ModeKeys.TRAIN features = tf.transpose(features, [3, 0, 1, 2]) # Double-transpose trick # Normalize the image to zero mean and unit variance. features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype) features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype) with tf.contrib.tpu.bfloat16_scope(): network = resnet_model.resnet_v1( resnet_depth=_RESNET_DEPTH, num_classes=_NUM_CLASSES, data_format='channels_last') logits = network(inputs=features, is_training=is_training) logits = tf.cast(logits, tf.float32) if mode == tf.estimator.ModeKeys.PREDICT: assert False, 'Not implemented correctly right now!' predictions = {'logits': logits} return tf.estimator.EstimatorSpec(mode, predictions=predictions) cross_entropy = tf.losses.sparse_softmax_cross_entropy( labels=labels, logits=logits) loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name]) if mode == tf.estimator.ModeKeys.EVAL: predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) # TODO(priyag): Add this back when in_top_k is supported on TPU. # in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) # top_5_accuracy = tf.metrics.mean(in_top_5) eval_metric_ops = { 'top_1_accuracy': top_1_accuracy, # 'top_5_accuracy': top_5_accuracy, } return tf.estimator.EstimatorSpec( mode, loss=loss, eval_metric_ops=eval_metric_ops) assert mode == tf.estimator.ModeKeys.TRAIN global_step = tf.train.get_or_create_global_step() batches_per_epoch = (_NUM_TRAIN_IMAGES / (FLAGS.train_batch_size * FLAGS.num_cores)) current_epoch = (tf.cast(global_step, tf.float32) / batches_per_epoch) learning_rate = learning_rate_schedule(current_epoch) optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=_MOMENTUM, use_nesterov=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step=global_step) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
def resnet_network(): network = resnet_model.resnet_v1(resnet_depth=FLAGS.resnet_depth, data_format=FLAGS.data_format) return network(inputs=feature_image, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
def build_network(features, mode, params): network = resnet_v1(resnet_depth=50, num_classes=params["classes"], data_format="channels_first") return network(inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
if load_checkpoint: model = load_model(filepath) else: # build the graph if version == 2: model = resnet_v2(input_shape=input_shape, depth=depth, activation_bits=activation_bits, weight_noise=weight_noise, trainable_conv=not finetune, trainable_dense=True) else: model = resnet_v1(input_shape=input_shape, depth=depth, activation_bits=activation_bits, weight_noise=weight_noise, trainable_conv=not finetune, trainable_dense=True) model.compile(loss='categorical_crossentropy', optimizer=Adam(learning_rate=lr_schedule(0)), metrics=['accuracy']) model.summary() print(model_type) if finetune: weights_path = os.path.join(os.getcwd(), finetune_ckpt_path) latest = tf.train.latest_checkpoint(weights_path) print(latest) model.load_weights(latest)
def resnet_model_fn(features, labels, mode, params): """The model_fn for ResNet to be used with TPUEstimator. Args: features: `Tensor` of batched images. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model """ if isinstance(features, dict): images = features['image'] hms = features['hm'] bboxs = features['bbox'] ccount = features['ccount'] else: images = features hms = None bboxs = None if FLAGS.data_format == 'channels_first': assert not FLAGS.transpose_input # channels_first only for GPU images = tf.transpose(images, [0, 3, 1, 2]) if hms is not None: hms = tf.transpose(hms, [0, 3, 1, 2]) bboxs = tf.transpose(bboxs, [0, 3, 1, 2]) if FLAGS.transpose_input: images = tf.transpose(images, [3, 0, 1, 2]) # HWCN to NHWC if hms is not None: hms = tf.transpose(hms, [3, 0, 1, 2]) bboxs = tf.transpose(bboxs, [3, 0, 1, 2]) if FLAGS.use_tpu: import bfloat16 scope_fn = lambda: bfloat16.bfloat16_scope() else: scope_fn = lambda: tf.variable_scope("") with scope_fn(): resnet_size = int(FLAGS.resnet_depth.split("_")[-1]) if FLAGS.resnet_depth.startswith("v1_"): print("\n\n\n\n\nUSING RESNET V1 {}\n\n\n\n\n".format( FLAGS.resnet_depth)) network = resnet_model.resnet_v1(resnet_depth=int(resnet_size), num_classes=LABEL_CLASSES, attention=None, apply_to="outputs", use_tpu=FLAGS.use_tpu, data_format=FLAGS.data_format) elif FLAGS.resnet_depth.startswith("SE-v1_"): print( "\n\n\n\n\nUSING RESNET V1 (Squeeze-and-excite) {}\n\n\n\n\n". format(resnet_size)) network = resnet_model.resnet_v1(resnet_depth=int(resnet_size), num_classes=LABEL_CLASSES, attention="se", apply_to="outputs", use_tpu=FLAGS.use_tpu, data_format=FLAGS.data_format) elif FLAGS.resnet_depth.startswith("GALA-v1_"): print("\n\n\n\n\nUSING RESNET V1 (GALA) {}\n\n\n\n\n".format( resnet_size)) network = resnet_model.resnet_v1(resnet_depth=int(resnet_size), num_classes=LABEL_CLASSES, attention="gala", apply_to="outputs", use_tpu=FLAGS.use_tpu, data_format=FLAGS.data_format) elif FLAGS.resnet_depth.startswith("v2_"): print("\n\n\n\n\nUSING RESNET V2 {}\n\n\n\n\n".format(resnet_size)) network = resnet_v2_model.resnet_v2(resnet_size=resnet_size, num_classes=LABEL_CLASSES, feature_attention=False, extra_convs=0, data_format=FLAGS.data_format, use_tpu=FLAGS.use_tpu) elif FLAGS.resnet_depth.startswith("SE-v2_"): print( "\n\n\n\n\nUSING RESNET V2 (Squeeze-and-excite) {}\n\n\n\n\n". format(resnet_size)) network = resnet_v2_model.resnet_v2(resnet_size=resnet_size, num_classes=LABEL_CLASSES, feature_attention="se", extra_convs=0, apply_to="output", data_format=FLAGS.data_format, use_tpu=FLAGS.use_tpu) elif FLAGS.resnet_depth.startswith("GALA-v2_"): print("\n\n\n\n\nUSING RESNET V2 (GALA) {}\n\n\n\n\n".format( resnet_size)) network = resnet_v2_model.resnet_v2(resnet_size=resnet_size, num_classes=LABEL_CLASSES, feature_attention="gala", extra_convs=1, data_format=FLAGS.data_format, use_tpu=FLAGS.use_tpu) else: assert False logits, attention = network( inputs=images, is_training=(mode == tf.estimator.ModeKeys.TRAIN)) logits = tf.cast(logits, tf.float32) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) batch_size = params['batch_size'] # Calculate softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, LABEL_CLASSES) loss = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=one_hot_labels) # Switch hms/bboxs if FLAGS.annotation == 'hms': pass elif FLAGS.annotation == 'bboxs': hms = bboxs elif FLAGS.annotation == 'none': hms = None else: raise NotImplementedError(FLAGS.annotation) # Add attention losses if hms is not None: map_loss_list = [] blur_click_maps = 49 # 0 = no, > 0 blur kernel blur_click_maps_sigma = 28 # 14 # Blur the heatmaps hms = blur(hms, kernel=blur_click_maps, sigma=blur_click_maps_sigma, dtype=images.dtype) mask = tf.cast(tf.greater(ccount, 0), tf.float32) mask = tf.reshape(mask, [int(hms.get_shape()[0]), 1, 1, 1]) for layer in attention: layer_shape = [int(x) for x in layer.get_shape()[1:3]] layer = tf.cast(layer, tf.float32) hms = tf.cast(hms, tf.float32) resized_maps = tf.image.resize_bilinear(hms, layer_shape, align_corners=True) if layer.get_shape().as_list()[-1] > 1: layer = tf.reduce_mean(tf.pow(layer, 2), axis=-1, keep_dims=True) resized_maps = l2_channel_norm(resized_maps) layer = l2_channel_norm(layer) dist = resized_maps - layer d = tf.nn.l2_loss(dist * mask) map_loss_list += [d] denominator = len(attention) if len(map_loss_list): denominator = len(attention) map_loss = (tf.add_n(map_loss_list) / float(denominator)) * 1e-5 loss += map_loss else: assert not FLAGS.resnet_depth.startswith( "GALA") or not FLAGS.resnet_depth.startswith( "SE"), "Failed to apply attention." loss += (WEIGHT_DECAY * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name and 'ATTENTION' not in v.name and 'block' not in v.name and 'training' not in v.name ])) host_call = None if mode == tf.estimator.ModeKeys.TRAIN: # Compute the current epoch and associated learning rate from global_step. global_step = tf.train.get_global_step() batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size current_epoch = (tf.cast(global_step, tf.float32) / batches_per_epoch) learning_rate = learning_rate_schedule(current_epoch) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=MOMENTUM, use_nesterov=True) if FLAGS.use_tpu: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): if FLAGS.clip_gradients == 0: print("\nnot clipping gradients\n") train_op = optimizer.minimize(loss, global_step) else: print("\nclipping gradients\n") gradients, variables = zip(*optimizer.compute_gradients(loss)) gradients, _ = tf.clip_by_global_norm(gradients, FLAGS.clip_gradients) train_op = optimizer.apply_gradients(zip(gradients, variables), global_step=global_step) if not FLAGS.skip_host_call: def host_call_fn(gs, loss, lr, ce): # , hm=None, image=None): """Training host call. Creates scalar summaries for training metrics. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `host_call`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `host_call`. Args: gs: `Tensor with shape `[batch]` for the global_step loss: `Tensor` with shape `[batch]` for the training loss. lr: `Tensor` with shape `[batch]` for the learning_rate. ce: `Tensor` with shape `[batch]` for the current_epoch. Returns: List of summary ops to run on the CPU host. """ gs = gs[0] with summary.create_file_writer(FLAGS.model_dir).as_default(): with summary.always_record_summaries(): summary.scalar('loss', loss[0], step=gs) summary.scalar('learning_rate', lr[0], step=gs) summary.scalar('current_epoch', ce[0], step=gs) # summary.image('image', hm, step=gs) # summary.image('heatmap', image, step=gs) return summary.all_summary_ops() # To log the loss, current learning rate, and epoch for Tensorboard, the # summary op needs to be run on the host CPU via host_call. host_call # expects [batch_size, ...] Tensors, thus reshape to introduce a batch # dimension. These Tensors are implicitly concatenated to # [params['batch_size']]. gs_t = tf.reshape(global_step, [1]) loss_t = tf.reshape(loss, [1]) lr_t = tf.reshape(learning_rate, [1]) ce_t = tf.reshape(current_epoch, [1]) # im_t = tf.cast(images, tf.float32) # hm_t = tf.cast(hms, tf.float32) host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t]) # host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t, im_t, hm_t]) else: train_op = None eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(labels, logits): """Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. """ predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) return { 'top_1_accuracy': top_1_accuracy, 'top_5_accuracy': top_5_accuracy, } eval_metrics = (metric_fn, [labels, logits]) # logging_hook = tf.train.LoggingTensorHook( # {"logging_hook_loss": loss}, every_n_iter=1) return tpu_estimator.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, # training_hooks=[logging_hook] )
def resnet_model_fn(features, labels, mode, params): """The model_fn for ResNet to be used with TPUEstimator. Args: features: `Tensor` of batched images. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model """ if isinstance(features, dict): features = features['feature'] # In most cases, the default data format NCHW instead of NHWC should be # used for a significant performance boost on GPU/TPU. NHWC should be used # only if the network needs to be run on CPU since the pooling operations # are only supported on NHWC. if FLAGS.data_format == 'channels_first': features = tf.transpose(features, [0, 3, 1, 2]) with tf.variable_scope('cg', custom_getter=get_custom_getter()): network = resnet_model.resnet_v1(resnet_depth=FLAGS.resnet_depth, num_classes=LABEL_CLASSES, data_format=FLAGS.data_format) logits = network(inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN)) logits = tf.cast(logits, tf.float32) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, LABEL_CLASSES) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels) tf.identity(cross_entropy, name='cross_entropy') tf.summary.scalar('cross_entropy', cross_entropy) # Add weight decay to the loss for non-batch-normalization variables. loss = cross_entropy + WEIGHT_DECAY * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) host_call = None if mode == tf.estimator.ModeKeys.TRAIN: # Compute the current epoch and associated learning rate from global_step. global_step = tf.train.get_global_step() batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size current_epoch = (tf.cast(global_step, tf.float32) / batches_per_epoch) learning_rate = learning_rate_schedule(current_epoch) # Create a tensor named learning_rate for logging purposes tf.identity(learning_rate, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=MOMENTUM, use_nesterov=True) optimizer = tf.contrib.estimator.TowerOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): #train_op = optimizer.minimize(loss, global_step) train_op = tf.group(optimizer.minimize(loss, global_step), update_ops) else: train_op = None accuracy = tf.metrics.accuracy(tf.argmax(one_hot_labels, axis=1), predictions['classes']) metrics = {'accuracy': accuracy} tf.identity(accuracy[1], name='train_accuracy') tf.summary.scalar('train_accuracy', accuracy[1]) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metrics)