示例#1
0
    def init_decoder(self, enc_outputs, expand_size=1):

        ctx = enc_outputs['ctx']

        ctx_mask = enc_outputs['ctx_mask']

        capsule_caches, layer_wise_capsule_caches = None, None

        if expand_size > 1:
            ctx = tile_batch(ctx, multiplier=expand_size)
            ctx_mask = tile_batch(ctx_mask, multiplier=expand_size)

            # Note that ctx has been tiled
            if self.config["capsule_type"] == "output":
                capsule_caches = self.decoder.final_capsule_layer.compute_caches(
                    ctx)

            elif self.config["capsule_type"] == "layer-wise":
                layer_wise_capsule_caches = []
                for i in range(self.config["n_layers"]):
                    layer_capsule_cache = self.decoder.block_stack[
                        i].capsule_layer.compute_caches(ctx)
                    layer_wise_capsule_caches.append(layer_capsule_cache)

        return {
            "ctx": ctx,
            "ctx_mask": ctx_mask,
            "enc_attn_caches": None,
            "self_attn_caches": None,
            "capsule_caches": capsule_caches,
            "layer_wise_capsule_caches": layer_wise_capsule_caches
        }
示例#2
0
    def init_decoder(self, enc_outputs, expand_size=1):

        ctx = enc_outputs['ctx']

        ctx_mask = enc_outputs['ctx_mask']

        dec_init, dec_caches = self.decoder.init_decoder(context=ctx, mask=ctx_mask)

        if expand_size > 1:
            ctx = tile_batch(ctx, expand_size)
            ctx_mask = tile_batch(ctx_mask, expand_size)
            dec_init = tile_batch(dec_init, expand_size)
            dec_caches = tile_batch(dec_caches, expand_size)

        return {"dec_hiddens": dec_init, "dec_caches": dec_caches, "ctx": ctx, "ctx_mask": ctx_mask}
示例#3
0
    def init_decoder(self, enc_outputs, expand_size=1):

        ctx = enc_outputs['ctx']

        ctx_mask = enc_outputs['ctx_mask']

        if expand_size > 1:
            ctx = tile_batch(ctx, multiplier=expand_size)
            ctx_mask = tile_batch(ctx_mask, multiplier=expand_size)

        return {
            "ctx": ctx,
            "ctx_mask": ctx_mask,
            "enc_attn_caches": None,
            "slf_attn_caches": None
        }
示例#4
0
    def init_decoder(self, enc_outputs, expand_size=1):
        ctx = enc_outputs['ctx']

        ctx_mask = enc_outputs['ctx_mask']

        dec_init, dec_caches, dec_capsule = self.decoder.init_decoder(
            context=ctx, mask=ctx_mask)

        if expand_size > 1:
            ctx = tile_batch(ctx, expand_size)
            ctx_mask = tile_batch(ctx_mask, expand_size)
            dec_init = tile_batch(dec_init, expand_size)
            dec_caches = tile_batch(dec_caches, expand_size)
            if dec_capsule is not None:
                dec_capsule = tile_batch(dec_capsule, expand_size)

        return {
            "dec_hiddens": dec_init,
            "dec_caches": dec_caches,
            "ctx": ctx,
            "ctx_mask": ctx_mask,
            "dec_capsule": dec_capsule,
            "routing_weights": None
        }