예제 #1
0
파일: main.py 프로젝트: zyq0104/uda
    def model_fn(features, labels, mode, params):
        sup_labels = tf.reshape(features["label"], [-1])

        #### Configuring the optimizer
        global_step = tf.train.get_global_step()
        metric_dict = {}
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        if FLAGS.unsup_ratio > 0 and is_training:
            all_images = tf.concat([
                features["image"], features["ori_image"], features["aug_image"]
            ], 0)
        else:
            all_images = features["image"]

        with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
            all_logits = build_model(
                inputs=all_images,
                num_classes=FLAGS.num_classes,
                is_training=is_training,
                update_bn=True and is_training,
                hparams=hparams,
            )

            sup_bsz = tf.shape(features["image"])[0]
            sup_logits = all_logits[:sup_bsz]

            sup_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=sup_labels, logits=sup_logits)
            sup_prob = tf.nn.softmax(sup_logits, axis=-1)
            metric_dict["sup/pred_prob"] = tf.reduce_mean(
                tf.reduce_max(sup_prob, axis=-1))
        if FLAGS.tsa:
            sup_loss, avg_sup_loss = anneal_sup_loss(sup_logits, sup_labels,
                                                     sup_loss, global_step,
                                                     metric_dict)
        else:
            avg_sup_loss = tf.reduce_mean(sup_loss)
        total_loss = avg_sup_loss

        if FLAGS.unsup_ratio > 0 and is_training:
            aug_bsz = tf.shape(features["ori_image"])[0]

            ori_logits = all_logits[sup_bsz:sup_bsz + aug_bsz]
            aug_logits = all_logits[sup_bsz + aug_bsz:]
            if FLAGS.uda_softmax_temp != -1:
                ori_logits_tgt = ori_logits / FLAGS.uda_softmax_temp
            else:
                ori_logits_tgt = ori_logits
            ori_prob = tf.nn.softmax(ori_logits, axis=-1)
            aug_prob = tf.nn.softmax(aug_logits, axis=-1)
            metric_dict["unsup/ori_prob"] = tf.reduce_mean(
                tf.reduce_max(ori_prob, axis=-1))
            metric_dict["unsup/aug_prob"] = tf.reduce_mean(
                tf.reduce_max(aug_prob, axis=-1))

            aug_loss = _kl_divergence_with_logits(
                p_logits=tf.stop_gradient(ori_logits_tgt), q_logits=aug_logits)

            if FLAGS.uda_confidence_thresh != -1:
                ori_prob = tf.nn.softmax(ori_logits, axis=-1)
                largest_prob = tf.reduce_max(ori_prob, axis=-1)
                loss_mask = tf.cast(
                    tf.greater(largest_prob, FLAGS.uda_confidence_thresh),
                    tf.float32)
                metric_dict["unsup/high_prob_ratio"] = tf.reduce_mean(
                    loss_mask)
                loss_mask = tf.stop_gradient(loss_mask)
                aug_loss = aug_loss * loss_mask
                metric_dict["unsup/high_prob_loss"] = tf.reduce_mean(aug_loss)

            if FLAGS.ent_min_coeff > 0:
                ent_min_coeff = FLAGS.ent_min_coeff
                metric_dict["unsup/ent_min_coeff"] = ent_min_coeff
                per_example_ent = get_ent(ori_logits)
                ent_min_loss = tf.reduce_mean(per_example_ent)
                total_loss = total_loss + ent_min_coeff * ent_min_loss

            avg_unsup_loss = tf.reduce_mean(aug_loss)
            total_loss += FLAGS.unsup_coeff * avg_unsup_loss
            metric_dict["unsup/loss"] = avg_unsup_loss

        total_loss = utils.decay_weights(total_loss, FLAGS.weight_decay_rate)

        #### Check model parameters
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        tf.logging.info("#params: {}".format(num_params))

        if FLAGS.verbose:
            format_str = "{{:<{0}s}}\t{{}}".format(
                max([len(v.name) for v in tf.trainable_variables()]))
            for v in tf.trainable_variables():
                tf.logging.info(format_str.format(v.name, v.get_shape()))

        #### Evaluation mode
        if mode == tf.estimator.ModeKeys.EVAL:
            #### Metric function for classification
            def metric_fn(per_example_loss, label_ids, logits):
                # classification loss & accuracy
                loss = tf.metrics.mean(per_example_loss)

                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(label_ids, predictions)

                ret_dict = {
                    "eval/classify_loss": loss,
                    "eval/classify_accuracy": accuracy
                }

                return ret_dict

            eval_metrics = (metric_fn, [sup_loss, sup_labels, sup_logits])

            #### Constucting evaluation TPUEstimatorSpec.
            eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode, loss=total_loss, eval_metrics=eval_metrics)

            return eval_spec

        # increase the learning rate linearly
        if FLAGS.warmup_steps > 0:
            warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                        * FLAGS.learning_rate
        else:
            warmup_lr = 0.0

        # decay the learning rate using the cosine schedule
        decay_lr = tf.train.cosine_decay(
            FLAGS.learning_rate,
            global_step=global_step - FLAGS.warmup_steps,
            decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
            alpha=FLAGS.min_lr_ratio)

        learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr,
                                 decay_lr)

        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=0.9,
                                               use_nesterov=True)

        if FLAGS.use_tpu:
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

        grads_and_vars = optimizer.compute_gradients(total_loss)
        gradients, variables = zip(*grads_and_vars)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.apply_gradients(
                zip(gradients, variables),
                global_step=tf.train.get_global_step())

        #### Creating training logging hook
        # compute accuracy
        sup_pred = tf.argmax(sup_logits, axis=-1, output_type=sup_labels.dtype)
        is_correct = tf.to_float(tf.equal(sup_pred, sup_labels))
        acc = tf.reduce_mean(is_correct)
        metric_dict["sup/sup_loss"] = avg_sup_loss
        metric_dict["training/loss"] = total_loss
        metric_dict["sup/acc"] = acc
        metric_dict["training/lr"] = learning_rate
        metric_dict["training/step"] = global_step

        if not FLAGS.use_tpu:
            log_info = ("step [{training/step}] lr {training/lr:.6f} "
                        "loss {training/loss:.4f} "
                        "sup/acc {sup/acc:.4f} sup/loss {sup/sup_loss:.6f} ")
            if FLAGS.unsup_ratio > 0:
                log_info += "unsup/loss {unsup/loss:.6f} "
            formatter = lambda kwargs: log_info.format(**kwargs)
            logging_hook = tf.train.LoggingTensorHook(
                tensors=metric_dict,
                every_n_iter=FLAGS.iterations,
                formatter=formatter)
            training_hooks = [logging_hook]
            #### Constucting training TPUEstimatorSpec.
            train_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                training_hooks=training_hooks)
        else:
            #### Constucting training TPUEstimatorSpec.
            host_call = utils.construct_scalar_host_call(
                metric_dict=metric_dict,
                model_dir=params["model_dir"],
                prefix="",
                reduce_fn=tf.reduce_mean)
            train_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                                         loss=total_loss,
                                                         train_op=train_op,
                                                         host_call=host_call)

        return train_spec
  def model_fn(features, labels, mode, params):
    sup_labels = tf.reshape(features["label"], [-1])

    #### Configuring the optimizer
    global_step = tf.train.get_global_step()
    metric_dict = {}
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    if FLAGS.unsup_ratio > 0 and is_training:
      all_images = tf.concat([features["image"],
                              features["ori_image"],
                              features["aug_image"]], 0)
    else:
      all_images = features["image"]

    with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
      all_logits = build_model(
          inputs=all_images,
          num_classes=FLAGS.num_classes,
          feature_dim =128,
          is_training=is_training,
          update_bn=True and is_training,
          hparams=hparams,
      )
      sup_bsz = tf.shape(features["image"])[0]

      sup_logits = all_logits[0][:sup_bsz]
      print('sup_buz')
      print(sup_bsz)
      sup_features = all_logits[1][:sup_bsz]


      map_dict = read_pkl()
      tmp_list = [x.numpy() for x in map_dict.values()]
      pedcc_features_all = np.concatenate(tmp_list)

      def f0():
          return tmp_list[0]
      def f1():
          return tmp_list[1]
      def f2():
          return tmp_list[2]
      def f3():
          return tmp_list[3]
      def f4():
          return tmp_list[4]
      def f5():
          return tmp_list[5]
      def f6():
          return tmp_list[6]
      def f7():
          return tmp_list[7]
      def f8():
          return tmp_list[8]
      def f9():
          return tmp_list[9]
      def f10():
          pass

      for i in range(FLAGS.train_batch_size):
          tmp = sup_labels[i]
          test = tf.case({
              tf.equal(tmp,0):  f0,
              tf.equal(tmp, 1): f1,
              tf.equal(tmp, 2): f2,
              tf.equal(tmp, 3): f3,
              tf.equal(tmp, 4): f4,
              tf.equal(tmp, 5): f5,
              tf.equal(tmp, 6): f6,
              tf.equal(tmp, 7): f7,
              tf.equal(tmp, 8): f8,
              tf.equal(tmp, 9): f9
          },exclusive=True)
          if i==0:
              feature_label=test
          else:
              feature_label=tf.concat([feature_label,test], axis=0)

      pedcc_features = tf.cast(feature_label, dtype=tf.float32)

      mse_loss = tf.reduce_mean(tf.square(sup_features- pedcc_features))
      loss_2 = AM_loss (sup_logits,sup_labels)
      sup_loss = mse_loss + loss_2
      sup_prob = tf.nn.softmax(sup_logits, axis=-1)
      metric_dict["sup/pred_prob"] = tf.reduce_mean(
          tf.reduce_max(sup_prob, axis=-1))

    avg_sup_loss = tf.reduce_mean(sup_loss)
    total_loss = avg_sup_loss

    if FLAGS.unsup_ratio > 0 and is_training:
      aug_bsz = tf.shape(features["ori_image"])[0]

      ori_logits = all_logits[0][sup_bsz : sup_bsz + aug_bsz]
      ori_features = all_logits[1][sup_bsz: sup_bsz + aug_bsz]
      aug_logits = all_logits[0][sup_bsz + aug_bsz:]

      ori_logits_tgt = ori_logits
      ori_prob = tf.nn.softmax(ori_logits, axis=-1)
      aug_prob = tf.nn.softmax(aug_logits, axis=-1)
      metric_dict["unsup/ori_prob"] = tf.reduce_mean(
          tf.reduce_max(ori_prob, axis=-1))
      metric_dict["unsup/aug_prob"] = tf.reduce_mean(
          tf.reduce_max(aug_prob, axis=-1))

      for i in range(0,int(FLAGS.train_batch_size*FLAGS.unsup_ratio/10-1)):  ##
          # print(i)
          if i==0:
              pedcc_features_sum = tf.concat([pedcc_features_all, pedcc_features_all], axis=0)
          else:
              pedcc_features_sum = tf.concat([pedcc_features_sum,pedcc_features_all], axis=0)
      pedcc_features_sum = tf.cast(pedcc_features_sum, dtype=tf.float32)


      mmd_loss = mmd_rbf(ori_features,pedcc_features_sum)
      mmd_loss = mmd_loss * 0.04
      aug_loss = _kl_divergence_with_logits(
          p_logits=tf.stop_gradient(ori_logits_tgt),
          q_logits=aug_logits)

      avg_unsup_loss = tf.reduce_mean(aug_loss)
      avg_unsup_loss = avg_unsup_loss*1600
      total_loss += FLAGS.unsup_coeff * avg_unsup_loss
      total_loss += mmd_loss
      metric_dict["unsup/mmd_loss"] = mmd_loss
      metric_dict["unsup/loss"] = avg_unsup_loss

    total_loss = utils.decay_weights(
        total_loss,
        FLAGS.weight_decay_rate)



    #### Check model parameters
    num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
    tf.logging.info("#params: {}".format(num_params))


    #### Evaluation mode
    if mode == tf.estimator.ModeKeys.EVAL:
      #### Metric function for classification
      def metric_fn(per_example_loss, label_ids, logits):
        # classification loss & accuracy
        loss = tf.metrics.mean(per_example_loss)

        predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
        accuracy = tf.metrics.accuracy(label_ids, predictions)

        ret_dict = {
            "eval/classify_loss": loss,
            "eval/classify_accuracy": accuracy
        }

        return ret_dict

      eval_metrics = (metric_fn, [sup_loss, sup_labels, sup_logits])

      #### Constucting evaluation TPUEstimatorSpec.
      eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          eval_metrics=eval_metrics)

      return eval_spec

    # increase the learning rate linearly
    if FLAGS.warmup_steps > 0:
      warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                  * FLAGS.learning_rate
    else:
      warmup_lr = 0.0

    # decay the learning rate using the cosine schedule
    decay_lr = tf.train.cosine_decay(
        FLAGS.learning_rate,
        global_step=global_step-FLAGS.warmup_steps,
        decay_steps=FLAGS.train_steps-FLAGS.warmup_steps,
        alpha=FLAGS.min_lr_ratio)

    learning_rate = tf.where(global_step < FLAGS.warmup_steps,
                             warmup_lr, decay_lr)


    optimizer = tf.train.MomentumOptimizer(
        learning_rate=learning_rate,
        momentum=0.9,
        use_nesterov=True)

    #### use_tpu =false  ###
    if FLAGS.use_tpu:
      optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

    grads_and_vars = optimizer.compute_gradients(total_loss)
    gradients, variables = zip(*grads_and_vars)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
      train_op = optimizer.apply_gradients(
          zip(gradients, variables), global_step=tf.train.get_global_step())

    #### Creating training logging hook
    # compute accuracy
    sup_pred = tf.argmax(sup_logits, axis=-1, output_type=sup_labels.dtype)
    is_correct = tf.to_float(tf.equal(sup_pred, sup_labels))
    acc = tf.reduce_mean(is_correct)
    metric_dict["sup/sup_loss"] = avg_sup_loss
    metric_dict["training/loss"] = total_loss
    metric_dict["sup/acc"] = acc
    metric_dict["training/lr"] = learning_rate
    metric_dict["training/step"] = global_step


    if not FLAGS.use_tpu:
      log_info = ("step [{training/step}] lr {training/lr:.6f} "
                  "loss {training/loss:.4f} "
                  "sup/acc {sup/acc:.4f} sup/loss {sup/sup_loss:.6f} ")
      if FLAGS.unsup_ratio > 0:
        log_info += "unsup/loss {unsup/loss:.6f} "
        log_info += "unsup/mmd_loss {unsup/mmd_loss:.6f} "
      formatter = lambda kwargs: log_info.format(**kwargs)
      logging_hook = tf.train.LoggingTensorHook(
          tensors=metric_dict,
          every_n_iter=FLAGS.iterations,
          formatter=formatter)
      training_hooks = [logging_hook]
      #### Constucting training TPUEstimatorSpec.
      train_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode, loss=total_loss, train_op=train_op,
          training_hooks=training_hooks)
    else:
      #### Constucting training TPUEstimatorSpec.
      host_call = utils.construct_scalar_host_call(
          metric_dict=metric_dict,
          model_dir=params["model_dir"],
          prefix="",
          reduce_fn=tf.reduce_mean)
      train_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode, loss=total_loss, train_op=train_op,
          host_call=host_call)

    return train_spec