Пример #1
0
def model_fn(features, mode, params):
    '''The model_fn to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
        `params['batch_size']` is always provided and should be used as the
        effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  '''
    def preprocess_image(image):
        # In most cases, the default data format NCHW instead of NHWC should be
        # used for a significant performance boost on GPU. NHWC should be used
        # only if the network needs to be run on CPU since the pooling operations
        # are only supported on NHWC. TPU uses XLA compiler to figure out best layout.
        if FLAGS.data_format == 'channels_first':
            assert not FLAGS.transpose_input  # channels_first only for GPU
            image = tf.transpose(image, [0, 3, 1, 2])

        if FLAGS.transpose_input and mode == tf.estimator.ModeKeys.TRAIN:
            image = tf.transpose(image, [3, 0, 1, 2])  # HWCN to NHWC
        return image

    def normalize_image(image):
        # Normalize the image to zero mean and unit variance.
        if FLAGS.data_format == 'channels_first':
            stats_shape = [3, 1, 1]
        else:
            stats_shape = [1, 1, 3]
        mean, std = task_info.get_mean_std(FLAGS.task_name)
        image -= tf.constant(mean, shape=stats_shape, dtype=image.dtype)
        image /= tf.constant(std, shape=stats_shape, dtype=image.dtype)
        return image

    image = features['image']
    image = preprocess_image(image)

    image_shape = image.get_shape().as_list()
    tf.logging.info('image shape: {}'.format(image_shape))
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    if mode != tf.estimator.ModeKeys.PREDICT:
        labels = features['label']
    else:
        labels = None

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']  # pylint: disable=unused-variable

    if FLAGS.unlabel_ratio and is_training:
        unl_bsz = features['unl_probs'].shape[0]
    else:
        unl_bsz = 0

    lab_bsz = image.shape[0] - unl_bsz
    assert lab_bsz == batch_size

    metric_dict = {}
    global_step = tf.train.get_global_step()

    has_moving_average_decay = (FLAGS.moving_average_decay > 0)
    # This is essential, if using a keras-derived model.
    tf.keras.backend.set_learning_phase(is_training)
    tf.logging.info('Using open-source implementation.')
    override_params = {}
    if FLAGS.dropout_rate is not None:
        override_params['dropout_rate'] = FLAGS.dropout_rate
    if FLAGS.stochastic_depth_rate is not None:
        override_params['stochastic_depth_rate'] = FLAGS.stochastic_depth_rate
    if FLAGS.data_format:
        override_params['data_format'] = FLAGS.data_format
    if FLAGS.num_label_classes:
        override_params['num_classes'] = FLAGS.num_label_classes
    if FLAGS.depth_coefficient:
        override_params['depth_coefficient'] = FLAGS.depth_coefficient
    if FLAGS.width_coefficient:
        override_params['width_coefficient'] = FLAGS.width_coefficient

    def build_model(scope=None,
                    reuse=tf.AUTO_REUSE,
                    model_name=None,
                    model_is_training=None,
                    input_image=None,
                    use_adv_bn=False,
                    is_teacher=False):
        model_name = model_name or FLAGS.model_name
        if model_is_training is None:
            model_is_training = is_training
        if input_image is None:
            input_image = image
        input_image = normalize_image(input_image)

        scope_model_name = model_name

        if scope:
            scope = scope + '/'
        else:
            scope = ''
        with tf.variable_scope(scope + scope_model_name, reuse=reuse):
            if model_name.startswith('efficientnet'):
                logits, _ = efficientnet_builder.build_model(
                    input_image,
                    model_name=model_name,
                    training=model_is_training,
                    override_params=override_params,
                    model_dir=FLAGS.model_dir,
                    use_adv_bn=use_adv_bn,
                    is_teacher=is_teacher)
            else:
                assert False, 'model {} not implemented'.format(model_name)
        return logits

    if params['use_bfloat16']:
        with tf.tpu.bfloat16_scope():
            logits = tf.cast(build_model(), tf.float32)
    else:
        logits = build_model()

    if FLAGS.teacher_model_name:
        teacher_image = preprocess_image(features['teacher_image'])
        if params['use_bfloat16']:
            with tf.tpu.bfloat16_scope():
                teacher_logits = tf.cast(
                    build_model(scope='teacher_model',
                                model_name=FLAGS.teacher_model_name,
                                model_is_training=False,
                                input_image=teacher_image,
                                is_teacher=True), tf.float32)
        else:
            teacher_logits = build_model(scope='teacher_model',
                                         model_name=FLAGS.teacher_model_name,
                                         model_is_training=False,
                                         input_image=teacher_image,
                                         is_teacher=True)
        teacher_logits = tf.stop_gradient(teacher_logits)
        if FLAGS.teacher_softmax_temp != -1:
            teacher_prob = tf.nn.softmax(teacher_logits /
                                         FLAGS.teacher_softmax_temp)
        else:
            teacher_prob = None
            teacher_one_hot_pred = tf.argmax(teacher_logits,
                                             axis=1,
                                             output_type=labels.dtype)

    if mode == tf.estimator.ModeKeys.PREDICT:
        if has_moving_average_decay:
            ema = tf.train.ExponentialMovingAverage(
                decay=FLAGS.moving_average_decay)
            ema_vars = utils.get_all_variable()
            restore_vars_dict = ema.variables_to_restore(ema_vars)
            tf.logging.info(
                'restored variables:\n%s',
                json.dumps(sorted(restore_vars_dict.keys()), indent=4))

        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.tpu.TPUEstimatorSpec(
            mode=mode,
            predictions=predictions,
            scaffold_fn=functools.partial(_scaffold_fn,
                                          restore_vars_dict=restore_vars_dict)
            if has_moving_average_decay else None)

    if has_moving_average_decay:
        ema_step = global_step
        ema = tf.train.ExponentialMovingAverage(
            decay=FLAGS.moving_average_decay, num_updates=ema_step)
        ema_vars = utils.get_all_variable()

    lab_labels = labels[:lab_bsz]
    lab_logits = logits[:lab_bsz]
    lab_pred = tf.argmax(lab_logits, axis=-1, output_type=labels.dtype)
    lab_prob = tf.nn.softmax(lab_logits)
    lab_acc = tf.to_float(tf.equal(lab_pred, lab_labels))
    metric_dict['lab/acc'] = tf.reduce_mean(lab_acc)
    metric_dict['lab/pred_prob'] = tf.reduce_mean(
        tf.reduce_max(lab_prob, axis=-1))
    one_hot_labels = tf.one_hot(lab_labels, FLAGS.num_label_classes)

    if FLAGS.unlabel_ratio:
        unl_labels = labels[lab_bsz:]
        unl_logits = logits[lab_bsz:]
        unl_pred = tf.argmax(unl_logits, axis=-1, output_type=labels.dtype)
        unl_prob = tf.nn.softmax(unl_logits)
        unl_acc = tf.to_float(tf.equal(unl_pred, unl_labels))
        metric_dict['unl/acc_to_dump'] = tf.reduce_mean(unl_acc)
        metric_dict['unl/pred_prob'] = tf.reduce_mean(
            tf.reduce_max(unl_prob, axis=-1))

    # compute lab_loss
    one_hot_labels = tf.one_hot(lab_labels, FLAGS.num_label_classes)
    lab_loss = tf.losses.softmax_cross_entropy(
        logits=lab_logits,
        onehot_labels=one_hot_labels,
        label_smoothing=FLAGS.label_smoothing,
        reduction=tf.losses.Reduction.NONE)
    if FLAGS.label_data_sample_prob != 1:
        # mask out part of the labeled data
        random_mask = tf.floor(
            FLAGS.label_data_sample_prob +
            tf.random_uniform(tf.shape(lab_loss), dtype=lab_loss.dtype))
        lab_loss = tf.reduce_mean(lab_loss * random_mask)
    else:
        lab_loss = tf.reduce_mean(lab_loss)
    metric_dict['lab/loss'] = lab_loss

    if FLAGS.unlabel_ratio:
        if FLAGS.teacher_softmax_temp == -1:  # Hard labels
            # Get one-hot labels
            if FLAGS.teacher_model_name:
                ext_teacher_pred = teacher_one_hot_pred[lab_bsz:]
                one_hot_labels = tf.one_hot(ext_teacher_pred,
                                            FLAGS.num_label_classes)
            else:
                one_hot_labels = tf.one_hot(unl_labels,
                                            FLAGS.num_label_classes)
            # Compute cross entropy
            unl_loss = tf.losses.softmax_cross_entropy(
                logits=unl_logits,
                onehot_labels=one_hot_labels,
                label_smoothing=FLAGS.label_smoothing)
        else:  # Soft labels
            # Get teacher prob
            if FLAGS.teacher_model_name:
                unl_teacher_prob = teacher_prob[lab_bsz:]
            else:
                scaled_prob = tf.pow(features['unl_probs'],
                                     1 / FLAGS.teacher_softmax_temp)
                unl_teacher_prob = scaled_prob / tf.reduce_sum(
                    scaled_prob, axis=-1, keepdims=True)
            metric_dict['unl/target_prob'] = tf.reduce_mean(
                tf.reduce_max(unl_teacher_prob, axis=-1))
            unl_loss = cross_entropy(unl_teacher_prob,
                                     unl_logits,
                                     return_mean=True)

        metric_dict['ext/loss'] = unl_loss
    else:
        unl_loss = 0

    real_lab_bsz = tf.to_float(lab_bsz) * FLAGS.label_data_sample_prob
    real_unl_bsz = batch_size * FLAGS.label_data_sample_prob * FLAGS.unlabel_ratio
    data_loss = lab_loss * real_lab_bsz + unl_loss * real_unl_bsz
    data_loss = data_loss / real_lab_bsz

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = data_loss + FLAGS.weight_decay * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])
    metric_dict['train/data_loss'] = data_loss
    metric_dict['train/loss'] = loss

    host_call = None
    restore_vars_dict = None

    if is_training:
        # Compute the current epoch and associated learning rate from global_step.
        current_epoch = (tf.cast(global_step, tf.float32) /
                         params['steps_per_epoch'])
        real_train_batch_size = FLAGS.train_batch_size
        real_train_batch_size *= FLAGS.label_data_sample_prob
        scaled_lr = FLAGS.base_learning_rate * (real_train_batch_size / 256.0)
        if FLAGS.final_base_lr:
            # total number of training epochs
            total_epochs = FLAGS.train_steps * FLAGS.train_batch_size * 1. / FLAGS.num_train_images - 5
            decay_times = math.log(FLAGS.final_base_lr /
                                   FLAGS.base_learning_rate) / math.log(0.97)
            decay_epochs = total_epochs / decay_times
            tf.logging.info(
                'setting decay_epochs to {:.2f}'.format(decay_epochs) +
                '\n' * 3)
        else:
            decay_epochs = 2.4 * FLAGS.train_ratio
        learning_rate = utils.build_learning_rate(
            scaled_lr,
            global_step,
            params['steps_per_epoch'],
            decay_epochs=decay_epochs,
            start_from_step=FLAGS.train_steps - FLAGS.train_last_step_num,
            warmup_epochs=5,
        )
        metric_dict['train/lr'] = learning_rate
        metric_dict['train/epoch'] = current_epoch
        optimizer = utils.build_optimizer(learning_rate)
        if FLAGS.use_tpu:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tf.tpu.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        tvars = tf.trainable_variables()
        g_vars = []
        tvars = sorted(tvars, key=lambda var: var.name)
        for var in tvars:
            if 'teacher_model' not in var.name:
                g_vars += [var]
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step, var_list=g_vars)

        if has_moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

        if not FLAGS.skip_host_call:
            host_call = utils.construct_scalar_host_call(metric_dict)
        scaffold_fn = None
        if FLAGS.teacher_model_name or FLAGS.init_model:
            scaffold_fn = utils.init_from_ckpt(scaffold_fn)
    else:
        train_op = None
        if has_moving_average_decay:
            # Load moving average variables for eval.
            restore_vars_dict = ema.variables_to_restore(ema_vars)

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:
        scaffold_fn = functools.partial(_scaffold_fn,
                                        restore_vars_dict=restore_vars_dict
                                        ) if has_moving_average_decay else None

        def metric_fn(labels, logits):
            '''Evaluation metric function. Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      '''

            predictions = tf.argmax(logits, axis=1)
            top_1_accuracy = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            top_5_accuracy = tf.metrics.mean(in_top_5)

            result_dict = {
                'top_1_accuracy': top_1_accuracy,
                'top_5_accuracy': top_5_accuracy,
            }

            return result_dict

        eval_metrics = (metric_fn, [labels, logits])

    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
    tf.logging.info('number of trainable parameters: {}'.format(num_params))

    return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                             loss=loss,
                                             train_op=train_op,
                                             host_call=host_call,
                                             eval_metrics=eval_metrics,
                                             scaffold_fn=scaffold_fn)
Пример #2
0
    def model_fn(features, labels, mode, params):
        sup_labels = tf.reshape(features["label"], [-1])

        #### Configuring the optimizer
        global_step = tf.train.get_global_step()
        metric_dict = {}
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        if FLAGS.unsup_ratio > 0 and is_training:
            all_images = tf.concat([
                features["image"], features["ori_image"], features["aug_image"]
            ], 0)
        else:
            all_images = features["image"]

        with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
            all_logits = build_model(
                inputs=all_images,
                num_classes=FLAGS.num_classes,
                is_training=is_training,
                update_bn=True and is_training,
                hparams=hparams,
            )

            sup_bsz = tf.shape(features["image"])[0]
            sup_logits = all_logits[:sup_bsz]

            sup_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=sup_labels, logits=sup_logits)
            sup_prob = tf.nn.softmax(sup_logits, axis=-1)
            metric_dict["sup/pred_prob"] = tf.reduce_mean(
                tf.reduce_max(sup_prob, axis=-1))
        if FLAGS.tsa:
            sup_loss, avg_sup_loss = anneal_sup_loss(sup_logits, sup_labels,
                                                     sup_loss, global_step,
                                                     metric_dict)
        else:
            avg_sup_loss = tf.reduce_mean(sup_loss)
        total_loss = avg_sup_loss

        if FLAGS.unsup_ratio > 0 and is_training:
            aug_bsz = tf.shape(features["ori_image"])[0]

            ori_logits = all_logits[sup_bsz:sup_bsz + aug_bsz]
            aug_logits = all_logits[sup_bsz + aug_bsz:]
            if FLAGS.uda_softmax_temp != -1:
                ori_logits_tgt = ori_logits / FLAGS.uda_softmax_temp
            else:
                ori_logits_tgt = ori_logits
            ori_prob = tf.nn.softmax(ori_logits, axis=-1)
            aug_prob = tf.nn.softmax(aug_logits, axis=-1)
            metric_dict["unsup/ori_prob"] = tf.reduce_mean(
                tf.reduce_max(ori_prob, axis=-1))
            metric_dict["unsup/aug_prob"] = tf.reduce_mean(
                tf.reduce_max(aug_prob, axis=-1))

            aug_loss = _kl_divergence_with_logits(
                p_logits=tf.stop_gradient(ori_logits_tgt), q_logits=aug_logits)

            if FLAGS.uda_confidence_thresh != -1:
                ori_prob = tf.nn.softmax(ori_logits, axis=-1)
                largest_prob = tf.reduce_max(ori_prob, axis=-1)
                loss_mask = tf.cast(
                    tf.greater(largest_prob, FLAGS.uda_confidence_thresh),
                    tf.float32)
                metric_dict["unsup/high_prob_ratio"] = tf.reduce_mean(
                    loss_mask)
                loss_mask = tf.stop_gradient(loss_mask)
                aug_loss = aug_loss * loss_mask
                metric_dict["unsup/high_prob_loss"] = tf.reduce_mean(aug_loss)

            if FLAGS.ent_min_coeff > 0:
                ent_min_coeff = FLAGS.ent_min_coeff
                metric_dict["unsup/ent_min_coeff"] = ent_min_coeff
                per_example_ent = get_ent(ori_logits)
                ent_min_loss = tf.reduce_mean(per_example_ent)
                total_loss = total_loss + ent_min_coeff * ent_min_loss

            avg_unsup_loss = tf.reduce_mean(aug_loss)
            total_loss += FLAGS.unsup_coeff * avg_unsup_loss
            metric_dict["unsup/loss"] = avg_unsup_loss

        total_loss = utils.decay_weights(total_loss, FLAGS.weight_decay_rate)

        #### Check model parameters
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        tf.logging.info("#params: {}".format(num_params))

        if FLAGS.verbose:
            format_str = "{{:<{0}s}}\t{{}}".format(
                max([len(v.name) for v in tf.trainable_variables()]))
            for v in tf.trainable_variables():
                tf.logging.info(format_str.format(v.name, v.get_shape()))

        #### Evaluation mode
        if mode == tf.estimator.ModeKeys.EVAL:
            #### Metric function for classification
            def metric_fn(per_example_loss, label_ids, logits):
                # classification loss & accuracy
                loss = tf.metrics.mean(per_example_loss)

                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(label_ids, predictions)

                ret_dict = {
                    "eval/classify_loss": loss,
                    "eval/classify_accuracy": accuracy
                }

                return ret_dict

            eval_metrics = (metric_fn, [sup_loss, sup_labels, sup_logits])

            #### Constucting evaluation TPUEstimatorSpec.
            eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode, loss=total_loss, eval_metrics=eval_metrics)

            return eval_spec

        # increase the learning rate linearly
        if FLAGS.warmup_steps > 0:
            warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                        * FLAGS.learning_rate
        else:
            warmup_lr = 0.0

        # decay the learning rate using the cosine schedule
        decay_lr = tf.train.cosine_decay(
            FLAGS.learning_rate,
            global_step=global_step - FLAGS.warmup_steps,
            decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
            alpha=FLAGS.min_lr_ratio)

        learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr,
                                 decay_lr)

        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=0.9,
                                               use_nesterov=True)

        if FLAGS.use_tpu:
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

        grads_and_vars = optimizer.compute_gradients(total_loss)
        gradients, variables = zip(*grads_and_vars)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.apply_gradients(
                zip(gradients, variables),
                global_step=tf.train.get_global_step())

        #### Creating training logging hook
        # compute accuracy
        sup_pred = tf.argmax(sup_logits, axis=-1, output_type=sup_labels.dtype)
        is_correct = tf.to_float(tf.equal(sup_pred, sup_labels))
        acc = tf.reduce_mean(is_correct)
        metric_dict["sup/sup_loss"] = avg_sup_loss
        metric_dict["training/loss"] = total_loss
        metric_dict["sup/acc"] = acc
        metric_dict["training/lr"] = learning_rate
        metric_dict["training/step"] = global_step

        if not FLAGS.use_tpu:
            log_info = ("step [{training/step}] lr {training/lr:.6f} "
                        "loss {training/loss:.4f} "
                        "sup/acc {sup/acc:.4f} sup/loss {sup/sup_loss:.6f} ")
            if FLAGS.unsup_ratio > 0:
                log_info += "unsup/loss {unsup/loss:.6f} "
            formatter = lambda kwargs: log_info.format(**kwargs)
            logging_hook = tf.train.LoggingTensorHook(
                tensors=metric_dict,
                every_n_iter=FLAGS.iterations,
                formatter=formatter)
            training_hooks = [logging_hook]
            #### Constucting training TPUEstimatorSpec.
            train_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                training_hooks=training_hooks)
        else:
            #### Constucting training TPUEstimatorSpec.
            host_call = utils.construct_scalar_host_call(
                metric_dict=metric_dict,
                model_dir=params["model_dir"],
                prefix="",
                reduce_fn=tf.reduce_mean)
            train_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                                         loss=total_loss,
                                                         train_op=train_op,
                                                         host_call=host_call)

        return train_spec
Пример #3
0
    def model_fn(features, labels, mode, params):
        sup_labels = tf.reshape(features["label"], [-1])

        #### Configuring the optimizer
        global_step = tf.train.get_global_step()
        metric_dict = {}
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        if FLAGS.unsup_ratio > 0 and is_training:
            all_images = tf.concat([
                features["image"], features["ori_image"], features["aug_image"]
            ], 0)
        else:
            all_images = features["image"]

        with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
            all_logits = build_model(
                inputs=all_images,
                num_classes=FLAGS.num_classes,
                feature_dim=128,
                is_training=is_training,
                update_bn=True and is_training,
                hparams=hparams,
            )
            sup_bsz = tf.shape(features["image"])[0]

            sup_logits = all_logits[0][:sup_bsz]
            print('sup_buz')
            print(sup_bsz)
            sup_features = all_logits[1][:sup_bsz]

            map_dict = read_pkl()
            tmp_list = [x.numpy() for x in map_dict.values()]
            pedcc_features_all = np.concatenate(tmp_list)

            def f0():
                return tmp_list[0]

            def f1():
                return tmp_list[1]

            def f2():
                return tmp_list[2]

            def f3():
                return tmp_list[3]

            def f4():
                return tmp_list[4]

            def f5():
                return tmp_list[5]

            def f6():
                return tmp_list[6]

            def f7():
                return tmp_list[7]

            def f8():
                return tmp_list[8]

            def f9():
                return tmp_list[9]

            def f10():
                pass

            for i in range(FLAGS.train_batch_size):
                tmp = sup_labels[i]
                test = tf.case(
                    {
                        tf.equal(tmp, 0): f0,
                        tf.equal(tmp, 1): f1,
                        tf.equal(tmp, 2): f2,
                        tf.equal(tmp, 3): f3,
                        tf.equal(tmp, 4): f4,
                        tf.equal(tmp, 5): f5,
                        tf.equal(tmp, 6): f6,
                        tf.equal(tmp, 7): f7,
                        tf.equal(tmp, 8): f8,
                        tf.equal(tmp, 9): f9
                    },
                    exclusive=True)
                if i == 0:
                    feature_label = test
                else:
                    feature_label = tf.concat([feature_label, test], axis=0)

            pedcc_features = tf.cast(feature_label, dtype=tf.float32)

            mse_loss = tf.reduce_mean(tf.square(sup_features - pedcc_features))
            loss_2 = AM_loss(sup_logits, sup_labels)
            sup_loss = mse_loss + loss_2
            sup_prob = tf.nn.softmax(sup_logits, axis=-1)
            metric_dict["sup/pred_prob"] = tf.reduce_mean(
                tf.reduce_max(sup_prob, axis=-1))

        avg_sup_loss = tf.reduce_mean(sup_loss)
        total_loss = avg_sup_loss

        if FLAGS.unsup_ratio > 0 and is_training:
            aug_bsz = tf.shape(features["ori_image"])[0]

            ori_logits = all_logits[0][sup_bsz:sup_bsz + aug_bsz]
            ori_features = all_logits[1][sup_bsz:sup_bsz + aug_bsz]
            aug_logits = all_logits[0][sup_bsz + aug_bsz:]

            ori_logits_tgt = ori_logits
            ori_prob = tf.nn.softmax(ori_logits, axis=-1)
            aug_prob = tf.nn.softmax(aug_logits, axis=-1)
            metric_dict["unsup/ori_prob"] = tf.reduce_mean(
                tf.reduce_max(ori_prob, axis=-1))
            metric_dict["unsup/aug_prob"] = tf.reduce_mean(
                tf.reduce_max(aug_prob, axis=-1))

            for i in range(
                    0, int(FLAGS.train_batch_size * FLAGS.unsup_ratio / 10 -
                           1)):  ##
                # print(i)
                if i == 0:
                    pedcc_features_sum = tf.concat(
                        [pedcc_features_all, pedcc_features_all], axis=0)
                else:
                    pedcc_features_sum = tf.concat(
                        [pedcc_features_sum, pedcc_features_all], axis=0)
            pedcc_features_sum = tf.cast(pedcc_features_sum, dtype=tf.float32)

            mmd_loss = mmd_rbf(ori_features, pedcc_features_sum)
            mmd_loss = mmd_loss * 0.2
            aug_loss = _kl_divergence_with_logits(
                p_logits=tf.stop_gradient(ori_logits_tgt), q_logits=aug_logits)

            avg_unsup_loss = tf.reduce_mean(aug_loss)
            avg_unsup_loss = avg_unsup_loss * 400
            total_loss += FLAGS.unsup_coeff * avg_unsup_loss
            total_loss += mmd_loss
            metric_dict["unsup/mmd_loss"] = mmd_loss
            metric_dict["unsup/loss"] = avg_unsup_loss

        total_loss = utils.decay_weights(total_loss, FLAGS.weight_decay_rate)

        #### Check model parameters
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        tf.logging.info("#params: {}".format(num_params))

        #### Evaluation mode
        if mode == tf.estimator.ModeKeys.EVAL:
            #### Metric function for classification
            def metric_fn(per_example_loss, label_ids, logits):
                # classification loss & accuracy
                loss = tf.metrics.mean(per_example_loss)

                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(label_ids, predictions)

                ret_dict = {
                    "eval/classify_loss": loss,
                    "eval/classify_accuracy": accuracy
                }

                return ret_dict

            eval_metrics = (metric_fn, [sup_loss, sup_labels, sup_logits])

            #### Constucting evaluation TPUEstimatorSpec.
            eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode, loss=total_loss, eval_metrics=eval_metrics)

            return eval_spec

        # increase the learning rate linearly
        if FLAGS.warmup_steps > 0:
            warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                        * FLAGS.learning_rate
        else:
            warmup_lr = 0.0

        # decay the learning rate using the cosine schedule
        decay_lr = tf.train.cosine_decay(
            FLAGS.learning_rate,
            global_step=global_step - FLAGS.warmup_steps,
            decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
            alpha=FLAGS.min_lr_ratio)

        learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr,
                                 decay_lr)

        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=0.9,
                                               use_nesterov=True)

        #### use_tpu =false  ###
        if FLAGS.use_tpu:
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

        grads_and_vars = optimizer.compute_gradients(total_loss)
        gradients, variables = zip(*grads_and_vars)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.apply_gradients(
                zip(gradients, variables),
                global_step=tf.train.get_global_step())

        #### Creating training logging hook
        # compute accuracy
        sup_pred = tf.argmax(sup_logits, axis=-1, output_type=sup_labels.dtype)
        is_correct = tf.to_float(tf.equal(sup_pred, sup_labels))
        acc = tf.reduce_mean(is_correct)
        metric_dict["sup/sup_loss"] = avg_sup_loss
        metric_dict["training/loss"] = total_loss
        metric_dict["sup/acc"] = acc
        metric_dict["training/lr"] = learning_rate
        metric_dict["training/step"] = global_step

        if not FLAGS.use_tpu:
            log_info = ("step [{training/step}] lr {training/lr:.6f} "
                        "loss {training/loss:.4f} "
                        "sup/acc {sup/acc:.4f} sup/loss {sup/sup_loss:.6f} ")
            if FLAGS.unsup_ratio > 0:
                log_info += "unsup/loss {unsup/loss:.6f} "
                log_info += "unsup/mmd_loss {unsup/mmd_loss:.6f} "
            formatter = lambda kwargs: log_info.format(**kwargs)
            logging_hook = tf.train.LoggingTensorHook(
                tensors=metric_dict,
                every_n_iter=FLAGS.iterations,
                formatter=formatter)
            training_hooks = [logging_hook]
            #### Constucting training TPUEstimatorSpec.
            train_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                training_hooks=training_hooks)
        else:
            #### Constucting training TPUEstimatorSpec.
            host_call = utils.construct_scalar_host_call(
                metric_dict=metric_dict,
                model_dir=params["model_dir"],
                prefix="",
                reduce_fn=tf.reduce_mean)
            train_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                                         loss=total_loss,
                                                         train_op=train_op,
                                                         host_call=host_call)

        return train_spec
Пример #4
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = GroverModel(
            config=config,
            is_training=is_training,
            input_ids=input_ids,
            pad_token_id=config.pad_token_id,
            chop_off_last_token=True,
        )

        total_loss = model.lm_loss()
        print(model.logits_flat)
        print(total_loss)

        if is_training:
            train_op, train_metrics = create_optimizer(total_loss,
                                                       learning_rate,
                                                       num_train_steps,
                                                       num_warmup_steps,
                                                       use_tpu)
            tvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        else:
            train_op = None
            train_metrics = {}
            tvars = tf.trainable_variables()
        params_sum = np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ])
        tf.logging.info("**** Trainable params_sum ****")
        tf.logging.info(params_sum)
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map,
             initialized_variable_names) = get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            if use_tpu:
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op,
                    host_call=construct_scalar_host_call(
                        metric_dict=train_metrics,
                        model_dir=params['model_dir'],
                        prefix='training/'),
                    scaffold_fn=scaffold_fn)
            else:
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op,
                    training_hooks=[
                        tf.train.LoggingTensorHook(
                            {
                                "train_loss": total_loss,
                                "global_step": tf.train.global_step
                            },
                            every_n_iter=10)
                    ],
                    scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(total_loss):
                loss = tf.metrics.mean(values=total_loss)
                return {
                    "eval_loss": loss,
                }

            eval_metrics = (metric_fn, [total_loss])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            gt_logprobs = tf.squeeze(tf.batch_gather(
                model.log_probs, model.target_ids[:, :, None]),
                                     axis=2)

            # Need top-p required under topp sampling!
            better_than_gt = model.log_probs > gt_logprobs[:, :, None]
            top_p_required = tf.reduce_sum(
                tf.cast(better_than_gt, tf.float32) * tf.exp(model.log_probs),
                axis=2)

            # No top-p sampling for now, since this seems to be too slow on TPUs
            if use_tpu:
                predictions = tf.reshape(
                    tf.random.categorical(logits=model.logits_flat,
                                          num_samples=1),
                    get_shape_list(model.target_ids),
                )
            else:
                # Argmax
                # predictions = tf.math.argmax(model.log_probs, axis=-1, output_type=tf.int32)
                predictions = tf.reshape(
                    _top_p_sample(model.logits_flat, num_samples=1,
                                  p=0.99)['sample'],
                    get_shape_list(model.target_ids),
                )
            pred_logprobs = tf.squeeze(tf.batch_gather(model.log_probs,
                                                       predictions[:, :,
                                                                   None]),
                                       axis=2)

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions={
                    'gt_logprobs': gt_logprobs,
                    'top_p_required': top_p_required,
                    'predictions': predictions,
                    'pred_logprobs': pred_logprobs,
                    'labels': input_ids
                },
                scaffold_fn=scaffold_fn)
        return output_spec