Esempio n. 1
0
def wsmtrain(Xa,Ya,w1,b1,w2,b2,Fhid,xshape,skip,sidevalue=0,   \
             n_epochs=100,lr=0.1,tfp=10,dfreq=10,dprint=1,dcurves=0, dperf=1, \
             Xv=None,Yv=None,pval=0,ntfpamin=100,dval=1) :
    ''' Apprentissage d'un réseau de neurones à masques et poids partagés pour une
    | architecture particulière : qui ne comprend qu'un niveau de convolution
    | constitué que d'une seule carte de caractéristique.
    | Sans données de validation : w1,b1,w2,b2 = wsmtrain(Xa,Ya,...)
    | Avec données de validation : w1,b1,w2,b2,it_minval = wsmtrain(Xa,Ya,...)
    | 
    | En Entrée :
    | Xa        : (matrice Nxp) Données d'entrées de l'ensemble d'apprentissage.
    |             N est le nombre de lignes d'exemples, p est la dimension des
    |             exemples (c'est à dire des vecteurs d'entrée en ligne).
    | Ya        : (matrice Nxq) Données de sortie de l'ensemble d'apprentissage
    | w1,b1,w2,b2: Poids et seuils initiaux.
    | Fhid      : Fonction à utiliser pour le calcul de la carte des caractéristiques
    |             (convolution sur les entrées). :
    |             "lin" : fonction linéaire
    |             "tah" ou "tgh" : fonction tangente hyperbolique
    |             Remarque : La couche de sortie est calculée avec la fonction
    |                        tangente hyperbolique
    | xshape    : [a, b] Les dimensions 2D à appliquer sur les vecteurs d'entrées
    |             de l'ensemble d'apprentissage.   
    | skip      : [l, c] : Décalage à appliquer pour le masque (mask) d'entrée.
    |              l : décalage horizontal (en nombre de ligne1)
    |              c : décalage vertical (en nombre de colonne)
    | sidevalue : Valeur d'entrée à affecter sur les points de la zone frontalière 
    |             des formes d'entrées
    | n_epochs  : (entier) Nombre maximum d'itérations (defaut: 100)
    | lr        : learning rate (ou pas de gradient) (defaut: 0.1)
    | tfp       : (entier) Top Frequency Processing : Permet de diminuer le nombre 
    |             de certains calcul (eraly stopping) en ne les effectuant que tous
    |             les tfp fois. (defaut: 10)
    | dfreq    : (entier) Fréquence d'affichage. par defaut, qu'un affichage final.    
    | dprint   : (entier) print (ou pas) des valeurs pendant le déroulement de    
    |            l'apprentissage, aussi en fonction des paramètre dperf, dval qui
    |            suivent. Valeur par défaut = 1 : affichage. 
    | dcurves  : Affichage (ou pas) des courbes du déroulement de l'apprentissage
    |            aussi en fonction des paramètre dperf, dval qui suivent.
    |            valeur par défaut = 0 : pas d'affichage.
    |            >= 1 : Affichage
    |               2 : Les erreurs sont affichées en log.
    | dperf    : Affichage (si=1) ou pas (si=0) de valeurs associées aux performances
    |            qui se comprennent ici comme des pourcentages d'erreur.
    | Xv       : (matrice) Données d'entrée de validation. Permet de pratiquer
    |            l'arrêt prématuré.
    | Yv       : (matrice) Sorties désirées associées aux données de validation Xv.
    | pval     : Indique sur quelle valeur doit se faire l'appréciation de l'arrêt
    |            prématuré :
    |            0 : sur l'erreur en validation (c'est le cas pas défaut)
    |            1 : sur la performance en validation 
    | ntfpamin : (entier) Nombre maximum de tfp à faire après le dernier minimum 
    |             constaté en validation (défaut: 100).
    | dval     : Affichge (si=1) ou pas (si=0) des valeurs qui se concerne l'ensemble
    |            de validation
    |
    | En Sortie :
    | Sans données de validation :
    |   w1,b1,w2,b2 : Matrices des poids obtenus à la fin de la procédure.
    | Avec données de validation   :
    |   w1,b1,w2,b2 : Matrices des poids obtenus au minimum sur l'ensemble de
    |                 validation. (Les poids de fin de procédure sont stockés
    |                 dans un fichier: wsmendwei)
    |   it_minval: l'indice de l'itération où le minimum en validation a été trouvé.
    '''
    nbitamin = ntfpamin
    #...
    mask = np.shape(w1)
    h_size, loop = wsm_convmask(xshape, mask, skip)

    # Check and init stuff
    if np.prod(xshape) != np.size(Xa, 1):
        print(
            "wsmtrain: la longueur des entrées n'est pas compatible avec les dimensions données dans xshape"
        )
        sys.exit(0)
    if Fhid is not "lin" and Fhid is not "tah" and Fhid is not "tgh":
        print("wsmtrain: ", Fhid, "activation function not defined !")
        sys.exit(0)
    Nout, dimout = np.shape(Ya)
    Napp, papp = np.shape(Xa)
    if Nout != Napp:
        print(
            "wsmtrain: input and output of learning set must have the same number of item"
        )
        sys.exit(0)

    # Do we use Validation data (for early stopping)
    valid_on = False
    if Xv is not None and Yv is not None:
        Nxval, pxval = np.shape(Xv)
        if Nxval != np.size(Yv, 0):
            print(
                "wsmtrain: input and output validating set must have the same number of item"
            )
            sys.exit(0)
        if pxval != papp:
            print(
                "wsmtrain: input dim of validating and learning set must be the same"
            )
            sys.exit(0)
        valid_on = True
        cpit_addmin = 0
        # ComPtage du nbre d'IT apres le Min
        minval = 1e16
        # Initialisation du minimum en validation
        it_minval = 0
        # Itération du minimum en validation
        WBMV = [np.copy(w1),
                np.copy(b1),
                np.copy(w2),
                np.copy(b2)]
    if valid_on == False:
        dval = 0
        # Si dval était à 1, on le force à 0.

    # init other stuff
    rw1, cw1 = np.shape(w1)
    lr2 = lr / max(h_size)
    lr1 = lr / np.sqrt(rw1 * cw1)
    y = np.zeros(dimout)
    #yt      = np.zeros((Napp,dimout));

    #--------------------------------------------------------
    if (dcurves > 0):
        plt.figure()
        xplot = np.array([-1, 0]).astype(int)
        plt.ion()

    #=========================================================
    # boucle principale d'Apprentissage
    continumongars = True
    epk = 0
    while (epk < n_epochs) and continumongars:
        epk += 1

        err = []
        for n in np.arange(Napp):  # pour chaque forme d'apprentissage
            # Forwarding data :
            Ai, sub_x = wsm_convol(Xa[n, :],
                                   xshape,
                                   w1,
                                   loop,
                                   sidevalue=sidevalue)

            if Fhid == "lin":
                Ai = Ai + b1
            else:  #:=> hidden_function=='tah'
                Ai = np.tanh(Ai + b1)

            # network output
            for k in np.arange(dimout):
                y[k] = np.tanh(np.sum(w2[k] * Ai) + b2[k])
            #yt[n,:] = y; # + sto de chaque y calculé

            # Error computation:
            errn = (Ya[n, :] - y)
            err.append(errn)

            dy_db2 = (1 - y * y)
            sum_err_dy_dw1 = np.zeros((rw1, cw1))
            sum_err_dy_db1 = np.zeros((h_size[0], h_size[1]))

            for k in np.arange(dimout):
                dy_dw1_k = np.zeros((rw1, cw1))
                dy_db1 = np.zeros((h_size[0], h_size[1]))

                w2_k = w2[k]

                if Fhid == "lin":
                    for i in np.arange(h_size[0]):
                        sub_xi = sub_x[i]
                        for j in np.arange(h_size[1]):
                            #dy_dw1_k   = dy_dw1_k    + np.dot(w2_k[i,j], sub_xi[j]);
                            Z = np.dot(w2_k[i, j], sub_xi[j])
                            dy_dw1_k = dy_dw1_k + Z
                            dy_db1[i, j] = dy_db1[i, j] + w2_k[i, j]
                else:  #:=> hidden_function=='tah'
                    for i in np.arange(h_size[0]):
                        sub_xi = sub_x[i]
                        for j in np.arange(h_size[1]):
                            Z = np.dot((1 - Ai[i, j]**2), w2_k[i, j])
                            dy_dw1_k = dy_dw1_k + np.dot(Z, sub_xi[j])
                            dy_db1[i, j] = dy_db1[i, j] + Z

                dy_dw1_k = np.dot(dy_db2[k], dy_dw1_k)
                dy_db1 = np.dot(dy_db2[k], dy_db1)

                sum_err_dy_dw1 = sum_err_dy_dw1 + np.dot(err[n][k], dy_dw1_k)
                sum_err_dy_db1 = sum_err_dy_db1 + np.dot(err[n][k], dy_db1)

                w2[k] = w2_k + np.dot(lr2 * dy_db2[k], err[n][k] * Ai)

            b2 = b2 + lr2 * dy_db2 * err[n]
            w1 = w1 + lr1 * sum_err_dy_dw1 / dimout
            b1 = b1 + lr1 * sum_err_dy_db1 / dimout

        # fin d'1 passage de toute la base d'apprentissage
        #
        # Top Frequency Processing -----------------------
        if np.remainder(epk,
                        tfp) == 0:  # toutes les tpf passages de la base d'app
            # Je ne stocke rien
            if valid_on:
                Y = wsm_out(Fhid,
                            w1,
                            b1,
                            w2,
                            b2,
                            Xv,
                            xshape,
                            loop,
                            sidevalue=-1)
                if pval:  # C'est la perf qui doit servir pour l'early stop.
                    perfval = tls.classperf(Yv, Y, miss=1)
                    earlyval = perfval
                else:  # c'est l'erreur quadratique
                    errqval = np.sum((Yv - Y)**2)
                    earlyval = errqval

                # Early stopping (stop criterium 1)
                if earlyval < minval:  # detection du minimum en validation
                    minval = np.copy(earlyval)
                    it_minval = np.copy(epk)
                    cpit_addmin = 0
                    # garder les poids du min en validation (MV stand for Min Validation)
                    WBMV = [
                        np.copy(w1),
                        np.copy(b1),
                        np.copy(w2),
                        np.copy(b2)
                    ]
                else:
                    cpit_addmin += 1
                    if cpit_addmin >= nbitamin:
                        print(
                            "wsmtrain: Min Val (%f) at it %d then stop %d it after"
                            % (minval, it_minval, nbitamin * tfp))
                        continumongars = False

        # Affichage ---------------------------------------
        # Si demandé et selon dfreq (ou fin de boucle pour le dernier affichage)
        if  (dprint>0 or dcurves>0) \
        and (np.remainder(epk,dfreq)==0 or continumongars==0 or epk>=n_epochs) :
            # Il se peut que dfreq ne coincide pas avec tfp, et que rendu ici
            # toutes les valeurs nécéssaires n'aient pas été calculées, le plus
            # simple alors seraient de les calculer ici.
            #
            # Pour l'ens d'App il faut calculer un Y sur l'ensemble de la base
            Y = wsm_out(Fhid,
                        w1,
                        b1,
                        w2,
                        b2,
                        Xa,
                        xshape,
                        loop,
                        sidevalue=sidevalue)
            errqapp = np.sum((Ya - Y)**2)
            errpapp = np.copy(errqapp)
            # l'erreur à ploter
            if dcurves > 1:
                errpapp = np.log(errqapp)
                # l'erreur à ploter en log
            if dperf:
                perfapp = tls.classperf(Ya, Y, miss=1)

            if dval == 1:  # display des elt de validation
                Y = wsm_out(Fhid,
                            w1,
                            b1,
                            w2,
                            b2,
                            Xv,
                            xshape,
                            loop,
                            sidevalue=sidevalue)
                errqval = np.sum((Yv - Y)**2)
                errpval = np.copy(errqval)
                # l'erreur à ploter
                if dcurves > 1:
                    errpval = np.log(errqval)
                    # l'erreur à ploter en log
                if dperf:
                    perfval = tls.classperf(Yv, Y, miss=1)

            # D'abord les prints (si demandés)
            if dprint:
                # L'entête
                if np.remainder(epk, 20 * dfreq) == dfreq:
                    print("| epk  |   ErrqApp  ", end='')
                    if dval:
                        print("|   ErrqVal  ", end='')
                    if dperf:
                        print("| PerfApp ", end='')
                        if dval:
                            print("| Perfval ", end='')
                    print("|")
                # Les valeurs
                print("| %4d | %10.6f " % (epk, errqapp), end='')
                # l'err tjrs
                if dval:
                    print("| %10.6f " % (errqval), end='')
                if dperf:
                    print("|  %.4f " % (perfapp), end='')
                    if dval:
                        print("|  %.4f " % (perfval), end='')
                print("|")

            # Ensuite la figure des courbes (si demandée)
            if dcurves > 0:
                xplot = xplot + 1
                if xplot[0] > 0:  # parce que la 1ère fois on a pas encore de valeur précédente valide
                    if dperf:
                        plt.subplot(2, 1, 2)
                        plt.plot(xplot, [PrevPerfApp, perfapp], '.-')
                        if dval:
                            plt.plot(xplot, [PrevPerfVal, perfval], '.-')
                        plt.subplot(2, 1, 1)
                    #
                    plt.plot(xplot, [PrevErrpApp, errpapp], '.-')
                    #l'err tjrs
                    if dval:
                        plt.plot(xplot, [PrevErrpVal, errpval], '.-')
                    plt.draw()
                # sauvegarde des valeurs de plot précédent
                PrevErrpApp = errpapp
                # tjrs
                if dperf:
                    PrevPerfApp = perfapp
                if dval:
                    PrevErrpVal = errpval
                    if dperf:
                        PrevPerfVal = perfval

    # fin du while : fin de la boucle principale d'Apprentissage
    #=============================================================
    # Last Action ---------------------------------------
    print("wsmtrain: %d epochs done" % epk)
    #
    if dcurves > 0:
        #plt.xlabel("X %d Epoch" % (tfp));
        plt.xlabel("X %d Epoch" % (dfreq))
        if dperf:
            plt.subplot(2, 1, 2)
            plt.ylabel("Perf: % Err. de classif.")
            plt.axis("tight")
            plt.subplot(2, 1, 1)
            # pour les erreurs (rem: pas de perf => pas de subplot)
        if dcurves == 1:
            plt.ylabel("$\sum$ Errors Quadratiques")
        else:
            plt.ylabel("log($\sum$ Errors Quadratiques)")
        plt.axis("tight")
        if dval:
            plt.legend(["Learning", "Validating"])
        else:
            plt.legend(["Learning"])
        plt.show()

    if valid_on:
        np.save("wsmendwei", [w1, b1, w2, b2])
        # Sauvegarde des poids en fin d'apprentissage
        return WBMV[0], WBMV[1], WBMV[2], WBMV[3], it_minval
    else:
        return w1, b1, w2, b2
Esempio n. 2
0
def pmctrain(Xa,Ya,WW,FF,nbitemax=1000, fparm=[0.6667,1.7159,0.0], \
             gstep=0.1, alpha=0.4737, weivar_seuil=10**(-6), \
             tfp=10,dfreq=-1,dprint=1,dcurves=0, dperf=0, \
             Xv=None,Yv=None,pval=0, ntfpamin=100,dval=1) :
    ''' Sans données de validation : WW = pmctrain(Xa,Ya,WW,FF,...)
    | Avec données de validation   : WW, it_minval = pmctrain(Xa,Ya,WW,FF,...)
    | Effectue les itérations d'apprentissage d'un PMC par l'algorithme de
    | rétropropagation de gradient dans sa version batch.
    |
    | En Entrée :
    | Xa : (matrice) Données d'apprentissage en entrée du PMC
    | Ya : (matrice) Sorties désirées (données de référence) associées à Xa.
    | WW  : liste des matrices de poids pour chaque couche (ci->ci+1)
    | FF  : liste de string des fonctions d'activation des neurones cachés.
    |       Les valeurs possibles sont :
    |       'tah' : pour la tangente hyperbolique 
    |       'sig' : pour la fonction sigmoide
    |       'lin' : pour la fonction linéaire
    |       'exp' : pour la fonction exponnentielle
    |       Il doit y avoir autant de fonctions dans FF que de matrice de poids
    |       dans WW. 
    | nbitemax : (entier) Nombre maximum d'itérations (defaut: 1000)
    | fsigparm : (vecteur) Paramètres de la fonction sigmoide :
    |            [asympt, pente, offset], (defaut: [0.6667, 1.7159, 0.0]).
    | gstep    : (réel) Pas de gradient initial (defaut: 0.1) 
    | alpha    : (réel) facteur de momentum (souvenir des états passés)
    |            default value=0.9/(1+.9)=0.4737)    
    | weivar_seuil : (réel) Paramètres pour l'arret sur un seuil minimum de 
    |            variation des poids. Valeur par défaut = 10**(-6).
    | tfp      : (entier) Top Frequency Processing : Permet de diminuer le nombre 
    |            de certains calcul (eraly stopping) et de stockage en ne les
    |            effectuant que tous les tfp fois. Ne doit etre supérieur à
    |            dfreq ou nbitemax (defaut: 10)
    | dfreq    : (entier) Fréquence d'affichage. par defaut, qu'un affichage final.    
    | dprint   : (entier) print (ou pas) des valeurs pendant le déroulement de    
    |            l'apprentissage, aussi en fonction des paramètre dperf, dval qui
    |            suivent. Valeur par défaut = 1 : affichage. 
    | dcurves  : Affichage (ou pas) des courbes du déroulement de l'apprentissage
    |            aussi en fonction des paramètre dperf, dval qui suivent.
    |            valeur par défaut = 0 : pas d'affichage.
    |            >= 1 : Affichage
    |               2 : Les erreurs sont affichées en log.
    | dperf    : Affichage (si=1) ou pas (si=0) de valeurs associées aux performances
    |            qui se comprennent ici comme des pourcentages d'erreur.
    | Xv       : (matrice) Données d'entrée de validation. Permet de pratiquer
    |            l'arrêt prématuré.
    | Yv       : (matrice) Sorties désirées associées aux données de validation Xv.
    | pval     : Indique sur quelle valeur doit se faire l'appréciation de l'arrêt
    |            prématuré :
    |            0 : sur l'erreur en validation (c'est le cas pas défaut)
    |            1 : sur la performance en validation 
    | ntfpamin : (entier) Nombre maximum de tfp à faire après le dernier minimum 
    |             constaté en validation (défaut: 100).
    | dval     : Affichge (si=1) ou pas (si=0) des valeurs qui se concerne l'ensemble
    |            de validation
    |
    | En Sortie :
    | Sans données de validation :
    |    WW : Matrices des poids obtenus à la fin de la procédure
    | Avec données de validation   :
    |    WW : Matrices des poids obtenus au minimum sur l'ensemble de validation.
    |         (Les poids de fin de procédure sont stockés dans un fichier: pmcendwei)
    |    it_minval: l'indice de l'itération où le minimum en validation a été trouvé.
    '''
    nbitamin = ntfpamin
    #...
    nbc = len(WW)
    # nombre de couches sans l'input
    if np.size(FF) != nbc:
        print(
            "pmctrain: Le nombre de fonctions de transfert doit correspondre au nombre de couches à calculer"
        )
        sys.exit(0)
    if dfreq == -1:
        dfreq = nbitemax
    if tfp > dfreq or tfp > nbitemax:
        print(
            "pmctrain: tfp parameter must not be greater than dfreq or nbitemax"
        )
        sys.exit(0)

    asympt = fparm[0]
    pente = fparm[1]
    offset = fparm[2]
    #
    #Initialisations ------------------------
    # Dimension stuff
    ell = np.size(Xa, 0)
    if np.size(Ya, 0) != ell:
        print("pmctrain: Xa and Ya must have the same number of line")
        sys.exit(0)
    onell = np.ones((ell, 1))
    # biais...
    t = []
    for i in np.arange(nbc):
        t.append(np.size(WW[i], 0))

    # Pour l'ajustement des pas (1)
    a = 1.5
    b = 1 / a
    dim = 0.5
    # diminution du pas en cas d'augmentation de l'erreur

    # Init sauvegardes pour marche arriere and other
    #WWp    = np.copy(WW);
    WWp = copy.deepcopy(WW)
    gradWWp = []
    OVF = []
    WWpp = []
    sizWW = []
    nbWW = []
    oneWW = []
    PAS = []
    for i in np.arange(nbc):
        gwpi = np.zeros(WW[i].shape)
        gradWWp.append(gwpi)
        # Pour l'ajustement des pas (2)
        ovfi = 1e3 * np.ones(np.shape(WW[i]))
        OVF.append(ovfi)
        # For weights variation threshold check and deal (1)
        wppi = np.zeros(np.shape(WW[i]))
        WWpp.append(wppi)
        sizwi = np.shape(WW[i])
        sizWW.append(sizwi)
        nbwi = np.prod(sizwi)
        nbWW.append(nbwi)
        onewi = np.ones(nbwi)
        oneWW.append(onewi)
        # ...
        pasi = gstep * np.ones(sizwi) * 2 / ell
        PAS.append(pasi)

    gradWW = copy.deepcopy(gradWWp)
    #np.copy(gradWWp); # easier if initialized
    descWW = copy.deepcopy(gradWW)
    #np.copy(gradWW);  # easier if initialized
    errold = 1e16

    # For weights variation threshold check and deal (2)
    sbougeti = ""
    #lambada = 0;               # Weight decay factor
    #lambda  = 0 #=lambada.*ones(size(w2));
    errtot = []
    # memoire erreur pour sortie
    perftot = []
    # memoire performance pour sortie
    nballit = 0
    # nombre total d'iterations effectuees
    nbite = 0
    # nombre de 'bonnes iterations'
    ctrback = 0
    # compteur d'iteration de marche arriere
    asympt = 0.6667
    pente = 1.7159
    offset = 0.0
    # sigmoide parameters
    weivar_freq = 200
    # check weights variation frequency default value (2nd stop criterium)

    # Do we use Validation data
    valid_on = False
    if Xv is not None and Yv is not None:
        if np.size(Xv, 0) != np.size(Yv, 0):
            print("pmctrain: Xv and Yv must have the same number of line")
            sys.exit(0)
        if np.size(Xv, 0) > 0:
            valid_on = True
            cpit_addmin = 0
            # ComPtage du nbre d'IT apres le Min
            minval = 1e16
            # Initialisation du minimum en validation
            WWMV = WW
            # garder les poids du min en validation
            it_minval = 0
            # Iteration du Min en sortie
            errtotval = []
            # memoire erreur pour sortie
            perftotval = []
            # memoire performance pour sortie
    if valid_on == False:
        dval = 0
        # Si dval était à 1, on le force à 0.

    #--------------------------------------------------------
    if (dcurves > 0):
        plt.figure()
        plt.ion()

    #=========================================================
    # boucle principale d'Apprentissage
    continumongars = 1
    while (nballit < nbitemax) & continumongars:

        #Increment d'iterations ..........................
        nballit = nballit + 1

        # Propagation Avant -------------------------------
        Yi = []
        Xi = np.copy(Xa)
        for i in np.arange(nbc):
            Xi = np.concatenate((Xi, onell), axis=1)
            #Ajout du biais aux entrées
            Ai = np.dot(Xi, WW[i])
            # Somme pondérées des entrées
            if FF[i] == "tah":
                Xi = np.tanh(Ai)
            elif FF[i] == "sig":
                ekx = np.exp(-pente * Ai)
                Xi = asympt * (1 - ekx) / (1 + ekx) + offset
            elif FF[i] == "lin":
                Xi = Ai
            elif FF[i] == "exp":
                Xi = np.exp(Ai)
            else:
                print("pmctrain: Unknown activation function %d." % (i + 1))
                sys.exit(0)
            Yi.append(Xi)

        # Coût --------------------------------------------
        err = (Yi[nbc - 1] - Ya)
        errnew = np.sum(err * err)  #+ sum(sum(lambda.*w2.^2));

        # Top Frequency Processing ------------------------
        # (to avoid too much computation and storage)
        if np.remainder(nballit, tfp) == 0:
            # on App
            errtot.append(errnew)
            # Sto Eapp
            if dperf:  #si cas de classification
                perfapp = tls.classperf(Ya, Yi[nbc - 1], miss=1)
                # Compute Papp
                perftot.append(perfapp)
                # Sto Papp

            # on Validation (if given) . . . . . . .
            if valid_on:
                Yvi = pmcout(Xv, WW, FF, fparm=fparm)
                if pval == 1:  # C'est la perf qui doit servir pour l'early stop.
                    perfval = tls.classperf(Yv, Yvi, miss=1)
                    earlyval = perfval
                else:  # c'est l'erreur quadratique
                    errqval = np.sum((Yv - Yvi)**2)
                    earlyval = errqval

                if dval:  # On veut afficher les valeurs de validation ;
                    # il faut calculer ce qui ne l'a pas été
                    if pval == 0:
                        perfval = tls.classperf(Yv, Yvi, miss=1)
                    else:
                        errqval = np.sum((Yv - Yvi)**2)
                    errtotval.append(errqval)
                    perftotval.append(perfval)

                # Early stopping (stop criterium 1)
                if earlyval < minval:  # detection du minimum en validation
                    minval = np.copy(earlyval)
                    it_minval = np.copy(nballit)
                    cpit_addmin = 0
                    WWMV = copy.deepcopy(WW)
                    #np.copy(WW);
                    # garder les poids du min en validation. MV stand for Min Validation
                else:
                    cpit_addmin += 1
                    if cpit_addmin >= nbitamin:
                        print(
                            "pmctrain: Min Val (%f) at it %d then stop %d it after"
                            % (minval, it_minval, nbitamin * tfp))
                        continumongars = 0

        # Marche arriere ----------------------------------
        if errnew > errold:
            ctrback = ctrback + 1
            # Increment des mauvaises iterations
            WW = copy.deepcopy(WWp)
            gradWW = copy.deepcopy(gradWWp)
            descWW = copy.deepcopy(gradWW)
            for i in np.arange(nbc):
                PAS[i] = PAS[i] * dim
                WW[i] = WW[i] - PAS[i] * gradWW[i]
            #
        else:  # Ca marche => un pas en avant :
            nbite = nbite + 1
            # Increment des bonnes iterations

            # Back Propagation -------------------------------------
            # From la 1ère couche de sortie
            dJdy = 2 * err
            y = Yi[nbc - 1]
            if FF[nbc - 1] == "tah":
                f2Ai = (1 - y * y)
            elif FF[nbc - 1] == "sig":
                f2Ai = asympt * (1 - y * y)
            elif FF[nbc - 1] == "lin":
                f2Ai = 1
            elif FF[nbc - 1] == "exp":
                f2Ai = y
            #else : Ne devrait pas se produire : déjà  controlé lors de passe avant
            dJdAi = dJdy * f2Ai
            if nbc > 1:
                Xai = np.concatenate((Yi[nbc - 2], onell), axis=1)
            else:
                Xai = np.concatenate((Xa, onell), axis=1)
            gradWW[nbc - 1] = np.dot(Xai.T, dJdAi) / ell
            #+ (2*lambda.*w2))./ell;
            descWW[nbc -
                   1] = (1 - alpha) * gradWW[nbc - 1] + alpha * descWW[nbc - 1]

            # Les couches cachées internes
            for i in np.arange(nbc - 1):
                Wi = WW[nbc - 1 - i]
                ti = t[nbc - 1 - i]
                Wi = Wi[0:ti - 1, :]
                y = Yi[nbc - 2 - i]
                if FF[nbc - 2 - i] == "tah":
                    dJdAi = np.dot(Wi, dJdAi.T).T * (1 - y * y)
                elif FF[nbc - 2 - i] == "sig":
                    dJdAi = np.dot(Wi, dJdAi.T).T * asympt * (1 - y * y)
                elif FF[nbc - 2 - i] == "lin":
                    dJdAi = np.dot(Wi, dJdAi.T).T
                elif FF[nbc - 2 - i] == "exp":
                    dJdAi = np.dot(Wi, dJdAi.T).T * y
                #else : Ne devrait pas se produire : déjà  controlé lors de passe avant
                if i < nbc - 2:  # pas fin
                    Xai = np.concatenate((Yi[nbc - 3 - i], onell), axis=1)
                else:
                    Xai = np.concatenate((Xa, onell), axis=1)
                gradWW[nbc - 2 - i] = np.dot(Xai.T, dJdAi) / ell
                descWW[nbc - 2 - i] = (1 - alpha) * gradWW[
                    nbc - 2 - i] + alpha * descWW[nbc - 2 - i]

            # sauvegardes (avant correction des poids) ............
            errold = errnew
            WWp = copy.deepcopy(WW)
            gradWWp = copy.deepcopy(gradWW)

            # Ajustement des pas et correction des poids
            for i in np.arange(nbc):
                # Ajustement des pas ..................................
                test = (gradWW[i] * descWW[i]) >= 0
                PAS[i] = (test * a + np.invert(test) * b) * PAS[i]
                PAS[i] = (PAS[i] <= OVF[i]) * PAS[i] + (PAS[i] >
                                                        OVF[i]) * OVF[i]
                # + voir aussi ci-dessous

                #Correction des poids (apres ajustement des pas) -----
                WW[i] = WW[i] - PAS[i] * descWW[i]

            # Check weights variation (stop criterium 2) ---------
            if np.remainder(nbite, weivar_freq) == 0:
                bougeti = 0.0
                for i in np.arange(nbc - 1):
                    Wi = WW[i]
                    abswi = abs(Wi)
                    abswi = np.reshape(abswi, nbWW[i])
                    maxwione = np.max((oneWW[i], abswi), axis=0)
                    maxwione = np.reshape(maxwione, sizWW[i])
                    maxi = max(np.max(abs(WWpp[i] - Wi) / maxwione, axis=0))
                    if maxi > bougeti:
                        bougeti = maxi
                if bougeti < weivar_seuil:  # est-ce le bon critère d'arret : les poids peuvent bouger
                    continumongars = 0
                    # relativement à leur propre valeur sans que cela n'impact
                    print('pmctrain: Weight variation threshold stop\n')
                    # plus la fonction de cout significativement !!!???
                    # (nécessite de plus la sauvegarde des matrices)
                sbougeti = "%6.6f" % (bougeti)

            # Régulierement, tous les 1000 it (arbitrairement?) on se repositionnne et on redistribue
            # le jeu ... (sachant qu'avec d'autres valeurs (100, 2000) ca fait une différence)
            if np.remainder(nbite, 1000) == 0:
                # Sauvegarde des Poids (pour apprecier leur variation (cf bougeti) entre 1000 !? itérations)
                WWpp = np.copy(WW)
                # Moyennage des Pas (pour redistribuer le jeu)
                for i in np.arange(nbc):
                    PAS[i] = np.mean(PAS[i]) * np.ones(sizWW[i]) * 2 / ell

            # fin back propagation--------------------------------

        # Affichage ----------------------------------------------
        if  (dprint>0 or dcurves>0) \
        and (np.remainder(nballit,dfreq)==0 or continumongars==0 or nballit>=nbitemax) :

            # D'abord les prints (si demandés)
            if dprint:
                # L'entête
                if np.remainder(nballit, 20 * dfreq) == dfreq:
                    print("|    #epoch  |  Errorapp  ", end='')
                    if dval:
                        print("|   ErrqVal  ", end='')
                    if dperf:
                        print("| PerfApp ", end='')
                        if dval:
                            print("| Perfval ", end='')
                    print("| bougeti  |")

                # Les valeurs
                print("|%5d %5d | %10.6f " % (nballit, nbite, errnew), end='')
                if dval:
                    print("| %10.6f " % (errqval), end='')
                if dperf:
                    print("|  %.4f " % (perfapp), end='')
                    if dval:
                        print("|  %.4f " % (perfval), end='')
                print("| %8s |" % sbougeti)

            # Ensuite la figure des courbes (si demandée)
            if dcurves > 0:  # Si courbes demandée(s)
                absci = np.arange(np.size(errtot)) + 1
                # abscisse pour les plots
                # Plot des perf (si y'en a)
                if np.size(perftot) > 0:  # Si y'a des perf
                    plt.subplot(2, 1, 2)
                    plt.plot(absci, perftot, '.-b')
                    if dval:
                        plt.plot(absci, perftotval, '.-r')
                    plt.subplot(2, 1, 1)
                    # pour les erreurs (rem: pas de perf => pas de subplot)
                # Plot des erreurs
                if dcurves == 1:
                    plt.plot(absci, errtot, '.-b')
                else:
                    plt.plot(absci, np.log(errtot), '.-b')
                if dval:
                    if dcurves == 1:
                        plt.plot(absci, errtotval, '.-r')
                    else:
                        plt.plot(absci, np.log(errtotval), '.-r')
                plt.axis("tight")
                plt.draw()

    # fin du while : fin de la boucle principale d'Apprentissage
    #=============================================================
    # Last Actions --------------------------------------------
    print("ctrback=", ctrback)
    print("pmctrain: %d iterations done" % nballit)
    #
    if dcurves > 0:  # Ornemantation à la fin seulement
        plt.xlabel("x %d Epoch" % tfp)
        if np.size(perftot) > 0:  # Si y'a des perf
            plt.subplot(2, 1, 2)
            plt.ylabel("Perf: % Err. de classif.")
            plt.axis("tight")
            plt.subplot(2, 1, 1)
            # pour les erreurs (rem: pas de perf => pas de subplot)
        if dcurves == 1:
            plt.ylabel("$\sum$ Errors Quadratiques")
        else:
            plt.ylabel("log($\sum$ Errors Quadratiques)")
        plt.axis("tight")
        #if valid_on :
        if dval:
            plt.legend(["Learning", "Validating"])
        else:
            plt.legend(["Learning"])
        plt.show()

    # Last computations for return values
    if valid_on:
        np.save("pmcendwei", WWp)
        # Sauvegarde des poids en fin d'apprentissage
        return WWMV, it_minval
    else:
        return WWp