def train_model(): """ This function will train model Tips: Load test,validation data first Then, seperately load training data, since training data is really huge. :return: """ path = '/home/jht00622/wiki_new.pkl' data = load_data(path) ## extract different type data train_dataset = data['train_dataset'] / 255 train_age_labels = data['train_age_labels'] #train_gender_labels = data['train_gender_labels'] valid_dataset = data['valid_dataset'] / 255 valid_age_labels = data['valid_age_labels'] #valid_gender_labels = data['valid_gender_labels'] test_dataset = data['test_dataset'] / 255 test_age_labels = data['test_age_labels'] #test_gender_labels = data['test_gender_labels'] hight = 128 channel = 1 batch_size = 50 learn_rate = 0.001 n_output = 4 total_size = train_dataset.shape[0] net = Network(n_output=n_output, n_length=hight, learning_rate=learn_rate, batch_size=batch_size, channel=channel, output_graph=False, use_ckpt=False) num_steps = 50000 for i in range(num_steps): # randomly sample batch memory from all memory indices = np.random.permutation(total_size)[:batch_size] batch_x = train_dataset[indices, :, :, :] batch_y = train_age_labels[indices, :] net.learn(batch_x, batch_y) if i % 20 == 0: cost, accu_rate = net.get_accuracy_rate(batch_x, batch_y) print("Iteration: %i. Train loss %.5f, Minibatch accuracy:" " %.1f%%" % (i, cost, accu_rate)) if i % 100 == 0: cost, accu_rate = net.get_accuracy_rate(valid_dataset, valid_age_labels) print("Iteration: %i. Validation loss %.5f, Validation accuracy:" " %.1f%%" % (i, cost, accu_rate)) cost, accu_rate = net.get_accuracy_rate(test_dataset, test_age_labels) print("Iteration: %i. Test loss %.5f, Test accuracy:" " %.1f%%" % (i, cost, accu_rate))
def train_model(): """ This function will train model Tips: Load test,validation data first Then, seperately load training data, since training data is really huge. :return: """ path = '/home/hengtong/project/age_gender/data/small/wiki_new.pkl' data = load_data(path) ## extract different type data train_dataset = data['train_dataset']/255 train_age_labels = data['train_age_labels'] #train_gender_labels = data['train_gender_labels'] valid_dataset = data['valid_dataset']/255 valid_age_labels = data['valid_age_labels'] #valid_gender_labels = data['valid_gender_labels'] test_dataset = data['test_dataset']/255 test_age_labels = data['test_age_labels'] #test_gender_labels = data['test_gender_labels'] hight = 128 channel = 1 batch_size = 128 learn_rate = 0.01 n_output = 4 # age mode total_size = train_dataset.shape[0] net = Network( n_output = n_output, n_length=hight, learning_rate=learn_rate, batch_size=batch_size, channel=channel, output_graph=False, use_ckpt=False ) epoch = 400 # epoch iteration = int(total_size / batch_size) print iteration i = 1 # total training time accu_train_age = [] accu_valid_age = [] accu_test_age = [] early_stop =0 # early stopping flag train_rate_age = 0 for e in range(epoch): print("-------------------------------") print("epoch %d" % (e + 1)) # randomly sample batch memory from all memory indices = np.random.permutation(total_size) for ite in range(iteration): mini_indices = indices[ite * batch_size:(ite + 1) * batch_size] batch_x = train_dataset[mini_indices, :, :, :] batch_y_age = train_age_labels[mini_indices, :] net.learn(batch_x, batch_y_age) if i % 50 == 0: cost, train_rate_age= net.get_accuracy_rate(batch_x, batch_y_age) print("Iteration: %i. Train loss %.5f, Minibatch gen accuracy:"" %.1f%%"% (i, cost, train_rate_age)) accu_train_age.append(train_rate_age) if i % 50 == 0: cost, valid_rate_age = net.get_accuracy_rate(valid_dataset, valid_age_labels) print("Iteration: %i. Validation loss %.5f, Validation gen accuracy:" " %.1f%%" % (i, cost, valid_rate_age)) accu_valid_age.append(valid_rate_age) cost, test_rate_age= net.get_accuracy_rate(test_dataset, test_age_labels) print("Iteration: %i. Test loss %.5f, Test gen accuracy:"" %.1f%%" % (i, cost, test_rate_age)) accu_test_age.append(test_rate_age) if i % 500 == 0: net.save_parameters() i = i + 1 # early stopping if train_rate_age == 100: if early_stop == 10: print("Early Stopping!") break else: early_stop = early_stop + 1 net.plot_cost() # plot trainingi cost plt.figure() # plot accuracy plt.plot(np.arange(len(accu_train_age)), accu_train_age, label='train age', linestyle='--') plt.plot(np.arange(len(accu_valid_age)), accu_valid_age, label='valid age', linestyle='-') plt.plot(np.arange(len(accu_test_age)), accu_test_age, label='test age', linestyle=':') plt.ylabel('age accuracy') plt.xlabel('epoch') plt.legend(loc='lower right') plt.grid() plt.savefig('age.png')