class Actor: def __init__(self, game, layers=[], checkpoint=None, format='one_hot', optimizer='adam'): self.game = game self.format = format self.layers = layers self.optimizer = optimizer self.network = Network( [game.state_size(format)] + layers + [game.num_possible_moves()], [], minibatch_size=50, steps=1, loss_function='cross_entropy', validation_fraction=0, test_fraction=0, learning_rate=0.001, optimizer=optimizer, output_functions=[tf.nn.softmax] ) self.network.build() if checkpoint: self.load_checkpoint(checkpoint) def select_move(self, state, stochastic=False): possible_moves = self.game.get_moves(state) formatted_state = self.game.format_for_nn(state, format=self.format) predictions = self.network.predict([formatted_state])[0] predictions = predictions[:len(possible_moves)] if not stochastic: move = np.argmax(predictions) return possible_moves[move] predictions = np.array(predictions) ps = predictions.sum() if predictions.sum() == 0: move = np.random.choice(np.arange(0, len(predictions))) else: predictions = predictions / predictions.sum() move = np.random.choice(np.arange(0, len(predictions)), p=predictions) return possible_moves[move] def save_checkpoint(self, checkpoint): self.network.save(checkpoint) def load_checkpoint(self, checkpoint): self.network.load(checkpoint)
from data.dataset_generator import Generator, number_to_index from nn.network import Network numbers = [6, 5, 2] layers = [784, 100, 3] net = Network(layers, number_to_index(numbers)) # cria um dataset de 60000 imagens, sendo 50000 para treino e 10000 para teste. # os números são limitados a 6, 5, 2, que são os correspondentes as 3 iniciais: # generator = Generator(50000, 10000, numbers, 28) # cria o dataset com 3 imagens de teste correspondentes as 3 iniciais, em numeros generator = Generator(0, 3, numbers, 28, True) dataset = generator.gen() # para treinar: # net.grad_descent(dataset[:50000], 200, 10, 3.0, test_data=dataset[50000:]) # carrega uma rede já treinada anteriormente: net.load('pre-trained-data/biases.npy', 'pre-trained-data/weights.npy') # porcentagem de acerto: print((net.evaluate(dataset) / 3) * 100)