예제 #1
0
파일: nmt.py 프로젝트: juheeuu/flowseq
    def backward(self, input: torch.Tensor, tgt_mask: torch.Tensor,
                 src: torch.Tensor,
                 src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        outputs = []
        masks = []
        out = input
        for i in range(self.levels - 1):
            out1, out2 = split(out, self.blocks[i].z_features)
            outputs.append(out2)
            masks.append(tgt_mask)
            out, tgt_mask = squeeze(out1, tgt_mask)

        logdet_accum = input.new_zeros(input.size(0))
        for i, block in enumerate(reversed(self.blocks)):
            if i > 0:
                out2 = outputs.pop()
                tgt_mask = masks.pop()
                out1 = unsqueeze(out)
                out = unsplit([out1, out2])
            out, logdet = block.backward(out, tgt_mask, src, src_mask)
            logdet_accum = logdet_accum + logdet
        assert len(outputs) == 0
        assert len(masks) == 0

        return out, logdet_accum
예제 #2
0
파일: nmt.py 프로젝트: juheeuu/flowseq
    def init(self,
             data: torch.Tensor,
             tgt_mask: torch.Tensor,
             src: torch.Tensor,
             src_mask: torch.Tensor,
             init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        logdet_accum = data.new_zeros(data.size(0))
        out = data
        outputs = []
        for i, block in enumerate(self.blocks):
            out, logdet = block.init(out,
                                     tgt_mask,
                                     src,
                                     src_mask,
                                     init_scale=init_scale)
            logdet_accum = logdet_accum + logdet
            if i < self.levels - 1:
                out1, out2 = split(out, block.z_features)
                outputs.append(out2)
                out, tgt_mask = squeeze(out1, tgt_mask)

        for _ in range(self.levels - 1):
            out2 = outputs.pop()
            out = unsqueeze(out)
            out = unsplit([out, out2])
        assert len(outputs) == 0
        return out, logdet_accum
예제 #3
0
파일: nmt.py 프로젝트: juheeuu/flowseq
    def forward(self, input: torch.Tensor, tgt_mask: torch.Tensor,
                src: torch.Tensor,
                src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        logdet_accum = input.new_zeros(input.size(0))
        out = input
        outputs = []
        for i, block in enumerate(self.blocks):
            out, logdet = block.forward(out, tgt_mask, src, src_mask)
            logdet_accum = logdet_accum + logdet
            if i < self.levels - 1:
                out1, out2 = split(out, block.z_features)
                outputs.append(out2)
                out, tgt_mask = squeeze(out1, tgt_mask)

        for _ in range(self.levels - 1):
            out2 = outputs.pop()
            out = unsqueeze(out)
            out = unsplit([out, out2])
        assert len(outputs) == 0
        return out, logdet_accum