def _maybe_calibrate_size(self, layers, out_filters, is_training): """Makes sure layers[0] and layers[1] have the same shapes.""" hw = [self._get_HW(layer) for layer in layers] c = [self._get_C(layer) for layer in layers] with tf.variable_scope("calibrate"): x = layers[0] if hw[0] != hw[1]: assert hw[0] == 2 * hw[1] with tf.variable_scope("pool_x"): x = tf.nn.relu(x) x = self._factorized_reduction(x, out_filters, 2, is_training) elif c[0] != out_filters: with tf.variable_scope("pool_x"): w = create_weight("w", [1, 1, c[0], out_filters]) x = tf.nn.relu(x) x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) x = batch_norm(x, is_training, data_format=self.data_format) y = layers[1] if c[1] != out_filters: with tf.variable_scope("pool_y"): w = create_weight("w", [1, 1, c[1], out_filters]) y = tf.nn.relu(y) y = tf.nn.conv2d(y, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) y = batch_norm(y, is_training, data_format=self.data_format) return [x, y]
def _enas_cell(self, x, curr_cell, prev_cell, op_id, out_filters): """Performs an enas operation specified by op_id.""" num_possible_inputs = curr_cell + 1 with tf.variable_scope("avg_pool"): avg_pool = tf.layers.average_pooling2d( x, [3, 3], [1, 1], "SAME", data_format=self.actual_data_format) avg_pool_c = self._get_C(avg_pool) if avg_pool_c != out_filters: with tf.variable_scope("conv"): w = create_weight( "w", [num_possible_inputs, avg_pool_c * out_filters]) w = w[prev_cell] w = tf.reshape(w, [1, 1, avg_pool_c, out_filters]) avg_pool = tf.nn.relu(avg_pool) avg_pool = tf.nn.conv2d(avg_pool, w, strides=[1, 1, 1, 1], padding="SAME", data_format=self.data_format) avg_pool = batch_norm(avg_pool, is_training=True, data_format=self.data_format) with tf.variable_scope("max_pool"): max_pool = tf.layers.max_pooling2d( x, [3, 3], [1, 1], "SAME", data_format=self.actual_data_format) max_pool_c = self._get_C(max_pool) if max_pool_c != out_filters: with tf.variable_scope("conv"): w = create_weight( "w", [num_possible_inputs, max_pool_c * out_filters]) w = w[prev_cell] w = tf.reshape(w, [1, 1, max_pool_c, out_filters]) max_pool = tf.nn.relu(max_pool) max_pool = tf.nn.conv2d(max_pool, w, strides=[1, 1, 1, 1], padding="SAME", data_format=self.data_format) max_pool = batch_norm(max_pool, is_training=True, data_format=self.data_format) x_c = self._get_C(x) if x_c != out_filters: with tf.variable_scope("x_conv"): w = create_weight("w", [num_possible_inputs, x_c * out_filters]) w = w[prev_cell] w = tf.reshape(w, [1, 1, x_c, out_filters]) x = tf.nn.relu(x) x = tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding="SAME", data_format=self.data_format) x = batch_norm(x, is_training=True, data_format=self.data_format) out = [ self._enas_conv(x, curr_cell, prev_cell, 3, out_filters), self._enas_conv(x, curr_cell, prev_cell, 5, out_filters), avg_pool, max_pool, x, ] out = tf.stack(out, axis=0) out = out[op_id, :, :, :, :] return out
def _factorized_reduction(self, x, out_filters, stride, is_training): """Reduces the shape of x without information loss due to striding.""" assert out_filters % 2 == 0, ( "Need even number of filters when using this factorized reduction.") if stride == 1: with tf.variable_scope("path_conv"): inp_c = self._get_C(x) w = create_weight("w", [1, 1, inp_c, out_filters]) x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) x = batch_norm(x, is_training, data_format=self.data_format) return x stride_spec = self._get_strides(stride) # Skip path 1 path1 = tf.nn.avg_pool( x, [1, 1, 1, 1], stride_spec, "VALID", data_format=self.data_format) with tf.variable_scope("path1_conv"): inp_c = self._get_C(path1) w = create_weight("w", [1, 1, inp_c, out_filters // 2]) path1 = tf.nn.conv2d(path1, w, [1, 1, 1, 1], "VALID", data_format=self.data_format) # Skip path 2 # First pad with 0"s on the right and bottom, then shift the filter to # include those 0"s that were added. if self.data_format == "NHWC": pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]] path2 = tf.pad(x, pad_arr)[:, 1:, 1:, :] concat_axis = 3 else: pad_arr = [[0, 0], [0, 0], [0, 1], [0, 1]] path2 = tf.pad(x, pad_arr)[:, :, 1:, 1:] concat_axis = 1 path2 = tf.nn.avg_pool( path2, [1, 1, 1, 1], stride_spec, "VALID", data_format=self.data_format) with tf.variable_scope("path2_conv"): inp_c = self._get_C(path2) w = create_weight("w", [1, 1, inp_c, out_filters // 2]) path2 = tf.nn.conv2d(path2, w, [1, 1, 1, 1], "VALID", data_format=self.data_format) # Concat and apply BN final_path = tf.concat(values=[path1, path2], axis=concat_axis) final_path = batch_norm(final_path, is_training, data_format=self.data_format) return final_path
def _fixed_conv(self, x, f_size, out_filters, stride, is_training, stack_convs=2): """Apply fixed convolution. Args: stacked_convs: number of separable convs to apply. """ for conv_id in range(stack_convs): inp_c = self._get_C(x) if conv_id == 0: strides = self._get_strides(stride) else: strides = [1, 1, 1, 1] with tf.variable_scope("sep_conv_{}".format(conv_id)): w_depthwise = create_weight("w_depth", [f_size, f_size, inp_c, 1]) w_pointwise = create_weight("w_point", [1, 1, inp_c, out_filters]) x = tf.nn.relu(x) x = tf.nn.separable_conv2d( x, depthwise_filter=w_depthwise, pointwise_filter=w_pointwise, strides=strides, padding="SAME", data_format=self.data_format) x = batch_norm(x, is_training, data_format=self.data_format) return x
def _model(self, images, is_training, reuse=False): with tf.variable_scope(self.name, reuse=reuse): layers = [] out_filters = self.out_filters with tf.variable_scope("stem_conv"): w = create_weight("w", [3, 3, 3, out_filters]) x = tf.nn.conv2d(images, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) x = batch_norm(x, is_training, data_format=self.data_format) layers.append(x) if self.whole_channels: start_idx = 0 else: start_idx = self.num_branches for layer_id in range(self.num_layers): with tf.variable_scope("layer_{0}".format(layer_id)): if self.fixed_arc is None: x = self._enas_layer(layer_id, layers, start_idx, out_filters, is_training) else: x = self._fixed_layer(layer_id, layers, start_idx, out_filters, is_training) layers.append(x) if layer_id in self.pool_layers: if self.fixed_arc is not None: out_filters *= 2 with tf.variable_scope("pool_at_{0}".format(layer_id)): pooled_layers = [] for i, layer in enumerate(layers): with tf.variable_scope("from_{0}".format(i)): x = self._factorized_reduction( layer, out_filters, 2, is_training) pooled_layers.append(x) layers = pooled_layers if self.whole_channels: start_idx += 1 + layer_id else: start_idx += 2 * self.num_branches + layer_id print(layers[-1]) x = global_avg_pool(x, data_format=self.data_format) if is_training: x = tf.nn.dropout(x, self.keep_prob) with tf.variable_scope("fc"): if self.data_format == "NWHC": inp_c = x.get_shape()[3].value elif self.data_format == "NCHW": inp_c = x.get_shape()[1].value else: raise ValueError("Unknown data_format {0}".format( self.data_format)) w = create_weight("w", [inp_c, 10]) x = tf.matmul(x, w) return x
def _pool_branch(self, inputs, is_training, count, avg_or_max, start_idx=None): """ Args: start_idx: where to start taking the output channels. if None, assuming fixed_arc mode count: how many output_channels to take. """ if start_idx is None: assert self.fixed_arc is not None, "you screwed up!" if self.data_format == "NHWC": inp_c = inputs.get_shape()[3].value elif self.data_format == "NCHW": inp_c = inputs.get_shape()[1].value with tf.variable_scope("conv_1"): w = create_weight("w", [1, 1, inp_c, self.out_filters]) x = tf.nn.conv2d(inputs, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) x = batch_norm(x, is_training, data_format=self.data_format) x = tf.nn.relu(x) with tf.variable_scope("pool"): if self.data_format == "NHWC": actual_data_format = "channels_last" elif self.data_format == "NCHW": actual_data_format = "channels_first" if avg_or_max == "avg": x = tf.layers.average_pooling2d(x, [3, 3], [1, 1], "SAME", data_format=actual_data_format) elif avg_or_max == "max": x = tf.layers.max_pooling2d(x, [3, 3], [1, 1], "SAME", data_format=actual_data_format) else: raise ValueError("Unknown pool {}".format(avg_or_max)) if start_idx is not None: if self.data_format == "NHWC": x = x[:, :, :, start_idx:start_idx + count] elif self.data_format == "NCHW": x = x[:, start_idx:start_idx + count, :, :] return x
def _enas_layer(self, layer_id, prev_layers, arc, out_filters): """ Args: layer_id: current layer prev_layers: cache of previous layers. for skip connections start_idx: where to start looking at. technically, we can infer this from layer_id, but why bother... """ assert len(prev_layers) == 2, "need exactly 2 inputs" layers = [prev_layers[0], prev_layers[1]] layers = self._maybe_calibrate_size(layers, out_filters, is_training=True) used = [] for cell_id in range(self.num_cells): prev_layers = tf.stack(layers, axis=0) with tf.variable_scope("cell_{0}".format(cell_id)): with tf.variable_scope("x"): x_id = arc[4 * cell_id] x_op = arc[4 * cell_id + 1] x = prev_layers[x_id, :, :, :, :] x = self._enas_cell(x, cell_id, x_id, x_op, out_filters) x_used = tf.one_hot(x_id, depth=self.num_cells + 2, dtype=tf.int32) with tf.variable_scope("y"): y_id = arc[4 * cell_id + 2] y_op = arc[4 * cell_id + 3] y = prev_layers[y_id, :, :, :, :] y = self._enas_cell(y, cell_id, y_id, y_op, out_filters) y_used = tf.one_hot(y_id, depth=self.num_cells + 2, dtype=tf.int32) out = x + y used.extend([x_used, y_used]) layers.append(out) used = tf.add_n(used) indices = tf.where(tf.equal(used, 0)) indices = tf.to_int32(indices) indices = tf.reshape(indices, [-1]) num_outs = tf.size(indices) out = tf.stack(layers, axis=0) out = tf.gather(out, indices, axis=0) inp = prev_layers[0] if self.data_format == "NHWC": N = tf.shape(inp)[0] H = tf.shape(inp)[1] W = tf.shape(inp)[2] C = tf.shape(inp)[3] out = tf.transpose(out, [1, 2, 3, 0, 4]) out = tf.reshape(out, [N, H, W, num_outs * out_filters]) elif self.data_format == "NCHW": N = tf.shape(inp)[0] C = tf.shape(inp)[1] H = tf.shape(inp)[2] W = tf.shape(inp)[3] out = tf.transpose(out, [1, 0, 2, 3, 4]) out = tf.reshape(out, [N, num_outs * out_filters, H, W]) else: raise ValueError("Unknown data_format '{0}'".format(self.data_format)) with tf.variable_scope("final_conv"): w = create_weight("w", [self.num_cells + 2, out_filters * out_filters]) w = tf.gather(w, indices, axis=0) w = tf.reshape(w, [1, 1, num_outs * out_filters, out_filters]) out = tf.nn.relu(out) out = tf.nn.conv2d(out, w, strides=[1, 1, 1, 1], padding="SAME", data_format=self.data_format) out = batch_norm(out, is_training=True, data_format=self.data_format) out = tf.reshape(out, tf.shape(prev_layers[0])) return out
def _fixed_layer(self, layer_id, prev_layers, arc, out_filters, stride, is_training, normal_or_reduction_cell="normal"): """ Args: prev_layers: cache of previous layers. for skip connections is_training: for batch_norm """ assert len(prev_layers) == 2 layers = [prev_layers[0], prev_layers[1]] layers = self._maybe_calibrate_size(layers, out_filters, is_training=is_training) with tf.variable_scope("layer_base"): x = layers[1] inp_c = self._get_C(x) w = create_weight("w", [1, 1, inp_c, out_filters]) x = tf.nn.relu(x) x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) x = batch_norm(x, is_training, data_format=self.data_format) layers[1] = x used = np.zeros([self.num_cells + 2], dtype=np.int32) f_sizes = [3, 5] for cell_id in range(self.num_cells): with tf.variable_scope("cell_{}".format(cell_id)): x_id = arc[4 * cell_id] used[x_id] += 1 x_op = arc[4 * cell_id + 1] x = layers[x_id] x_stride = stride if x_id in [0, 1] else 1 with tf.variable_scope("x_conv"): if x_op in [0, 1]: f_size = f_sizes[x_op] x = self._fixed_conv(x, f_size, out_filters, x_stride, is_training) elif x_op in [2, 3]: inp_c = self._get_C(x) if x_op == 2: x = tf.layers.average_pooling2d( x, [3, 3], [x_stride, x_stride], "SAME", data_format=self.actual_data_format) else: x = tf.layers.max_pooling2d( x, [3, 3], [x_stride, x_stride], "SAME", data_format=self.actual_data_format) if inp_c != out_filters: w = create_weight("w", [1, 1, inp_c, out_filters]) x = tf.nn.relu(x) x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) x = batch_norm(x, is_training, data_format=self.data_format) else: inp_c = self._get_C(x) if x_stride > 1: assert x_stride == 2 x = self._factorized_reduction(x, out_filters, 2, is_training) if inp_c != out_filters: w = create_weight("w", [1, 1, inp_c, out_filters]) x = tf.nn.relu(x) x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) x = batch_norm(x, is_training, data_format=self.data_format) if (x_op in [0, 1, 2, 3] and self.drop_path_keep_prob is not None and is_training): x = self._apply_drop_path(x, layer_id) y_id = arc[4 * cell_id + 2] used[y_id] += 1 y_op = arc[4 * cell_id + 3] y = layers[y_id] y_stride = stride if y_id in [0, 1] else 1 with tf.variable_scope("y_conv"): if y_op in [0, 1]: f_size = f_sizes[y_op] y = self._fixed_conv(y, f_size, out_filters, y_stride, is_training) elif y_op in [2, 3]: inp_c = self._get_C(y) if y_op == 2: y = tf.layers.average_pooling2d( y, [3, 3], [y_stride, y_stride], "SAME", data_format=self.actual_data_format) else: y = tf.layers.max_pooling2d( y, [3, 3], [y_stride, y_stride], "SAME", data_format=self.actual_data_format) if inp_c != out_filters: w = create_weight("w", [1, 1, inp_c, out_filters]) y = tf.nn.relu(y) y = tf.nn.conv2d(y, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) y = batch_norm(y, is_training, data_format=self.data_format) else: inp_c = self._get_C(y) if y_stride > 1: assert y_stride == 2 y = self._factorized_reduction(y, out_filters, 2, is_training) if inp_c != out_filters: w = create_weight("w", [1, 1, inp_c, out_filters]) y = tf.nn.relu(y) y = tf.nn.conv2d(y, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) y = batch_norm(y, is_training, data_format=self.data_format) if (y_op in [0, 1, 2, 3] and self.drop_path_keep_prob is not None and is_training): y = self._apply_drop_path(y, layer_id) out = x + y layers.append(out) out = self._fixed_combine(layers, used, out_filters, is_training, normal_or_reduction_cell) return out
def _model(self, images, is_training, reuse=False): """Compute the logits given the images.""" if self.fixed_arc is None: is_training = True with tf.variable_scope(self.name, reuse=reuse): # the first two inputs with tf.variable_scope("stem_conv"): w = create_weight("w", [3, 3, 3, self.out_filters * 3]) x = tf.nn.conv2d( images, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) x = batch_norm(x, is_training, data_format=self.data_format) if self.data_format == "NHCW": split_axis = 3 elif self.data_format == "NCHW": split_axis = 1 else: raise ValueError("Unknown data_format '{0}'".format(self.data_format)) layers = [x, x] # building layers in the micro space out_filters = self.out_filters for layer_id in range(self.num_layers + 2): with tf.variable_scope("layer_{0}".format(layer_id)): if layer_id not in self.pool_layers: if self.fixed_arc is None: x = self._enas_layer( layer_id, layers, self.normal_arc, out_filters) else: x = self._fixed_layer( layer_id, layers, self.normal_arc, out_filters, 1, is_training, normal_or_reduction_cell="normal") else: out_filters *= 2 if self.fixed_arc is None: x = self._factorized_reduction(x, out_filters, 2, is_training) layers = [layers[-1], x] x = self._enas_layer( layer_id, layers, self.reduce_arc, out_filters) else: x = self._fixed_layer( layer_id, layers, self.reduce_arc, out_filters, 2, is_training, normal_or_reduction_cell="reduction") print("Layer {0:>2d}: {1}".format(layer_id, x)) layers = [layers[-1], x] # auxiliary heads self.num_aux_vars = 0 if (self.use_aux_heads and layer_id in self.aux_head_indices and is_training): print("Using aux_head at layer {0}".format(layer_id)) with tf.variable_scope("aux_head"): aux_logits = tf.nn.relu(x) aux_logits = tf.layers.average_pooling2d( aux_logits, [5, 5], [3, 3], "VALID", data_format=self.actual_data_format) with tf.variable_scope("proj"): inp_c = self._get_C(aux_logits) w = create_weight("w", [1, 1, inp_c, 128]) aux_logits = tf.nn.conv2d(aux_logits, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) aux_logits = batch_norm(aux_logits, is_training=True, data_format=self.data_format) aux_logits = tf.nn.relu(aux_logits) with tf.variable_scope("avg_pool"): inp_c = self._get_C(aux_logits) hw = self._get_HW(aux_logits) w = create_weight("w", [hw, hw, inp_c, 768]) aux_logits = tf.nn.conv2d(aux_logits, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) aux_logits = batch_norm(aux_logits, is_training=True, data_format=self.data_format) aux_logits = tf.nn.relu(aux_logits) with tf.variable_scope("fc"): aux_logits = global_avg_pool(aux_logits, data_format=self.data_format) inp_c = aux_logits.get_shape()[1].value w = create_weight("w", [inp_c, 10]) aux_logits = tf.matmul(aux_logits, w) self.aux_logits = aux_logits aux_head_variables = [ var for var in tf.trainable_variables() if ( var.name.startswith(self.name) and "aux_head" in var.name)] self.num_aux_vars = count_model_params(aux_head_variables) print("Aux head uses {0} params".format(self.num_aux_vars)) x = tf.nn.relu(x) x = global_avg_pool(x, data_format=self.data_format) if is_training and self.keep_prob is not None and self.keep_prob < 1.0: x = tf.nn.dropout(x, self.keep_prob) with tf.variable_scope("fc"): inp_c = self._get_C(x) w = create_weight("w", [inp_c, 10]) x = tf.matmul(x, w) return x
def _conv_branch(self, inputs, filter_size, is_training, count, out_filters, ch_mul=1, start_idx=None, separable=False): """ Args: start_idx: where to start taking the output channels. if None, assuming fixed_arc mode count: how many output_channels to take. """ if start_idx is None: assert self.fixed_arc is not None, "you screwed up!" if self.data_format == "NHWC": inp_c = inputs.get_shape()[3].value elif self.data_format == "NCHW": inp_c = inputs.get_shape()[1].value with tf.variable_scope("inp_conv_1"): w = create_weight("w", [1, 1, inp_c, out_filters]) x = tf.nn.conv2d(inputs, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) x = batch_norm(x, is_training, data_format=self.data_format) x = tf.nn.relu(x) with tf.variable_scope("out_conv_{}".format(filter_size)): if start_idx is None: if separable: w_depth = create_weight("w_depth", [ self.filter_size, self.filter_size, out_filters, ch_mul ]) w_point = create_weight( "w_point", [1, 1, out_filters * ch_mul, count]) x = tf.nn.separable_conv2d(x, w_depth, w_point, strides=[1, 1, 1, 1], padding="SAME", data_format=self.data_format) x = batch_norm(x, is_training, data_format=self.data_format) else: w = create_weight("w", [filter_size, filter_size, inp_c, count]) x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) x = batch_norm(x, is_training, data_format=self.data_format) else: if separable: w_depth = create_weight( "w_depth", [filter_size, filter_size, out_filters, ch_mul]) w_point = create_weight( "w_point", [out_filters, out_filters * ch_mul]) w_point = w_point[start_idx:start_idx + count, :] w_point = tf.transpose(w_point, [1, 0]) w_point = tf.reshape(w_point, [1, 1, out_filters * ch_mul, count]) x = tf.nn.separable_conv2d(x, w_depth, w_point, strides=[1, 1, 1, 1], padding="SAME", data_format=self.data_format) mask = tf.range(0, out_filters, dtype=tf.int32) mask = tf.logical_and(start_idx <= mask, mask < start_idx + count) x = batch_norm_with_mask(x, is_training, mask, out_filters, data_format=self.data_format) else: w = create_weight( "w", [filter_size, filter_size, out_filters, out_filters]) w = tf.transpose(w, [3, 0, 1, 2]) w = w[start_idx:start_idx + count, :, :, :] w = tf.transpose(w, [1, 2, 3, 0]) x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) mask = tf.range(0, out_filters, dtype=tf.int32) mask = tf.logical_and(start_idx <= mask, mask < start_idx + count) x = batch_norm_with_mask(x, is_training, mask, out_filters, data_format=self.data_format) x = tf.nn.relu(x) return x
def _fixed_layer(self, layer_id, prev_layers, start_idx, out_filters, is_training): """ Args: layer_id: current layer prev_layers: cache of previous layers. for skip connections start_idx: where to start looking at. technically, we can infer this from layer_id, but why bother... is_training: for batch_norm """ inputs = prev_layers[-1] if self.whole_channels: if self.data_format == "NHWC": inp_c = inputs.get_shape()[3].value elif self.data_format == "NCHW": inp_c = inputs.get_shape()[1].value count = self.sample_arc[start_idx] if count in [0, 1, 2, 3]: size = [3, 3, 5, 5] filter_size = size[count] with tf.variable_scope("conv_1x1"): w = create_weight("w", [1, 1, inp_c, out_filters]) out = tf.nn.relu(inputs) out = tf.nn.conv2d(out, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) out = batch_norm(out, is_training, data_format=self.data_format) with tf.variable_scope("conv_{0}x{0}".format(filter_size)): w = create_weight( "w", [filter_size, filter_size, out_filters, out_filters]) out = tf.nn.relu(out) out = tf.nn.conv2d(out, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) out = batch_norm(out, is_training, data_format=self.data_format) elif count == 4: pass elif count == 5: pass else: raise ValueError( "Unknown operation number '{0}'".format(count)) else: count = ( self.sample_arc[start_idx:start_idx + 2 * self.num_branches] * self.out_filters_scale) branches = [] total_out_channels = 0 with tf.variable_scope("branch_0"): total_out_channels += count[1] branches.append( self._conv_branch(inputs, 3, is_training, count[1])) with tf.variable_scope("branch_1"): total_out_channels += count[3] branches.append( self._conv_branch(inputs, 3, is_training, count[3], separable=True)) with tf.variable_scope("branch_2"): total_out_channels += count[5] branches.append( self._conv_branch(inputs, 5, is_training, count[5])) with tf.variable_scope("branch_3"): total_out_channels += count[7] branches.append( self._conv_branch(inputs, 5, is_training, count[7], separable=True)) if self.num_branches >= 5: with tf.variable_scope("branch_4"): total_out_channels += count[9] branches.append( self._pool_branch(inputs, is_training, count[9], "avg")) if self.num_branches >= 6: with tf.variable_scope("branch_5"): total_out_channels += count[11] branches.append( self._pool_branch(inputs, is_training, count[11], "max")) with tf.variable_scope("final_conv"): w = create_weight("w", [1, 1, total_out_channels, out_filters]) if self.data_format == "NHWC": branches = tf.concat(branches, axis=3) elif self.data_format == "NCHW": branches = tf.concat(branches, axis=1) out = tf.nn.relu(branches) out = tf.nn.conv2d(out, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) out = batch_norm(out, is_training, data_format=self.data_format) if layer_id > 0: if self.whole_channels: skip_start = start_idx + 1 else: skip_start = start_idx + 2 * self.num_branches skip = self.sample_arc[skip_start:skip_start + layer_id] total_skip_channels = np.sum(skip) + 1 res_layers = [] for i in range(layer_id): if skip[i] == 1: res_layers.append(prev_layers[i]) prev = res_layers + [out] if self.data_format == "NHWC": prev = tf.concat(prev, axis=3) elif self.data_format == "NCHW": prev = tf.concat(prev, axis=1) out = prev with tf.variable_scope("skip"): w = create_weight( "w", [1, 1, total_skip_channels * out_filters, out_filters]) out = tf.nn.relu(out) out = tf.nn.conv2d(out, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) out = batch_norm(out, is_training, data_format=self.data_format) return out
def _enas_layer(self, layer_id, prev_layers, start_idx, out_filters, is_training): """ Args: layer_id: current layer prev_layers: cache of previous layers. for skip connections start_idx: where to start looking at. technically, we can infer this from layer_id, but why bother... is_training: for batch_norm """ inputs = prev_layers[-1] if self.whole_channels: if self.data_format == "NHWC": inp_h = inputs.get_shape()[1].value inp_w = inputs.get_shape()[2].value inp_c = inputs.get_shape()[3].value elif self.data_format == "NCHW": inp_c = inputs.get_shape()[1].value inp_h = inputs.get_shape()[2].value inp_w = inputs.get_shape()[3].value count = self.sample_arc[start_idx] branches = {} with tf.variable_scope("branch_0"): y = self._conv_branch(inputs, 3, is_training, out_filters, out_filters, start_idx=0) branches[tf.equal(count, 0)] = lambda: y with tf.variable_scope("branch_1"): y = self._conv_branch(inputs, 3, is_training, out_filters, out_filters, start_idx=0, separable=True) branches[tf.equal(count, 1)] = lambda: y with tf.variable_scope("branch_2"): y = self._conv_branch(inputs, 5, is_training, out_filters, out_filters, start_idx=0) branches[tf.equal(count, 2)] = lambda: y with tf.variable_scope("branch_3"): y = self._conv_branch(inputs, 5, is_training, out_filters, out_filters, start_idx=0, separable=True) branches[tf.equal(count, 3)] = lambda: y if self.num_branches >= 5: with tf.variable_scope("branch_4"): y = self._pool_branch(inputs, is_training, out_filters, "avg", start_idx=0) branches[tf.equal(count, 4)] = lambda: y if self.num_branches >= 6: with tf.variable_scope("branch_5"): y = self._pool_branch(inputs, is_training, out_filters, "max", start_idx=0) branches[tf.equal(count, 5)] = lambda: y out = tf.case(branches, default=lambda: tf.constant(0, tf.float32), exclusive=True) if self.data_format == "NHWC": out.set_shape([None, inp_h, inp_w, out_filters]) elif self.data_format == "NCHW": out.set_shape([None, out_filters, inp_h, inp_w]) else: count = self.sample_arc[start_idx:start_idx + 2 * self.num_branches] branches = [] with tf.variable_scope("branch_0"): branches.append( self._conv_branch(inputs, 3, is_training, count[1], out_filters, start_idx=count[0])) with tf.variable_scope("branch_1"): branches.append( self._conv_branch(inputs, 3, is_training, count[3], out_filters, start_idx=count[2], separable=True)) with tf.variable_scope("branch_2"): branches.append( self._conv_branch(inputs, 5, is_training, count[5], out_filters, start_idx=count[4])) with tf.variable_scope("branch_3"): branches.append( self._conv_branch(inputs, 5, is_training, count[7], out_filters, start_idx=count[6], separable=True)) if self.num_branches >= 5: with tf.variable_scope("branch_4"): branches.append( self._pool_branch(inputs, is_training, count[9], "avg", start_idx=count[8])) if self.num_branches >= 6: with tf.variable_scope("branch_5"): branches.append( self._pool_branch(inputs, is_training, count[11], "max", start_idx=count[10])) with tf.variable_scope("final_conv"): w = create_weight( "w", [self.num_branches * out_filters, out_filters]) w_mask = tf.constant( [False] * (self.num_branches * out_filters), tf.bool) new_range = tf.range(0, self.num_branches * out_filters, dtype=tf.int32) for i in range(self.num_branches): start = out_filters * i + count[2 * i] new_mask = tf.logical_and( start <= new_range, new_range < start + count[2 * i + 1]) w_mask = tf.logical_or(w_mask, new_mask) w = tf.boolean_mask(w, w_mask) w = tf.reshape(w, [1, 1, -1, out_filters]) inp = prev_layers[-1] if self.data_format == "NHWC": branches = tf.concat(branches, axis=3) elif self.data_format == "NCHW": branches = tf.concat(branches, axis=1) N = tf.shape(inp)[0] H = inp.get_shape()[2].value W = inp.get_shape()[3].value branches = tf.reshape(branches, [N, -1, H, W]) out = tf.nn.conv2d(branches, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) out = batch_norm(out, is_training, data_format=self.data_format) out = tf.nn.relu(out) if layer_id > 0: if self.whole_channels: skip_start = start_idx + 1 else: skip_start = start_idx + 2 * self.num_branches skip = self.sample_arc[skip_start:skip_start + layer_id] with tf.variable_scope("skip"): res_layers = [] for i in range(layer_id): res_layers.append( tf.cond(tf.equal(skip[i], 1), lambda: prev_layers[i], lambda: tf.zeros_like(prev_layers[i]))) res_layers.append(out) out = tf.add_n(res_layers) out = batch_norm(out, is_training, data_format=self.data_format) return out