コード例 #1
0
ファイル: run_hmax.py プロジェクト: serre-lab/prj_selfsup
 def module_fn(is_training):
     """Function that builds TF-Hub module."""
     endpoints = {}
     inputs = tf.placeholder(tf.float32, [None, None, None, 3])
     with tf.variable_scope('base_model', reuse=tf.AUTO_REUSE):
         hiddens = model(inputs, is_training)
         for v in [
                 'initial_conv', 'initial_max_pool', 'block_group1',
                 'block_group2', 'block_group3', 'block_group4',
                 'final_avg_pool'
         ]:
             endpoints[v] = tf.get_default_graph().get_tensor_by_name(
                 'base_model/{}:0'.format(v))
     if FLAGS.train_mode == 'pretrain':
         hiddens_proj = model_util.projection_head(hiddens, is_training)
         endpoints['proj_head_input'] = hiddens
         endpoints['proj_head_output'] = hiddens_proj
     else:
         logits_sup = model_util.supervised_head(hiddens, num_classes,
                                                 is_training)
         endpoints['logits_sup'] = logits_sup
     hub.add_signature(inputs=dict(images=inputs),
                       outputs=dict(endpoints, default=hiddens))
コード例 #2
0
ファイル: model_hmax.py プロジェクト: serre-lab/prj_selfsup
  def model_fn(features, labels, mode, params=None):
    """Build model and optimizer."""
    is_training = mode == tf.estimator.ModeKeys.TRAIN

    # Check training mode.
    if FLAGS.train_mode == 'pretrain':
      num_transforms = 1
      if FLAGS.use_td_loss:
        num_transforms += 1
      if FLAGS.use_bu_loss:
        num_transforms += 1

      if FLAGS.fine_tune_after_block > -1:
        raise ValueError('Does not support layer freezing during pretraining,'
                         'should set fine_tune_after_block<=-1 for safety.')
    elif FLAGS.train_mode == 'finetune':
      num_transforms = 1
    else:
      raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode))
    
    # Split channels, and optionally apply extra batched augmentation.
    features_list = tf.split(
        features, num_or_size_splits=num_transforms, axis=-1)
    
    if FLAGS.use_td_loss:
      target_images = features_list[-1]
      features_list = features_list[:-1]
      # transforms
      thetas_list = tf.split(
        labels['thetas'], num_or_size_splits=num_transforms, axis=-1)
      if FLAGS.train_mode == 'pretrain':  # Fix for fine-tuning/eval
        thetas = tf.concat(thetas_list[:-1], 0)
    else:
      target_images = features_list
    

    if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain':
      features_list, sigmas = data_util.batch_random_blur(
          features_list, FLAGS.image_size, FLAGS.image_size)
      if FLAGS.use_td_loss: 
        sigmas = tf.concat(sigmas, 0)
        thetas = tf.concat([thetas, sigmas[:,None]], 1) 
    else:
      if FLAGS.use_td_loss:
        sigmas = tf.zeros_like(thetas[:,0])
        thetas = tf.concat([thetas, sigmas[:,None]], 1) 
        # thetas = tf.zeros([target_images.get_shape().as_list()[0], 11]) 

    features = tf.concat(features_list, 0)  # (num_transforms * bsz, h, w, c)
    
    # Base network forward pass.
    with tf.variable_scope('base_model'):
      if FLAGS.train_mode == 'finetune':
        if FLAGS.fine_tune_after_block >= 4:
          # Finetune just supervised (linear) head will not update BN stats.
          model_train_mode = False
      else:
        if FLAGS.use_td_loss:
          viz_features = features
          features = (features, thetas)
        else:
          viz_features = features

        # Pretrain or finetune anything else will update BN stats.
        model_train_mode = is_training

      outputs = model(features, is_training=model_train_mode)
      
    # Add head and loss.
    if FLAGS.train_mode == 'pretrain':
      tpu_context = params['context'] if 'context' in params else None
      
      if FLAGS.use_td_loss and isinstance(outputs, tuple):
        hiddens, reconstruction, metric_hidden_r, metric_hidden_t = outputs
      else:
        hiddens = outputs
        reconstruction = features

      if FLAGS.use_td_loss:
        with tf.name_scope('td_loss'):
          if FLAGS.td_loss=='attractive':
            td_loss, logits_td_con, labels_td_con = obj_lib.td_attractive_loss(
              reconstruction=metric_hidden_r,
              target=metric_hidden_t,
              temperature=FLAGS.temperature,
              tpu_context=tpu_context if is_training else None)
            logits_td_con = tf.zeros([params['batch_size'], params['batch_size']])
            labels_td_con = tf.zeros([params['batch_size'], params['batch_size']])
          elif FLAGS.td_loss=='attractive_repulsive':
            td_loss, logits_td_con, labels_td_con = obj_lib.td_attractive_repulsive_loss(
              reconstruction=metric_hidden_r,
              target=metric_hidden_t,
              temperature=FLAGS.temperature,
              tpu_context=tpu_context if is_training else None)
          else:
            raise NotImplementedError("Error at TD loss {}".format(FLAGS.td_loss))
      else:
        # No TD loss
        logits_td_con = tf.zeros([params['batch_size'], params['batch_size']])
        labels_td_con = tf.zeros([params['batch_size'], params['batch_size']])
        td_loss = 0.
      hiddens_proj = model_util.projection_head(hiddens, is_training)

      if FLAGS.use_bu_loss:
        with tf.name_scope('bu_loss'):
          if FLAGS.bu_loss=='attractive':
            bu_loss, logits_bu_con, labels_bu_con = obj_lib.attractive_loss(
              hiddens_proj,
              temperature=FLAGS.temperature,
              hidden_norm=FLAGS.hidden_norm)
            logits_bu_con = tf.zeros([params['batch_size'], params['batch_size']])
            labels_bu_con = tf.zeros([params['batch_size'], params['batch_size']])

          elif FLAGS.bu_loss=='attractive_repulsive':
            bu_loss, logits_bu_con, labels_bu_con = obj_lib.attractive_repulsive_loss(
              hiddens_proj,
              hidden_norm=FLAGS.hidden_norm,
              temperature=FLAGS.temperature,
              tpu_context=tpu_context if is_training else None)  
          else:
            raise NotImplementedError('Unknown loss')
      else:
        # No BU loss
        logits_bu_con = tf.zeros([params['batch_size'], params['batch_size']])
        labels_bu_con = tf.zeros([params['batch_size'], params['batch_size']])
        bu_loss = 0.
      logits_sup = tf.zeros([params['batch_size'], num_classes])

    else:
      # contrast_loss = tf.zeros([])
      td_loss = tf.zeros([])
      bu_loss = tf.zeros([])
      logits_td_con = tf.zeros([params['batch_size'], 10])
      labels_td_con = tf.zeros([params['batch_size'], 10])
      logits_bu_con = tf.zeros([params['batch_size'], 10])
      labels_bu_con = tf.zeros([params['batch_size'], 10])
      hiddens = outputs
      hiddens = model_util.projection_head(hiddens, is_training)
      logits_sup = model_util.supervised_head(
          hiddens, num_classes, is_training)
      sup_loss = obj_lib.supervised_loss(
          labels=labels['labels'],
          logits=logits_sup,
          weights=labels['mask'])

    # Add weight decay to loss, for non-LARS optimizers.
    model_util.add_weight_decay(adjust_per_optimizer=True)
    
    # reg_loss = tf.losses.get_regularization_losses()

    
    if FLAGS.train_mode == 'pretrain':
      print(bu_loss)
      print(td_loss)
      loss =  tf.add_n([td_loss * FLAGS.td_loss_weight, bu_loss * FLAGS.bu_loss_weight] + tf.losses.get_regularization_losses())
    else:
      loss =  tf.add_n([sup_loss] + tf.losses.get_regularization_losses())
           
    # loss = tf.losses.get_total_loss()

    if FLAGS.train_mode == 'pretrain':
      variables_to_train = tf.trainable_variables()
    else:
      collection_prefix = 'trainable_variables_inblock_'
      variables_to_train = []
      for j in range(FLAGS.fine_tune_after_block + 1, 6):
        variables_to_train += tf.get_collection(collection_prefix + str(j))
      assert variables_to_train, 'variables_to_train shouldn\'t be empty!'

    tf.logging.info('===============Variables to train (begin)===============')
    tf.logging.info(variables_to_train)
    tf.logging.info('================Variables to train (end)================')

    learning_rate = model_util.learning_rate_schedule(
        FLAGS.learning_rate, num_train_examples)

    if is_training:
      
      if FLAGS.train_summary_steps > 0:
        # Compute stats for the summary.
        prob_bu_con = tf.nn.softmax(logits_bu_con)
        entropy_bu_con = - tf.reduce_mean(
            tf.reduce_sum(prob_bu_con * tf.math.log(prob_bu_con + 1e-8), -1))
        prob_td_con = tf.nn.softmax(logits_td_con)
        entropy_td_con = - tf.reduce_mean(
            tf.reduce_sum(prob_td_con * tf.math.log(prob_td_con + 1e-8), -1))

        contrast_bu_acc = tf.equal(
            tf.argmax(labels_bu_con, 1), tf.argmax(logits_bu_con, axis=1))
        contrast_bu_acc = tf.reduce_mean(tf.cast(contrast_bu_acc, tf.float32))
        contrast_td_acc = tf.equal(
            tf.argmax(labels_td_con, 1), tf.argmax(logits_td_con, axis=1))
        contrast_td_acc = tf.reduce_mean(tf.cast(contrast_td_acc, tf.float32))
        
        label_acc = tf.equal(
            tf.argmax(labels['labels'], 1), tf.argmax(logits_sup, axis=1))
        label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32))
        

        def host_call_fn(gs, g_l, bu_l, td_l, c_bu_a, c_td_a, l_a, c_e_bu, c_e_td, lr, tar_im, viz_f, rec_im):
          gs = gs[0]
          with tf2.summary.create_file_writer(
              FLAGS.model_dir,
              max_queue=FLAGS.checkpoint_steps).as_default():
            with tf2.summary.record_if(True):
              tf2.summary.scalar(
                  'total_loss',
                  g_l[0],
                  step=gs)
                  
              tf2.summary.scalar(
                  'train_bottomup_loss',
                  bu_l[0],
                  step=gs)

              tf2.summary.scalar(
                  'train_topdown_loss',
                  td_l[0],
                  step=gs)
              
              tf2.summary.scalar(
                  'train_bottomup_acc',
                  c_bu_a[0],
                  step=gs)
              tf2.summary.scalar(
                  'train_topdown_acc',
                  c_td_a[0],
                  step=gs)
              
              tf2.summary.scalar(
                  'train_label_accuracy',
                  l_a[0],
                  step=gs)
              
              tf2.summary.scalar(
                  'contrast_bu_entropy',
                  c_e_bu[0],
                  step=gs)
              tf2.summary.scalar(
                  'contrast_td_entropy',
                  c_e_td[0],
                  step=gs)
              
              tf2.summary.scalar(
                  'learning_rate', lr[0],
                  step=gs)

              # print("Images")
              # print(target_images)
              # print("Features")
              # print(viz_features)
              # print("Reconstruction")
              # print(reconstruction)
              tf2.summary.image(
                  'Images',
                  tar_im[0],
                  step=gs)
              tf2.summary.image(
                  'Transformed images',
                  viz_f[0],
                  step=gs)
              tf2.summary.image(
                  'Reconstructed images',
                  rec_im[0],
                  step=gs)

            return tf.summary.all_v2_summary_ops()


        n_images = 4
        if isinstance(target_images, list):
          target_images = target_images[0]
        image_shape = target_images.get_shape().as_list()

        tar_im = tf.reshape(tf.cast(target_images[:n_images], tf.float32), [1, n_images] + image_shape[1:])
        viz_f = tf.reshape(tf.cast(viz_features[:n_images], tf.float32), [1, n_images] + image_shape[1:])
        rec_im = tf.reshape(tf.cast(reconstruction[:n_images], tf.float32), [1, n_images] + image_shape[1:])
        
        gs = tf.reshape(tf.train.get_global_step(), [1])
        
        g_l = tf.reshape(loss, [1])

        bu_l = tf.reshape(bu_loss, [1])
        td_l = tf.reshape(td_loss, [1])

        c_bu_a = tf.reshape(contrast_bu_acc, [1])
        c_td_a = tf.reshape(contrast_td_acc, [1])
        
        l_a = tf.reshape(label_acc, [1])
        c_e_bu = tf.reshape(entropy_bu_con, [1])
        c_e_td = tf.reshape(entropy_td_con, [1])
        
        lr = tf.reshape(learning_rate, [1])
        
        host_call = (host_call_fn, [gs, g_l, bu_l, td_l, c_bu_a, c_td_a, l_a, c_e_bu, c_e_td, lr, tar_im, viz_f, rec_im])
        
      else:
        host_call=None

      optimizer = model_util.get_optimizer(learning_rate)
      control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      # if FLAGS.train_summary_steps > 0:
      #   control_deps.extend(tf.summary.all_v2_summary_ops())
      with tf.control_dependencies(control_deps):
        train_op = optimizer.minimize(
            loss, global_step=tf.train.get_or_create_global_step(),
            var_list=variables_to_train)
      
      
      if FLAGS.checkpoint:
        def scaffold_fn():
          """Scaffold function to restore non-logits vars from checkpoint."""
          tf.logging.info('*'*180)
          tf.logging.info('Initializing from checkpoint %s'%FLAGS.checkpoint)
          tf.logging.info('*'*180)

          tf.train.init_from_checkpoint(
              FLAGS.checkpoint,
              {v.op.name: v.op.name
               for v in tf.global_variables(FLAGS.variable_schema)})

          if FLAGS.zero_init_logits_layer:
            # Init op that initializes output layer parameters to zeros.
            output_layer_parameters = [
                var for var in tf.trainable_variables() if var.name.startswith(
                    'head_supervised')]
            tf.logging.info('Initializing output layer parameters %s to zero',
                            [x.op.name for x in output_layer_parameters])
            with tf.control_dependencies([tf.global_variables_initializer()]):
              init_op = tf.group([
                  tf.assign(x, tf.zeros_like(x))
                  for x in output_layer_parameters])
            return tf.train.Scaffold(init_op=init_op)
          else:
            return tf.train.Scaffold()
      else:
        scaffold_fn = None

      return tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode, 
          train_op=train_op, 
          loss=loss, 
          scaffold_fn=scaffold_fn, 
          host_call=host_call
          )

    else:

      def metric_fn(logits_sup, labels_sup, logits_bu_con, labels_bu_con, 
                    logits_td_con, labels_td_con, mask,
                    **kws):
        """Inner metric function."""
        metrics = {k: tf.metrics.mean(v, weights=mask)
                   for k, v in kws.items()}
        metrics['label_top_1_accuracy'] = tf.metrics.accuracy(
            tf.argmax(labels_sup, 1), tf.argmax(logits_sup, axis=1),
            weights=mask)
        metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k(
            tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask)
        
        metrics['bottomup_top_1_accuracy'] = tf.metrics.accuracy(
            tf.argmax(labels_bu_con, 1), tf.argmax(logits_bu_con, axis=1),
            weights=mask)
        # metrics['bottomup_top_5_accuracy'] = tf.metrics.recall_at_k(
        #     tf.argmax(labels_bu_con, 1), logits_bu_con, k=5, weights=mask)

        metrics['topdown_top_1_accuracy'] = tf.metrics.accuracy(
            tf.argmax(labels_td_con, 1), tf.argmax(logits_td_con, axis=1),
            weights=mask)
        # metrics['topdown_top_5_accuracy'] = tf.metrics.recall_at_k(
        #     tf.argmax(labels_td_con, 1), logits_td_con, k=5, weights=mask)
        return metrics

      metrics = {
          'logits_sup': logits_sup,
          'labels_sup': labels['labels'],
          'logits_bu_con': logits_bu_con,
          'logits_td_con': logits_td_con,
          'labels_bu_con': labels_bu_con,
          'labels_td_con': labels_td_con,
          'mask': labels['mask'],
          'td_loss': tf.fill((params['batch_size'],), bu_loss),
          'bu_loss': tf.fill((params['batch_size'],), td_loss),
          'regularization_loss': tf.fill((params['batch_size'],),
                                         tf.losses.get_regularization_loss()),
      }

      return tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=loss,
          eval_metrics=(metric_fn, metrics),
          host_call=None,
          scaffold_fn=None)