Ejemplo n.º 1
0
def prototypical_layer(nclasses, x_train, y_train, x_test, phi, ext_wts=None):
    """Computes the prototypes, cluster centers.
  Args:
    x_train: [N, ...], Train data.
    y_train: [N], Train class labels.
    x_test: [N, ...], Test data.
    phi: Feature extractor function.
  Returns:
    logits: [N, K], Test prediction.
  """
    protos = [None] * nclasses
    x_all = concat([x_train, x_test], 0)
    h, _ = phi(x_all, reuse=None, is_training=True, ext_wts=ext_wts)
    num_x_train = tf.shape(x_train)[0]
    h_train = h[:num_x_train, :]
    h_test = h[num_x_train:, :]
    for kk in range(nclasses):
        ksel = tf.expand_dims(tf.cast(tf.equal(y_train, kk), h_train.dtype),
                              1)  # [N, 1]
        protos[kk] = tf.reduce_sum(h_train * ksel, [0],
                                   keep_dims=True)  # [N, D]
        protos[kk] /= tf.reduce_sum(ksel)
        protos[kk] = tf.Print(protos[kk], [
            'proto',
            tf.reduce_mean(protos[kk]),
            tf.reduce_max(protos[kk]),
            tf.reduce_min(protos[kk])
        ])
    protos = concat(protos, 0)  # [K, D]
    logits = compute_logits(protos, h_test)
    return logits
Ejemplo n.º 2
0
def prototypical_clustering_gmm_layer(nclasses,
                                      x_train,
                                      y_train,
                                      x_test,
                                      phi,
                                      num_cluster_steps,
                                      lambd=0.1,
                                      alpha=0.0):
    """Computes the prototypes, cluster centers, with additional clustering on
    the validation data.
  Args:
    x_train: [N, D], Train data.
    y_train: [N], Train class labels.
    x_test: [N, D], Test data.
    phi: Feature extractor function.
  Returns:
    logits: [N, K], Test prediction.
  """
    protos = [None] * nclasses
    covar = [None] * nclasses
    x_all = concat([x_train, x_test], 0)
    h, _ = phi(x_all, reuse=None, is_training=True)
    num_x_train = tf.shape(x_train)[0]
    h_train = h[:num_x_train, :]
    h_test = h[num_x_train:, :]
    ndim = tf.shape(h)[1]

    # Initialize cluster center.
    for kk in range(nclasses):
        ksel = tf.expand_dims(tf.cast(tf.equal(y_train, kk), h_train.dtype),
                              1)  # [N, 1]
        protos[kk] = tf.reduce_sum(h_train * ksel, [0],
                                   keep_dims=True)  # [N, D]
        protos[kk] /= tf.reduce_sum(ksel)
        #covar[kk] = tf.expand_dims(tf.eye(ndim), 0)
        covar[kk] = tf.ones([1, ndim])  # diagonal
    protos = concat(protos, 0)  # [K, D]
    covar = concat(covar, 0)  # [K, D, D]

    # Run clustering.
    for tt in range(num_cluster_steps):
        all_data = concat([h_train, h_test], 0)
        # Label assignment.
        prob = assign_gmm_diag_cluster(protos, covar, all_data, nclasses)
        protos, covar = update_gmm_diag_cluster(all_data, prob)

    # Be cautious here!!
    uloss = mh_diag_dist_loss(protos, covar, all_data, nclasses)
    logits = compute_gmm_diag_logits(protos, covar, h_test, nclasses)
    return logits, uloss
Ejemplo n.º 3
0
def prototypical_clustering_layer(nclasses,
                                  x_train,
                                  y_train,
                                  x_test,
                                  phi,
                                  num_cluster_steps,
                                  lambd=0.1,
                                  alpha=0.0):
    """Computes the prototypes, cluster centers, with additional clustering on
    the validation data.
  Args:
    x_train: [N, D], Train data.
    y_train: [N], Train class labels.
    x_test: [N, D], Test data.
    phi: Feature extractor function.
  Returns:
    logits: [N, K], Test prediction.
  """
    protos = [None] * nclasses
    x_all = concat([x_train, x_test], 0)
    h, _ = phi(x_all, reuse=None, is_training=True)
    num_x_train = tf.shape(x_train)[0]
    h_train = h[:num_x_train, :]
    h_test = h[num_x_train:, :]

    # Initialize cluster center.
    for kk in range(nclasses):
        ksel = tf.expand_dims(tf.cast(tf.equal(y_train, kk), h_train.dtype),
                              1)  # [N, 1]
        protos[kk] = tf.reduce_sum(h_train * ksel, [0],
                                   keep_dims=True)  # [N, D]
        protos[kk] /= tf.reduce_sum(ksel)
    protos = concat(protos, 0)  # [K, D]

    # Run clustering.
    for tt in range(num_cluster_steps):
        all_data = concat([h_train, h_test], 0)
        # Label assignment.
        prob = assign_cluster(protos, all_data)
        # Prototype update.
        ## This is a vanilla version.
        ## Probably want to impose a constraint that each training example can only
        ## be in one cluster.
        protos = update_cluster(all_data, prob)

    # Be cautious here!!
    # Returns the logits and unsupervised clustering loss.
    uloss = sq_dist_loss(protos, all_data)
    logits = compute_logits(protos, h_test)
    return logits, uloss
Ejemplo n.º 4
0
    def _compute_protos(self, nclasses, h_train, y_train):
        """Computes the prototypes, cluster centers.
    Args:
      nclasses: Int. Number of classes.
      h_train: [B, N, D], Train features.
      y_train: [B, N], Train class labels.
    Returns:
      protos: [B, K, D], Test prediction.
    """
        with tf.name_scope('Compute-protos'):
            protos = [None] * nclasses
            for kk in range(nclasses):
                # [B, N, 1]
                ksel = tf.expand_dims(
                    tf.cast(tf.equal(y_train, kk), h_train.dtype), 2)
                # [B, N, D]
                protos[kk] = tf.reduce_sum(h_train * ksel, [1], keep_dims=True)
                protos[kk] /= tf.reduce_sum(ksel, [1, 2], keep_dims=True)
                protos[kk] = debug_identity(protos[kk], "proto")
            protos = concat(protos, 1)  # [B, K, D]
            self.adv_summaries.append(
                tf.summary.histogram('Proto norms',
                                     tf.norm(tf.squeeze(protos), axis=1)))

        return protos
Ejemplo n.º 5
0
 def get_encoded_inputs(self, *x_list, **kwargs):
     """Runs the reference and candidate images through the feature model phi.
 Returns:
   h_train: [B, N, D]
   h_unlabel: [B, P, D]
   h_test: [B, M, D]
 """
     if 'ext_wts' in kwargs:
         ext_wts = kwargs['ext_wts']
     else:
         ext_wts = None
     VAT = False
     if 'VAT' in kwargs:
         VAT = kwargs['VAT']
     config = self.config
     bsize = tf.shape(self.x_train)[0]
     bsize = tf.shape(x_list[0])[0]
     num = [tf.shape(xx)[1] for xx in x_list]
     x_all = concat(x_list, 1)
     x_all = tf.reshape(
         x_all, [-1, config.height, config.width, config.num_channel])
     h_all = self.phi(x_all, ext_wts=ext_wts, VAT=VAT)
     tf.assert_greater(tf.reduce_mean(tf.abs(h_all)), 0.0)
     # h_all_p = self.phi(tf.random_normal(tf.shape(x_all)), ext_wts=ext_wts)
     # h_all = tf.Print(h_all, [tf.reduce_sum(h_all),tf.reduce_sum(h_all - h_all_p)], '\n-----------')
     h_all = tf.reshape(h_all, [bsize, sum(num), -1])
     h_list = tf.split(h_all, num, axis=1)
     return h_list
Ejemplo n.º 6
0
def compute_gmm_logits(cluster_centers, cluster_covar, data, nclasses):
    """Computes the logits of being in one cluster, Manhalanobis.
  Args:
    cluster_centers: [K, D] Cluster center representation.
    cluster_covar: [K, D, D] Cluster covariance matrix.
    data: [N, D] Data representation.
    nclasses: Integer. K, number of classes.
  Returns:
    log_prob: [N, K] logits.
  """
    cluster_centers = tf.expand_dims(cluster_centers, 0)  # [1, K, D]
    data = tf.expand_dims(data, 1)  # [N, 1 D]
    diff = data - cluster_centers  # [N, K, D]
    result = []  # [N, K]
    print("covar", cluster_covar.get_shape())
    for kk in range(nclasses):
        _covar = cluster_covar[kk, :, :]  # [D, D]
        _diff = diff[:, kk, :]  # [N, D]
        _diff_ = tf.expand_dims(_diff, 2)
        _icovar = tf.matrix_inverse(_covar)
        _icovar = tf.expand_dims(_icovar, 0)  # [1, D, D]
        prod = tf.reduce_sum(_diff_ * _icovar, [1])  # [N, D]
        print(_icovar.get_shape())
        print(_diff_.get_shape())
        print(prod.get_shape())
        prod = tf.reduce_sum(_diff * prod, [1], keep_dims=True)  # [N, 1]
        result.append(-prod)
    logits = concat(result, 1)
    print("logits", logits.get_shape())
    return logits
Ejemplo n.º 7
0
 def encode(self, *x_list, **kwargs):
     """
 """
     if 'ext_wts' in kwargs:
         ext_wts = kwargs['ext_wts']
     else:
         ext_wts = None
     update_batch_stats = True
     if 'update_batch_stats' in kwargs:
         update_batch_stats = kwargs['update_batch_stats']
     config = self.config
     bsize = tf.shape(self.x_train)[0]
     bsize = tf.shape(x_list[0])[0]
     num = [tf.shape(xx)[1] for xx in x_list]
     x_all = concat(x_list, 1)
     x_all = tf.reshape(
         x_all, [-1, config.height, config.width, config.num_channel])
     h_all = self.phi(x_all,
                      ext_wts=ext_wts,
                      update_batch_stats=update_batch_stats)
     # tf.assert_greater(tf.reduce_mean(tf.abs(h_all)), 0.0)
     # h_all_p = self.phi(tf.random_normal(tf.shape(x_all)), ext_wts=ext_wts)
     # h_all = tf.Print(h_all, [tf.shape(h_all)], '\n-----------')
     h_all = tf.reshape(h_all, [bsize, sum(num), -1])
     h_list = tf.split(h_all, num, axis=1)
     return h_list
    def predict(self):
        """See `model.py` for documentation."""
        super().predict()

        nclasses = self.nway
        num_cluster_steps = self.config.num_cluster_steps
        h_train, h_unlabel, h_test = self.encode(self.x_train, self.x_unlabel,
                                                 self.x_test)
        y_train = self.y_train
        protos = self._compute_protos(nclasses, h_train, y_train)
        logits = compute_logits(protos, h_test)

        # Hard assignment for training images.
        prob_train = [None] * nclasses
        for kk in range(nclasses):
            # [B, N, 1]
            prob_train[kk] = tf.expand_dims(
                tf.cast(tf.equal(y_train, kk), h_train.dtype), 2)
        prob_train = concat(prob_train, 2)

        h_all = concat([h_train, h_unlabel], 1)

        logits_list = []
        logits_list.append(compute_logits(protos, h_test))

        # Run clustering.
        for tt in range(num_cluster_steps):
            # Label assignment.
            prob_unlabel = assign_cluster(protos, h_unlabel)
            entropy = tf.reduce_sum(-prob_unlabel * tf.log(prob_unlabel), [2],
                                    keep_dims=True)
            prob_all = concat([prob_train, prob_unlabel], 1)
            prob_all = tf.stop_gradient(prob_all)
            protos = update_cluster(h_all, prob_all)
            # protos = tf.cond(
            #     tf.shape(self._x_unlabel)[1] > 0,
            #     lambda: update_cluster(h_all, prob_all), lambda: protos)
            logits_list.append(compute_logits(protos, h_test))

            self._unlabel_logits = compute_logits(self.protos, h_unlabel)[0]
            self._logits = logits_list
 def noisy_forward(self,
                   data,
                   noise=tf.constant(0.0),
                   update_batch_stats=False):
     with tf.name_scope("forward"):
         protos = concat(
             [self.protos,
              tf.zeros_like(self.protos[:, 0:1, :])], 1)
         encoded = self.phi(data + noise,
                            update_batch_stats=update_batch_stats)
         logits = compute_logits_radii(protos, tf.expand_dims(encoded, 0),
                                       self.radii)
     return logits
Ejemplo n.º 10
0
    def _compute_protos(self, nclasses, h_train, y_train):
        num_points = self.nshot
        with tf.name_scope('Compute-protos'):
            protos = [None] * nclasses
            for kk in range(nclasses):
                # [B, N, 1]
                ksel = tf.expand_dims(
                    tf.cast(tf.equal(y_train, kk), h_train.dtype), 2)
                protos[kk] = construct_proto(
                    tf.boolean_mask(h_train, ksel[:, :, 0]), num_points)
            protos = concat(protos, 1)  # [B, K, D]

            return protos
Ejemplo n.º 11
0
 def _compute_protos(self, nclasses, h_train, y_train):
     """Computes the prototypes, cluster centers.
 Args:
   nclasses: Int. Number of classes.
   h_train: [B, N, D], Train features.
   y_train: [B, N], Train class labels.
 Returns:
   protos: [B, K, D], Test prediction.
 """
     protos = [None] * nclasses
     for kk in range(nclasses):
         # [B, N, 1]
         ksel = tf.expand_dims(
             tf.cast(tf.equal(y_train, kk), h_train.dtype), 2)
         # [B, N, D]
         protos[kk] = tf.reduce_sum(h_train * ksel, [1], keep_dims=True)
         protos[kk] /= tf.reduce_sum(ksel, [1, 2], keep_dims=True)
         protos[kk] = debug_identity(protos[kk], "proto")
     protos = concat(protos, 1)  # [B, K, D]
     return protos
Ejemplo n.º 12
0
 def get_encoded_inputs(self, *x_list, **kwargs):
     """Runs the reference and candidate images through the feature model phi.
 Returns:
   h_train: [B, N, D]
   h_unlabel: [B, P, D]
   h_test: [B, M, D]
 """
     config = self.config
     bsize = tf.shape(self.x_train)[0]
     bsize = tf.shape(x_list[0])[0]
     num = [tf.shape(xx)[1] for xx in x_list]
     x_all = concat(x_list, 1)
     if 'ext_wts' in kwargs:
         ext_wts = kwargs['ext_wts']
     else:
         ext_wts = None
     x_all = tf.reshape(
         x_all, [-1, config.height, config.width, config.num_channel])
     h_all = self.phi(x_all, ext_wts=ext_wts)
     h_all = tf.reshape(h_all, [bsize, sum(num), -1])
     h_list = tf.split(h_all, num, axis=1)
     return h_list
Ejemplo n.º 13
0
def update_cluster(data, prob, fix_last_row=False):
  """Updates cluster center based on assignment, standard K-Means.
  Args:
    data: [B, N, D]. Data representation.
    prob: [B, N, K]. Cluster assignment soft probability.
    fix_last_row: Bool. Whether or not to fix the last row to 0.
  Returns:
    cluster_centers: [B, K, D]. Cluster center representation.
  """
  # Normalize accross N.
  if fix_last_row:
    prob_ = prob[:, :, :-1]
  else:
    prob_ = prob
  prob_sum = tf.reduce_sum(prob_, [1], keep_dims=True)
  prob_sum += tf.to_float(tf.equal(prob_sum, 0.0))
  prob2 = prob_ / prob_sum
  cluster_centers = tf.reduce_sum(
      tf.expand_dims(data, 2) * tf.expand_dims(prob2, 3), [1])
  if fix_last_row:
    cluster_centers = concat(
        [cluster_centers,
         tf.zeros_like(cluster_centers[:, 0:1, :])], 1)
  return cluster_centers
    def predict(self):
        """See `model.py` for documentation."""
        nclasses = self.nway
        num_cluster_steps = self.config.num_cluster_steps
        h_train, h_unlabel, h_test = self.encode(self.x_train, self.x_unlabel,
                                                 self.x_test)
        y_train = self.y_train
        protos = self._compute_protos(nclasses, h_train, y_train)
        logits_list = []
        logits_list.append(compute_logits(protos, h_test))

        # Hard assignment for training images.
        prob_train = [None] * (nclasses)
        for kk in range(nclasses):
            # [B, N, 1]
            prob_train[kk] = tf.expand_dims(
                tf.cast(tf.equal(y_train, kk), h_train.dtype), 2)
        prob_train = concat(prob_train, 2)

        y_train_shape = tf.shape(y_train)
        bsize = y_train_shape[0]

        h_all = concat([h_train, h_unlabel], 1)
        mask = None

        # Calculate pairwise distances.
        protos_1 = tf.expand_dims(protos, 2)
        protos_2 = tf.expand_dims(h_unlabel, 1)
        pair_dist = tf.reduce_sum((protos_1 - protos_2)**2, [3])  # [B, K, N]
        mean_dist = tf.reduce_mean(pair_dist, [2], keep_dims=True)
        pair_dist_normalize = pair_dist / mean_dist
        min_dist = tf.reduce_min(pair_dist_normalize, [2],
                                 keep_dims=True)  # [B, K, 1]
        max_dist = tf.reduce_max(pair_dist_normalize, [2], keep_dims=True)
        mean_dist, var_dist = tf.nn.moments(pair_dist_normalize, [2],
                                            keep_dims=True)
        mean_dist += tf.to_float(tf.equal(mean_dist, 0.0))
        var_dist += tf.to_float(tf.equal(var_dist, 0.0))
        skew = tf.reduce_mean(
            ((pair_dist_normalize - mean_dist)**3) / (tf.sqrt(var_dist)**3),
            [2],
            keep_dims=True)
        kurt = tf.reduce_mean(
            ((pair_dist_normalize - mean_dist)**4) / (var_dist**2) - 3, [2],
            keep_dims=True)

        n_features = 5
        n_out = 3

        dist_features = tf.reshape(
            concat([min_dist, max_dist, var_dist, skew, kurt], 2),
            [-1, n_features])  # [BK, 4]
        dist_features = tf.stop_gradient(dist_features)

        hdim = [n_features, 20, n_out]
        act_fn = [tf.nn.tanh, None]
        thresh = mlp(dist_features,
                     hdim,
                     is_training=True,
                     act_fn=act_fn,
                     dtype=tf.float32,
                     add_bias=True,
                     wd=None,
                     init_std=[0.01, 0.01],
                     init_method=None,
                     scope="dist_mlp",
                     dropout=None,
                     trainable=True)
        scale = tf.exp(thresh[:, 2])
        bias_start = tf.exp(thresh[:, 0])
        bias_add = thresh[:, 1]
        bias_start = tf.reshape(bias_start, [bsize, 1, -1])  #[B, 1, K]
        bias_add = tf.reshape(bias_add, [bsize, 1, -1])

        self._scale = scale
        self._bias_start = bias_start
        self._bias_add = bias_add

        # Run clustering.
        for tt in range(num_cluster_steps):
            protos_1 = tf.expand_dims(protos, 2)
            protos_2 = tf.expand_dims(h_unlabel, 1)
            pair_dist = tf.reduce_sum((protos_1 - protos_2)**2,
                                      [3])  # [B, K, N]
            m_dist = tf.reduce_mean(pair_dist, [2])  # [B, K]
            m_dist_1 = tf.expand_dims(m_dist, 1)  # [B, 1, K]
            m_dist_1 += tf.to_float(tf.equal(m_dist_1, 0.0))
            # Label assignment.
            if num_cluster_steps > 1:
                bias_tt = bias_start + (
                    tt / float(num_cluster_steps - 1)) * bias_add
            else:
                bias_tt = bias_start

            negdist = compute_logits(protos, h_unlabel)
            mask = tf.sigmoid((negdist / m_dist_1 + bias_tt) * scale)
            prob_unlabel, mask = assign_cluster_soft_mask(
                protos, h_unlabel, mask)
            prob_all = concat([prob_train, prob_unlabel * mask], 1)
            # No update if 0 unlabel.
            protos = tf.cond(
                tf.shape(self._x_unlabel)[1] > 0,
                lambda: update_cluster(h_all, prob_all), lambda: protos)
            logits_list.append(compute_logits(protos, h_test))

        # Distractor evaluation.
        if mask is not None:
            max_mask = tf.reduce_max(mask, [2])
            mean_mask = tf.reduce_mean(max_mask)
            pred_non_distractor = tf.to_float(max_mask > mean_mask)
            acc, recall, precision = eval_distractor(pred_non_distractor,
                                                     self.y_unlabel)
            self._non_distractor_acc = acc
            self._distractor_recall = recall
            self._distractor_precision = precision
            self._distractor_pred = max_mask
        return logits_list
    def predict(self):
        """See `model.py` for documentation."""
        nclasses = self.nway
        num_cluster_steps = self.config.num_cluster_steps
        h_train, h_unlabel, h_test = self.get_encoded_inputs(
            self.x_train, self.x_unlabel, self.x_test)
        y_train = self.y_train
        protos = self._compute_protos(nclasses, h_train, y_train)

        # Distractor class has a zero vector as prototype.
        protos = concat([protos, tf.zeros_like(protos[:, 0:1, :])], 1)

        # Hard assignment for training images.
        prob_train = [None] * (nclasses + 1)
        for kk in range(nclasses):
            # [B, N, 1]
            prob_train[kk] = tf.expand_dims(
                tf.cast(tf.equal(y_train, kk), h_train.dtype), 2)
            prob_train[-1] = tf.zeros_like(prob_train[0])
        prob_train = concat(prob_train, 2)

        # Initialize cluster radii.
        radii = [None] * (nclasses + 1)
        y_train_shape = tf.shape(y_train)
        bsize = y_train_shape[0]
        for kk in range(nclasses):
            radii[kk] = tf.ones([bsize, 1]) * 1.0

        # Distractor class has a larger radius.
        if FLAGS.learn_radius:
            log_distractor_radius = tf.get_variable(
                "log_distractor_radius",
                shape=[],
                dtype=tf.float32,
                initializer=tf.constant_initializer(np.log(FLAGS.init_radius)))
            distractor_radius = tf.exp(log_distractor_radius)
        else:
            distractor_radius = FLAGS.init_radius
        distractor_radius = tf.cond(
            tf.shape(self._x_unlabel)[1] > 0, lambda: distractor_radius,
            lambda: 100000.0)
        # distractor_radius = tf.Print(distractor_radius, [distractor_radius])
        radii[-1] = tf.ones([bsize, 1]) * distractor_radius
        radii = concat(radii, 1)  # [B, K]

        h_all = concat([h_train, h_unlabel], 1)
        logits_list = []
        logits_list.append(compute_logits_radii(protos, h_test, radii))

        # Run clustering.
        for tt in range(num_cluster_steps):
            # Label assignment.
            prob_unlabel = assign_cluster_radii(protos, h_unlabel, radii)
            prob_all = concat([prob_train, prob_unlabel], 1)
            protos = update_cluster(h_all, prob_all)
            logits_list.append(compute_logits_radii(protos, h_test, radii))

        # Distractor evaluation.
        is_distractor = tf.equal(tf.argmax(prob_unlabel, axis=-1), nclasses)
        pred_non_distractor = 1.0 - tf.to_float(is_distractor)
        acc, recall, precision = eval_distractor(pred_non_distractor,
                                                 self.y_unlabel)
        self._non_distractor_acc = acc
        self._distractor_recall = recall
        self._distractor_precision = precision
        self._distractor_pred = 1.0 - tf.exp(prob_unlabel[:, :, -1])
        return logits_list
Ejemplo n.º 16
0
def prototypical_clustering_learn_layer(nclasses,
                                        x_train,
                                        y_train,
                                        x_test,
                                        phi,
                                        num_cluster_steps,
                                        lambd=0.1,
                                        alpha=0.0):
    """Computes the prototypes, cluster centers, with additional clustering on
    the validation data.
  Args:
    x_train: [N, D], Train data.
    y_train: [N], Train class labels.
    x_test: [N, D], Test data.
    phi: Feature extractor function.
  Returns:
    logits: [N, K], Test prediction.
  """
    protos = [None] * nclasses
    x_all = concat([x_train, x_test], 0)
    h, wts = phi(x_all, reuse=None, is_training=True)
    num_x_train = tf.shape(x_train)[0]
    h_train = h[:num_x_train, :]
    h_test = h[num_x_train:, :]
    wts_keys = wts.keys()
    wts_tensors = [wts[kk] for kk in wts_keys]

    for kk in range(nclasses):
        ksel = tf.expand_dims(tf.cast(tf.equal(y_train, kk), h_train.dtype),
                              1)  # [N, 1]
        protos[kk] = tf.reduce_sum(h_train * ksel, [0],
                                   keep_dims=True)  # [N, D]
        protos[kk] /= tf.reduce_sum(ksel)

    protos = concat(protos, 0)  # [K, D]

    for tt in range(num_cluster_steps):
        # Compute new representation of the data.
        h, _ = phi(x_all,
                   reuse=True,
                   is_training=True,
                   ext_wts=dict(zip(wts_keys, wts_tensors)))
        h_train = h[:num_x_train, :]
        h_test = h[num_x_train:, :]
        all_data = concat([h_train, h_test], 0)

        # Label assignment.
        prob = assign_cluster(protos, all_data)

        # Prototype update.
        ## This is a vanilla version.
        ## Probably want to impose a constraint that each training example can only
        ## be in one cluster.
        protos = update_cluster(all_data, prob)

        if alpha > 0.0 and tt < num_cluster_steps - 1:
            # We can also use soft labels here.
            loss = lambd * sq_dist_loss(protos, all_data)
            # One gradient update towards fast weights.
            print(wts_tensors)
            [print(vv.name) for vv in wts_tensors]
            grads = tf.gradients(loss, wts_tensors, gate_gradients=True)
            [print(gg) for gg in grads]
            # Stop the gradient of the gradient.
            grads = [tf.stop_gradient(gg) for gg in grads]
            wts_tensors = [
                wt - alpha * gg for wt, gg in zip(wts_tensors, grads)
            ]

    # Be cautious here!!
    #uloss = sq_dist_loss(protos, all_data)
    uloss = 0.0
    logits = compute_logits(protos, h_test)
    return logits, uloss