num_actions = 25 if args.dem_context: train_buffer_file = '/scratch/ssd001/home/tkillian/ml4h2020_srl/raw_data_buffers/train_buffer' validation_buffer_file = '/scratch/ssd001/home/tkillian/ml4h2020_srl/raw_data_buffers/val_buffer' else: train_buffer_file = '/scratch/ssd001/home/tkillian/ml4h2020_srl/raw_data_buffers/train_noCntxt_buffer' validation_buffer_file = '/scratch/ssd001/home/tkillian/ml4h2020_srl/raw_data_buffers/val_noCntxt_buffer' storage_dir = '/scratch/ssd001/home/tkillian/ml4h2020_srl/BehavCloning/' + args.storage_folder + '/' if not os.path.exists(storage_dir): os.mkdir(storage_dir) # Initialize and load the training and validation buffers to populate dataloaders train_buffer = ReplayBuffer(input_dim, args.batch_size, 200000, device) train_buffer.load(train_buffer_file) states = train_buffer.state[:train_buffer.crt_size] actions = train_buffer.action[:train_buffer.crt_size] train_dataset = TensorDataset( torch.from_numpy(states).float(), torch.from_numpy(actions).long()) train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) val_buffer = ReplayBuffer(input_dim, args.batch_size, 50000, device) val_buffer.load(validation_buffer_file) val_states = val_buffer.state[:val_buffer.crt_size] val_actions = val_buffer.action[:val_buffer.crt_size] val_dataset = TensorDataset( torch.from_numpy(val_states).float(),
class DDQNLearner(DDQN): def __init__(self, env, save_dirs, save_freq=10000, gamma=0.99, batch_size=32, learning_rate=0.0001, buffer_size=10000, learn_start=10000, target_network_update_freq=1000, train_freq=4, epsilon_min=0.01, exploration_fraction=0.1, tot_steps=int(1e7)): DDQN.__init__(self, env=env, save_dirs=save_dirs, learning_rate=learning_rate) self.gamma = gamma self.batch_size = batch_size self.learning_rate = learning_rate self.buffer_size = buffer_size self.learn_start = learn_start self.target_network_update_freq = target_network_update_freq self.train_freq = train_freq self.epsilon_min = epsilon_min self.exploration_fraction = exploration_fraction self.tot_steps = tot_steps self.epsilon = 1.0 self.exploration = LinearSchedule(schedule_timesteps=int( self.exploration_fraction * self.tot_steps), initial_p=self.epsilon, final_p=self.epsilon_min) self.save_freq = save_freq self.replay_buffer = ReplayBuffer(save_dirs=save_dirs, buffer_size=self.buffer_size, obs_shape=self.input_shape) self.exploration_factor_save_path = os.path.join( self.save_path, 'exploration-factor.npz') self.target_model_save_path = os.path.join(self.save_path, 'target-wts.h5') self.target_model = NeuralNet(input_shape=self.input_shape, num_actions=self.num_actions, learning_rate=learning_rate, blueprint=self.blueprint).model self.show_hyperparams() self.update_target() self.load() def update_exploration(self, t): self.epsilon = self.exploration.value(t) def update_target(self): self.target_model.set_weights(self.local_model.get_weights()) def remember(self, obs, action, rew, new_obs, done): self.replay_buffer.add(obs, action, rew, new_obs, done) def step_update(self, t): hist = None if t <= self.learn_start: return hist if t % self.train_freq == 0: hist = self.learn() if t % self.target_network_update_freq == 0: self.update_target() return hist def act(self, obs): if np.random.rand() < self.epsilon: return self.env.action_space.sample() q_vals = self.local_model.predict( np.expand_dims(obs, axis=0).astype(float) / 255, batch_size=1) return np.argmax(q_vals[0]) def learn(self): if self.replay_buffer.meta_data['fill_size'] < self.batch_size: return curr_obs, action, reward, next_obs, done = self.replay_buffer.get_minibatch( self.batch_size) target = self.local_model.predict(curr_obs.astype(float) / 255, batch_size=self.batch_size) done_mask = done.ravel() undone_mask = np.invert(done).ravel() target[done_mask, action[done_mask].ravel()] = reward[done_mask].ravel() Q_target = self.target_model.predict(next_obs.astype(float) / 255, batch_size=self.batch_size) Q_future = np.max(Q_target[undone_mask], axis=1) target[undone_mask, action[undone_mask].ravel( )] = reward[undone_mask].ravel() + self.gamma * Q_future hist = self.local_model.fit(curr_obs.astype(float) / 255, target, batch_size=self.batch_size, verbose=0).history return hist def load_mdl(self): super().load_mdl() if os.path.isfile(self.target_model_save_path): self.target_model.load_weights(self.target_model_save_path) print('Loaded Target Model...') else: print('No existing Target Model found...') def save_mdl(self): self.local_model.save_weights(self.local_model_save_path) print('Local Model Saved...') self.target_model.save_weights(self.target_model_save_path) print('Target Model Saved...') def save_exploration(self): np.savez(self.exploration_factor_save_path, exploration=self.epsilon) print('Exploration Factor Saved...') def load_exploration(self): if os.path.isfile(self.exploration_factor_save_path): with np.load(self.exploration_factor_save_path) as f: self.epsilon = np.asscalar(f['exploration']) print('Exploration Factor Loaded...') else: print('No existing Exploration Factor found...') def save(self, t, logger): ep = logger.data['episode'] if (self.save_freq is not None and t > self.learn_start and ep > 100 and t % self.save_freq == 0): if logger.update_best_score(): logger.save_state() self.save_mdl() self.save_exploration() self.replay_buffer.save() def load(self): self.load_mdl() self.load_exploration() self.replay_buffer.load() def show_hyperparams(self): print('Discount Factor (gamma): {}'.format(self.gamma)) print('Batch Size: {}'.format(self.batch_size)) print('Replay Buffer Size: {}'.format(self.buffer_size)) print('Training Frequency: {}'.format(self.train_freq)) print('Target network update Frequency: {}'.format( self.target_network_update_freq)) print('Replay start size: {}'.format(self.learn_start))