def experiment(variant): print('CUDA status:', torch.cuda.is_available()) env = make_env(variant['env']) # Set seeds variant['seed'] = int(variant['seed']) env.seed(int(variant['seed'])) torch.manual_seed(int(variant['seed'])) np.random.seed(int(variant['seed'])) state_dim = env.observation_space.shape[0] action_dim = env.action_space.shape[0] max_action = float(env.action_space.high[0]) kwargs = {"state_dim": state_dim, "action_dim": action_dim, "max_action": max_action, "discount": variant['discount'], "tau": variant['tau'], 'network_class': NETWORK_CLASSES[variant['network_class']]} # custom network kwargs mlp_network_kwargs = dict(n_hidden=variant['n_hidden'], hidden_dim=variant['hidden_dim'], first_dim=variant['first_dim']) dropout_mlp_network_kwargs = dict(n_hidden=variant['n_hidden'], hidden_dim=variant['hidden_dim'], first_dim=variant['first_dim'], dropout_p=variant['dropout_p']) variable_init_mlp_network_kwargs = dict(n_hidden=variant['n_hidden'], hidden_dim=variant['hidden_dim'], first_dim=variant['first_dim'], sigma=variant['sigma']) fourier_network_kwargs = dict(n_hidden=variant['n_hidden'], hidden_dim=variant['hidden_dim'], fourier_dim=variant['fourier_dim'], sigma=variant['sigma'], concatenate_fourier=variant['concatenate_fourier'], train_B=variant['train_B']) siren_network_kwargs = dict(n_hidden=variant['n_hidden'], hidden_dim=variant['hidden_dim'], first_omega_0=variant['omega'], hidden_omega_0=variant['omega']) if variant['network_class'] in {'MLP', 'D2RL', 'ConcatMLP', 'SpectralMLP'}: kwargs['network_kwargs'] = mlp_network_kwargs elif variant['network_class'] == 'DropoutMLP': kwargs['network_kwargs'] = dropout_mlp_network_kwargs elif variant['network_class'] == 'VariableInitMLP': kwargs['network_kwargs'] = variable_init_mlp_network_kwargs elif variant['network_class'] in {'FourierMLP', 'LogUniformFourierMLP'}: kwargs['network_kwargs'] = fourier_network_kwargs elif variant['network_class'] == 'Siren': kwargs['network_kwargs'] = siren_network_kwargs else: raise NotImplementedError # Initialize policy if variant['policy'] == "TD3": # Target policy smoothing is scaled wrt the action scale kwargs["policy_noise"] = variant['policy_noise * max_action'] kwargs["noise_clip"] = variant['noise_clip * max_action'] kwargs["policy_freq"] = variant['policy_freq'] policy = TD3.TD3(**kwargs) elif variant['policy'] == "OurDDPG": policy = OurDDPG.DDPG(**kwargs) elif variant['policy'] == "DDPG": policy = DDPG.DDPG(**kwargs) elif variant['policy'] == "SAC": kwargs['lr'] = variant['lr'] kwargs['alpha'] = variant['alpha'] kwargs['automatic_entropy_tuning'] = variant['automatic_entropy_tuning'] kwargs['weight_decay'] = variant['weight_decay'] # left out dmc policy = SAC(**kwargs) elif 'PytorchSAC' in variant['policy']: kwargs['action_range'] = [float(env.action_space.low.min()), float(env.action_space.high.max())] kwargs['actor_lr'] = variant['lr'] kwargs['critic_lr'] = variant['lr'] kwargs['alpha_lr'] = variant['alpha_lr'] kwargs['weight_decay'] = variant['weight_decay'] kwargs['no_target'] = variant['no_target'] kwargs['mlp_policy'] = variant['mlp_policy'] kwargs['mlp_qf'] = variant['mlp_qf'] del kwargs['max_action'] if variant['policy'] == 'PytorchSAC': policy = PytorchSAC(**kwargs) elif variant['policy'] == 'RandomNoisePytorchSAC': kwargs['noise_dist'] = variant['noise_dist'] kwargs['noise_scale'] = variant['noise_scale'] policy = RandomNoiseSACAgent(**kwargs) elif variant['policy'] == 'SmoothedPytorchSAC': kwargs['n_critic_samples'] = variant['n_critic_samples'] kwargs['noise_dist'] = variant['noise_dist'] kwargs['noise_scale'] = variant['noise_scale'] policy = SmoothedSACAgent(**kwargs) elif variant['policy'] == 'FuncRegPytorchSAC': kwargs['critic_target_update_frequency'] = variant['critic_freq'] kwargs['fr_weight'] = variant['fr_weight'] policy = FuncRegSACAgent(**kwargs) else: raise NotImplementedError if variant['load_model'] != "": raise RuntimeError # load replay buffer replay_buffer = torch.load(os.path.join(variant['replay_buffer_folder'], 'generated_replay_buffer.pt')) policy_optimizer = torch.optim.Adam(policy.actor.parameters(), lr=variant['lr']) qf_optimizer = torch.optim.Adam(policy.critic.Q1.parameters(), lr=variant['lr']) # split into train and val for both action and q_value indices = np.arange(replay_buffer.max_size) random.shuffle(indices) train_indices = indices[:int(0.9 * len(indices))] val_indices = indices[int(0.9 * len(indices)):] train_dataset = torch.utils.data.TensorDataset(torch.tensor(replay_buffer.state[train_indices]).float(), torch.tensor(replay_buffer.action[train_indices]).float(), torch.tensor(replay_buffer.correct_action[train_indices]).float(), torch.tensor(replay_buffer.q_value[train_indices]).float()) val_dataset = torch.utils.data.TensorDataset(torch.tensor(replay_buffer.state[val_indices]).float(), torch.tensor(replay_buffer.action[val_indices]).float(), torch.tensor(replay_buffer.correct_action[val_indices]).float(), torch.tensor(replay_buffer.q_value[val_indices]).float()) # train a network on it train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=variant['batch_size'], shuffle=True, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=variant['batch_size'], shuffle=True, pin_memory=True) train_q_losses = [] train_policy_losses = [] val_q_losses = [] val_policy_losses = [] for _ in trange(variant['n_train_epochs']): total_q_loss = 0 total_policy_loss = 0 for (state, action, correct_action, q) in train_loader: state = state.to(DEVICE) action = action.to(DEVICE) correct_action = correct_action.to(DEVICE) q = q.to(DEVICE) q_preds = policy.critic.Q1(torch.cat([state, action], dim=-1)) policy_preds = policy.actor(state).mean q_loss = F.mse_loss(q_preds, q) policy_loss = F.mse_loss(policy_preds, correct_action) qf_optimizer.zero_grad() policy_optimizer.zero_grad() q_loss.backward() policy_loss.backward() qf_optimizer.step() policy_optimizer.step() total_q_loss += q_loss.item() total_policy_loss += policy_loss.item() # get validation stats total_val_q_loss = 0 total_val_policy_loss = 0 with torch.no_grad(): for (state, action, correct_action, q) in val_loader: state = state.to(DEVICE) action = action.to(DEVICE) correct_action = correct_action.to(DEVICE) q = q.to(DEVICE) q_preds = policy.critic.Q1(torch.cat([state, action], dim=-1)) policy_preds = policy.actor(state).mean q_loss = F.mse_loss(q_preds, q) policy_loss = F.mse_loss(policy_preds, correct_action) total_val_q_loss += q_loss.item() total_val_policy_loss += policy_loss.item() train_q_losses.append(total_q_loss / len(train_loader)) train_policy_losses.append(total_policy_loss / len(train_loader)) val_q_losses.append(total_val_q_loss / len(val_loader)) val_policy_losses.append(total_val_policy_loss / len(val_loader)) print(f'train: qf loss: {train_q_losses[-1]:.4f}, policy loss: {train_policy_losses[-1]:.4f}') print(f'val: qf loss: {val_q_losses[-1]:.4f}, policy loss: {val_policy_losses[-1]:.4f}') # evaluate the resulting policy for 100 episodes eval_return = eval_policy(policy, variant['env'], variant['seed'], eval_episodes=variant['eval_episodes']) # save the results to_save = dict( train_q_losses=train_q_losses, train_policy_losses=train_policy_losses, val_q_losses=val_q_losses, val_policy_losses=val_policy_losses, eval_return=eval_return, qf=policy.critic.Q1.state_dict(), policy=policy.actor.state_dict() ) torch.save(to_save, os.path.join(variant['replay_buffer_folder'], f'{variant["network_class"]}_distillation.pt'))
def experiment(variant): print('CUDA status:', torch.cuda.is_available()) env = make_env(variant['env']) # Set seeds variant['seed'] = int(variant['seed']) env.seed(int(variant['seed'])) torch.manual_seed(int(variant['seed'])) np.random.seed(int(variant['seed'])) state_dim = env.observation_space.shape[0] action_dim = env.action_space.shape[0] max_action = float(env.action_space.high[0]) kwargs = { "state_dim": state_dim, "action_dim": action_dim, "max_action": max_action, "discount": variant['discount'], "tau": variant['tau'], 'network_class': NETWORK_CLASSES[variant['network_class']] } # custom network kwargs mlp_network_kwargs = dict(n_hidden=variant['n_hidden'], hidden_dim=variant['hidden_dim'], first_dim=variant['first_dim']) dropout_mlp_network_kwargs = dict(n_hidden=variant['n_hidden'], hidden_dim=variant['hidden_dim'], first_dim=variant['first_dim'], dropout_p=variant['dropout_p']) variable_init_mlp_network_kwargs = dict(n_hidden=variant['n_hidden'], hidden_dim=variant['hidden_dim'], first_dim=variant['first_dim'], sigma=variant['sigma']) fourier_network_kwargs = dict( n_hidden=variant['n_hidden'], hidden_dim=variant['hidden_dim'], fourier_dim=variant['fourier_dim'], sigma=variant['sigma'], concatenate_fourier=variant['concatenate_fourier'], train_B=variant['train_B']) siren_network_kwargs = dict(n_hidden=variant['n_hidden'], hidden_dim=variant['hidden_dim'], first_omega_0=variant['omega'], hidden_omega_0=variant['omega']) if variant['network_class'] in {'MLP', 'D2RL', 'ConcatMLP', 'SpectralMLP'}: kwargs['network_kwargs'] = mlp_network_kwargs elif variant['network_class'] == 'DropoutMLP': kwargs['network_kwargs'] = dropout_mlp_network_kwargs elif variant['network_class'] == 'VariableInitMLP': kwargs['network_kwargs'] = variable_init_mlp_network_kwargs elif variant['network_class'] in {'FourierMLP', 'LogUniformFourierMLP'}: kwargs['network_kwargs'] = fourier_network_kwargs elif variant['network_class'] == 'Siren': kwargs['network_kwargs'] = siren_network_kwargs else: raise NotImplementedError # Initialize policy if variant['policy'] == "TD3": # Target policy smoothing is scaled wrt the action scale kwargs["policy_noise"] = variant['policy_noise * max_action'] kwargs["noise_clip"] = variant['noise_clip * max_action'] kwargs["policy_freq"] = variant['policy_freq'] policy = TD3.TD3(**kwargs) elif variant['policy'] == "OurDDPG": policy = OurDDPG.DDPG(**kwargs) elif variant['policy'] == "DDPG": policy = DDPG.DDPG(**kwargs) elif variant['policy'] == "SAC": kwargs['lr'] = variant['lr'] kwargs['alpha'] = variant['alpha'] kwargs['automatic_entropy_tuning'] = variant[ 'automatic_entropy_tuning'] kwargs['weight_decay'] = variant['weight_decay'] # left out dmc policy = SAC(**kwargs) elif 'PytorchSAC' in variant['policy']: kwargs['action_range'] = [ float(env.action_space.low.min()), float(env.action_space.high.max()) ] kwargs['actor_lr'] = variant['lr'] kwargs['critic_lr'] = variant['lr'] kwargs['alpha_lr'] = variant['alpha_lr'] kwargs['weight_decay'] = variant['weight_decay'] kwargs['no_target'] = variant['no_target'] kwargs['mlp_policy'] = variant['mlp_policy'] kwargs['mlp_qf'] = variant['mlp_qf'] del kwargs['max_action'] if variant['policy'] == 'PytorchSAC': policy = PytorchSAC(**kwargs) elif variant['policy'] == 'RandomNoisePytorchSAC': kwargs['noise_dist'] = variant['noise_dist'] kwargs['noise_scale'] = variant['noise_scale'] policy = RandomNoiseSACAgent(**kwargs) elif variant['policy'] == 'SmoothedPytorchSAC': kwargs['n_critic_samples'] = variant['n_critic_samples'] kwargs['noise_dist'] = variant['noise_dist'] kwargs['noise_scale'] = variant['noise_scale'] policy = SmoothedSACAgent(**kwargs) elif variant['policy'] == 'FuncRegPytorchSAC': kwargs['critic_target_update_frequency'] = args.critic_freq kwargs['fr_weight'] = args.fr_weight policy = FuncRegSACAgent(**kwargs) else: raise NotImplementedError if variant['load_model'] != "": policy_file = variant['load_model'] # policy_file = file_name if variant['load_model'] == "default" else variant['load_model'] policy.load(policy_file) replay_buffer = CustomReplayBuffer(state_dim, action_dim, max_size=int(variant['max_timesteps'])) # fill replay buffer, save immediately state, done = env.reset(), False episode_reward = 0 episode_timesteps = 0 episode_num = 0 curr_time = datetime.now() for t in trange(int(variant['max_timesteps'])): episode_timesteps += 1 action = policy.select_action(np.array(state), evaluate=False) # Perform action next_state, reward, done, _ = env.step(action) replay_buffer.add(state, action) state = next_state episode_reward += reward if done or episode_timesteps > env._max_episode_steps: # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True print( f"Total T: {t + 1} Episode Num: {episode_num + 1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f}" ) # Reset environment state, done = env.reset(), False episode_reward = 0 episode_timesteps = 0 episode_num += 1 # save the replay buffer folder = os.path.dirname(policy_file) torch.save(replay_buffer, os.path.join(folder, 'generated_replay_buffer.pt')) assert replay_buffer.max_size == replay_buffer.size # label the items in the replay buffer with my q-networks and policy with torch.no_grad(): for start_idx in trange(0, replay_buffer.max_size, variant['batch_size']): end_idx = start_idx + variant['batch_size'] obs = torch.tensor(replay_buffer.state[start_idx:end_idx], device=DEVICE, dtype=torch.float32) action = torch.tensor(replay_buffer.action[start_idx:end_idx], device=DEVICE, dtype=torch.float32) actor_Q1, actor_Q2 = policy.critic(obs, action) actor_Q = torch.min(actor_Q1, actor_Q2) action = policy.actor(obs).mean.clamp(*policy.action_range) replay_buffer.set_values(start_idx, end_idx, to_np(actor_Q), to_np(action)) # overwrite the bad replay buffer torch.save(replay_buffer, os.path.join(folder, 'generated_replay_buffer.pt'))