示例#1
0
    def init(self, name, with_name=None, outputs=None):
        """Init model class.

    Args:
      name: name of model scope
      with_name: variable with prefix_name
      outputs: an output variable of the model
    """
        self.name = name
        self.updates = self.get_update_ops(name, outputs)
        assert self.updates

        self.trainable_variables = utils.get_var(tf.trainable_variables(),
                                                 name)
        # batch_norm variables (moving mean and moving average)
        self.updates_variables = [
            a for a in utils.get_var(
                tf.global_variables(), name, with_name=with_name)
            if a not in self.trainable_variables
        ]
        self.total_vars = len(self.trainable_variables) + len(
            self.updates_variables)

        tf.logging.info(
            '[StrategyNetBase] {}: Find {} trainable_variables and {} update_variables {} updates'
            .format(name, len(self.trainable_variables),
                    len(self.updates_variables), len(self.updates)))

        self.created = True
示例#2
0
    def init(self, name, with_name=None, outputs=None, stop_grad_scope=None):
        """Init model class.

    Args:
      name: name of model scope.
      with_name: variable with prefix_name of bn update_variables.
      outputs: an output variable of the model.
      stop_grad_scope: excluding variables before the scope, e.g. dense.
    """
        self.name = name
        self.updates = self.get_update_ops(name, outputs)
        assert self.updates or self.stop_gradient

        self.trainable_variables = utils.get_var(tf.trainable_variables(),
                                                 name)

        if self.stop_gradient:
            # Arxiv total model variables including frozen ones.
            self.total_model_variables = self.trainable_variables
            assert stop_grad_scope
            ind = 0
            for v in self.trainable_variables:
                if stop_grad_scope in v.name:
                    break
                ind += 1
            self.trainable_variables = self.trainable_variables[ind:]
            assert self.trainable_variables

        # batch_norm variables (moving mean and moving average)
        self.updates_variables = [
            a for a in utils.get_var(
                tf.global_variables(), name, with_name=with_name)
            if a not in self.trainable_variables
        ]
        if self.stop_gradient:
            # Arxiv total model update variables including frozen ones.
            self.total_updates_variables = self.updates_variables
            self.updates_variables = []

        self.total_vars = len(self.trainable_variables) + len(
            self.updates_variables)

        tf.logging.info(
            '[StrategyNetBase] {}: Find {} trainable_variables and {} update_variables {} updates'
            .format(name, len(self.trainable_variables),
                    len(self.updates_variables), len(self.updates)))

        self.created = True
示例#3
0
 def meta_momentum_update(self, grad, var_name, optimizer):
   # Finds corresponding momentum of a var name
   accumulation = utils.get_var(optimizer.variables(), var_name.split(':')[0])
   if len(accumulation) != 1:
     raise ValueError('length of accumulation {}'.format(len(accumulation)))
   new_grad = tf.math.add(
       tf.stop_gradient(accumulation[0]) * FLAGS.meta_momentum, grad)
   return new_grad
示例#4
0
    def clean_acc_history(self):
        """Cleans accumulated counter in metrics.accuracy."""

        if not hasattr(self, 'clean_accstate_op'):
            self.clean_accstate_op = [
                a.assign(0)
                for a in utils.get_var(tf.local_variables(), 'accuracy')
            ]
            logging.info('Create {} clean accuracy state ops'.format(
                len(self.clean_accstate_op)))
        self.sess.run(self.clean_accstate_op)
示例#5
0
    def __call__(self, inputs, name, training, reuse=True, custom_getter=None):
        """Forward pass."""

        self.name = name
        with tf.variable_scope(name,
                               reuse=reuse,
                               custom_getter=custom_getter,
                               dtype=tf.float32):

            logits = super(ImagenetModelv2, self).__call__(inputs, training)

            if not isinstance(reuse, bool) or not reuse:
                # when use AUTO_REUSE, the model is copied at different deivces.
                # we still want to initialize the following variables.
                self.regularization_loss = decay_weights(
                    self.wd, utils.get_var(tf.trainable_variables(), name))
                self.init(name,
                          with_name='batch_normalization',
                          outputs=logits)
                self.count_parameters(name)

        return logits
示例#6
0
    def __call__(self,
                 images,
                 name,
                 reuse=True,
                 training=True,
                 custom_getter=None):
        """Builds the WRN model.

    Build the Wide ResNet model from https://arxiv.org/abs/1605.07146.

    Args:
      images: Tensor of images that will be fed into the Wide ResNet Model.
      name: Name of the model as scope
      reuse: If True, reuses the parameters.
      training: If True, for training stage.
      custom_getter: custom_getter function for variable_scope.

    Returns:
      The logits of the Wide ResNet model.
    """
        num_classes = self.num_classes
        wrn_size = self.wrn_size

        kernel_size = wrn_size
        filter_size = 3
        num_blocks_per_resnet = 4
        filters = [
            min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4
        ]
        strides = [1, 2, 2]  # stride for each resblock

        # scopes = setup_arg_scopes(training)
        # with contextlib.nested(*scopes):
        with setup_arg_scopes(training):
            with tf.variable_scope(name,
                                   reuse=reuse,
                                   custom_getter=custom_getter):

                # Run the first conv
                with tf.variable_scope('init'):
                    x = images
                    output_filters = filters[0]
                    x = ops.conv2d(x,
                                   output_filters,
                                   filter_size,
                                   scope='init_conv')

                first_x = x  # Res from the beginning
                orig_x = x  # Res from previous block

                for block_num in range(1, 4):
                    with tf.variable_scope('unit_{}_0'.format(block_num)):
                        activate_before_residual = True if block_num == 1 else False
                        x = residual_block(
                            x,
                            filters[block_num - 1],
                            filters[block_num],
                            strides[block_num - 1],
                            activate_before_residual=activate_before_residual)
                    for i in range(1, num_blocks_per_resnet):
                        with tf.variable_scope('unit_{}_{}'.format(
                                block_num, i)):
                            x = residual_block(x,
                                               filters[block_num],
                                               filters[block_num],
                                               1,
                                               activate_before_residual=False)
                    x, orig_x = _res_add(filters[block_num - 1],
                                         filters[block_num],
                                         strides[block_num - 1], x, orig_x)
                final_stride_val = np.prod(strides)
                x, _ = _res_add(filters[0], filters[3], final_stride_val, x,
                                first_x)
                with tf.variable_scope('unit_last'):
                    x = ops.batch_norm(x, scope='final_bn')
                    x = tf.nn.relu(x)
                    x = ops.global_avg_pool(x)
                    logits = ops.fc(x, num_classes)

                if not isinstance(reuse, bool) or not reuse:
                    self.regularization_loss = decay_weights(
                        self.wd, utils.get_var(tf.trainable_variables(), name))
                    self.init(name, with_name='moving', outputs=logits)
                    self.count_parameters(name)
        return logits