def __init__(self, x, y, num_classes, dtype=tf.float32, learn_rate=1e-3):
     x_shape = x.get_shape()
     x_size = 1
     for ss in x_shape[1:]:
         x_size *= int(ss)
     x = tf.reshape(x, [-1, x_size])
     w_class = weight_variable([x_size, num_classes],
                               init_method='truncated_normal',
                               dtype=tf.float32,
                               init_param={'stddev': 0.01},
                               name='w_class')
     b_class = weight_variable([num_classes],
                               init_method='constant',
                               init_param={'val': 0.0},
                               name='b_class')
     logits = tf.matmul(x, w_class) + b_class
     xent = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                           labels=y)
     xent = tf.reduce_mean(xent, name='xent')
     cost = xent
     cost += self._decay()
     self._cost = cost
     self._inputs = x
     self._labels = y
     self._train_op = tf.train.AdamOptimizer(learn_rate).minimize(
         cost, var_list=[w_class, b_class])
     correct = tf.equal(tf.argmax(logits, axis=1), y)
     self._acc = tf.reduce_mean(tf.cast(correct, dtype))
     self._prediction = tf.nn.softmax(logits)
示例#2
0
 def _weight_variable(self,
                      shape,
                      init_method=None,
                      dtype=tf.float32,
                      init_param=None,
                      wd=None,
                      name=None,
                      trainable=True,
                      seed=0):
     """Wrapper to declare variables. Default on CPU."""
     if self._ext_wts is None:
         return weight_variable(shape,
                                init_method=init_method,
                                dtype=dtype,
                                init_param=init_param,
                                wd=wd,
                                name=name,
                                trainable=trainable,
                                seed=seed)
     else:
         assert self._slow_bn, "Must enable slow BN"
         assert name is not None  # Use name to retrieve the variable name
         vs = tf.get_variable_scope()
         var_name = vs.name + '/' + name
         if var_name in self._ext_wts:
             log.info(
                 'Found variable {} in external weights'.format(var_name))
             return self._ext_wts[var_name]
         else:
             log.error('Not found variable {} in external weights'.format(
                 var_name))
             raise ValueError('Variable not found')
 def build_fast_weights(self):
     """Build fast weights for task B."""
     transfer_config = self.config.transfer_config
     num_classes_b = self.num_classes_b
     h_size = self._h_size
     suffix = '' if self.num_classes_b == 5 else '_{}'.format(num_classes_b)
     if transfer_config.fast_model == 'lr':
         w_class_b = weight_variable([h_size, num_classes_b],
                                     dtype=self.dtype,
                                     init_method='constant',
                                     init_param={'val': 0.0},
                                     wd=transfer_config.finetune_wd,
                                     name='w_class_b' + suffix)
         b_class_b = weight_variable([num_classes_b],
                                     dtype=self.dtype,
                                     init_method='constant',
                                     init_param={'val': -1.0},
                                     name='b_class_b' + suffix)
         fast_weights = [w_class_b, b_class_b]
     elif transfer_config.fast_model == 'resmlp':
         mlp_size = transfer_config.fast_mlp_hidden
         w_class_b2 = weight_variable([h_size, mlp_size],
                                      dtype=self.dtype,
                                      init_method='constant',
                                      init_param={'val': 0.0},
                                      wd=transfer_config.finetune_wd,
                                      name='w_class_b2' + suffix)
         b_class_b2 = weight_variable([mlp_size],
                                      dtype=self.dtype,
                                      init_method='constant',
                                      init_param={'val': 0.0},
                                      name='b_class_b2' + suffix)
         w_class_b3 = weight_variable([h_size, num_classes_b],
                                      dtype=self.dtype,
                                      init_method='constant',
                                      init_param={'val': 0.0},
                                      wd=transfer_config.finetune_wd,
                                      name='w_class_b3' + suffix)
         w_class_b = weight_variable([mlp_size, num_classes_b],
                                     dtype=self.dtype,
                                     init_method='constant',
                                     init_param={'val': 0.0},
                                     wd=transfer_config.finetune_wd,
                                     name='w_class_b' + suffix)
         b_class_b = weight_variable([num_classes_b],
                                     dtype=self.dtype,
                                     init_method='constant',
                                     init_param={'val': -1.0},
                                     name='b_class_b' + suffix)
         fast_weights = [
             w_class_b3, w_class_b2, b_class_b2, w_class_b, b_class_b
         ]
     else:
         assert False
     return fast_weights
    def __call__(self, fast_weights, reuse=None, **kwargs):
        """Applies static attractor.

    Args:
      fast_weights: A tuple of two elements.
          - w_b: [D, K]. Weights of the logistic regression.
          - b_b: [K]. Bias of the logistic regression.
      reuse: Bool. Whether to reuse variables.
    """
        with tf.variable_scope("transfer_loss", reuse=reuse):
            w_class_b_reg = self.combine_wb(fast_weights[0], fast_weights[1])
            dtype = fast_weights[0].dtype
            h_size_reg = int(w_class_b_reg.shape[0])
            attr = weight_variable([h_size_reg],
                                   dtype=dtype,
                                   init_method='constant',
                                   init_param={'val': 0.0},
                                   wd=self.config.wd,
                                   name='attr')
            attr_ = tf.expand_dims(attr, 1)
            if self.config.learn_gamma:
                log_gamma = weight_variable(
                    [h_size_reg],
                    init_method='constant',
                    dtype=dtype,
                    init_param={'val': np.log(self.config.init_gamma)},
                    wd=self.config.wd,
                    name='log_gamma')
            else:
                log_gamma = tf.ones([h_size_reg], dtype=dtype) * np.log(
                    self.config.init_gamma)
            log_gamma_ = tf.expand_dims(log_gamma, 1)
            gamma_ = tf.exp(log_gamma_)
            self.gamma = gamma_
        dist = tf.reduce_sum(tf.square(w_class_b_reg - attr_) * gamma_, [0])
        transfer_loss = tf.reduce_mean(dist)
        return transfer_loss
    def __call__(self, fast_weights, reuse=None, **kwargs):
        """Applies static attractor.

    Args:
      fast_weights: A tuple of the following elements:
        w_1: [D, K]. Logistic regression weights.
        w_2: [D, H]. First layer weights.
        b_2: [H]. First layer biases.
        w_3: [H, K]. Second layer weights.
        b_3: [K]. Second layer bias.

      reuse: Bool. Whether to reuse variables.
    """
        with tf.variable_scope("transfer_loss", reuse=reuse):
            w_1 = fast_weights[0]
            w_2 = fast_weights[1]
            b_2 = fast_weights[2]
            w_3 = fast_weights[3]
            b_3 = fast_weights[4]
            dtype = w_1.dtype
            Hplus1 = int(w_3.shape[0]) + 1  # H+1
            Dplus1 = int(w_2.shape[0]) + 1  # D+1
            # Logistic regression weights + biases.
            w_class_b_reg = self.combine_wb(w_3, b_3)  # [H+1, K]
            # First layer weights + biases.
            w_class_b_reg2 = self.combine_wb(w_2, b_2)  # [D+1, H]
            # Second layer weights.
            b_1 = tf.zeros([int(w_1.shape[1])], dtype=dtype)
            w_class_b_reg3 = self.combine_wb(w_1, b_1)  # [D+1, K]
            attr = weight_variable([Hplus1],
                                   init_method='truncated_normal',
                                   dtype=dtype,
                                   init_param={'stddev': 0.01},
                                   wd=self.config.wd,
                                   name='attr')
            attr2 = weight_variable([Dplus1],
                                    init_method='truncated_normal',
                                    dtype=dtype,
                                    init_param={'stddev': 0.01},
                                    wd=self.config.wd,
                                    name='attr2')
            attr3 = weight_variable([Dplus1],
                                    init_method='truncated_normal',
                                    dtype=dtype,
                                    init_param={'stddev': 0.01},
                                    wd=self.config.wd,
                                    name='attr3')
            attr_ = tf.expand_dims(attr, 1)  # [H+1, 1]
            attr2_ = tf.expand_dims(attr2, 1)  # [D+1, 1]
            attr3_ = tf.expand_dims(attr3, 1)  # [D+1, 1]

            init_log_gamma = np.log(self.config.init_gamma)
            if self.config.learn_gamma:
                log_gamma = weight_variable([Hplus1],
                                            init_method='constant',
                                            dtype=dtype,
                                            init_param={'val': init_log_gamma},
                                            wd=self.config.wd,
                                            name='log_gamma')
                log_gamma2 = weight_variable(
                    [Dplus1],
                    init_method='constant',
                    dtype=dtype,
                    init_param={'val': init_log_gamma},
                    wd=self.config.wd,
                    name='log_gamma2')
                log_gamma3 = weight_variable(
                    [Dplus1],
                    init_method='constant',
                    dtype=dtype,
                    init_param={'val': init_log_gamma},
                    wd=self.config.wd,
                    name='log_gamma3')
            else:
                log_gamma = tf.ones([Hplus1], dtype=dtype) * init_log_gamma
                log_gamma2 = tf.ones([Dplus1], dtype=dtype) * init_log_gamma
                log_gamma3 = tf.ones([Dplus1], dtype=dtype) * init_log_gamma
            gamma_ = tf.exp(tf.expand_dims(log_gamma, 1))  # [H+1, 1]
            gamma2_ = tf.exp(tf.expand_dims(log_gamma2, 1))  # [D+1, 1]
            gamma3_ = tf.exp(tf.expand_dims(log_gamma3, 1))  # [D+1, 1]
        loss = tf.reduce_mean(
            tf.reduce_sum(tf.square(w_class_b_reg - attr_) * gamma_, [0]))
        loss += tf.reduce_mean(
            tf.reduce_sum(tf.square(w_class_b_reg2 - attr2_) * gamma2_, [0]))
        loss += tf.reduce_mean(
            tf.reduce_sum(tf.square(w_class_b_reg3 - attr3_) * gamma3_, [0]))
        return loss
示例#6
0
    def __call__(self, fast_weights, is_training=True, reuse=None, **kwargs):
        """Applies attention attractor for MLP with residual connection.

    Args:
      fast_weights: A tuple of the following elements:
        w_1: [D, K]. Logistic regression weights.
        w_2: [D, H]. First layer weights.
        b_2: [H]. First layer biases.
        w_3: [H, K]. Second layer weights.
        b_3: [K]. Second layer bias.

      reuse: Bool. Whether to reuse variables.

      kwargs: Contains the following fields:
        - y_b: Labels of the support examples.
        - h_b: Features of the support examples.
        - w_class_a: Base class weights.
        - mask: Bool flag whether we need to mask the base class weights.
        - y_sel: Binary flag on the base class weights.
    """
        y_b = kwargs['y_b']
        h_b = kwargs['h_b']
        w_class_a = kwargs['w_class_a']
        mask = kwargs['mask']
        y_sel = kwargs['y_sel']
        with tf.variable_scope("transfer_loss", reuse=reuse):
            w_1 = fast_weights[0]
            w_2 = fast_weights[1]
            b_2 = fast_weights[2]
            w_3 = fast_weights[3]
            b_3 = fast_weights[4]
            dtype = w_1.dtype
            Hplus1 = int(w_3.shape[0]) + 1  # H+1
            Dplus1 = int(w_2.shape[0]) + 1  # D+1
            K = int(w_1.shape[1])
            D = int(w_1.shape[0])
            M = self.config.mlp_hidden
            # Logistic regression weights + biases.
            w_class_b_reg = self.combine_wb(w_3, b_3)  # [H+1, K]
            # First layer weights + biases.
            w_class_b_reg2 = self.combine_wb(w_2, b_2)  # [D+1, H]
            # Second layer weights.
            b_1 = tf.zeros([int(w_1.shape[1])], dtype=dtype)
            w_class_b_reg3 = self.combine_wb(w_1, b_1)  # [D+1, K]
            h_size_reg = int(w_class_b_reg.shape[0])  # H+1
            h_size_reg2 = int(w_class_b_reg2.shape[0])  # D+1
            h_size = int(w_class_a.shape[0])  # D
            tau_init = self.config.attn_attr_tau_init
            tau_q = weight_variable([],
                                    init_method='constant',
                                    dtype=dtype,
                                    init_param={'val': tau_init},
                                    name='tau_qq')
            tau_q2 = weight_variable([],
                                     init_method='constant',
                                     dtype=dtype,
                                     init_param={'val': tau_init},
                                     name='tau_qq2')
            Ko = int(w_class_a.shape[1])  # Kold
            h_attend_bias = weight_variable([Hplus1],
                                            dtype=dtype,
                                            init_method='truncated_normal',
                                            init_param={'std': 1e-2},
                                            wd=self.config.wd,
                                            name='h_attend_bias')
            h_attend_bias2 = weight_variable([Dplus1],
                                             dtype=dtype,
                                             init_method='truncated_normal',
                                             init_param={'std': 1e-2},
                                             wd=self.config.wd,
                                             name='h_attend_bias2')
            h_attend_bias3 = weight_variable([Dplus1],
                                             dtype=dtype,
                                             init_method='truncated_normal',
                                             init_param={'std': 1e-2},
                                             wd=self.config.wd,
                                             name='h_attend_bias3')
            assert self.config.mlp_hidden != 0
            w_kb = weight_variable([D, M],
                                   init_method='truncated_normal',
                                   dtype=dtype,
                                   init_param={'stddev': self.config.mlp_init},
                                   wd=self.config.wd,
                                   name='w_kb')
            b_kb = weight_variable([M],
                                   init_method='constant',
                                   dtype=dtype,
                                   init_param={'val': 0.0},
                                   wd=self.config.wd,
                                   name='b_kb')
            w_kb21 = weight_variable(
                [M, Hplus1],
                init_method='truncated_normal',
                dtype=dtype,
                init_param={'stddev': self.config.mlp_init},
                wd=self.config.wd,
                name='w_kb21')
            b_kb21 = weight_variable([Hplus1],
                                     init_method='constant',
                                     dtype=dtype,
                                     init_param={'val': 0.0},
                                     wd=self.config.wd,
                                     name='b_kb21')
            w_kb22 = weight_variable(
                [M, 2 * Dplus1],
                init_method='truncated_normal',
                dtype=dtype,
                init_param={'stddev': self.config.mlp_init},
                wd=self.config.wd,
                name='w_kb22')
            b_kb22 = weight_variable([2 * Dplus1],
                                     init_method='constant',
                                     dtype=dtype,
                                     init_param={'val': 0.0},
                                     wd=self.config.wd,
                                     name='b_kb22')

            w_class_a_mask = tf.cond(mask,
                                     self._get_mask_fn(w_class_a, y_sel, Ko),
                                     lambda: w_class_a)
            # [Ko, D+1] -> [Ko, M]
            kbz = tf.tanh(tf.matmul(tf.transpose(w_class_a_mask), w_kb) + b_kb)
            # [Ko, M] -> [Ko, H+1]
            k_b = tf.matmul(kbz, w_kb21) + b_kb21
            # [Ko, M] -> [Ko, 2(D+1)]
            k_b2 = tf.matmul(kbz, w_kb22) + b_kb22
            k_b = tf.transpose(k_b)  # [H+1, Ko]
            k_b2 = tf.transpose(k_b2)  # [2(D+2), Ko]
            k_b_mask = tf.cond(mask, self._get_mask_fn(k_b, y_sel, Ko),
                               lambda: k_b)
            k_b2_mask = tf.cond(mask, self._get_mask_fn(k_b2, y_sel, Ko),
                                lambda: k_b2)

            if self.config.old_and_new:
                y_b = y_b - Ko
            protos = self._compute_protos(K, h_b, y_b)  # [K, D]
            if is_training:
                protos = tf.nn.dropout(protos, keep_prob=0.9)
            protos_norm = self._normalize(protos, axis=1)  # [K, D]
            episode_mean = tf.reduce_mean(h_b, [0], keepdims=True)  # [1, D]
            episode_norm = self._normalize(episode_mean, axis=1)  # [1, D]
            w_class_a_norm = self._normalize(w_class_a_mask, axis=0)  # [D, Ko]
            h_dot_w = tf.matmul(protos_norm, w_class_a_norm)  # [K, Ko]
            e_dot_w = tf.matmul(episode_norm, w_class_a_norm)  # [1, Ko]
            h_dot_w *= tau_q  # [K, Ko]
            e_dot_w *= tau_q2  # [1, Ko]
            proto_attend = tf.nn.softmax(h_dot_w)  # [K, Ko]
            episode_attend = tf.nn.softmax(e_dot_w)  # [1, Ko]
            k_b3 = k_b2[h_size_reg2:, :]  # [D+1, Ko]
            k_b2 = k_b2[:h_size_reg2, :]  # [D+1, Ko]
            h_attend = tf.matmul(proto_attend, k_b,
                                 transpose_b=True)  # [K, H+1]
            h_attend2 = tf.matmul(episode_attend, k_b2,
                                  transpose_b=True)  # [1, D+1]
            h_attend3 = tf.matmul(proto_attend, k_b3,
                                  transpose_b=True)  # [K, D+1]
            attr_ = tf.transpose(h_attend +
                                 h_attend_bias)  # [K, H+1] -> [H+1, K]
            attr2_ = tf.transpose(h_attend2 +
                                  h_attend_bias2)  # [1, D+1] -> [D+1, 1]
            attr3_ = tf.transpose(h_attend3 +
                                  h_attend_bias3)  # [K, D+1] -> [D+1, K]
        with tf.variable_scope("new_loss", reuse=reuse):
            init_log_gamma = np.log(self.config.init_gamma)
            if self.config.learn_gamma:
                log_gamma = weight_variable([Hplus1],
                                            init_method='constant',
                                            dtype=dtype,
                                            init_param={'val': init_log_gamma},
                                            wd=self.config.wd,
                                            name='log_gamma')
                log_gamma2 = weight_variable(
                    [Dplus1],
                    init_method='constant',
                    dtype=dtype,
                    init_param={'val': init_log_gamma},
                    wd=self.config.wd,
                    name='log_gamma2')
                log_gamma3 = weight_variable(
                    [Dplus1],
                    init_method='constant',
                    dtype=dtype,
                    init_param={'val': init_log_gamma},
                    wd=self.config.wd,
                    name='log_gamma3')
            else:
                log_gamma = tf.ones([Hplus1], dtype=dtype) * init_log_gamma
                log_gamma2 = tf.ones([Dplus1], dtype=dtype) * init_log_gamma
                log_gamma3 = tf.ones([Dplus1], dtype=dtype) * init_log_gamma
            gamma_ = tf.exp(tf.expand_dims(log_gamma, 1))  # [H+1, 1]
            gamma2_ = tf.exp(tf.expand_dims(log_gamma2, 1))  # [D+1, 1]
            gamma3_ = tf.exp(tf.expand_dims(log_gamma3, 1))  # [D+1, 1]

        loss = tf.reduce_mean(
            tf.reduce_sum(tf.square(w_class_b_reg - attr_) * gamma_, [0]))
        loss += tf.reduce_mean(
            tf.reduce_sum(tf.square(w_class_b_reg2 - attr2_) * gamma2_, [0]))
        loss += tf.reduce_mean(
            tf.reduce_sum(tf.square(w_class_b_reg3 - attr3_) * gamma3_, [0]))
        return loss
    def __init__(self,
                 config,
                 x,
                 y,
                 num_classes,
                 is_training=True,
                 dtype=tf.float32):
        """Constructor.

    Args:
      config:
      x:
      y:
      num_classes:
    """
        h, _ = cnn(x,
                   config.filter_size,
                   strides=config.strides,
                   pool_fn=[tf.nn.max_pool] * len(config.pool_fn),
                   pool_size=config.pool_size,
                   pool_strides=config.pool_strides,
                   act_fn=[tf.nn.relu for aa in config.conv_act_fn],
                   add_bias=True,
                   init_std=config.conv_init_std,
                   init_method=config.conv_init_method,
                   wd=config.wd,
                   dtype=dtype,
                   batch_norm=True,
                   is_training=is_training,
                   ext_wts=None)
        h_shape = h.get_shape()
        h_size = 1
        for ss in h_shape[1:]:
            h_size *= int(ss)
        h = tf.reshape(h, [-1, h_size])
        w_class = weight_variable([h_size, num_classes],
                                  init_method='truncated_normal',
                                  dtype=tf.float32,
                                  init_param={'stddev': 0.01},
                                  name='w_class')
        b_class = weight_variable([num_classes],
                                  init_method='constant',
                                  init_param={'val': 0.0},
                                  name='b_class')
        self._feature = h
        logits = tf.matmul(h, w_class) + b_class
        xent = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                              labels=y)
        xent = tf.reduce_mean(xent, name='xent')
        cost = xent
        cost += self._decay()
        self._cost = cost
        self._inputs = x
        self._labels = y
        global_step = tf.get_variable('global_step',
                                      shape=[],
                                      dtype=tf.int64,
                                      trainable=False)
        # Learning rate decay.
        learn_rate = tf.train.piecewise_constant(
            global_step,
            list(np.array(config.lr_decay_steps).astype(np.int64)),
            [config.learn_rate] + list(config.lr_list))
        self._learn_rate = learn_rate
        self._train_op = tf.train.AdamOptimizer(learn_rate).minimize(
            cost, global_step=global_step)

        correct = tf.equal(tf.argmax(logits, axis=1), y)
        self._acc = tf.reduce_mean(tf.cast(correct, dtype))
  def __init__(self,
               config,
               x,
               y,
               x_b,
               y_b,
               x_b_v,
               y_b_v,
               num_classes_a,
               num_classes_b,
               is_training=True,
               ext_wts=None,
               y_sel=None,
               w_class_a=None,
               b_class_a=None,
               nshot=None):
    self._config = config
    self._is_training = is_training
    self._num_classes_a = num_classes_a
    self._num_classes_b = num_classes_b

    if config.backbone_class == 'resnet_backbone':
      bb_config = config.resnet_config
    else:
      assert False, 'Not supported'
    opt_config = config.optimizer_config
    proto_config = config.protonet_config
    transfer_config = config.transfer_config

    self._backbone = get_model(config.backbone_class, bb_config)
    self._inputs = x
    self._labels = y
    # if opt_config.num_gpu > 1:
    #   self._labels_all = allgather(self._labels)
    # else:
    self._labels_all = self._labels
    self._inputs_b = x_b
    self._labels_b = y_b
    self._inputs_b_v = x_b_v
    self._labels_b_v = y_b_v
    # if opt_config.num_gpu > 1:
    #   self._labels_b_v_all = allgather(self._labels_b_v)
    # else:
    self._labels_b_v_all = self._labels_b_v
    self._y_sel = y_sel
    self._mask = tf.placeholder(tf.bool, [], name='mask')

    # global_step = tf.get_variable(
    #     'global_step', shape=[], dtype=tf.int64, trainable=False)
    global_step = tf.contrib.framework.get_or_create_global_step()
    self._global_step = global_step
    log.info('LR decay steps {}'.format(opt_config.lr_decay_steps))
    log.info('LR list {}'.format(opt_config.lr_list))
    learn_rate = tf.train.piecewise_constant(
        global_step, list(
            np.array(opt_config.lr_decay_steps).astype(np.int64)),
        list(opt_config.lr_list))
    self._learn_rate = learn_rate

    opt = self.get_optimizer(opt_config.optimizer, learn_rate)
    # if opt_config.num_gpu > 1:
    #   opt = hvd.DistributedOptimizer(opt)

    with tf.name_scope('TaskA'):
      h_a = self.backbone(x, is_training=is_training, ext_wts=ext_wts)
      self._h_a = h_a

    # Apply BN ops.
    bn_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.name_scope('TaskB'):
      x_b_all = tf.concat([x_b, x_b_v], axis=0)
      if ext_wts is not None:
        h_b_all = self.backbone(
            x_b_all, is_training=is_training, reuse=True, ext_wts=ext_wts)
      else:
        h_b_all = self.backbone(x_b_all, is_training=is_training, reuse=True)

    with tf.name_scope('TaskA'):
      # Calculates hidden activation size.
      h_shape = h_a.get_shape()
      h_size = 1
      for ss in h_shape[1:]:
        h_size *= int(ss)

      if w_class_a is None:
        if ext_wts is not None:
          w_class_a = weight_variable(
              [h_size, num_classes_a],
              init_method='numpy',
              dtype=tf.float32,
              init_param={'val': np.transpose(ext_wts['w_class_a'])},
              wd=config.wd,
              name='w_class_a')
          b_class_a = weight_variable([],
                                      init_method='numpy',
                                      dtype=tf.float32,
                                      init_param={'val': ext_wts['b_class_a']},
                                      wd=0e0,
                                      name='b_class_a')
        else:
          w_class_a = weight_variable([h_size, num_classes_a],
                                      init_method='truncated_normal',
                                      dtype=tf.float32,
                                      init_param={'stddev': 0.01},
                                      wd=bb_config.wd,
                                      name='w_class_a')
          b_class_a = weight_variable([num_classes_a],
                                      init_method='constant',
                                      init_param={'val': 0.0},
                                      name='b_class_a')
        self._w_class_a_orig = w_class_a
        self._b_class_a_orig = b_class_a
      else:
        assert b_class_a is not None
        w_class_a_orig = weight_variable([h_size, num_classes_a],
                                         init_method='truncated_normal',
                                         dtype=tf.float32,
                                         init_param={'stddev': 0.01},
                                         wd=bb_config.wd,
                                         name='w_class_a')
        b_class_a_orig = weight_variable([num_classes_a],
                                         init_method='constant',
                                         init_param={'val': 0.0},
                                         name='b_class_a')
        self._w_class_a_orig = w_class_a_orig
        self._b_class_a_orig = b_class_a_orig

      self._w_class_a = w_class_a
      self._b_class_a = b_class_a
      num_classes_a_dyn = tf.cast(tf.shape(b_class_a)[0], tf.int64)
      num_classes_a_dyn32 = tf.shape(b_class_a)[0]

      if proto_config.cosine_a:
        if proto_config.cosine_tau:
          if ext_wts is None:
            init_val = 10.0
          else:
            init_val = ext_wts['tau'][0]
          tau = weight_variable([],
                                init_method='constant',
                                init_param={'val': init_val},
                                name='tau')
        else:
          tau = tf.constant(1.0)
        w_class_a_norm = self._normalize(w_class_a, 0)
        h_a_norm = self._normalize(h_a, 1)
        dot = tf.matmul(h_a_norm, w_class_a_norm)
        if ext_wts is not None:
          dot += b_class_a
        logits_a = tau * dot
      else:
        logits_a = compute_euc(tf.transpose(w_class_a), h_a)
      self._prediction_a = logits_a
      # if opt_config.num_gpu > 1:
      #   self._prediction_a_all = allgather(self._prediction_a)
      # else:
      self._prediction_a_all = self._prediction_a

      xent_a = tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=logits_a, labels=y)
      cost_a = tf.reduce_mean(xent_a, name='xent')
      self._cost_a = cost_a
      cost_a += self._decay()
      correct_a = tf.equal(tf.argmax(logits_a, axis=1), y)
      self._correct_a = correct_a
      self._acc_a = tf.reduce_mean(tf.cast(correct_a, cost_a.dtype))

    with tf.name_scope('TaskB'):
      h_b = h_b_all[:tf.shape(x_b)[0]]
      h_b_v = h_b_all[tf.shape(x_b)[0]:]

      # Add new axes for the `batch` dimension.
      h_b_ = tf.expand_dims(h_b, 0)
      h_b_v_ = tf.expand_dims(h_b_v, 0)
      y_b_ = tf.expand_dims(y_b, 0)
      y_b_v_ = tf.expand_dims(y_b_v, 0)

      if transfer_config.old_and_new:
        protos_b = self._compute_protos(num_classes_b, h_b_,
                                        y_b_ - num_classes_a)
      else:
        protos_b = self._compute_protos(num_classes_b, h_b_, y_b_)

      w_class_a_ = tf.expand_dims(tf.transpose(w_class_a), 0)
      if proto_config.protos_phi:
        w_p1 = weight_variable([h_size],
                               init_method='constant',
                               dtype=tf.float32,
                               init_param={'val': 1.0},
                               wd=bb_config.wd,
                               name='w_p1')
      if proto_config.cosine_attention:
        w_q = weight_variable([h_size, h_size],
                              init_method='truncated_normal',
                              dtype=tf.float32,
                              init_param={'stddev': 0.1},
                              wd=bb_config.wd,
                              name='w_q')
        k_b = weight_variable([num_classes_a, h_size],
                              init_method='truncated_normal',
                              dtype=tf.float32,
                              init_param={'stddev': 0.1},
                              wd=bb_config.wd,
                              name='k_b')
        tau_q = weight_variable([],
                                init_method='constant',
                                init_param={'val': 10.0},
                                name='tau_q')
        if transfer_config.old_and_new:
          w_class_b = self._compute_protos_attend_fix(
              num_classes_b, h_b_, y_b_ - num_classes_a_dyn, w_q, tau_q, k_b,
              self._w_class_a_orig)
        else:
          w_class_b = self._compute_protos_attend_fix(
              num_classes_b, h_b_, y_b_, w_q, tau_q, k_b, self._w_class_a_orig)
        assert proto_config.protos_phi
        w_p2 = weight_variable([h_size],
                               init_method='constant',
                               dtype=tf.float32,
                               init_param={'val': 1.0},
                               wd=bb_config.wd,
                               name='w_p2')
        self._k_b = tf.expand_dims(w_p2, 1) * self._w_class_a_orig
        self._k_b2 = k_b
        self.bias = w_class_b
        self.new_protos = w_p1 * protos_b
        self.new_bias = w_p2 * w_class_b
        w_class_b = w_p1 * protos_b + w_p2 * w_class_b
        self.protos = protos_b
        self.w_class_b_final = w_class_b
      else:
        w_class_b = protos_b
        if proto_config.protos_phi:
          w_class_b = w_p1 * w_class_b

      self._w_class_b = w_class_b

      if transfer_config.old_and_new:
        w_class_all = tf.concat([w_class_a_, w_class_b], axis=1)
      else:
        w_class_all = w_class_b

      if proto_config.cosine_softmax_tau:
        tau_b = weight_variable([],
                                init_method='constant',
                                init_param={'val': 10.0},
                                name='tau_b')
      else:
        tau_b = tf.constant(1.0)

      if proto_config.similarity == 'euclidean':
        logits_b_v = compute_logits(w_class_all, h_b_v_)
      elif proto_config.similarity == 'cosine':
        logits_b_v = tau_b * compute_logits_cosine(w_class_all, h_b_v_)
      else:
        raise ValueError('Unknown similarity')
      self._logits_b_v = logits_b_v
      self._prediction_b = logits_b_v[0]
      # if opt_config.num_gpu > 1:
      #   self._prediction_b_all = allgather(self._prediction_b)
      # else:
      self._prediction_b_all = self._prediction_b

      # Mask out the old classes.
      def mask_fn():
        bin_mask = tf.expand_dims(
            tf.reduce_sum(
                tf.one_hot(y_sel, num_classes_a + num_classes_b),
                0,
                keep_dims=True), 0)
        logits_b_v_m = logits_b_v * (1.0 - bin_mask)
        logits_b_v_m -= bin_mask * 100.0
        return logits_b_v_m

      # if transfer_config.old_and_new:
      #   logits_b_v = tf.cond(self._mask, mask_fn, lambda: logits_b_v)
      xent_b_v = tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=logits_b_v, labels=y_b_v_)
      cost_b = tf.reduce_mean(xent_b_v, name='xent')
      self._cost_b = cost_b

    if transfer_config.old_and_new:
      total_cost = cost_b
    else:
      total_cost = (transfer_config.cost_a_ratio * cost_a +
                    transfer_config.cost_b_ratio * cost_b)
    self._total_cost = total_cost

    if not transfer_config.meta_only:
      # assert False, 'let us go for pretrained model first'
      var_list = tf.trainable_variables()
      var_list = list(filter(lambda x: 'phi' in x.name, var_list))
      layers = self.config.transfer_config.meta_layers
      if layers == "all":
        pass
      elif layers == "4":
        keywords = ['TaskB', 'unit_4_']
        filter_fn = lambda x: any([kw in x.name for kw in keywords])
        var_list = list(filter(filter_fn, var_list))
      else:
        raise ValueError('Unknown finetune layers {}'.format(layers))
      [log.info('Slow weights {}'.format(v.name)) for v in var_list]
    else:
      var_list = []

    if proto_config.cosine_softmax_tau:
      var_list += [tau_b]

    if proto_config.cosine_attention:
      var_list += [w_q, tau_q, k_b, w_p2]

    if proto_config.protos_phi:
      var_list += [w_p1]

    if transfer_config.train_wclass_a:
      if proto_config.similarity == 'euclidean':
        var_list += [w_class_a, b_class_a]
      elif proto_config.similarity == 'cosine':
        var_list += [w_class_a]

    if is_training:
      grads_and_vars = opt.compute_gradients(total_cost, var_list)
      with tf.control_dependencies(bn_ops):
        [log.info('BN op {}'.format(op.name)) for op in bn_ops]
        train_op = opt.apply_gradients(grads_and_vars, global_step=global_step)

      grads_and_vars_b = opt.compute_gradients(cost_b, var_list)
      with tf.control_dependencies(bn_ops):
        train_op_b = opt.apply_gradients(
            grads_and_vars_b, global_step=global_step)

      with tf.control_dependencies(bn_ops):
        train_op_a = opt.minimize(cost_a, global_step=global_step)
      self._train_op = train_op
      self._train_op_a = train_op_a
      self._train_op_b = train_op_b
    self._initializer = tf.global_variables_initializer()
    self._w_class_a = w_class_a
示例#9
0
    def __call__(self, fast_weights, is_training=True, reuse=None, **kwargs):
        """Applies attention attractor.

    Args:
      fast_weights: A tuple of two elements.
        - w_b: [D, K]. Weights of the logistic regression.
        - b_b: [K]. Bias of the logistic regression.

      reuse: Bool. Whether to reuse variables.

      kwargs: Contains the following fields:
        - y_b: Labels of the support examples.
        - h_b: Features of the support examples.
        - w_class_a: Base class weights.
        - mask: Bool flag whether we need to mask the base class weights.
        - y_sel: Binary flag on the base class weights.
    """
        y_b = kwargs['y_b']
        h_b = kwargs['h_b']
        w_class_a = kwargs['w_class_a']
        mask = kwargs['mask']
        y_sel = kwargs['y_sel']
        dtype = h_b.dtype
        with tf.variable_scope("transfer_loss", reuse=reuse):
            w_class_b_reg = self.combine_wb(fast_weights[0], fast_weights[1])
            h_size_reg = int(w_class_b_reg.shape[0])
            h_size = int(w_class_a.shape[0])
            tau_qq = weight_variable(
                [],
                dtype=dtype,
                init_method='constant',
                init_param={'val': self.config.attn_attr_tau_init},
                name='tau_qq')
            h_attend_bias = weight_variable(
                [h_size_reg],
                dtype=dtype,
                init_method='truncated_normal',
                init_param={'stddev': 1e-2},
                wd=self.config.wd,  # wasn't there before.
                name='h_attend_bias')
            num_classes_a = int(w_class_a.shape[1])
            num_classes_b = int(w_class_b_reg.shape[1])
            assert self.config.mlp_hidden != 0
            w_kb = weight_variable([h_size, self.config.mlp_hidden],
                                   init_method='truncated_normal',
                                   dtype=dtype,
                                   init_param={'stddev': self.config.mlp_init},
                                   wd=self.config.wd,
                                   name='w_kb')
            b_kb = weight_variable([self.config.mlp_hidden],
                                   init_method='constant',
                                   dtype=dtype,
                                   init_param={'val': 0.0},
                                   wd=self.config.wd,
                                   name='b_kb')
            w_kb2 = weight_variable(
                [self.config.mlp_hidden, h_size_reg],
                init_method='truncated_normal',
                dtype=dtype,
                init_param={'stddev': self.config.mlp_init},
                wd=self.config.wd,
                name='w_kb2')
            b_kb2 = weight_variable([h_size_reg],
                                    init_method='constant',
                                    dtype=dtype,
                                    init_param={'val': 0.0},
                                    wd=self.config.wd,
                                    name='b_kb2')
            w_class_a_mask = tf.cond(
                mask, self._get_mask_fn(w_class_a, y_sel, num_classes_a),
                lambda: w_class_a)
            k_b = tf.matmul(
                tf.tanh(tf.matmul(tf.transpose(w_class_a_mask), w_kb) + b_kb),
                w_kb2) + b_kb2
            self._k_b = k_b
            k_b = tf.transpose(k_b)
            k_b_mask = tf.cond(mask,
                               self._get_mask_fn(k_b, y_sel, num_classes_a),
                               lambda: k_b)

            if self.config.old_and_new:
                attended_h = self._compute_protos_attend(
                    num_classes_b,
                    h_b,
                    y_b - num_classes_a,
                    tau_qq,
                    h_attend_bias,
                    k_b_mask,
                    w_class_a_mask,
                    is_training=is_training)
            else:
                attended_h = self._compute_protos_attend5_fix(
                    num_classes_b,
                    h_b,
                    y_b,
                    tau_qq,
                    h_attend_bias,
                    k_b_mask,
                    w_class_a_mask,
                    is_training=is_training)
            self.attended_h = attended_h
            self.h_b = h_b

            # Cache the value of the attended features.
            if self.config.cache_transfer_loss_var:
                self._transfer_loss_var = attended_h
                tloss_var_plh = tf.placeholder(dtype, [None, h_size_reg],
                                               name='transfer_loss_var_plh')
                self._transfer_loss_var_plh = tloss_var_plh
                attended_h = tloss_var_plh
        with tf.variable_scope("new_loss", reuse=reuse):
            if self.config.learn_gamma:
                log_gamma = weight_variable(
                    [h_size_reg],
                    init_method='constant',
                    dtype=dtype,
                    init_param={'val': np.log(self.config.init_gamma)},
                    wd=self.config.wd,
                    name='log_gamma')
            else:
                log_gamma = tf.ones([h_size_reg], dtype=dtype) * np.log(
                    self.config.init_gamma)
            log_gamma_ = tf.expand_dims(log_gamma, 1)
            gamma_ = tf.exp(log_gamma_)
            self.gamma = gamma_

        # [D, K2] and [K2, D]
        dist = tf.reduce_sum(
            tf.square(w_class_b_reg - tf.transpose(attended_h)) * gamma_, [0])
        transfer_loss = tf.reduce_mean(dist)
        return transfer_loss
示例#10
0
    def build_task_a(self, x, y, is_training, ext_wts=None):
        print('build_task_a')
        """Build task A branch.

    Args:
      x: Tensor. [N, H, W, C]. Inputs tensor.
      y: Tensor. [N]. Labels tensor.
      is_training: Bool. Whether in training mode.
      ext_wts: Dict. External weights dictionary.
      opt: Optimizer object.
    """
        config = self.config
        global_step = self.global_step
        if config.backbone_class == 'resnet_backbone' or config.backbone_class == 'resnet_backbone_metaCNN':
            bb_config = config.resnet_config
        else:
            assert False, 'Not supported'
        proto_config = config.protonet_config
        opt_config = config.optimizer_config
        num_classes_a = self._num_classes_a

        # Classification branch for task A.
        h_a, h_meta = self._run_backbone(x,
                                         is_training=is_training,
                                         ext_wts=ext_wts)
        self._h_a = h_a
        h_shape = h_a.get_shape()
        h_size = 1
        for ss in h_shape[1:]:
            h_size *= int(ss)
        self._h_size = h_size

        if ext_wts is not None:
            w_class_a = weight_variable(
                [h_size, num_classes_a],
                init_method='numpy',
                dtype=self.dtype,
                init_param={'val': np.transpose(ext_wts['w_class_a'])},
                wd=bb_config.wd,
                name='w_class_a')
            b_class_a = weight_variable(
                [],
                init_method='numpy',
                dtype=self.dtype,
                init_param={'val': ext_wts['b_class_a']},
                wd=0e0,
                name='b_class_a')
        else:
            w_class_a = weight_variable([h_size, num_classes_a],
                                        init_method='truncated_normal',
                                        dtype=self.dtype,
                                        init_param={'stddev': 0.01},
                                        wd=bb_config.wd,
                                        name='w_class_a')
            b_class_a = weight_variable([num_classes_a],
                                        dtype=self.dtype,
                                        init_method='constant',
                                        init_param={'val': 0.0},
                                        name='b_class_a')
        self._w_class_a = w_class_a
        self._b_class_a = b_class_a
        num_classes_a_dyn = tf.cast(tf.shape(b_class_a)[0], tf.int64)
        num_classes_a_dyn32 = tf.shape(b_class_a)[0]

        if proto_config.cosine_a:
            if proto_config.cosine_tau:
                if ext_wts is None:
                    tau_init_val = 10.0
                else:
                    tau_init_val = ext_wts['tau'][0]
                tau = weight_variable([],
                                      dtype=self.dtype,
                                      init_method='constant',
                                      init_param={'val': tau_init_val},
                                      name='tau')
            else:
                tau = tf.constant(1.0)

            w_class_a_norm = self._normalize(w_class_a, axis=0)
            h_a_norm = self._normalize(h_a, axis=1)
            dot = tf.matmul(h_a_norm, w_class_a_norm)
            if ext_wts is not None:
                dot += b_class_a
            logits_a = tau * dot
        else:
            # logits_a = tf.matmul(h_a, w_class_a) + b_class_a
            logits_a = compute_euc(tf.transpose(w_class_a), h_a)

        self._prediction_a = logits_a
        self._prediction_a_all = self._prediction_a
        y_dense = tf.one_hot(y, num_classes_a)
        xent_a = tf.nn.softmax_cross_entropy_with_logits(logits=logits_a,
                                                         labels=y_dense)
        xent_a = tf.reduce_mean(xent_a, name='xent')
        cost_a = xent_a
        self._cost_a = cost_a
        cost_a += self._decay()
        self._prediction_a = logits_a
        print('build_task_a done')
        return logits_a