def __init__(self, env, network, device=0, obs_key=None, hist_size=5000, reward_func=None, **kwargs): # from surprise.envs.vizdoom.networks import VAEConv # from surprise.envs.vizdoom.buffer import VAEBuffer from surprise.envs.vizdoom.buffer import SimpleBuffer from surprise.envs.vizdoom.networks import VizdoomFeaturizer from rlkit.torch.networks import Mlp from torch import optim from gym import spaces ''' params ====== env (gym.Env) : environment to wrap ''' self.device = device self.env = env self._obs_key = obs_key self._reward_func = reward_func # Gym spaces self.action_space = env.action_space self.observation_space = env.observation_space # RND stuff self._buffer = SimpleBuffer(device=self.device, size=hist_size) if (kwargs["network_type"] == "flat"): self.target_net = Mlp( hidden_sizes=[128, 64], input_size=self.observation_space.low.size, output_size=64, ).to(self.device) self.target_net.eval() self.pred_net = Mlp( hidden_sizes=[128, 64, 32], input_size=self.observation_space.low.size, output_size=64, ).to(self.device) else: self.target_net = VizdoomFeaturizer(kwargs["encoding_size"]).to( self.device) self.target_net.eval() self.pred_net = VizdoomFeaturizer(kwargs["encoding_size"]).to( self.device) self.optimizer = optim.Adam(self.pred_net.parameters(), lr=1e-4) self.network = self.pred_net self.step_freq = 16 self.loss = torch.zeros(1)
def __init__( self, render=False, config_path='/home/gberseth/playground/BayesianSurpriseCode/surprise/envs/vizdoom/scenarios/take_cover.cfg', god=False, respawn=True): # Start game self.game = vzd.DoomGame() # Set sleep time (for visualization) self.sleep_time = 0 self.game.set_window_visible(render) self.sleep_time = .02 * int(render) # Game Configs self.game.load_config(config_path) self.game.set_screen_resolution(vzd.ScreenResolution.RES_640X480) self.game.set_render_screen_flashes(True) # Effect upon taking damage self.episode_length = 1000 self.skiprate = 2 self.game.set_episode_timeout(self.episode_length * self.skiprate) # Initialize the game self.game.init() # Actions are left or right self.actions = [[True, False], [False, True]] # Env Variables self.done = False self.time = 1 self.downsample_factor = .02 self.obs_hist = [self.get_random_state(res=(48, 64)) for _ in range(4)] self.god = god self.respawn = respawn self.deaths = 0 self.fireball = 0 # Spaces self.action_space = spaces.Discrete(2) self.observation_space = spaces.Box(0, self.episode_length, shape=(4, 48, 64)) # RND stuff self.buffer = SimpleBuffer(device=device) self.target_net = VizdoomFeaturizer(64).to(device) self.target_net.eval() self.pred_net = VizdoomFeaturizer(64).to(device) self.optimizer = optim.Adam(self.pred_net.parameters(), lr=1e-4) self.step_freq = 8 self.loss = torch.zeros(1) self.reset()
def __init__(self, env, network, device=0, obs_key=None, hist_size=3000, reward_func=None, **kwargs): # from surprise.envs.vizdoom.networks import VAEConv # from surprise.envs.vizdoom.buffer import VAEBuffer from surprise.envs.vizdoom.buffer import SimpleBuffer from surprise.envs.vizdoom.networks import MLP, VizdoomFeaturizer # from rlkit.torch.networks import Mlp from torch import optim from gym import spaces ''' params ====== env (gym.Env) : environment to wrap ''' self.device = device self.env = env self._obs_key = obs_key self._reward_func = reward_func # Gym spaces self.action_space = env.action_space self.observation_space = env.observation_space # RND stuff# from rlkit.torch.networks import Mlp print("hist_size: ", hist_size) self._buffer = SimpleBuffer(device=self.device, size=hist_size) if (kwargs["network_type"] == "flat"): self.forward_network = MLP([ self.observation_space.low.size + self.action_space.n, 64, self.observation_space.low.size ]).to(self.device) self.inverse_network = MLP([ self.observation_space.low.size * 2, 32, 16, self.action_space.n ], log_softmax=True).to(self.device) self.featurizer = MLP([ self.observation_space.low.size, 64, self.observation_space.low.size ]).to(self.device) else: self.forward_network = MLP([ self.observation_space.low.size + self.action_space.n, 64, 64, self.observation_space.low.size ]).to(self.device) self.inverse_network = MLP([ self.observation_space.low.size * 2, 32, 16, 8, self.action_space.n ], log_softmax=True).to(self.device) self.featurizer = VizdoomFeaturizer( self.observation_space.low.size).to(self.device) self.optimizer_forward = optim.Adam(self.forward_network.parameters(), lr=1e-4) self.optimizer_inverse = optim.Adam( list(self.inverse_network.parameters()) + list(self.featurizer.parameters()), lr=1e-4) self.network = self.forward_network self.step_freq = 16 self.inverse_loss = torch.zeros(1) self.forward_loss = torch.zeros(1) self.loss = torch.zeros(1)
class ICMWrapper(gym.Env): def __init__(self, env, network, device=0, obs_key=None, hist_size=3000, reward_func=None, **kwargs): # from surprise.envs.vizdoom.networks import VAEConv # from surprise.envs.vizdoom.buffer import VAEBuffer from surprise.envs.vizdoom.buffer import SimpleBuffer from surprise.envs.vizdoom.networks import MLP, VizdoomFeaturizer # from rlkit.torch.networks import Mlp from torch import optim from gym import spaces ''' params ====== env (gym.Env) : environment to wrap ''' self.device = device self.env = env self._obs_key = obs_key self._reward_func = reward_func # Gym spaces self.action_space = env.action_space self.observation_space = env.observation_space # RND stuff# from rlkit.torch.networks import Mlp print("hist_size: ", hist_size) self._buffer = SimpleBuffer(device=self.device, size=hist_size) if (kwargs["network_type"] == "flat"): self.forward_network = MLP([ self.observation_space.low.size + self.action_space.n, 64, self.observation_space.low.size ]).to(self.device) self.inverse_network = MLP([ self.observation_space.low.size * 2, 32, 16, self.action_space.n ], log_softmax=True).to(self.device) self.featurizer = MLP([ self.observation_space.low.size, 64, self.observation_space.low.size ]).to(self.device) else: self.forward_network = MLP([ self.observation_space.low.size + self.action_space.n, 64, 64, self.observation_space.low.size ]).to(self.device) self.inverse_network = MLP([ self.observation_space.low.size * 2, 32, 16, 8, self.action_space.n ], log_softmax=True).to(self.device) self.featurizer = VizdoomFeaturizer( self.observation_space.low.size).to(self.device) self.optimizer_forward = optim.Adam(self.forward_network.parameters(), lr=1e-4) self.optimizer_inverse = optim.Adam( list(self.inverse_network.parameters()) + list(self.featurizer.parameters()), lr=1e-4) self.network = self.forward_network self.step_freq = 16 self.inverse_loss = torch.zeros(1) self.forward_loss = torch.zeros(1) self.loss = torch.zeros(1) def step(self, action): if (self._obs_key is None): data = [ np.array(self._prev_obs), np.eye(self.action_space.n)[action], None ] else: data = [ np.array(self._prev_obs[self._obs_key]), np.eye(self.action_space.n)[action], None ] # Take Action obs, rew, done, info = self.env.step(action) # Finish off (s,a,s') tuplet and add to buffer if (self._obs_key is None): data[2] = np.array(obs) else: data[2] = np.array(obs[self._obs_key]) # print ("data: ", data) self._buffer.add(tuple(data)) # Get wrapper outputs # print ("data:", data) # Update network if self._time % self.step_freq == 0: self.step_model() obs = self.encode_obs(obs) info["task_reward"] = rew info.update(self.get_info()) done = self.get_done(done) self._time = self._time + 1 rew = self.get_rews(data) if (self._reward_func == "add"): rew = rew + info["task_reward"] # print("Add reward: ", rew) self._prev_obs = obs return obs, rew, done, info def step_model(self, batch_size=64): # Set phase self.forward_network.train() self.inverse_network.train() self.featurizer.train() # Get data (s,a,s') s, a, sp = self._buffer.sample(batch_size) s = torch.tensor(s).to(self.device).float() a = torch.tensor(a).to(self.device).float() sp = torch.tensor(sp).to(self.device).float() # print ("s.shape: ", s, a, sp) # featurize fs, fsp = self.featurizer(s), self.featurizer(sp) # inverse model action_pred = self.inverse_network(torch.cat([fs, fsp], dim=1)) # forward model state_pred = self.forward_network(torch.cat([fs, a], dim=1).detach()) # losses, save as class parameters to pass to info self.inverse_loss = F.nll_loss(action_pred, torch.argmax(a, dim=1)) self.forward_loss = F.mse_loss(state_pred, fsp) self.loss = self.inverse_loss + self.forward_loss # Step self.optimizer_forward.zero_grad() self.optimizer_inverse.zero_grad() self.loss.backward() self.optimizer_forward.step() self.optimizer_inverse.step() def get_obs(self, obs): ''' Augment observation, perhaps with generative model params ''' return obs def get_info(self): return { 'inverse_loss': self.inverse_loss.item(), 'forward_loss': self.forward_loss.item(), 'loss': self.loss.item(), } def get_done(self, env_done): ''' figure out if we're done params ====== env_done (bool) : done bool from the wrapped env, doesn't necessarily need to be used ''' return env_done def reset(self): ''' Reset the wrapped env and the buffer ''' import numpy as np self._time = 0 self._prev_obs = self.env.reset() ### Can't add data to buffer with action. # target = self.target_net(torch.tensor(obs).float().unsqueeze(0).to(self.device)) # data = [obs, target.cpu().detach().numpy()[0]] # self._buffer.add(tuple(data)) # self._buffer.add(obs) # print( "obs1: ", obs.shape) self._prev_obs = self.encode_obs(self._prev_obs) # print( "obs2: ", obs.shape) # step_skip = 10 # if len(self._buffer) > 0 and (np.random.rand() > (1/step_skip)): # # for _ in range(step_skip): # # print("Training VAE") # self._loss = self.step_vae(self.batch_size, self.steps) return self._prev_obs def render(self): self.env.render() def get_rews(self, sas): # Set phase self.featurizer.eval() self.forward_network.eval() # Unpack s, a, sp = sas # Convert to tensor s = torch.tensor(s).float().to(self.device).unsqueeze(0) a = torch.tensor(a).float().to(self.device).unsqueeze(0) sp = torch.tensor(sp).float().to(self.device).unsqueeze(0) # featurize fs, fsp = self.featurizer(s), self.featurizer(sp) # forward model state_pred = self.forward_network(torch.cat([fs, a], dim=1)) # losses, save as class parameters to pass to info return F.mse_loss(state_pred, fsp).item() def encode_obs(self, obs): ''' Used to encode the observation before putting on the buffer ''' return obs
class RNDWrapper(gym.Env): def __init__(self, env, network, device=0, obs_key=None, hist_size=5000, reward_func=None, **kwargs): # from surprise.envs.vizdoom.networks import VAEConv # from surprise.envs.vizdoom.buffer import VAEBuffer from surprise.envs.vizdoom.buffer import SimpleBuffer from surprise.envs.vizdoom.networks import VizdoomFeaturizer from rlkit.torch.networks import Mlp from torch import optim from gym import spaces ''' params ====== env (gym.Env) : environment to wrap ''' self.device = device self.env = env self._obs_key = obs_key self._reward_func = reward_func # Gym spaces self.action_space = env.action_space self.observation_space = env.observation_space # RND stuff self._buffer = SimpleBuffer(device=self.device, size=hist_size) if (kwargs["network_type"] == "flat"): self.target_net = Mlp( hidden_sizes=[128, 64], input_size=self.observation_space.low.size, output_size=64, ).to(self.device) self.target_net.eval() self.pred_net = Mlp( hidden_sizes=[128, 64, 32], input_size=self.observation_space.low.size, output_size=64, ).to(self.device) else: self.target_net = VizdoomFeaturizer(kwargs["encoding_size"]).to( self.device) self.target_net.eval() self.pred_net = VizdoomFeaturizer(kwargs["encoding_size"]).to( self.device) self.optimizer = optim.Adam(self.pred_net.parameters(), lr=1e-4) self.network = self.pred_net self.step_freq = 16 self.loss = torch.zeros(1) def step(self, action): # Take Action obs, rew, done, info = self.env.step(action) # Finish off (s,a,s') tuplet and add to buffer if (self._obs_key is None): target = self.target_net( torch.tensor(obs).float().unsqueeze(0).to(self.device)) data = [obs, target.detach().cpu().numpy()[0]] else: target = self.target_net( torch.tensor(obs[self._obs_key]).float().unsqueeze(0).to( self.device)) data = [obs[self._obs_key], target.detach().cpu().numpy()[0]] self._buffer.add(tuple(data)) # Get wrapper outputs # print ("data:", data) # Update network if self._time % self.step_freq == 0: self.step_model() obs = self.encode_obs(obs) info["rnd_loss"] = self.loss.item() info["task_reward"] = rew done = self.get_done(done) self._time = self._time + 1 rew = self.get_rews(data) if (self._reward_func == "add"): rew = rew + info["task_reward"] # print("Add reward: ", rew) return obs, rew, done, info def step_model(self, batch_size=64): # Set phase self.pred_net.train() # Get data (s,a,s') data, target = self._buffer.sample(batch_size) # for data_ in data: # print ("data: ", np.array(data_).shape) # Do i have to tensor-ify (i think so...) data = torch.tensor(data).to(self.device).float() target = torch.tensor(target).to(self.device).float() # forward model pred = self.pred_net(data) # losses, save as class parameters to pass to info self.loss = F.mse_loss(pred, target) # Step self.optimizer.zero_grad() self.loss.backward() self.optimizer.step() # print( "training RND") def get_obs(self, obs): ''' Augment observation, perhaps with generative model params ''' return obs def get_done(self, env_done): ''' figure out if we're done params ====== env_done (bool) : done bool from the wrapped env, doesn't necessarily need to be used ''' return env_done def reset(self): ''' Reset the wrapped env and the buffer ''' import numpy as np self._time = 0 obs = self.env.reset() if (self._obs_key is None): target = self.target_net( torch.tensor(obs).float().unsqueeze(0).to(self.device)) data = [obs, target.detach().cpu().numpy()[0]] else: target = self.target_net( torch.tensor(obs[self._obs_key]).float().unsqueeze(0).to( self.device)) data = [obs[self._obs_key], target.detach().cpu().numpy()[0]] self._buffer.add(tuple(data)) # self._buffer.add(obs) # print( "obs1: ", obs.shape) obs = self.encode_obs(obs) # print( "obs2: ", obs.shape) # step_skip = 10 # if len(self._buffer) > 0 and (np.random.rand() > (1/step_skip)): # # for _ in range(step_skip): # # print("Training VAE") # self._loss = self.step_vae(self.batch_size, self.steps) return obs def render(self): self.env.render() def get_rews(self, data): # Set phase self.pred_net.eval() # Convert to tensor data, target = data data = torch.tensor(data).float().to(self.device).unsqueeze(0) target = torch.tensor(target).float().to(self.device).unsqueeze(0) # forward model pred = self.pred_net(data) ### TODO: Add reward scaling via continuous std from RND paper # losses, save as class parameters to pass to info return F.mse_loss(pred, target).item() def encode_obs(self, obs): ''' Used to encode the observation before putting on the buffer ''' return obs
class TakeCoverEnv_RND(gym.Env): #def __init__(self, render=False, config_path='/home/dangengdg/minimalEntropy/tree_search/vizdoom/scenarios/take_cover.cfg', god=False, respawn=True): def __init__( self, render=False, config_path='/home/gberseth/playground/BayesianSurpriseCode/surprise/envs/vizdoom/scenarios/take_cover.cfg', god=False, respawn=True): # Start game self.game = vzd.DoomGame() # Set sleep time (for visualization) self.sleep_time = 0 self.game.set_window_visible(render) self.sleep_time = .02 * int(render) # Game Configs self.game.load_config(config_path) self.game.set_screen_resolution(vzd.ScreenResolution.RES_640X480) self.game.set_render_screen_flashes(True) # Effect upon taking damage self.episode_length = 1000 self.skiprate = 2 self.game.set_episode_timeout(self.episode_length * self.skiprate) # Initialize the game self.game.init() # Actions are left or right self.actions = [[True, False], [False, True]] # Env Variables self.done = False self.time = 1 self.downsample_factor = .02 self.obs_hist = [self.get_random_state(res=(48, 64)) for _ in range(4)] self.god = god self.respawn = respawn self.deaths = 0 self.fireball = 0 # Spaces self.action_space = spaces.Discrete(2) self.observation_space = spaces.Box(0, self.episode_length, shape=(4, 48, 64)) # RND stuff self.buffer = SimpleBuffer(device=device) self.target_net = VizdoomFeaturizer(64).to(device) self.target_net.eval() self.pred_net = VizdoomFeaturizer(64).to(device) self.optimizer = optim.Adam(self.pred_net.parameters(), lr=1e-4) self.step_freq = 8 self.loss = torch.zeros(1) self.reset() def get_random_state(self, res): ''' Get a random, gaussian state (roughly the average state) ''' return .3 + np.random.randn(*res) / 100 def encode(self, obs): ''' Encodes a (lowres) buffer observation ''' return obs def reset_game(self): # New episode self.game.new_episode() # Set invincible if self.god: self.game.send_game_command('god') def reset(self): self.reset_game() # Env Variables self.done = False self.time = 1 self.downsample_factor = .02 self.obs_hist = [self.get_random_state(res=(48, 64)) for _ in range(4)] self.deaths = 0 self.fireball = 0 # Losses self.loss = torch.zeros(1) return self.get_obs() def _render(self): state = self.game.get_state() return state.screen_buffer.mean(0) / 256.0 def render(self): state = self.game.get_state() state = state.screen_buffer.transpose(1, 2, 0) imsave('./rollouts/vizdoom-nogod/{:03d}.png'.format(self.time), state) def render_obsres(self): img = self._render() img = cv2.resize(img, (0, 0), fx=.1, fy=.1, interpolation=cv2.INTER_AREA) return img def render_lowres(self): img = self._render() img = cv2.resize(img, (0, 0), fx=self.downsample_factor, fy=self.downsample_factor, interpolation=cv2.INTER_AREA) return img def get_obs(self): # We can reshape in the network img_obs = np.array(self.obs_hist).flatten() return img_obs def get_rews(self, data): # Set phase self.pred_net.eval() # Convert to tensor data, target = data data = torch.tensor(data).float().to(device).unsqueeze(0) target = torch.tensor(target).float().to(device).unsqueeze(0) # forward model pred = self.pred_net(data) # losses, save as class parameters to pass to info return F.mse_loss(pred, target).item() def get_info(self): return { 'deaths': self.deaths, 'loss': self.loss.item(), 'fireball': self.fireball } def step(self, action): # Take action with skiprate r = self.game.make_action(self.actions[action], self.skiprate) # If died, then return if self.game.is_episode_finished(): if self.respawn: self.deaths += 1 self.reset_game() else: self.done = True return self.prev_obs, self.prev_rew, self.done, self.get_info() # For visualization if self.sleep_time > 0: sleep(self.sleep_time) # Increment time self.time += 1 # If episode finished, set done if self.time == self.episode_length: self.done = True # Add to obs hist self.obs_hist.append(self.render_obsres()) if len(self.obs_hist) > 4: self.obs_hist.pop(0) # Finish off (s,a,s') tuplet and add to buffer data = self.get_obs().reshape(4, 48, 64) target = self.target_net( torch.tensor(data).float().unsqueeze(0).to(device)) data = [data, target.cpu().detach().numpy()[0]] self.buffer.add(tuple(data)) # We need to save these b/c the doom env is weird # After dying, we can't get any observations self.prev_obs = self.get_obs() self.prev_rew = self.get_rews(data) # Update network if self.time % self.step_freq == 0: self.step_net() # Update fireball var self.fireball += int(self.game.get_state().screen_buffer.mean() > 120) return self.prev_obs, self.prev_rew, self.done, self.get_info() def step_net(self, batch_size=64): # Set phase self.pred_net.train() # Get data (s,a,s') data, target = self.buffer.sample(batch_size) # Do i have to tensor-ify (i think so...) data = torch.tensor(data).to(device).float() target = torch.tensor(target).to(device).float() # forward model pred = self.pred_net(data) # losses, save as class parameters to pass to info self.loss = F.mse_loss(pred, target) # Step self.optimizer.zero_grad() self.loss.backward() self.optimizer.step()
def __init__( self, render=False, config_path='/home/dangengdg/minimalEntropy/tree_search/vizdoom/scenarios/defend_the_line.cfg', god=True, respawn=True, skill=3, augment_obs=False, true_rew=False, joint_rew=False): #def __init__(self, render=False, config_path='/home/daniel/Documents/minimalEntropy/tree_search/vizdoom/scenarios/defend_the_line.cfg', god=True, respawn=True, skill=3, augment_obs=True, true_rew=False): # Start game self.game = vzd.DoomGame() # Set sleep time (for visualization) self.sleep_time = 0 self.game.set_window_visible(render) self.sleep_time = .02 * int(render) # Game Configs self.game.load_config(config_path) self.game.set_screen_resolution(vzd.ScreenResolution.RES_640X480) self.game.set_render_screen_flashes(True) # Effect upon taking damage self.game.set_doom_skill(skill) self.episode_length = 1000 self.skiprate = 2 self.game.set_episode_timeout(self.episode_length * self.skiprate) # Initialize the game self.game.init() # Actions are left or right self.actions = [ list(x) for x in np.eye( self.game.get_available_buttons_size()).astype(bool) ] # Env Variables self.done = False self.time = 1 self.downsample_factor = .02 self.obs_hist = [self.get_random_state(res=48 * 64) for _ in range(4)] self.god = god self.respawn = respawn self.deaths = 0 self.true_rew = true_rew self.joint_rew = joint_rew # Spaces self.action_space = spaces.Discrete(3) self.observation_space = spaces.Box(0, self.episode_length, shape=(12289, )) # RND Stuff self.buffer = SimpleBuffer(device=device) feature_dim = 32 self.target_net = VizdoomFeaturizer(feature_dim, qf=True).to(device) self.target_net.eval() self.predictor_net = VizdoomFeaturizer(feature_dim, qf=True).to(device) self.optimizer = optim.Adam(self.predictor_net.parameters(), lr=1e-4) self.step_freq = 8 self.loss = torch.zeros(1) self.reset()
class DefendTheLineEnv_RND(gym.Env): def __init__( self, render=False, config_path='/home/dangengdg/minimalEntropy/tree_search/vizdoom/scenarios/defend_the_line.cfg', god=True, respawn=True, skill=3, augment_obs=False, true_rew=False, joint_rew=False): #def __init__(self, render=False, config_path='/home/daniel/Documents/minimalEntropy/tree_search/vizdoom/scenarios/defend_the_line.cfg', god=True, respawn=True, skill=3, augment_obs=True, true_rew=False): # Start game self.game = vzd.DoomGame() # Set sleep time (for visualization) self.sleep_time = 0 self.game.set_window_visible(render) self.sleep_time = .02 * int(render) # Game Configs self.game.load_config(config_path) self.game.set_screen_resolution(vzd.ScreenResolution.RES_640X480) self.game.set_render_screen_flashes(True) # Effect upon taking damage self.game.set_doom_skill(skill) self.episode_length = 1000 self.skiprate = 2 self.game.set_episode_timeout(self.episode_length * self.skiprate) # Initialize the game self.game.init() # Actions are left or right self.actions = [ list(x) for x in np.eye( self.game.get_available_buttons_size()).astype(bool) ] # Env Variables self.done = False self.time = 1 self.downsample_factor = .02 self.obs_hist = [self.get_random_state(res=48 * 64) for _ in range(4)] self.god = god self.respawn = respawn self.deaths = 0 self.true_rew = true_rew self.joint_rew = joint_rew # Spaces self.action_space = spaces.Discrete(3) self.observation_space = spaces.Box(0, self.episode_length, shape=(12289, )) # RND Stuff self.buffer = SimpleBuffer(device=device) feature_dim = 32 self.target_net = VizdoomFeaturizer(feature_dim, qf=True).to(device) self.target_net.eval() self.predictor_net = VizdoomFeaturizer(feature_dim, qf=True).to(device) self.optimizer = optim.Adam(self.predictor_net.parameters(), lr=1e-4) self.step_freq = 8 self.loss = torch.zeros(1) self.reset() def get_random_state(self, res): ''' Get a random, gaussian state (roughly the average state) ''' return .3 + np.random.randn(res) / 100 def reset_game(self): # New episode self.game.new_episode() # Set invincible if self.god: self.game.send_game_command('god') def reset(self): self.reset_game() # Env Variables self.done = False self.time = 1 self.downsample_factor = .02 self.obs_hist = [self.get_random_state(res=48 * 64) for _ in range(4)] self.deaths = 0 # Losses self.loss = torch.zeros(1) return self.get_obs() def _render(self): state = self.game.get_state() return state.screen_buffer.mean(0) / 256.0 def render(self): try: state = self.game.get_state() state = state.screen_buffer.transpose(1, 2, 0) im = Image.fromarray(state) im.save('./rollouts/{:04d}.png'.format(self.time)) except: pass def render_obsres(self): img = self._render() img = cv2.resize(img, (0, 0), fx=.1, fy=.1, interpolation=cv2.INTER_AREA) return img def get_obs(self): # We can reshape in the network img_obs = np.array(self.obs_hist).flatten() return np.hstack([img_obs, self.time]) def get_rews(self, state): # Set phase self.predictor_net.eval() # Convert to tensor state = torch.tensor(state).float().to(device).unsqueeze(0) # forward model target = self.target_net(state).detach() pred = self.predictor_net(state) # Add data to buffer self.buffer.add(tuple([state, target])) # losses, save as class parameters to pass to info if self.joint_rew: return F.mse_loss(pred, target).item() + 5e-7 else: return F.mse_loss(pred, target).item() def get_info(self): return { 'lifespan': self.time, 'deaths': self.deaths, 'kills': self.game.get_state().game_variables[2], 'loss': self.loss.item() } def step(self, action): # If only shoot if we have enough bullets, else random op if action == 2: if self.game.get_state().game_variables[0] == 0: action = np.random.randint(2) # Get info before action, otherwise breaks on death info = self.get_info() # Take action with skiprate r = self.game.make_action(self.actions[action], self.skiprate) # If died, then return if self.game.is_episode_finished(): if self.respawn: self.deaths += 1 self.reset_game() else: self.done = True return self.prev_obs, self.prev_rew, self.done, info # For visualization if self.sleep_time > 0: sleep(self.sleep_time) # Increment time self.time += 1 # If episode finished, set done if self.time == self.episode_length: self.done = True # Add to obs hist self.obs_hist.append(self.render_obsres().flatten()) if len(self.obs_hist) > 4: self.obs_hist.pop(0) # We need to save these b/c the doom env is weird # After dying, we can't get any observations self.prev_obs = self.get_obs() # Get rew and add data to buffer self.prev_rew = self.get_rews(np.array(self.obs_hist)) # Update network if self.time % self.step_freq == 0: self.step_net() return self.prev_obs, self.prev_rew, self.done, info def step_net(self, batch_size=64): # Set phase self.predictor_net.train() # Get data (s,a,s') state, target = self.buffer.sample(batch_size) state = torch.cat(state, 0) target = torch.cat(target, 0) # Zero grad self.optimizer.zero_grad() # forward model pred = self.predictor_net(state) # losses, save as class parameters to pass to info self.loss = F.mse_loss(pred, target) # Step self.loss.backward() self.optimizer.step()