def build_model(): """Builds graph for model to train with rewrites for quantization. Returns: g: Graph with fake quantization ops and batch norm folding suitable for training quantized weights. train_tensor: Train op for execution during training. """ g = tf.Graph() with g.as_default(), tf.device( tf.train.replica_device_setter(FLAGS.ps_tasks)): inputs, labels, _ = get_batch(FLAGS.dataset_dir, FLAGS.batch_size) with slim.arg_scope(mobilenet_v2.training_scope()): logits, _ = mobilenet_v2.mobilenet_v2_050( inputs, num_classes=FLAGS.num_classes) labels = slim.one_hot_encoding(labels, FLAGS.num_classes) tf.losses.softmax_cross_entropy(labels, logits) # Call rewriter to produce graph with fake quant ops and folded batch norms # quant_delay delays start of quantization till quant_delay steps, allowing # for better model accuracy. if FLAGS.quantize: tf.contrib.quantize.create_training_graph( quant_delay=get_quant_delay()) total_loss = tf.losses.get_total_loss(name='total_loss') # Configure the learning rate using an exponential decay. num_epochs_per_decay = 1 imagenet_size = 51200 decay_steps = int(imagenet_size / FLAGS.batch_size * num_epochs_per_decay) learning_rate = tf.train.exponential_decay( get_learning_rate(), tf.train.get_or_create_global_step(), decay_steps, _LEARNING_RATE_DECAY_FACTOR, staircase=False) #opt = tf.train.GradientDescentOptimizer(learning_rate) opt = tf.train.RMSPropOptimizer(learning_rate, decay=0.9, momentum=0.9) train_tensor = slim.learning.create_train_op(total_loss, optimizer=opt) slim.summaries.add_scalar_summary(total_loss, 'total_loss', 'losses') slim.summaries.add_scalar_summary(learning_rate, 'learning_rate', 'training') return g, train_tensor
def net(inputs, data_format='channels_last', depth_mult=1.0, VGG_PARAMS_FILE=None, is_train=False): if data_format != "channels_last": print('only works for channels last now') return None with tf.contrib.slim.arg_scope( mobilenet_v2.training_scope(is_training=is_train)): logits, endpoints = mobilenet_v2.mobilenet(inputs, base_only=True, reuse=tf.AUTO_REUSE, final_endpoint="layer_19", depth_multiplier=depth_mult) l15e = endpoints['layer_15/expansion_output'] l19 = endpoints['layer_19'] return [l15e, l19], endpoints
def build_model(): """Build the mobilenet_v1 model for evaluation. Returns: g: graph with rewrites after insertion of quantization ops and batch norm folding. eval_ops: eval ops for inference. variables_to_restore: List of variables to restore from checkpoint. """ g = tf.Graph() with g.as_default(): inputs, labels, _ = get_batch(FLAGS.dataset_dir, FLAGS.batch_size) scope = mobilenet_v2.training_scope(is_training=False, weight_decay=0.0) with slim.arg_scope(scope): _, end_points = mobilenet_v2.mobilenet_v2_050( inputs, is_training=False, num_classes=FLAGS.num_classes) if FLAGS.quantize: tf.contrib.quantize.create_eval_graph() eval_ops = metrics(end_points['Predictions'], labels) return g, eval_ops
def test_model(): errfile = '../_error/other_errors.txt' sess = tf.Session() inputs, labels, name_batch = get_batch(FLAGS.dataset_dir, FLAGS.batch_size, shuffle=False) scope = mobilenet_v2.training_scope(is_training=False, weight_decay=0.0) with slim.arg_scope(scope): logits, end_points = mobilenet_v2.mobilenet_v2_050( inputs, is_training=False, num_classes=FLAGS.num_classes) # evaluate model, for classification correct_pred = tf.equal(tf.argmax(end_points['Predictions'], 1), labels) acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # saver for restore model saver = tf.train.Saver() print('[*] Try to load trained model...') ckpt_name = load(sess, saver, FLAGS.checkpoint_dir) step = 0 accs = 0 me_acc = 0 errors_name = [] max_steps = int(FLAGS.num_examples / FLAGS.batch_size) print('START TESTING...') try: while not coord.should_stop(): for _step in range(step + 1, step + max_steps + 1): # test #_label, _logits, _points = sess.run([labels, logits, end_points]) _name, _logits, _corr, _acc = sess.run( [name_batch, logits, correct_pred, acc]) if (~_corr).any(): errors_name.extend(list(_name[~_corr])) accs += _acc me_acc = accs / _step if _step % 20 == 0: print( time.strftime("%X"), 'global_step:{0}, current_acc:{1:.6f}'.format( _step, me_acc)) except tf.errors.OutOfRangeError: accuracy = 1 - len(errors_name) / FLAGS.num_examples print(time.strftime("%X"), 'RESULT >>> current_acc:{0:.6f}'.format(accuracy)) # print(errors_name) errorsfile = open(errfile, 'a') errorsfile.writelines('\n' + ckpt_name + '--' + str(accuracy)) for err in errors_name: errorsfile.writelines('\n' + err.decode('utf-8')) errorsfile.close() finally: coord.request_stop() coord.join(threads) sess.close() print('FINISHED TESTING.')
channels = tf.split(axis=2, num_or_size_splits=num_channels, value=image) for i in range(num_channels): channels[i] -= img_mean[i] return tf.concat(axis=2, values=channels) image_raw_data = tf.placeholder(tf.string, None) img_data = tf.image.decode_jpeg(image_raw_data) image = tf.image.convert_image_dtype(img_data, dtype=tf.uint8) img_show = image image.set_shape([img_size, img_size, 3]) image = tf.to_float(image) image = _mean_image_subtraction(image) image = tf.expand_dims(image, [0]) with tf.contrib.slim.arg_scope(mnv2.training_scope(is_training=False)): logits, endpoint = mnv2.mobilenet(image, num_classes=4) #logits=tf.Print(logits,[logits],'logits: ',summarize=32) #logits=tf.sigmoid(logits) #logits=tf.nn.softmax(logits) #logits = tf.Print(logits, [logits], 'softmax: ', summarize=32) saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, './output/chamo_3000.000000_0.002084/chamo.ckpt') while True: img, cur_id = get_sample(0) image_raw_data_jpg = tf.gfile.FastGFile('./re/chamo.jpg', 'rb').read() re = sess.run([logits, img_show], feed_dict={image_raw_data: image_raw_data_jpg}) print(re[0][0]) show_img = re[1]
def model_fn(features, labels, mode): feature_dict = features labels = tf.reshape(feature_dict["labels"], [-1]) if mode == tf.estimator.ModeKeys.TRAIN: is_training = True else: is_training = False if FLAGS.model == 'mobilenet': scope = mobilenet_v2.training_scope(is_training=is_training) with tf.contrib.slim.arg_scope(scope): net, end_points = mobilenet_v2.mobilenet( feature_dict["features"], is_training=is_training, num_classes=FLAGS.num_classes) end_points['embedding'] = end_points['global_pool'] elif FLAGS.model == 'mobilefacenet': scope = mobilenet_v2.training_scope(is_training=is_training) with tf.contrib.slim.arg_scope(scope): net, end_points = mobilefacenet.mobilefacenet( feature_dict["features"], is_training=is_training, num_classes=FLAGS.num_classes) elif FLAGS.model == 'metric_learning': with slim.arg_scope( inception_v1.inception_v1_arg_scope(weight_decay=0.0)): net, end_points = proxy_metric_learning.metric_learning( feature_dict["features"], is_training=is_training, num_classes=FLAGS.num_classes) elif FLAGS.model == 'faceresnet': net, end_points = mobilefacenet.faceresnet( feature_dict["features"], is_training=is_training, num_classes=FLAGS.num_classes) elif FLAGS.model == 'cifar100': net, end_points = cifar.nin(feature_dict["features"], labels, is_training=is_training, num_classes=FLAGS.num_classes) else: raise ValueError("Unknown model %s" % FLAGS.model) small_embeddings = tf.squeeze(end_points['embedding']) logits = end_points['Logits'] predictions = tf.cast(tf.argmax(logits, 1), dtype=tf.int32) if FLAGS.final_activation in ['soft_thresh', 'none']: abs_embeddings = tf.abs(small_embeddings) else: abs_embeddings = small_embeddings nnz_small_embeddings = tf.cast(tf.less(FLAGS.zero_threshold, abs_embeddings), dtype=tf.float32) small_sparsity = tf.reduce_sum(nnz_small_embeddings, axis=1) small_sparsity = tf.reduce_mean(small_sparsity) mean_nnz_col = tf.reduce_mean(nnz_small_embeddings, axis=0) sum_nnz_col = tf.reduce_sum(nnz_small_embeddings, axis=0) mean_flops_ub = tf.reduce_sum(sum_nnz_col * (sum_nnz_col - 1)) / (FLAGS.batch_size * (FLAGS.batch_size - 1)) mean_flops = tf.reduce_sum(mean_nnz_col * mean_nnz_col) l1_norm_row = tf.reduce_sum(abs_embeddings, axis=1) l1_norm_col = tf.reduce_mean(abs_embeddings, axis=0) mean_l1_norm = tf.reduce_mean(l1_norm_row) mean_flops_sur = tf.reduce_sum(l1_norm_col * l1_norm_col) if mode == tf.estimator.ModeKeys.PREDICT: predictions_dict = { 'predictions': predictions, 'true_labels': feature_dict["labels"], 'true_label_texts': feature_dict["label_texts"], 'small_embeddings': small_embeddings, 'sparsity/small': small_sparsity, 'sparsity/flops': mean_flops_ub, 'filename': features["filename"] } return tf.estimator.EstimatorSpec(mode, predictions=predictions_dict) if FLAGS.model == 'mobilenet': cr_ent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=labels) elif 'face' in FLAGS.model: cr_ent_loss = mobilefacenet.arcface_loss(logits=logits, labels=labels, out_num=FLAGS.num_classes) elif FLAGS.model == 'metric_learning': cr_ent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=labels) # cr_ent_loss = mobilefacenet.arcface_loss(logits=logits, labels=labels, out_num=FLAGS.num_classes) elif FLAGS.model == 'cifar100': # cr_ent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) cr_ent_loss = tf.contrib.losses.metric_learning.triplet_semihard_loss( labels, small_embeddings, margin=0.1) # cr_ent_loss = mobilefacenet.arcface_loss(logits=logits, # labels=labels, out_num=FLAGS.num_classes, m=0.0) # cr_ent_loss = triplet_semihard_loss(labels=labels, embeddings=small_embeddings, # pairwise_distance=lambda embed: pairwise_distance_euclid(embed, squared=True), margin=0.3) else: raise ValueError('Unknown model %s' % FLAGS.model) cr_ent_loss = tf.reduce_mean( cr_ent_loss) + tf.losses.get_regularization_loss() ema = tf.train.ExponentialMovingAverage(decay=0.99) ema_op = ema.apply([mean_l1_norm]) moving_l1_norm = ema.average(mean_l1_norm) global_step = tf.train.get_or_create_global_step() all_ops = [ema_op] l1_weight = tf.Variable(0.0, name='l1_weight', trainable=False) if FLAGS.l1_weighing_scheme == 'constant': l1_weight = FLAGS.l1_parameter elif FLAGS.l1_weighing_scheme == 'dynamic_1': l1_weight = FLAGS.l1_parameter / moving_l1_norm l1_weight = tf.stop_gradient(l1_weight) l1_weight = tf.train.piecewise_constant(x=global_step, boundaries=[5], values=[0.0, l1_weight]) elif FLAGS.l1_weighing_scheme == 'dynamic_2': if FLAGS.sparsity_type == "flops_sur": update_lr = 1e-5 else: update_lr = 1e-4 update = update_lr * (FLAGS.l1_parameter - cr_ent_loss) assign_op = tf.assign(l1_weight, tf.nn.relu(l1_weight + update)) all_ops.append(assign_op) elif FLAGS.l1_weighing_scheme == 'dynamic_3': update_lr = 1e-4 global_step = tf.train.get_or_create_global_step() upper_bound = FLAGS.l1_parameter - ( FLAGS.l1_parameter - 12.0) * tf.cast(global_step, tf.float32) / 5e5 upper_bound = tf.cast(upper_bound, tf.float32) update = update_lr * tf.sign(upper_bound - cr_ent_loss) assign_op = tf.assign(l1_weight, l1_weight + tf.nn.relu(update)) all_ops.append(assign_op) elif FLAGS.l1_weighing_scheme == 'dynamic_4': l1_weight = FLAGS.l1_parameter * tf.minimum( 1.0, (tf.cast(global_step, tf.float32) / FLAGS.l1_p_steps)**2) elif FLAGS.l1_weighing_scheme is None: l1_weight = 0.0 else: raise ValueError('Unknown l1_weighing_scheme %s' % FLAGS.l1_weighing_scheme) if FLAGS.sparsity_type == 'l1_norm': sparsity_loss = l1_weight * mean_l1_norm elif FLAGS.sparsity_type == 'flops_sur': sparsity_loss = l1_weight * mean_flops_sur elif FLAGS.sparsity_type is None: sparsity_loss = 0.0 else: raise ValueError("Unknown sparsity_type %d" % FLAGS.sparsity_type) total_loss = cr_ent_loss + sparsity_loss accuracy = tf.reduce_mean( tf.cast(tf.equal(predictions, labels), dtype=tf.float32)) if mode == tf.estimator.ModeKeys.EVAL: global_accuracy = tf.metrics.mean(accuracy) global_small_sparsity = tf.metrics.mean(small_sparsity) global_flops = tf.metrics.mean(mean_flops_ub) global_l1_norm = tf.metrics.mean(mean_l1_norm) global_mean_flops_sur = tf.metrics.mean(mean_flops_sur) global_cr_ent_loss = tf.metrics.mean(cr_ent_loss) metrics = { 'accuracy': global_accuracy, 'sparsity/small': global_small_sparsity, 'sparsity/flops': global_flops, 'l1_norm': global_l1_norm, 'l1_norm/flops_sur': global_mean_flops_sur, 'loss/cr_ent_loss': global_cr_ent_loss, } return tf.estimator.EstimatorSpec(mode, loss=total_loss, eval_metric_ops=metrics) if mode == tf.estimator.ModeKeys.TRAIN: base_lrate = FLAGS.learning_rate learning_rate = tf.train.piecewise_constant( x=global_step, boundaries=[ FLAGS.decay_step if FLAGS.decay_step is not None else int(1e7) ], values=[base_lrate, base_lrate / 10.0]) tf.summary.image("input", feature_dict["features"], max_outputs=1) tf.summary.scalar('sparsity/small', small_sparsity) tf.summary.scalar('sparsity/flops', mean_flops_ub) tf.summary.scalar('accuracy', accuracy) tf.summary.scalar('l1_norm', mean_l1_norm) tf.summary.scalar('l1_norm/ema', moving_l1_norm) # Comment this for mom tf.summary.scalar('l1_norm/l1_weight', l1_weight) tf.summary.scalar('l1_norm/flops_sur', mean_flops_sur) tf.summary.scalar('learning_rate', learning_rate) tf.summary.scalar('loss/cr_ent_loss', cr_ent_loss) tf.summary.scalar( 'sparsity/ratio', mean_flops * FLAGS.embedding_size / (small_sparsity * small_sparsity)) try: tf.summary.scalar('loss/upper_bound', upper_bound) except NameError: print("Skipping 'upper_bound' summary") for variable in tf.trainable_variables(): if 'soft_thresh' in variable.name: print('Adding summary for lambda') tf.summary.scalar('lambda', variable) # Histogram summaries # gamma = tf.get_default_graph().get_tensor_by_name('MobilenetV2/Conv_2/BatchNorm/gamma:0') # pre_relu = tf.get_default_graph().get_tensor_by_name( # 'MobilenetV2/Conv_2/BatchNorm/FusedBatchNorm:0') # pre_relu = tf.squeeze(pre_relu) # tf.summary.histogram('gamma', gamma) # tf.summary.histogram('pre_relu', pre_relu[:, 237]) # tf.summary.histogram('small_activations', nnz_small_embeddings) # tf.summary.histogram('small_activations/log', tf.log(nnz_small_embeddings + 1e-10)) # fl_sur_ratio = (mean_nnz_col * mean_nnz_col) / (l1_norm_col * l1_norm_col) # tf.summary.histogram('fl_sur_ratio', fl_sur_ratio) # tf.summary.histogram('fl_sur_ratio/log', tf.log(fl_sur_ratio + 1e-10)) # l1_sur_ratio = (mean_nnz_col * mean_nnz_col) / l1_norm_col # tf.summary.histogram('l1_sur_ratio', l1_sur_ratio) # tf.summary.histogram('l1_sur_ratio/log', tf.log(l1_sur_ratio + 1e-10)) if FLAGS.optimizer == 'mom': optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=FLAGS.momentum) elif FLAGS.optimizer == 'sgd': optimizer = tf.train.GradientDescentOptimizer( learning_rate=learning_rate) elif FLAGS.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=0.0001) elif FLAGS.optimizer == 'rmsprop': optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate, epsilon=0.01, momentum=FLAGS.momentum) else: raise ValueError("Unknown optimizer %s" % FLAGS.optimizer) # Make the optimizer distributed optimizer = hvd.DistributedOptimizer(optimizer) train_op = tf.contrib.slim.learning.create_train_op( total_loss, optimizer, global_step=tf.train.get_or_create_global_step()) all_ops.append(train_op) merged = tf.group(*all_ops) return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=merged)
def build_model(apply_or_model=False, apply_and_model=False): """Build test model and write model as pb file. Args: apply_or_model, apply_and_model: whether to apply or/and model. """ g = tf.Graph() with g.as_default(), tf.device( tf.train.replica_device_setter(FLAGS.ps_tasks)): anchors = anchor_generator.generate_anchors(**_anchors_figure) box_pred = box_predictor.SSDBoxPredictor(FLAGS.is_training, FLAGS.num_classes, box_code_size=4) batchnorm_updates_collections = (None if FLAGS.inplace_batchnorm_update else tf.GraphKeys.UPDATE_OPS) anchors = tf.convert_to_tensor(anchors, dtype=tf.float32, name='anchors') convert_ratio = tf.convert_to_tensor(_convert_ratio, tf.float32, name='convert_ratio') value_to_ratio = tf.convert_to_tensor(_value_to_ratio, tf.float32, name='convert_ratio') img_tensor = tf.placeholder( tf.float32, [1, FLAGS.original_image_height, FLAGS.original_image_width, 3], name='input_img') grid_size_tensor = tf.placeholder(tf.float32, [2], 'input_grid_size') preimg_batch, grid_points_tl = preprocess(img_tensor, grid_size_tensor, FLAGS.image_size, value_to_ratio, apply_or_model) with slim.arg_scope([slim.batch_norm], is_training=( FLAGS.is_training and not FLAGS.freeze_batchnorm), updates_collections=batchnorm_updates_collections),\ slim.arg_scope( mobilenet_v2.training_scope(is_training=None, bn_decay=0.997)): _, image_features = mobilenet_v2.mobilenet_base( preimg_batch, final_endpoint='layer_18', depth_multiplier=FLAGS.depth_multiplier, finegrain_classification_mode=True) feature_maps = feature_map_generator.pooling_pyramid_feature_maps( base_feature_map_depth=0, num_layers=2, image_features={'image_features': image_features['layer_18']}) pred_dict = box_pred.predict(feature_maps.values(), [1, 1]) box_encodings = tf.concat(pred_dict['box_encodings'], axis=1) if box_encodings.shape.ndims == 4 and box_encodings.shape[2] == 1: box_encodings = tf.squeeze(box_encodings, axis=2) class_predictions_with_background = tf.concat( pred_dict['class_predictions_with_background'], axis=1) detection_boxes, detection_scores = postprocess( anchors, box_encodings, class_predictions_with_background, convert_ratio, grid_points_tl, num_classes=FLAGS.num_classes, score_threshold=FLAGS.score_threshold, apply_and_model=apply_and_model) input_boxes = tf.placeholder_with_default(detection_boxes[:1], [None, 4], name='input_boxes') if apply_or_model or apply_and_model: return g, img_tensor, input_boxes, detection_boxes, detection_scores num_batch = shape_utils.combined_static_and_dynamic_shape(input_boxes) input_scores = tf.tile([0.7], [num_batch[0]]) total_boxes = tf.concat([detection_boxes, input_boxes], 0) total_scores = tf.concat([detection_scores, input_scores], 0) result_dict = non_max_suppression( total_boxes, total_scores, max_output_size=FLAGS.max_output_size, iou_threshold=FLAGS.iou_threshold) output_node_names = [ 'Non_max_suppression/result_boxes', 'Non_max_suppression/result_scores', 'Non_max_suppression/abnormal_indices', 'Non_max_suppression/abnormal_inter_idx', 'Non_max_suppression/abnormal_inter' ] init_op = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init_op) # saver for restore model saver = tf.train.Saver() print('[*] Try to load trained model...') ckpt_name = load(sess, saver, FLAGS.checkpoint_dir) write_pb_model(FLAGS.checkpoint_dir + ckpt_name + '.pb', sess, g.as_graph_def(), output_node_names)
def build_model(): g = tf.Graph() with g.as_default(), tf.device( tf.train.replica_device_setter(FLAGS.ps_tasks)): matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=0.5, unmatched_threshold=0.4, force_match_for_each_row=True) anchors = anchor_generator.generate_anchors(**_anchors_figure) box_pred = box_predictor.SSDBoxPredictor(FLAGS.is_training, FLAGS.num_classes, box_code_size=4) batchnorm_updates_collections = (None if FLAGS.inplace_batchnorm_update else tf.GraphKeys.UPDATE_OPS) with tf.variable_scope('inputs'): img_batch, bbox_batch, bbox_num_batch, _ = get_batch( FLAGS.dataset_dir, FLAGS.batch_size) img_batch = tf.cast(img_batch, tf.float32) / 127.5 - 1 img_batch = tf.identity(img_batch, name='gt_imgs') bbox_list = [] for i in range(FLAGS.batch_size): gt_boxes = tf.identity(bbox_batch[i][:bbox_num_batch[i]], name='gt_boxes') bbox_list.append(gt_boxes) anchors = tf.convert_to_tensor(anchors, dtype=tf.float32, name='anchors') with slim.arg_scope([slim.batch_norm], is_training=(FLAGS.is_training and not FLAGS.freeze_batchnorm), updates_collections=batchnorm_updates_collections),\ slim.arg_scope( mobilenet_v2.training_scope(is_training=None, bn_decay=0.997)): _, image_features = mobilenet_v2.mobilenet_base( img_batch, final_endpoint='layer_18', depth_multiplier=FLAGS.depth_multiplier, finegrain_classification_mode=True) feature_maps = feature_map_generator.pooling_pyramid_feature_maps( base_feature_map_depth=0, num_layers=2, image_features={'image_features': image_features['layer_18']}) pred_dict = box_pred.predict(feature_maps.values(), [1, 1]) box_encodings = tf.concat(pred_dict['box_encodings'], axis=1) if box_encodings.shape.ndims == 4 and box_encodings.shape[2] == 1: box_encodings = tf.squeeze(box_encodings, axis=2) class_predictions_with_background = tf.concat( pred_dict['class_predictions_with_background'], axis=1) losses_dict = loss_op.loss(box_encodings, class_predictions_with_background, bbox_list, anchors, matcher, random_example=False) for loss_tensor in losses_dict.values(): tf.losses.add_loss(loss_tensor) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Configure the learning rate using an exponential decay. num_epochs_per_decay = 3 imagenet_size = 10240 decay_steps = int(imagenet_size / FLAGS.batch_size * num_epochs_per_decay) learning_rate = tf.train.exponential_decay( FLAGS.learning_rate, tf.train.get_or_create_global_step(), decay_steps, _LEARNING_RATE_DECAY_FACTOR, staircase=True) opt = tf.train.AdamOptimizer(learning_rate) total_losses = [] cls_loc_losses = tf.get_collection(tf.GraphKeys.LOSSES) cls_loc_loss = tf.add_n(cls_loc_losses, name='cls_loc_loss') total_losses.append(cls_loc_loss) regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) regularization_loss = tf.add_n(regularization_losses, name='regularization_loss') total_losses.append(regularization_loss) total_loss = tf.add_n(total_losses, name='total_loss') grads_and_vars = opt.compute_gradients(total_loss) total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') grad_updates = opt.apply_gradients( grads_and_vars, global_step=tf.train.get_or_create_global_step()) update_ops.append(grad_updates) update_op = tf.group(*update_ops, name='update_barrier') with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') slim.summaries.add_scalar_summary(cls_loc_loss, 'cls_loc_loss', 'losses') slim.summaries.add_scalar_summary(regularization_loss, 'regularization_loss', 'losses') slim.summaries.add_scalar_summary(total_loss, 'total_loss', 'losses') slim.summaries.add_scalar_summary(learning_rate, 'learning_rate', 'training') return g, train_tensor