class DDQNLearner(DDQN): def __init__(self, env, save_dirs, save_freq=10000, gamma=0.99, batch_size=32, learning_rate=0.0001, buffer_size=10000, learn_start=10000, target_network_update_freq=1000, train_freq=4, epsilon_min=0.01, exploration_fraction=0.1, tot_steps=int(1e7)): DDQN.__init__(self, env=env, save_dirs=save_dirs, learning_rate=learning_rate) self.gamma = gamma self.batch_size = batch_size self.learning_rate = learning_rate self.buffer_size = buffer_size self.learn_start = learn_start self.target_network_update_freq = target_network_update_freq self.train_freq = train_freq self.epsilon_min = epsilon_min self.exploration_fraction = exploration_fraction self.tot_steps = tot_steps self.epsilon = 1.0 self.exploration = LinearSchedule(schedule_timesteps=int( self.exploration_fraction * self.tot_steps), initial_p=self.epsilon, final_p=self.epsilon_min) self.save_freq = save_freq self.replay_buffer = ReplayBuffer(save_dirs=save_dirs, buffer_size=self.buffer_size, obs_shape=self.input_shape) self.exploration_factor_save_path = os.path.join( self.save_path, 'exploration-factor.npz') self.target_model_save_path = os.path.join(self.save_path, 'target-wts.h5') self.target_model = NeuralNet(input_shape=self.input_shape, num_actions=self.num_actions, learning_rate=learning_rate, blueprint=self.blueprint).model self.show_hyperparams() self.update_target() self.load() def update_exploration(self, t): self.epsilon = self.exploration.value(t) def update_target(self): self.target_model.set_weights(self.local_model.get_weights()) def remember(self, obs, action, rew, new_obs, done): self.replay_buffer.add(obs, action, rew, new_obs, done) def step_update(self, t): hist = None if t <= self.learn_start: return hist if t % self.train_freq == 0: hist = self.learn() if t % self.target_network_update_freq == 0: self.update_target() return hist def act(self, obs): if np.random.rand() < self.epsilon: return self.env.action_space.sample() q_vals = self.local_model.predict( np.expand_dims(obs, axis=0).astype(float) / 255, batch_size=1) return np.argmax(q_vals[0]) def learn(self): if self.replay_buffer.meta_data['fill_size'] < self.batch_size: return curr_obs, action, reward, next_obs, done = self.replay_buffer.get_minibatch( self.batch_size) target = self.local_model.predict(curr_obs.astype(float) / 255, batch_size=self.batch_size) done_mask = done.ravel() undone_mask = np.invert(done).ravel() target[done_mask, action[done_mask].ravel()] = reward[done_mask].ravel() Q_target = self.target_model.predict(next_obs.astype(float) / 255, batch_size=self.batch_size) Q_future = np.max(Q_target[undone_mask], axis=1) target[undone_mask, action[undone_mask].ravel( )] = reward[undone_mask].ravel() + self.gamma * Q_future hist = self.local_model.fit(curr_obs.astype(float) / 255, target, batch_size=self.batch_size, verbose=0).history return hist def load_mdl(self): super().load_mdl() if os.path.isfile(self.target_model_save_path): self.target_model.load_weights(self.target_model_save_path) print('Loaded Target Model...') else: print('No existing Target Model found...') def save_mdl(self): self.local_model.save_weights(self.local_model_save_path) print('Local Model Saved...') self.target_model.save_weights(self.target_model_save_path) print('Target Model Saved...') def save_exploration(self): np.savez(self.exploration_factor_save_path, exploration=self.epsilon) print('Exploration Factor Saved...') def load_exploration(self): if os.path.isfile(self.exploration_factor_save_path): with np.load(self.exploration_factor_save_path) as f: self.epsilon = np.asscalar(f['exploration']) print('Exploration Factor Loaded...') else: print('No existing Exploration Factor found...') def save(self, t, logger): ep = logger.data['episode'] if (self.save_freq is not None and t > self.learn_start and ep > 100 and t % self.save_freq == 0): if logger.update_best_score(): logger.save_state() self.save_mdl() self.save_exploration() self.replay_buffer.save() def load(self): self.load_mdl() self.load_exploration() self.replay_buffer.load() def show_hyperparams(self): print('Discount Factor (gamma): {}'.format(self.gamma)) print('Batch Size: {}'.format(self.batch_size)) print('Replay Buffer Size: {}'.format(self.buffer_size)) print('Training Frequency: {}'.format(self.train_freq)) print('Target network update Frequency: {}'.format( self.target_network_update_freq)) print('Replay start size: {}'.format(self.learn_start))
class Experiment(object): def __init__(self, domain, train_data_file, validation_data_file, test_data_file, minibatch_size, rng, device, behav_policy_file_wDemo, behav_policy_file, context_input=False, context_dim=0, drop_smaller_than_minibatch=True, folder_name='/Name', autoencoder_saving_period=20, resume=False, sided_Q='negative', autoencoder_num_epochs=50, autoencoder_lr=0.001, autoencoder='AIS', hidden_size=16, ais_gen_model=1, ais_pred_model=1, embedding_dim=4, state_dim=42, num_actions=25, corr_coeff_param=10, dst_hypers={}, cde_hypers={}, odernn_hypers={}, **kwargs): ''' We assume discrete actions and scalar rewards! ''' self.rng = rng self.device = device self.train_data_file = train_data_file self.validation_data_file = validation_data_file self.test_data_file = test_data_file self.minibatch_size = minibatch_size self.drop_smaller_than_minibatch = drop_smaller_than_minibatch self.autoencoder_num_epochs = autoencoder_num_epochs self.autoencoder = autoencoder self.autoencoder_lr = autoencoder_lr self.saving_period = autoencoder_saving_period self.resume = resume self.sided_Q = sided_Q self.num_actions = num_actions self.state_dim = state_dim self.corr_coeff_param = corr_coeff_param self.context_input = context_input # Check to see if we'll one-hot encode the categorical contextual input self.context_dim = context_dim # Check to see if we'll remove the context from the input and only use it for decoding self.hidden_size = hidden_size if self.context_input: self.input_dim = self.state_dim + self.context_dim + self.num_actions else: self.input_dim = self.state_dim + self.num_actions self.autoencoder_lower = self.autoencoder.lower() self.data_folder = folder_name + f'/{self.autoencoder_lower}_data' self.checkpoint_file = folder_name + f'/{self.autoencoder_lower}_checkpoints/checkpoint.pt' if not os.path.exists(folder_name + f'/{self.autoencoder_lower}_checkpoints'): os.mkdir(folder_name + f'/{self.autoencoder_lower}_checkpoints') if not os.path.exists(folder_name + f'/{self.autoencoder_lower}_data'): os.mkdir(folder_name + f'/{self.autoencoder_lower}_data') self.store_path = folder_name self.gen_file = folder_name + f'/{self.autoencoder_lower}_data/{self.autoencoder_lower}_gen.pt' self.pred_file = folder_name + f'/{self.autoencoder_lower}_data/{self.autoencoder_lower}_pred.pt' if self.autoencoder == 'AIS': self.container = AIS.ModelContainer(device, ais_gen_model, ais_pred_model) self.gen = self.container.make_encoder( self.hidden_size, self.state_dim, self.num_actions, context_input=self.context_input, context_dim=self.context_dim) self.pred = self.container.make_decoder(self.hidden_size, self.state_dim, self.num_actions) elif self.autoencoder == 'AE': self.container = AE.ModelContainer(device) self.gen = self.container.make_encoder( self.hidden_size, self.state_dim, self.num_actions, context_input=self.context_input, context_dim=self.context_dim) self.pred = self.container.make_decoder(self.hidden_size, self.state_dim, self.num_actions) elif self.autoencoder == 'DST': self.dst_hypers = dst_hypers self.container = DST.ModelContainer(device) self.gen = self.container.make_encoder( self.input_dim, self.hidden_size, gru_n_layers=self.dst_hypers['gru_n_layers'], augment_chs=self.dst_hypers['augment_chs']) self.pred = self.container.make_decoder( self.hidden_size, self.state_dim, self.dst_hypers['decoder_hidden_units']) elif self.autoencoder == 'DDM': self.container = DDM.ModelContainer(device) self.gen = self.container.make_encoder( self.state_dim, self.hidden_size, context_input=self.context_input, context_dim=self.context_dim) self.pred = self.container.make_decoder(self.state_dim, self.hidden_size) self.dyn = self.container.make_dyn(self.num_actions, self.hidden_size) self.all_params = chain(self.gen.parameters(), self.pred.parameters(), self.dyn.parameters()) self.inv_loss_coef = 10 self.dec_loss_coef = 0.1 self.max_grad_norm = 50 self.dyn_file = folder_name + '/ddm_data/ddm_dyn.pt' elif self.autoencoder == 'RNN': self.container = RNN.ModelContainer(device) self.gen = self.container.make_encoder( self.hidden_size, self.state_dim, self.num_actions, context_input=self.context_input, context_dim=self.context_dim) self.pred = self.container.make_decoder(self.hidden_size, self.state_dim, self.num_actions) elif self.autoencoder == 'CDE': self.cde_hypers = cde_hypers self.container = CDE.ModelContainer(device) self.gen = self.container.make_encoder( self.input_dim + 1, self.hidden_size, hidden_hidden_channels=self. cde_hypers['encoder_hidden_hidden_channels'], num_hidden_layers=self.cde_hypers['encoder_num_hidden_layers']) self.pred = self.container.make_decoder( self.hidden_size, self.state_dim, self.cde_hypers['decoder_num_layers'], self.cde_hypers['decoder_num_units']) elif self.autoencoder == 'ODERNN': self.odernn_hypers = odernn_hypers self.container = ODERNN.ModelContainer(device) self.gen = self.container.make_encoder(self.input_dim, self.hidden_size, self.odernn_hypers) self.pred = self.container.make_decoder( self.hidden_size, self.state_dim, self.odernn_hypers['decoder_n_layers'], self.odernn_hypers['decoder_n_units']) else: raise NotImplementedError self.buffer_save_file = self.data_folder + '/ReplayBuffer' self.next_obs_pred_errors_file = self.data_folder + '/test_next_obs_pred_errors.pt' self.test_representations_file = self.data_folder + '/test_representations.pt' self.test_correlations_file = self.data_folder + '/test_correlations.pt' self.policy_eval_save_file = self.data_folder + '/dBCQ_policy_eval' self.policy_save_file = self.data_folder + '/dBCQ_policy' self.behav_policy_file_wDemo = behav_policy_file_wDemo self.behav_policy_file = behav_policy_file # Read in the data csv files assert (domain == 'sepsis') self.train_demog, self.train_states, self.train_interventions, self.train_lengths, self.train_times, self.acuities, self.rewards = torch.load( self.train_data_file) train_idx = torch.arange(self.train_demog.shape[0]) self.train_dataset = TensorDataset(self.train_demog, self.train_states, self.train_interventions, self.train_lengths, self.train_times, self.acuities, self.rewards, train_idx) self.train_loader = DataLoader(self.train_dataset, batch_size=self.minibatch_size, shuffle=True) self.val_demog, self.val_states, self.val_interventions, self.val_lengths, self.val_times, self.val_acuities, self.val_rewards = torch.load( self.validation_data_file) val_idx = torch.arange(self.val_demog.shape[0]) self.val_dataset = TensorDataset(self.val_demog, self.val_states, self.val_interventions, self.val_lengths, self.val_times, self.val_acuities, self.val_rewards, val_idx) self.val_loader = DataLoader(self.val_dataset, batch_size=self.minibatch_size, shuffle=False) self.test_demog, self.test_states, self.test_interventions, self.test_lengths, self.test_times, self.test_acuities, self.test_rewards = torch.load( self.test_data_file) test_idx = torch.arange(self.test_demog.shape[0]) self.test_dataset = TensorDataset(self.test_demog, self.test_states, self.test_interventions, self.test_lengths, self.test_times, self.test_acuities, self.test_rewards, test_idx) self.test_loader = DataLoader(self.test_dataset, batch_size=self.minibatch_size, shuffle=False) # encode CDE data first to save time if self.autoencoder == 'CDE': self.train_coefs = load_cde_data('train', self.train_dataset, self.cde_hypers['coefs_folder'], self.context_input, device) self.val_coefs = load_cde_data('val', self.val_dataset, self.cde_hypers['coefs_folder'], self.context_input, device) self.test_coefs = load_cde_data('test', self.test_dataset, self.cde_hypers['coefs_folder'], self.context_input, device) def load_model_from_checkpoint(self, checkpoint_file_path): checkpoint = torch.load(checkpoint_file_path) self.gen.load_state_dict(checkpoint['{}_gen_state_dict'.format( self.autoencoder.lower())]) self.pred.load_state_dict(checkpoint['{}_pred_state_dict'.format( self.autoencoder.lower())]) if self.autoencoder == 'DDM': self.dyn.load_state_dict(checkpoint['{}_dyn_state_dict'.format( self.autoencoder.lower())]) print("Experiment: generator and predictor models loaded.") def train_autoencoder(self): print('Experiment: training autoencoder') device = self.device if self.autoencoder != 'DDM': self.optimizer = torch.optim.Adam(list(self.gen.parameters()) + list(self.pred.parameters()), lr=self.autoencoder_lr, amsgrad=True) else: self.optimizer = torch.optim.Adam(list(self.gen.parameters()) + list(self.pred.parameters()) + list(self.dyn.parameters()), lr=self.autoencoder_lr, amsgrad=True) self.autoencoding_losses = [] self.autoencoding_losses_validation = [] if self.resume: # Need to rebuild this to resume training for 400 additional epochs if feasible... try: checkpoint = torch.load(self.checkpoint_file) self.gen.load_state_dict(checkpoint['gen_state_dict']) self.pred.load_state_dict(checkpoint['pred_state_dict']) if self.autoencoder == 'DDM': self.dyn.load_state_dict(checkpoint['dyn_state_dict']) self.optimizer.load_state_dict( checkpoint['optimizer_state_dict']) epoch_0 = checkpoint['epoch'] + 1 self.autoencoding_losses = checkpoint['loss'] self.autoencoding_losses_validation = checkpoint[ 'validation_loss'] print( 'Starting from epoch: {0} and continuing up to epoch {1}'. format(epoch_0, self.autoencoder_num_epochs)) except: epoch_0 = 0 print( 'Error loading file, training from default setting. epoch_0 = 0' ) else: epoch_0 = 0 for epoch in range(epoch_0, self.autoencoder_num_epochs): epoch_loss = [] print( "Experiment: autoencoder {0}: training Epoch = ".format( self.autoencoder), epoch + 1, 'out of', self.autoencoder_num_epochs, 'epochs') # Loop through the data using the data loader for ii, (dem, ob, ac, l, t, scores, rewards, idx) in enumerate(self.train_loader): # print("Batch {}".format(ii),end='') dem = dem.to( device ) # 5 dimensional vector (Gender, Ventilation status, Re-admission status, Age, Weight) ob = ob.to( device) # 33 dimensional vector (time varying measures) ac = ac.to(device) l = l.to(device) t = t.to(device) scores = scores.to(device) idx = idx.to(device) loss_pred = 0 # Cut tensors down to the batch's largest sequence length... Trying to speed things up a bit... max_length = int(l.max().item()) # The following losses are for DDM and will not be modified by any other approach train_loss, dec_loss, inv_loss = 0, 0, 0 model_loss, recon_loss, forward_loss = 0, 0, 0 self.gen.train() self.pred.train() ob = ob[:, :max_length, :] dem = dem[:, :max_length, :] ac = ac[:, :max_length, :] scores = scores[:, :max_length, :] if self.autoencoder == 'CDE': loss_pred, mse_loss, _ = self.container.loop( ob, dem, ac, scores, l, max_length, self.context_input, corr_coeff_param=self.corr_coeff_param, device=device, coefs=self.train_coefs, idx=idx) else: loss_pred, mse_loss, _ = self.container.loop( ob, dem, ac, scores, l, max_length, self.context_input, corr_coeff_param=self.corr_coeff_param, device=device, autoencoder=self.autoencoder) self.optimizer.zero_grad() if self.autoencoder != 'DDM': loss_pred.backward() self.optimizer.step() epoch_loss.append(loss_pred.detach().cpu().numpy()) else: train_loss, dec_loss, inv_loss, model_loss, recon_loss, forward_loss, corr_loss, loss_pred = loss_pred train_loss = forward_loss + self.inv_loss_coef * inv_loss + self.dec_loss_coef * dec_loss - self.corr_coeff_param * corr_loss.sum( ) train_loss.backward() torch.nn.utils.clip_grad_norm(self.all_params, self.max_grad_norm) self.optimizer.step() epoch_loss.append(loss_pred.detach().cpu().numpy()) self.autoencoding_losses.append(epoch_loss) if ( epoch + 1 ) % self.saving_period == 0: # Run validation and also save checkpoint #Computing validation loss epoch_validation_loss = [] with torch.no_grad(): for jj, (dem, ob, ac, l, t, scores, rewards, idx) in enumerate(self.val_loader): dem = dem.to(device) ob = ob.to(device) ac = ac.to(device) l = l.to(device) t = t.to(device) idx = idx.to(device) scores = scores.to(device) loss_val = 0 # Cut tensors down to the batch's largest sequence length... Trying to speed things up a bit... max_length = int(l.max().item()) ob = ob[:, :max_length, :] dem = dem[:, :max_length, :] ac = ac[:, :max_length, :] scores = scores[:, :max_length, :] self.gen.eval() self.pred.eval() if self.autoencoder == 'CDE': loss_val, mse_loss, _ = self.container.loop( ob, dem, ac, scores, l, max_length, corr_coeff_param=0, device=device, coefs=self.val_coefs, idx=idx) else: loss_val, mse_loss, _ = self.container.loop( ob, dem, ac, scores, l, max_length, self.context_input, corr_coeff_param=0, device=device, autoencoder=self.autoencoder) if self.autoencoder in ['DST', 'ODERNN', 'CDE']: epoch_validation_loss.append(mse_loss) elif self.autoencoder == "DDM": epoch_validation_loss.append( loss_val[-1].detach().cpu().numpy()) else: epoch_validation_loss.append( loss_val.detach().cpu().numpy()) self.autoencoding_losses_validation.append( epoch_validation_loss) save_dict = { 'epoch': epoch, 'gen_state_dict': self.gen.state_dict(), 'pred_state_dict': self.pred.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'loss': self.autoencoding_losses, 'validation_loss': self.autoencoding_losses_validation } if self.autoencoder == 'DDM': save_dict['dyn_state_dict'] = self.dyn.state_dict() try: torch.save(save_dict, self.checkpoint_file) # torch.save(save_dict, self.checkpoint_file[:-3] + str(epoch) +'_.pt') np.save( self.data_folder + '/{}_losses.npy'.format(self.autoencoder.lower()), np.array(self.autoencoding_losses)) except Exception as e: print(e) try: np.save( self.data_folder + '/{}_validation_losses.npy'.format( self.autoencoder.lower()), np.array(self.autoencoding_losses_validation)) except Exception as e: print(e) #Final epoch checkpoint try: save_dict = { 'epoch': self.autoencoder_num_epochs - 1, 'gen_state_dict': self.gen.state_dict(), 'pred_state_dict': self.pred.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'loss': self.autoencoding_losses, 'validation_loss': self.autoencoding_losses_validation, } if self.autoencoder == 'DDM': save_dict['dyn_state_dict'] = self.dyn.state_dict() torch.save(self.dyn.state_dict(), self.dyn_file) torch.save(self.gen.state_dict(), self.gen_file) torch.save(self.pred.state_dict(), self.pred_file) torch.save(save_dict, self.checkpoint_file) np.save( self.data_folder + '/{}_losses.npy'.format(self.autoencoder.lower()), np.array(self.autoencoding_losses)) except Exception as e: print(e) def evaluate_trained_model(self): '''After training, this method can be called to use the trained autoencoder to embed all the data in the representation space. We encode all data subsets (train, validation and test) separately and save them off as independent tuples. We then will also combine these subsets to populate a replay buffer to train a policy from. This method will also evaluate the decoder's ability to correctly predict the next observation from the and also will evaluate the trained representation's correlation with the acuity scores. ''' # Initialize the replay buffer self.replay_buffer = ReplayBuffer( self.hidden_size, self.minibatch_size, 350000, self.device, encoded_state=True, obs_state_dim=self.state_dim + (self.context_dim if self.context_input else 0)) errors = [] correlations = torch.Tensor() test_representations = torch.Tensor() print('Encoding the Training and Validataion Data.') ## LOOP THROUGH THE DATA # ----------------------------------------------- # For Training and Validation sets (Encode the observations only, add all data to the experience replay buffer) # For the Test set: # - Encode the observations # - Save off the data (as test tuples and place in the experience replay buffer) # - Evaluate accuracy of predicting the next observation using the decoder module of the model # - Evaluate the correlation coefficient between the learned representations and the acuity scores with torch.no_grad(): for i_set, loader in enumerate( [self.train_loader, self.val_loader, self.test_loader]): if i_set == 2: print( 'Encoding the Test Data. Evaluating prediction accuracy. Calculating Correlation Coefficients.' ) for dem, ob, ac, l, t, scores, rewards, idx in loader: dem = dem.to(self.device) ob = ob.to(self.device) ac = ac.to(self.device) l = l.to(self.device) t = t.to(self.device) scores = scores.to(self.device) rewards = rewards.to(self.device) max_length = int(l.max().item()) ob = ob[:, :max_length, :] dem = dem[:, :max_length, :] ac = ac[:, :max_length, :] scores = scores[:, :max_length, :] rewards = rewards[:, :max_length] cur_obs, next_obs = ob[:, :-1, :], ob[:, 1:, :] cur_dem, next_dem = dem[:, :-1, :], dem[:, 1:, :] cur_actions = ac[:, :-1, :] cur_rewards = rewards[:, :-1] cur_scores = scores[:, :-1, :] mask = (cur_obs == 0).all(dim=2) self.gen.eval() self.pred.eval() if self.autoencoder in ['AE', 'AIS', 'RNN']: if self.context_input: representations = self.gen( torch.cat( (cur_obs, cur_dem, torch.cat((torch.zeros( (ob.shape[0], 1, ac.shape[-1])).to( self.device), ac[:, :-2, :]), dim=1)), dim=-1)) else: representations = self.gen( torch.cat( (cur_obs, torch.cat((torch.zeros( (ob.shape[0], 1, ac.shape[-1])).to( self.device), ac[:, :-2, :]), dim=1)), dim=-1)) if self.autoencoder == 'RNN': pred_obs = self.pred(representations) else: pred_obs = self.pred( torch.cat((representations, cur_actions), dim=-1)) pred_error = F.mse_loss(next_obs[~mask], pred_obs[~mask]) elif self.autoencoder == 'DDM': # Initialize hidden states for the LSTM layer cx_d = torch.zeros(1, ob.shape[0], self.hidden_size).to(self.device) hx_d = torch.zeros(1, ob.shape[0], self.hidden_size).to(self.device) if self.context_input: representations = self.gen( torch.cat((cur_obs, cur_dem), dim=-1)) z_prime = self.gen( torch.cat((next_obs, next_dem), dim=-1)) else: representations = self.gen(cur_obs) z_prime = self.gen(next_obs) s_hat = self.pred(representations) z_prime_hat, a_hat, _ = self.dyn( (representations, z_prime, cur_actions, (hx_d, cx_d))) s_prime_hat = self.pred(z_prime_hat) __, pred_error, __, __, __ = get_dynamics_losses( cur_obs[~mask], s_hat[~mask], next_obs[~mask], s_prime_hat[~mask], z_prime[~mask], z_prime_hat[~mask], a_hat[~mask], cur_actions[~mask], discrete=False) elif self.autoencoder in ['DST', 'ODERNN']: _, pred_error, representations = self.container.loop( ob, dem, ac, scores, l, max_length, self.context_input, corr_coeff_param=0, device=self.device) representations = representations[:, :-1, :].detach( ) # remove latent of last time step (with no target) elif self.autoencoder == 'CDE': i_coefs = (self.train_coefs, self.val_coefs, self.test_coefs)[i_set] _, pred_error, representations = self.container.loop( ob, dem, ac, scores, l, max_length, self.context_input, corr_coeff_param=0, device=self.device, coefs=i_coefs, idx=idx) representations = representations[:, :-1, :].detach() if i_set == 2: # If we're evaluating the models on the test set... # Compute the Pearson correlation of the learned representations and the acuity scores corr = torch.zeros( (cur_obs.shape[0], representations.shape[-1], cur_scores.shape[-1])) for i in range(cur_obs.shape[0]): corr[i] = pearson_correlation( representations[i][~mask[i]], cur_scores[i][~mask[i]], device=self.device) # Concatenate this batch's correlations with the larger tensor correlations = torch.cat((correlations, corr), dim=0) # Concatenate the batch's representations with the larger tensor test_representations = torch.cat( (test_representations, representations.cpu()), dim=0) # Append the batch's prediction errors to the list if torch.is_tensor(pred_error): errors.append(pred_error.item()) else: errors.append(pred_error) # Remove values with the computed mask and add data to the experience replay buffer cur_rep = torch.cat( (representations[:, :-1, :], torch.zeros( (cur_obs.shape[0], 1, self.hidden_size)).to( self.device)), dim=1) next_rep = torch.cat( (representations[:, 1:, :], torch.zeros( (cur_obs.shape[0], 1, self.hidden_size)).to( self.device)), dim=1) cur_rep = cur_rep[~mask].cpu() next_rep = next_rep[~mask].cpu() cur_actions = cur_actions[~mask].cpu() cur_rewards = cur_rewards[~mask].cpu() cur_obs = cur_obs[~mask].cpu( ) # Need to keep track of the actual observations that were made to form the corresponding representations (for downstream WIS) next_obs = next_obs[~mask].cpu() cur_dem = cur_dem[~mask].cpu() next_dem = next_dem[~mask].cpu() # Loop over all transitions and add them to the replay buffer for i_trans in range(cur_rep.shape[0]): done = cur_rewards[i_trans] != 0 if self.context_input: self.replay_buffer.add( cur_rep[i_trans].numpy(), cur_actions[i_trans].argmax().item(), next_rep[i_trans].numpy(), cur_rewards[i_trans].item(), done.item(), torch.cat((cur_obs[i_trans], cur_dem[i_trans]), dim=-1).numpy(), torch.cat( (next_obs[i_trans], next_dem[i_trans]), dim=-1).numpy()) else: self.replay_buffer.add( cur_rep[i_trans].numpy(), cur_actions[i_trans].argmax().item(), next_rep[i_trans].numpy(), cur_rewards[i_trans].item(), done.item(), cur_obs[i_trans].numpy(), next_obs[i_trans].numpy()) ## SAVE OFF DATA # -------------- self.replay_buffer.save(self.buffer_save_file) torch.save(errors, self.next_obs_pred_errors_file) torch.save(test_representations, self.test_representations_file) torch.save(correlations, self.test_correlations_file) def train_dBCQ_policy(self, pol_learning_rate=1e-3): # Initialize parameters for policy learning params = { "eval_freq": 500, "discount": 0.99, "buffer_size": 350000, "batch_size": self.minibatch_size, "optimizer": "Adam", "optimizer_parameters": { "lr": pol_learning_rate }, "train_freq": 1, "polyak_target_update": True, "target_update_freq": 1, "tau": 0.01, "max_timesteps": 5e5, "BCQ_threshold": 0.3, "buffer_dir": self.buffer_save_file, "policy_file": self.policy_save_file + f'_l{pol_learning_rate}.pt', "pol_eval_file": self.policy_eval_save_file + f'_l{pol_learning_rate}.npy', } # Initialize a dataloader for policy evaluation (will need representations, observations, demographics, rewards and actions from the test dataset) test_representations = torch.load( self.test_representations_file) # Load the test representations pol_eval_dataset = TensorDataset(test_representations, self.test_states, self.test_interventions, self.test_demog, self.test_rewards) pol_eval_dataloader = DataLoader(pol_eval_dataset, batch_size=self.minibatch_size, shuffle=False) # Initialize and Load the experience replay buffer corresponding with the current settings of rand_num, hidden_size, etc... replay_buffer = ReplayBuffer( self.hidden_size, self.minibatch_size, 350000, self.device, encoded_state=True, obs_state_dim=self.state_dim + (self.context_dim if self.context_input else 0)) # Load the pretrained policy for whether or not the demographic context was used to train the representations behav_input = self.state_dim + (self.context_dim if self.context_input else 0) behav_pol = FC_BC(behav_input, self.num_actions, 64).to(self.device) if self.context_input: behav_pol.load_state_dict(torch.load(self.behav_policy_file_wDemo)) else: behav_pol.load_state_dict(torch.load(self.behav_policy_file)) behav_pol.eval() # Run dBCQ_utils.train_dBCQ train_dBCQ(replay_buffer, self.num_actions, self.hidden_size, self.device, params, behav_pol, pol_eval_dataloader, self.context_input)