def init_like_chainer(self): """Initialize weight like chainer. chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) however, there are two exceptions as far as I know. - EmbedID.W ~ Normal(0, 1) - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) """ lecun_normal_init_parameters(self) # exceptions # embed weight ~ Normal(0, 1) self.dec.embed.weight.data.normal_(0, 1) # forget-bias = 1.0 # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745 for i in six.moves.range(len(self.dec.decoder)): set_forget_bias_to_one(self.dec.decoder[i].bias_ih) # gs534 - lextree if self.init_from is not None: model_init = torch.load(self.init_from, map_location=lambda storage, loc: storage) model_init = model_init.state_dict() if not isinstance(model_init, dict) else model_init own_state = self.state_dict() for name, param in model_init.items(): if name in own_state: own_state[name].copy_(param.data)
def init_like_chainer(self): """Initialize weight like chainer. chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) however, there are two exceptions as far as I know. - EmbedID.W ~ Normal(0, 1) - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) """ lecun_normal_init_parameters(self) # exceptions # embed weight ~ Normal(0, 1) self.dec.embed.weight.data.normal_(0, 1) # forget-bias = 1.0 # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745 for i in six.moves.range(len(self.dec.decoder)): set_forget_bias_to_one(self.dec.decoder[i].bias_ih)
def initializer(model, args): """Initialize transducer model. Args: model (torch.nn.Module): transducer instance args (Namespace): argument Namespace containing options """ if args.dtype != 'transformer': if args.etype == 'transformer': initialize(model.encoder, args.transformer_init) lecun_normal_init_parameters(model.dec) else: lecun_normal_init_parameters(model) model.dec.embed.weight.data.normal_(0, 1) for l in six.moves.range(len(model.dec.decoder)): set_forget_bias_to_one(model.dec.decoder[l].bias_ih) else: if args.etype == 'transformer': initialize(model, args.transformer_init) else: lecun_normal_init_parameters(model.encoder) initialize(model.decoder, args.transformer_init)
def initializer(model, args): """Initialize transducer model. Args: model (torch.nn.Module): transducer instance args (Namespace): argument Namespace containing options """ if "custom" not in args.dtype: if "custom" in args.etype: initialize(model.encoder, args.transformer_init) lecun_normal_init_parameters(model.dec) else: lecun_normal_init_parameters(model) model.dec.embed.weight.data.normal_(0, 1) for i in range(model.dec.dlayers): set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_ih_l0")) set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_hh_l0")) else: if "custom" in args.etype: initialize(model, args.transformer_init) else: lecun_normal_init_parameters(model.enc) initialize(model.decoder, args.transformer_init)
def initializer(model, args): """Initialize transducer model. Args: model (torch.nn.Module): transducer instance args (Namespace): argument Namespace containing options """ if args.dtype != 'transformer': if args.etype == 'transformer': initialize(model.encoder, args.transformer_init) lecun_normal_init_parameters(model.decoder) else: lecun_normal_init_parameters(model) model.decoder.embed.weight.data.normal_(0, 1) else: if args.etype == 'transformer': initialize(model, args.transformer_init) else: lecun_normal_init_parameters(model.encoder) initialize(model.decoder, args.transformer_init)