Example #1
0
        def resnet_down_block(u, ul, us, uls):
            us = us.copy()
            uls = uls.copy()
            for _ in range(nr_resnet):
                u = ResnetDown()(u, us.pop())
                ul = ResnetDownRight()(ul, np.concatenate((u, uls.pop()), -1))

            return u, ul, us, uls
Example #2
0
 def up_pass(images):
     images = np.concatenate((images, np.ones(images.shape[:-1] + (1, ))),
                             -1)
     us = [down_shift(ConvDown(filter_shape=(2, 3))(images))]
     uls = [
         down_shift(ConvDown(filter_shape=(1, 3))(images)) +
         right_shift(ConvDownRight(filter_shape=(2, 1))(images))
     ]
     us, uls = ResnetUpBlock()(us, uls)
     us.append(HalveDown()(us[-1]))
     uls.append(HalveDownRight()(uls[-1]))
     us, uls = ResnetUpBlock()(us, uls)
     us.append(HalveDown()(us[-1]))
     uls.append(HalveDownRight()(uls[-1]))
     return ResnetUpBlock()(us, uls)
Example #3
0
def right_shift(input):
    h, _, c = input.shape
    return np.concatenate((np.zeros((h, 1, c)), input[:, :-1]), 1)
Example #4
0
def down_shift(input):
    _, w, c = input.shape
    return np.concatenate((np.zeros((1, w, c)), input[:-1]), 0)
Example #5
0
def concat_elu(x, axis=-1):
    return elu(np.concatenate((x, -x), axis))