def plot_variance_vs_position(result_dir='variance_vs_position', network_name='vae_net', var_max=5.0): import network as network if not os.path.isdir(result_dir): os.mkdir(result_dir) game = wind_tunnel.WindTunnel() encodings = [] recons = [] images = [] losses = [] network.saver.restore(network.sess, './' + network_name + '.ckpt') for i in tqdm.tqdm(range(100)): game.agent = i game.generate_new_state() image = np.reshape(np.array(game.get_current_state()), [-1, 84, 84, 1]) images.append(image) [encoding_mean, recon, loss] = network.sess.run([network.mu_z, network.mu_x, network.loss], feed_dict={network.inp_image: image / 255.}) encodings.append(encoding_mean) losses.append(loss) recons.append(recon) losses = np.array(losses) losses = losses - np.min(losses) losses = losses / np.max(losses) encodings = np.concatenate(encodings, 0) variances = np.sqrt(np.sum((encodings**2), axis=1)) novelty = var_max * losses + variances * (1. - losses) f, ax = plt.subplots() ax.plot(range(100), variances, 'blue') ax.plot(range(100), losses, 'red') ax.plot(range(100), novelty, 'green') f.savefig(os.path.join(result_dir, 'variance_plot.png')) side_by_side_dir = os.path.join(result_dir, 'reconstructions') if not os.path.isdir(side_by_side_dir): os.mkdir(side_by_side_dir) for i in tqdm.tqdm(range(100)): vh.display_side_by_side(os.path.join(side_by_side_dir, '%5d' % i), images[i][0], recons[i][0])
def setup_wind_tunnel_env(): env = wind_tunnel.WindTunnel() num_actions = len(env.get_actions_for_state(None)) return env, num_actions
import numpy as np import wind_tunnel import toy_mr game = wind_tunnel.WindTunnel() class RandomAgent(object): def get_action(self, game): actions = game.get_actions_for_state(None) return np.random.choice(actions) def get_batch_func(batch_size, game, agent): states = [] for i in range(batch_size): if game.is_current_state_terminal(): game.reset_environment() action = agent.get_action(game) game.perform_action(action) states.append(game.get_current_state()[0]) return np.array(states) def setup_toy_mr_env(): env = toy_mr.ToyMR('../mr_maps/full_mr_map.txt') num_actions = len(env.get_actions_for_state(None)) return env, num_actions toy_mr_env, toy_mr_num_actions = setup_toy_mr_env()