예제 #1
0
 def teacher_param_getter(getter, name, *args, **kwargs):
     # assert name in student_param_dict, "Unknown variable {}.".format(name)
     if name in student_param_dict:
         if "log_sigma2" in name:
             print("Take the log over average of sigma2!")
             return log_w_clip(
                 teacher_ema.average(student_param_dict[name]))
         else:
             return teacher_ema.average(student_param_dict[name])
     else:
         return getter(name, *args, **kwargs)
예제 #2
0
    def _consistency_loss(self):
        y_prob = self.get_output('y_dist_sto')['prob']
        if self.cons_against_mean:
            ya_prob = self.get_output('ya_dist_det')['prob']
        else:
            ya_prob = self.get_output('ya_dist_sto')['prob']

        if self.cons_mode == 'mse':
            # IMPORTANT: Here, we take the sum over classes.
            # Implementations from other papers they use 'mean' instead of 'sum'.
            # This suggests that our 'cons_coeff' must be about 10, not 100 like other papers
            print("cons_mode=mse!")
            consistency = tf.reduce_sum(tf.square(y_prob - tf.stop_gradient(ya_prob)), axis=1)
        elif self.cons_mode == 'kld':
            print("cons_mode=kld!")
            from my_utils.tensorflow_utils.distributions import KLD_2Cats_v2
            consistency = KLD_2Cats_v2(y_prob, tf.stop_gradient(ya_prob))
        elif self.cons_mode == 'rev_kld':
            print("cons_mode=rev_kld!")
            from my_utils.tensorflow_utils.distributions import KLD_2Cats_v2
            consistency = KLD_2Cats_v2(tf.stop_gradient(ya_prob), y_prob)
        elif self.cons_mode == '2rand':
            print("cons_mode=2rand!")
            from my_utils.tensorflow_utils.activations import log_w_clip
            # IMPORTANT: We try to stop gradient here!
            # consistency = -log_w_clip(tf.reduce_sum(y_prob * ya_prob, axis=1))
            consistency = -log_w_clip(tf.reduce_sum(y_prob * tf.stop_gradient(ya_prob), axis=1))
        else:
            raise ValueError("Do not support 'cons_mode'={}!".format(self.cons_mode))

        if self.cons_4_unlabeled_only:
            label_flag_inv = self.get_output('label_flag_inv')
            num_unlabeled = self.get_output('num_unlabeled')
            consistency = tf.reduce_sum(consistency * label_flag_inv, axis=0) * 1.0 / (num_unlabeled + 1e-8)
        else:
            consistency = tf.reduce_mean(consistency, axis=0)

        results = {
            'cons': consistency,
        }

        return results
예제 #3
0
    def _class_loss(self):
        # This function considers both labeled and unlabeled data in a single batch
        y_idx = self.y_ph
        y = self.get_output('y')
        y_dist = self.get_output('y_dist')
        label_flag = self.get_output('label_flag')
        label_flag_inv = self.get_output('label_flag_inv')
        num_labeled = self.get_output('num_labeled')
        num_unlabeled = self.get_output('num_unlabeled')

        y_logit, y_prob = y_dist['logit'], y_dist['prob']

        # Cross entropy loss for labeled data
        cross_ent_l = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=y, logits=y_logit, dim=-1)
        cross_ent_l = tf.reduce_sum(cross_ent_l * label_flag, axis=0) * 1.0 / (num_labeled + 1e-8)

        # Conditional entropy loss for unlabeld data
        cond_ent_u = tf.reduce_sum(-y_prob * log_w_clip(y_prob), axis=1)
        cond_ent_u = tf.reduce_sum(cond_ent_u * label_flag_inv, axis=0) * 1.0 / (num_unlabeled + 1e-8)

        y_pred = tf.argmax(y_prob, axis=1, output_type=tf.int32)
        y_matched = tf.cast(tf.equal(y_pred, y_idx), dtype=tf.float32)
        acc_y_l = tf.reduce_sum(y_matched * label_flag, axis=0) * 1.0 / (num_labeled + 1e-8)
        acc_y_u = tf.reduce_sum(y_matched * label_flag_inv, axis=0) * 1.0 / (num_unlabeled + 1e-8)
        acc_y = tf.reduce_mean(y_matched, axis=0)

        results = {
            'cross_ent_l': cross_ent_l,
            'cond_ent_u': cond_ent_u,
            'y_pred': y_pred,
            'acc_y_l': acc_y_l,
            'acc_y_u': acc_y_u,
            'acc_y': acc_y,
        }

        return results
예제 #4
0
    def __init__(self,
                 params,
                 swa_params=None,
                 param_scales=None,
                 collection=None):
        self._num_updates = 0
        self._num_updates_ph = tf.placeholder(dtype=tf.float32,
                                              shape=[],
                                              name="num_updates")

        self._params = params

        if param_scales is None:
            self._param_scales = [None for _ in range(len(params))]
        else:
            assert len(param_scales) == len(params), "If 'param_scales' is not None, " \
                "its length ({}) must be equal to the length of 'params' ({})!"\
                .format(len(param_scales), len(params))
            self._param_scales = param_scales

        self.collection = "STOCHASTIC_WEIGHT_AVERAGE" if None else collection

        self.swa_params_dict = {}
        self.swa_updates = []
        self.swa_updates_first = [
        ]  # Update at first step, simply assign params to swa_params

        # Create 'swa_params' if it is not provided
        if swa_params is None:
            self.swa_params = []

            for param in params:
                # assert isinstance(param, tf.Variable), "type(param) = {}".format(type(param))
                with tf.variable_scope(param.op.name):
                    # Create swa_param
                    swa_param = tf.get_variable(
                        "SWA",  # shape=param.shape,
                        dtype=param._initial_value.dtype,
                        initializer=param._initial_value,
                        trainable=False,
                        collections=self.collection)
                    self.swa_params.append(swa_param)
                    self.swa_params_dict[param.op.name] = swa_param
        else:
            assert len(swa_params) == len(params), "len(swa_params) ({}) must " \
                "be equal to len(params) ({})!".format(len(swa_params), len(params))

            self.swa_params = swa_params
            for param, swa_param in zip(params, swa_params):
                self.swa_params_dict[param.op.name] = swa_param

        for i, (param, swa_param) in enumerate(zip(params, self.swa_params)):
            # Update
            if self._param_scales[i] is None:
                new_swa_param = (swa_param * (self._num_updates_ph - 1) + param) \
                                / self._num_updates_ph
            # log(w)
            elif self._param_scales[i] == "log":
                from my_utils.tensorflow_utils.activations import exp_w_clip, log_w_clip
                new_swa_param = (exp_w_clip(swa_param) * (self._num_updates_ph - 1) + exp_w_clip(param)) \
                                / self._num_updates_ph
                new_swa_param = log_w_clip(new_swa_param)
            elif self._param_scales[i] == "exp":
                from my_utils.tensorflow_utils.activations import exp_w_clip, log_w_clip
                new_swa_param = (log_w_clip(swa_param) * (self._num_updates_ph - 1) + log_w_clip(param)) \
                                / self._num_updates_ph
                new_swa_param = exp_w_clip(new_swa_param)
            else:
                raise ValueError(
                    "Do not support scale = '{}' for {}-th param!".format(
                        self._param_scales[i], i))

            swa_update = tf.assign(swa_param, new_swa_param)
            self.swa_updates.append(swa_update)

            # Update at first step
            swa_update_first = tf.assign(swa_param, param)
            self.swa_updates_first.append(swa_update_first)
예제 #5
0
    def _mur_loss(self):
        from tensorflow.contrib.graph_editor import graph_replace

        # IMPORTANT: We use 'x_pert_stu' to ensure no perturbation on the input
        # (batch, x_dim)
        x0 = self.get_output('x_pert_stu')

        # IMPORTANT: The output here is 'y_dist_stu_sto' not 'y_dist_stu'
        # (batch, num_classes)
        y0_prob = self.get_output('y_dist_stu_sto')['prob']
        # (batch, )
        cond_ent0 = tf.reduce_sum(-y0_prob * log_w_clip(y0_prob), axis=1)

        normalized_axes = list(range(1, x0.shape.ndims))
        g0 = tf.gradients(cond_ent0, [x0])[0]
        g0_norm = tf.stop_gradient(
            tf.sqrt(tf.reduce_sum(g0**2, axis=normalized_axes, keepdims=True))
            + 1e-15)

        rad = self.mur_noise_radius

        # Direct approximation
        if self.mur_opt_steps == 1:
            print("Direct approximation of x*!")
            eps = rad * g0 / (g0_norm + 1e-8)
            x_final = tf.stop_gradient(x0 + eps)

        else:
            lr = self.mur_opt_lr

            if self.mur_iter_mode == "grad_asc_w_lagrangian_relax":
                print(
                    "Iterative approximation of x* using vanilla gradient ascent!"
                )
                x_t = x0
                cond_ent_t = cond_ent0

                for _ in range(self.mur_opt_steps):
                    grad_x_t = tf.gradients(cond_ent_t, [x_t])[0]

                    xt_m_x0_norm = tf.stop_gradient(
                        tf.sqrt(
                            tf.reduce_sum((x_t - x0)**2,
                                          axis=normalized_axes,
                                          keepdims=True)) + 1e-15)

                    # Update 'x_t' and 'cond_ent_t'
                    x_t = tf.stop_gradient(
                        x_t + lr *
                        (grad_x_t - g0_norm / self.mur_noise_radius *
                         (x_t - x0) * (2 - self.mur_noise_radius /
                                       (xt_m_x0_norm + 1e-15))))
                    cond_ent_t = graph_replace(cond_ent0,
                                               replacement_ts={x0: x_t})

                x_final = x_t

            elif self.mur_iter_mode == "proj_grad_asc":
                print(
                    "Iterative approximation of x* using project gradient ascent!"
                )
                x_t = x0
                cond_ent_t = cond_ent0

                for _ in range(self.mur_opt_steps):
                    grad_x_t = tf.gradients(cond_ent_t, [x_t])[0]

                    z_t = x_t + lr * grad_x_t

                    # (batch, 1, 1, 1)
                    zt_m_x0_norm = tf.stop_gradient(
                        tf.sqrt(
                            tf.reduce_sum((z_t - x0)**2,
                                          axis=normalized_axes,
                                          keepdims=True)) + 1e-15)

                    cond = tf.cast(tf.less_equal(zt_m_x0_norm, rad),
                                   dtype=tf.float32)
                    x_t = cond * z_t + (1.0 -
                                        cond) * (x0 +
                                                 (z_t - x0) / zt_m_x0_norm)
                    x_t = tf.stop_gradient(x_t)

                    cond_ent_t = graph_replace(cond_ent0,
                                               replacement_ts={x0: x_t})

                x_final = x_t

            else:
                raise ValueError(self.mur_iter_mode)

        y_prob_final = graph_replace(y0_prob, replacement_ts={x0: x_final})

        if self.mur_mode == "mse_wrt_point":
            mur = tf.reduce_sum(
                tf.square(tf.stop_gradient(y0_prob) - y_prob_final), axis=1)
        elif self.mur_mode == "mse_wrt_neigh":
            mur = tf.reduce_sum(tf.square(y0_prob -
                                          tf.stop_gradient(y_prob_final)),
                                axis=1)
        else:
            raise ValueError("Do not support mur_mode={}!".format(
                self.mur_mode))

        if self.mur_4_unlabeled_only:
            label_flag_inv = self.get_output('label_flag_inv')
            num_unlabeled = self.get_output('num_unlabeled')
            mur = tf.reduce_sum(mur * label_flag_inv,
                                axis=0) * 1.0 / (num_unlabeled + 1e-8)
        else:
            mur = tf.reduce_mean(mur, axis=0)

        return {'grad_norm_avg': tf.reduce_mean(g0_norm), 'mur': mur}