Beispiel #1
0
    def add_value(self, key: str, value, round_digits: Optional[int] = None):
        """
        Add a column value to the current step.

        :param key: data key
        :param value: value to record, pass '' to print nothing
        :param round_digits: digits to rounds to, pass `None` (default) for no rounding
        """
        if not isinstance(key, str):
            raise pyrado.TypeErr(given=key, expected_type=str)
        if round_digits is not None and not isinstance(round_digits, int):
            raise pyrado.TypeErr(given=round_digits, expected_type=int)

        # Compute full prefixed key
        key = self._prefix_str + key

        if self._first_step:
            # Record new key during first step
            self._value_keys.append(key)
        elif key not in self._value_keys:
            # Make sure the key was used during first step
            raise pyrado.KeyErr(
                msg=
                "New value keys may only be added before the first step is finished"
            )

        # Pre-process lists
        if isinstance(value, list):
            if len(value) == 1:
                value = value[0]
        # Pre-process PyTorch tensors and numpy arrays (the same way)
        if isinstance(value, to.Tensor):
            value = value.detach().cpu().numpy()
        if isinstance(value, np.ndarray):
            if round_digits is not None:
                value = np.round(value, round_digits)
            if value.ndim == 0 or value.size == 1:  # scalar
                value = value.item()
            else:
                value = value.flatten()
                if value.ndim == 1:  # vector
                    value = value.tolist()
                else:
                    raise pyrado.ShapeErr(
                        msg="Logging 2-dim arrays or tensors is not supported."
                    )
        # Pre-process floats
        elif isinstance(value, float):
            if round_digits is not None:
                value = round(value, round_digits)

        # Record value
        self._current_values[key] = value
        self._values_changed = True
Beispiel #2
0
    def adapt(self, domain_distr_param: str,
              domain_distr_param_value: Union[float, int, to.Tensor]):
        """
        Update this domain parameter.

        .. note::
            This function should by called by the subclasses' `adapt()` function.

        :param domain_distr_param: distribution parameter to update, e.g. mean or std
        :param domain_distr_param_value: new value of the distribution parameter
        """
        if domain_distr_param not in self.get_field_names():
            raise pyrado.KeyErr(
                msg=
                f"The domain parameter {self.name} does not have a domain distribution parameter "
                f"called {domain_distr_param}!")
        setattr(self, domain_distr_param, domain_distr_param_value)
Beispiel #3
0
    def save_snapshot(self, meta_info: dict = None):
        super().save_snapshot(meta_info)

        # ParameterExploring subroutine saves the best policy (in this case a DomainDistrParamPolicy)
        prefix = meta_info.get("prefix", "")
        if prefix != "":
            self._subrtn.save_snapshot(meta_info=dict(
                prefix=f"{prefix}_ddp"))  # save iter_X_ddp_policy.pt
        self._subrtn.save_snapshot(
            meta_info=dict(prefix="ddp"))  # override ddp_policy.pt

        joblib.dump(self._subrtn.env, osp.join(self.save_dir, "env_sim.pkl"))

        # Print the current search distribution's mean
        cpp = self._subrtn.policy.transform_to_ddp_space(
            self._subrtn.policy.param_values)
        self._subrtn.env.adapt_randomizer(
            domain_distr_param_values=cpp.detach().cpu().numpy())
        print_cbt(
            f"Current policy domain parameter distribution\n{self._subrtn.env.randomizer}",
            "g")

        # Set the randomizer to best fitted domain distribution
        cbp = self._subrtn.policy.transform_to_ddp_space(
            self._subrtn.best_policy_param)
        self._subrtn.env.adapt_randomizer(
            domain_distr_param_values=cbp.detach().cpu().numpy())
        print_cbt(
            f"Best fitted domain parameter distribution\n{self._subrtn.env.randomizer}",
            "g")

        if "rollouts_real" not in meta_info:
            raise pyrado.KeyErr(keys="rollouts_real", container=meta_info)
        pyrado.save(meta_info["rollouts_real"],
                    "rollouts_real.pkl",
                    self.save_dir,
                    prefix=prefix)
    def adapt_one_distr_param(self, domain_param_name: str,
                              domain_distr_param: str,
                              domain_distr_param_value: Union[float, int]):
        """
        Update the randomizer's domain parameter distribution for one domain parameter.

        :param domain_param_name: name of the domain parameter which's distribution parameter should be updated
        :param domain_distr_param: distribution parameter to update, e.g. mean or std
        :param domain_distr_param_value: new value of the distribution parameter
        """
        for dp in self.domain_params:
            if dp.name == domain_param_name:
                if domain_distr_param in dp.get_field_names():
                    # Set the new value
                    if not isinstance(domain_distr_param_value,
                                      (int, float, bool)):
                        pyrado.TypeErr(given=domain_distr_param_value,
                                       expected_type=[int, float, bool])
                    dp.adapt(domain_distr_param, domain_distr_param_value)
                else:
                    raise pyrado.KeyErr(
                        msg=
                        f"The domain parameter {dp.name} does not have a domain distribution parameter "
                        f"called {domain_distr_param}!")
Beispiel #5
0
    def train_policy_sim(self,
                         domain_params: to.Tensor,
                         prefix: str,
                         cnt_rep: int,
                         use_rec_init_states: bool = True) -> float:
        """
        Train a policy in simulation for given hyper-parameters from the domain randomizer.

        :param domain_params: domain parameters sampled from the posterior [shape N x D where N is the number of
                              samples and D is the number of domain parameters]
        :param prefix: set a prefix to the saved file name, use "" for no prefix
        :param cnt_rep: current repetition count, coming from the wrapper function
        :param use_rec_init_states: if `True`, the previous rollout will be loaded to extract the initial states, and
                                    sync them with the recorded ones
        :return: estimated return of the trained policy in the target domain
        """
        if not (domain_params.ndim == 2
                and domain_params.shape[1] == len(self.dp_mapping)):
            raise pyrado.ShapeErr(given=domain_params,
                                  expected_match=(-1, len(self.dp_mapping)))

        # Insert the domain parameters into the wrapped environment's buffer
        self.fill_domain_param_buffer(self._env_sim_trn, self.dp_mapping,
                                      domain_params)

        # Set the initial state spaces of the simulation environment to match the observed initial states
        if use_rec_init_states:
            rollouts_real = pyrado.load("rollouts_real.pkl",
                                        self._save_dir,
                                        prefix=prefix)
            init_states_real = np.stack(
                [ro.states[0, :] for ro in rollouts_real])
            if not init_states_real.shape == (
                    len(rollouts_real),
                    self._env_sim_trn.state_space.flat_dim):
                raise pyrado.ShapeErr(
                    given=init_states_real,
                    expected_match=(len(rollouts_real),
                                    self._env_sim_trn.state_space.flat_dim))
            inner_env(
                self._env_sim_trn).init_space = DiscreteSpace(init_states_real)
            print_cbt(
                "The simulation environment's initial states have been set to the recorded ones.",
                "w")

        # Reset the subroutine algorithm which includes resetting the exploration
        self._cnt_samples += self._subrtn_policy.sample_count
        self._subrtn_policy.reset()

        # Propagate the updated training environment to the SamplerPool's workers
        if hasattr(self._subrtn_policy, "sampler"):
            self._subrtn_policy.sampler.reinit(env=self._env_sim_trn)
        else:
            raise pyrado.KeyErr(keys="sampler", container=self._subrtn_policy)

        # Do a warm start, but randomly reset the policy parameters if training failed once
        self._subrtn_policy.init_modules(self.warmstart and cnt_rep == 0)

        # Train a policy in simulation using the subroutine
        self._subrtn_policy.train(
            snapshot_mode=self._subrtn_policy_snapshot_mode,
            meta_info=dict(prefix=prefix))

        # Return the estimated return of the trained policy in simulation
        assert len(self._env_sim_trn.buffer) == self.num_eval_samples
        self._env_sim_trn.ring_idx = 0  # don't reset the buffer to eval on the same domains as trained
        avg_ret_sim = self.eval_policy(None, self._env_sim_trn,
                                       self._subrtn_policy.policy, prefix,
                                       self.num_eval_samples)
        return float(avg_ret_sim)
Beispiel #6
0
    num_cand = cands.shape[
        0]  # number of samples i.e. iterations of BayRn (including init phase)
    dim_cand = cands.shape[1]  # number of domain distribution parameters
    print_cbt(
        f'Found {num_cand} candidates of dimension {dim_cand}:\n{cands.detach().cpu().numpy()}',
        'w')
    if dim_cand % 2 != 0:
        raise pyrado.ShapeErr(
            msg=
            'The dimension of domain distribution parameters must be a multiple of 2!'
        )

    # Remove the initial candidates
    hparams = load_dict_from_yaml(osp.join(ex_dir, 'hyperparams.yaml'))
    if 'algo_name' not in hparams:
        raise pyrado.KeyErr(keys='algo_name', container=hparams)
    if 'dp_map' not in hparams:
        raise pyrado.KeyErr(keys='dp_map', container=hparams)

    # Process algorithms differently
    if hparams['algo_name'] == BayRn.name:
        try:
            num_init_cand = hparams['algo']['num_init_cand']
        except KeyError:
            raise KeyError(
                'There was no num_init_cand key in the hparams.yaml file!'
                'Are you sure you loaded a BayRn experiment?')
        if not args.load_all:
            cands = cands[num_init_cand:, :]
            # cands_values = cands_values[num_init_cand:, :]
            num_cand -= num_init_cand
Beispiel #7
0
def draw_curve(plot_type: str,
               ax: plt.Axes,
               data: pd.DataFrame,
               x_grid: Union[list, np.ndarray, to.Tensor],
               x_label: Optional[Union[str, Sequence[str]]] = None,
               y_label: Optional[str] = None,
               curve_label: Optional[str] = None,
               area_label: Optional[str] = None,
               vline_level: Optional[float] = None,
               vline_label: str = 'approx. solved',
               title: Optional[str] = None,
               show_legend: bool = True,
               legend_kwargs: dict = None,
               plot_kwargs: dict = None) -> plt.Figure:
    """
    Create a box or violin plot for a list of data arrays or a pandas DataFrame.
    The plot is neither shown nor saved.

    .. note::
        If you want to have a tight layout, it is best to pass axes of a figure with `tight_layout=True` or
        `constrained_layout=True`.

        If you want to order the 4th element to the 2nd position in terms of colors use
        .. code-block:: python

            palette.insert(1, palette.pop(3))

    :param plot_type: tye of categorical plot, pass box or violin
    :param ax: axis of the figure to plot on
    :param data: pandas DataFrame containing the columns `mean`, `std`, `min`, and `max` depending on the `plot_type`
    :param x_grid: values to plot the data over, e.g. time
    :param x_label: labels for the categories on the x-axis, if `data` is not given as a `DataFrame`
    :param y_label: label for the y-axis, pass `None` to set no label
    :param curve_label: label of the (1-dim) curve
    :param area_label: label of the (transparent) area
    :param vline_level: if not `None` (default) add a vertical line at the given level
    :param vline_label: label for the vertical line
    :param show_legend: if `True` the legend is shown, useful when handling multiple subplots
    :param title: title displayed above the figure, set to None to suppress the title
    :param legend_kwargs: keyword arguments forwarded to pyplot's `legend()` function, e.g. `loc='best'`
    :param plot_kwargs: keyword arguments forwarded to seaborn's `boxplot()` or `violinplot()` function
    :return: handle to the resulting figure
    """
    plot_type = plot_type.lower()
    if plot_type not in ['mean_std', 'min_mean_max']:
        raise pyrado.ValueErr(given=plot_type,
                              eq_constraint='mean_std or min_mean_max')
    if not isinstance(data, pd.DataFrame):
        raise pyrado.TypeErr(given=data, expected_type=pd.DataFrame)
    if x_label is not None and not isinstance(x_label, str):
        raise pyrado.TypeErr(given=x_label, expected_type=str)
    if y_label is not None and not isinstance(y_label, str):
        raise pyrado.TypeErr(given=y_label, expected_type=str)

    # Set defaults which can be overwritten by passing plot_kwargs
    plot_kwargs = merge_dicts([dict(alpha=0.3), plot_kwargs])
    legend_kwargs = dict() if legend_kwargs is None else legend_kwargs
    # palette = sns.color_palette() if palette is None else palette

    # Preprocess
    if isinstance(x_grid, list):
        x_grid = np.array(x_grid)
    elif isinstance(x_grid, to.Tensor):
        x_grid = x_grid.detach().cpu().numpy()

    # Plot
    if plot_type == 'mean_std':
        if not ('mean' in data.columns and 'std' in data.columns):
            raise pyrado.KeyErr(keys="'mean' and 'std'", container=data)
        num_stds = 2
        if area_label is None:
            area_label = rf'$\pm {num_stds}$ std'
        ax.fill_between(x_grid,
                        data['mean'] - num_stds * data['std'],
                        data['mean'] + num_stds * data['std'],
                        label=area_label,
                        **plot_kwargs)

    elif plot_type == 'min_mean_max':
        if not ('mean' in data.columns and 'min' in data.columns
                and 'max' in data.columns):
            raise pyrado.KeyErr(keys="'mean' and 'min' and 'max'",
                                container=data)
        if area_label is None:
            area_label = r'min \& max'
        ax.fill_between(x_grid,
                        data['min'],
                        data['max'],
                        label=area_label,
                        **plot_kwargs)

    # plot mean last for proper z-ordering
    plot_kwargs['alpha'] = 1
    ax.plot(x_grid, data['mean'], label=curve_label, **plot_kwargs)

    # Postprocess
    if vline_level is not None:
        # Add dashed line to mark a threshold
        ax.axhline(vline_level, c='k', ls='--', lw=1., label=vline_label)

    if x_label is None:
        ax.get_xaxis().set_ticks([])

    if y_label is not None:
        ax.set_ylabel(y_label)

    if show_legend:
        ax.legend(**legend_kwargs)

    if title is not None:
        ax.set_title(title)

    return plt.gcf()
Beispiel #8
0
    def step(self, snapshot_mode: str, meta_info: dict = None):
        if "rollouts_real" not in meta_info:
            raise pyrado.KeyErr(keys="rollouts_real", container=meta_info)

        # Extract the initial states from the real rollouts
        rollouts_real = meta_info["rollouts_real"]
        init_states_real = [ro.states[0, :] for ro in rollouts_real]

        # Sample new policy parameters a.k.a domain distribution parameters
        param_sets = self._subrtn.expl_strat.sample_param_sets(
            nominal_params=self._subrtn.policy.param_values,
            num_samples=self._subrtn.pop_size,
            include_nominal_params=True,
        )

        # Iterate over every domain parameter distribution. We basically mimic the ParameterExplorationSampler here,
        # but we need to adapt the randomizer (and not just the domain parameters) por every policy param set
        param_samples = []
        loss_hist = []
        for idx_ps, ps in enumerate(param_sets):
            # Update the randomizer to use the new
            new_ddp_vals = self._subrtn.policy.transform_to_ddp_space(ps)
            self._subrtn.env.adapt_randomizer(
                domain_distr_param_values=new_ddp_vals.detach().cpu().numpy())
            self._subrtn.env.randomizer.randomize(
                num_samples=self.num_rollouts_per_distr)
            sampled_domain_params = self._subrtn.env.randomizer.get_params()

            # Sample the rollouts
            rollouts_sim = self.behavior_sampler.sample(init_states_real,
                                                        sampled_domain_params,
                                                        eval=True)

            # Iterate over simulated rollout with the same initial state
            for idx_real, idcs_sim in enumerate(
                    gen_ordered_batch_idcs(self.num_rollouts_per_distr,
                                           len(rollouts_sim),
                                           sorted=True)):
                # Clip the rollouts rollouts yielding two lists of pairwise equally long rollouts
                ros_real_tr, ros_sim_tr = self.truncate_rollouts(
                    [rollouts_real[idx_real]],
                    rollouts_sim[slice(idcs_sim[0], idcs_sim[-1] + 1)])

                # Check the validity of the initial states. The domain parameters will be different.
                assert len(ros_real_tr) == len(ros_sim_tr) == len(idcs_sim)
                assert check_all_equal([ro.states[0, :] for ro in ros_real_tr])
                assert check_all_equal([ro.states[0, :] for ro in ros_sim_tr])
                assert all([
                    np.allclose(r.states[0, :], s.states[0, :])
                    for r, s in zip(ros_real_tr, ros_sim_tr)
                ])

                # Compute the losses
                losses = np.asarray([
                    self.loss_fcn(ro_r, ro_s)
                    for ro_r, ro_s in zip(ros_real_tr, ros_sim_tr)
                ])

                if np.all(losses == 0.0):
                    raise pyrado.ValueErr(
                        msg=
                        "All SysIdViaEpisodicRL losses are equal to zero! Most likely the domain"
                        "randomization is too extreme, such that every trajectory is done after"
                        "one step. Check the exploration strategy.")

                # Handle zero losses by setting them to the maximum current loss
                losses[losses == 0] = np.max(losses)
                loss_hist.extend(losses)

                # We need to assign the loss value to the simulated rollout, but this one can be of a different
                # length than the real-world rollouts as well as of different length than the original
                # (non-truncated) simulated rollout. Thus, we simply write the loss value into the first step.
                for i, l in zip(range(idcs_sim[0], idcs_sim[-1] + 1), losses):
                    rollouts_sim[i].rewards[:] = 0.0
                    rollouts_sim[i].rewards[0] = -l

            # Collect the results
            param_samples.append(
                ParameterSample(params=ps, rollouts=rollouts_sim))

        # Bind the parameter samples and their rollouts in the usual container
        param_samp_res = ParameterSamplingResult(param_samples)
        self._cnt_samples += sum(
            [len(ro) for pss in param_samp_res for ro in pss.rollouts])

        # Log metrics computed from the old policy (before the update)
        loss_hist = np.asarray(loss_hist)
        self.logger.add_value("min sysid loss", np.min(loss_hist), 6)
        self.logger.add_value("median sysid loss", np.median(loss_hist), 6)
        self.logger.add_value("avg sysid loss", np.mean(loss_hist), 6)
        self.logger.add_value("max sysid loss", np.max(loss_hist), 6)
        self.logger.add_value("std sysid loss", np.std(loss_hist), 6)

        # Extract the best policy parameter sample for saving it later
        self._subrtn.best_policy_param = param_samp_res.parameters[np.argmax(
            param_samp_res.mean_returns)].clone()

        # Save snapshot data
        self.make_snapshot(snapshot_mode,
                           float(np.max(param_samp_res.mean_returns)),
                           meta_info)

        # Update the wrapped algorithm's update method
        self._subrtn.update(
            param_samp_res,
            ret_avg_curr=param_samp_res[0].mean_undiscounted_return)
Beispiel #9
0
    num_cand = cands.shape[
        0]  # number of samples i.e. iterations of BayRn (including init phase)
    dim_cand = cands.shape[1]  # number of domain distribution parameters
    print_cbt(
        f"Found {num_cand} candidates of dimension {dim_cand}:\n{cands.detach().cpu().numpy()}",
        "w")
    if dim_cand % 2 != 0:
        raise pyrado.ShapeErr(
            msg=
            "The dimension of domain distribution parameters must be a multiple of 2!"
        )

    # Remove the initial candidates
    hparams = load_dict_from_yaml(osp.join(ex_dir, "hyperparams.yaml"))
    if "algo_name" not in hparams:
        raise pyrado.KeyErr(keys="algo_name", container=hparams)
    if "dp_map" not in hparams:
        raise pyrado.KeyErr(keys="dp_map", container=hparams)

    # Process algorithms differently
    if hparams["algo_name"] == BayRn.name:
        try:
            num_init_cand = hparams["algo"]["num_init_cand"]
        except KeyError:
            raise KeyError(
                "There was no num_init_cand key in the hparams.yaml file!"
                "Are you sure you loaded a BayRn experiment?")
        if not args.load_all:
            cands = cands[num_init_cand:, :]
            # cands_values = cands_values[num_init_cand:, :]
            num_cand -= num_init_cand
Beispiel #10
0
def init_param(m, **kwargs):
    """
    Initialize the parameters of the PyTorch Module / layer / network / cell according to its type.

    :param m: PyTorch Module / layer / network / cell to initialize
    :param kwargs: optional keyword arguments, e.g. `t_max` for LSTM's chrono initialization [2], or `uniform_bias`

    .. seealso::
        [1] A.M. Sachse, J. L. McClelland, S. Ganguli, "Exact solutions to the nonlinear dynamics of learning in
        deep linear neural networks", 2014

        [2] C. Tallec, Y. Ollivier, "Can recurrent neural networks warp time?", 2018, ICLR
    """
    kwargs = kwargs if kwargs is not None else dict()

    if isinstance(m, (nn.Linear, nn.RNN, nn.GRU, nn.GRUCell)):
        for name, param in m.named_parameters():
            if 'weight' in name:
                if len(param.shape) >= 2:
                    # Most common case
                    init.orthogonal_(
                        param.data)  # former: init.xavier_normal_(param.data)
                else:
                    init.normal_(param.data)
            elif 'bias' in name:
                if kwargs.get('uniform_bias', False):
                    init.uniform_(param.data,
                                  a=-1. / sqrt(param.data.nelement()),
                                  b=1. / sqrt(param.data.nelement()))
                else:
                    # Default case
                    init.normal_(param.data,
                                 std=1. / sqrt(param.data.nelement()))
            else:
                raise pyrado.KeyErr(keys='weight or bias', container=param)

    elif isinstance(m, (nn.LSTM, nn.LSTMCell)):
        for name, param in m.named_parameters():
            if 'weight_ih' in name:
                # Initialize the input to hidden weights orthogonally
                # w_ii, w_if, w_ic, w_io
                nn.init.orthogonal_(param.data)
            elif 'weight_hh' in name:
                # Initialize the hidden to hidden weights separately as identity matrices and stack them afterwards
                # w_ii, w_if, w_ic, w_io
                weight_hh_ii = to.eye(m.hidden_size, m.hidden_size)
                weight_hh_if = to.eye(m.hidden_size, m.hidden_size)
                weight_hh_ic = to.eye(m.hidden_size, m.hidden_size)
                weight_hh_io = to.eye(m.hidden_size, m.hidden_size)
                weight_hh_all = to.cat(
                    [weight_hh_ii, weight_hh_if, weight_hh_ic, weight_hh_io],
                    dim=0)
                param.data.copy_(weight_hh_all)
            elif 'bias' in name:
                # b_ii, b_if, b_ig, b_io
                if 't_max' in kwargs:
                    if not isinstance(kwargs['t_max'],
                                      (float, int, to.Tensor)):
                        raise pyrado.TypeErr(
                            given=kwargs['t_max'],
                            expected_type=[float, int, to.Tensor])
                    # Initialize all biases to 0, but the bias of the forget and input gate using the chrono init
                    nn.init.constant_(param.data, val=0)
                    param.data[m.hidden_size:m.hidden_size * 2] = to.log(
                        nn.init.uniform_(  # forget gate
                            param.data[m.hidden_size:m.hidden_size * 2], 1,
                            kwargs['t_max'] - 1))
                    param.data[0:m.hidden_size] = -param.data[
                        m.hidden_size:2 * m.hidden_size]  # input gate
                else:
                    # Initialize all biases to 0, but the bias of the forget gate to 1
                    nn.init.constant_(param.data, val=0)
                    param.data[m.hidden_size:m.hidden_size * 2].fill_(1)

    elif isinstance(m, nn.Conv1d):
        if kwargs.get('bell', False):
            # Initialize the kernel weights with a shifted of shape exp(-x^2 / sigma^2).
            # The biases are left unchanged.
            if m.weight.data.shape[2] % 2 == 0:
                ks_half = m.weight.data.shape[2] // 2
                ls_half = to.linspace(ks_half, 0, ks_half)  # descending
                ls = to.cat([ls_half, reversed(ls_half)])
            else:
                ks_half = ceil(m.weight.data.shape[2] / 2)
                ls_half = to.linspace(ks_half, 0, ks_half)  # descending
                ls = to.cat([ls_half, reversed(ls_half[:-1])])
            _apply_weights_conf(m, ls, ks_half)

    elif isinstance(m, MirrConv1d):
        if kwargs.get('bell', False):
            # Initialize the kernel weights with a shifted of shape exp(-x^2 / sigma^2).
            # The biases are left unchanged (does not exist by default).
            ks = m.weight.data.shape[2]  # ks_mirr = ceil(ks_conv1d / 2)
            ls = to.linspace(ks, 0, ks)  # descending
            _apply_weights_conf(m, ls, ks)

    elif isinstance(m, ScaleLayer):
        # Initialize all weights to 1
        m.weight.data.fill_(1.)

    elif isinstance(m, PositiveScaleLayer):
        # Initialize all weights to 1
        m.log_weight.data.fill_(0.)

    elif isinstance(m, IndiNonlinLayer):
        # Initialize all weights to 1 and all biases to 0 (if they exist)
        if m.weight is not None:
            init.normal_(m.weight, std=1. / sqrt(m.weight.nelement()))
        if m.bias is not None:
            init.normal_(m.bias, std=1. / sqrt(m.bias.nelement()))

    else:
        pass