def conv_op(filters, filter_size, stride, dilation_rate, spatial_separable): if spatial_separable: return mo.siso_sequential([ conv2d(D([filters]), D([[1, filter_size]]), D([[1, stride]])), batch_normalization(), relu(), conv2d(D([filters]), D([[filter_size, 1]]), D([[stride, 1]])), ]) else: return conv2d(D([filters]), D([filter_size]), D([stride]), D([dilation_rate]))
def intermediate_node_fn(num_inputs, filters): return mo.siso_sequential([ add(num_inputs), conv2d(D([filters]), D([3])), batch_normalization(), relu() ])
def full_conv_op(filters, filter_size, stride, dilation_rate, spatial_separable): # Add bottleneck layer according to # https://github.com/tensorflow/tpu/blob/master/models/official/amoeba_net/network_utils.py if filter_size == 3 and spatial_separable: reduced_filter_size = int(3 * filters / 8) else: reduced_filter_size = int(filters / 4) if reduced_filter_size < 1: return wrap_relu_batch_norm( conv_op(filters, filter_size, stride, dilation_rate, spatial_separable)) else: return mo.siso_sequential([ wrap_relu_batch_norm(conv2d(D([reduced_filter_size]), D([1]))), wrap_relu_batch_norm( conv_op(reduced_filter_size, filter_size, stride, dilation_rate, spatial_separable)), wrap_relu_batch_norm(conv2d(D([filters]), D([1]))) ])
def aux_logits(): return mo.siso_sequential([ relu(), avg_pool2d(D([5]), D([3]), D(['VALID'])), conv2d(D([128]), D([1])), batch_normalization(), relu(), global_convolution(D([768])), batch_normalization(), relu(), flatten(), fc_layer(D([10])) ])
def cell_input_fn(filters): prev_input = mo.identity() cur_input = wrap_relu_batch_norm(conv2d(D([filters]), D([1]))) transformed_prev_input = maybe_factorized_reduction(add_relu=True) transformed_prev_input[0]['in0'].connect(prev_input[1]['out']) transformed_prev_input[0]['in1'].connect(cur_input[1]['out']) return { 'in0': prev_input[0]['in'], 'in1': cur_input[0]['in'] }, { 'out0': transformed_prev_input[1]['out'], 'out1': cur_input[1]['out'] }
def generate_stage(stage_num, num_nodes, filters, filter_size): h_connections = [ Bool(name='%d_in_%d_%d' % (stage_num, in_id, out_id)) for (in_id, out_id) in itertools.combinations(range(1, num_nodes + 1), 2) ] return genetic_stage( lambda: mo.siso_sequential([ conv2d(D([filters]), D([filter_size])), batch_normalization(), relu() ]), lambda num_inputs: intermediate_node_fn(num_inputs, filters), lambda num_inputs: intermediate_node_fn(num_inputs, filters), h_connections, num_nodes)
def intermediate_node_fn(reduction, input_id, node_id, op_num, filters, cell_ratio, cell_ops): stride = 2 if reduction and input_id < 2 else 1 h_is_not_none = co.DependentHyperparameter( lambda dh: dh["op"] != 'none', {'op': cell_ops[node_id * 2 + op_num]}) op_in, op_out = mo.siso_or( { 'none': lambda: check_filters(filters, stride), 'conv1': lambda: wrap_relu_batch_norm( conv2d(D([filters]), D([1]), h_stride=D([stride]))), 'conv3': lambda: full_conv_op(filters, 3, stride, 1, False), 'depth_sep3': lambda: separable_conv_op(filters, 3, stride), 'depth_sep5': lambda: separable_conv_op(filters, 5, stride), 'depth_sep7': lambda: separable_conv_op(filters, 7, stride), 'dilated_3x3_rate_2': lambda: full_conv_op(filters, 3, stride, 2, False), 'dilated_3x3_rate_4': lambda: full_conv_op(filters, 3, stride, 4, False), 'dilated_3x3_rate_6': lambda: full_conv_op(filters, 3, stride, 6, False), '1x3_3x1': lambda: full_conv_op(filters, 3, stride, 1, True), '1x7_7x1': lambda: full_conv_op(filters, 7, stride, 1, True), 'avg2': lambda: pool_op(filters, 2, stride, 'avg'), 'avg3': lambda: pool_op(filters, 3, stride, 'avg'), 'max2': lambda: pool_op(filters, 2, stride, 'max'), 'max3': lambda: pool_op(filters, 3, stride, 'max'), 'min2': lambda: pool_op(filters, 2, stride, 'min') }, cell_ops[node_id * 2 + op_num]) drop_in, drop_out = miso_optional(lambda: drop_path(cell_ratio), h_is_not_none) drop_in['in0'].connect(op_out['out']) drop_in['in1'].connect(global_vars['total_steps']) return op_in, drop_out
def stem(filters): return mo.siso_sequential( [conv2d(D([filters]), D([3])), batch_normalization()])