def _set_cell_ops(edge, C): edge.data.set('op', [ ops.Identity(), ops.Zero(stride=1), ops.ReLUConvBN(C, C, kernel_size=3), ops.ReLUConvBN(C, C, kernel_size=1), ops.AvgPool1x1(kernel_size=3, stride=1), ])
def _set_ops(edge, C, stride): """ Replace the 'op' at the edges with the ones defined here. This function is called by the framework for every edge in the defined scope. Args: current_egde_data (EdgeData): The data that currently sits at the edge. C (int): convolutional channels stride (int): stride for the operation Returns: EdgeData: the updated EdgeData object. """ edge.data.set('op', [ ops.Identity() if stride == 1 else FactorizedReduce(C, C, stride, affine=False), ops.Zero(stride=stride), ops.MaxPool(3, stride), ops.AvgPool(3, stride), ops.SepConv( C, C, kernel_size=3, stride=stride, padding=1, affine=False), ops.SepConv( C, C, kernel_size=5, stride=stride, padding=2, affine=False), ops.DilConv(C, C, kernel_size=3, stride=stride, padding=2, dilation=2, affine=False), ops.DilConv(C, C, kernel_size=5, stride=stride, padding=4, dilation=2, affine=False), ])
def _set_cell_ops(edge, C): edge.data.set('op', [ ops.Identity(), ops.Zero(stride=1), ])
def _set_cell_ops(current_edge_data, C): current_edge_data.set('op', [ ops.Identity(), ops.Zero(stride=1), ])