def residual_block(x, in_filter, out_filter, stride, update_bn=True, activate_before_residual=False): """Adds residual connection to `x` in addition to applying BN->ReLU->3x3 Conv. Args: x: Tensor that is the output of the previous layer in the model. in_filter: Number of filters `x` has. out_filter: Number of filters that the output of this layer will have. stride: Integer that specified what stride should be applied to `x`. activate_before_residual: Boolean on whether a BN->ReLU should be applied to x before the convolution is applied. Returns: A Tensor that is the result of applying two sequences of BN->ReLU->3x3 Conv and then adding that Tensor to `x`. """ if activate_before_residual: # Pass up RELU and BN activation for resnet with tf.variable_scope('shared_activation'): x = ops.batch_norm(x, update_stats=update_bn, scope='init_bn') x = tf.nn.relu(x) orig_x = x else: orig_x = x block_x = x if not activate_before_residual: with tf.variable_scope('residual_only_activation'): block_x = ops.batch_norm(block_x, update_stats=update_bn, scope='init_bn') block_x = tf.nn.relu(block_x) with tf.variable_scope('sub1'): block_x = ops.conv2d(block_x, out_filter, 3, stride=stride, scope='conv1') with tf.variable_scope('sub2'): block_x = ops.batch_norm(block_x, update_stats=update_bn, scope='bn2') block_x = tf.nn.relu(block_x) block_x = ops.conv2d(block_x, out_filter, 3, stride=1, scope='conv2') with tf.variable_scope( 'sub_add'): # If number of filters do not agree then zero pad them if in_filter != out_filter: orig_x = ops.avg_pool(orig_x, stride, stride) orig_x = ops.zero_pad(orig_x, in_filter, out_filter) x = orig_x + block_x return x
def _shake_shake_skip_connection(x, output_filters, stride): """Adds a residual connection to the filter x for the shake-shake model.""" curr_filters = int(x.shape[3]) if curr_filters == output_filters: return x stride_spec = ops.stride_arr(stride, stride) # Skip path 1 path1 = tf.nn.avg_pool( x, [1, 1, 1, 1], stride_spec, 'VALID', data_format='NHWC') path1 = ops.conv2d(path1, int(output_filters / 2), 1, scope='path1_conv') # Skip path 2 # First pad with 0's then crop pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]] path2 = tf.pad(x, pad_arr)[:, 1:, 1:, :] concat_axis = 3 path2 = tf.nn.avg_pool( path2, [1, 1, 1, 1], stride_spec, 'VALID', data_format='NHWC') path2 = ops.conv2d(path2, int(output_filters / 2), 1, scope='path2_conv') # Concat and apply BN final_path = tf.concat(values=[path1, path2], axis=concat_axis) final_path = ops.batch_norm(final_path, scope='final_path_bn') return final_path
def build_shake_shake_model(images, num_classes, hparams, is_training): """Builds the Shake-Shake model. Build the Shake-Shake model from https://arxiv.org/abs/1705.07485. Args: images: Tensor of images that will be fed into the Wide ResNet Model. num_classes: Number of classed that the model needs to predict. hparams: tf.HParams object that contains additional hparams needed to construct the model. In this case it is the `shake_shake_widen_factor` that is used to determine how many filters the model has. is_training: Is the model training or not. Returns: The logits of the Shake-Shake model. """ depth = 26 k = hparams.shake_shake_widen_factor # The widen factor n = int((depth - 2) / 6) x = images x = ops.conv2d(x, 16, 3, scope='init_conv') x = ops.batch_norm(x, scope='init_bn') with tf.variable_scope('L1'): x = _shake_shake_layer(x, 16 * k, n, 1, is_training) with tf.variable_scope('L2'): x = _shake_shake_layer(x, 32 * k, n, 2, is_training) with tf.variable_scope('L3'): x = _shake_shake_layer(x, 64 * k, n, 2, is_training) x = tf.nn.relu(x) x = ops.global_avg_pool(x) # Fully connected logits = ops.fc(x, num_classes) return logits
def build_wrn_model(images, num_classes, wrn_size): """Builds the WRN model. Build the Wide ResNet model from https://arxiv.org/abs/1605.07146. Args: images: Tensor of images that will be fed into the Wide ResNet Model. num_classes: Number of classed that the model needs to predict. wrn_size: Parameter that scales the number of filters in the Wide ResNet model. Returns: The logits of the Wide ResNet model. """ kernel_size = wrn_size filter_size = 3 num_blocks_per_resnet = 4 filters = [ min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4 ] strides = [1, 2, 2] # stride for each resblock # Run the first conv with tf.variable_scope('init'): x = images output_filters = filters[0] x = ops.conv2d(x, output_filters, filter_size, scope='init_conv') first_x = x # Res from the beginning orig_x = x # Res from previous block for block_num in range(1, 4): with tf.variable_scope('unit_{}_0'.format(block_num)): activate_before_residual = True if block_num == 1 else False x = residual_block( x, filters[block_num - 1], filters[block_num], strides[block_num - 1], activate_before_residual=activate_before_residual) for i in range(1, num_blocks_per_resnet): with tf.variable_scope('unit_{}_{}'.format(block_num, i)): x = residual_block(x, filters[block_num], filters[block_num], 1, activate_before_residual=False) x, orig_x = _res_add(filters[block_num - 1], filters[block_num], strides[block_num - 1], x, orig_x) final_stride_val = np.prod(strides) x, _ = _res_add(filters[0], filters[3], final_stride_val, x, first_x) with tf.variable_scope('unit_last'): x = ops.batch_norm(x, scope='final_bn') x = tf.nn.relu(x) x = ops.global_avg_pool(x) logits = ops.fc(x, num_classes) return logits