def _preprocess_states_actions(actions, states, device): # Process states and actions states = [''.join(list(state)) for state in states] states, states_len = pad_sequences(states) states, _ = seq2tensor(states, get_default_tokens()) states = torch.from_numpy(states).long().to(device) states_len = torch.tensor(states_len).long().to(device) actions, _ = seq2tensor(actions, get_default_tokens()) actions = torch.from_numpy(actions.reshape(-1)).long().to(device) return (states, states_len), actions
def data_provider(k, flags): tokens = get_default_tokens() demo_data = GeneratorData(training_data_path=flags.demo_file, delimiter='\t', cols_to_read=[0], keep_header=True, pad_symbol=' ', max_len=120, tokens=tokens, use_cuda=use_cuda) unbiased_data = GeneratorData(training_data_path=flags.unbiased_file, delimiter='\t', cols_to_read=[0], keep_header=True, pad_symbol=' ', max_len=120, tokens=tokens, use_cuda=use_cuda) prior_data = GeneratorData(training_data_path=flags.prior_data, delimiter='\t', cols_to_read=[0], keep_header=True, pad_symbol=' ', max_len=120, tokens=tokens, use_cuda=use_cuda) return { 'demo_data': demo_data, 'unbiased_data': unbiased_data, 'prior_data': prior_data }
def initialize(hparams, train_data, val_data, test_data): # Create pytorch data loaders train_loader = DataLoader(SmilesDataset(train_data[0], train_data[1]), batch_size=hparams['batch'], collate_fn=lambda x: x) if val_data: val_loader = DataLoader(SmilesDataset(val_data[0], val_data[1]), batch_size=hparams['batch'], collate_fn=lambda x: x) else: val_loader = None test_loader = DataLoader(SmilesDataset(test_data[0], test_data[1]), batch_size=hparams['batch'], collate_fn=lambda x: x) # Create model and optimizer model = RNNPredictorModel(d_model=int(hparams['d_model']), tokens=get_default_tokens(), num_layers=int(hparams['rnn_num_layers']), dropout=float(hparams['dropout']), bidirectional=hparams['is_bidirectional'], unit_type=hparams['unit_type'], device=device).to(device) optimizer = parse_optimizer(hparams, model) metrics = [mean_squared_error, root_mean_squared_error, r2_score] return { 'data_loaders': { 'train': train_loader, 'val': val_loader if val_data else None, 'test': test_loader }, 'model': model, 'optimizer': optimizer, 'metrics': metrics }
def __init__(self, hparams, device, is_binary=False): expert_model_dir = hparams['model_dir'] assert (os.path.isdir(expert_model_dir)), 'Expert model(s) should be in a dedicated folder' self.models = [] self.tokens = get_default_tokens() self.device = device model_paths = os.listdir(expert_model_dir) self.transformer = None self.is_binary = is_binary for model_file in model_paths: if 'transformer' in model_file: with open(os.path.join(expert_model_dir, model_file), 'rb') as f: self.transformer = joblib.load(f) continue model = RNNPredictorModel(d_model=hparams['d_model'], tokens=self.tokens, num_layers=hparams['rnn_num_layers'], dropout=hparams['dropout'], bidirectional=hparams['is_bidirectional'], unit_type=hparams['unit_type'], device=device).to(device) if is_binary: model = torch.nn.Sequential(model, torch.nn.Sigmoid()).to(device) model.load_state_dict(torch.load(os.path.join(expert_model_dir, model_file), map_location=torch.device(device))) model = model.eval() self.models.append(model)
def initialize(hparams, train_data, val_data, test_data): # Create pytorch data loaders train_loader = DataLoader(SmilesDataset(train_data[0], train_data[1]), batch_size=hparams['batch'], shuffle=True, collate_fn=lambda x: x) if val_data: val_loader = DataLoader(SmilesDataset(val_data[0], val_data[1]), batch_size=hparams['batch'], collate_fn=lambda x: x) else: val_loader = None test_loader = DataLoader(SmilesDataset(test_data[0], test_data[1]), batch_size=hparams['batch'], collate_fn=lambda x: x) # Create model and optimizer model = torch.nn.Sequential(RNNPredictorModel(d_model=int(hparams['d_model']), tokens=get_default_tokens(), num_layers=int(hparams['rnn_num_layers']), dropout=float(hparams['dropout']), bidirectional=hparams['is_bidirectional'], unit_type=hparams['unit_type'], device=device), torch.nn.Sigmoid()).to(device) optimizer = parse_optimizer(hparams, model) metrics = [accuracy_score, precision_score, recall_score, f1_score] return {'data_loaders': {'train': train_loader, 'val': val_loader if val_data else None, 'test': test_loader}, 'model': model, 'optimizer': optimizer, 'metrics': metrics}
def calc_adv_ref(self, trajectory): states, actions, _ = unpack_batch([trajectory], self.gamma) last_state = ''.join(list(states[-1])) inp, _ = seq2tensor([last_state], tokens=get_default_tokens()) inp = torch.from_numpy(inp).long().to(self.device) values_v = self.critic(inp) values = values_v.view(-1, ).data.cpu().numpy() last_gae = 0.0 result_adv = [] result_ref = [] for val, next_val, exp in zip(reversed(values[:-1]), reversed(values[1:]), reversed(trajectory[:-1])): if exp.last_state is None: # for terminal state delta = exp.reward - val last_gae = delta else: delta = exp.reward + self.gamma * next_val - val last_gae = delta + self.gamma * self.gae_lambda * last_gae result_adv.append(last_gae) result_ref.append(last_gae + val) adv_v = torch.FloatTensor(list(reversed(result_adv))).to(self.device) ref_v = torch.FloatTensor(list(reversed(result_ref))).to(self.device) return states[:-1], actions[:-1], adv_v, ref_v
def smiles_to_tensor(smiles): smiles = list(smiles) _, valid_vec = canonical_smiles(smiles) valid_vec = torch.tensor(valid_vec).view(-1, 1).float().to(device) smiles, _ = pad_sequences(smiles) inp, _ = seq2tensor(smiles, tokens=get_default_tokens()) inp = torch.from_numpy(inp).long().to(device) return inp, valid_vec
def data_provider(k, flags): tokens = get_default_tokens() gen_data = GeneratorData(training_data_path=flags.data_file, delimiter='\t', cols_to_read=[0], keep_header=True, pad_symbol=' ', max_len=120, tokens=tokens, use_cuda=use_cuda) return {"train": gen_data, "val": gen_data, "test": gen_data}
def fit(self, trajectories): """Train the reward function / model using the GRL algorithm.""" """Train the reward function / model using the GRL algorithm.""" if self.use_buffer: extra_trajs = self.replay_buffer.sample(self.batch_size) trajectories.extend(extra_trajs) self.replay_buffer.populate(trajectories) d_traj, d_traj_probs = [], [] for traj in trajectories: d_traj.append(''.join(list(traj.terminal_state.state)) + traj.terminal_state.action) d_traj_probs.append(traj.traj_prob) _, valid_vec_samp = canonical_smiles(d_traj) valid_vec_samp = torch.tensor(valid_vec_samp).view(-1, 1).float().to( self.device) d_traj, _ = pad_sequences(d_traj) d_samp, _ = seq2tensor(d_traj, tokens=get_default_tokens()) d_samp = torch.from_numpy(d_samp).long().to(self.device) losses = [] for i in trange(self.k, desc='IRL optimization...'): # D_demo processing demo_states, demo_actions = self.demo_gen_data.random_training_set( ) d_demo = torch.cat( [demo_states, demo_actions[:, -1].reshape(-1, 1)], dim=1).to(self.device) valid_vec_demo = torch.ones(d_demo.shape[0]).view( -1, 1).float().to(self.device) d_demo_out = self.model([d_demo, valid_vec_demo]) # D_samp processing d_samp_out = self.model([d_samp, valid_vec_samp]) d_out_combined = torch.cat([d_samp_out, d_demo_out], dim=0) if d_samp_out.shape[0] < 1000: d_samp_out = torch.cat([d_samp_out, d_demo_out], dim=0) z = torch.ones(d_samp_out.shape[0]).float().to( self.device) # dummy importance weights TODO: replace this d_samp_out = z.view(-1, 1) * torch.exp(d_samp_out) # objective loss = torch.mean(d_demo_out) - torch.log(torch.mean(d_samp_out)) losses.append(loss.item()) loss = -loss # for maximization # update params self.optimizer.zero_grad() loss.backward() self.optimizer.step() # self.lr_sch.step() return np.mean(losses)
def train(init_args, agent_net_path=None, agent_net_name=None, seed=0, n_episodes=500, sim_data_node=None, tb_writer=None, is_hsearch=False, n_to_generate=200, learn_irl=True, bias_mode='max'): tb_writer = tb_writer() agent = init_args['agent'] probs_reg = init_args['probs_reg'] drl_algorithm = init_args['drl_alg'] irl_algorithm = init_args['irl_alg'] reward_func = init_args['reward_func'] gamma = init_args['gamma'] episodes_to_train = init_args['episodes_to_train'] expert_model = init_args['expert_model'] demo_data_gen = init_args['demo_data_gen'] unbiased_data_gen = init_args['unbiased_data_gen'] best_model_wts = None best_score = 0. exp_avg = ExpAverage(beta=0.6) # load pretrained model if agent_net_path and agent_net_name: print('Loading pretrained model...') agent.model.load_state_dict( IReLeaSE.load_model(agent_net_path, agent_net_name)) print('Pretrained model loaded successfully!') # collect mean predictions unbiased_smiles_mean_pred, biased_smiles_mean_pred, gen_smiles_mean_pred = [], [], [] unbiased_smiles_mean_pred_data_node = DataNode( 'baseline_mean_vals', unbiased_smiles_mean_pred) biased_smiles_mean_pred_data_node = DataNode('biased_mean_vals', biased_smiles_mean_pred) gen_smiles_mean_pred_data_node = DataNode('gen_mean_vals', gen_smiles_mean_pred) if sim_data_node: sim_data_node.data = [ unbiased_smiles_mean_pred_data_node, biased_smiles_mean_pred_data_node, gen_smiles_mean_pred_data_node ] start = time.time() # Begin simulation and training total_rewards = [] irl_trajectories = [] done_episodes = 0 batch_episodes = 0 exp_trajectories = [] env = MoleculeEnv(actions=get_default_tokens(), reward_func=reward_func) exp_source = ExperienceSourceFirstLast(env, agent, gamma, steps_count=1, steps_delta=1) traj_prob = 1. exp_traj = [] demo_score = np.mean( expert_model(demo_data_gen.random_training_set_smiles(1000))[1]) baseline_score = np.mean( expert_model( unbiased_data_gen.random_training_set_smiles(1000))[1]) with contextlib.suppress(Exception if is_hsearch else DummyException): with TBMeanTracker(tb_writer, 1) as tracker: for step_idx, exp in tqdm(enumerate(exp_source)): exp_traj.append(exp) traj_prob *= probs_reg.get(list(exp.state), exp.action) if exp.last_state is None: irl_trajectories.append( Trajectory(terminal_state=EpisodeStep( exp.state, exp.action), traj_prob=traj_prob)) exp_trajectories.append( exp_traj) # for ExperienceFirstLast objects exp_traj = [] traj_prob = 1. probs_reg.clear() batch_episodes += 1 new_rewards = exp_source.pop_total_rewards() if new_rewards: reward = new_rewards[0] done_episodes += 1 total_rewards.append(reward) mean_rewards = float(np.mean(total_rewards[-100:])) tracker.track('mean_total_reward', mean_rewards, step_idx) tracker.track('total_reward', reward, step_idx) print( f'Time = {time_since(start)}, step = {step_idx}, reward = {reward:6.2f}, ' f'mean_100 = {mean_rewards:6.2f}, episodes = {done_episodes}' ) with torch.set_grad_enabled(False): samples = generate_smiles( drl_algorithm.model, demo_data_gen, init_args['gen_args'], num_samples=n_to_generate) predictions = expert_model(samples)[1] mean_preds = np.mean(predictions) try: percentage_in_threshold = np.sum( (predictions >= 7.0)) / len(predictions) except: percentage_in_threshold = 0. per_valid = len(predictions) / n_to_generate print( f'Mean value of predictions = {mean_preds}, ' f'% of valid SMILES = {per_valid}, ' f'% in drug-like region={percentage_in_threshold}') unbiased_smiles_mean_pred.append(float(baseline_score)) biased_smiles_mean_pred.append(float(demo_score)) gen_smiles_mean_pred.append(float(mean_preds)) tb_writer.add_scalars( 'qsar_score', { 'sampled': mean_preds, 'baseline': baseline_score, 'demo_data': demo_score }, step_idx) tb_writer.add_scalars( 'SMILES stats', { 'per. of valid': per_valid, 'per. above threshold': percentage_in_threshold }, step_idx) eval_dict = {} eval_score = IReLeaSE.evaluate( eval_dict, samples, demo_data_gen.random_training_set_smiles(1000)) for k in eval_dict: tracker.track(k, eval_dict[k], step_idx) tracker.track('Average SMILES length', np.nanmean([len(s) for s in samples]), step_idx) if bias_mode == 'max': diff = mean_preds - demo_score else: diff = demo_score - mean_preds score = np.exp(diff) exp_avg.update(score) tracker.track('score', score, step_idx) if exp_avg.value > best_score: best_model_wts = [ copy.deepcopy( drl_algorithm.model.state_dict()), copy.deepcopy(irl_algorithm.model.state_dict()) ] best_score = exp_avg.value if best_score >= np.exp(0.): print( f'threshold reached, best score={mean_preds}, ' f'threshold={demo_score}, training completed') break if done_episodes == n_episodes: print('Training completed!') break if batch_episodes < episodes_to_train: continue # Train models print('Fitting models...') irl_stmt = '' if learn_irl: irl_loss = irl_algorithm.fit(irl_trajectories) tracker.track('irl_loss', irl_loss, step_idx) irl_stmt = f'IRL loss = {irl_loss}, ' rl_loss = drl_algorithm.fit(exp_trajectories) samples = generate_smiles(drl_algorithm.model, demo_data_gen, init_args['gen_args'], num_samples=3) print( f'{irl_stmt}RL loss = {rl_loss}, samples = {samples}') tracker.track('agent_loss', rl_loss, step_idx) # Reset batch_episodes = 0 irl_trajectories.clear() exp_trajectories.clear() if best_model_wts: drl_algorithm.model.load_state_dict(best_model_wts[0]) irl_algorithm.model.load_state_dict(best_model_wts[1]) duration = time.time() - start print('\nTraining duration: {:.0f}m {:.0f}s'.format( duration // 60, duration % 60)) return { 'model': [drl_algorithm.model, irl_algorithm.model], 'score': round(best_score, 3), 'epoch': done_episodes }
default='drd2_active.smi', help='The filename for the created dataset') args = parser.parse_args() assert (os.path.exists(args.svc)) assert (os.path.exists(args.data)) assert (0 < args.threshold < 1) # Load file containing SMILES gen_data = GeneratorData(training_data_path=args.data, delimiter='\t', cols_to_read=[0], keep_header=True, pad_symbol=' ', max_len=120, tokens=get_default_tokens(), use_cuda=False) # Load classifier clf = DRD2Model(args.svc) # Screen SMILES in data file and write active compounds to file. os.makedirs(args.save_dir, exist_ok=True) num_active = 0 with open(os.path.join(args.save_dir, args.filename), 'w') as f: for i in trange(gen_data.file_len, desc='Screening compounds...'): smiles = gen_data.file[i][1:-1] p = clf(smiles) if p >= args.threshold: f.write(smiles + '\n') num_active += 1
def train(init_args, model_path=None, agent_net_name=None, reward_net_name=None, n_episodes=500, sim_data_node=None, tb_writer=None, is_hsearch=False, n_to_generate=200, learn_irl=True): tb_writer = tb_writer() agent = init_args['agent'] probs_reg = init_args['probs_reg'] drl_algorithm = init_args['drl_alg'] irl_algorithm = init_args['irl_alg'] reward_func = init_args['reward_func'] gamma = init_args['gamma'] episodes_to_train = init_args['episodes_to_train'] expert_model = init_args['expert_model'] demo_data_gen = init_args['demo_data_gen'] unbiased_data_gen = init_args['unbiased_data_gen'] best_model_wts = None exp_avg = ExpAverage(beta=0.6) best_score = -1. # load pretrained model if model_path and agent_net_name and reward_net_name: try: print('Loading pretrained model...') weights = IReLeaSE.load_model(model_path, agent_net_name) agent.model.load_state_dict(weights) print('Pretrained model loaded successfully!') reward_func.model.load_state_dict(IReLeaSE.load_model(model_path, reward_net_name)) print('Reward model loaded successfully!') except: print('Pretrained model could not be loaded. Terminating prematurely.') return {'model': [drl_algorithm.actor, drl_algorithm.critic, irl_algorithm.model], 'score': round(best_score, 3), 'epoch': -1} start = time.time() # Begin simulation and training total_rewards = [] trajectories = [] done_episodes = 0 batch_episodes = 0 exp_trajectories = [] step_idx = 0 # collect mean predictions unbiased_smiles_mean_pred, biased_smiles_mean_pred, gen_smiles_mean_pred = [], [], [] unbiased_smiles_mean_pred_data_node = DataNode('baseline_mean_vals', unbiased_smiles_mean_pred) biased_smiles_mean_pred_data_node = DataNode('biased_mean_vals', biased_smiles_mean_pred) gen_smiles_mean_pred_data_node = DataNode('gen_mean_vals', gen_smiles_mean_pred) if sim_data_node: sim_data_node.data = [unbiased_smiles_mean_pred_data_node, biased_smiles_mean_pred_data_node, gen_smiles_mean_pred_data_node] env = MoleculeEnv(actions=get_default_tokens(), reward_func=reward_func) exp_source = ExperienceSourceFirstLast(env, agent, gamma, steps_count=1, steps_delta=1) traj_prob = 1. exp_traj = [] demo_score = np.mean(expert_model(demo_data_gen.random_training_set_smiles(1000))[1]) baseline_score = np.mean(expert_model(unbiased_data_gen.random_training_set_smiles(1000))[1]) # with contextlib.suppress(RuntimeError if is_hsearch else DummyException): try: with TBMeanTracker(tb_writer, 1) as tracker: for step_idx, exp in tqdm(enumerate(exp_source)): exp_traj.append(exp) traj_prob *= probs_reg.get(list(exp.state), exp.action) if exp.last_state is None: trajectories.append(Trajectory(terminal_state=EpisodeStep(exp.state, exp.action), traj_prob=traj_prob)) exp_trajectories.append(exp_traj) # for ExperienceFirstLast objects exp_traj = [] traj_prob = 1. probs_reg.clear() batch_episodes += 1 new_rewards = exp_source.pop_total_rewards() if new_rewards: reward = new_rewards[0] done_episodes += 1 total_rewards.append(reward) mean_rewards = float(np.mean(total_rewards[-100:])) tracker.track('mean_total_reward', mean_rewards, step_idx) tracker.track('total_reward', reward, step_idx) print(f'Time = {time_since(start)}, step = {step_idx}, reward = {reward:6.2f}, ' f'mean_100 = {mean_rewards:6.2f}, episodes = {done_episodes}') with torch.set_grad_enabled(False): samples = generate_smiles(drl_algorithm.model, demo_data_gen, init_args['gen_args'], num_samples=n_to_generate) predictions = expert_model(samples)[1] mean_preds = np.nanmean(predictions) if math.isnan(mean_preds) or math.isinf(mean_preds): print(f'mean preds is {mean_preds}, terminating') # best_score = -1. break try: percentage_in_threshold = np.sum((predictions <= demo_score)) / len(predictions) except: percentage_in_threshold = 0. per_valid = len(predictions) / n_to_generate if per_valid < 0.2: print(f'Percentage of valid SMILES is = {per_valid}. Terminating...') # best_score = -1. break print(f'Mean value of predictions = {mean_preds}, % of valid SMILES = {per_valid}') unbiased_smiles_mean_pred.append(float(baseline_score)) biased_smiles_mean_pred.append(float(demo_score)) gen_smiles_mean_pred.append(float(mean_preds)) tb_writer.add_scalars('qsar_score', {'sampled': mean_preds, 'baseline': baseline_score, 'demo_data': demo_score}, step_idx) tb_writer.add_scalars('SMILES stats', {'per. of valid': per_valid, 'per. in drug-like region': percentage_in_threshold}, step_idx) eval_dict = {} eval_score = IReLeaSE.evaluate(eval_dict, samples, demo_data_gen.random_training_set_smiles(1000)) for k in eval_dict: tracker.track(k, eval_dict[k], step_idx) avg_len = np.nanmean([len(s) for s in samples]) tracker.track('Average SMILES length', np.nanmean([len(s) for s in samples]), step_idx) d_penalty = eval_score < .5 s_penalty = avg_len < 20 diff = demo_score - mean_preds # score = 3 * np.exp(diff) + np.log(per_valid + 1e-5) - s_penalty * np.exp( # diff) - d_penalty * np.exp(diff) score = np.exp(diff) # score = np.exp(diff) + np.mean([np.exp(per_valid), np.exp(percentage_in_threshold)]) if math.isnan(score) or math.isinf(score): # best_score = -1. print(f'Score is {score}, terminating.') break tracker.track('score', score, step_idx) exp_avg.update(score) if is_hsearch: best_score = exp_avg.value if exp_avg.value > best_score: best_model_wts = [copy.deepcopy(drl_algorithm.actor.state_dict()), copy.deepcopy(drl_algorithm.critic.state_dict()), copy.deepcopy(irl_algorithm.model.state_dict())] best_score = exp_avg.value if best_score >= np.exp(0.): print(f'threshold reached, best score={mean_preds}, ' f'threshold={demo_score}, training completed') break if done_episodes == n_episodes: print('Training completed!') break if batch_episodes < episodes_to_train: continue # Train models print('Fitting models...') irl_loss = 0. # irl_loss = irl_algorithm.fit(trajectories) if learn_irl else 0. rl_loss = drl_algorithm.fit(exp_trajectories) samples = generate_smiles(drl_algorithm.model, demo_data_gen, init_args['gen_args'], num_samples=3) print(f'IRL loss = {irl_loss}, RL loss = {rl_loss}, samples = {samples}') tracker.track('irl_loss', irl_loss, step_idx) tracker.track('critic_loss', rl_loss[0], step_idx) tracker.track('agent_loss', rl_loss[1], step_idx) # Reset batch_episodes = 0 trajectories.clear() exp_trajectories.clear() except Exception as e: print(str(e)) if best_model_wts: drl_algorithm.actor.load_state_dict(best_model_wts[0]) drl_algorithm.critic.load_state_dict(best_model_wts[1]) irl_algorithm.model.load_state_dict(best_model_wts[2]) duration = time.time() - start print('\nTraining duration: {:.0f}m {:.0f}s'.format(duration // 60, duration % 60)) # if math.isinf(best_score) or math.isnan(best_score): # best_score = -1. return {'model': [drl_algorithm.actor, drl_algorithm.critic, irl_algorithm.model], 'score': round(best_score, 3), 'epoch': done_episodes}
def fit(self, trajectories): sq2ten = lambda x: torch.from_numpy( seq2tensor(x, get_default_tokens())[0]).long().to(self.device) t_states, t_actions, t_adv, t_ref = [], [], [], [] t_old_probs = [] for traj in trajectories: states, actions, adv_v, ref_v = self.calc_adv_ref(traj) if len(states) == 0: continue t_states.append(states) t_actions.append(actions) t_adv.append(adv_v) t_ref.append(ref_v) with torch.set_grad_enabled(False): hidden_states = self.initial_states_func( batch_size=1, **self.initial_states_args) trajectory_input = sq2ten(states[-1]) actions = sq2ten(actions) old_probs = [] for p in range(len(trajectory_input)): outputs = self.model([trajectory_input[p].reshape(1, 1)] + hidden_states) output, hidden_states = outputs[0], outputs[1:] log_prob = torch.log_softmax(output.view(1, -1), dim=1) old_probs.append(log_prob[0, actions[p]].item()) t_old_probs.append(old_probs) if len(t_states) == 0: return 0., 0. for epoch in trange(self.ppo_epochs, desc='PPO optimization...'): cr_loss = 0. ac_loss = 0. for i in range(len(t_states)): traj_last_state = t_states[i][-1] traj_actions = t_actions[i] traj_adv = t_adv[i] traj_ref = t_ref[i] traj_old_probs = t_old_probs[i] hidden_states = self.initial_states_func( 1, **self.initial_states_args) for p in range(len(traj_last_state)): state, action, adv = traj_last_state[p], traj_actions[ p], traj_adv[p] old_log_prob = traj_old_probs[p] state, action = sq2ten(state), sq2ten(action) # Critic pred = self.critic(state) cr_loss = cr_loss + F.mse_loss(pred.reshape(-1, 1), traj_ref[p].reshape(-1, 1)) # Actor outputs = self.actor([state] + hidden_states) output, hidden_states = outputs[0], outputs[1:] logprob_pi_v = torch.log_softmax(output.view(1, -1), dim=-1) logprob_pi_v = logprob_pi_v[0, action] ratio_v = torch.exp(logprob_pi_v - old_log_prob) surr_obj_v = adv * ratio_v clipped_surr_v = adv * torch.clamp( ratio_v, 1.0 - self.ppo_eps, 1.0 + self.ppo_eps) loss_policy_v = torch.min(surr_obj_v, clipped_surr_v) # Maximize entropy prob = torch.softmax(output.view(1, -1), dim=1) prob = prob[0, action] entropy = prob * logprob_pi_v entropy_loss = self.entropy_beta * entropy ac_loss = ac_loss - (loss_policy_v + entropy_loss) # Update weights self.critic_opt.zero_grad() self.actor_opt.zero_grad() cr_loss = cr_loss / len(trajectories) ac_loss = ac_loss / len(trajectories) cr_loss.backward() ac_loss.backward() self.critic_opt.step() self.actor_opt.step() return cr_loss.item(), -ac_loss.item()