def init(self, name, with_name=None, outputs=None): """Init model class. Args: name: name of model scope with_name: variable with prefix_name outputs: an output variable of the model """ self.name = name self.updates = self.get_update_ops(name, outputs) assert self.updates self.trainable_variables = utils.get_var(tf.trainable_variables(), name) # batch_norm variables (moving mean and moving average) self.updates_variables = [ a for a in utils.get_var( tf.global_variables(), name, with_name=with_name) if a not in self.trainable_variables ] self.total_vars = len(self.trainable_variables) + len( self.updates_variables) tf.logging.info( '[StrategyNetBase] {}: Find {} trainable_variables and {} update_variables {} updates' .format(name, len(self.trainable_variables), len(self.updates_variables), len(self.updates))) self.created = True
def init(self, name, with_name=None, outputs=None, stop_grad_scope=None): """Init model class. Args: name: name of model scope. with_name: variable with prefix_name of bn update_variables. outputs: an output variable of the model. stop_grad_scope: excluding variables before the scope, e.g. dense. """ self.name = name self.updates = self.get_update_ops(name, outputs) assert self.updates or self.stop_gradient self.trainable_variables = utils.get_var(tf.trainable_variables(), name) if self.stop_gradient: # Arxiv total model variables including frozen ones. self.total_model_variables = self.trainable_variables assert stop_grad_scope ind = 0 for v in self.trainable_variables: if stop_grad_scope in v.name: break ind += 1 self.trainable_variables = self.trainable_variables[ind:] assert self.trainable_variables # batch_norm variables (moving mean and moving average) self.updates_variables = [ a for a in utils.get_var( tf.global_variables(), name, with_name=with_name) if a not in self.trainable_variables ] if self.stop_gradient: # Arxiv total model update variables including frozen ones. self.total_updates_variables = self.updates_variables self.updates_variables = [] self.total_vars = len(self.trainable_variables) + len( self.updates_variables) tf.logging.info( '[StrategyNetBase] {}: Find {} trainable_variables and {} update_variables {} updates' .format(name, len(self.trainable_variables), len(self.updates_variables), len(self.updates))) self.created = True
def meta_momentum_update(self, grad, var_name, optimizer): # Finds corresponding momentum of a var name accumulation = utils.get_var(optimizer.variables(), var_name.split(':')[0]) if len(accumulation) != 1: raise ValueError('length of accumulation {}'.format(len(accumulation))) new_grad = tf.math.add( tf.stop_gradient(accumulation[0]) * FLAGS.meta_momentum, grad) return new_grad
def clean_acc_history(self): """Cleans accumulated counter in metrics.accuracy.""" if not hasattr(self, 'clean_accstate_op'): self.clean_accstate_op = [ a.assign(0) for a in utils.get_var(tf.local_variables(), 'accuracy') ] logging.info('Create {} clean accuracy state ops'.format( len(self.clean_accstate_op))) self.sess.run(self.clean_accstate_op)
def __call__(self, inputs, name, training, reuse=True, custom_getter=None): """Forward pass.""" self.name = name with tf.variable_scope(name, reuse=reuse, custom_getter=custom_getter, dtype=tf.float32): logits = super(ImagenetModelv2, self).__call__(inputs, training) if not isinstance(reuse, bool) or not reuse: # when use AUTO_REUSE, the model is copied at different deivces. # we still want to initialize the following variables. self.regularization_loss = decay_weights( self.wd, utils.get_var(tf.trainable_variables(), name)) self.init(name, with_name='batch_normalization', outputs=logits) self.count_parameters(name) return logits
def __call__(self, images, name, reuse=True, training=True, custom_getter=None): """Builds the WRN model. Build the Wide ResNet model from https://arxiv.org/abs/1605.07146. Args: images: Tensor of images that will be fed into the Wide ResNet Model. name: Name of the model as scope reuse: If True, reuses the parameters. training: If True, for training stage. custom_getter: custom_getter function for variable_scope. Returns: The logits of the Wide ResNet model. """ num_classes = self.num_classes wrn_size = self.wrn_size kernel_size = wrn_size filter_size = 3 num_blocks_per_resnet = 4 filters = [ min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4 ] strides = [1, 2, 2] # stride for each resblock # scopes = setup_arg_scopes(training) # with contextlib.nested(*scopes): with setup_arg_scopes(training): with tf.variable_scope(name, reuse=reuse, custom_getter=custom_getter): # Run the first conv with tf.variable_scope('init'): x = images output_filters = filters[0] x = ops.conv2d(x, output_filters, filter_size, scope='init_conv') first_x = x # Res from the beginning orig_x = x # Res from previous block for block_num in range(1, 4): with tf.variable_scope('unit_{}_0'.format(block_num)): activate_before_residual = True if block_num == 1 else False x = residual_block( x, filters[block_num - 1], filters[block_num], strides[block_num - 1], activate_before_residual=activate_before_residual) for i in range(1, num_blocks_per_resnet): with tf.variable_scope('unit_{}_{}'.format( block_num, i)): x = residual_block(x, filters[block_num], filters[block_num], 1, activate_before_residual=False) x, orig_x = _res_add(filters[block_num - 1], filters[block_num], strides[block_num - 1], x, orig_x) final_stride_val = np.prod(strides) x, _ = _res_add(filters[0], filters[3], final_stride_val, x, first_x) with tf.variable_scope('unit_last'): x = ops.batch_norm(x, scope='final_bn') x = tf.nn.relu(x) x = ops.global_avg_pool(x) logits = ops.fc(x, num_classes) if not isinstance(reuse, bool) or not reuse: self.regularization_loss = decay_weights( self.wd, utils.get_var(tf.trainable_variables(), name)) self.init(name, with_name='moving', outputs=logits) self.count_parameters(name) return logits