Exemplo n.º 1
0
    def classification_train_op(self):
        with tf.name_scope('Classification_train_op'):
            config = self.config
            classification_weight = config.classification_weight
            features = self.phi(self.training_data, update_batch_stats=False)

            logits = compute_logits(self._persistent_protos, features)
            # probs = tf.nn.softmax(logits/1000)
            # logits = tf.Print(logits, [self.training_labels],summarize=50)
            loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logits / 200, labels=self.training_labels)
            loss = tf.reduce_mean(loss)
            opt = tf.train.AdamOptimizer(classification_weight *
                                         self.learn_rate,
                                         name='Classification-Optimizer')
            grads_and_vars = opt.compute_gradients(loss)
            train_op = opt.apply_gradients(grads_and_vars)
            self.adv_summaries.append(
                tf.summary.scalar("loss", loss, family='classification'))

            for gradient, variable in grads_and_vars:
                if gradient is None:
                    gradient = tf.constant(0.0)
                self.adv_summaries.append(
                    tf.summary.scalar("gradients/" + variable.name,
                                      l2_norm(gradient),
                                      family="classification"))
                # self.adv_summaries.append(tf.summary.scalar("variables/" + variable.name, l2_norm(variable), family="VARS"))
                # self.adv_summaries.append(tf.summary.histogram("gradients/" + variable.name, gradient, family="Grads"))

        return loss, train_op
Exemplo n.º 2
0
	def noisy_forward(self, data, noise=tf.constant(0.0), update_batch_stats=False, wts=None):
		if wts is None:
			wts = self.embedding_weights
		with tf.name_scope("forward"):
			encoded = self.phi(data+noise, update_batch_stats=update_batch_stats, ext_wts=wts)
			logits = compute_logits(self.protos, encoded)
		return logits
Exemplo n.º 3
0
 def predict(self):
     """See `model.py` for documentation."""
     h_train, h_test = self.get_encoded_inputs(self.x_train, self.x_test)
     y_train = self.y_train
     nclasses = self.nway
     protos = self._compute_protos(nclasses, h_train, y_train)
     logits = compute_logits(protos, h_test)
     return [logits]
Exemplo n.º 4
0
 def noisy_forward(self,
                   data,
                   noise=tf.constant(0.0),
                   update_batch_stats=False):
     with tf.name_scope("forward"):
         encoded = self.phi(data + noise,
                            update_batch_stats=update_batch_stats)
         logits = compute_logits(self.protos, encoded)
     return logits
    def predict(self):
        """See `model.py` for documentation."""
        super().predict()

        nclasses = self.nway
        num_cluster_steps = self.config.num_cluster_steps
        h_train, h_unlabel, h_test = self.encode(self.x_train, self.x_unlabel,
                                                 self.x_test)
        y_train = self.y_train
        protos = self._compute_protos(nclasses, h_train, y_train)
        logits = compute_logits(protos, h_test)

        # Hard assignment for training images.
        prob_train = [None] * nclasses
        for kk in range(nclasses):
            # [B, N, 1]
            prob_train[kk] = tf.expand_dims(
                tf.cast(tf.equal(y_train, kk), h_train.dtype), 2)
        prob_train = concat(prob_train, 2)

        h_all = concat([h_train, h_unlabel], 1)

        logits_list = []
        logits_list.append(compute_logits(protos, h_test))

        # Run clustering.
        for tt in range(num_cluster_steps):
            # Label assignment.
            prob_unlabel = assign_cluster(protos, h_unlabel)
            entropy = tf.reduce_sum(-prob_unlabel * tf.log(prob_unlabel), [2],
                                    keep_dims=True)
            prob_all = concat([prob_train, prob_unlabel], 1)
            prob_all = tf.stop_gradient(prob_all)
            protos = update_cluster(h_all, prob_all)
            # protos = tf.cond(
            #     tf.shape(self._x_unlabel)[1] > 0,
            #     lambda: update_cluster(h_all, prob_all), lambda: protos)
            logits_list.append(compute_logits(protos, h_test))

            self._unlabel_logits = compute_logits(self.protos, h_unlabel)[0]
            self._logits = logits_list
Exemplo n.º 6
0
	def compute_output(self):
		if not self.is_training:
			config = self.config
			VAT_weight = 0.00005
			num_steps = 20

			weights =  self.embedding_weights
			self.fast_weights = self.embedding_weights

			x = tf.reshape(self.x_train, [-1, config.height, config.width, config.num_channel])
			y = tf.squeeze(self.y_train_one_hot)

			for i in range(num_steps):
				loss = self.virtual_adversarial_loss(self.x_unlabel_flat, self._unlabel_logits, name="VAT-Inference")
				loss +=  self.virtual_adversarial_loss(x, y, name="VAT-Inference")
				ent_loss = entropy_y_x(compute_logits(self.protos, self.h_test))
				loss = ent_loss
				grads = tf.gradients(loss, list(weights.values()))
				grads = [tf.stop_gradient(grad) for grad in grads]
				gradients = dict(zip(weights.keys(), grads))

				self.fast_weights = dict(zip(weights.keys(),
																[self.fast_weights[key] - VAT_weight * gradients[key] for key in weights.keys()]))
				# self.vat_grads_and_vars =[]
				# for gradient, variable in vat_grads_and_vars:
				# 	if gradient is None:
				# 		gradient = tf.constant(0.0)
				# 	self.vat_grads_and_vars.append((gradient, variable))

				# with tf.control_dependencies([vat_train_op]):
				encoded_train, encoded_test = self.encode(self.x_train, self.x_test, update_batch_stats=False, ext_wts=self.fast_weights)
				protos = self._compute_protos(self.nway, encoded_train, self.y_train)

				self._logits = [compute_logits(protos, encoded_test)]

		# self._logits = tf.Print(self._logits, [tf.shape(self.x_unlabel_flat), tf.shape(self._unlabel_logits)])
		# self._logits = tf.Print(self._logits, [tf.shape(self.x_train), tf.shape(self.y_train_one_hot)])
		super().compute_output()
Exemplo n.º 7
0
	def compute_output(self):
		if not self.is_training:
			config = self.config
			num_steps = config.num_steps
			step_size = config.inference_step_size



			# x = tf.reshape(self._h_test, [-1, config.height, config.width, config.num_channel])
			y = tf.squeeze(self.y_train_one_hot)
			# self._protos = tf.Print(self.protos, [tf.shape(self._unlabel_logits),
			# 																			tf.shape(compute_logits(self.protos, self.h_test)),
			# 																			entropy_y_x(tf.expand_dims(self._unlabel_logits, 0))])
			protos = self.protos
			h_train = self.h_train
			h_test = self.h_test
			h_unlabel = self.h_unlabel
			with tf.name_scope('Adaptation'):
				for i in range(num_steps):
					# vat_loss_lbl = self.virtual_adversarial_loss\
					# 	(tf.squeeze(h_train), tf.squeeze(compute_logits(self.protos, h_train)), name="VAT-Inference")
					# vat_loss_ulbl = self.virtual_adversarial_loss\
					# 	(h_unlabel, self._unlabel_logits, name="VAT-Inference")
					# vat_loss_test = self.virtual_adversarial_loss\
					# 	(tf.squeeze(h_test), tf.squeeze(self._logits), name="VAT-Inference")
					ent_loss_lbl_r = relative_entropy_y_x(compute_logits(self.protos, h_train)[0])
					ent_loss_ulbl_r = relative_entropy_y_x(self._unlabel_logits)
					ent_loss_test_r = relative_entropy_y_x(compute_logits(self.protos, h_test)[0])
					ent_loss_lbl = entropy_y_x(compute_logits(self.protos, h_train))
					ent_loss_ulbl = entropy_y_x(tf.expand_dims(self._unlabel_logits, 0))
					ent_loss_test = entropy_y_x(compute_logits(self.protos, h_test))
					# ent_c = entropy_y_x(
					# 	tf.concat([compute_logits(self.protos, tf.concat( [self.h_train, self.h_test], 1)) , tf.expand_dims(self._unlabel_logits,0)],1))

					# grads_vat_ulbl = tf.gradients(vat_loss_ulbl, self.protos)[0]
					# grads_vat_test = tf.gradients(vat_loss_test, self.protos)[0]
					# grads_ent_ulbl = tf.gradients(ent_loss_ulbl, self.protos)[0]
					# grads_ent_test = tf.gradients(ent_loss_test, self.protos)[0]
					# grads_ent_lbl = tf.gradients(ent_loss_lbl, self.protos)[0]

					loss =  ent_loss_ulbl
					grads = tf.gradients(loss, self.protos, name='GRADS')[0]
					# grads = tf.Print(grads, [ent_loss_lbl, ent_loss_ulbl, ent_loss_test], name="PRINT")
					self._protos= self.protos -  step_size * grads
					self._unlabel_logits = compute_logits(self.protos, h_unlabel)
					self._logits = [compute_logits(self.protos, h_test)]

				# self.inference_summaries.append(tf.summary.scalar('ent-ulbl-loss', l2_norm(ent_loss_ulbl), family='loss_inference' ))
				# self.inference_summaries.append(tf.summary.scalar('vat-ulbl-grad', l2_norm(grads_vat_ulbl), family='inference'))
				# self.inference_summaries.append(tf.summary.scalar('vat-test-grad', l2_norm(grads_vat_test), family='inference'))
				# self.inference_summaries.append(tf.summary.scalar('ent-ulbl-grad', l2_norm(grads_ent_ulbl), family='inference'))
				# self.inference_summaries.append(tf.summary.scalar('ent-test-grad', l2_norm(grads_ent_test), family='inference'))
				# self.inference_summaries.append(tf.summary.scalar('ent-lbl-grad', l2_norm(grads_ent_lbl), family='inference'))

			# self._logits = tf.Print(self._logits,[tf.shape(self._logits)], "\n-------------------------------------\n")
			# self._logits = tf.Print(self._logits, [tf.shape(self.x_untf.shape(self.y_train_one_hot)])
		super().compute_output()
Exemplo n.º 8
0
    def get_train_op(self, logits, y_test):
        loss, train_op = BasicModelVAT.get_train_op(self, logits, y_test)
        config = self.config
        ENT_weight = config.ENT_weight
        VAT_ENT_step_size = config.VAT_ENT_step_size

        logits = self._unlabel_logits

        s = tf.shape(logits)
        s = s[0]
        p = tf.stop_gradient(self.h_unlabel)
        affinity_matrix = compute_logits(
            p, p) - (tf.eye(s, dtype=tf.float32) * 1000.0)
        # logits = tf.Print(logits, [tf.shape(point_logits)])

        ENT_loss = walking_penalty(logits, affinity_matrix)
        loss += ENT_weight * ENT_loss

        ENT_opt = tf.train.AdamOptimizer(VAT_ENT_step_size * self.learn_rate,
                                         name="Entropy-optimizer")
        ENT_grads_and_vars = ENT_opt.compute_gradients(loss)
        train_op = ENT_opt.apply_gradients(ENT_grads_and_vars)

        for gradient, variable in ENT_grads_and_vars:
            if gradient is None:
                gradient = tf.constant(0.0)
            self.adv_summaries.append(
                tf.summary.scalar("ENT/gradients/" + variable.name,
                                  l2_norm(gradient),
                                  family="Grads"))
            self.adv_summaries.append(
                tf.summary.histogram("ENT/gradients/" + variable.name,
                                     gradient,
                                     family="Grads"))

        self.summaries.append(tf.summary.scalar('entropy loss', ENT_loss))

        return loss, train_op
Exemplo n.º 9
0
    def get_SSL_train_op(self):
        config = self.config
        VAT_weight = config.VAT_weight
        ENT_weight = config.ENT_weight
        features = self.phi(self.unlabeled_training_data,
                            update_batch_stats=False)
        logits = compute_logits(self.protos, features)

        VAT_loss = self.virtual_adversarial_loss(self.unlabeled_training_data,
                                                 logits)
        ENT_loss = entropy_y_x(tf.expand_dims(logits, 0))
        # probs = tf.nn.softmax(logits)
        # ENT_loss = tf.Print(ENT_loss, [])
        VAT_opt = tf.train.AdamOptimizer(VAT_weight * self.learn_rate,
                                         name='Classification-VAT-Optimizer')
        VAT_grads_and_vars = VAT_opt.compute_gradients(VAT_loss)
        VAT_train_op = VAT_opt.apply_gradients(VAT_grads_and_vars)

        self.adv_summaries.append(
            tf.summary.scalar("VAT-loss", VAT_loss, family='classification'))
        self.adv_summaries.append(
            tf.summary.scalar("ENT-loss", ENT_loss, family='classification'))

        return VAT_loss, VAT_train_op
    def predict(self):
        """See `model.py` for documentation."""
        nclasses = self.nway
        num_cluster_steps = self.config.num_cluster_steps
        h_train, h_unlabel, h_test = self.encode(self.x_train, self.x_unlabel,
                                                 self.x_test)
        y_train = self.y_train
        protos = self._compute_protos(nclasses, h_train, y_train)
        logits_list = []
        logits_list.append(compute_logits(protos, h_test))

        # Hard assignment for training images.
        prob_train = [None] * (nclasses)
        for kk in range(nclasses):
            # [B, N, 1]
            prob_train[kk] = tf.expand_dims(
                tf.cast(tf.equal(y_train, kk), h_train.dtype), 2)
        prob_train = concat(prob_train, 2)

        y_train_shape = tf.shape(y_train)
        bsize = y_train_shape[0]

        h_all = concat([h_train, h_unlabel], 1)
        mask = None

        # Calculate pairwise distances.
        protos_1 = tf.expand_dims(protos, 2)
        protos_2 = tf.expand_dims(h_unlabel, 1)
        pair_dist = tf.reduce_sum((protos_1 - protos_2)**2, [3])  # [B, K, N]
        mean_dist = tf.reduce_mean(pair_dist, [2], keep_dims=True)
        pair_dist_normalize = pair_dist / mean_dist
        min_dist = tf.reduce_min(pair_dist_normalize, [2],
                                 keep_dims=True)  # [B, K, 1]
        max_dist = tf.reduce_max(pair_dist_normalize, [2], keep_dims=True)
        mean_dist, var_dist = tf.nn.moments(pair_dist_normalize, [2],
                                            keep_dims=True)
        mean_dist += tf.to_float(tf.equal(mean_dist, 0.0))
        var_dist += tf.to_float(tf.equal(var_dist, 0.0))
        skew = tf.reduce_mean(
            ((pair_dist_normalize - mean_dist)**3) / (tf.sqrt(var_dist)**3),
            [2],
            keep_dims=True)
        kurt = tf.reduce_mean(
            ((pair_dist_normalize - mean_dist)**4) / (var_dist**2) - 3, [2],
            keep_dims=True)

        n_features = 5
        n_out = 3

        dist_features = tf.reshape(
            concat([min_dist, max_dist, var_dist, skew, kurt], 2),
            [-1, n_features])  # [BK, 4]
        dist_features = tf.stop_gradient(dist_features)

        hdim = [n_features, 20, n_out]
        act_fn = [tf.nn.tanh, None]
        thresh = mlp(dist_features,
                     hdim,
                     is_training=True,
                     act_fn=act_fn,
                     dtype=tf.float32,
                     add_bias=True,
                     wd=None,
                     init_std=[0.01, 0.01],
                     init_method=None,
                     scope="dist_mlp",
                     dropout=None,
                     trainable=True)
        scale = tf.exp(thresh[:, 2])
        bias_start = tf.exp(thresh[:, 0])
        bias_add = thresh[:, 1]
        bias_start = tf.reshape(bias_start, [bsize, 1, -1])  #[B, 1, K]
        bias_add = tf.reshape(bias_add, [bsize, 1, -1])

        self._scale = scale
        self._bias_start = bias_start
        self._bias_add = bias_add

        # Run clustering.
        for tt in range(num_cluster_steps):
            protos_1 = tf.expand_dims(protos, 2)
            protos_2 = tf.expand_dims(h_unlabel, 1)
            pair_dist = tf.reduce_sum((protos_1 - protos_2)**2,
                                      [3])  # [B, K, N]
            m_dist = tf.reduce_mean(pair_dist, [2])  # [B, K]
            m_dist_1 = tf.expand_dims(m_dist, 1)  # [B, 1, K]
            m_dist_1 += tf.to_float(tf.equal(m_dist_1, 0.0))
            # Label assignment.
            if num_cluster_steps > 1:
                bias_tt = bias_start + (
                    tt / float(num_cluster_steps - 1)) * bias_add
            else:
                bias_tt = bias_start

            negdist = compute_logits(protos, h_unlabel)
            mask = tf.sigmoid((negdist / m_dist_1 + bias_tt) * scale)
            prob_unlabel, mask = assign_cluster_soft_mask(
                protos, h_unlabel, mask)
            prob_all = concat([prob_train, prob_unlabel * mask], 1)
            # No update if 0 unlabel.
            protos = tf.cond(
                tf.shape(self._x_unlabel)[1] > 0,
                lambda: update_cluster(h_all, prob_all), lambda: protos)
            logits_list.append(compute_logits(protos, h_test))

        # Distractor evaluation.
        if mask is not None:
            max_mask = tf.reduce_max(mask, [2])
            mean_mask = tf.reduce_mean(max_mask)
            pred_non_distractor = tf.to_float(max_mask > mean_mask)
            acc, recall, precision = eval_distractor(pred_non_distractor,
                                                     self.y_unlabel)
            self._non_distractor_acc = acc
            self._distractor_recall = recall
            self._distractor_precision = precision
            self._distractor_pred = max_mask
        return logits_list
Exemplo n.º 11
0
 def noisy_forward(self, data, noise, update_batch_stats=False):
     with tf.name_scope("forward"):
         encoded = self.h_unlabel
         logits = compute_logits(self.protos + noise, encoded)
     return logits
    def predict(self):
        """See `model.py` for documentation."""
        super().predict()
        nclasses = self.nway
        num_cluster_steps = self.config.num_cluster_steps
        h_train, h_unlabel, h_test = self.encode(self.x_train, self.x_unlabel,
                                                 self.x_test)
        y_train = self.y_train
        protos = self._compute_protos(nclasses, h_train, y_train)

        # Distractor class has a zero vector as prototype.
        protos = concat([protos, tf.zeros_like(protos[:, 0:1, :])], 1)

        # Hard assignment for training images.
        prob_train = [None] * (nclasses + 1)
        for kk in range(nclasses):
            # [B, N, 1]
            prob_train[kk] = tf.expand_dims(
                tf.cast(tf.equal(y_train, kk), h_train.dtype), 2)
            prob_train[-1] = tf.zeros_like(prob_train[0])
        prob_train = concat(prob_train, 2)

        # Initialize cluster radii.
        radii = [None] * (nclasses + 1)
        y_train_shape = tf.shape(y_train)
        bsize = y_train_shape[0]
        for kk in range(nclasses):
            radii[kk] = tf.ones([bsize, 1]) * 1.0

        # Distractor class has a larger radius.
        if FLAGS.learn_radius:
            log_distractor_radius = tf.get_variable(
                "log_distractor_radius",
                shape=[],
                dtype=tf.float32,
                initializer=tf.constant_initializer(np.log(FLAGS.init_radius)))
            distractor_radius = tf.exp(log_distractor_radius)
        else:
            distractor_radius = FLAGS.init_radius
        distractor_radius = tf.cond(
            tf.shape(self._x_unlabel)[1] > 0, lambda: distractor_radius,
            lambda: 100000.0)
        # distractor_radius = tf.Print(distractor_radius, [distractor_radius])
        radii[-1] = tf.ones([bsize, 1]) * distractor_radius
        radii = concat(radii, 1)  # [B, K]

        self.radii = radii

        h_all = concat([h_train, h_unlabel], 1)
        logits_list = []
        logits_list.append(compute_logits_radii(protos, h_test, radii))

        # Run clustering.
        for tt in range(num_cluster_steps):
            # Label assignment.
            prob_unlabel = assign_cluster_radii(protos, h_unlabel, radii)
            prob_all = concat([prob_train, prob_unlabel], 1)
            protos = update_cluster(h_all, prob_all)
            logits_list.append(compute_logits_radii(protos, h_test, radii))

        # Distractor evaluation.
        is_distractor = tf.equal(tf.argmax(prob_unlabel, axis=-1), nclasses)
        pred_non_distractor = 1.0 - tf.to_float(is_distractor)
        acc, recall, precision = eval_distractor(pred_non_distractor,
                                                 self.y_unlabel)
        self._non_distractor_acc = acc
        self._distractor_recall = recall
        self._distractor_precision = precision
        self._distractor_pred = 1.0 - tf.exp(prob_unlabel[:, :, -1])

        self._logits = logits_list
        if not self.is_training:
            self._logits = [compute_logits(protos[:, 0:5], h_test)]
        protos = concat(
            [self.protos, tf.zeros_like(self.protos[:, 0:1, :])], 1)
        radii = tf.stop_gradient(self.radii)
        self._unlabel_logits = compute_logits_radii(protos, h_unlabel,
                                                    radii)[0]
 def predict(self):
     super(RefineModel, self).predict()
     with tf.name_scope('Predict/VAT'):
         self._unlabel_logits = compute_logits(self._ssl_protos,
                                               self.h_unlabel)
Exemplo n.º 14
0
  def __init__(self,
               config,
               x,
               y,
               x_b,
               y_b,
               x_b_v,
               y_b_v,
               num_classes_a,
               num_classes_b,
               is_training=True,
               ext_wts=None,
               y_sel=None,
               w_class_a=None,
               b_class_a=None,
               nshot=None):
    self._config = config
    self._is_training = is_training
    self._num_classes_a = num_classes_a
    self._num_classes_b = num_classes_b

    if config.backbone_class == 'resnet_backbone':
      bb_config = config.resnet_config
    else:
      assert False, 'Not supported'
    opt_config = config.optimizer_config
    proto_config = config.protonet_config
    transfer_config = config.transfer_config

    self._backbone = get_model(config.backbone_class, bb_config)
    self._inputs = x
    self._labels = y
    # if opt_config.num_gpu > 1:
    #   self._labels_all = allgather(self._labels)
    # else:
    self._labels_all = self._labels
    self._inputs_b = x_b
    self._labels_b = y_b
    self._inputs_b_v = x_b_v
    self._labels_b_v = y_b_v
    # if opt_config.num_gpu > 1:
    #   self._labels_b_v_all = allgather(self._labels_b_v)
    # else:
    self._labels_b_v_all = self._labels_b_v
    self._y_sel = y_sel
    self._mask = tf.placeholder(tf.bool, [], name='mask')

    # global_step = tf.get_variable(
    #     'global_step', shape=[], dtype=tf.int64, trainable=False)
    global_step = tf.contrib.framework.get_or_create_global_step()
    self._global_step = global_step
    log.info('LR decay steps {}'.format(opt_config.lr_decay_steps))
    log.info('LR list {}'.format(opt_config.lr_list))
    learn_rate = tf.train.piecewise_constant(
        global_step, list(
            np.array(opt_config.lr_decay_steps).astype(np.int64)),
        list(opt_config.lr_list))
    self._learn_rate = learn_rate

    opt = self.get_optimizer(opt_config.optimizer, learn_rate)
    # if opt_config.num_gpu > 1:
    #   opt = hvd.DistributedOptimizer(opt)

    with tf.name_scope('TaskA'):
      h_a = self.backbone(x, is_training=is_training, ext_wts=ext_wts)
      self._h_a = h_a

    # Apply BN ops.
    bn_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.name_scope('TaskB'):
      x_b_all = tf.concat([x_b, x_b_v], axis=0)
      if ext_wts is not None:
        h_b_all = self.backbone(
            x_b_all, is_training=is_training, reuse=True, ext_wts=ext_wts)
      else:
        h_b_all = self.backbone(x_b_all, is_training=is_training, reuse=True)

    with tf.name_scope('TaskA'):
      # Calculates hidden activation size.
      h_shape = h_a.get_shape()
      h_size = 1
      for ss in h_shape[1:]:
        h_size *= int(ss)

      if w_class_a is None:
        if ext_wts is not None:
          w_class_a = weight_variable(
              [h_size, num_classes_a],
              init_method='numpy',
              dtype=tf.float32,
              init_param={'val': np.transpose(ext_wts['w_class_a'])},
              wd=config.wd,
              name='w_class_a')
          b_class_a = weight_variable([],
                                      init_method='numpy',
                                      dtype=tf.float32,
                                      init_param={'val': ext_wts['b_class_a']},
                                      wd=0e0,
                                      name='b_class_a')
        else:
          w_class_a = weight_variable([h_size, num_classes_a],
                                      init_method='truncated_normal',
                                      dtype=tf.float32,
                                      init_param={'stddev': 0.01},
                                      wd=bb_config.wd,
                                      name='w_class_a')
          b_class_a = weight_variable([num_classes_a],
                                      init_method='constant',
                                      init_param={'val': 0.0},
                                      name='b_class_a')
        self._w_class_a_orig = w_class_a
        self._b_class_a_orig = b_class_a
      else:
        assert b_class_a is not None
        w_class_a_orig = weight_variable([h_size, num_classes_a],
                                         init_method='truncated_normal',
                                         dtype=tf.float32,
                                         init_param={'stddev': 0.01},
                                         wd=bb_config.wd,
                                         name='w_class_a')
        b_class_a_orig = weight_variable([num_classes_a],
                                         init_method='constant',
                                         init_param={'val': 0.0},
                                         name='b_class_a')
        self._w_class_a_orig = w_class_a_orig
        self._b_class_a_orig = b_class_a_orig

      self._w_class_a = w_class_a
      self._b_class_a = b_class_a
      num_classes_a_dyn = tf.cast(tf.shape(b_class_a)[0], tf.int64)
      num_classes_a_dyn32 = tf.shape(b_class_a)[0]

      if proto_config.cosine_a:
        if proto_config.cosine_tau:
          if ext_wts is None:
            init_val = 10.0
          else:
            init_val = ext_wts['tau'][0]
          tau = weight_variable([],
                                init_method='constant',
                                init_param={'val': init_val},
                                name='tau')
        else:
          tau = tf.constant(1.0)
        w_class_a_norm = self._normalize(w_class_a, 0)
        h_a_norm = self._normalize(h_a, 1)
        dot = tf.matmul(h_a_norm, w_class_a_norm)
        if ext_wts is not None:
          dot += b_class_a
        logits_a = tau * dot
      else:
        logits_a = compute_euc(tf.transpose(w_class_a), h_a)
      self._prediction_a = logits_a
      # if opt_config.num_gpu > 1:
      #   self._prediction_a_all = allgather(self._prediction_a)
      # else:
      self._prediction_a_all = self._prediction_a

      xent_a = tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=logits_a, labels=y)
      cost_a = tf.reduce_mean(xent_a, name='xent')
      self._cost_a = cost_a
      cost_a += self._decay()
      correct_a = tf.equal(tf.argmax(logits_a, axis=1), y)
      self._correct_a = correct_a
      self._acc_a = tf.reduce_mean(tf.cast(correct_a, cost_a.dtype))

    with tf.name_scope('TaskB'):
      h_b = h_b_all[:tf.shape(x_b)[0]]
      h_b_v = h_b_all[tf.shape(x_b)[0]:]

      # Add new axes for the `batch` dimension.
      h_b_ = tf.expand_dims(h_b, 0)
      h_b_v_ = tf.expand_dims(h_b_v, 0)
      y_b_ = tf.expand_dims(y_b, 0)
      y_b_v_ = tf.expand_dims(y_b_v, 0)

      if transfer_config.old_and_new:
        protos_b = self._compute_protos(num_classes_b, h_b_,
                                        y_b_ - num_classes_a)
      else:
        protos_b = self._compute_protos(num_classes_b, h_b_, y_b_)

      w_class_a_ = tf.expand_dims(tf.transpose(w_class_a), 0)
      if proto_config.protos_phi:
        w_p1 = weight_variable([h_size],
                               init_method='constant',
                               dtype=tf.float32,
                               init_param={'val': 1.0},
                               wd=bb_config.wd,
                               name='w_p1')
      if proto_config.cosine_attention:
        w_q = weight_variable([h_size, h_size],
                              init_method='truncated_normal',
                              dtype=tf.float32,
                              init_param={'stddev': 0.1},
                              wd=bb_config.wd,
                              name='w_q')
        k_b = weight_variable([num_classes_a, h_size],
                              init_method='truncated_normal',
                              dtype=tf.float32,
                              init_param={'stddev': 0.1},
                              wd=bb_config.wd,
                              name='k_b')
        tau_q = weight_variable([],
                                init_method='constant',
                                init_param={'val': 10.0},
                                name='tau_q')
        if transfer_config.old_and_new:
          w_class_b = self._compute_protos_attend_fix(
              num_classes_b, h_b_, y_b_ - num_classes_a_dyn, w_q, tau_q, k_b,
              self._w_class_a_orig)
        else:
          w_class_b = self._compute_protos_attend_fix(
              num_classes_b, h_b_, y_b_, w_q, tau_q, k_b, self._w_class_a_orig)
        assert proto_config.protos_phi
        w_p2 = weight_variable([h_size],
                               init_method='constant',
                               dtype=tf.float32,
                               init_param={'val': 1.0},
                               wd=bb_config.wd,
                               name='w_p2')
        self._k_b = tf.expand_dims(w_p2, 1) * self._w_class_a_orig
        self._k_b2 = k_b
        self.bias = w_class_b
        self.new_protos = w_p1 * protos_b
        self.new_bias = w_p2 * w_class_b
        w_class_b = w_p1 * protos_b + w_p2 * w_class_b
        self.protos = protos_b
        self.w_class_b_final = w_class_b
      else:
        w_class_b = protos_b
        if proto_config.protos_phi:
          w_class_b = w_p1 * w_class_b

      self._w_class_b = w_class_b

      if transfer_config.old_and_new:
        w_class_all = tf.concat([w_class_a_, w_class_b], axis=1)
      else:
        w_class_all = w_class_b

      if proto_config.cosine_softmax_tau:
        tau_b = weight_variable([],
                                init_method='constant',
                                init_param={'val': 10.0},
                                name='tau_b')
      else:
        tau_b = tf.constant(1.0)

      if proto_config.similarity == 'euclidean':
        logits_b_v = compute_logits(w_class_all, h_b_v_)
      elif proto_config.similarity == 'cosine':
        logits_b_v = tau_b * compute_logits_cosine(w_class_all, h_b_v_)
      else:
        raise ValueError('Unknown similarity')
      self._logits_b_v = logits_b_v
      self._prediction_b = logits_b_v[0]
      # if opt_config.num_gpu > 1:
      #   self._prediction_b_all = allgather(self._prediction_b)
      # else:
      self._prediction_b_all = self._prediction_b

      # Mask out the old classes.
      def mask_fn():
        bin_mask = tf.expand_dims(
            tf.reduce_sum(
                tf.one_hot(y_sel, num_classes_a + num_classes_b),
                0,
                keep_dims=True), 0)
        logits_b_v_m = logits_b_v * (1.0 - bin_mask)
        logits_b_v_m -= bin_mask * 100.0
        return logits_b_v_m

      # if transfer_config.old_and_new:
      #   logits_b_v = tf.cond(self._mask, mask_fn, lambda: logits_b_v)
      xent_b_v = tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=logits_b_v, labels=y_b_v_)
      cost_b = tf.reduce_mean(xent_b_v, name='xent')
      self._cost_b = cost_b

    if transfer_config.old_and_new:
      total_cost = cost_b
    else:
      total_cost = (transfer_config.cost_a_ratio * cost_a +
                    transfer_config.cost_b_ratio * cost_b)
    self._total_cost = total_cost

    if not transfer_config.meta_only:
      # assert False, 'let us go for pretrained model first'
      var_list = tf.trainable_variables()
      var_list = list(filter(lambda x: 'phi' in x.name, var_list))
      layers = self.config.transfer_config.meta_layers
      if layers == "all":
        pass
      elif layers == "4":
        keywords = ['TaskB', 'unit_4_']
        filter_fn = lambda x: any([kw in x.name for kw in keywords])
        var_list = list(filter(filter_fn, var_list))
      else:
        raise ValueError('Unknown finetune layers {}'.format(layers))
      [log.info('Slow weights {}'.format(v.name)) for v in var_list]
    else:
      var_list = []

    if proto_config.cosine_softmax_tau:
      var_list += [tau_b]

    if proto_config.cosine_attention:
      var_list += [w_q, tau_q, k_b, w_p2]

    if proto_config.protos_phi:
      var_list += [w_p1]

    if transfer_config.train_wclass_a:
      if proto_config.similarity == 'euclidean':
        var_list += [w_class_a, b_class_a]
      elif proto_config.similarity == 'cosine':
        var_list += [w_class_a]

    if is_training:
      grads_and_vars = opt.compute_gradients(total_cost, var_list)
      with tf.control_dependencies(bn_ops):
        [log.info('BN op {}'.format(op.name)) for op in bn_ops]
        train_op = opt.apply_gradients(grads_and_vars, global_step=global_step)

      grads_and_vars_b = opt.compute_gradients(cost_b, var_list)
      with tf.control_dependencies(bn_ops):
        train_op_b = opt.apply_gradients(
            grads_and_vars_b, global_step=global_step)

      with tf.control_dependencies(bn_ops):
        train_op_a = opt.minimize(cost_a, global_step=global_step)
      self._train_op = train_op
      self._train_op_a = train_op_a
      self._train_op_b = train_op_b
    self._initializer = tf.global_variables_initializer()
    self._w_class_a = w_class_a
 def predict(self):
     """See `model.py` for documentation."""
     with tf.name_scope('Predict'):
         self.init_episode_classifier()
         logits = compute_logits(self.protos, self.h_test)
         self._logits = [logits]