def inception_block( x: TensorType, specs: Tuple, channel_axis: int, name: str, weight_suffix: Optional[str] = "weights", conv_suffix: Optional[str] = "", norm_suffix: Optional[str] = "/BatchNorm") -> TensorType: """Inception block. Args: x: input_tensor specs: Number of filters per branch. channel_axis: channel dimension name: Prefix for ths block. weight_suffix: Name of learn-able parameters in conv. conv_suffix: Suffix for conv layer. norm_suffix: Suffix for batch norm. Returns: Concatenated output of inception block. """ (br0, br1, br2, br3) = specs # ((64,), (96,128), (16,32), (32,)) branch_0 = conv_norm_relu(x, br0[0], 1, 1, name=name + "/Branch_0/Conv2d_0a_1x1", weight_suffix=weight_suffix, conv_suffix=conv_suffix, norm_suffix=norm_suffix) branch_1 = conv_norm_relu(x, br1[0], 1, 1, name=name + "/Branch_1/Conv2d_0a_1x1", weight_suffix=weight_suffix, conv_suffix=conv_suffix, norm_suffix=norm_suffix) branch_1 = conv_norm_relu(branch_1, br1[1], 3, 3, name=name + "/Branch_1/Conv2d_0b_3x3", weight_suffix=weight_suffix, conv_suffix=conv_suffix, norm_suffix=norm_suffix) branch_2 = conv_norm_relu(x, br2[0], 1, 1, name=name + "/Branch_2/Conv2d_0a_1x1", weight_suffix=weight_suffix, conv_suffix=conv_suffix, norm_suffix=norm_suffix) if name == "InceptionV1/Mixed_5b": branch_2b_name = name + "/Branch_2/Conv2d_0a_3x3" else: branch_2b_name = name + "/Branch_2/Conv2d_0b_3x3" branch_2 = conv_norm_relu(branch_2, br2[1], 3, 3, name=branch_2b_name, weight_suffix=weight_suffix, conv_suffix=conv_suffix, norm_suffix=norm_suffix) branch_3 = max_pool(x, 3, strides=1, padding='same', name=name + "/Branch_3/Conv2d_0a_max") branch_3 = conv_norm_relu(branch_3, br3[0], 1, 1, name=name + "/Branch_3/Conv2d_0b_1x1", weight_suffix=weight_suffix, conv_suffix=conv_suffix, norm_suffix=norm_suffix) x = concat([branch_0, branch_1, branch_2, branch_3], axis=channel_axis, name=name + "_Concatenated") return x
def build_model(self, img_input: TensorType) -> TensorType: """Build graph using img_input as input. Args: img_input: 4D Image input tensor of shape (batch, height, width, channels) Returns: `Tensor` holding output probabilities per class, shape (batch, num_classes) """ x = conv_norm_relu(img_input, 32, 3, strides=2, padding='VALID') x = conv_norm_relu(x, 32, 3, padding='VALID') x = conv_norm_relu( x, 64, 3, ) x = max_pool(x, 3, strides=2) x = conv_norm_relu(x, 80, 1, padding='VALID') x = conv_norm_relu(x, 192, 3, padding='VALID') x = max_pool(x, 3, strides=2) # mixed 0: 35 x 35 x 256 branch1x1 = conv_norm_relu(x, 64, 1) branch5x5 = conv_norm_relu(x, 48, 1) branch5x5 = conv_norm_relu(branch5x5, 64, 5) branch3x3dbl = conv_norm_relu(x, 64, 1) branch3x3dbl = conv_norm_relu(branch3x3dbl, 96, 3) branch3x3dbl = conv_norm_relu(branch3x3dbl, 96, 3) branch_pool = avg_pool(x, 3, strides=1, padding='SAME') branch_pool = conv_norm_relu(branch_pool, 32, 1) x = concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=-1, name='mixed0') # mixed 1: 35 x 35 x 288 branch1x1 = conv_norm_relu(x, 64, 1) branch5x5 = conv_norm_relu(x, 48, 1) branch5x5 = conv_norm_relu(branch5x5, 64, 5) branch3x3dbl = conv_norm_relu(x, 64, 1) branch3x3dbl = conv_norm_relu(branch3x3dbl, 96, 3) branch3x3dbl = conv_norm_relu(branch3x3dbl, 96, 3) branch_pool = avg_pool(x, 3, strides=1, padding='SAME') branch_pool = conv_norm_relu(branch_pool, 64, 1) x = concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=-1, name='mixed1') # mixed 2: 35 x 35 x 288 branch1x1 = conv_norm_relu(x, 64, 1) branch5x5 = conv_norm_relu(x, 48, 1) branch5x5 = conv_norm_relu(branch5x5, 64, 5) branch3x3dbl = conv_norm_relu(x, 64, 1) branch3x3dbl = conv_norm_relu(branch3x3dbl, 96, 3) branch3x3dbl = conv_norm_relu(branch3x3dbl, 96, 3) branch_pool = avg_pool(x, 3, strides=1, padding='SAME') branch_pool = conv_norm_relu(branch_pool, 64, 1) x = concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=-1, name='mixed2') # mixed 3: 17 x 17 x 768 branch3x3 = conv_norm_relu(x, 384, 3, strides=2, padding='VALID') branch3x3dbl = conv_norm_relu(x, 64, 1) branch3x3dbl = conv_norm_relu(branch3x3dbl, 96, 3) branch3x3dbl = conv_norm_relu(branch3x3dbl, 96, 3, strides=2, padding='VALID') branch_pool = max_pool(x, 3, 2) x = concat([branch3x3, branch3x3dbl, branch_pool], axis=-1, name='mixed3') # mixed 4: 17 x 17 x 768 branch1x1 = conv_norm_relu(x, 192, 1) branch7x7 = conv_norm_relu(x, 128, 1) branch7x7 = conv_norm_relu(branch7x7, 128, 1, 7) branch7x7 = conv_norm_relu(branch7x7, 192, 7, 1) branch7x7dbl = conv_norm_relu(x, 128, 1, 1) branch7x7dbl = conv_norm_relu(branch7x7dbl, 128, 7, 1) branch7x7dbl = conv_norm_relu(branch7x7dbl, 128, 1, 7) branch7x7dbl = conv_norm_relu(branch7x7dbl, 128, 7, 1) branch7x7dbl = conv_norm_relu(branch7x7dbl, 192, 1, 7) branch_pool = avg_pool(x, 3, strides=1, padding='SAME') branch_pool = conv_norm_relu(branch_pool, 192, 1, 1) x = concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=-1, name='mixed4') # mixed 5, 6: 17 x 17 x 768 for i in range(2): branch1x1 = conv_norm_relu(x, 192, 1, 1) branch7x7 = conv_norm_relu(x, 160, 1, 1) branch7x7 = conv_norm_relu(branch7x7, 160, 1, 7) branch7x7 = conv_norm_relu(branch7x7, 192, 7, 1) branch7x7dbl = conv_norm_relu(x, 160, 1, 1) branch7x7dbl = conv_norm_relu(branch7x7dbl, 160, 7, 1) branch7x7dbl = conv_norm_relu(branch7x7dbl, 160, 1, 7) branch7x7dbl = conv_norm_relu(branch7x7dbl, 160, 7, 1) branch7x7dbl = conv_norm_relu(branch7x7dbl, 192, 1, 7) branch_pool = avg_pool(x, 3, strides=1, padding='SAME') branch_pool = conv_norm_relu(branch_pool, 192, 1, 1) x = concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=-1, name='mixed' + str(5 + i)) # mixed 7: 17 x 17 x 768 branch1x1 = conv_norm_relu(x, 192, 1, 1) branch7x7 = conv_norm_relu(x, 192, 1, 1) branch7x7 = conv_norm_relu(branch7x7, 192, 1, 7) branch7x7 = conv_norm_relu(branch7x7, 192, 7, 1) branch7x7dbl = conv_norm_relu(x, 192, 1, 1) branch7x7dbl = conv_norm_relu(branch7x7dbl, 192, 7, 1) branch7x7dbl = conv_norm_relu(branch7x7dbl, 192, 1, 7) branch7x7dbl = conv_norm_relu(branch7x7dbl, 192, 7, 1) branch7x7dbl = conv_norm_relu(branch7x7dbl, 192, 1, 7) branch_pool = avg_pool(x, 3, strides=1, padding='SAME') branch_pool = conv_norm_relu(branch_pool, 192, 1, 1) x = concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=-1, name='mixed7') # mixed 8: 8 x 8 x 1280 branch3x3 = conv_norm_relu(x, 192, 1) branch3x3 = conv_norm_relu(branch3x3, 320, 3, strides=2, padding='VALID') branch7x7x3 = conv_norm_relu(x, 192, 1, 1) branch7x7x3 = conv_norm_relu(branch7x7x3, 192, 1, 7) branch7x7x3 = conv_norm_relu(branch7x7x3, 192, 7, 1) branch7x7x3 = conv_norm_relu(branch7x7x3, 192, 3, 3, strides=2, padding='VALID') branch_pool = max_pool(x, 3, strides=2) x = concat([branch3x3, branch7x7x3, branch_pool], axis=-1, name='mixed8') # mixed 9: 8 x 8 x 2048 for i in range(2): branch1x1 = conv_norm_relu(x, 320, 1, 1) branch3x3 = conv_norm_relu(x, 384, 1, 1) branch3x3_1 = conv_norm_relu(branch3x3, 384, 1, 3) branch3x3_2 = conv_norm_relu(branch3x3, 384, 3, 1) branch3x3 = concat([branch3x3_1, branch3x3_2], axis=-1, name='mixed9_' + str(i)) branch3x3dbl = conv_norm_relu(x, 448, 1, 1) branch3x3dbl = conv_norm_relu(branch3x3dbl, 384, 3, 3) branch3x3dbl_1 = conv_norm_relu(branch3x3dbl, 384, 1, 3) branch3x3dbl_2 = conv_norm_relu(branch3x3dbl, 384, 3, 1) branch3x3dbl = concat([branch3x3dbl_1, branch3x3dbl_2], axis=-1) branch_pool = avg_pool(x, 3, strides=1, padding='SAME') branch_pool = conv_norm_relu(branch_pool, 192, 1, 1) x = concat([branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=-1, name='mixed' + str(9 + i)) # Classification block x = avg_pool(x, kernel_size=8, strides=1, name='avg_pool') x = squeeze(x, axis=[1, 2], name='squeeze') x = fully_connected(x, self.num_classes, name='predictions') x = softmax(x, name='output-prob') return x
def normal_a_cell(ip, p, filters, block_id=None): """Adds a Normal cell for NASNet-A (Fig. 4 in the paper). Args: ip: Input tensor `x` p: Input tensor `p` filters: Number of output filters block_id: String block_id Returns: A tensorflow tensor """ channel_dim = -1 with tf.name_scope('normal_A_block_%s' % block_id): p = NASNetMobile.adjust_block(p, ip, filters, block_id) h = layers.relu(ip) h = layers.conv(h, filters_out=filters, kernel_size=(1, 1), stride=1, padding='same', name='normal_conv_1_%s' % block_id, add_bias=False) h = layers.norm(h, axis=channel_dim, momentum=0.9997, epsilon=1e-3, name='normal_bn_1_%s' % block_id) with tf.name_scope('block_1'): x1_1 = NASNetMobile.separable_conv_block( h, filters, kernel_size=5, block_id='normal_left1_%s' % block_id) x1_2 = NASNetMobile.separable_conv_block( p, filters, block_id='normal_right1_%s' % block_id) x1 = x1_1 + x1_2 with tf.name_scope('block_2'): x2_1 = NASNetMobile.separable_conv_block( p, filters, 5, block_id='normal_left2_%s' % block_id) x2_2 = NASNetMobile.separable_conv_block( p, filters, 3, block_id='normal_right2_%s' % block_id) x2 = x2_1 + x2_2 with tf.name_scope('block_3'): x3 = layers.avg_pool(h, 3, strides=1, padding='same', name='normal_left3_%s' % block_id) x3 = x3 + p with tf.name_scope('block_4'): x4_1 = layers.avg_pool(p, 3, strides=1, padding='same', name='normal_left4_%s' % block_id) x4_2 = layers.avg_pool(p, 3, strides=1, padding='same', name='normal_right4_%s' % block_id) x4 = x4_1 + x4_2 with tf.name_scope('block_5'): x5 = NASNetMobile.separable_conv_block( h, filters, block_id='normal_left5_%s' % block_id) x5 = x5 + h x = layers.concat([p, x1, x2, x3, x4, x5], axis=channel_dim, name='normal_concat_%s' % block_id) return x, ip
def reduction_a_cell(ip, p, filters, block_id=None): """Adds a Reduction cell for NASNet-A (Fig. 4 in the paper). Args: ip: Input tensor `x` p: Input tensor `p` filters: Number of output filters block_id: String block_id Returns: A tf tensor """ channel_dim = -1 with tf.name_scope('reduction_A_block_%s' % block_id): p = NASNetMobile.adjust_block(p, ip, filters, block_id) h = layers.relu(ip) h = layers.conv(h, filters_out=filters, kernel_size=(1, 1), stride=1, padding='same', name='reduction_conv_1_%s' % block_id, add_bias=False) h = layers.norm(h, axis=channel_dim, momentum=0.9997, epsilon=1e-3, name='reduction_bn_1_%s' % block_id) h3 = layers.zero_padding(h, padding=NASNetMobile.correct_pad( h, (3, 3)), name='reduction_pad_1_%s' % block_id) with tf.name_scope('block_1'): x1_1 = NASNetMobile.separable_conv_block( h, filters=filters, kernel_size=5, strides=2, block_id='reduction_left1_%s' % block_id) x1_2 = NASNetMobile.separable_conv_block( p, filters=filters, kernel_size=7, strides=2, block_id='reduction_right1_%s' % block_id) x1 = x1_1 + x1_2 with tf.name_scope('block_2'): x2_1 = layers.max_pool(h3, 3, strides=2, padding='valid', name='reduction_left2_%s' % block_id) x2_2 = NASNetMobile.separable_conv_block( p, filters=filters, kernel_size=7, strides=2, block_id='reduction_right2_%s' % block_id) x2 = x2_1 + x2_2 with tf.name_scope('block_3'): x3_1 = layers.avg_pool(h3, 3, strides=2, padding='valid', name='reduction_left3_%s' % block_id) x3_2 = NASNetMobile.separable_conv_block( p, filters, 5, strides=2, block_id='reduction_right3_%s' % block_id) x3 = x3_1 + x3_2 with tf.name_scope('block_4'): x4 = layers.avg_pool(x1, 3, strides=1, padding='same', name='reduction_left4_%s' % block_id) x4 += x2 with tf.name_scope('block_5'): x5_1 = NASNetMobile.separable_conv_block( x1, filters, 3, block_id='reduction_left4_%s' % block_id) x5_2 = layers.max_pool(h3, 3, strides=2, padding='valid', name='reduction_right5_%s' % block_id) x5 = x5_1 + x5_2 x = layers.concat([x2, x3, x4, x5], axis=channel_dim, name='reduction_concat_%s' % block_id) return x, ip
def adjust_block(p, ip, filters, block_id=None): """Adjusts the input `previous path` to match the shape of the `input`. Used in situations where the output number of filters needs to be changed. Args: p: Input tensor which needs to be modified ip: Input tensor whose shape needs to be matched filters: Number of output filters to be matched block_id: String block_id Returns: Adjusted tf tensor. """ channel_dim = -1 img_dim = -2 ip_shape = ip.get_shape().as_list() if p is not None: p_shape = p.get_shape().as_list() else: p_shape = ip_shape with tf.name_scope('adjust_block'): if p is None: p = ip elif p_shape[img_dim] != ip_shape[img_dim]: with tf.name_scope('adjust_reduction_block_%s' % block_id): p = layers.relu(p, name='adjust_relu_1_%s' % block_id) p1 = layers.avg_pool(p, 1, strides=2, padding='valid', name='adjust_avg_pool_1_%s' % block_id) p1 = layers.conv(p1, filters_out=filters // 2, kernel_size=(1, 1), padding='same', add_bias=False, name='adjust_conv_1_%s' % block_id) p2 = layers.zero_padding(p, padding=((0, 1), (0, 1))) p2 = layers.crop(p2, cropping=((1, 0), (1, 0))) p2 = layers.avg_pool(p2, 1, strides=2, padding='valid', name='adjust_avg_pool_2_%s' % block_id) p2 = layers.conv(p2, filters_out=filters // 2, kernel_size=(1, 1), padding='same', add_bias=False, name='adjust_conv_2_%s' % block_id) p = layers.concat([p1, p2], axis=channel_dim) p = layers.norm(p, axis=channel_dim, momentum=0.9997, epsilon=1e-3, name='adjust_bn_%s' % block_id) elif p_shape[channel_dim] != filters: with tf.name_scope('adjust_projection_block_%s' % block_id): p = layers.relu(p) p = layers.conv(p, filters_out=filters, kernel_size=(1, 1), stride=1, padding='same', name='adjust_conv_projection_%s' % block_id, add_bias=False) p = layers.norm(p, axis=channel_dim, momentum=0.9997, epsilon=1e-3, name='adjust_bn_%s' % block_id) return p