Ejemplo n.º 1
0
    def init(self,
             data,
             s=None,
             init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        logdet_accum = data.new_zeros(data.size(0))
        out = data
        outputs = []
        for i, block in enumerate(self.blocks):
            if isinstance(block, MaCowInternalBlock) or isinstance(
                    block, MaCowTopBlock):
                if s is not None:
                    s = squeeze2d(s, factor=2)
                out = squeeze2d(out, factor=2)
            out, logdet = block.init(out, s=s, init_scale=init_scale)
            logdet_accum = logdet_accum + logdet
            if isinstance(block, MaCowInternalBlock):
                out1, out2 = split2d(out, block.z1_channels)
                outputs.append(out2)
                out = out1

        out = unsqueeze2d(out, factor=2)
        for _ in range(self.internals):
            out2 = outputs.pop()
            out = unsqueeze2d(unsplit2d([out, out2]), factor=2)
        assert len(outputs) == 0
        return out, logdet_accum
Ejemplo n.º 2
0
    def backward(self,
                 input: torch.Tensor,
                 s=None) -> Tuple[torch.Tensor, torch.Tensor]:
        outputs = []
        if s is not None:
            s = squeeze2d(s, factor=2)
        out = squeeze2d(input, factor=2)
        for block in self.blocks:
            if isinstance(block, MaCowInternalBlock):
                if s is not None:
                    s = squeeze2d(s, factor=2)
                out1, out2 = split2d(out, block.z1_channels)
                outputs.append(out2)
                out = squeeze2d(out1, factor=2)

        logdet_accum = input.new_zeros(input.size(0))
        for i, block in enumerate(reversed(self.blocks)):
            if isinstance(block, MaCowInternalBlock):
                out2 = outputs.pop()
                out = unsplit2d([out, out2])
            out, logdet = block.backward(out, s=s)
            logdet_accum = logdet_accum + logdet
            if isinstance(block, MaCowInternalBlock) or isinstance(
                    block, MaCowTopBlock):
                if s is not None:
                    s = unsqueeze2d(s, factor=2)
                out = unsqueeze2d(out, factor=2)
        assert len(outputs) == 0
        return out, logdet_accum
Ejemplo n.º 3
0
    def init(
        self,
        data,
        s=None,
        init_scale=1.0
    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
        out = data
        # [batch]
        logdet_accum = data.new_zeros(data.size(0))
        outputs = []
        for layer, prior in zip(self.layers, self.priors):
            for step in layer:
                out, logdet = step.init(out, s=s, init_scale=init_scale)
                logdet_accum = logdet_accum + logdet
            out, logdet = prior.init(out, s=s, init_scale=init_scale)
            logdet_accum = logdet_accum + logdet
            # split
            out1, out2 = split2d(out, prior.z1_channels)
            outputs.append(out2)
            out = out1

        outputs.append(out)
        outputs.reverse()
        out = unsplit2d(outputs)
        return out, logdet_accum
Ejemplo n.º 4
0
    def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        logdet_accum = data.new_zeros(data.size(0))
        out = data
        outputs = []
        for i, block in enumerate(self.blocks):
            out = squeeze2d(out, factor=2)
            out, logdet = block.init(out, init_scale=init_scale)
            logdet_accum = logdet_accum + logdet
            if isinstance(block, GlowInternalBlock):
                out1, out2 = split2d(out, out.size(1) // 2)
                outputs.append(out2)
                out = out1

        out = unsqueeze2d(out, factor=2)
        for _ in range(self.levels - 1):
            out2 = outputs.pop()
            out = unsqueeze2d(unsplit2d([out, out2]), factor=2)
        assert len(outputs) == 0
        return out, logdet_accum
Ejemplo n.º 5
0
    def backward(self,
                 input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        outputs = []
        out = squeeze2d(input, factor=2)
        for _ in range(self.levels - 1):
            out1, out2 = split2d(out, out.size(1) // 2)
            outputs.append(out2)
            out = squeeze2d(out1, factor=2)

        logdet_accum = input.new_zeros(input.size(0))
        for i, block in enumerate(reversed(self.blocks)):
            if isinstance(block, GlowInternalBlock):
                out2 = outputs.pop()
                out = unsplit2d([out, out2])
            out, logdet = block.backward(out)
            logdet_accum = logdet_accum + logdet
            out = unsqueeze2d(out, factor=2)
        assert len(outputs) == 0
        return out, logdet_accum
Ejemplo n.º 6
0
    def forward(self,
                input: torch.Tensor,
                s=None) -> Tuple[torch.Tensor, torch.Tensor]:
        out = input
        # [batch]
        logdet_accum = input.new_zeros(input.size(0))
        outputs = []
        for layer, prior in zip(self.layers, self.priors):
            for step in layer:
                out, logdet = step.forward(out, s=s)
                logdet_accum = logdet_accum + logdet
            out, logdet = prior.forward(out, s=s)
            logdet_accum = logdet_accum + logdet
            # split
            out1, out2 = split2d(out, prior.z1_channels)
            outputs.append(out2)
            out = out1

        outputs.append(out)
        outputs.reverse()
        out = unsplit2d(outputs)
        return out, logdet_accum
Ejemplo n.º 7
0
    def backward(self,
                 input: torch.Tensor,
                 s=None) -> Tuple[torch.Tensor, torch.Tensor]:
        out = input
        outputs = []
        for prior in self.priors:
            out1, out2 = split2d(out, prior.z1_channels)
            outputs.append(out2)
            out = out1

        # [batch]
        logdet_accum = out.new_zeros(out.size(0))
        for layer, prior in zip(reversed(self.layers), reversed(self.priors)):
            out2 = outputs.pop()
            out = unsplit2d([out, out2])
            out, logdet = prior.backward(out, s=s)
            logdet_accum = logdet_accum + logdet
            for step in reversed(layer):
                out, logdet = step.backward(out, s=s)
                logdet_accum = logdet_accum + logdet

        assert len(outputs) == 0
        return out, logdet_accum