Esempio n. 1
0
    def init_like_chainer(self):
        """Initialize weight like chainer.

        chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0
        pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5)
        however, there are two exceptions as far as I know.
        - EmbedID.W ~ Normal(0, 1)
        - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM)
        """
        lecun_normal_init_parameters(self)
        # exceptions
        # embed weight ~ Normal(0, 1)
        self.dec.embed.weight.data.normal_(0, 1)
        # forget-bias = 1.0
        # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745
        for i in six.moves.range(len(self.dec.decoder)):
            set_forget_bias_to_one(self.dec.decoder[i].bias_ih)
        # gs534 - lextree
        if self.init_from is not None:
            model_init = torch.load(self.init_from, map_location=lambda storage, loc: storage)
            model_init = model_init.state_dict() if not isinstance(model_init, dict) else model_init
            own_state = self.state_dict()
            for name, param in model_init.items():
                if name in own_state:
                    own_state[name].copy_(param.data)
Esempio n. 2
0
    def init_like_chainer(self):
        """Initialize weight like chainer.

        chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0
        pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5)
        however, there are two exceptions as far as I know.
        - EmbedID.W ~ Normal(0, 1)
        - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM)
        """
        lecun_normal_init_parameters(self)
        # exceptions
        # embed weight ~ Normal(0, 1)
        self.dec.embed.weight.data.normal_(0, 1)
        # forget-bias = 1.0
        # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745
        for i in six.moves.range(len(self.dec.decoder)):
            set_forget_bias_to_one(self.dec.decoder[i].bias_ih)
Esempio n. 3
0
def initializer(model, args):
    """Initialize transducer model.

    Args:
        model (torch.nn.Module): transducer instance
        args (Namespace): argument Namespace containing options

    """
    if args.dtype != 'transformer':
        if args.etype == 'transformer':
            initialize(model.encoder, args.transformer_init)
            lecun_normal_init_parameters(model.dec)
        else:
            lecun_normal_init_parameters(model)

        model.dec.embed.weight.data.normal_(0, 1)

        for l in six.moves.range(len(model.dec.decoder)):
            set_forget_bias_to_one(model.dec.decoder[l].bias_ih)
    else:
        if args.etype == 'transformer':
            initialize(model, args.transformer_init)
        else:
            lecun_normal_init_parameters(model.encoder)
            initialize(model.decoder, args.transformer_init)
Esempio n. 4
0
def initializer(model, args):
    """Initialize transducer model.

    Args:
        model (torch.nn.Module): transducer instance
        args (Namespace): argument Namespace containing options

    """
    if "custom" not in args.dtype:
        if "custom" in args.etype:
            initialize(model.encoder, args.transformer_init)
            lecun_normal_init_parameters(model.dec)
        else:
            lecun_normal_init_parameters(model)

        model.dec.embed.weight.data.normal_(0, 1)

        for i in range(model.dec.dlayers):
            set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_ih_l0"))
            set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_hh_l0"))
    else:
        if "custom" in args.etype:
            initialize(model, args.transformer_init)
        else:
            lecun_normal_init_parameters(model.enc)
            initialize(model.decoder, args.transformer_init)
Esempio n. 5
0
def initializer(model, args):
    """Initialize transducer model.

    Args:
        model (torch.nn.Module): transducer instance
        args (Namespace): argument Namespace containing options

    """
    if args.dtype != 'transformer':
        if args.etype == 'transformer':
            initialize(model.encoder, args.transformer_init)
            lecun_normal_init_parameters(model.decoder)
        else:
            lecun_normal_init_parameters(model)

        model.decoder.embed.weight.data.normal_(0, 1)
    else:
        if args.etype == 'transformer':
            initialize(model, args.transformer_init)
        else:
            lecun_normal_init_parameters(model.encoder)
            initialize(model.decoder, args.transformer_init)