def _factorized_reduction(self, x, out_filters, stride, is_training): """Reduces the shape of x without information loss due to striding.""" assert out_filters % 2 == 0, ( "Need even number of filters when using this factorized reduction.") if stride == 1: with tf.variable_scope("path_conv"): inp_c = get_C(x, self.data_format) w = create_weight("w", [1, 1, inp_c, out_filters]) x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) x = batch_norm(x, is_training, data_format=self.data_format) return x stride_spec = get_strides(stride, self.data_format) # Skip path 1 path1 = tf.nn.avg_pool( x, [1, 1, 1, 1], stride_spec, "VALID", data_format=self.data_format) with tf.variable_scope("path1_conv"): inp_c = get_C(path1, self.data_format) w = create_weight("w", [1, 1, inp_c, out_filters // 2]) path1 = tf.nn.conv2d(path1, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) # Skip path 2 # First pad with 0"s on the right and bottom, then shift the filter to # include those 0"s that were added. if self.data_format == "NHWC": pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]] path2 = tf.pad(x, pad_arr)[:, 1:, 1:, :] concat_axis = 3 else: pad_arr = [[0, 0], [0, 0], [0, 1], [0, 1]] path2 = tf.pad(x, pad_arr)[:, :, 1:, 1:] concat_axis = 1 path2 = tf.nn.avg_pool( path2, [1, 1, 1, 1], stride_spec, "VALID", data_format=self.data_format) with tf.variable_scope("path2_conv"): inp_c = get_C(path2, self.data_format) w = create_weight("w", [1, 1, inp_c, out_filters // 2]) path2 = tf.nn.conv2d(path2, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) # Concat and apply BN final_path = tf.concat(values=[path1, path2], axis=concat_axis) final_path = batch_norm(final_path, is_training, data_format=self.data_format) return final_path
def post_process_out(out, optional_inputs): '''Form skip connection and perform batch norm''' with tf.variable_scope("skip"): inputs = layers[-1] if self.data_format == "NHWC": inp_h = inputs.get_shape()[1].value inp_w = inputs.get_shape()[2].value inp_c = inputs.get_shape()[3].value out.set_shape([None, inp_h, inp_w, out_filters]) elif self.data_format == "NCHW": inp_c = inputs.get_shape()[1].value inp_h = inputs.get_shape()[2].value inp_w = inputs.get_shape()[3].value out.set_shape([None, out_filters, inp_h, inp_w]) optional_inputs.append(out) pout = tf.add_n(optional_inputs) out = batch_norm(pout, is_training, data_format=self.data_format) layers.append(out) return out
def _model(self, images, is_training, reuse=False): '''Build model''' with tf.variable_scope(self.name, reuse=reuse): layers = [] out_filters = self.out_filters with tf.variable_scope("stem_conv"): w = create_weight("w", [3, 3, 3, out_filters]) x = tf.nn.conv2d(images, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) x = batch_norm(x, is_training, data_format=self.data_format) layers.append(x) def add_fixed_pooling_layer(layer_id, layers, out_filters, is_training): '''Add a fixed pooling layer every four layers''' out_filters *= 2 with tf.variable_scope("pool_at_{0}".format(layer_id)): pooled_layers = [] for i, layer in enumerate(layers): with tf.variable_scope("from_{0}".format(i)): x = self._factorized_reduction( layer, out_filters, 2, is_training) pooled_layers.append(x) return pooled_layers, out_filters def post_process_out(out, optional_inputs): '''Form skip connection and perform batch norm''' with tf.variable_scope("skip"): inputs = layers[-1] if self.data_format == "NHWC": inp_h = inputs.get_shape()[1].value inp_w = inputs.get_shape()[2].value inp_c = inputs.get_shape()[3].value out.set_shape([None, inp_h, inp_w, out_filters]) elif self.data_format == "NCHW": inp_c = inputs.get_shape()[1].value inp_h = inputs.get_shape()[2].value inp_w = inputs.get_shape()[3].value out.set_shape([None, out_filters, inp_h, inp_w]) optional_inputs.append(out) pout = tf.add_n(optional_inputs) out = batch_norm(pout, is_training, data_format=self.data_format) layers.append(out) return out global layer_id layer_id = -1 def get_layer_id(): global layer_id layer_id += 1 return 'layer_' + str(layer_id) def conv3(inputs): # res_layers is pre_layers that are chosen to form skip connection # layers[-1] is always the latest input with tf.variable_scope(get_layer_id()): with tf.variable_scope('branch_0'): out = conv_op(inputs[0][0], 3, is_training, out_filters, out_filters, self.data_format, start_idx=None) out = post_process_out(out, inputs[1]) return out def conv3_sep(inputs): with tf.variable_scope(get_layer_id()): with tf.variable_scope('branch_1'): out = conv_op(inputs[0][0], 3, is_training, out_filters, out_filters, self.data_format, start_idx=None, separable=True) out = post_process_out(out, inputs[1]) return out def conv5(inputs): with tf.variable_scope(get_layer_id()): with tf.variable_scope('branch_2'): out = conv_op(inputs[0][0], 5, is_training, out_filters, out_filters, self.data_format, start_idx=None) out = post_process_out(out, inputs[1]) return out def conv5_sep(inputs): with tf.variable_scope(get_layer_id()): with tf.variable_scope('branch_3'): out = conv_op(inputs[0][0], 5, is_training, out_filters, out_filters, self.data_format, start_idx=None, separable=True) out = post_process_out(out, inputs[1]) return out def avg_pool(inputs): with tf.variable_scope(get_layer_id()): with tf.variable_scope('branch_4'): out = pool_op(inputs[0][0], is_training, out_filters, out_filters, "avg", self.data_format, start_idx=None) out = post_process_out(out, inputs[1]) return out def max_pool(inputs): with tf.variable_scope(get_layer_id()): with tf.variable_scope('branch_5'): out = pool_op(inputs[0][0], is_training, out_filters, out_filters, "max", self.data_format, start_idx=None) out = post_process_out(out, inputs[1]) return out """@nni.mutable_layers( { layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()], fixed_inputs:[x], layer_output: layer_0_out }, { layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()], fixed_inputs:[layer_0_out], optional_inputs: [layer_0_out], optional_input_size: [0, 1], layer_output: layer_1_out }, { layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()], fixed_inputs:[layer_1_out], optional_inputs: [layer_0_out, layer_1_out], optional_input_size: [0, 1], layer_output: layer_2_out }, { layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()], fixed_inputs:[layer_2_out], optional_inputs: [layer_0_out, layer_1_out, layer_2_out], optional_input_size: [0, 1], layer_output: layer_3_out } )""" layers, out_filters = add_fixed_pooling_layer( 3, layers, out_filters, is_training) layer_0_out, layer_1_out, layer_2_out, layer_3_out = layers[-4:] """@nni.mutable_layers( { layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()], fixed_inputs: [layer_3_out], optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out], optional_input_size: [0, 1], layer_output: layer_4_out }, { layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()], fixed_inputs: [layer_4_out], optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out], optional_input_size: [0, 1], layer_output: layer_5_out }, { layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()], fixed_inputs: [layer_5_out], optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out, layer_5_out], optional_input_size: [0, 1], layer_output: layer_6_out }, { layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()], fixed_inputs: [layer_6_out], optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out, layer_5_out, layer_6_out], optional_input_size: [0, 1], layer_output: layer_7_out } )""" layers, out_filters = add_fixed_pooling_layer( 7, layers, out_filters, is_training) layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out, layer_5_out, layer_6_out, layer_7_out = layers[ -8:] """@nni.mutable_layers( { layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()], fixed_inputs: [layer_7_out], optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out, layer_5_out, layer_6_out, layer_7_out], optional_input_size: [0, 1], layer_output: layer_8_out }, { layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()], fixed_inputs: [layer_8_out], optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out, layer_5_out, layer_6_out, layer_7_out, layer_8_out], optional_input_size: [0, 1], layer_output: layer_9_out }, { layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()], fixed_inputs: [layer_9_out], optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out, layer_5_out, layer_6_out, layer_7_out, layer_8_out, layer_9_out], optional_input_size: [0, 1], layer_output: layer_10_out }, { layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()], fixed_inputs:[layer_10_out], optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out, layer_5_out, layer_6_out, layer_7_out, layer_8_out, layer_9_out, layer_10_out], optional_input_size: [0, 1], layer_output: layer_11_out } )""" x = global_avg_pool(layer_11_out, data_format=self.data_format) if is_training: x = tf.nn.dropout(x, self.keep_prob) with tf.variable_scope("fc"): if self.data_format == "NHWC": inp_c = x.get_shape()[3].value elif self.data_format == "NCHW": inp_c = x.get_shape()[1].value else: raise ValueError("Unknown data_format {0}".format( self.data_format)) w = create_weight("w", [inp_c, 10]) x = tf.matmul(x, w) return x