def test_reorder_spatial_single_spatial(CDHWN): # Reorder to NCH tensor = ng.placeholder([CDHWN[-1], CDHWN[0], CDHWN[2]]) new_axes = reorder_spatial_axes(tensor, "C", ("D", "H", "W")).axes assert new_axes == CDHWN assert new_axes[1].length == 1 # D has been added with length 1 assert new_axes[3].length == 1 # W has been added with length 1
def test_reorder_spatial_double_spatial(CDHWN): tensor = ng.placeholder([CDHWN[-1], CDHWN[0], CDHWN[3], CDHWN[2]]) new_axes = reorder_spatial_axes(tensor).axes assert len(new_axes) == 5 assert new_axes[0].name == 'C' assert new_axes[1].length == 1 assert new_axes[2].length == CDHWN[3].length assert new_axes[3].length == CDHWN[2].length
def __call__(self, in_obj): ppm = self.poolparams.copy() in_obj = reorder_spatial_axes(in_obj) in_axes = in_obj.axes if self.o_axes is None: self.o_axes = ng.make_axes([ ng.make_axis(name=a.name) for a in in_axes if not a.is_batch ]) # set lengths out_shape = [ output_dim(in_axes[0].length, ppm['J'], ppm['pad_d'], ppm['str_d']), output_dim(in_axes[1].length, ppm['T'], ppm['pad_d'], ppm['str_d']), output_dim(in_axes[2].length, ppm['R'], ppm['pad_h'], ppm['str_h']), output_dim(in_axes[3].length, ppm['S'], ppm['pad_w'], ppm['str_w']) ] self.o_axes.set_shape(out_shape) self.o_axes |= in_axes.batch_axis() return ng.pooling(ppm, in_obj, axes=self.o_axes)
def __call__(self, in_obj): cpm = self.convparams.copy() in_obj = reorder_spatial_axes(in_obj) in_axes = in_obj.axes if self.f_axes is None: self.f_axes = ng.make_axes([in_axes[0]]) for nm in 'TRSK': self.f_axes |= ng.make_axis(length=cpm[nm], name=nm) # mark 'K' as a shadow axis for the initializers. self.axes_map = shadow_axes_map(self.f_axes.find_by_name('K')) self.f_axes = ng.make_axes([ axis if axis.name != 'K' else list(self.axes_map.keys())[0] for axis in self.f_axes ]) self.W = ng.variable(axes=self.f_axes, initial_value=self.init, scope=self.scope).named('convwt') if self.o_axes is None: self.o_axes = ng.make_axes([ ng.make_axis(name=a.name) for a in in_axes if not a.is_batch ]) # set lengths out_shape = [ self.f_axes[-1].length, output_dim(in_axes[1].length, cpm['T'], cpm['pad_d'], cpm['str_d'], False, cpm['dil_d']), output_dim(in_axes[2].length, cpm['R'], cpm['pad_h'], cpm['str_h'], False, cpm['dil_h']), output_dim(in_axes[3].length, cpm['S'], cpm['pad_w'], cpm['str_w'], False, cpm['dil_w']) ] self.o_axes.set_shape(out_shape) self.o_axes |= in_axes.batch_axis() return ng.map_roles(ng.convolution(cpm, in_obj, self.W, axes=self.o_axes), self.axes_map)
def __call__(self, in_obj, channel_axes="C", spatial_axes=("D", "H", "W"), **kwargs): """ Arguments: in_obj (Op): Input op channel_axes (str): name of the expected channel axis type - defaults to "C" spatial_axes (tuple): names of expected depth, height and width axis types - defaults to "D", "H", and "W" """ if isinstance(spatial_axes, dict): spatial_axes = tuple( spatial_axes.get(name, name) for name in ("D", "H", "W")) elif isinstance(spatial_axes, tuple): if len(spatial_axes) < 3: raise ValueError( "spatial_axes must have length 3 (e.g. ('D', 'H', 'W'))") spatial_axes = tuple( name if name else default for name, default in zip(spatial_axes, ("D", "H", "W"))) orig_axes = in_obj.axes in_obj = reorder_spatial_axes(in_obj, channel_axes, spatial_axes) channel_axes = in_obj.axes.get_by_names(channel_axes) spatial_axes = in_obj.axes.get_by_names(*spatial_axes) filter_axes = self._filter_axes(channel_axes, spatial_axes) # mark 'K' as a shadow axis for the initializers. axes_map = shadow_axes_map(filter_axes.find_by_name('K')) filter_axes = ng.make_axes([ axis if axis.name != 'K' else list(axes_map.keys())[0] for axis in filter_axes ]) if not self.initialized: if not self.weight_norm: self.W = ng.variable(axes=filter_axes, initial_value=self.init, metadata={ "label": LABELS["weight"] }).named("W") else: self.v = ng.variable(axes=filter_axes, initial_value=self.init, metadata={ "label": LABELS["weight"] }).named("v") out_axes = ng.make_axes( [filter_axes.get_by_names("K__NG_SHADOW")]) v_norm = ng.mean(ng.square(self.v), out_axes=out_axes) self.g = ng.variable(axes=out_axes, initial_value=self.init, metadata={ "label": LABELS["weight"] }).named("g") self.W = self.g * self.v * ng.reciprocal( ng.sqrt(v_norm + 1e-3)) else: if filter_axes != self.W.axes: raise ValueError( ("{layer_name} layer has already been initialized with an " "input object which has resulted in filter axes: " "{existing_filter_axes}. This new input object has axes: " "{input_axes}, which implies the need for filter axes: " "{new_filter_axes} which are different than the existing " "filter axes.").format( layer_name=self.name, existing_filter_axes=self.W.axes, input_axes=in_obj.axes, new_filter_axes=filter_axes, )) output = ng.map_roles( self._conv_op(in_obj, channel_axes, spatial_axes), axes_map) # Reorder the output to match the input order output_axis_order = ng.make_axes( [output.axes.find_by_name(ax.name)[0] for ax in orig_axes]) # Remove introduced axes. If their length is > 1, then perhaps they should be kept slices = [ 0 if (ax not in orig_axes) and ax.length == 1 else slice(None) for ax in output.axes ] output = ng.tensor_slice(output, slices) # New axes with length > 1 may have been introduced. Add them to the end. output_axis_order = output_axis_order | output.axes return ng.axes_with_order(output, output_axis_order)
def test_reorder_spatial_toomany_spatial(CDHWN, axis_a): tensor = ng.placeholder(CDHWN + axis_a) with pytest.raises(IncompatibleAxesError): reorder_spatial_axes(tensor, "C", ("D", "H", "W"))
def test_reorder_spatial_triple_spatial(CDHWN): # Reorder to NCWHD tensor = ng.placeholder( [CDHWN[-1], CDHWN[0], CDHWN[3], CDHWN[2], CDHWN[1]]) new_axes = reorder_spatial_axes(tensor, "C", ("D", "H", "W")).axes assert new_axes == CDHWN
def test_reorder_spatial_no_spatial(CDHWN): tensor = ng.placeholder([CDHWN[0], CDHWN[-1]]) with pytest.raises(IncompatibleAxesError): reorder_spatial_axes(tensor, "C", ("D", "H", "W"))
def test_reorder_spatial_no_channel(CDHWN): tensor = ng.placeholder(CDHWN[-2:]) new_axes = reorder_spatial_axes(tensor, "C", ("D", "H", "W")).axes assert len(new_axes) == 5 assert new_axes[0].name == 'C' assert new_axes[0].length == 1
def test_reorder_spatial_no_batch(CDHWN): tensor = ng.placeholder(CDHWN[0:2]) with pytest.raises(ValueError): reorder_spatial_axes(tensor, "C", ("D", "H", "W"))
def test_reorder_spatial_toomany_spatial(CDHWN, axis_a): tensor = ng.placeholder(CDHWN + axis_a) with pytest.raises(ValueError): reorder_spatial_axes(tensor)
def test_reorder_spatial_no_spatial(CDHWN): tensor = ng.placeholder([CDHWN[0], CDHWN[-1]]) with pytest.raises(ValueError): reorder_spatial_axes(tensor)