#!/bin/env python3.5
#Author: Saurabh Pathak
from matplotlib import pyplot as pl
from mpl_toolkits.mplot3d import Axes3D
from numpy import genfromtxt, matrix, zeros, exp, tanh
from ann import ANN

cols = {0, 1, 4}
data = genfromtxt('data/iris.data.train', delimiter=',', converters={4: lambda x: 0. if x == b'Iris-setosa' else 1. if x == b'Iris-versicolor' else 2.}, usecols=cols)
dataset, y = matrix(data[:,:2]), data[:,2]

def bitmapper():
    y_new = matrix(zeros((y.shape[0], 3), 'float64'))
    for i in range(y.size): y_new[i, y[i]] = 1
    return y_new.T

y = bitmapper()
print(y)

#nn = ANN(lambda x: 1 / (1 + exp(-x)), (dataset.shape[1], 4, 5, 4, 3))
nn = ANN(tanh, (dataset.shape[1], 7, 3))
nn.learn(dataset, y)