def experiment(variant):
    print('CUDA status:', torch.cuda.is_available())
    env = make_env(variant['env'])

    # Set seeds
    variant['seed'] = int(variant['seed'])
    env.seed(int(variant['seed']))
    torch.manual_seed(int(variant['seed']))
    np.random.seed(int(variant['seed']))

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])

    kwargs = {"state_dim": state_dim, "action_dim": action_dim, "max_action": max_action,
              "discount": variant['discount'], "tau": variant['tau'],
              'network_class': NETWORK_CLASSES[variant['network_class']]}

    # custom network kwargs
    mlp_network_kwargs = dict(n_hidden=variant['n_hidden'],
                              hidden_dim=variant['hidden_dim'],
                              first_dim=variant['first_dim'])
    dropout_mlp_network_kwargs = dict(n_hidden=variant['n_hidden'],
                                      hidden_dim=variant['hidden_dim'],
                                      first_dim=variant['first_dim'],
                                      dropout_p=variant['dropout_p'])
    variable_init_mlp_network_kwargs = dict(n_hidden=variant['n_hidden'],
                                            hidden_dim=variant['hidden_dim'],
                                            first_dim=variant['first_dim'],
                                            sigma=variant['sigma'])
    fourier_network_kwargs = dict(n_hidden=variant['n_hidden'],
                                  hidden_dim=variant['hidden_dim'],
                                  fourier_dim=variant['fourier_dim'],
                                  sigma=variant['sigma'],
                                  concatenate_fourier=variant['concatenate_fourier'],
                                  train_B=variant['train_B'])
    siren_network_kwargs = dict(n_hidden=variant['n_hidden'],
                                hidden_dim=variant['hidden_dim'],
                                first_omega_0=variant['omega'],
                                hidden_omega_0=variant['omega'])
    if variant['network_class'] in {'MLP', 'D2RL', 'ConcatMLP', 'SpectralMLP'}:
        kwargs['network_kwargs'] = mlp_network_kwargs
    elif variant['network_class'] == 'DropoutMLP':
        kwargs['network_kwargs'] = dropout_mlp_network_kwargs
    elif variant['network_class'] == 'VariableInitMLP':
        kwargs['network_kwargs'] = variable_init_mlp_network_kwargs
    elif variant['network_class'] in {'FourierMLP', 'LogUniformFourierMLP'}:
        kwargs['network_kwargs'] = fourier_network_kwargs
    elif variant['network_class'] == 'Siren':
        kwargs['network_kwargs'] = siren_network_kwargs
    else:
        raise NotImplementedError

    # Initialize policy
    if variant['policy'] == "TD3":
        # Target policy smoothing is scaled wrt the action scale
        kwargs["policy_noise"] = variant['policy_noise * max_action']
        kwargs["noise_clip"] = variant['noise_clip * max_action']
        kwargs["policy_freq"] = variant['policy_freq']
        policy = TD3.TD3(**kwargs)
    elif variant['policy'] == "OurDDPG":
        policy = OurDDPG.DDPG(**kwargs)
    elif variant['policy'] == "DDPG":
        policy = DDPG.DDPG(**kwargs)
    elif variant['policy'] == "SAC":
        kwargs['lr'] = variant['lr']
        kwargs['alpha'] = variant['alpha']
        kwargs['automatic_entropy_tuning'] = variant['automatic_entropy_tuning']
        kwargs['weight_decay'] = variant['weight_decay']
        # left out dmc
        policy = SAC(**kwargs)
    elif 'PytorchSAC' in variant['policy']:
        kwargs['action_range'] = [float(env.action_space.low.min()), float(env.action_space.high.max())]
        kwargs['actor_lr'] = variant['lr']
        kwargs['critic_lr'] = variant['lr']
        kwargs['alpha_lr'] = variant['alpha_lr']
        kwargs['weight_decay'] = variant['weight_decay']
        kwargs['no_target'] = variant['no_target']
        kwargs['mlp_policy'] = variant['mlp_policy']
        kwargs['mlp_qf'] = variant['mlp_qf']
        del kwargs['max_action']
        if variant['policy'] == 'PytorchSAC':
            policy = PytorchSAC(**kwargs)
        elif variant['policy'] == 'RandomNoisePytorchSAC':
            kwargs['noise_dist'] = variant['noise_dist']
            kwargs['noise_scale'] = variant['noise_scale']
            policy = RandomNoiseSACAgent(**kwargs)
        elif variant['policy'] == 'SmoothedPytorchSAC':
            kwargs['n_critic_samples'] = variant['n_critic_samples']
            kwargs['noise_dist'] = variant['noise_dist']
            kwargs['noise_scale'] = variant['noise_scale']
            policy = SmoothedSACAgent(**kwargs)
        elif variant['policy'] == 'FuncRegPytorchSAC':
            kwargs['critic_target_update_frequency'] = variant['critic_freq']
            kwargs['fr_weight'] = variant['fr_weight']
            policy = FuncRegSACAgent(**kwargs)
    else:
        raise NotImplementedError

    if variant['load_model'] != "":
        raise RuntimeError

    # load replay buffer
    replay_buffer = torch.load(os.path.join(variant['replay_buffer_folder'], 'generated_replay_buffer.pt'))

    policy_optimizer = torch.optim.Adam(policy.actor.parameters(), lr=variant['lr'])
    qf_optimizer = torch.optim.Adam(policy.critic.Q1.parameters(), lr=variant['lr'])

    # split into train and val for both action and q_value
    indices = np.arange(replay_buffer.max_size)
    random.shuffle(indices)
    train_indices = indices[:int(0.9 * len(indices))]
    val_indices = indices[int(0.9 * len(indices)):]
    train_dataset = torch.utils.data.TensorDataset(torch.tensor(replay_buffer.state[train_indices]).float(),
                                                   torch.tensor(replay_buffer.action[train_indices]).float(),
                                                   torch.tensor(replay_buffer.correct_action[train_indices]).float(),
                                                   torch.tensor(replay_buffer.q_value[train_indices]).float())
    val_dataset = torch.utils.data.TensorDataset(torch.tensor(replay_buffer.state[val_indices]).float(),
                                                 torch.tensor(replay_buffer.action[val_indices]).float(),
                                                 torch.tensor(replay_buffer.correct_action[val_indices]).float(),
                                                 torch.tensor(replay_buffer.q_value[val_indices]).float())

    # train a network on it
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=variant['batch_size'], shuffle=True,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=variant['batch_size'], shuffle=True,
                                             pin_memory=True)

    train_q_losses = []
    train_policy_losses = []
    val_q_losses = []
    val_policy_losses = []
    for _ in trange(variant['n_train_epochs']):
        total_q_loss = 0
        total_policy_loss = 0
        for (state, action, correct_action, q) in train_loader:
            state = state.to(DEVICE)
            action = action.to(DEVICE)
            correct_action = correct_action.to(DEVICE)
            q = q.to(DEVICE)
            q_preds = policy.critic.Q1(torch.cat([state, action], dim=-1))
            policy_preds = policy.actor(state).mean
            q_loss = F.mse_loss(q_preds, q)
            policy_loss = F.mse_loss(policy_preds, correct_action)
            qf_optimizer.zero_grad()
            policy_optimizer.zero_grad()
            q_loss.backward()
            policy_loss.backward()
            qf_optimizer.step()
            policy_optimizer.step()
            total_q_loss += q_loss.item()
            total_policy_loss += policy_loss.item()

        # get validation stats
        total_val_q_loss = 0
        total_val_policy_loss = 0
        with torch.no_grad():
            for (state, action, correct_action, q) in val_loader:
                state = state.to(DEVICE)
                action = action.to(DEVICE)
                correct_action = correct_action.to(DEVICE)
                q = q.to(DEVICE)
                q_preds = policy.critic.Q1(torch.cat([state, action], dim=-1))
                policy_preds = policy.actor(state).mean
                q_loss = F.mse_loss(q_preds, q)
                policy_loss = F.mse_loss(policy_preds, correct_action)
                total_val_q_loss += q_loss.item()
                total_val_policy_loss += policy_loss.item()

        train_q_losses.append(total_q_loss / len(train_loader))
        train_policy_losses.append(total_policy_loss / len(train_loader))
        val_q_losses.append(total_val_q_loss / len(val_loader))
        val_policy_losses.append(total_val_policy_loss / len(val_loader))
        print(f'train: qf loss: {train_q_losses[-1]:.4f}, policy loss: {train_policy_losses[-1]:.4f}')
        print(f'val: qf loss: {val_q_losses[-1]:.4f}, policy loss: {val_policy_losses[-1]:.4f}')

    # evaluate the resulting policy for 100 episodes
    eval_return = eval_policy(policy, variant['env'], variant['seed'], eval_episodes=variant['eval_episodes'])

    # save the results
    to_save = dict(
        train_q_losses=train_q_losses,
        train_policy_losses=train_policy_losses,
        val_q_losses=val_q_losses,
        val_policy_losses=val_policy_losses,
        eval_return=eval_return,
        qf=policy.critic.Q1.state_dict(),
        policy=policy.actor.state_dict()
    )
    torch.save(to_save, os.path.join(variant['replay_buffer_folder'], f'{variant["network_class"]}_distillation.pt'))
示例#2
0
def experiment(variant):
    print('CUDA status:', torch.cuda.is_available())
    env = make_env(variant['env'])

    # Set seeds
    variant['seed'] = int(variant['seed'])
    env.seed(int(variant['seed']))
    torch.manual_seed(int(variant['seed']))
    np.random.seed(int(variant['seed']))

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])

    kwargs = {
        "state_dim": state_dim,
        "action_dim": action_dim,
        "max_action": max_action,
        "discount": variant['discount'],
        "tau": variant['tau'],
        'network_class': NETWORK_CLASSES[variant['network_class']]
    }

    # custom network kwargs
    mlp_network_kwargs = dict(n_hidden=variant['n_hidden'],
                              hidden_dim=variant['hidden_dim'],
                              first_dim=variant['first_dim'])
    dropout_mlp_network_kwargs = dict(n_hidden=variant['n_hidden'],
                                      hidden_dim=variant['hidden_dim'],
                                      first_dim=variant['first_dim'],
                                      dropout_p=variant['dropout_p'])
    variable_init_mlp_network_kwargs = dict(n_hidden=variant['n_hidden'],
                                            hidden_dim=variant['hidden_dim'],
                                            first_dim=variant['first_dim'],
                                            sigma=variant['sigma'])
    fourier_network_kwargs = dict(
        n_hidden=variant['n_hidden'],
        hidden_dim=variant['hidden_dim'],
        fourier_dim=variant['fourier_dim'],
        sigma=variant['sigma'],
        concatenate_fourier=variant['concatenate_fourier'],
        train_B=variant['train_B'])
    siren_network_kwargs = dict(n_hidden=variant['n_hidden'],
                                hidden_dim=variant['hidden_dim'],
                                first_omega_0=variant['omega'],
                                hidden_omega_0=variant['omega'])
    if variant['network_class'] in {'MLP', 'D2RL', 'ConcatMLP', 'SpectralMLP'}:
        kwargs['network_kwargs'] = mlp_network_kwargs
    elif variant['network_class'] == 'DropoutMLP':
        kwargs['network_kwargs'] = dropout_mlp_network_kwargs
    elif variant['network_class'] == 'VariableInitMLP':
        kwargs['network_kwargs'] = variable_init_mlp_network_kwargs
    elif variant['network_class'] in {'FourierMLP', 'LogUniformFourierMLP'}:
        kwargs['network_kwargs'] = fourier_network_kwargs
    elif variant['network_class'] == 'Siren':
        kwargs['network_kwargs'] = siren_network_kwargs
    else:
        raise NotImplementedError

    # Initialize policy
    if variant['policy'] == "TD3":
        # Target policy smoothing is scaled wrt the action scale
        kwargs["policy_noise"] = variant['policy_noise * max_action']
        kwargs["noise_clip"] = variant['noise_clip * max_action']
        kwargs["policy_freq"] = variant['policy_freq']
        policy = TD3.TD3(**kwargs)
    elif variant['policy'] == "OurDDPG":
        policy = OurDDPG.DDPG(**kwargs)
    elif variant['policy'] == "DDPG":
        policy = DDPG.DDPG(**kwargs)
    elif variant['policy'] == "SAC":
        kwargs['lr'] = variant['lr']
        kwargs['alpha'] = variant['alpha']
        kwargs['automatic_entropy_tuning'] = variant[
            'automatic_entropy_tuning']
        kwargs['weight_decay'] = variant['weight_decay']
        # left out dmc
        policy = SAC(**kwargs)
    elif 'PytorchSAC' in variant['policy']:
        kwargs['action_range'] = [
            float(env.action_space.low.min()),
            float(env.action_space.high.max())
        ]
        kwargs['actor_lr'] = variant['lr']
        kwargs['critic_lr'] = variant['lr']
        kwargs['alpha_lr'] = variant['alpha_lr']
        kwargs['weight_decay'] = variant['weight_decay']
        kwargs['no_target'] = variant['no_target']
        kwargs['mlp_policy'] = variant['mlp_policy']
        kwargs['mlp_qf'] = variant['mlp_qf']
        del kwargs['max_action']
        if variant['policy'] == 'PytorchSAC':
            policy = PytorchSAC(**kwargs)
        elif variant['policy'] == 'RandomNoisePytorchSAC':
            kwargs['noise_dist'] = variant['noise_dist']
            kwargs['noise_scale'] = variant['noise_scale']
            policy = RandomNoiseSACAgent(**kwargs)
        elif variant['policy'] == 'SmoothedPytorchSAC':
            kwargs['n_critic_samples'] = variant['n_critic_samples']
            kwargs['noise_dist'] = variant['noise_dist']
            kwargs['noise_scale'] = variant['noise_scale']
            policy = SmoothedSACAgent(**kwargs)
        elif variant['policy'] == 'FuncRegPytorchSAC':
            kwargs['critic_target_update_frequency'] = args.critic_freq
            kwargs['fr_weight'] = args.fr_weight
            policy = FuncRegSACAgent(**kwargs)
    else:
        raise NotImplementedError

    if variant['load_model'] != "":
        policy_file = variant['load_model']
        # policy_file = file_name if variant['load_model'] == "default" else variant['load_model']
        policy.load(policy_file)

    replay_buffer = CustomReplayBuffer(state_dim,
                                       action_dim,
                                       max_size=int(variant['max_timesteps']))

    # fill replay buffer, save immediately
    state, done = env.reset(), False
    episode_reward = 0
    episode_timesteps = 0
    episode_num = 0
    curr_time = datetime.now()
    for t in trange(int(variant['max_timesteps'])):
        episode_timesteps += 1
        action = policy.select_action(np.array(state), evaluate=False)

        # Perform action
        next_state, reward, done, _ = env.step(action)
        replay_buffer.add(state, action)
        state = next_state
        episode_reward += reward

        if done or episode_timesteps > env._max_episode_steps:
            # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True
            print(
                f"Total T: {t + 1} Episode Num: {episode_num + 1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f}"
            )
            # Reset environment
            state, done = env.reset(), False
            episode_reward = 0
            episode_timesteps = 0
            episode_num += 1

    # save the replay buffer
    folder = os.path.dirname(policy_file)
    torch.save(replay_buffer, os.path.join(folder,
                                           'generated_replay_buffer.pt'))
    assert replay_buffer.max_size == replay_buffer.size

    # label the items in the replay buffer with my q-networks and policy
    with torch.no_grad():
        for start_idx in trange(0, replay_buffer.max_size,
                                variant['batch_size']):
            end_idx = start_idx + variant['batch_size']
            obs = torch.tensor(replay_buffer.state[start_idx:end_idx],
                               device=DEVICE,
                               dtype=torch.float32)
            action = torch.tensor(replay_buffer.action[start_idx:end_idx],
                                  device=DEVICE,
                                  dtype=torch.float32)
            actor_Q1, actor_Q2 = policy.critic(obs, action)
            actor_Q = torch.min(actor_Q1, actor_Q2)
            action = policy.actor(obs).mean.clamp(*policy.action_range)
            replay_buffer.set_values(start_idx, end_idx, to_np(actor_Q),
                                     to_np(action))

    # overwrite the bad replay buffer
    torch.save(replay_buffer, os.path.join(folder,
                                           'generated_replay_buffer.pt'))