Beispiel #1
0
 def transform(x):
     xs = cache(x)
     xs = [
         torch.from_numpy(x.astype(np.float32)).unsqueeze(0) / 8 for x in xs
     ]
     xs = torch.stack(xs)
     xs = low_pass_filter(xs, 2)
     return xs
Beispiel #2
0
    def forward(self, x):  # pylint: disable=W
        def gate(y):
            nbatch = y.size(0)
            nx = y.size(2)
            ny = y.size(3)
            nz = y.size(4)

            size_out = sum(mul * (2 * n + 1)
                           for n, mul in enumerate(self.repr_out))

            if self.gate_act is not None:
                g = y[:, size_out:]
                g = self.gate_act(g)
                begin_g = 0  # index of first scalar gate capsule

            z = y.new_empty(
                (y.size(0), size_out, y.size(2), y.size(3), y.size(4)))
            begin_y = 0  # index of first capsule

            for n, mul in enumerate(self.repr_out):
                if mul == 0:
                    continue
                dim = 2 * n + 1

                # crop out capsules of order n
                field_y = y[:, begin_y:begin_y +
                            mul * dim]  # [batch, feature * repr, x, y, z]

                if n == 0:
                    # Scalar activation
                    if self.scalar_act is not None:
                        field = self.scalar_act(field_y)
                    else:
                        field = field_y
                else:
                    if self.gate_act is not None:
                        # reshape channels in capsules and capsule entries
                        field_y = field_y.contiguous()
                        field_y = field_y.view(
                            nbatch, mul, dim, nx, ny,
                            nz)  # [batch, feature, repr, x, y, z]

                        # crop out corresponding scalar gates
                        field_g = g[:, begin_g:begin_g +
                                    mul]  # [batch, feature, x, y, z]
                        begin_g += mul
                        # reshape channels for broadcasting
                        field_g = field_g.contiguous()
                        field_g = field_g.view(
                            nbatch, mul, 1, nx, ny,
                            nz)  # [batch, feature, repr, x, y, z]

                        # scale non-scalar capsules by gate values
                        field = field_y * field_g  # [batch, feature, repr, x, y, z]
                        field = field.view(
                            nbatch, mul * dim, nx, ny,
                            nz)  # [batch, feature * repr, x, y, z]
                        del field_g
                    else:
                        field = field_y
                del field_y

                z[:, begin_y:begin_y + mul * dim] = field
                begin_y += mul * dim
                del field

            return z

        # convolution
        z = self.conv(x)

        # gate
        if self.scalar_act is not None or self.gate_act is not None:
            z = torch.utils.checkpoint.checkpoint(
                gate, z) if self.checkpoint else gate(z)

        # stride
        if self.stride > 1:
            z = low_pass_filter(z, self.stride, self.stride)

        # dropout
        if self.dropout is not None:
            z = self.dropout(z)

        return z
Beispiel #3
0
 def forward(self, inp):
     return low_pass_filter(inp, self.scale, self.stride)
Beispiel #4
0
 def transform(x):
     x = cache(x)
     x = torch.from_numpy(x.astype(np.float32)).unsqueeze(0) / 8
     x = low_pass_filter(x, 2)
     return x
Beispiel #5
0
 def forward(self, inp):  # pylint: disable=W
     inp = low_pass_filter(inp, 2)
     return self.sequence(inp)