def build_detection_graph(input_type, detection_model, input_shape, output_collection_name, graph_hook_fn): """Build the detection graph.""" if input_type not in input_placeholder_fn_map: raise ValueError('Unknown input type: {}'.format(input_type)) placeholder_args = {} if input_shape is not None: if (input_type != 'image_tensor' and input_type != 'encoded_image_string_tensor' and input_type != 'tf_example'): raise ValueError( 'Can only specify input shape for `image_tensor`, ' '`encoded_image_string_tensor`, or `tf_example` ' 'inputs.') placeholder_args['input_shape'] = input_shape placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type]( **placeholder_args) outputs = _get_outputs_from_inputs( input_tensors=input_tensors, detection_model=detection_model, output_collection_name=output_collection_name) # Add global step to the graph. slim.get_or_create_global_step() if graph_hook_fn: graph_hook_fn() return outputs, placeholder_tensor
def build_detection_graph(input_type, detection_model, input_shape, output_collection_name, graph_hook_fn, use_side_inputs=False, side_input_shapes=None, side_input_names=None, side_input_types=None): """Build the detection graph.""" if input_type not in input_placeholder_fn_map: raise ValueError('Unknown input type: {}'.format(input_type)) placeholder_args = {} side_inputs = {} if input_shape is not None: if (input_type != 'image_tensor' and input_type != 'encoded_image_string_tensor' and input_type != 'tf_example' and input_type != 'tf_sequence_example'): raise ValueError( 'Can only specify input shape for `image_tensor`, ' '`encoded_image_string_tensor`, `tf_example`, ' ' or `tf_sequence_example` inputs.') placeholder_args['input_shape'] = input_shape placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type]( **placeholder_args) placeholder_tensors = {'inputs': placeholder_tensor} if use_side_inputs: for idx, side_input_name in enumerate(side_input_names): side_input_placeholder, side_input = _side_input_tensor_placeholder( side_input_shapes[idx], side_input_name, side_input_types[idx]) print(side_input) side_inputs[side_input_name] = side_input placeholder_tensors[side_input_name] = side_input_placeholder outputs = _get_outputs_from_inputs( input_tensors=input_tensors, detection_model=detection_model, output_collection_name=output_collection_name, **side_inputs) # Add global step to the graph. slim.get_or_create_global_step() if graph_hook_fn: graph_hook_fn() return outputs, placeholder_tensors
def deploy(config, model_fn, args=None, kwargs=None, optimizer=None, summarize_gradients=False): """Deploys a Slim-constructed model across multiple clones. The deployment options are specified by the config object and support deploying one or several clones on different GPUs and one or several replicas of such clones. The argument `model_fn` is called `config.num_clones` times to create the model clones as `model_fn(*args, **kwargs)`. The optional argument `optimizer` is an `Optimizer` object. If not `None`, the deployed model is configured for training with that optimizer. If `config` specifies deployment on multiple replicas then the default tensorflow device is set appropriatly for each call to `model_fn` and for the slim variable creation functions: model and global variables will be created on the `ps` device, the clone operations will be on the `worker` device. Args: config: A `DeploymentConfig` object. model_fn: A callable. Called as `model_fn(*args, **kwargs)` args: Optional list of arguments to pass to `model_fn`. kwargs: Optional list of keyword arguments to pass to `model_fn`. optimizer: Optional `Optimizer` object. If passed the model is deployed for training with that optimizer. summarize_gradients: Whether or not add summaries to the gradients. Returns: A `DeployedModel` namedtuple. """ # Gather initial summaries. summaries = set(tf.get_collection(tf.compat.v1.GraphKeys.SUMMARIES)) # Create Clones. clones = create_clones(config, model_fn, args, kwargs) first_clone = clones[0] # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. update_ops = tf.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, first_clone.scope) train_op = None total_loss = None with tf.device(config.optimizer_device()): if optimizer: # Place the global step on the device storing the variables. with tf.device(config.variables_device()): global_step = slim.get_or_create_global_step() # Compute the gradients for the clones. total_loss, clones_gradients = optimize_clones(clones, optimizer) if clones_gradients: if summarize_gradients: # Add summaries to the gradients. summaries |= set( _add_gradients_summaries(clones_gradients)) # Create gradient updates. grad_updates = optimizer.apply_gradients( clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_op = tf.identity(total_loss, name='train_op') else: clones_losses = [] regularization_losses = tf.get_collection( tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES) for clone in clones: with tf.name_scope(clone.scope): clone_loss = _gather_clone_loss(clone, len(clones), regularization_losses) if clone_loss is not None: clones_losses.append(clone_loss) # Only use regularization_losses for the first clone regularization_losses = None if clones_losses: total_loss = tf.add_n(clones_losses, name='total_loss') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.compat.v1.GraphKeys.SUMMARIES, first_clone.scope)) if total_loss is not None: # Add total_loss to summary. summaries.add(tf.summary.scalar('total_loss', total_loss)) if summaries: # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') else: summary_op = None return DeployedModel(train_op, summary_op, total_loss, clones)
def main(_): if not FLAGS.dataset_dir: raise ValueError('You must supply the dataset directory with --dataset_dir') tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): tf_global_step = slim.get_or_create_global_step() ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset( FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) #################### # Select the model # #################### network_fn = nets_factory.get_network_fn( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), is_training=False) ############################################################## # Create a dataset provider that loads data from the dataset # ############################################################## provider = slim.dataset_data_provider.DatasetDataProvider( dataset, shuffle=False, common_queue_capacity=2 * FLAGS.batch_size, common_queue_min=FLAGS.batch_size) [image, label] = provider.get(['image', 'label']) label -= FLAGS.labels_offset ##################################### # Select the preprocessing function # ##################################### preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name image_preprocessing_fn = preprocessing_factory.get_preprocessing( preprocessing_name, is_training=False) eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size image = image_preprocessing_fn(image, eval_image_size, eval_image_size) images, labels = tf.train.batch( [image, label], batch_size=FLAGS.batch_size, num_threads=FLAGS.num_preprocessing_threads, capacity=5 * FLAGS.batch_size) #################### # Define the model # #################### logits, _ = network_fn(images) #if FLAGS.quantize: # tf.contrib.quantize.create_eval_graph() if FLAGS.moving_average_decay: variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, tf_global_step) variables_to_restore = variable_averages.variables_to_restore( slim.get_model_variables()) variables_to_restore[tf_global_step.op.name] = tf_global_step else: variables_to_restore = slim.get_variables_to_restore() predictions = tf.argmax(logits, 1) labels = tf.squeeze(labels) # Define the metrics: names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ 'Accuracy': slim.metrics.streaming_accuracy(predictions, labels), 'Recall_5': slim.metrics.streaming_recall_at_k( logits, labels, 5), }) # Print the summaries to screen. for name, value in names_to_values.items(): summary_name = 'eval/%s' % name op = tf.summary.scalar(summary_name, value, collections=[]) op = tf.Print(op, [value], summary_name) tf.add_to_collection(tf.compat.v1.GraphKeys.SUMMARIES, op) # TODO(sguada) use num_epochs=1 if FLAGS.max_num_batches: num_batches = FLAGS.max_num_batches else: # This ensures that we make a single pass over all of the data. num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size)) if tf.gfile.IsDirectory(FLAGS.checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path) else: checkpoint_path = FLAGS.checkpoint_path tf.logging.info('Evaluating %s' % checkpoint_path) slim.evaluation.evaluate_once( master=FLAGS.master, checkpoint_path=checkpoint_path, logdir=FLAGS.eval_dir, num_evals=num_batches, eval_op=list(names_to_updates.values()), variables_to_restore=variables_to_restore)
def main(_): if not FLAGS.dataset_dir: raise ValueError( 'You must supply the dataset directory with --dataset_dir') tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): _ = slim.get_or_create_global_step( ) # Required when creating the session. ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) ######################### # Configure the network # ######################### inception_params = network_params.InceptionV3FCNParams( receptive_field_size=FLAGS.receptive_field_size, prelogit_dropout_keep_prob=0.8, depth_multiplier=0.1, min_depth=16, inception_fcn_stride=0, ) conv_params = network_params.ConvScopeParams( dropout=False, dropout_keep_prob=0.8, batch_norm=True, batch_norm_decay=0.99, l2_weight_decay=4e-05, ) network_fn = inception_v3_fcn.get_inception_v3_fcn_network_fn( inception_params, conv_params, num_classes=dataset.num_classes, is_training=False, ) ############################################################## # Create a dataset provider that loads data from the dataset # ############################################################## provider = slim.dataset_data_provider.DatasetDataProvider( dataset, shuffle=False, common_queue_capacity=2 * FLAGS.batch_size, common_queue_min=FLAGS.batch_size) [image, label] = provider.get(['image', 'label']) ##################################### # Select the preprocessing function # ##################################### image_preprocessing_fn = preprocessing_factory.get_preprocessing( 'inception_v3', is_training=False) eval_image_size = FLAGS.receptive_field_size image = image_preprocessing_fn(image, eval_image_size, eval_image_size) images, labels = tf.train.batch([image, label], batch_size=FLAGS.batch_size, num_threads=PREPROCESSING_THREADS, capacity=5 * FLAGS.batch_size) #################### # Define the model # #################### logits, _ = network_fn(images) variables_to_restore = slim.get_variables_to_restore() predictions = tf.argmax(logits, 1) labels = tf.squeeze(labels) # Define the metrics: names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ 'Accuracy': slim.metrics.streaming_accuracy(predictions, labels), 'Recall_2': slim.metrics.streaming_recall_at_k(logits, labels, 2), }) # Print the summaries to screen. for name, value in names_to_values.items(): summary_name = 'eval/%s' % name op = tf.summary.scalar(summary_name, value, collections=[]) op = tf.Print(op, [value], summary_name) tf.add_to_collection(tf.GraphKeys.SUMMARIES, op) # This ensures that we make a single pass over all of the data. num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size)) if tf.gfile.IsDirectory(FLAGS.checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path) else: checkpoint_path = FLAGS.checkpoint_path tf.logging.info('Evaluating %s', checkpoint_path) slim.evaluation.evaluate_once( master='', checkpoint_path=checkpoint_path, logdir=FLAGS.eval_dir, num_evals=num_batches, eval_op=list(names_to_updates.values()), session_config=tf.ConfigProto(allow_soft_placement=True), variables_to_restore=variables_to_restore)
def main(_): if not FLAGS.dataset_dir: raise ValueError('You must supply the dataset directory with --dataset_dir') tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) with tf.Graph().as_default(): tf_global_step = slim.get_or_create_global_step() ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset( FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) #################### # Select the model # #################### n_hash = FLAGS.number_hashing_functions L_vec = FLAGS.neuron_vector_length quant_params = [] for i in range(len(n_hash)): quant_params.append([int(n_hash[i]), int(L_vec[i])]) network_fn = nets_factory.get_network_fn( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), quant_params=quant_params, is_training=False) # network_fn = nets_factory.get_network_fn( # FLAGS.model_name, # num_classes=(dataset.num_classes - FLAGS.labels_offset), # is_training=False) ############################################################## # Create a dataset provider that loads data from the dataset # ############################################################## provider = slim.dataset_data_provider.DatasetDataProvider( dataset, shuffle=False, common_queue_capacity=2 * FLAGS.batch_size, common_queue_min=FLAGS.batch_size) [image, label] = provider.get(['image', 'label']) label -= FLAGS.labels_offset ##################################### # Select the preprocessing function # ##################################### preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name image_preprocessing_fn = preprocessing_factory.get_preprocessing( preprocessing_name, is_training=False) eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size image = image_preprocessing_fn(image, eval_image_size, eval_image_size) images, labels = tf.compat.v1.train.batch( [image, label], batch_size=FLAGS.batch_size, num_threads=FLAGS.num_preprocessing_threads, capacity=5 * FLAGS.batch_size) #################### # Define the model # #################### logits, _ = network_fn(images) if FLAGS.moving_average_decay: variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, tf_global_step) variables_to_restore = variable_averages.variables_to_restore( slim.get_model_variables()) variables_to_restore[tf_global_step.op.name] = tf_global_step else: variables_to_restore = slim.get_variables_to_restore() predictions = tf.argmax(input=logits, axis=1) labels = tf.squeeze(labels) # Define the metrics: #names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ names_to_values, names_to_updates = aggregate_metric_map({ #'Accuracy': slim.metrics.streaming_accuracy(predictions,labels), 'Accuracy': tf.compat.v1.metrics.accuracy(labels, predictions), ##FIXXED 'Recall_5': ( logits, labels, 5), }) # Print the summaries to screen. for name, value in names_to_values.items(): summary_name = 'eval/%s' % name op = tf.compat.v1.summary.scalar(summary_name, value, collections=[]) op = tf.compat.v1.Print(op, [value], summary_name) tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.SUMMARIES, op) # TODO(sguada) use num_epochs=1 if FLAGS.max_num_batches: num_batches = FLAGS.max_num_batches else: # This ensures that we make a single pass over all of the data. num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size)) if tf.io.gfile.isdir(FLAGS.checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path) else: checkpoint_path = FLAGS.checkpoint_path tf.compat.v1.logging.info('Evaluating %s' % checkpoint_path) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth=True # config.log_device_placement=True slim.evaluation.evaluate_once( master=FLAGS.master, checkpoint_path=checkpoint_path, logdir=FLAGS.eval_dir, num_evals=num_batches, eval_op=list(names_to_updates.values()), session_config=config, variables_to_restore=variables_to_restore)
def predict(model_root, datasets_dir, model_name, test_image_name): with tf.Graph().as_default(): tf_global_step = slim.get_or_create_global_step() test_image = os.path.join(datasets_dir, test_image_name) # dataset = convert_data.get_datasets('test',dataset_dir=datasets_dir) network_fn = net_select.get_network_fn(model_name, num_classes=20, is_training=False) batch_size = 1 eval_image_size = network_fn.default_image_size # images, images_raw, labels = load_batch(datasets_dir, # height=eval_image_size, # width=eval_image_size) image_preprocessing_fn = preprocessing_select.get_preprocessing( model_name, is_training=False) image_data = tf.io.read_file(test_image) image_data = tf.image.decode_jpeg(image_data, channels=3) image_data = image_preprocessing_fn(image_data, eval_image_size, eval_image_size) image_data = tf.expand_dims(image_data, 0) logits_1, end_points_1 = network_fn(image_data) attention_maps = tf.reduce_mean(end_points_1['attention_maps'], axis=-1, keepdims=True) attention_maps = tf.image.resize(attention_maps, [eval_image_size, eval_image_size], method=tf.image.ResizeMethod.BILINEAR) bboxes = tf_v1.py_func(mask2bbox, [attention_maps], [tf.float32]) bboxes = tf.reshape(bboxes, [batch_size, 4]) # print(bboxes) box_ind = tf.range(batch_size, dtype=tf.int32) images = tf.image.crop_and_resize( image_data, bboxes, box_ind, crop_size=[eval_image_size, eval_image_size]) logits_2, end_points_2 = network_fn(images, reuse=True) logits = tf.math.log( tf.nn.softmax(logits_1) * 0.5 + tf.nn.softmax(logits_2) * 0.5) checkpoint_path = os.path.join(model_root, model_name) if tf.io.gfile.isdir(checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) else: checkpoint_path = checkpoint_path init_fn = slim.assign_from_checkpoint_fn( checkpoint_path, slim.get_variables_to_restore()) # with tf_v1.Session() as sess: # with slim.queues.QueueRunners(sess): # sess.run(tf_v1.initialize_local_variables()) # init_fn(sess) # np_probabilities, np_images_raw, np_labels = sess.run([logits, images_raw, labels]) # # for i in range(batch_size): # image = np_images_raw[i, :, :, :] # true_label = np_labels[i] # predicted_label = np.argmax(np_probabilities[i, :]) # print('true is {}, predict is {}'.format(true_label, predicted_label)) with tf_v1.Session() as sess: with slim.queues.QueueRunners(sess): sess.run(tf_v1.initialize_local_variables()) init_fn(sess) np_images, np_probabilities = sess.run([image_data, logits]) predicted_label = np.argmax(np_probabilities[0, :]) print(predicted_label)
def main(model_root, datasets_dir, model_name, test_image_name): with tf.Graph().as_default(): tf_global_step = slim.get_or_create_global_step() test_image = os.path.join(datasets_dir, test_image_name) dataset = convert_data.get_datasets('train', dataset_dir=datasets_dir) network_fn = net_select.get_network_fn(model_name, num_classes=dataset.num_classes, is_training=False) provider = slim.dataset_data_provider.DatasetDataProvider( dataset, shuffle=False, common_queue_capacity=20 * batch_size, common_queue_min=10 * batch_size) [image, label] = provider.get(['image', 'label']) image_preprocessing_fn = preprocessing_select.get_preprocessing( model_name, is_training=False) eval_image_size = network_fn.default_image_size image = image_preprocessing_fn(image, eval_image_size, eval_image_size) images, labels = tf_v1.train.batch( [image, label], batch_size=batch_size, num_threads=num_preprocessing_threads, capacity=5 * batch_size) checkpoint_path = os.path.join(model_root, model_name) if tf.io.gfile.isdir(checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) else: checkpoint_path = checkpoint_path logits_1, end_points_1 = network_fn(images) attention_maps = tf.reduce_mean(end_points_1['attention_maps'], axis=-1, keepdims=True) attention_maps = tf.image.resize(attention_maps, [eval_image_size, eval_image_size], method=tf.image.ResizeMethod.BILINEAR) bboxes = tf_v1.py_func(mask2bbox, [attention_maps], [tf.float32]) bboxes = tf.reshape(bboxes, [batch_size, 4]) box_ind = tf.range(batch_size, dtype=tf.int32) images = tf.image.crop_and_resize( images, bboxes, box_ind, crop_size=[eval_image_size, eval_image_size]) logits_2, end_points_2 = network_fn(images, reuse=True) logits = tf_v1.log( tf.nn.softmax(logits_1) * 0.5 + tf.nn.softmax(logits_2) * 0.5) """ tf_v1.enable_eager_execution() #测试单张图片 image_data = tf.io.read_file(test_image) image_data = tf.image.decode_jpeg(image_data,channels= 3) # plt.figure(1) # plt.imshow(image_data) image_data = image_preprocessing_fn(image_data, eval_image_size, eval_image_size) image_data = tf.expand_dims(image_data, 0) logits_3,end_points_3 = network_fn(image_data,reuse =True) attention_map = tf.reduce_mean(end_points_3['attention_maps'], axis=-1, keepdims=True) attention_map = tf.image.resize(attention_map, [eval_image_size, eval_image_size], method=tf.image.ResizeMethod.BILINEAR) bboxes = tf_v1.py_func(mask2bbox, [attention_map], [tf.float32]) bboxes = tf.reshape(bboxes, [batch_size, 4]) box_ind = tf.range(batch_size, dtype=tf.int32) image_data = tf.image.crop_and_resize(images, bboxes, box_ind, crop_size=[eval_image_size, eval_image_size]) logits_4, end_points_4 = network_fn(image_data, reuse=True) logits_0 = tf_v1.log(tf.nn.softmax(logits_3) * 0.5 + tf.nn.softmax(logits_4) * 0.5) probabilities = logits_0[0,0:] print(probabilities) # sorted_inds = [i[0] for i in sorted(enumerate(-probabilities),key= lambda x:x[1])] sorted_inds = (np.argsort(probabilities.numpy())[::-1]) train_info = sio.loadmat(os.path.join(datasets_dir, 'devkit', 'cars_train_annos.mat'))['annotations'][0] names = train_info['class'] print(names) for i in range(5): index = sorted_inds[i] # 打印top5的预测类别和相应的概率值。 print('Probability %0.2f => [%s]' % (probabilities[index],names[index+1][0][0])) """ if moving_average_decay: variable_averages = tf.train.ExponentialMovingAverage( moving_average_decay, tf_global_step) variables_to_restore = variable_averages.variables_to_restore( slim.get_model_variables()) variables_to_restore[tf_global_step.op.name] = tf_global_step else: variables_to_restore = slim.get_variables_to_restore() logits_to_updates = add_eval_summary(logits, labels, scope='/bilinear') logits_1_to_updates = add_eval_summary(logits_1, labels, scope='/logits_1') logits_2_to_updates = add_eval_summary(logits_2, labels, scope='/logits_2') if max_num_batches: num_batches = max_num_batches else: # This ensures that we make a single pass over all of the data. num_batches = math.ceil(dataset.num_samples / float(batch_size)) config = tf_v1.ConfigProto(allow_soft_placement=True, log_device_placement=False) config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = 1.0 tf.compat.v1.disable_eager_execution() while True: if tf.io.gfile.isdir(checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) else: checkpoint_path = checkpoint_path print('Evaluating %s' % checkpoint_path) eval_op = [] # eval_op = list(logits_to_updates.values()) eval_op.append(list(logits_to_updates.values())) eval_op.append(list(logits_1_to_updates.values())) eval_op.append(list(logits_2_to_updates.values())) # tf.convert_to_tensor(eval_op) # tf.cast(eval_op,dtype=tf.string) # print(eval_op) test_dir = checkpoint_path slim.evaluation.evaluate_once( master=' ', checkpoint_path=checkpoint_path, logdir=test_dir, num_evals=num_batches, eval_op=eval_op, variables_to_restore=variables_to_restore, final_op=None, session_config=config)
def deploy(config, model_fn, args=None, kwargs=None, optimizer=None, summarize_gradients=False): # Gather initial summaries. summaries = set( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES)) # Create Clones. clones = create_clones(config, model_fn, args, kwargs) first_clone = clones[0] update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, first_clone.scope) train_op = None total_loss = None with tf.device(config.optimizer_device()): if optimizer: # Place the global step on the device storing the variables. with tf.device(config.variables_device()): global_step = slim.get_or_create_global_step() # Compute the gradients for the clones. total_loss, clones_gradients = optimize_clones(clones, optimizer) if clones_gradients: if summarize_gradients: summaries |= set( _add_gradients_summaries(clones_gradients)) # Create gradient updates. grad_updates = optimizer.apply_gradients( clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_op = tf.identity(total_loss, name='train_op') else: clones_losses = [] regularization_losses = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES) for clone in clones: with tf.name_scope(clone.scope): clone_loss = _gather_clone_loss(clone, len(clones), regularization_losses) if clone_loss is not None: clones_losses.append(clone_loss) regularization_losses = None if clones_losses: total_loss = tf.add_n(clones_losses, name='total_loss') summaries |= set( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES, first_clone.scope)) if total_loss is not None: # Add total_loss to summary. summaries.add(tf.summary.scalar('total_loss', total_loss)) if summaries: # Merge all summaries together. summary_op = tf.compat.v1.summary.merge(list(summaries), name='summary_op') else: summary_op = None return DeployedModel(train_op, summary_op, total_loss, clones)