示例#1
0
def initializer(model: torch.nn.Module, args: Namespace):
    """Initialize transducer model.

    Args:
        model: Transducer model.
        args: Namespace containing model options.

    """
    for name, p in model.named_parameters():
        if any(x in name for x in ["enc.", "dec.", "transducer_tasks."]):
            if p.dim() == 1:
                # bias
                p.data.zero_()
            elif p.dim() == 2:
                # linear weight
                n = p.size(1)
                stdv = 1.0 / math.sqrt(n)
                p.data.normal_(0, stdv)
            elif p.dim() in (3, 4):
                # conv weight
                n = p.size(1)
                for k in p.size()[2:]:
                    n *= k
                    stdv = 1.0 / math.sqrt(n)
                    p.data.normal_(0, stdv)

    if args.dtype != "custom":
        model.dec.embed.weight.data.normal_(0, 1)

        for i in range(model.dec.dlayers):
            set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_ih_l0"))
            set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_hh_l0"))
示例#2
0
def initializer(model, args):
    """Initialize transducer model.

    Args:
        model (torch.nn.Module): transducer instance
        args (Namespace): argument Namespace containing options

    """
    if args.dtype != 'transformer':
        if args.etype == 'transformer':
            initialize(model.encoder, args.transformer_init)
            lecun_normal_init_parameters(model.dec)
        else:
            lecun_normal_init_parameters(model)

        model.dec.embed.weight.data.normal_(0, 1)

        for l in six.moves.range(len(model.dec.decoder)):
            set_forget_bias_to_one(model.dec.decoder[l].bias_ih)
    else:
        if args.etype == 'transformer':
            initialize(model, args.transformer_init)
        else:
            lecun_normal_init_parameters(model.encoder)
            initialize(model.decoder, args.transformer_init)
def initializer(model, args):
    """Initialize transducer model.

    Args:
        model (torch.nn.Module): transducer instance
        args (Namespace): argument Namespace containing options

    """
    for name, p in model.named_parameters():
        if any(x in name for x in ["enc.", "dec.", "joint_network"]):
            # rnn based parts + joint network
            if p.dim() == 1:
                # bias
                p.data.zero_()
            elif p.dim() == 2:
                # linear weight
                n = p.size(1)
                stdv = 1.0 / math.sqrt(n)
                p.data.normal_(0, stdv)
            elif p.dim() in (3, 4):
                # conv weight
                n = p.size(1)
                for k in p.size()[2:]:
                    n *= k
                    stdv = 1.0 / math.sqrt(n)
                    p.data.normal_(0, stdv)

    if args.dtype != "custom":
        model.dec.embed.weight.data.normal_(0, 1)

        for i in range(model.dec.dlayers):
            set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_ih_l0"))
            set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_hh_l0"))
示例#4
0
def initializer(model, args):
    """Initialize transducer model.

    Args:
        model (torch.nn.Module): transducer instance
        args (Namespace): argument Namespace containing options

    """
    if "custom" not in args.dtype:
        if "custom" in args.etype:
            initialize(model.encoder, args.transformer_init)
            lecun_normal_init_parameters(model.dec)
        else:
            lecun_normal_init_parameters(model)

        model.dec.embed.weight.data.normal_(0, 1)

        for i in range(model.dec.dlayers):
            set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_ih_l0"))
            set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_hh_l0"))
    else:
        if "custom" in args.etype:
            initialize(model, args.transformer_init)
        else:
            lecun_normal_init_parameters(model.enc)
            initialize(model.decoder, args.transformer_init)
示例#5
0
    def init_like_chainer(self):
        """Initialize weight like chainer.

        chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0
        pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5)
        however, there are two exceptions as far as I know.
        - EmbedID.W ~ Normal(0, 1)
        - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM)
        """
        lecun_normal_init_parameters(self)
        # exceptions
        # embed weight ~ Normal(0, 1)
        self.dec.embed.weight.data.normal_(0, 1)
        # forget-bias = 1.0
        # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745
        for i in six.moves.range(len(self.dec.decoder)):
            set_forget_bias_to_one(self.dec.decoder[i].bias_ih)
        # gs534 - lextree
        if self.init_from is not None:
            model_init = torch.load(self.init_from, map_location=lambda storage, loc: storage)
            model_init = model_init.state_dict() if not isinstance(model_init, dict) else model_init
            own_state = self.state_dict()
            for name, param in model_init.items():
                if name in own_state:
                    own_state[name].copy_(param.data)
示例#6
0
文件: e2e_st.py 项目: unilight/espnet
    def init_like_chainer(self):
        """Initialize weight like chainer.

        chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0
        pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5)
        however, there are two exceptions as far as I know.
        - EmbedID.W ~ Normal(0, 1)
        - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM)
        """
        lecun_normal_init_parameters(self)
        # exceptions
        # embed weight ~ Normal(0, 1)
        self.dec.embed.weight.data.normal_(0, 1)
        # forget-bias = 1.0
        # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745
        for i in six.moves.range(len(self.dec.decoder)):
            set_forget_bias_to_one(self.dec.decoder[i].bias_ih)