def get_mlpn_predict(X, parameters, hidden_n):
    model = MLPRegressor(hidden_layer_sizes=hidden_n,
                         activation='logistic',
                         max_iter=1)
    y = np.random.rand(X.shape[0])
    model.fit(X, y)
    weights1 = []
    weights2 = []
    bias1 = []
    bias2 = []
    for i in range(X.shape[1]):
        weights1.append(parameters[i * hidden_n:(i + 1) * hidden_n])
    p_index = X.shape[1] * hidden_n
    bias1.append(parameters[p_index:p_index + hidden_n])
    p_index = p_index + hidden_n
    for i in range(p_index, p_index + hidden_n):
        weights2.append([parameters[i]])
    p_index = p_index + hidden_n
    bias2.append(parameters[p_index:])
    weights1 = np.array(weights1)
    bias1 = np.array(bias1)
    weights2 = np.array(weights2)
    bias2 = np.array(bias2)
    weights = [weights1, weights2]
    bias = [bias1, bias2]
    model.coefs_ = weights
    model.intercepts_ = bias
    predvalue = model.predict(X)
    return np.array(predvalue)
예제 #2
0
 def _create_new_nn(self, weights, biases):
     mlp = MLPRegressor(hidden_layer_sizes = self._nn_architecture, alpha=10**-10, max_iter=1)
     mlp.fit([np.random.randn(self._n_features)], [np.random.randn(self._n_actions)])
     mlp.coefs_ = weights
     mlp.intercepts_ = biases
     mlp.out_activation_ = 'softmax'
     return mlp
예제 #3
0
    def initMlp(self, netParams):
        """
        initializes a MultiLayer Perceptron (MLP) Regressor with the desired network architecture (layers)
        and network parameters (weights and biases).
        :param netParams: a list of floats representing the network parameters (weights and biases) of the MLP
        :return: initialized MLP Regressor
        """

        # create the initial MLP:
        mlp = MLPRegressor(hidden_layer_sizes=(HIDDEN_LAYER, ), max_iter=1)

        # This will initialize input and output layers, and nodes weights and biases:
        # we are not otherwise interested in training the MLP here, hence the settings max_iter=1 above
        mlp.fit(
            np.random.uniform(low=-1, high=1, size=INPUTS).reshape(1, -1),
            np.ones(OUTPUTS))

        # weights are represented as a list of 2 ndarrays:
        # - hidden layer weights: INPUTS x HIDDEN_LAYER
        # - output layer weights: HIDDEN_LAYER x OUTPUTS
        numWeights = INPUTS * HIDDEN_LAYER + HIDDEN_LAYER * OUTPUTS
        weights = np.array(netParams[:numWeights])
        mlp.coefs_ = [
            weights[0:INPUTS * HIDDEN_LAYER].reshape((INPUTS, HIDDEN_LAYER)),
            weights[INPUTS * HIDDEN_LAYER:].reshape((HIDDEN_LAYER, OUTPUTS))
        ]

        # biases are represented as a list of 2 ndarrays:
        # - hidden layer biases: HIDDEN_LAYER x 1
        # - output layer biases: OUTPUTS x 1
        biases = np.array(netParams[numWeights:])
        mlp.intercepts_ = [biases[:HIDDEN_LAYER], biases[HIDDEN_LAYER:]]

        return mlp
예제 #4
0
def deserialize_mlp_regressor(model_dict):
    model = MLPRegressor(**model_dict['params'])

    model.coefs_ = model_dict['coefs_']
    model.loss_ = model_dict['loss_']
    model.intercepts_ = model_dict['intercepts_']
    model.n_iter_ = model_dict['n_iter_']
    model.n_layers_ = model_dict['n_layers_']
    model.n_outputs_ = model_dict['n_outputs_']
    model.out_activation_ = model_dict['out_activation_']

    return model
예제 #5
0
    def get_encoder_from_autoencoder(
        vectors: List[array],
        trained_encoder_decoder: MLPRegressor,
        hidden_layer_sizes: List[int],
        output_layer_size: int,
    ) -> MLPRegressor:

        model = MLPRegressor(random_state=1,
                             activation="relu",
                             hidden_layer_sizes=hidden_layer_sizes)
        model.fit(X=vectors, y=zeros(shape=(len(vectors), output_layer_size)))
        model.coefs_ = trained_encoder_decoder.coefs_[:len(hidden_layer_sizes
                                                           ) + 1]
        return model
def get_mlp_predict(X, parameters, hidden_n):
    model = MLPRegressor(hidden_layer_sizes=hidden_n,
                         activation='logistic',
                         max_iter=1)
    y = np.random.rand(X.shape[0])
    model.fit(X, y)
    weights = []
    bias = []
    for i in range(X.shape[1]):
        weights.append(parameters[i * hidden_n:(i + 1) * hidden_n])
    for i in range(len(parameters) - hidden_n, len(parameters)):
        bias.append([parameters[i]])
    weights = np.array(weights)
    bias = np.array(bias)
    coefs = [weights, bias]
    model.coefs_ = coefs
    predvalue = model.predict(X)
    return np.array(predvalue)
예제 #7
0
net.coefs_ = [
    list([[
        1.41064131e+00, 1.46948691e+00, -2.24180998e-02, -3.78213189e-02,
        -7.88690564e-47, 1.06024932e+00, 3.02749706e-77, 1.01993214e-65,
        1.57168527e-45, 1.33264699e+00
    ],
          [
              1.36787233e+00, 1.48466498e+00, -2.24184017e-02, -2.70854246e-01,
              -3.20318084e-90, 1.43620710e+00, -6.97627193e-56, 8.66637950e-63,
              1.30390091e-58, 1.36911041e+00
          ],
          [
              1.13074871e+00, 1.14861820e+00, -2.24064805e-02, -3.57387868e-02,
              -3.63739675e-84, 1.25371029e+00, 2.74808685e-88, 7.95821987e-72,
              -6.16637575e-95, 9.14102736e-01
          ],
          [
              1.56892593e+00, 1.66921605e+00, -2.24187276e-02, 3.97464248e-02,
              -5.00913299e-96, 1.32831486e+00, 5.23613549e-47, 7.96319992e-47,
              1.11270815e-97, 1.58048805e+00
          ],
          [
              1.46609153e+00, 1.64144575e+00, -2.24185198e-02, -3.70583893e-02,
              -1.48994534e-83, 1.48847808e+00, -2.86366630e-94, 5.36510314e-98,
              1.61654579e-54, 1.78469820e+00
          ],
          [
              4.27915262e-01, 6.13957196e-01, -2.33231506e-02, -7.66902167e-03,
              -3.89765017e-48, 5.10179927e-01, 8.11190958e-68, 1.77924760e-50,
              -1.81057442e-58, 4.42326647e-01
          ],
          [
              1.72333659e+00, 1.64688949e+00, -2.24184017e-02, -3.79464251e-02,
              -5.55046234e-99, 1.55477628e+00, 1.21867458e-96, -9.09268946e-63,
              -7.93446998e-74, 1.66607761e+00
          ],
          [
              1.65764539e+00, 1.81899065e+00, -2.24311935e-02, -4.40808094e-02,
              5.77931344e-94, 1.49321653e+00, -2.22456818e-71, 2.47084981e-99,
              -2.91387914e-99, 1.74245198e+00
          ],
          [
              1.77452049e+00, 1.59332372e+00, -2.24185198e-02, -8.41958797e-03,
              -1.33421184e-72, 1.47806937e+00, 3.21925485e-63, 2.84528628e-89,
              7.20529734e-98, 1.67939888e+00
          ],
          [
              1.72283323e+00, 1.81996851e+00, -2.24186142e-02, -4.08878027e-02,
              -5.87322523e-74, 1.70523929e+00, -2.39415123e-60, 1.37579163e-90,
              -3.34827922e-98, 1.62981667e+00
          ],
          [
              1.49334508e+00, 1.64927013e+00, -2.36990836e-02, -1.14629487e-01,
              -8.11789175e-56, 1.52021563e+00, 5.35984102e-61, -1.43319115e-59,
              -2.33143502e-93, 1.73770059e+00
          ],
          [
              1.75491696e+00, 1.73139747e+00, -2.24184017e-02, -4.84019108e-02,
              -2.16493726e-98, 1.54103295e+00, -1.35152193e-76, 9.97747019e-58,
              -1.70468225e-96, 1.59777217e+00
          ],
          [
              1.72605524e+00, 1.88481937e+00, 1.92892970e-03, -3.69851379e-02,
              3.37490590e-93, 1.49437858e+00, 2.94218424e-98, 6.42548932e-87,
              -1.16766145e-89, 1.77561187e+00
          ],
          [
              1.90768844e+00, 1.93244124e+00, -2.74564574e-02, -2.71667538e-01,
              -2.18666923e-49, 1.55576580e+00, -1.58598149e-75, 3.78152642e-98,
              4.23037119e-99, 1.73901334e+00
          ],
          [
              1.55641842e+00, 1.79417984e+00, -5.33388254e-02, -3.65821918e-02,
              -4.33876657e-76, 1.38924817e+00, -4.44999122e-91, 2.60235079e-78,
              -1.34557746e-92, 1.65395741e+00
          ],
          [
              1.80106140e+00, 1.71822717e+00, -2.24185198e-02, -6.08623498e-02,
              -2.63200165e-95, 1.63469537e+00, -1.59828045e-75, 8.96513524e-57,
              -1.54236694e-91, 1.63753035e+00
          ],
          [
              1.63285069e+00, 1.69965960e+00, -2.24185788e-02, -9.00696543e-04,
              3.80723743e-70, 1.64507097e+00, 1.53521926e-58, -6.48996765e-73,
              -2.38599167e-61, 1.43841903e+00
          ],
          [
              1.66784122e+00, 1.64395019e+00, -1.94467567e-02, 3.32720907e-02,
              -4.44534570e-59, 1.50816903e+00, -7.64643820e-98, 2.75376561e-97,
              1.07638351e-88, 1.70558378e+00
          ],
          [
              1.61329682e+00, 1.63271565e+00, -4.29314297e-02, 1.94655265e-01,
              -1.12144809e-50, 1.27333655e+00, 1.82673396e-54, 3.71140173e-98,
              -1.62296556e-59, 1.61163731e+00
          ],
          [
              1.59510539e+00, 1.81980312e+00, -3.30870058e-02, -3.71166967e-02,
              -7.87829729e-85, 1.42915028e+00, -4.36115173e-69, 2.03217920e-97,
              3.22513628e-94, 1.62940354e+00
          ],
          [
              1.79143343e+00, 1.57172023e+00, -2.24185198e-02, -1.51838680e-01,
              -4.99287499e-66, 1.73919962e+00, -1.17830405e-98,
              -2.28552643e-78, -4.83889897e-56, 1.83747478e+00
          ],
          [
              1.74400131e+00, 1.57415224e+00, -3.35400574e-02, -5.86277565e-02,
              -4.11793646e-94, 1.37246761e+00, 1.05188608e-90, -1.92722372e-92,
              1.73788699e-89, 1.56492990e+00
          ],
          [
              1.60255981e+00, 1.70390259e+00, -2.50406772e-02, -2.13121414e-01,
              1.63413144e-69, 1.43694087e+00, -9.56922597e-98, 8.74895335e-85,
              -1.80515318e-50, 1.60100365e+00
          ],
          [
              1.53168374e+00, 1.84291014e+00, -2.24414250e-02, 3.46318652e-02,
              3.69355801e-75, 1.53155037e+00, 9.85110048e-51, -4.91456836e-54,
              -1.06266143e-47, 1.46674520e+00
          ]]),
    list([[7.61211153e-01], [5.29023058e-01], [-6.76783513e-01],
          [-1.23527535e-01], [1.04599422e-01], [1.06178562e+00],
          [-1.09977597e-43], [-1.22990539e-90], [-3.14851814e-21],
          [7.33380751e-01]])
]