def ProcessModel(self, request, context): global server_fits global episodes_processed if server_fits == 0: print("Client found, fitting model") elif server_fits % 100 == 0 and not verbose: # sanity check every 100 fits print(f"{server_fits}th fit") state = pickle.loads(request.state) next_state = pickle.loads(request.next_state) target = request.reward + gamma * np.max(model.predict(next_state)) target_vec = model.predict(state)[0] target_vec[request.action] = target, target_vec.reshape(-1, action_count), epochs=1, verbose=0) server_fits += 1 if request.done: print("Done fitting model for current episode") episodes_processed += 1 return mpm_pb2.Empty()
def SaveModel(self, request, context): print(f"Saving model to {model_file}") if not request.model_only: if episodes_processed >= 3: print(f"Saving reward plot to {figure_file}") plot_reward(rewards, figure_file) return mpm_pb2.Empty()
def DropClient(self, request, context): print("Client terminated training") return mpm_pb2.Empty()
def main(): global env global model global epsilon global episodes_processed global rewards if verbose: print("\nAction Space: ", env.action_space) print("Action Meanings: \n", env.get_action_meanings()) # print("Action Keys: \n", env.get_keys_to_action()) for i in range(n_episodes): print("Episode:", i) if verbose: print() state = env.reset() state = np.asarray(state) state = state.reshape((1, ) + state.shape + (1, )) done = False total_reward = 0 tick = 0 always_noop = False while not done: if render: env.render() action, action_type, always_noop = get_action( always_noop, epsilon, env.action_space.sample()) if run_as == Run.client: if action == -1: action_response = stub.PredictAction( mpm_pb2.StateRequest(state=pickle.dumps(state))) action = action_response.action next_state, current_reward, done, info = env.step(action) next_state = np.asarray(next_state) next_state = next_state.reshape((1, ) + next_state.shape + (1, )) stub.ProcessModel( mpm_pb2.ModelRequest(state=pickle.dumps(state), next_state=pickle.dumps(next_state), reward=current_reward, done=done)) elif run_as == Run.local: if action == -1: action = np.argmax(model.predict(state)) next_state, current_reward, done, info = env.step(action) next_state = np.asarray(next_state) next_state = next_state.reshape((1, ) + next_state.shape + (1, )) # Q-value for action target = current_reward + gamma * np.max( model.predict(next_state)) # Array of Q-values for all actions target_vec = model.predict(state)[0] # Change actions value to be target for fitting target_vec[action] = target, target_vec.reshape(-1, action_count), epochs=1, verbose=0) total_reward += current_reward # total_reward += 1 # Reward each survived tick if verbose: print("EP %i. ACTION: %9s%7s | REWARD: %4i | LIVES: %d" % (episodes_processed, env.get_action_meanings()[action], action_type, current_reward, info.get('ale.lives'))) state = next_state if done: rewards.append(total_reward) if verbose: print() print(f"Reward: {total_reward}\n") tick += 1 if run_as == Run.local: else: stub.SaveModel(mpm_pb2.SaveRequest(model_only=True)) if render: env.render() episodes_processed += 1 if epsilon > min_epsilon: epsilon *= decay print("Decayed epsilon to", epsilon) env.close() if run_as == Run.local: if n_episodes >= 3: plot_reward(rewards, figure_file) elif run_as == Run.client: stub.SaveModel(mpm_pb2.SaveRequest(model_only=False)) stub.DropClient(mpm_pb2.Empty())
if n_episodes >= 3: plot_reward(rewards, figure_file) elif run_as == Run.client: stub.SaveModel(mpm_pb2.SaveRequest(model_only=False)) stub.DropClient(mpm_pb2.Empty()) try: main() except KeyboardInterrupt: print("\nKEYBOARD INTERRUPT") try: if episodes_processed > 0: save = input("Save model data? [y/n] ") if save == 'y' or save == 'Y': if run_as == Run.local: env.close() print(f"Saving model to {model_file}") if episodes_processed >= 3: print(f"Saving reward plot to {figure_file}") plot_reward(rewards, figure_file) elif run_as == Run.client: env.close() stub.SaveModel(mpm_pb2.SaveRequest(model_only=False)) stub.DropClient(mpm_pb2.Empty()) elif run_as == Run.client: stub.DropClient(mpm_pb2.Empty()) sys.exit(0) except SystemExit: os._exit(0)