Esempio n. 1
0
def train_resnet_mentornet(max_step_run):
  """Trains the mentornet with the student resnet model.

  Args:
    max_step_run: The maximum number of gradient steps.
  """
  if not os.path.exists(FLAGS.train_log_dir):
    os.makedirs(FLAGS.train_log_dir)
  g = tf.Graph()

  with g.as_default():
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      tf_global_step = tf.train.get_or_create_global_step()

      # pylint: disable=line-too-long
      images, one_hot_labels, clean_images, clean_one_hot_labels, num_samples_per_epoch, num_of_classes = cifar_data_provider.my_provide_resnet_data(
          FLAGS.dataset_name,
          'train',
          FLAGS.batch_size,
          dataset_dir=FLAGS.data_dir)

      hps = resnet_model.HParams(
          batch_size=FLAGS.batch_size,
          num_classes=num_of_classes,
          min_lrn_rate=0.0001,
          lrn_rate=FLAGS.learning_rate,
          num_residual_units=9,
          use_bottleneck=False,
          weight_decay_rate=0.0002,
          relu_leakiness=0.1,
          optimizer='mom')

      images.set_shape([FLAGS.batch_size, 32, 32, 3])
      tf.logging.info('num_of_example=%s', num_samples_per_epoch)

      # Define the model:
      resnet = resnet_model.ResNet(hps, images, one_hot_labels, mode='train')
      logits = resnet.build_model()

      # Specify the loss function:
      loss = tf.nn.softmax_cross_entropy_with_logits(
          labels=one_hot_labels, logits=logits)

      dropout_rates = utils.parse_dropout_rate_list(FLAGS.example_dropout_rates)
      example_dropout_rates = tf.convert_to_tensor(
          dropout_rates, np.float32, name='example_dropout_rates')

      loss_p_percentile = tf.convert_to_tensor(
          np.array([FLAGS.loss_p_percentile] * 100),
          np.float32,
          name='loss_p_percentile')

      loss = tf.reshape(loss, [-1, 1])

      epoch_step = tf.to_int32(
          tf.floor(tf.divide(tf_global_step, max_step_run) * 100))

      zero_labels = tf.zeros([tf.shape(loss)[0], 1], tf.float32)

      v = utils.mentornet(
          epoch_step,
          loss,
          zero_labels,
          loss_p_percentile,
          example_dropout_rates,
          burn_in_epoch=FLAGS.burn_in_epoch,
          fixed_epoch_after_burn_in=FLAGS.fixed_epoch_after_burn_in,
          loss_moving_average_decay=FLAGS.loss_moving_average_decay)

      tf.stop_gradient(v)

      # Split v into clean data & noise data part
      is_clean = tf.reshape(tf.reduce_all(tf.equal(one_hot_labels, clean_one_hot_labels), axis=1), [-1,1])
      clean_v = tf.boolean_mask(v, is_clean)
      noise_v = tf.boolean_mask(v, ~is_clean)
      tf.add_to_collection('v', v)
      tf.add_to_collection('v', clean_v)
      tf.add_to_collection('v', noise_v)

      slim.summaries.add_histogram_summary(tf.boolean_mask(v, is_clean), 'clean_v')
      slim.summaries.add_histogram_summary(tf.boolean_mask(v, ~is_clean), 'noisy_v')

      # Log data utilization
      data_util = utils.summarize_data_utilization(v, tf_global_step,
                                                   FLAGS.batch_size)
      decay_loss = resnet.decay()
      weighted_loss_vector = tf.multiply(loss, v)

      weighted_loss = tf.reduce_mean(weighted_loss_vector)

      slim.summaries.add_scalar_summary(
          tf.reduce_mean(loss), 'mentornet/orig_loss')
      slim.summaries.add_scalar_summary(weighted_loss,
                                        'mentornet/weighted_loss')

      # Normalize the decay loss based on v
      weighed_decay_loss = decay_loss * (tf.reduce_sum(v) / FLAGS.batch_size)

      weighted_total_loss = weighted_loss + weighed_decay_loss

      slim.summaries.add_scalar_summary(weighted_total_loss,
                                        'mentornet/total_loss')

      slim.summaries.add_scalar_summary(weighted_total_loss, 'total_loss')
      tf.add_to_collection('total_loss', weighted_total_loss)

      boundaries = [19531, 25000, 30000]
      values = [FLAGS.learning_rate * t for t in [1, 0.1, 0.01, 0.001]]
      lr = tf.train.piecewise_constant(tf_global_step, boundaries, values)
      slim.summaries.add_scalar_summary(lr, 'learning_rate')

      # Specify the optimization scheme:
      with tf.control_dependencies([weighted_total_loss, data_util]):
        # Set up training.
        trainable_variables = tf.trainable_variables()
        trainable_variables = tf.contrib.framework.filter_variables(
            trainable_variables, exclude_patterns=['mentornet'])

        grads = tf.gradients(weighted_total_loss, trainable_variables)
        optimizer = tf.train.MomentumOptimizer(lr, momentum=0.9)

        apply_op = optimizer.apply_gradients(
            zip(grads, trainable_variables),
            global_step=tf_global_step,
            name='train_step')

        train_ops = [apply_op] + resnet.extra_train_ops
        train_op = tf.group(*train_ops)

      # Parameter restore setup
      if FLAGS.trained_mentornet_dir is not None:
        ckpt_model = FLAGS.trained_mentornet_dir
        if os.path.isdir(FLAGS.trained_mentornet_dir):
          ckpt_model = tf.train.latest_checkpoint(ckpt_model)

        # Fix the mentornet parameters
        variables_to_restore = slim.get_variables_to_restore(
            # TODO(lujiang): mentornet_inputs or mentor_inputs?
            include=['mentornet', 'mentornet_inputs'])
        iassign_op1, ifeed_dict1 = tf.contrib.framework.assign_from_checkpoint(
            ckpt_model, variables_to_restore)

        # Create an initial assignment function.
        def init_assign_fn(sess):
          tf.logging.info('Restore using customer initializer %s', '.' * 10)
          sess.run(iassign_op1, ifeed_dict1)
      else:
        init_assign_fn = None

      tf.logging.info('-' * 20 + 'MentorNet' + '-' * 20)
      tf.logging.info('loaded pretrained mentornet from %s', ckpt_model)
      tf.logging.info('loss_p_percentile=%3f', FLAGS.loss_p_percentile)
      tf.logging.info('burn_in_epoch=%d', FLAGS.burn_in_epoch)
      tf.logging.info('fixed_epoch_after_burn_in=%s',
                      FLAGS.fixed_epoch_after_burn_in)
      tf.logging.info('loss_moving_average_decay=%3f',
                      FLAGS.loss_moving_average_decay)
      tf.logging.info('example_dropout_rates %s', ','.join(
          str(t) for t in dropout_rates))
      tf.logging.info('-' * 20)

      saver = tf.train.Saver(max_to_keep=10, keep_checkpoint_every_n_hours=24)

      # Run training.
      slim.learning.train(
          train_op=train_op,
          train_step_fn=resnet_train_step,
          logdir=FLAGS.train_log_dir,
          master=FLAGS.master,
          is_chief=FLAGS.task == 0,
          saver=saver,
          number_of_steps=max_step_run,
          init_fn=init_assign_fn,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)
Esempio n. 2
0
def train_inception_mentornet(max_step_run):
  """Trains the mentornet with the student inception model.

  Args:
    max_step_run: The maximum number of gradient steps.
  """
  if not os.path.exists(FLAGS.train_log_dir):
    os.makedirs(FLAGS.train_log_dir)
  g = tf.Graph()

  with g.as_default():
    # If ps_tasks is zero, the local device is used. When using multiple
    # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
    # across the different devices.
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      config = tf.ConfigProto()
      # limit gpu memory to run train and eval on the same gpu
      config.gpu_options.per_process_gpu_memory_fraction = 0.8

      tf_global_step = tf.train.get_or_create_global_step()

      # pylint: disable=line-too-long
      images, one_hot_labels, num_samples_per_epoch, num_of_classes = cifar_data_provider.provide_cifarnet_data(
          FLAGS.dataset_name,
          'train',
          FLAGS.batch_size,
          dataset_dir=FLAGS.data_dir)

      images.set_shape([FLAGS.batch_size, 32, 32, 3])
      tf.logging.info('num_of_example=%s', num_samples_per_epoch)

      # Define the model:
      with slim.arg_scope(
          inception_model.cifarnet_arg_scope(weight_decay=0.004)):
        logits, _ = inception_model.cifarnet(
            images, num_of_classes, is_training=True, dropout_keep_prob=0.8)

      # Specify the loss function:
      loss = tf.nn.softmax_cross_entropy_with_logits(
          labels=one_hot_labels, logits=logits)

      dropout_rates = utils.parse_dropout_rate_list(FLAGS.example_dropout_rates)
      example_dropout_rates = tf.convert_to_tensor(
          dropout_rates, np.float32, name='example_dropout_rates')

      loss_p_percentile = tf.convert_to_tensor(
          np.array([FLAGS.loss_p_percentile] * 100),
          np.float32,
          name='loss_p_percentile')

      epoch_step = tf.to_int32(
          tf.floor(tf.divide(tf_global_step, max_step_run) * 100))

      zero_labels = tf.zeros([tf.shape(loss)[0], 1], tf.float32)

      loss = tf.reshape(loss, [-1, 1])

      v = utils.mentornet(
          epoch_step,
          loss,
          zero_labels,
          loss_p_percentile,
          example_dropout_rates,
          burn_in_epoch=FLAGS.burn_in_epoch,
          fixed_epoch_after_burn_in=FLAGS.fixed_epoch_after_burn_in,
          loss_moving_average_decay=FLAGS.loss_moving_average_decay)

      tf.stop_gradient(v)

      # log data utilization
      data_util = utils.summarize_data_utilization(v, tf_global_step,
                                                   FLAGS.batch_size)

      weighted_loss_vector = tf.multiply(loss, v)

      weighted_loss = tf.reduce_mean(weighted_loss_vector)

      slim.summaries.add_scalar_summary(
          tf.reduce_mean(loss), 'mentornet/orig_loss')
      slim.summaries.add_scalar_summary(weighted_loss,
                                        'mentornet/weighted_loss')

      # normalize the decay loss based on v
      weighed_decay_loss = 0
      weighted_total_loss = weighted_loss + weighed_decay_loss

      slim.summaries.add_scalar_summary(weighted_total_loss,
                                        'mentornet/total_loss')

      slim.summaries.add_scalar_summary(weighted_total_loss, 'total_loss')
      tf.add_to_collection('total_loss', weighted_total_loss)

      decay_steps = int(
          num_samples_per_epoch / FLAGS.batch_size * FLAGS.num_epochs_per_decay)

      lr = tf.train.exponential_decay(
          FLAGS.learning_rate,
          tf_global_step,
          decay_steps,
          FLAGS.learning_rate_decay_factor,
          staircase=True)
      slim.summaries.add_scalar_summary(lr, 'learning_rate', print_summary=True)

      with tf.control_dependencies([weighted_total_loss, data_util]):
        # Set up training.
        trainable_variables = tf.trainable_variables()
        trainable_variables = tf.contrib.framework.filter_variables(
            trainable_variables, exclude_patterns=['mentornet'])

        # Specify the optimization scheme:
        optimizer = tf.train.GradientDescentOptimizer(lr)
        train_op = slim.learning.create_train_op(
            weighted_total_loss,
            optimizer,
            variables_to_train=trainable_variables)

      # Restore setup
      if FLAGS.trained_mentornet_dir is not None:
        ckpt_model = FLAGS.trained_mentornet_dir
        if os.path.isdir(FLAGS.trained_mentornet_dir):
          ckpt_model = tf.train.latest_checkpoint(ckpt_model)

        # fix the mentornet parameters
        variables_to_restore = slim.get_variables_to_restore(
            # TODO(lujiang): mentornet_inputs or mentor_inputs?
            include=['mentornet', 'mentornet_inputs'])
        iassign_op1, ifeed_dict1 = tf.contrib.framework.assign_from_checkpoint(
            ckpt_model, variables_to_restore)

        # Create an initial assignment function.
        def init_assign_fn(sess):
          tf.logging.info('Restore using customer initializer %s', '.' * 10)
          sess.run(iassign_op1, ifeed_dict1)
      else:
        init_assign_fn = None

      tf.logging.info('-' * 20 + 'MentorNet' + '-' * 20)
      tf.logging.info('loaded pretrained mentornet from %s', ckpt_model)
      tf.logging.info('loss_p_percentile=%3f', FLAGS.loss_p_percentile)
      tf.logging.info('burn_in_epoch=%d', FLAGS.burn_in_epoch)
      tf.logging.info('fixed_epoch_after_burn_in=%s',
                      FLAGS.fixed_epoch_after_burn_in)
      tf.logging.info('loss_moving_average_decay=%3f',
                      FLAGS.loss_moving_average_decay)
      tf.logging.info('example_dropout_rates %s', ','.join(
          str(t) for t in dropout_rates))
      tf.logging.info('-' * 20)

      saver = tf.train.Saver(max_to_keep=10, keep_checkpoint_every_n_hours=5)

      # Run training.
      slim.learning.train(
          train_op=train_op,
          logdir=FLAGS.train_log_dir,
          master=FLAGS.master,
          is_chief=FLAGS.task == 0,
          saver=saver,
          session_config=config,
          number_of_steps=max_step_run,
          init_fn=init_assign_fn,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)
def train_resnet_mentormix(max_step_run):
    """Trains the mentornet with the student resnet model.

  Args:
    max_step_run: The maximum number of gradient steps.
  """
    if not os.path.exists(FLAGS.train_log_dir):
        os.makedirs(FLAGS.train_log_dir)
    g = tf.Graph()

    with g.as_default():
        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            tf_global_step = tf.train.get_or_create_global_step()

            (images, one_hot_labels, num_samples_per_epoch,
             num_of_classes) = cifar_data_provider.provide_resnet_data(
                 FLAGS.dataset_name,
                 'train',
                 FLAGS.batch_size,
                 dataset_dir=FLAGS.data_dir)

            hps = resnet_model.HParams(batch_size=FLAGS.batch_size,
                                       num_classes=num_of_classes,
                                       min_lrn_rate=0.0001,
                                       lrn_rate=FLAGS.learning_rate,
                                       num_residual_units=5,
                                       use_bottleneck=False,
                                       weight_decay_rate=0.0002,
                                       relu_leakiness=0.1,
                                       optimizer='mom')

            images.set_shape([FLAGS.batch_size, 32, 32, 3])

            # Define the model:
            resnet = resnet_model.ResNet(hps,
                                         images,
                                         one_hot_labels,
                                         mode='train')
            with tf.variable_scope('ResNet32'):
                logits = resnet.build_model()

            # Specify the loss function:
            loss = tf.nn.softmax_cross_entropy_with_logits(
                labels=one_hot_labels, logits=logits)

            dropout_rates = utils.parse_dropout_rate_list(
                FLAGS.example_dropout_rates)
            example_dropout_rates = tf.convert_to_tensor(
                dropout_rates, np.float32, name='example_dropout_rates')

            loss_p_percentile = tf.convert_to_tensor(np.array(
                [FLAGS.loss_p_percentile] * 100),
                                                     np.float32,
                                                     name='loss_p_percentile')

            loss = tf.reshape(loss, [-1, 1])

            epoch_step = tf.to_int32(
                tf.floor(tf.divide(tf_global_step, max_step_run) * 100))

            zero_labels = tf.zeros([tf.shape(loss)[0], 1], tf.float32)

            mentornet_net_hparams = utils.get_mentornet_network_hyperparameter(
                FLAGS.trained_mentornet_dir)

            # In the simplest case, this function can be replaced with a thresholding
            # function. See loss_thresholding_function in utils.py.
            v = utils.mentornet(epoch_step,
                                loss,
                                zero_labels,
                                loss_p_percentile,
                                example_dropout_rates,
                                burn_in_epoch=FLAGS.burn_in_epoch,
                                mentornet_net_hparams=mentornet_net_hparams,
                                avg_name='individual')

            v = tf.stop_gradient(v)
            loss = tf.stop_gradient(tf.identity(loss))
            logits = tf.stop_gradient(tf.identity(logits))

            # Perform MentorMix
            images_mix, labels_mix = utils.mentor_mix_up(
                images, one_hot_labels, v, FLAGS.mixup_alpha)
            resnet = resnet_model.ResNet(hps,
                                         images_mix,
                                         labels_mix,
                                         mode='train')
            with tf.variable_scope('ResNet32', reuse=True):
                logits_mix = resnet.build_model()

            loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels_mix,
                                                           logits=logits_mix)
            decay_loss = resnet.decay()

            # second weighting
            if FLAGS.second_reweight:
                loss = tf.reshape(loss, [-1, 1])
                v = utils.mentornet(
                    epoch_step,
                    loss,
                    zero_labels,
                    loss_p_percentile,
                    example_dropout_rates,
                    burn_in_epoch=FLAGS.burn_in_epoch,
                    mentornet_net_hparams=mentornet_net_hparams,
                    avg_name='mixed')
                v = tf.stop_gradient(v)
                weighted_loss_vector = tf.multiply(loss, v)
                loss = tf.reduce_mean(weighted_loss_vector)
                # reproduced with the following decay loss which should be 0.
                decay_loss = tf.losses.get_regularization_loss()
                decay_loss = decay_loss * (tf.reduce_sum(v) / FLAGS.batch_size)

            # Log data utilization
            data_util = utils.summarize_data_utilization(
                v, tf_global_step, FLAGS.batch_size)

            loss = tf.reduce_mean(loss)
            slim.summaries.add_scalar_summary(tf.reduce_mean(loss),
                                              'mentormix/mix_loss')

            weighted_total_loss = loss + decay_loss

            slim.summaries.add_scalar_summary(weighted_total_loss,
                                              'total_loss')
            tf.add_to_collection('total_loss', weighted_total_loss)

            # Set up the moving averages:
            moving_average_variables = tf.trainable_variables()
            moving_average_variables = tf.contrib.framework.filter_variables(
                moving_average_variables, exclude_patterns=['mentornet'])

            variable_averages = tf.train.ExponentialMovingAverage(
                0.9999, tf_global_step)
            tf.add_to_collection(
                tf.GraphKeys.UPDATE_OPS,
                variable_averages.apply(moving_average_variables))

            decay_steps = FLAGS.num_epochs_per_decay * num_samples_per_epoch / FLAGS.batch_size
            lr = tf.train.exponential_decay(FLAGS.learning_rate,
                                            tf_global_step,
                                            decay_steps,
                                            FLAGS.learning_rate_decay_factor,
                                            staircase=True)
            lr = tf.squeeze(lr)
            slim.summaries.add_scalar_summary(lr, 'learning_rate')

            # Specify the optimization scheme:
            with tf.control_dependencies([weighted_total_loss, data_util]):
                # Set up training.
                trainable_variables = tf.trainable_variables()
                trainable_variables = tf.contrib.framework.filter_variables(
                    trainable_variables, exclude_patterns=['mentornet'])

                grads = tf.gradients(weighted_total_loss, trainable_variables)
                optimizer = tf.train.MomentumOptimizer(lr, momentum=0.9)

                apply_op = optimizer.apply_gradients(
                    zip(grads, trainable_variables),
                    global_step=tf_global_step,
                    name='train_step')

                train_ops = [apply_op
                             ] + resnet.extra_train_ops + tf.get_collection(
                                 tf.GraphKeys.UPDATE_OPS)
                train_op = tf.group(*train_ops)

            # Parameter restore setup
            if FLAGS.trained_mentornet_dir is not None:
                ckpt_model = FLAGS.trained_mentornet_dir
                if os.path.isdir(FLAGS.trained_mentornet_dir):
                    ckpt_model = tf.train.latest_checkpoint(ckpt_model)

                # Fix the mentornet parameters
                variables_to_restore = slim.get_variables_to_restore(
                    include=['mentornet', 'mentornet_inputs'])
                iassign_op1, ifeed_dict1 = tf.contrib.framework.assign_from_checkpoint(
                    ckpt_model, variables_to_restore)

                # Create an initial assignment function.
                def init_assign_fn(sess):
                    tf.logging.info('Restore using customer initializer %s',
                                    '.' * 10)
                    sess.run(iassign_op1, ifeed_dict1)
            else:
                init_assign_fn = None

            tf.logging.info('-' * 20 + 'MentorMix' + '-' * 20)
            tf.logging.info('loss_p_percentile=%3f', FLAGS.loss_p_percentile)
            tf.logging.info('mixup_alpha=%d', FLAGS.mixup_alpha)
            tf.logging.info('-' * 20)

            saver = tf.train.Saver(max_to_keep=10,
                                   keep_checkpoint_every_n_hours=24)

            # Run training.
            slim.learning.train(train_op=train_op,
                                train_step_fn=resnet_train_step,
                                logdir=FLAGS.train_log_dir,
                                master=FLAGS.master,
                                is_chief=FLAGS.task == 0,
                                saver=saver,
                                number_of_steps=max_step_run,
                                init_fn=init_assign_fn,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)