def reduction_cell(graph, prev, cur, out_channels): cur = squeeze(graph, out_channels, cur) prev = fit(graph, cur, prev) ts = list() outputs = list() ts.append(seperable_conv(graph, input=prev, out_channels=out_channels, kernels=(7,7), strides=(2,2), padding="SAME")) ts.append(seperable_conv(graph, input=cur, out_channels=out_channels, kernels=(5,5), strides=(2,2), padding="SAME")) outputs.append(graph.add(ts[0], ts[1])) ts.append(graph.maxpool2d(input=cur, kernels=(3,3), strides=(2,2), padding="SAME")) ts.append(seperable_conv(graph, input=prev, out_channels=out_channels, kernels=(7,7), strides=(2,2), padding="SAME")) outputs.append(graph.add(ts[2], ts[3])) ts.append(graph.avgpool2d(input=cur, kernels=(3,3), strides=(2,2), padding="SAME")) ts.append(seperable_conv(graph, input=prev, out_channels=out_channels, kernels=(5,5), strides=(2,2), padding="SAME")) outputs.append(graph.add(ts[4], ts[5])) ts.append(graph.maxpool2d(input=cur, kernels=(3,3), strides=(2,2), padding="SAME")) ts.append(seperable_conv(graph, input=outputs[0], out_channels=out_channels, kernels=(3,3), strides=(1,1), padding="SAME")) outputs.append(graph.add(ts[6], ts[7])) ts.append(graph.avgpool2d(input=outputs[0], kernels=(3,3), strides=(1,1), padding="SAME")) ts.append(outputs[1]) outputs.append(graph.add(ts[8], ts[9])) return graph.concat(1, outputs)
def normal_cell(graph, prev, cur, out_channels): cur = squeeze(graph, out_channels, cur) prev = fit(graph, cur, prev) ts = list() ts.append(seperable_conv(graph, input=cur, out_channels=out_channels, kernels=(3,3), strides=(1,1), padding="SAME")) ts.append(cur) ts.append(seperable_conv(graph, input=prev, out_channels=out_channels, kernels=(3,3), strides=(1,1), padding="SAME")) ts.append(seperable_conv(graph, input=cur, out_channels=out_channels, kernels=(3,3), strides=(1,1), padding="SAME")) ts.append(graph.avgpool2d(input=cur, kernels=(3,3), strides=(1,1), padding="SAME")) ts.append(prev) ts.append(graph.avgpool2d(input=prev, kernels=(3,3), strides=(1,1), padding="SAME")) ts.append(graph.avgpool2d(input=prev, kernels=(3,3), strides=(1,1), padding="SAME")) ts.append(seperable_conv(graph, input=prev, out_channels=out_channels, kernels=(3,3), strides=(1,1), padding="SAME")) ts.append(seperable_conv(graph, input=prev, out_channels=out_channels, kernels=(3,3), strides=(1,1), padding="SAME")) assert len(ts) == 10, "Expected 10 tensors, got {}".format(len(ts)) outputs = list() for i in range(5): outputs.append(graph.add(ts[2*i], ts[2*i+1])) return graph.concat(1, outputs)