Beispiel #1
0
def main_circles():
    # init
    data = CirclesData()
    data.plot_data()
    np.random.seed(42)
    N = data.Xtrain.shape[0]
    inds = np.arange(0, N)
    np.random.shuffle(inds)
    Xtrain = data.Xtrain[inds]
    Ytrain = data.Ytrain[inds]
    Nbatch = 15
    nx = data.Xtrain.shape[1]
    nh = 10
    ny = data.Ytrain.shape[1]
    eta = 0.03

    # Premiers tests, code à modifier
    model, loss, optim = init_model(nx, nh, ny, eta)

    writer = SummaryWriter()
    L, acc = 0, 0

    # TODO apprentissage
    Nepochs = 200
    for i in range(Nepochs):

        for j in range(0, N, Nbatch):
            Xbatch = Xtrain[j:j + Nbatch]
            Ybatch = Ytrain[j:j + Nbatch]
            Yhat = model(Xbatch)
            L, acc = loss_accuracy(loss, Yhat, Ybatch)
            # Calcule les gradients
            optim.zero_grad()
            L.backward()
            optim.step()

        # Loss and Accuracy on Test
        Yhat_test = model(data.Xtest)
        L_test, acc_test = loss_accuracy(loss, Yhat_test, data.Ytest)

        data.plot_loss(L, L_test, acc, acc_test)

    Ygrid = torch.nn.Softmax(dim=1)(model(data.Xgrid))
    data.plot_data_with_grid(Ygrid.detach())

    # attendre un appui sur une touche pour garder les figures
    input("done")
Beispiel #2
0
    # Premiers tests, code à modifier
    model, loss = init_model(nx, nh, ny)

    writer = SummaryWriter()
    L, acc = 0, 0

    # TODO apprentissage
    Nepochs = 200
    for i in range(Nepochs):

        for j in range(0, N, Nbatch):
            Xbatch = Xtrain[j:j + Nbatch]
            Ybatch = Ytrain[j:j + Nbatch]
            Yhat = model(Xbatch)
            L, acc = loss_accuracy(loss, Yhat, Ybatch)
            # Calcule les gradients
            L.backward()
            params = sgd(model, eta)

        # Loss and Accuracy on Test
        Yhat_test = model(data.Xtest)
        L_test, acc_test = loss_accuracy(loss, Yhat_test, data.Ytest)

        data.plot_loss(L, L_test, acc, acc_test)

    Ygrid = torch.nn.Softmax(dim=1)(model(data.Xgrid))
    data.plot_data_with_grid(Ygrid.detach())

    # attendre un appui sur une touche pour garder les figures
    input("done")