def classification_train_op(self): with tf.name_scope('Classification_train_op'): config = self.config classification_weight = config.classification_weight features = self.phi(self.training_data, update_batch_stats=False) logits = compute_logits(self._persistent_protos, features) # probs = tf.nn.softmax(logits/1000) # logits = tf.Print(logits, [self.training_labels],summarize=50) loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits / 200, labels=self.training_labels) loss = tf.reduce_mean(loss) opt = tf.train.AdamOptimizer(classification_weight * self.learn_rate, name='Classification-Optimizer') grads_and_vars = opt.compute_gradients(loss) train_op = opt.apply_gradients(grads_and_vars) self.adv_summaries.append( tf.summary.scalar("loss", loss, family='classification')) for gradient, variable in grads_and_vars: if gradient is None: gradient = tf.constant(0.0) self.adv_summaries.append( tf.summary.scalar("gradients/" + variable.name, l2_norm(gradient), family="classification")) # self.adv_summaries.append(tf.summary.scalar("variables/" + variable.name, l2_norm(variable), family="VARS")) # self.adv_summaries.append(tf.summary.histogram("gradients/" + variable.name, gradient, family="Grads")) return loss, train_op
def noisy_forward(self, data, noise=tf.constant(0.0), update_batch_stats=False, wts=None): if wts is None: wts = self.embedding_weights with tf.name_scope("forward"): encoded = self.phi(data+noise, update_batch_stats=update_batch_stats, ext_wts=wts) logits = compute_logits(self.protos, encoded) return logits
def predict(self): """See `model.py` for documentation.""" h_train, h_test = self.get_encoded_inputs(self.x_train, self.x_test) y_train = self.y_train nclasses = self.nway protos = self._compute_protos(nclasses, h_train, y_train) logits = compute_logits(protos, h_test) return [logits]
def noisy_forward(self, data, noise=tf.constant(0.0), update_batch_stats=False): with tf.name_scope("forward"): encoded = self.phi(data + noise, update_batch_stats=update_batch_stats) logits = compute_logits(self.protos, encoded) return logits
def predict(self): """See `model.py` for documentation.""" super().predict() nclasses = self.nway num_cluster_steps = self.config.num_cluster_steps h_train, h_unlabel, h_test = self.encode(self.x_train, self.x_unlabel, self.x_test) y_train = self.y_train protos = self._compute_protos(nclasses, h_train, y_train) logits = compute_logits(protos, h_test) # Hard assignment for training images. prob_train = [None] * nclasses for kk in range(nclasses): # [B, N, 1] prob_train[kk] = tf.expand_dims( tf.cast(tf.equal(y_train, kk), h_train.dtype), 2) prob_train = concat(prob_train, 2) h_all = concat([h_train, h_unlabel], 1) logits_list = [] logits_list.append(compute_logits(protos, h_test)) # Run clustering. for tt in range(num_cluster_steps): # Label assignment. prob_unlabel = assign_cluster(protos, h_unlabel) entropy = tf.reduce_sum(-prob_unlabel * tf.log(prob_unlabel), [2], keep_dims=True) prob_all = concat([prob_train, prob_unlabel], 1) prob_all = tf.stop_gradient(prob_all) protos = update_cluster(h_all, prob_all) # protos = tf.cond( # tf.shape(self._x_unlabel)[1] > 0, # lambda: update_cluster(h_all, prob_all), lambda: protos) logits_list.append(compute_logits(protos, h_test)) self._unlabel_logits = compute_logits(self.protos, h_unlabel)[0] self._logits = logits_list
def compute_output(self): if not self.is_training: config = self.config VAT_weight = 0.00005 num_steps = 20 weights = self.embedding_weights self.fast_weights = self.embedding_weights x = tf.reshape(self.x_train, [-1, config.height, config.width, config.num_channel]) y = tf.squeeze(self.y_train_one_hot) for i in range(num_steps): loss = self.virtual_adversarial_loss(self.x_unlabel_flat, self._unlabel_logits, name="VAT-Inference") loss += self.virtual_adversarial_loss(x, y, name="VAT-Inference") ent_loss = entropy_y_x(compute_logits(self.protos, self.h_test)) loss = ent_loss grads = tf.gradients(loss, list(weights.values())) grads = [tf.stop_gradient(grad) for grad in grads] gradients = dict(zip(weights.keys(), grads)) self.fast_weights = dict(zip(weights.keys(), [self.fast_weights[key] - VAT_weight * gradients[key] for key in weights.keys()])) # self.vat_grads_and_vars =[] # for gradient, variable in vat_grads_and_vars: # if gradient is None: # gradient = tf.constant(0.0) # self.vat_grads_and_vars.append((gradient, variable)) # with tf.control_dependencies([vat_train_op]): encoded_train, encoded_test = self.encode(self.x_train, self.x_test, update_batch_stats=False, ext_wts=self.fast_weights) protos = self._compute_protos(self.nway, encoded_train, self.y_train) self._logits = [compute_logits(protos, encoded_test)] # self._logits = tf.Print(self._logits, [tf.shape(self.x_unlabel_flat), tf.shape(self._unlabel_logits)]) # self._logits = tf.Print(self._logits, [tf.shape(self.x_train), tf.shape(self.y_train_one_hot)]) super().compute_output()
def compute_output(self): if not self.is_training: config = self.config num_steps = config.num_steps step_size = config.inference_step_size # x = tf.reshape(self._h_test, [-1, config.height, config.width, config.num_channel]) y = tf.squeeze(self.y_train_one_hot) # self._protos = tf.Print(self.protos, [tf.shape(self._unlabel_logits), # tf.shape(compute_logits(self.protos, self.h_test)), # entropy_y_x(tf.expand_dims(self._unlabel_logits, 0))]) protos = self.protos h_train = self.h_train h_test = self.h_test h_unlabel = self.h_unlabel with tf.name_scope('Adaptation'): for i in range(num_steps): # vat_loss_lbl = self.virtual_adversarial_loss\ # (tf.squeeze(h_train), tf.squeeze(compute_logits(self.protos, h_train)), name="VAT-Inference") # vat_loss_ulbl = self.virtual_adversarial_loss\ # (h_unlabel, self._unlabel_logits, name="VAT-Inference") # vat_loss_test = self.virtual_adversarial_loss\ # (tf.squeeze(h_test), tf.squeeze(self._logits), name="VAT-Inference") ent_loss_lbl_r = relative_entropy_y_x(compute_logits(self.protos, h_train)[0]) ent_loss_ulbl_r = relative_entropy_y_x(self._unlabel_logits) ent_loss_test_r = relative_entropy_y_x(compute_logits(self.protos, h_test)[0]) ent_loss_lbl = entropy_y_x(compute_logits(self.protos, h_train)) ent_loss_ulbl = entropy_y_x(tf.expand_dims(self._unlabel_logits, 0)) ent_loss_test = entropy_y_x(compute_logits(self.protos, h_test)) # ent_c = entropy_y_x( # tf.concat([compute_logits(self.protos, tf.concat( [self.h_train, self.h_test], 1)) , tf.expand_dims(self._unlabel_logits,0)],1)) # grads_vat_ulbl = tf.gradients(vat_loss_ulbl, self.protos)[0] # grads_vat_test = tf.gradients(vat_loss_test, self.protos)[0] # grads_ent_ulbl = tf.gradients(ent_loss_ulbl, self.protos)[0] # grads_ent_test = tf.gradients(ent_loss_test, self.protos)[0] # grads_ent_lbl = tf.gradients(ent_loss_lbl, self.protos)[0] loss = ent_loss_ulbl grads = tf.gradients(loss, self.protos, name='GRADS')[0] # grads = tf.Print(grads, [ent_loss_lbl, ent_loss_ulbl, ent_loss_test], name="PRINT") self._protos= self.protos - step_size * grads self._unlabel_logits = compute_logits(self.protos, h_unlabel) self._logits = [compute_logits(self.protos, h_test)] # self.inference_summaries.append(tf.summary.scalar('ent-ulbl-loss', l2_norm(ent_loss_ulbl), family='loss_inference' )) # self.inference_summaries.append(tf.summary.scalar('vat-ulbl-grad', l2_norm(grads_vat_ulbl), family='inference')) # self.inference_summaries.append(tf.summary.scalar('vat-test-grad', l2_norm(grads_vat_test), family='inference')) # self.inference_summaries.append(tf.summary.scalar('ent-ulbl-grad', l2_norm(grads_ent_ulbl), family='inference')) # self.inference_summaries.append(tf.summary.scalar('ent-test-grad', l2_norm(grads_ent_test), family='inference')) # self.inference_summaries.append(tf.summary.scalar('ent-lbl-grad', l2_norm(grads_ent_lbl), family='inference')) # self._logits = tf.Print(self._logits,[tf.shape(self._logits)], "\n-------------------------------------\n") # self._logits = tf.Print(self._logits, [tf.shape(self.x_untf.shape(self.y_train_one_hot)]) super().compute_output()
def get_train_op(self, logits, y_test): loss, train_op = BasicModelVAT.get_train_op(self, logits, y_test) config = self.config ENT_weight = config.ENT_weight VAT_ENT_step_size = config.VAT_ENT_step_size logits = self._unlabel_logits s = tf.shape(logits) s = s[0] p = tf.stop_gradient(self.h_unlabel) affinity_matrix = compute_logits( p, p) - (tf.eye(s, dtype=tf.float32) * 1000.0) # logits = tf.Print(logits, [tf.shape(point_logits)]) ENT_loss = walking_penalty(logits, affinity_matrix) loss += ENT_weight * ENT_loss ENT_opt = tf.train.AdamOptimizer(VAT_ENT_step_size * self.learn_rate, name="Entropy-optimizer") ENT_grads_and_vars = ENT_opt.compute_gradients(loss) train_op = ENT_opt.apply_gradients(ENT_grads_and_vars) for gradient, variable in ENT_grads_and_vars: if gradient is None: gradient = tf.constant(0.0) self.adv_summaries.append( tf.summary.scalar("ENT/gradients/" + variable.name, l2_norm(gradient), family="Grads")) self.adv_summaries.append( tf.summary.histogram("ENT/gradients/" + variable.name, gradient, family="Grads")) self.summaries.append(tf.summary.scalar('entropy loss', ENT_loss)) return loss, train_op
def get_SSL_train_op(self): config = self.config VAT_weight = config.VAT_weight ENT_weight = config.ENT_weight features = self.phi(self.unlabeled_training_data, update_batch_stats=False) logits = compute_logits(self.protos, features) VAT_loss = self.virtual_adversarial_loss(self.unlabeled_training_data, logits) ENT_loss = entropy_y_x(tf.expand_dims(logits, 0)) # probs = tf.nn.softmax(logits) # ENT_loss = tf.Print(ENT_loss, []) VAT_opt = tf.train.AdamOptimizer(VAT_weight * self.learn_rate, name='Classification-VAT-Optimizer') VAT_grads_and_vars = VAT_opt.compute_gradients(VAT_loss) VAT_train_op = VAT_opt.apply_gradients(VAT_grads_and_vars) self.adv_summaries.append( tf.summary.scalar("VAT-loss", VAT_loss, family='classification')) self.adv_summaries.append( tf.summary.scalar("ENT-loss", ENT_loss, family='classification')) return VAT_loss, VAT_train_op
def predict(self): """See `model.py` for documentation.""" nclasses = self.nway num_cluster_steps = self.config.num_cluster_steps h_train, h_unlabel, h_test = self.encode(self.x_train, self.x_unlabel, self.x_test) y_train = self.y_train protos = self._compute_protos(nclasses, h_train, y_train) logits_list = [] logits_list.append(compute_logits(protos, h_test)) # Hard assignment for training images. prob_train = [None] * (nclasses) for kk in range(nclasses): # [B, N, 1] prob_train[kk] = tf.expand_dims( tf.cast(tf.equal(y_train, kk), h_train.dtype), 2) prob_train = concat(prob_train, 2) y_train_shape = tf.shape(y_train) bsize = y_train_shape[0] h_all = concat([h_train, h_unlabel], 1) mask = None # Calculate pairwise distances. protos_1 = tf.expand_dims(protos, 2) protos_2 = tf.expand_dims(h_unlabel, 1) pair_dist = tf.reduce_sum((protos_1 - protos_2)**2, [3]) # [B, K, N] mean_dist = tf.reduce_mean(pair_dist, [2], keep_dims=True) pair_dist_normalize = pair_dist / mean_dist min_dist = tf.reduce_min(pair_dist_normalize, [2], keep_dims=True) # [B, K, 1] max_dist = tf.reduce_max(pair_dist_normalize, [2], keep_dims=True) mean_dist, var_dist = tf.nn.moments(pair_dist_normalize, [2], keep_dims=True) mean_dist += tf.to_float(tf.equal(mean_dist, 0.0)) var_dist += tf.to_float(tf.equal(var_dist, 0.0)) skew = tf.reduce_mean( ((pair_dist_normalize - mean_dist)**3) / (tf.sqrt(var_dist)**3), [2], keep_dims=True) kurt = tf.reduce_mean( ((pair_dist_normalize - mean_dist)**4) / (var_dist**2) - 3, [2], keep_dims=True) n_features = 5 n_out = 3 dist_features = tf.reshape( concat([min_dist, max_dist, var_dist, skew, kurt], 2), [-1, n_features]) # [BK, 4] dist_features = tf.stop_gradient(dist_features) hdim = [n_features, 20, n_out] act_fn = [tf.nn.tanh, None] thresh = mlp(dist_features, hdim, is_training=True, act_fn=act_fn, dtype=tf.float32, add_bias=True, wd=None, init_std=[0.01, 0.01], init_method=None, scope="dist_mlp", dropout=None, trainable=True) scale = tf.exp(thresh[:, 2]) bias_start = tf.exp(thresh[:, 0]) bias_add = thresh[:, 1] bias_start = tf.reshape(bias_start, [bsize, 1, -1]) #[B, 1, K] bias_add = tf.reshape(bias_add, [bsize, 1, -1]) self._scale = scale self._bias_start = bias_start self._bias_add = bias_add # Run clustering. for tt in range(num_cluster_steps): protos_1 = tf.expand_dims(protos, 2) protos_2 = tf.expand_dims(h_unlabel, 1) pair_dist = tf.reduce_sum((protos_1 - protos_2)**2, [3]) # [B, K, N] m_dist = tf.reduce_mean(pair_dist, [2]) # [B, K] m_dist_1 = tf.expand_dims(m_dist, 1) # [B, 1, K] m_dist_1 += tf.to_float(tf.equal(m_dist_1, 0.0)) # Label assignment. if num_cluster_steps > 1: bias_tt = bias_start + ( tt / float(num_cluster_steps - 1)) * bias_add else: bias_tt = bias_start negdist = compute_logits(protos, h_unlabel) mask = tf.sigmoid((negdist / m_dist_1 + bias_tt) * scale) prob_unlabel, mask = assign_cluster_soft_mask( protos, h_unlabel, mask) prob_all = concat([prob_train, prob_unlabel * mask], 1) # No update if 0 unlabel. protos = tf.cond( tf.shape(self._x_unlabel)[1] > 0, lambda: update_cluster(h_all, prob_all), lambda: protos) logits_list.append(compute_logits(protos, h_test)) # Distractor evaluation. if mask is not None: max_mask = tf.reduce_max(mask, [2]) mean_mask = tf.reduce_mean(max_mask) pred_non_distractor = tf.to_float(max_mask > mean_mask) acc, recall, precision = eval_distractor(pred_non_distractor, self.y_unlabel) self._non_distractor_acc = acc self._distractor_recall = recall self._distractor_precision = precision self._distractor_pred = max_mask return logits_list
def noisy_forward(self, data, noise, update_batch_stats=False): with tf.name_scope("forward"): encoded = self.h_unlabel logits = compute_logits(self.protos + noise, encoded) return logits
def predict(self): """See `model.py` for documentation.""" super().predict() nclasses = self.nway num_cluster_steps = self.config.num_cluster_steps h_train, h_unlabel, h_test = self.encode(self.x_train, self.x_unlabel, self.x_test) y_train = self.y_train protos = self._compute_protos(nclasses, h_train, y_train) # Distractor class has a zero vector as prototype. protos = concat([protos, tf.zeros_like(protos[:, 0:1, :])], 1) # Hard assignment for training images. prob_train = [None] * (nclasses + 1) for kk in range(nclasses): # [B, N, 1] prob_train[kk] = tf.expand_dims( tf.cast(tf.equal(y_train, kk), h_train.dtype), 2) prob_train[-1] = tf.zeros_like(prob_train[0]) prob_train = concat(prob_train, 2) # Initialize cluster radii. radii = [None] * (nclasses + 1) y_train_shape = tf.shape(y_train) bsize = y_train_shape[0] for kk in range(nclasses): radii[kk] = tf.ones([bsize, 1]) * 1.0 # Distractor class has a larger radius. if FLAGS.learn_radius: log_distractor_radius = tf.get_variable( "log_distractor_radius", shape=[], dtype=tf.float32, initializer=tf.constant_initializer(np.log(FLAGS.init_radius))) distractor_radius = tf.exp(log_distractor_radius) else: distractor_radius = FLAGS.init_radius distractor_radius = tf.cond( tf.shape(self._x_unlabel)[1] > 0, lambda: distractor_radius, lambda: 100000.0) # distractor_radius = tf.Print(distractor_radius, [distractor_radius]) radii[-1] = tf.ones([bsize, 1]) * distractor_radius radii = concat(radii, 1) # [B, K] self.radii = radii h_all = concat([h_train, h_unlabel], 1) logits_list = [] logits_list.append(compute_logits_radii(protos, h_test, radii)) # Run clustering. for tt in range(num_cluster_steps): # Label assignment. prob_unlabel = assign_cluster_radii(protos, h_unlabel, radii) prob_all = concat([prob_train, prob_unlabel], 1) protos = update_cluster(h_all, prob_all) logits_list.append(compute_logits_radii(protos, h_test, radii)) # Distractor evaluation. is_distractor = tf.equal(tf.argmax(prob_unlabel, axis=-1), nclasses) pred_non_distractor = 1.0 - tf.to_float(is_distractor) acc, recall, precision = eval_distractor(pred_non_distractor, self.y_unlabel) self._non_distractor_acc = acc self._distractor_recall = recall self._distractor_precision = precision self._distractor_pred = 1.0 - tf.exp(prob_unlabel[:, :, -1]) self._logits = logits_list if not self.is_training: self._logits = [compute_logits(protos[:, 0:5], h_test)] protos = concat( [self.protos, tf.zeros_like(self.protos[:, 0:1, :])], 1) radii = tf.stop_gradient(self.radii) self._unlabel_logits = compute_logits_radii(protos, h_unlabel, radii)[0]
def predict(self): super(RefineModel, self).predict() with tf.name_scope('Predict/VAT'): self._unlabel_logits = compute_logits(self._ssl_protos, self.h_unlabel)
def __init__(self, config, x, y, x_b, y_b, x_b_v, y_b_v, num_classes_a, num_classes_b, is_training=True, ext_wts=None, y_sel=None, w_class_a=None, b_class_a=None, nshot=None): self._config = config self._is_training = is_training self._num_classes_a = num_classes_a self._num_classes_b = num_classes_b if config.backbone_class == 'resnet_backbone': bb_config = config.resnet_config else: assert False, 'Not supported' opt_config = config.optimizer_config proto_config = config.protonet_config transfer_config = config.transfer_config self._backbone = get_model(config.backbone_class, bb_config) self._inputs = x self._labels = y # if opt_config.num_gpu > 1: # self._labels_all = allgather(self._labels) # else: self._labels_all = self._labels self._inputs_b = x_b self._labels_b = y_b self._inputs_b_v = x_b_v self._labels_b_v = y_b_v # if opt_config.num_gpu > 1: # self._labels_b_v_all = allgather(self._labels_b_v) # else: self._labels_b_v_all = self._labels_b_v self._y_sel = y_sel self._mask = tf.placeholder(tf.bool, [], name='mask') # global_step = tf.get_variable( # 'global_step', shape=[], dtype=tf.int64, trainable=False) global_step = tf.contrib.framework.get_or_create_global_step() self._global_step = global_step log.info('LR decay steps {}'.format(opt_config.lr_decay_steps)) log.info('LR list {}'.format(opt_config.lr_list)) learn_rate = tf.train.piecewise_constant( global_step, list( np.array(opt_config.lr_decay_steps).astype(np.int64)), list(opt_config.lr_list)) self._learn_rate = learn_rate opt = self.get_optimizer(opt_config.optimizer, learn_rate) # if opt_config.num_gpu > 1: # opt = hvd.DistributedOptimizer(opt) with tf.name_scope('TaskA'): h_a = self.backbone(x, is_training=is_training, ext_wts=ext_wts) self._h_a = h_a # Apply BN ops. bn_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.name_scope('TaskB'): x_b_all = tf.concat([x_b, x_b_v], axis=0) if ext_wts is not None: h_b_all = self.backbone( x_b_all, is_training=is_training, reuse=True, ext_wts=ext_wts) else: h_b_all = self.backbone(x_b_all, is_training=is_training, reuse=True) with tf.name_scope('TaskA'): # Calculates hidden activation size. h_shape = h_a.get_shape() h_size = 1 for ss in h_shape[1:]: h_size *= int(ss) if w_class_a is None: if ext_wts is not None: w_class_a = weight_variable( [h_size, num_classes_a], init_method='numpy', dtype=tf.float32, init_param={'val': np.transpose(ext_wts['w_class_a'])}, wd=config.wd, name='w_class_a') b_class_a = weight_variable([], init_method='numpy', dtype=tf.float32, init_param={'val': ext_wts['b_class_a']}, wd=0e0, name='b_class_a') else: w_class_a = weight_variable([h_size, num_classes_a], init_method='truncated_normal', dtype=tf.float32, init_param={'stddev': 0.01}, wd=bb_config.wd, name='w_class_a') b_class_a = weight_variable([num_classes_a], init_method='constant', init_param={'val': 0.0}, name='b_class_a') self._w_class_a_orig = w_class_a self._b_class_a_orig = b_class_a else: assert b_class_a is not None w_class_a_orig = weight_variable([h_size, num_classes_a], init_method='truncated_normal', dtype=tf.float32, init_param={'stddev': 0.01}, wd=bb_config.wd, name='w_class_a') b_class_a_orig = weight_variable([num_classes_a], init_method='constant', init_param={'val': 0.0}, name='b_class_a') self._w_class_a_orig = w_class_a_orig self._b_class_a_orig = b_class_a_orig self._w_class_a = w_class_a self._b_class_a = b_class_a num_classes_a_dyn = tf.cast(tf.shape(b_class_a)[0], tf.int64) num_classes_a_dyn32 = tf.shape(b_class_a)[0] if proto_config.cosine_a: if proto_config.cosine_tau: if ext_wts is None: init_val = 10.0 else: init_val = ext_wts['tau'][0] tau = weight_variable([], init_method='constant', init_param={'val': init_val}, name='tau') else: tau = tf.constant(1.0) w_class_a_norm = self._normalize(w_class_a, 0) h_a_norm = self._normalize(h_a, 1) dot = tf.matmul(h_a_norm, w_class_a_norm) if ext_wts is not None: dot += b_class_a logits_a = tau * dot else: logits_a = compute_euc(tf.transpose(w_class_a), h_a) self._prediction_a = logits_a # if opt_config.num_gpu > 1: # self._prediction_a_all = allgather(self._prediction_a) # else: self._prediction_a_all = self._prediction_a xent_a = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits_a, labels=y) cost_a = tf.reduce_mean(xent_a, name='xent') self._cost_a = cost_a cost_a += self._decay() correct_a = tf.equal(tf.argmax(logits_a, axis=1), y) self._correct_a = correct_a self._acc_a = tf.reduce_mean(tf.cast(correct_a, cost_a.dtype)) with tf.name_scope('TaskB'): h_b = h_b_all[:tf.shape(x_b)[0]] h_b_v = h_b_all[tf.shape(x_b)[0]:] # Add new axes for the `batch` dimension. h_b_ = tf.expand_dims(h_b, 0) h_b_v_ = tf.expand_dims(h_b_v, 0) y_b_ = tf.expand_dims(y_b, 0) y_b_v_ = tf.expand_dims(y_b_v, 0) if transfer_config.old_and_new: protos_b = self._compute_protos(num_classes_b, h_b_, y_b_ - num_classes_a) else: protos_b = self._compute_protos(num_classes_b, h_b_, y_b_) w_class_a_ = tf.expand_dims(tf.transpose(w_class_a), 0) if proto_config.protos_phi: w_p1 = weight_variable([h_size], init_method='constant', dtype=tf.float32, init_param={'val': 1.0}, wd=bb_config.wd, name='w_p1') if proto_config.cosine_attention: w_q = weight_variable([h_size, h_size], init_method='truncated_normal', dtype=tf.float32, init_param={'stddev': 0.1}, wd=bb_config.wd, name='w_q') k_b = weight_variable([num_classes_a, h_size], init_method='truncated_normal', dtype=tf.float32, init_param={'stddev': 0.1}, wd=bb_config.wd, name='k_b') tau_q = weight_variable([], init_method='constant', init_param={'val': 10.0}, name='tau_q') if transfer_config.old_and_new: w_class_b = self._compute_protos_attend_fix( num_classes_b, h_b_, y_b_ - num_classes_a_dyn, w_q, tau_q, k_b, self._w_class_a_orig) else: w_class_b = self._compute_protos_attend_fix( num_classes_b, h_b_, y_b_, w_q, tau_q, k_b, self._w_class_a_orig) assert proto_config.protos_phi w_p2 = weight_variable([h_size], init_method='constant', dtype=tf.float32, init_param={'val': 1.0}, wd=bb_config.wd, name='w_p2') self._k_b = tf.expand_dims(w_p2, 1) * self._w_class_a_orig self._k_b2 = k_b self.bias = w_class_b self.new_protos = w_p1 * protos_b self.new_bias = w_p2 * w_class_b w_class_b = w_p1 * protos_b + w_p2 * w_class_b self.protos = protos_b self.w_class_b_final = w_class_b else: w_class_b = protos_b if proto_config.protos_phi: w_class_b = w_p1 * w_class_b self._w_class_b = w_class_b if transfer_config.old_and_new: w_class_all = tf.concat([w_class_a_, w_class_b], axis=1) else: w_class_all = w_class_b if proto_config.cosine_softmax_tau: tau_b = weight_variable([], init_method='constant', init_param={'val': 10.0}, name='tau_b') else: tau_b = tf.constant(1.0) if proto_config.similarity == 'euclidean': logits_b_v = compute_logits(w_class_all, h_b_v_) elif proto_config.similarity == 'cosine': logits_b_v = tau_b * compute_logits_cosine(w_class_all, h_b_v_) else: raise ValueError('Unknown similarity') self._logits_b_v = logits_b_v self._prediction_b = logits_b_v[0] # if opt_config.num_gpu > 1: # self._prediction_b_all = allgather(self._prediction_b) # else: self._prediction_b_all = self._prediction_b # Mask out the old classes. def mask_fn(): bin_mask = tf.expand_dims( tf.reduce_sum( tf.one_hot(y_sel, num_classes_a + num_classes_b), 0, keep_dims=True), 0) logits_b_v_m = logits_b_v * (1.0 - bin_mask) logits_b_v_m -= bin_mask * 100.0 return logits_b_v_m # if transfer_config.old_and_new: # logits_b_v = tf.cond(self._mask, mask_fn, lambda: logits_b_v) xent_b_v = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits_b_v, labels=y_b_v_) cost_b = tf.reduce_mean(xent_b_v, name='xent') self._cost_b = cost_b if transfer_config.old_and_new: total_cost = cost_b else: total_cost = (transfer_config.cost_a_ratio * cost_a + transfer_config.cost_b_ratio * cost_b) self._total_cost = total_cost if not transfer_config.meta_only: # assert False, 'let us go for pretrained model first' var_list = tf.trainable_variables() var_list = list(filter(lambda x: 'phi' in x.name, var_list)) layers = self.config.transfer_config.meta_layers if layers == "all": pass elif layers == "4": keywords = ['TaskB', 'unit_4_'] filter_fn = lambda x: any([kw in x.name for kw in keywords]) var_list = list(filter(filter_fn, var_list)) else: raise ValueError('Unknown finetune layers {}'.format(layers)) [log.info('Slow weights {}'.format(v.name)) for v in var_list] else: var_list = [] if proto_config.cosine_softmax_tau: var_list += [tau_b] if proto_config.cosine_attention: var_list += [w_q, tau_q, k_b, w_p2] if proto_config.protos_phi: var_list += [w_p1] if transfer_config.train_wclass_a: if proto_config.similarity == 'euclidean': var_list += [w_class_a, b_class_a] elif proto_config.similarity == 'cosine': var_list += [w_class_a] if is_training: grads_and_vars = opt.compute_gradients(total_cost, var_list) with tf.control_dependencies(bn_ops): [log.info('BN op {}'.format(op.name)) for op in bn_ops] train_op = opt.apply_gradients(grads_and_vars, global_step=global_step) grads_and_vars_b = opt.compute_gradients(cost_b, var_list) with tf.control_dependencies(bn_ops): train_op_b = opt.apply_gradients( grads_and_vars_b, global_step=global_step) with tf.control_dependencies(bn_ops): train_op_a = opt.minimize(cost_a, global_step=global_step) self._train_op = train_op self._train_op_a = train_op_a self._train_op_b = train_op_b self._initializer = tf.global_variables_initializer() self._w_class_a = w_class_a
def predict(self): """See `model.py` for documentation.""" with tf.name_scope('Predict'): self.init_episode_classifier() logits = compute_logits(self.protos, self.h_test) self._logits = [logits]