def up_pass(images): images = np.pad(images, ((0, 0), (0, 0), (0, 0), (0, 1)), constant_values=1) us = [down_shift(ConvDown(filter_shape=(2, 3))(images))] uls = [down_shift(ConvDown(filter_shape=(1, 3))(images)) + right_shift(ConvDownRight(filter_shape=(2, 1))(images))] us, uls = ResnetUpBlock()(us, uls) us.append(HalveDown()(us[-1])) uls.append(HalveDownRight()(uls[-1])) us, uls = ResnetUpBlock()(us, uls) us.append(HalveDown()(us[-1])) uls.append(HalveDownRight()(uls[-1])) return ResnetUpBlock()(us, uls)
def down_right_shifted_conv(inputs): padded = np.pad(inputs, ((0, 0), (f_h - 1, 0), (f_w - 1, 0), (0, 0))) return Conv(out_chan, filter_shape, strides, 'VALID', **kwargs)(padded)
def right_shift(input): return np.pad(input[:, :-1], ((0, 0), (1, 0), (0, 0)))
def down_shift(input): return np.pad(input[:-1], ((1, 0), (0, 0), (0, 0)))