def intermediate_node_fn(num_inputs, node_id, filters, cell_ops): return mo.siso_sequential([ add(num_inputs), mo.siso_or( { 'conv1': lambda: conv2d(D([filters]), D([1])), 'conv3': lambda: conv2d(D([filters]), D([3])), 'max3': lambda: max_pool2d(D([3])) }, cell_ops[node_id]), batch_normalization(), relu() ])
def enas_op(h_op_name, out_filters, name, weight_sharer): return mo.siso_or( { 'conv3': lambda: enas_conv(out_filters, 3, False, weight_sharer, name), 'conv5': lambda: enas_conv(out_filters, 5, False, weight_sharer, name), 'dsep_conv3': lambda: enas_conv(out_filters, 3, True, weight_sharer, name), 'dsep_conv5': lambda: enas_conv(out_filters, 5, True, weight_sharer, name), 'avg_pool': lambda: avg_pool(D([3]), D([1])), 'max_pool': lambda: max_pool(D([3]), D([1])) }, h_op_name)
def intermediate_node_fn(reduction, input_id, node_id, op_num, filters, cell_ratio, cell_ops): stride = 2 if reduction and input_id < 2 else 1 h_is_not_none = co.DependentHyperparameter( lambda op: op != 'none', {'op': cell_ops[node_id * 2 + op_num]}) op_in, op_out = mo.siso_or( { 'none': lambda: check_filters(filters, stride), 'conv1': lambda: wrap_relu_batch_norm( conv2d(D([filters]), D([1]), h_stride=D([stride]))), 'conv3': lambda: full_conv_op(filters, 3, stride, 1, False), 'depth_sep3': lambda: separable_conv_op(filters, 3, stride), 'depth_sep5': lambda: separable_conv_op(filters, 5, stride), 'depth_sep7': lambda: separable_conv_op(filters, 7, stride), 'dilated_3x3_rate_2': lambda: full_conv_op(filters, 3, stride, 2, False), 'dilated_3x3_rate_4': lambda: full_conv_op(filters, 3, stride, 4, False), 'dilated_3x3_rate_6': lambda: full_conv_op(filters, 3, stride, 6, False), '1x3_3x1': lambda: full_conv_op(filters, 3, stride, 1, True), '1x7_7x1': lambda: full_conv_op(filters, 7, stride, 1, True), 'avg2': lambda: pool_op(filters, 2, stride, 'avg'), 'avg3': lambda: pool_op(filters, 3, stride, 'avg'), 'max2': lambda: pool_op(filters, 2, stride, 'max'), 'max3': lambda: pool_op(filters, 3, stride, 'max'), 'min2': lambda: pool_op(filters, 2, stride, 'min') }, cell_ops[node_id * 2 + op_num]) drop_in, drop_out = miso_optional(lambda: drop_path(cell_ratio), h_is_not_none) drop_in['In0'].connect(op_out['Out']) drop_in['In1'].connect(global_vars['progress']) return op_in, drop_out