def resnet_down_block(u, ul, us, uls): us = us.copy() uls = uls.copy() for _ in range(nr_resnet): u = ResnetDown()(u, us.pop()) ul = ResnetDownRight()(ul, np.concatenate((u, uls.pop()), -1)) return u, ul, us, uls
def up_pass(images): images = np.concatenate((images, np.ones(images.shape[:-1] + (1, ))), -1) us = [down_shift(ConvDown(filter_shape=(2, 3))(images))] uls = [ down_shift(ConvDown(filter_shape=(1, 3))(images)) + right_shift(ConvDownRight(filter_shape=(2, 1))(images)) ] us, uls = ResnetUpBlock()(us, uls) us.append(HalveDown()(us[-1])) uls.append(HalveDownRight()(uls[-1])) us, uls = ResnetUpBlock()(us, uls) us.append(HalveDown()(us[-1])) uls.append(HalveDownRight()(uls[-1])) return ResnetUpBlock()(us, uls)
def right_shift(input): h, _, c = input.shape return np.concatenate((np.zeros((h, 1, c)), input[:, :-1]), 1)
def down_shift(input): _, w, c = input.shape return np.concatenate((np.zeros((1, w, c)), input[:-1]), 0)
def concat_elu(x, axis=-1): return elu(np.concatenate((x, -x), axis))