Esempio n. 1
0
def test_reorder_spatial_single_spatial(CDHWN):
    # Reorder to NCH
    tensor = ng.placeholder([CDHWN[-1], CDHWN[0], CDHWN[2]])
    new_axes = reorder_spatial_axes(tensor, "C", ("D", "H", "W")).axes
    assert new_axes == CDHWN
    assert new_axes[1].length == 1  # D has been added with length 1
    assert new_axes[3].length == 1  # W has been added with length 1
Esempio n. 2
0
def test_reorder_spatial_double_spatial(CDHWN):
    tensor = ng.placeholder([CDHWN[-1], CDHWN[0], CDHWN[3], CDHWN[2]])
    new_axes = reorder_spatial_axes(tensor).axes
    assert len(new_axes) == 5
    assert new_axes[0].name == 'C'
    assert new_axes[1].length == 1
    assert new_axes[2].length == CDHWN[3].length
    assert new_axes[3].length == CDHWN[2].length
Esempio n. 3
0
    def __call__(self, in_obj):
        ppm = self.poolparams.copy()
        in_obj = reorder_spatial_axes(in_obj)
        in_axes = in_obj.axes

        if self.o_axes is None:
            self.o_axes = ng.make_axes([
                ng.make_axis(name=a.name) for a in in_axes if not a.is_batch
            ])
            # set lengths
            out_shape = [
                output_dim(in_axes[0].length, ppm['J'], ppm['pad_d'], ppm['str_d']),
                output_dim(in_axes[1].length, ppm['T'], ppm['pad_d'], ppm['str_d']),
                output_dim(in_axes[2].length, ppm['R'], ppm['pad_h'], ppm['str_h']),
                output_dim(in_axes[3].length, ppm['S'], ppm['pad_w'], ppm['str_w'])
            ]
            self.o_axes.set_shape(out_shape)
            self.o_axes |= in_axes.batch_axis()

        return ng.pooling(ppm, in_obj, axes=self.o_axes)
Esempio n. 4
0
    def __call__(self, in_obj):
        cpm = self.convparams.copy()
        in_obj = reorder_spatial_axes(in_obj)
        in_axes = in_obj.axes

        if self.f_axes is None:
            self.f_axes = ng.make_axes([in_axes[0]])
            for nm in 'TRSK':
                self.f_axes |= ng.make_axis(length=cpm[nm], name=nm)
            # mark 'K' as a shadow axis for the initializers.
            self.axes_map = shadow_axes_map(self.f_axes.find_by_name('K'))
            self.f_axes = ng.make_axes([
                axis if axis.name != 'K' else list(self.axes_map.keys())[0]
                for axis in self.f_axes
            ])

            self.W = ng.variable(axes=self.f_axes, initial_value=self.init,
                                 scope=self.scope).named('convwt')

        if self.o_axes is None:
            self.o_axes = ng.make_axes([
                ng.make_axis(name=a.name) for a in in_axes if not a.is_batch
            ])
            # set lengths
            out_shape = [
                self.f_axes[-1].length,
                output_dim(in_axes[1].length, cpm['T'], cpm['pad_d'], cpm['str_d'], False,
                           cpm['dil_d']),
                output_dim(in_axes[2].length, cpm['R'], cpm['pad_h'], cpm['str_h'], False,
                           cpm['dil_h']),
                output_dim(in_axes[3].length, cpm['S'], cpm['pad_w'], cpm['str_w'], False,
                           cpm['dil_w'])
            ]
            self.o_axes.set_shape(out_shape)
            self.o_axes |= in_axes.batch_axis()

        return ng.map_roles(ng.convolution(cpm, in_obj, self.W, axes=self.o_axes), self.axes_map)
    def __call__(self,
                 in_obj,
                 channel_axes="C",
                 spatial_axes=("D", "H", "W"),
                 **kwargs):
        """
        Arguments:
            in_obj (Op): Input op
            channel_axes (str): name of the expected channel axis type - defaults to "C"
            spatial_axes (tuple): names of expected depth, height and width axis types - defaults
                                  to "D", "H", and "W"
        """
        if isinstance(spatial_axes, dict):
            spatial_axes = tuple(
                spatial_axes.get(name, name) for name in ("D", "H", "W"))
        elif isinstance(spatial_axes, tuple):
            if len(spatial_axes) < 3:
                raise ValueError(
                    "spatial_axes must have length 3 (e.g. ('D', 'H', 'W'))")
            spatial_axes = tuple(
                name if name else default
                for name, default in zip(spatial_axes, ("D", "H", "W")))

        orig_axes = in_obj.axes
        in_obj = reorder_spatial_axes(in_obj, channel_axes, spatial_axes)
        channel_axes = in_obj.axes.get_by_names(channel_axes)
        spatial_axes = in_obj.axes.get_by_names(*spatial_axes)

        filter_axes = self._filter_axes(channel_axes, spatial_axes)

        # mark 'K' as a shadow axis for the initializers.
        axes_map = shadow_axes_map(filter_axes.find_by_name('K'))
        filter_axes = ng.make_axes([
            axis if axis.name != 'K' else list(axes_map.keys())[0]
            for axis in filter_axes
        ])

        if not self.initialized:
            if not self.weight_norm:
                self.W = ng.variable(axes=filter_axes,
                                     initial_value=self.init,
                                     metadata={
                                         "label": LABELS["weight"]
                                     }).named("W")
            else:
                self.v = ng.variable(axes=filter_axes,
                                     initial_value=self.init,
                                     metadata={
                                         "label": LABELS["weight"]
                                     }).named("v")
                out_axes = ng.make_axes(
                    [filter_axes.get_by_names("K__NG_SHADOW")])
                v_norm = ng.mean(ng.square(self.v), out_axes=out_axes)
                self.g = ng.variable(axes=out_axes,
                                     initial_value=self.init,
                                     metadata={
                                         "label": LABELS["weight"]
                                     }).named("g")
                self.W = self.g * self.v * ng.reciprocal(
                    ng.sqrt(v_norm + 1e-3))
        else:
            if filter_axes != self.W.axes:
                raise ValueError(
                    ("{layer_name} layer has already been initialized with an "
                     "input object which has resulted in filter axes: "
                     "{existing_filter_axes}. This new input object has axes: "
                     "{input_axes}, which implies the need for filter axes: "
                     "{new_filter_axes} which are different than the existing "
                     "filter axes.").format(
                         layer_name=self.name,
                         existing_filter_axes=self.W.axes,
                         input_axes=in_obj.axes,
                         new_filter_axes=filter_axes,
                     ))

        output = ng.map_roles(
            self._conv_op(in_obj, channel_axes, spatial_axes), axes_map)
        # Reorder the output to match the input order
        output_axis_order = ng.make_axes(
            [output.axes.find_by_name(ax.name)[0] for ax in orig_axes])
        # Remove introduced axes. If their length is > 1, then perhaps they should be kept
        slices = [
            0 if (ax not in orig_axes) and ax.length == 1 else slice(None)
            for ax in output.axes
        ]
        output = ng.tensor_slice(output, slices)
        # New axes with length > 1 may have been introduced. Add them to the end.
        output_axis_order = output_axis_order | output.axes
        return ng.axes_with_order(output, output_axis_order)
Esempio n. 6
0
def test_reorder_spatial_toomany_spatial(CDHWN, axis_a):
    tensor = ng.placeholder(CDHWN + axis_a)
    with pytest.raises(IncompatibleAxesError):
        reorder_spatial_axes(tensor, "C", ("D", "H", "W"))
Esempio n. 7
0
def test_reorder_spatial_triple_spatial(CDHWN):
    # Reorder to NCWHD
    tensor = ng.placeholder(
        [CDHWN[-1], CDHWN[0], CDHWN[3], CDHWN[2], CDHWN[1]])
    new_axes = reorder_spatial_axes(tensor, "C", ("D", "H", "W")).axes
    assert new_axes == CDHWN
Esempio n. 8
0
def test_reorder_spatial_no_spatial(CDHWN):
    tensor = ng.placeholder([CDHWN[0], CDHWN[-1]])
    with pytest.raises(IncompatibleAxesError):
        reorder_spatial_axes(tensor, "C", ("D", "H", "W"))
Esempio n. 9
0
def test_reorder_spatial_no_channel(CDHWN):
    tensor = ng.placeholder(CDHWN[-2:])
    new_axes = reorder_spatial_axes(tensor, "C", ("D", "H", "W")).axes
    assert len(new_axes) == 5
    assert new_axes[0].name == 'C'
    assert new_axes[0].length == 1
Esempio n. 10
0
def test_reorder_spatial_no_batch(CDHWN):
    tensor = ng.placeholder(CDHWN[0:2])
    with pytest.raises(ValueError):
        reorder_spatial_axes(tensor, "C", ("D", "H", "W"))
Esempio n. 11
0
def test_reorder_spatial_toomany_spatial(CDHWN, axis_a):
    tensor = ng.placeholder(CDHWN + axis_a)
    with pytest.raises(ValueError):
        reorder_spatial_axes(tensor)
Esempio n. 12
0
def test_reorder_spatial_no_spatial(CDHWN):
    tensor = ng.placeholder([CDHWN[0], CDHWN[-1]])
    with pytest.raises(ValueError):
        reorder_spatial_axes(tensor)