def train_resnet_mentornet(max_step_run): """Trains the mentornet with the student resnet model. Args: max_step_run: The maximum number of gradient steps. """ if not os.path.exists(FLAGS.train_log_dir): os.makedirs(FLAGS.train_log_dir) g = tf.Graph() with g.as_default(): with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): tf_global_step = tf.train.get_or_create_global_step() # pylint: disable=line-too-long images, one_hot_labels, clean_images, clean_one_hot_labels, num_samples_per_epoch, num_of_classes = cifar_data_provider.my_provide_resnet_data( FLAGS.dataset_name, 'train', FLAGS.batch_size, dataset_dir=FLAGS.data_dir) hps = resnet_model.HParams( batch_size=FLAGS.batch_size, num_classes=num_of_classes, min_lrn_rate=0.0001, lrn_rate=FLAGS.learning_rate, num_residual_units=9, use_bottleneck=False, weight_decay_rate=0.0002, relu_leakiness=0.1, optimizer='mom') images.set_shape([FLAGS.batch_size, 32, 32, 3]) tf.logging.info('num_of_example=%s', num_samples_per_epoch) # Define the model: resnet = resnet_model.ResNet(hps, images, one_hot_labels, mode='train') logits = resnet.build_model() # Specify the loss function: loss = tf.nn.softmax_cross_entropy_with_logits( labels=one_hot_labels, logits=logits) dropout_rates = utils.parse_dropout_rate_list(FLAGS.example_dropout_rates) example_dropout_rates = tf.convert_to_tensor( dropout_rates, np.float32, name='example_dropout_rates') loss_p_percentile = tf.convert_to_tensor( np.array([FLAGS.loss_p_percentile] * 100), np.float32, name='loss_p_percentile') loss = tf.reshape(loss, [-1, 1]) epoch_step = tf.to_int32( tf.floor(tf.divide(tf_global_step, max_step_run) * 100)) zero_labels = tf.zeros([tf.shape(loss)[0], 1], tf.float32) v = utils.mentornet( epoch_step, loss, zero_labels, loss_p_percentile, example_dropout_rates, burn_in_epoch=FLAGS.burn_in_epoch, fixed_epoch_after_burn_in=FLAGS.fixed_epoch_after_burn_in, loss_moving_average_decay=FLAGS.loss_moving_average_decay) tf.stop_gradient(v) # Split v into clean data & noise data part is_clean = tf.reshape(tf.reduce_all(tf.equal(one_hot_labels, clean_one_hot_labels), axis=1), [-1,1]) clean_v = tf.boolean_mask(v, is_clean) noise_v = tf.boolean_mask(v, ~is_clean) tf.add_to_collection('v', v) tf.add_to_collection('v', clean_v) tf.add_to_collection('v', noise_v) slim.summaries.add_histogram_summary(tf.boolean_mask(v, is_clean), 'clean_v') slim.summaries.add_histogram_summary(tf.boolean_mask(v, ~is_clean), 'noisy_v') # Log data utilization data_util = utils.summarize_data_utilization(v, tf_global_step, FLAGS.batch_size) decay_loss = resnet.decay() weighted_loss_vector = tf.multiply(loss, v) weighted_loss = tf.reduce_mean(weighted_loss_vector) slim.summaries.add_scalar_summary( tf.reduce_mean(loss), 'mentornet/orig_loss') slim.summaries.add_scalar_summary(weighted_loss, 'mentornet/weighted_loss') # Normalize the decay loss based on v weighed_decay_loss = decay_loss * (tf.reduce_sum(v) / FLAGS.batch_size) weighted_total_loss = weighted_loss + weighed_decay_loss slim.summaries.add_scalar_summary(weighted_total_loss, 'mentornet/total_loss') slim.summaries.add_scalar_summary(weighted_total_loss, 'total_loss') tf.add_to_collection('total_loss', weighted_total_loss) boundaries = [19531, 25000, 30000] values = [FLAGS.learning_rate * t for t in [1, 0.1, 0.01, 0.001]] lr = tf.train.piecewise_constant(tf_global_step, boundaries, values) slim.summaries.add_scalar_summary(lr, 'learning_rate') # Specify the optimization scheme: with tf.control_dependencies([weighted_total_loss, data_util]): # Set up training. trainable_variables = tf.trainable_variables() trainable_variables = tf.contrib.framework.filter_variables( trainable_variables, exclude_patterns=['mentornet']) grads = tf.gradients(weighted_total_loss, trainable_variables) optimizer = tf.train.MomentumOptimizer(lr, momentum=0.9) apply_op = optimizer.apply_gradients( zip(grads, trainable_variables), global_step=tf_global_step, name='train_step') train_ops = [apply_op] + resnet.extra_train_ops train_op = tf.group(*train_ops) # Parameter restore setup if FLAGS.trained_mentornet_dir is not None: ckpt_model = FLAGS.trained_mentornet_dir if os.path.isdir(FLAGS.trained_mentornet_dir): ckpt_model = tf.train.latest_checkpoint(ckpt_model) # Fix the mentornet parameters variables_to_restore = slim.get_variables_to_restore( # TODO(lujiang): mentornet_inputs or mentor_inputs? include=['mentornet', 'mentornet_inputs']) iassign_op1, ifeed_dict1 = tf.contrib.framework.assign_from_checkpoint( ckpt_model, variables_to_restore) # Create an initial assignment function. def init_assign_fn(sess): tf.logging.info('Restore using customer initializer %s', '.' * 10) sess.run(iassign_op1, ifeed_dict1) else: init_assign_fn = None tf.logging.info('-' * 20 + 'MentorNet' + '-' * 20) tf.logging.info('loaded pretrained mentornet from %s', ckpt_model) tf.logging.info('loss_p_percentile=%3f', FLAGS.loss_p_percentile) tf.logging.info('burn_in_epoch=%d', FLAGS.burn_in_epoch) tf.logging.info('fixed_epoch_after_burn_in=%s', FLAGS.fixed_epoch_after_burn_in) tf.logging.info('loss_moving_average_decay=%3f', FLAGS.loss_moving_average_decay) tf.logging.info('example_dropout_rates %s', ','.join( str(t) for t in dropout_rates)) tf.logging.info('-' * 20) saver = tf.train.Saver(max_to_keep=10, keep_checkpoint_every_n_hours=24) # Run training. slim.learning.train( train_op=train_op, train_step_fn=resnet_train_step, logdir=FLAGS.train_log_dir, master=FLAGS.master, is_chief=FLAGS.task == 0, saver=saver, number_of_steps=max_step_run, init_fn=init_assign_fn, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def train_inception_mentornet(max_step_run): """Trains the mentornet with the student inception model. Args: max_step_run: The maximum number of gradient steps. """ if not os.path.exists(FLAGS.train_log_dir): os.makedirs(FLAGS.train_log_dir) g = tf.Graph() with g.as_default(): # If ps_tasks is zero, the local device is used. When using multiple # (non-local) replicas, the ReplicaDeviceSetter distributes the variables # across the different devices. with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): config = tf.ConfigProto() # limit gpu memory to run train and eval on the same gpu config.gpu_options.per_process_gpu_memory_fraction = 0.8 tf_global_step = tf.train.get_or_create_global_step() # pylint: disable=line-too-long images, one_hot_labels, num_samples_per_epoch, num_of_classes = cifar_data_provider.provide_cifarnet_data( FLAGS.dataset_name, 'train', FLAGS.batch_size, dataset_dir=FLAGS.data_dir) images.set_shape([FLAGS.batch_size, 32, 32, 3]) tf.logging.info('num_of_example=%s', num_samples_per_epoch) # Define the model: with slim.arg_scope( inception_model.cifarnet_arg_scope(weight_decay=0.004)): logits, _ = inception_model.cifarnet( images, num_of_classes, is_training=True, dropout_keep_prob=0.8) # Specify the loss function: loss = tf.nn.softmax_cross_entropy_with_logits( labels=one_hot_labels, logits=logits) dropout_rates = utils.parse_dropout_rate_list(FLAGS.example_dropout_rates) example_dropout_rates = tf.convert_to_tensor( dropout_rates, np.float32, name='example_dropout_rates') loss_p_percentile = tf.convert_to_tensor( np.array([FLAGS.loss_p_percentile] * 100), np.float32, name='loss_p_percentile') epoch_step = tf.to_int32( tf.floor(tf.divide(tf_global_step, max_step_run) * 100)) zero_labels = tf.zeros([tf.shape(loss)[0], 1], tf.float32) loss = tf.reshape(loss, [-1, 1]) v = utils.mentornet( epoch_step, loss, zero_labels, loss_p_percentile, example_dropout_rates, burn_in_epoch=FLAGS.burn_in_epoch, fixed_epoch_after_burn_in=FLAGS.fixed_epoch_after_burn_in, loss_moving_average_decay=FLAGS.loss_moving_average_decay) tf.stop_gradient(v) # log data utilization data_util = utils.summarize_data_utilization(v, tf_global_step, FLAGS.batch_size) weighted_loss_vector = tf.multiply(loss, v) weighted_loss = tf.reduce_mean(weighted_loss_vector) slim.summaries.add_scalar_summary( tf.reduce_mean(loss), 'mentornet/orig_loss') slim.summaries.add_scalar_summary(weighted_loss, 'mentornet/weighted_loss') # normalize the decay loss based on v weighed_decay_loss = 0 weighted_total_loss = weighted_loss + weighed_decay_loss slim.summaries.add_scalar_summary(weighted_total_loss, 'mentornet/total_loss') slim.summaries.add_scalar_summary(weighted_total_loss, 'total_loss') tf.add_to_collection('total_loss', weighted_total_loss) decay_steps = int( num_samples_per_epoch / FLAGS.batch_size * FLAGS.num_epochs_per_decay) lr = tf.train.exponential_decay( FLAGS.learning_rate, tf_global_step, decay_steps, FLAGS.learning_rate_decay_factor, staircase=True) slim.summaries.add_scalar_summary(lr, 'learning_rate', print_summary=True) with tf.control_dependencies([weighted_total_loss, data_util]): # Set up training. trainable_variables = tf.trainable_variables() trainable_variables = tf.contrib.framework.filter_variables( trainable_variables, exclude_patterns=['mentornet']) # Specify the optimization scheme: optimizer = tf.train.GradientDescentOptimizer(lr) train_op = slim.learning.create_train_op( weighted_total_loss, optimizer, variables_to_train=trainable_variables) # Restore setup if FLAGS.trained_mentornet_dir is not None: ckpt_model = FLAGS.trained_mentornet_dir if os.path.isdir(FLAGS.trained_mentornet_dir): ckpt_model = tf.train.latest_checkpoint(ckpt_model) # fix the mentornet parameters variables_to_restore = slim.get_variables_to_restore( # TODO(lujiang): mentornet_inputs or mentor_inputs? include=['mentornet', 'mentornet_inputs']) iassign_op1, ifeed_dict1 = tf.contrib.framework.assign_from_checkpoint( ckpt_model, variables_to_restore) # Create an initial assignment function. def init_assign_fn(sess): tf.logging.info('Restore using customer initializer %s', '.' * 10) sess.run(iassign_op1, ifeed_dict1) else: init_assign_fn = None tf.logging.info('-' * 20 + 'MentorNet' + '-' * 20) tf.logging.info('loaded pretrained mentornet from %s', ckpt_model) tf.logging.info('loss_p_percentile=%3f', FLAGS.loss_p_percentile) tf.logging.info('burn_in_epoch=%d', FLAGS.burn_in_epoch) tf.logging.info('fixed_epoch_after_burn_in=%s', FLAGS.fixed_epoch_after_burn_in) tf.logging.info('loss_moving_average_decay=%3f', FLAGS.loss_moving_average_decay) tf.logging.info('example_dropout_rates %s', ','.join( str(t) for t in dropout_rates)) tf.logging.info('-' * 20) saver = tf.train.Saver(max_to_keep=10, keep_checkpoint_every_n_hours=5) # Run training. slim.learning.train( train_op=train_op, logdir=FLAGS.train_log_dir, master=FLAGS.master, is_chief=FLAGS.task == 0, saver=saver, session_config=config, number_of_steps=max_step_run, init_fn=init_assign_fn, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def train_resnet_mentormix(max_step_run): """Trains the mentornet with the student resnet model. Args: max_step_run: The maximum number of gradient steps. """ if not os.path.exists(FLAGS.train_log_dir): os.makedirs(FLAGS.train_log_dir) g = tf.Graph() with g.as_default(): with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): tf_global_step = tf.train.get_or_create_global_step() (images, one_hot_labels, num_samples_per_epoch, num_of_classes) = cifar_data_provider.provide_resnet_data( FLAGS.dataset_name, 'train', FLAGS.batch_size, dataset_dir=FLAGS.data_dir) hps = resnet_model.HParams(batch_size=FLAGS.batch_size, num_classes=num_of_classes, min_lrn_rate=0.0001, lrn_rate=FLAGS.learning_rate, num_residual_units=5, use_bottleneck=False, weight_decay_rate=0.0002, relu_leakiness=0.1, optimizer='mom') images.set_shape([FLAGS.batch_size, 32, 32, 3]) # Define the model: resnet = resnet_model.ResNet(hps, images, one_hot_labels, mode='train') with tf.variable_scope('ResNet32'): logits = resnet.build_model() # Specify the loss function: loss = tf.nn.softmax_cross_entropy_with_logits( labels=one_hot_labels, logits=logits) dropout_rates = utils.parse_dropout_rate_list( FLAGS.example_dropout_rates) example_dropout_rates = tf.convert_to_tensor( dropout_rates, np.float32, name='example_dropout_rates') loss_p_percentile = tf.convert_to_tensor(np.array( [FLAGS.loss_p_percentile] * 100), np.float32, name='loss_p_percentile') loss = tf.reshape(loss, [-1, 1]) epoch_step = tf.to_int32( tf.floor(tf.divide(tf_global_step, max_step_run) * 100)) zero_labels = tf.zeros([tf.shape(loss)[0], 1], tf.float32) mentornet_net_hparams = utils.get_mentornet_network_hyperparameter( FLAGS.trained_mentornet_dir) # In the simplest case, this function can be replaced with a thresholding # function. See loss_thresholding_function in utils.py. v = utils.mentornet(epoch_step, loss, zero_labels, loss_p_percentile, example_dropout_rates, burn_in_epoch=FLAGS.burn_in_epoch, mentornet_net_hparams=mentornet_net_hparams, avg_name='individual') v = tf.stop_gradient(v) loss = tf.stop_gradient(tf.identity(loss)) logits = tf.stop_gradient(tf.identity(logits)) # Perform MentorMix images_mix, labels_mix = utils.mentor_mix_up( images, one_hot_labels, v, FLAGS.mixup_alpha) resnet = resnet_model.ResNet(hps, images_mix, labels_mix, mode='train') with tf.variable_scope('ResNet32', reuse=True): logits_mix = resnet.build_model() loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels_mix, logits=logits_mix) decay_loss = resnet.decay() # second weighting if FLAGS.second_reweight: loss = tf.reshape(loss, [-1, 1]) v = utils.mentornet( epoch_step, loss, zero_labels, loss_p_percentile, example_dropout_rates, burn_in_epoch=FLAGS.burn_in_epoch, mentornet_net_hparams=mentornet_net_hparams, avg_name='mixed') v = tf.stop_gradient(v) weighted_loss_vector = tf.multiply(loss, v) loss = tf.reduce_mean(weighted_loss_vector) # reproduced with the following decay loss which should be 0. decay_loss = tf.losses.get_regularization_loss() decay_loss = decay_loss * (tf.reduce_sum(v) / FLAGS.batch_size) # Log data utilization data_util = utils.summarize_data_utilization( v, tf_global_step, FLAGS.batch_size) loss = tf.reduce_mean(loss) slim.summaries.add_scalar_summary(tf.reduce_mean(loss), 'mentormix/mix_loss') weighted_total_loss = loss + decay_loss slim.summaries.add_scalar_summary(weighted_total_loss, 'total_loss') tf.add_to_collection('total_loss', weighted_total_loss) # Set up the moving averages: moving_average_variables = tf.trainable_variables() moving_average_variables = tf.contrib.framework.filter_variables( moving_average_variables, exclude_patterns=['mentornet']) variable_averages = tf.train.ExponentialMovingAverage( 0.9999, tf_global_step) tf.add_to_collection( tf.GraphKeys.UPDATE_OPS, variable_averages.apply(moving_average_variables)) decay_steps = FLAGS.num_epochs_per_decay * num_samples_per_epoch / FLAGS.batch_size lr = tf.train.exponential_decay(FLAGS.learning_rate, tf_global_step, decay_steps, FLAGS.learning_rate_decay_factor, staircase=True) lr = tf.squeeze(lr) slim.summaries.add_scalar_summary(lr, 'learning_rate') # Specify the optimization scheme: with tf.control_dependencies([weighted_total_loss, data_util]): # Set up training. trainable_variables = tf.trainable_variables() trainable_variables = tf.contrib.framework.filter_variables( trainable_variables, exclude_patterns=['mentornet']) grads = tf.gradients(weighted_total_loss, trainable_variables) optimizer = tf.train.MomentumOptimizer(lr, momentum=0.9) apply_op = optimizer.apply_gradients( zip(grads, trainable_variables), global_step=tf_global_step, name='train_step') train_ops = [apply_op ] + resnet.extra_train_ops + tf.get_collection( tf.GraphKeys.UPDATE_OPS) train_op = tf.group(*train_ops) # Parameter restore setup if FLAGS.trained_mentornet_dir is not None: ckpt_model = FLAGS.trained_mentornet_dir if os.path.isdir(FLAGS.trained_mentornet_dir): ckpt_model = tf.train.latest_checkpoint(ckpt_model) # Fix the mentornet parameters variables_to_restore = slim.get_variables_to_restore( include=['mentornet', 'mentornet_inputs']) iassign_op1, ifeed_dict1 = tf.contrib.framework.assign_from_checkpoint( ckpt_model, variables_to_restore) # Create an initial assignment function. def init_assign_fn(sess): tf.logging.info('Restore using customer initializer %s', '.' * 10) sess.run(iassign_op1, ifeed_dict1) else: init_assign_fn = None tf.logging.info('-' * 20 + 'MentorMix' + '-' * 20) tf.logging.info('loss_p_percentile=%3f', FLAGS.loss_p_percentile) tf.logging.info('mixup_alpha=%d', FLAGS.mixup_alpha) tf.logging.info('-' * 20) saver = tf.train.Saver(max_to_keep=10, keep_checkpoint_every_n_hours=24) # Run training. slim.learning.train(train_op=train_op, train_step_fn=resnet_train_step, logdir=FLAGS.train_log_dir, master=FLAGS.master, is_chief=FLAGS.task == 0, saver=saver, number_of_steps=max_step_run, init_fn=init_assign_fn, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)