Exemplo n.º 1
0
    def backward_attn(self, input: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
        outputs = []
        out = input
        for i in range(self.levels - 1):
            if i > 0:
                out1, out2 = split2d(out, self.blocks[i].z_channels)
                outputs.append(out2)
                out = out1
            out = squeeze2d(out, factor=2)
            if self.squeeze_h:
                h = squeeze2d(h, factor=2)

        logdet_accum = input.new_zeros(input.size(0))
        for i, block in enumerate(reversed(self.blocks)):
            if i > 0:
                out = unsqueeze2d(out, factor=2)
                if self.squeeze_h:
                    h = unsqueeze2d(h, factor=2)
                if i < self.levels - 1:
                    out2 = outputs.pop()
                    out = unsplit2d([out, out2])
            if i < self.levels - 1:
                out, logdet = block.backward(out, h=h)
            else:
                out, logdet, attn = block.backward_attn(out, h=h)
            logdet_accum = logdet_accum + logdet
        assert len(outputs) == 0
        return out, logdet_accum, attn
Exemplo n.º 2
0
    def forward_attn(self, input: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
        logdet_accum = input.new_zeros(input.size(0))
        out = input
        outputs = []
        attns = None
        for i, block in enumerate(self.blocks):
            # print('block {}, intput_out: shape: {}'.format(i, out.shape))
            if i == 0:
                out, logdet, attns = block.forward_attn(out, h=h)
            else:
                out, logdet = block.forward(out, h=h)
            # print('block {}, forward_out: shape: {}'.format(i, out.shape))
            logdet_accum = logdet_accum + logdet
            if i < self.levels - 1:
                if i > 0:
                    # split when block is not bottom or top
                    out1, out2 = split2d(out, block.z_channels)
                    outputs.append(out2)
                    out = out1
                    # print('block {}, split_out: shape: {}'.format(i, out.shape))
                # squeeze when block is not top
                out = squeeze2d(out, factor=2)
                # print('block {}, squeeze_out: shape: {}'.format(i, out.shape))
                if self.squeeze_h:
                    h = squeeze2d(h, factor=2)

        out = unsqueeze2d(out, factor=2)
        for _ in range(self.internals):
            out2 = outputs.pop()
            out = unsqueeze2d(unsplit2d([out, out2]), factor=2)
        assert len(outputs) == 0
        
        return out, logdet_accum, attns
Exemplo n.º 3
0
    def init(self,
             data: torch.Tensor,
             h=None,
             init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        logdet_accum = data.new_zeros(data.size(0))
        out = data
        outputs = []
        for i, block in enumerate(self.blocks):
            out, logdet = block.init(out, h=h, init_scale=init_scale)
            logdet_accum = logdet_accum + logdet
            if i < self.levels - 1:
                if i > 0:
                    # split when block is not bottom or top
                    out1, out2 = split2d(out, block.z_channels)
                    outputs.append(out2)
                    out = out1
                # squeeze when block is not top
                out = squeeze2d(out, factor=2)
                if self.squeeze_h:
                    h = squeeze2d(h, factor=2)

        out = unsqueeze2d(out, factor=2)
        for _ in range(self.internals):
            out2 = outputs.pop()
            out = unsqueeze2d(unsplit2d([out, out2]), factor=2)
        assert len(outputs) == 0
        return out, logdet_accum
Exemplo n.º 4
0
 def init(self, data, h=None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
     # conv1x1
     out, logdet_accum = self.conv1x1.init(data, init_scale=init_scale)
     # coupling
     out, logdet = self.coupling.init(out, h=h, init_scale=init_scale)
     logdet_accum = logdet_accum + logdet
     # actnorm
     out1, out2 = split2d(out, self.z1_channels)
     out2, logdet = self.actnorm.init(out2, init_scale=init_scale)
     logdet_accum = logdet_accum + logdet
     out = unsplit2d([out1, out2])
     return out, logdet_accum
Exemplo n.º 5
0
 def backward(self, input: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
     # actnorm
     out1, out2 = split2d(input, self.z1_channels)
     out2, logdet_accum = self.actnorm.backward(out2)
     out = unsplit2d([out1, out2])
     # coupling
     out, logdet = self.coupling.backward(out, h=h)
     logdet_accum = logdet_accum + logdet
     # conv1x1
     out, logdet = self.conv1x1.backward(out)
     logdet_accum = logdet_accum + logdet
     return out, logdet_accum
Exemplo n.º 6
0
    def forward_attn(self, input: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        attn = None
        # conv1x1
        out, logdet_accum = self.conv1x1.forward(input)
        # coupling
        out, logdet, attn = self.coupling.forward_attn(out, h=h)

        logdet_accum = logdet_accum + logdet
        # actnorm
        out1, out2 = split2d(out, self.z1_channels)
        out2, logdet = self.actnorm.forward(out2)
        logdet_accum = logdet_accum + logdet
        out = unsplit2d([out1, out2])
        return out, logdet_accum, attn
Exemplo n.º 7
0
    def init(self, data, h=None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        out = data
        # [batch]
        logdet_accum = data.new_zeros(data.size(0))
        outputs = []
        for layer, prior in zip(self.layers, self.priors):
            for step in layer:
                out, logdet = step.init(out, h=h, init_scale=init_scale)
                logdet_accum = logdet_accum + logdet
            out, logdet = prior.init(out, h=h, init_scale=init_scale)
            logdet_accum = logdet_accum + logdet
            # split
            out1, out2 = split2d(out, prior.z1_channels)
            outputs.append(out2)
            out = out1

        outputs.append(out)
        outputs.reverse()
        out = unsplit2d(outputs)
        return out, logdet_accum
Exemplo n.º 8
0
    def forward(self, input: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
        out = input
        # [batch]
        logdet_accum = input.new_zeros(input.size(0))
        outputs = []
        for layer, prior in zip(self.layers, self.priors):
            for step in layer:
                out, logdet = step.forward(out, h=h)
                logdet_accum = logdet_accum + logdet
            out, logdet = prior.forward(out, h=h)
            logdet_accum = logdet_accum + logdet
            # split
            out1, out2 = split2d(out, prior.z1_channels)
            outputs.append(out2)
            out = out1

        outputs.append(out)
        outputs.reverse()
        out = unsplit2d(outputs)
        return out, logdet_accum
Exemplo n.º 9
0
    def backward(self, input: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
        out = input
        outputs = []
        for prior in self.priors:
            out1, out2 = split2d(out, prior.z1_channels)
            outputs.append(out2)
            out = out1

        # [batch]
        logdet_accum = out.new_zeros(out.size(0))
        for layer, prior in zip(reversed(self.layers), reversed(self.priors)):
            out2 = outputs.pop()
            out = unsplit2d([out, out2])
            out, logdet = prior.backward(out, h=h)
            logdet_accum = logdet_accum + logdet
            for step in reversed(layer):
                out, logdet = step.backward(out, h=h)
                logdet_accum = logdet_accum + logdet

        assert len(outputs) == 0
        return out, logdet_accum