clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) first_clone_scope = deploy_config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) with tf.name_scope('synchronized_train'): with tf.device(deploy_config.optimizer_device()): learning_rate = tf.train.exponential_decay( args.learning_rate, global_step, args.learning_rate_decay_steps, args.learning_rate_decay, staircase=True, name='exponential_decay_learning_rate') optimizer = tf.train.AdamOptimizer(learning_rate) variables_to_train = tf.trainable_variables() total_loss, clones_gradients = model_deploy.optimize_clones( clones, optimizer, var_list=variables_to_train) grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) train_tensor = control_flow_ops.with_dependencies([update_op], total_loss, name='train_op') with tf.name_scope('summaries'): end_points = clones[0].outputs for end_point in end_points: x = end_points[end_point] summaries.add(tf.summary.histogram('activations/' + end_point, x)) summaries.add( tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x)))
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) # Set up deployment (i.e., multi-GPUs and/or multi-replicas). config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.num_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Split the batch across GPUs. assert FLAGS.train_batch_size % config.num_clones == 0, ( 'Training batch size not divisble by number of clones (GPUs).') clone_batch_size = FLAGS.train_batch_size // config.num_clones # Get dataset-dependent information. dataset = segmentation_dataset.get_dataset(FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir) tf.gfile.MakeDirs(FLAGS.train_logdir) tf.logging.info('Training on %s set', FLAGS.train_split) with tf.Graph().as_default() as graph: with tf.device(config.inputs_device()): samples = input_generator.get( dataset, FLAGS.train_crop_size, clone_batch_size, min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, min_scale_factor=FLAGS.min_scale_factor, max_scale_factor=FLAGS.max_scale_factor, scale_factor_step_size=FLAGS.scale_factor_step_size, dataset_split=FLAGS.train_split, is_training=True, model_variant=FLAGS.model_variant) inputs_queue = prefetch_queue.prefetch_queue(samples, capacity=128 * config.num_clones) # Create the global step on the device storing the variables. with tf.device(config.variables_device()): global_step = tf.train.get_or_create_global_step() # Define the model and create clones. model_fn = _build_deeplab model_args = (inputs_queue, { common.OUTPUT_TYPE: dataset.num_classes }, dataset.ignore_label) clones = model_deploy.create_clones(config, model_fn, args=model_args) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. first_clone_scope = config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Add summaries for model variables. for model_var in slim.get_model_variables(): summaries.add(tf.summary.histogram(model_var.op.name, model_var)) # Add summaries for images, labels, semantic predictions if FLAGS.save_summaries_images: summary_image = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/')) summaries.add( tf.summary.image('samples/%s' % common.IMAGE, summary_image)) first_clone_label = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/')) # Scale up summary image pixel values for better visualization. pixel_scaling = max(1, 255 // dataset.num_classes) summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('samples/%s' % common.LABEL, summary_label)) first_clone_output = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/')) predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1) summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('samples/%s' % common.OUTPUT_TYPE, summary_predictions)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Build the optimizer based on the device specification. with tf.device(config.optimizer_device()): learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) with tf.device(config.variables_device()): total_loss, grads_and_vars = model_deploy.optimize_clones( clones, optimizer) total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') summaries.add(tf.summary.scalar('total_loss', total_loss)) # Modify the gradients for biases and last layer variables. last_layers = model.get_extra_layer_scopes( FLAGS.last_layers_contain_logits_only) grad_mult = train_utils.get_model_gradient_multipliers( last_layers, FLAGS.last_layer_gradient_multiplier) if grad_mult: grads_and_vars = slim.learning.multiply_gradients( grads_and_vars, grad_mult) # Create gradient update op. grad_updates = optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) # Start the training. slim.learning.train(train_tensor, logdir=FLAGS.train_logdir, log_every_n_steps=FLAGS.log_steps, master=FLAGS.master, number_of_steps=FLAGS.training_number_of_steps, is_chief=(FLAGS.task == 0), session_config=session_config, startup_delay_steps=startup_delay_steps, init_fn=train_utils.get_model_init_fn( FLAGS.train_logdir, FLAGS.tf_initial_checkpoint, FLAGS.initialize_last_layer, last_layers, ignore_missing_vars=True), summary_op=summary_op, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def main(_): if not os.path.isdir(FLAGS.train_dir): os.makedirs(FLAGS.train_dir) if not FLAGS.dataset_dir: raise ValueError( 'You must supply the dataset directory with --dataset_dir') if not FLAGS.aug_mode: raise ValueError('aug_mode need to be speficied.') if (not FLAGS.train_image_height) or (not FLAGS.train_image_width): raise ValueError( 'The image height and width must be define explicitly.') if FLAGS.hd_data: if FLAGS.train_image_height != 400 or FLAGS.train_image_width != 200: FLAGS.train_image_height, FLAGS.train_image_width = 400, 200 print("set the image size to (%d, %d)" % (400, 200)) # config and print log config_and_print_log(FLAGS) tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Create global_step with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() ##################################### # Select the preprocessing function # ##################################### img_func = get_img_func() ###################### # Select the dataset # ###################### dataset = dataset_factory.DataLoader(FLAGS.model_name, FLAGS.dataset_name, FLAGS.dataset_dir, FLAGS.set, FLAGS.hd_data, img_func, FLAGS.batch_size, FLAGS.batch_k, FLAGS.max_number_of_steps, get_pair_type()) ###################### # Select the network # ###################### network_fn = nets_factory.get_network_fn( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay, is_training=True, sample_number=FLAGS.sample_number) #################### # Define the model # #################### def clone_fn(tf_batch_queue): return build_graph(tf_batch_queue, network_fn) clones = model_deploy.create_clones(deploy_config, clone_fn, [dataset.tf_batch_queue]) first_clone_scope = deploy_config.clone_scope(0) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by network_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Add summaries for end_points. end_points = clones[0].outputs # Add summaries for losses. loss_dict = {} for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): if loss.name == 'softmax_cross_entropy_loss/value:0': loss_dict['clf'] = loss elif 'softmax_cross_entropy_loss' in loss.name: loss_dict['sample_clf_' + str(loss.name.split('/')[0].split('_')[-1])] = loss elif 'entropy' in loss.name: loss_dict['entropy'] = loss else: raise Exception('Loss type error') ################################# # Configure the moving averages # ################################# if FLAGS.moving_average_decay: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) else: moving_average_variables, variable_averages = None, None ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.optimizer_device()): learning_rate = _configure_learning_rate(dataset.num_samples, global_step, FLAGS) optimizer = _configure_optimizer(learning_rate) if FLAGS.sync_replicas: # If sync_replicas is enabled, the averaging will be done in the chief # queue runner. optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=FLAGS.replicas_to_aggregate, variable_averages=variable_averages, variables_to_average=moving_average_variables, replica_id=tf.constant(FLAGS.task, tf.int32, shape=()), total_num_replicas=FLAGS.worker_replicas) elif FLAGS.moving_average_decay: # Update ops executed locally by trainer. update_ops.append( variable_averages.apply(moving_average_variables)) # Variables to train. variables_to_train = _get_variables_to_train() # and returns a train_tensor and summary_op # total_loss is the sum of all LOSSES and REGULARIZATION_LOSSES in tf.GraphKeys total_loss, clones_gradients = model_deploy.optimize_clones( clones, optimizer, var_list=variables_to_train) # Create gradient updates. grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) train_tensor = control_flow_ops.with_dependencies([update_op], total_loss, name='train_op') train_tensor_list = [train_tensor] format_str = 'step %d, loss = %.2f' for loss_key in sorted(loss_dict.keys()): train_tensor_list.append(loss_dict[loss_key]) format_str += (', %s_loss = ' % loss_key + '%.8f') format_str += ' (%.1f examples/sec; %.3f sec/batch)' # Create a saver. saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') ########################### # Kicks off the training. # ########################### # Build an initialization operation to run below. init = tf.global_variables_initializer() # Start running operations on the Graph. allow_soft_placement must be set to # True to build towers on GPU, as some of the ops do not have GPU # implementations. sess = tf.Session(config=tf.ConfigProto( allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement)) sess.run(init) # load pretrained weights if FLAGS.checkpoint_path is not None: print("Load the pretrained weights") weight_ini_fn = _get_init_fn() weight_ini_fn(sess) else: print("Train from the scratch") # Start the queue runners. tf.train.start_queue_runners(sess=sess) # for step in xrange(FLAGS.max_number_of_steps): for step in xrange(FLAGS.max_number_of_steps + 1): start_time = time.time() loss_value_list = sess.run(train_tensor_list, feed_dict=dataset.get_feed_dict()) duration = time.time() - start_time # assert not np.isnan(loss_value), 'Model diverged with loss = NaN' if step % FLAGS.log_every_n_steps == 0: # num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus num_examples_per_step = FLAGS.batch_size examples_per_sec = num_examples_per_step / duration # sec_per_batch = duration / FLAGS.num_gpus sec_per_batch = duration print(format_str % tuple([step] + loss_value_list + [examples_per_sec, sec_per_batch])) # Save the model checkpoint periodically. # if step % FLAGS.model_snapshot_steps == 0 or (step + 1) == FLAGS.max_number_of_steps: if step % FLAGS.model_snapshot_steps == 0: saver.save(sess, checkpoint_path, global_step=step) print('OK...')
def main(_): # if not FLAGS.dataset_dir: # raise ValueError('You must supply the dataset directory with --dataset_dir') tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Create global_step with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() ###################### # Select the dataset # ###################### # dataset = dataset_factory.get_dataset( # FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) ###################### # Select the network # ###################### # network_fn = nets_factory.get_network_fn( # FLAGS.model_name, # num_classes=(dataset.num_classes - FLAGS.labels_offset), # weight_decay=FLAGS.weight_decay, # is_training=True) def localization_net_alpha(inputs, num_transformer, num_theta_params): """ Utilize inception_v2 as the localization net of spatial transformer """ # outputs 7*7*1024: default final_endpoint='Mixed_5c' before full connection layer with tf.variable_scope('inception_net'): net, _ = inception_v2.inception_v2_base(inputs) # fc layer using [1, 1] convolution kernel: 1*1*1024 with tf.variable_scope('logits'): net = slim.conv2d(net, 128, [1, 1], scope='conv2d_a_1x1') kernel_size = inception_v2._reduced_kernel_size_for_small_input(net, [7, 7]) net = slim.conv2d(net, 128, kernel_size, padding='VALID', scope='conv2d_b_{}x{}'.format(*kernel_size)) init_biase = tf.constant_initializer([1.1, .0, 1.1, .0] * num_transformer) logits = slim.conv2d(net, num_transformer * num_theta_params, [1, 1], weights_initializer=tf.truncated_normal_initializer(stddev=0.1), biases_initializer=init_biase, normalizer_fn=None, activation_fn=tf.nn.tanh, scope='conv2d_c_1x1') return tf.squeeze(logits, [1, 2]) def _inception_logits(inputs, num_outputs, dropout_keep_prob, activ_fn=None): with tf.variable_scope('logits'): kernel_size = inception_v2._reduced_kernel_size_for_small_input(inputs, [7, 7]) # shape ?*1*1*? net = slim.avg_pool2d(inputs, kernel_size, padding='VALID') # drop out neuron before fc conv net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='dropout') # [1, 1] fc conv logits = slim.conv2d(net, num_outputs, [1, 1], normalizer_fn=None, activation_fn=activ_fn, scope='conv2_a_1x1') return tf.squeeze(logits, [1, 2]) def network_fn(inputs): """Fine grained classification with multiplex spatial transformation channels utilizing inception nets """ end_points = {} arg_scope = inception_v2.inception_v2_arg_scope(weight_decay=FLAGS.weight_decay) with slim.arg_scope(arg_scope): with tf.variable_scope('stn'): with tf.variable_scope('localization'): transformer_theta = localization_net_alpha(inputs, NUM_TRANSFORMER, NUM_THETA_PARAMS) transformer_theta_split = tf.split(transformer_theta, NUM_TRANSFORMER, axis=1) end_points['stn/localization/transformer_theta'] = transformer_theta transformer_outputs = [] for theta in transformer_theta_split: transformer_outputs.append( transformer(inputs, theta, transformer_output_size, sampling_kernel='bilinear')) inception_outputs = [] transformer_outputs_shape = [FLAGS.batch_size, transformer_output_size[0], transformer_output_size[1], 3] with tf.variable_scope('classification'): for path_idx, inception_inputs in enumerate(transformer_outputs): with tf.variable_scope('path_{}'.format(path_idx)): inception_inputs.set_shape(transformer_outputs_shape) net, _ = inception_v2.inception_v2_base(inception_inputs) inception_outputs.append(net) # concatenate the endpoints: num_batch*7*7*(num_transformer*1024) multipath_outputs = tf.concat(inception_outputs, axis=-1) # final fc layer logits classification_logits = _inception_logits(multipath_outputs, NUM_CLASSES, dropout_keep_prob) end_points['stn/classification/logits'] = classification_logits return classification_logits, end_points ##################################### # Select the preprocessing function # ##################################### # preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name # image_preprocessing_fn = preprocessing_factory.get_preprocessing( # preprocessing_name, # is_training=True) def image_preprocessing_fn(image, out_height, out_width): if image.dtype != tf.float32: image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.image.central_crop(image, central_fraction=0.975) image = tf.expand_dims(image, 0) image = tf.image.resize_bilinear(image, [out_height, out_width], align_corners=False) image = tf.squeeze(image, [0]) image = tf.image.random_flip_left_right(image) image = tf.subtract(image, 0.5) image = tf.multiply(image, 2.0) image.set_shape((out_height, out_width, 3)) return image ############################################################## # Create a dataset provider that loads data from the dataset # ############################################################## def _get_filename_list(file_dir, file): filename_path = os.path.join(file_dir, file) filename_list = [] cls_label_list = [] with open(filename_path, 'r') as f: for line in f: filename, label, nid, attr = line.strip().split(',') filename_list.append(filename) cls_label_list.append(int(label)) return filename_list, cls_label_list with tf.device(deploy_config.inputs_device()): # create the filename and label example filename_list, label_list = _get_filename_list(filename_dir, file) num_samples = len(filename_list) filename, label = tf.train.slice_input_producer([filename_list, label_list], num_epochs) # decode and preprocess the image file_content = tf.read_file(filename) image = tf.image.decode_jpeg(file_content, channels=3) train_image_size = FLAGS.train_image_size or default_image_size image = image_preprocessing_fn(image, train_image_size, train_image_size) images, labels = tf.train.batch( [image, label], batch_size=FLAGS.batch_size, num_threads=FLAGS.num_preprocessing_threads, capacity=5 * FLAGS.batch_size) labels = slim.one_hot_encoding( labels, NUM_CLASSES - FLAGS.labels_offset) batch_queue = slim.prefetch_queue.prefetch_queue( [images, labels], capacity=2 * deploy_config.num_clones) #################### # Define the model # #################### def clone_fn(batch_queue): """Allows data parallelism by creating multiple clones of network_fn.""" with tf.device(deploy_config.inputs_device()): images, labels = batch_queue.dequeue() logits, end_points = network_fn(images) ############################# # Specify the loss function # ############################# if 'AuxLogits' in end_points: tf.losses.softmax_cross_entropy( logits=end_points['AuxLogits'], onehot_labels=labels, label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss') tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels, label_smoothing=FLAGS.label_smoothing, weights=1.0) return end_points # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) first_clone_scope = deploy_config.clone_scope(0) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by network_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Add summaries for end_points. end_points = clones[0].outputs for end_point in end_points: x = end_points[end_point] summaries.add(tf.summary.histogram('activations/' + end_point, x)) summaries.add(tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x))) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Add summaries for variables. for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) ################################# # Configure the moving averages # ################################# if FLAGS.moving_average_decay: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) else: moving_average_variables, variable_averages = None, None ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.optimizer_device()): learning_rate_loc = _configure_learning_rate_loc(num_samples, global_step) learning_rate_cls = _configure_learning_rate_cls(num_samples, global_step) optimizer_loc = _configure_optimizer(learning_rate_loc) optimizer_cls = _configure_optimizer(learning_rate_cls) summaries.add(tf.summary.scalar('learning_rate_loc', learning_rate_loc)) summaries.add(tf.summary.scalar('learning_rate_cls', learning_rate_cls)) if FLAGS.sync_replicas: # If sync_replicas is enabled, the averaging will be done in the chief # queue runner. optimizer_loc = tf.train.SyncReplicasOptimizer( opt=optimizer_loc, replicas_to_aggregate=FLAGS.replicas_to_aggregate, variable_averages=variable_averages, variables_to_average=moving_average_variables, replica_id=tf.constant(FLAGS.task, tf.int32, shape=()), total_num_replicas=FLAGS.worker_replicas) optimizer_cls = tf.train.SyncReplicasOptimizer( opt=optimizer_cls, replicas_to_aggregate=FLAGS.replicas_to_aggregate, variable_averages=variable_averages, variables_to_average=moving_average_variables, replica_id=tf.constant(FLAGS.task, tf.int32, shape=()), total_num_replicas=FLAGS.worker_replicas) elif FLAGS.moving_average_decay: # Update ops executed locally by trainer. update_ops.append(variable_averages.apply(moving_average_variables)) # Variables to train. # variables_to_train = _get_variables_to_train() loc_vars_to_train = _get_localization_vars_to_train(loc_train_vars_scope) cls_vars_to_train = _get_classification_vars_to_train(cls_train_vars_scope) # and returns a train_tensor and summary_op _, clones_gradients_loc = model_deploy.optimize_clones( clones, optimizer_loc, var_list=loc_vars_to_train) total_loss, clones_gradients_cls = model_deploy.optimize_clones( clones, optimizer_cls, var_list=cls_vars_to_train) # Add total_loss to summary. summaries.add(tf.summary.scalar('total_loss', total_loss)) # Create gradient updates. grad_updates_loc = optimizer_loc.apply_gradients(clones_gradients_loc) grad_updates_cls = optimizer_cls.apply_gradients(clones_gradients_cls, global_step=global_step) update_ops.append(grad_updates_loc) update_ops.append(grad_updates_cls) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') ########################### # Kicks off the training. # ########################### slim.learning.train( train_tensor, logdir=FLAGS.train_dir, master=FLAGS.master, is_chief=(FLAGS.task == 0), init_fn=_get_init_fn(), summary_op=summary_op, number_of_steps=FLAGS.max_number_of_steps, log_every_n_steps=FLAGS.log_every_n_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, sync_optimizer=None)
def main(_): if not FLAGS.dataset_dir: raise ValueError( 'You must supply the dataset directory with --dataset_dir') tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): # Config model_deploy # deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Create global_step with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() # Select the dataset # dataset = nsfw.get_split(FLAGS.dataset_split_name, FLAGS.dataset_dir) # Select the network # network_fn = nets_factory.get_network_fn( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay, is_training=True) # Select the preprocessing function # preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name image_preprocessing_fn = preprocessing_factory.get_preprocessing( preprocessing_name, is_training=True) # Create a dataset provider that loads data from the dataset # with tf.device(deploy_config.inputs_device()): provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=FLAGS.num_readers, common_queue_capacity=20 * FLAGS.batch_size, common_queue_min=10 * FLAGS.batch_size) [image, label] = provider.get(['image', 'label']) label -= FLAGS.labels_offset train_image_size = FLAGS.train_image_size or network_fn.default_image_size image = image_preprocessing_fn(image, train_image_size, train_image_size) images, labels = tf.train.batch( [image, label], batch_size=FLAGS.batch_size, num_threads=FLAGS.num_preprocessing_threads, capacity=5 * FLAGS.batch_size) labels = slim.one_hot_encoding( labels, dataset.num_classes - FLAGS.labels_offset) batch_queue = slim.prefetch_queue.prefetch_queue( [images, labels], capacity=2 * deploy_config.num_clones) # Define the model # def clone_fn(batch_queue): """Allows data parallelism by creating multiple clones of network_fn.""" images, labels = batch_queue.dequeue() logits, end_points = network_fn(images) # Specify the loss function # if 'AuxLogits' in end_points: slim.losses.softmax_cross_entropy( end_points['AuxLogits'], labels, label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss') slim.losses.softmax_cross_entropy( logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0) return end_points # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) first_clone_scope = deploy_config.clone_scope(0) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by network_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Add summaries for end_points. end_points = clones[0].outputs for end_point in end_points: x = end_points[end_point] summaries.add(tf.summary.histogram('activations/' + end_point, x)) summaries.add( tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x))) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Add summaries for variables. for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) # Configure the moving averages # if FLAGS.moving_average_decay: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) else: moving_average_variables, variable_averages = None, None # Configure the optimization procedure. # with tf.device(deploy_config.optimizer_device()): learning_rate = _configure_learning_rate(dataset.num_samples, global_step) optimizer = _configure_optimizer(learning_rate) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) if FLAGS.sync_replicas: # If sync_replicas is enabled, the averaging will be done in the chief # queue runner. optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=FLAGS.replicas_to_aggregate, total_num_replicas=FLAGS.worker_replicas, variable_averages=variable_averages, variables_to_average=moving_average_variables) elif FLAGS.moving_average_decay: # Update ops executed locally by trainer. update_ops.append( variable_averages.apply(moving_average_variables)) # Variables to train. variables_to_train = _get_variables_to_train() # and returns a train_tensor and summary_op total_loss, clones_gradients = model_deploy.optimize_clones( clones, optimizer, var_list=variables_to_train) # Add total_loss to summary. summaries.add(tf.summary.scalar('total_loss', total_loss)) # Create gradient updates. grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') # Kicks off the training. # slim.learning.train( train_tensor, logdir=FLAGS.train_dir, master=FLAGS.master, is_chief=(FLAGS.task == 0), init_fn=_get_init_fn(), summary_op=summary_op, number_of_steps=FLAGS.max_number_of_steps, log_every_n_steps=FLAGS.log_every_n_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Create global_step with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() with tf.device(deploy_config.inputs_device()): iterator = coco.get_dataset(FLAGS.train_data_file, batch_size=FLAGS.batch_size, num_epochs=500, buffer_size=250 * FLAGS.num_clones, num_parallel_calls=4 * FLAGS.num_clones, crop_height=FLAGS.height, crop_width=FLAGS.width, resize_shape=FLAGS.width, data_augment=True) def clone_fn(iterator): with tf.device(deploy_config.inputs_device()): batch_image, batch_labels = iterator.get_next() s = batch_labels.get_shape().as_list() batch_labels.set_shape([FLAGS.batch_size, s[1], s[2], s[3]]) s = batch_image.get_shape().as_list() batch_image.set_shape([FLAGS.batch_size, s[1], s[2], s[3]]) num_classes = coco.num_classes() logits, end_points = resseg_model( batch_image, FLAGS.height, FLAGS.width, FLAGS.scale, FLAGS.weight_decay, FLAGS.use_seperable_convolution, num_classes, is_training=True, use_batch_norm=FLAGS.use_batch_norm, num_units=FLAGS.num_units, filter_depth_multiplier=FLAGS.filter_depth_multiplier) s = logits.get_shape().as_list() with tf.device(deploy_config.inputs_device()): lmap_size = 256 lmap = np.array([0] * lmap_size) for k, v in coco.id2trainid_objects.items(): lmap[k] = v + 1 lmap = tf.constant(lmap, tf.uint8) down_labels = tf.cast(batch_labels, tf.int32) label_mask = tf.squeeze((down_labels < 255)) down_labels = tf.gather(lmap, down_labels) down_labels = tf.cast(down_labels, tf.int32) down_labels = tf.reshape( down_labels, tf.TensorShape([FLAGS.batch_size, s[1], s[2]])) down_labels = tf.cast(label_mask, tf.int32) * down_labels fg_weights = tf.constant(FLAGS.foreground_weight, dtype=tf.int32, shape=label_mask.shape) label_weights = tf.cast(label_mask, tf.int32) * fg_weights # Specify the loss cross_entropy = tf.losses.sparse_softmax_cross_entropy( down_labels, logits, weights=label_weights, scope='xentropy') tf.losses.add_loss(cross_entropy) return end_points, batch_image, down_labels, logits # Gather initial summaries summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, clone_fn, [iterator]) first_clone_scope = deploy_config.clone_scope(0) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by network_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) ################################# # Configure the moving averages # ################################# if FLAGS.moving_average_decay: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) else: moving_average_variables, variable_averages = None, None ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.optimizer_device()): learning_rate = _configure_learning_rate( FLAGS.num_samples_per_epoch, global_step, deploy_config.num_clones) optimizer = _configure_optimizer(learning_rate) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) if FLAGS.sync_replicas: # If sync_replicas is enabled, the averaging will be done in the chief # queue runner. optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=FLAGS.replicas_to_aggregate, variable_averages=variable_averages, variables_to_average=moving_average_variables, replica_id=tf.constant(FLAGS.task, tf.int32, shape=()), total_num_replicas=FLAGS.worker_replicas) elif FLAGS.moving_average_decay: # Update ops executed locally by trainer. update_ops.append( variable_averages.apply(moving_average_variables)) end_points, batch_image, down_labels, logits = clones[0].outputs cmap = np.array(coco.id2color) cmap = tf.constant(cmap, tf.uint8) seg_map = tf.gather(cmap, down_labels) predictions = tf.argmax(logits, axis=3) pred_map = tf.gather(cmap, predictions) summaries.add(tf.summary.image('labels', seg_map)) summaries.add(tf.summary.image('predictions', pred_map)) summaries.add(tf.summary.image('images', batch_image)) # Variables to train. variables_to_train = _get_variables_to_train() # Returns a train_tensor and summary_op total_loss, clones_gradients = model_deploy.optimize_clones( clones, optimizer, var_list=variables_to_train) # Add total_loss to summary. summaries.add(tf.summary.scalar('total_loss', total_loss)) # Create gradient updates. grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) train_tensor = control_flow_ops.with_dependencies([update_op], total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') if FLAGS.sync_replicas: sync_optimizer = opt startup_delay_steps = 0 else: sync_optimizer = None startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps ########################### # Kick off the training. # ########################### slim.learning.train(train_tensor, logdir=FLAGS.train_dir, master=FLAGS.master, is_chief=(FLAGS.task == 0), init_fn=_get_init_fn(), summary_op=summary_op, number_of_steps=FLAGS.max_number_of_steps, log_every_n_steps=FLAGS.log_every_n_steps, startup_delay_steps=startup_delay_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, sync_optimizer=sync_optimizer)
def main(model_root, datasets_dir, model_name): # tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) # 训练相关参数设置 with tf.Graph().as_default(): deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=False, replica_id=task, num_replicas=worker_replicas, num_ps_tasks=num_ps_tasks) global_step = slim.create_global_step() train_dir = os.path.join(model_root, model_name) dataset = convert_data.get_datasets('train', dataset_dir=datasets_dir) network_fn = net_select.get_network_fn(model_name, num_classes=dataset.num_classes, weight_decay=weight_decay, is_training=True) image_preprocessing_fn = preprocessing_select.get_preprocessing( model_name, is_training=True) print("the data_sources:", dataset.data_sources) with tf.device(deploy_config.inputs_device()): provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=num_readers, common_queue_capacity=20 * batch_size, common_queue_min=10 * batch_size) [image, label] = provider.get(['image', 'label']) train_image_size = network_fn.default_image_size image = image_preprocessing_fn(image, train_image_size, train_image_size) images, labels = tf.compat.v1.train.batch( [image, label], batch_size=batch_size, num_threads=num_preprocessing_threads, capacity=5 * batch_size) labels = slim.one_hot_encoding(labels, dataset.num_classes) batch_queue = slim.prefetch_queue.prefetch_queue( [images, labels], capacity=2 * deploy_config.num_clones) def calculate_pooling_center_loss(features, label, alfa, nrof_classes, weights, name): features = tf.reshape(features, [features.shape[0], -1]) label = tf.argmax(label, 1) nrof_features = features.get_shape()[1] centers = tf.compat.v1.get_variable( name, [nrof_classes, nrof_features], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False) label = tf.reshape(label, [-1]) centers_batch = tf.gather(centers, label) centers_batch = tf.nn.l2_normalize(centers_batch, axis=-1) diff = (1 - alfa) * (centers_batch - features) centers = tf.compat.v1.scatter_sub(centers, label, diff) with tf.control_dependencies([centers]): distance = tf.square(features - centers_batch) distance = tf.reduce_sum(distance, axis=-1) center_loss = tf.reduce_mean(distance) center_loss = tf.identity(center_loss * weights, name=name + '_loss') return center_loss def attention_crop(attention_maps): ''' 利用attention map 做数据增强,这里是论文中的Crop Mask :param attention_maps: Feature maps降维得到的 :return: ''' batch_size, height, width, num_parts = attention_maps.shape bboxes = [] for i in range(batch_size): attention_map = attention_maps[i] part_weights = attention_map.mean(axis=0).mean(axis=0) part_weights = np.sqrt(part_weights) part_weights = part_weights / np.sum(part_weights) selected_index = np.random.choice(np.arange(0, num_parts), 1, p=part_weights)[0] mask = attention_map[:, :, selected_index] threshold = random.uniform(0.4, 0.6) itemindex = np.where(mask >= mask.max() * threshold) ymin = itemindex[0].min() / height - 0.1 ymax = itemindex[0].max() / height + 0.1 xmin = itemindex[1].min() / width - 0.1 xmax = itemindex[1].max() / width + 0.1 bbox = np.asarray([ymin, xmin, ymax, xmax], dtype=np.float32) bboxes.append(bbox) bboxes = np.asarray(bboxes, np.float32) return bboxes def attention_drop(attention_maps): ''' 这里是attention drop部分,目的是为了让模型可以注意到物体的其他部位(因不同attention map可能聚焦了同一部位) :param attention_maps: :return: ''' batch_size, height, width, num_parts = attention_maps.shape masks = [] for i in range(batch_size): attention_map = attention_maps[i] part_weights = attention_map.mean(axis=0).mean(axis=0) part_weights = np.sqrt(part_weights) if (np.sum(part_weights) != 0): part_weights = part_weights / np.sum(part_weights) selected_index = np.random.choice(np.arange(0, num_parts), 1, p=part_weights)[0] mask = attention_map[:, :, selected_index:selected_index + 1] # soft mask threshold = random.uniform(0.2, 0.5) mask = (mask < threshold * mask.max()).astype(np.float32) masks.append(mask) masks = np.asarray(masks, dtype=np.float32) return masks def clone_fn(batch_queue): """Allows data parallelism by creating multiple clones of network_fn.""" images, labels = batch_queue.dequeue() logits_1, end_points_1 = network_fn(images) attention_maps = end_points_1['attention_maps'] attention_maps = tf.image.resize( attention_maps, [train_image_size, train_image_size], method=tf.image.ResizeMethod.BILINEAR) # attention crop bboxes = tf.compat.v1.py_func(attention_crop, [attention_maps], [tf.float32]) bboxes = tf.reshape(bboxes, [batch_size, 4]) box_ind = tf.range(batch_size, dtype=tf.int32) images_crop = tf.image.crop_and_resize( images, bboxes, box_ind, crop_size=[train_image_size, train_image_size]) # attention drop masks = tf.compat.v1.py_func(attention_drop, [attention_maps], [tf.float32]) masks = tf.reshape( masks, [batch_size, train_image_size, train_image_size, 1]) images_drop = images * masks logits_2, end_points_2 = network_fn(images_crop, reuse=True) logits_3, end_points_3 = network_fn(images_drop, reuse=True) slim.losses.softmax_cross_entropy(logits_1, labels, weights=1 / 3.0, scope='cross_entropy_1') slim.losses.softmax_cross_entropy(logits_2, labels, weights=1 / 3.0, scope='cross_entropy_2') slim.losses.softmax_cross_entropy(logits_3, labels, weights=1 / 3.0, scope='cross_entropy_3') embeddings = end_points_1['embeddings'] center_loss = calculate_pooling_center_loss( features=embeddings, label=labels, alfa=0.95, nrof_classes=dataset.num_classes, weights=1.0, name='center_loss') slim.losses.add_loss(center_loss) return end_points_1 # Gather initial summaries. summaries = set( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) first_clone_scope = deploy_config.clone_scope(0) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by network_fn. update_ops = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS, first_clone_scope) # Add summaries for end_points. end_points = clones[0].outputs for end_point in end_points: x = end_points[end_point] summaries.add(tf.summary.histogram('activations/' + end_point, x)) summaries.add( tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x))) # Add summaries for losses. for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Add summaries for variables. for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) ################################# # Configure the moving averages # ################################# if moving_average_decay: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( moving_average_decay, global_step) else: moving_average_variables, variable_averages = None, None ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.optimizer_device()): learning_rate = configure_learning_rate(dataset.num_samples, global_step) optimizer = configure_optimizer(learning_rate) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) if moving_average_decay: # Update ops executed locally by trainer. update_ops.append( variable_averages.apply(moving_average_variables)) # Variables to train. variables_to_train = get_variables_to_train(trainable_scopes) # and returns a train_tensor and summary_op total_loss, clones_gradients = model_deploy.optimize_clones( clones, optimizer, var_list=variables_to_train) # Add total_loss to summary. summaries.add(tf.summary.scalar('total_loss', total_loss)) # Create gradient updates. grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.compat.v1.summary.merge_all() config = tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False) config.gpu_options.allow_growth = True config.gpu_options.visible_device_list = "0" save_model_path = os.path.join(checkpoint_path, model_name, "%s.ckpt" % model_name) print(save_model_path) # saver = tf.compat.v1.train.import_meta_graph('%s.meta'%save_model_path, clear_devices=True) tf.compat.v1.disable_eager_execution() # train the model slim.learning.train( train_op=train_tensor, logdir=train_dir, is_chief=(task == 0), init_fn=_get_init_fn(save_model_path, train_dir=train_dir), summary_op=summary_op, number_of_steps=max_number_of_steps, log_every_n_steps=log_every_n_steps, save_summaries_secs=save_summaries_secs, save_interval_secs=save_interval_secs, # sync_optimizer=None, session_config=config)