Exemple #1
0
class SACAgent:
    def __init__(self, env_name, model, env_max_steps=0, min_steps_per_update=1, iters_per_update=100,
                 replay_batch_size=64, seed=0, gamma=0.95, polyak=0.995, alpha=0.2, sgd_batch_size=64,
                 sgd_lr=1e-3, exploration_steps=100, replay_buf_size=int(100000), normalize_steps=1000,
                 use_gpu=False, reward_stop=None, env_config={}, sgd_lr_sched=None):
        """
        Implements soft actor critic

        Args:
            env_name: name of the openAI gym environment to solve
            model: model from seagul.rl.models. Contains policy, value fn, q1_fn, q2_fn
            min_steps_per_update: minimun number of steps to take before running updates, will finish episodes before updating
            env_max_steps: number of steps the environment takes before finishing, if the environment emits a done signal before this we consider it a failure.
            iters_per_update: how many update steps to make every time we update
            replay_batch_size: how big a batch to pull from the replay buffer for each update
            seed: random seed for all rngs
            gamma: discount applied to future rewards, usually close to 1
            polyak: term determining how fast the target network is copied from the value function
            alpha: weighting term for the entropy. 0 corresponds to no penalty for deterministic policy
            sgd_batch_size: minibatch size for policy updates
            sgd_lr: initial learning rate for policy optimizer
            val_lr: initial learning rate for value optimizer
            q_lr: initial learning rate for q fn optimizer
            exploration_steps: initial number of random actions to take, aids exploration
            replay_buf_size: how big of a replay buffer to use
            use_gpu: determines if we try to use a GPU or not
            reward_stop: reward value to bail at
            env_config: dictionary containing kwargs to pass to your the environment
            sgd_lr_sched: list of sgd_lrs to interpolate between as training goes on

        """
        self.env_name = env_name
        self.model = model
        self.env_max_steps=env_max_steps
        self.min_steps_per_update = min_steps_per_update
        self.iters_per_update = iters_per_update
        self.replay_batch_size = replay_batch_size
        self.seed = seed
        self.gamma = gamma
        self.polyak = polyak
        self.alpha = alpha
        self.sgd_batch_size = sgd_batch_size
        self.sgd_lr = sgd_lr
        self.exploration_steps = exploration_steps
        self.replay_buf_size = replay_buf_size
        self.normalize_steps = normalize_steps
        self.use_gpu = use_gpu
        self.reward_stop = reward_stop
        self.env_config = env_config
        self.sgd_lr_sched = sgd_lr_sched

    def learn(self, train_steps):
        """
                runs sac for train_steps

                Returns:
                    model: trained model
                    avg_reward_hist: list with the average reward per episode at each epoch
                    var_dict: dictionary with all locals, for logging/debugging purposes
                """

        torch.set_num_threads(1) # performance issue with data loader

        env = gym.make(self.env_name, **self.env_config)
        if isinstance(env.action_space, gym.spaces.Box):
            act_size = env.action_space.shape[0]
            act_dtype = env.action_space.sample().dtype
        else:
            raise NotImplementedError("trying to use unsupported action space", env.action_space)

        obs_size = env.observation_space.shape[0]

        random_model = RandModel(self.model.act_limit, act_size)
        self.replay_buf = ReplayBuffer(obs_size, act_size, self.replay_buf_size)
        self.target_value_fn = copy.deepcopy(self.model.value_fn)

        pol_opt = torch.optim.Adam(self.model.policy.parameters(), lr=self.sgd_lr)
        val_opt = torch.optim.Adam(self.model.value_fn.parameters(), lr=self.sgd_lr)
        q1_opt = torch.optim.Adam(self.model.q1_fn.parameters(), lr=self.sgd_lr)
        q2_opt = torch.optim.Adam(self.model.q2_fn.parameters(), lr=self.sgd_lr)

        if self.sgd_lr_sched:
            sgd_lookup = make_schedule(self.sgd_lr_sched, train_steps)
        else:
            sgd_lookup = None


        # seed all our RNGs
        env.seed(self.seed)
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)

        # set defaults, and decide if we are using a GPU or not
        use_cuda = torch.cuda.is_available() and self.use_gpu
        device = torch.device("cuda:0" if use_cuda else "cpu")

        self.raw_rew_hist = []
        self.val_loss_hist = []
        self.pol_loss_hist = []
        self.q1_loss_hist = []
        self.q2_loss_hist = []

        progress_bar = tqdm.tqdm(total=train_steps + self.normalize_steps)
        cur_total_steps = 0
        progress_bar.update(0)
        early_stop = False
        norm_obs1 = torch.empty(0)

        while cur_total_steps < self.normalize_steps:
            ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, self.env_max_steps)
            norm_obs1 = torch.cat((norm_obs1, ep_obs1))

            ep_steps = ep_rews.shape[0]
            cur_total_steps += ep_steps

            progress_bar.update(ep_steps)
        if self.normalize_steps > 0:
            obs_mean = norm_obs1.mean(axis=0)
            obs_std  = norm_obs1.std(axis=0)
            obs_std[torch.isinf(1/obs_std)] = 1

            self.model.policy.state_means = obs_mean
            self.model.policy.state_std  =  obs_std
            self.model.value_fn.state_means = obs_mean
            self.model.value_fn.state_std = obs_std
            self.target_value_fn.state_means = obs_mean
            self.target_value_fn.state_std = obs_std

            self.model.q1_fn.state_means = torch.cat((obs_mean, torch.zeros(act_size)))
            self.model.q1_fn.state_std = torch.cat((obs_std, torch.ones(act_size)))
            self.model.q2_fn.state_means = self.model.q1_fn.state_means
            self.model.q2_fn.state_std = self.model.q1_fn.state_std

        while cur_total_steps < self.exploration_steps:
            ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, self.env_max_steps)
            self.replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done)

            ep_steps = ep_rews.shape[0]
            cur_total_steps += ep_steps

            progress_bar.update(ep_steps)

        while cur_total_steps < train_steps:
            cur_batch_steps = 0

            # Bail out if we have met out reward threshold
            if len(self.raw_rew_hist) > 2 and self.reward_stop:
                if self.raw_rew_hist[-1] >= self.reward_stop and self.raw_rew_hist[-2] >= self.reward_stop:
                    early_stop = True
                    break

            # collect data with the current policy
            # ========================================================================
            while cur_batch_steps < self.min_steps_per_update:
                ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, self.model, self.env_max_steps)
                self.replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done)

                ep_steps = ep_rews.shape[0]
                cur_batch_steps += ep_steps
                cur_total_steps += ep_steps

                self.raw_rew_hist.append(torch.sum(ep_rews))
                #print(self.raw_rew_hist[-1])


            progress_bar.update(cur_batch_steps)

            for _ in range(min(int(ep_steps), self.iters_per_update)):

                torch.autograd.set_grad_enabled(False)
                # compute targets for Q and V
                # ========================================================================
                replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = self.replay_buf.sample_batch(self.replay_batch_size)

                q_targ = replay_rews + self.gamma * (1 - replay_done) * self.target_value_fn(replay_obs2)

                noise = torch.randn(self.replay_batch_size, act_size)
                sample_acts, sample_logp = self.model.select_action(replay_obs1, noise)

                q_in = torch.cat((replay_obs1, sample_acts), dim=1)
                q_preds = torch.cat((self.model.q1_fn(q_in), self.model.q2_fn(q_in)), dim=1)
                q_min, q_min_idx = torch.min(q_preds, dim=1)
                q_min = q_min.reshape(-1,1)

                v_targ = q_min - self.alpha * sample_logp
                #v_targ = v_targ

                torch.autograd.set_grad_enabled(True)

                # q_fn update
                # ========================================================================
                num_mbatch = int(self.replay_batch_size / self.sgd_batch_size)

                for i in range(num_mbatch):
                    cur_sample = i*self.sgd_batch_size

                    q_in = torch.cat((replay_obs1[cur_sample:cur_sample + self.sgd_batch_size], replay_acts[cur_sample:cur_sample + self.sgd_batch_size]), dim=1)
                    q1_preds = self.model.q1_fn(q_in)
                    q2_preds = self.model.q2_fn(q_in)
                    q1_loss = torch.pow(q1_preds - q_targ[cur_sample:cur_sample + self.sgd_batch_size], 2).mean()
                    q2_loss = torch.pow(q2_preds - q_targ[cur_sample:cur_sample + self.sgd_batch_size], 2).mean()
                    q_loss = q1_loss + q2_loss

                    q1_opt.zero_grad()
                    q2_opt.zero_grad()
                    q_loss.backward()
                    q1_opt.step()
                    q2_opt.step()

                # val_fn update
                # ========================================================================
                for i in range(num_mbatch):
                    cur_sample = i*self.sgd_batch_size

                    # predict and calculate loss for the batch
                    val_preds = self.model.value_fn(replay_obs1[cur_sample:cur_sample + self.sgd_batch_size])
                    val_loss = torch.pow(val_preds - v_targ[cur_sample:cur_sample + self.sgd_batch_size], 2).mean()

                    # do the normal pytorch update
                    val_opt.zero_grad()
                    val_loss.backward()
                    val_opt.step()

                # policy_fn update
                # ========================================================================
                for param in self.model.q1_fn.parameters():
                    param.requires_grad = False

                for i in range(num_mbatch):
                    cur_sample = i*self.sgd_batch_size

                    noise = torch.randn(replay_obs1[cur_sample:cur_sample + self.sgd_batch_size].shape[0], act_size)
                    local_acts, local_logp = self.model.select_action(replay_obs1[cur_sample:cur_sample + self.sgd_batch_size], noise)

                    q_in = torch.cat((replay_obs1[cur_sample:cur_sample + self.sgd_batch_size], local_acts), dim=1)
                    pol_loss = (self.alpha * local_logp - self.model.q1_fn(q_in)).mean()

                    pol_opt.zero_grad()
                    pol_loss.backward()
                    pol_opt.step()

                for param in self.model.q1_fn.parameters():
                    param.requires_grad = True

                # Update target value fn with polyak average
                # ========================================================================
                self.val_loss_hist.append(val_loss.item())
                self.pol_loss_hist.append(pol_loss.item())
                self.q1_loss_hist.append(q1_loss.item())
                self.q2_loss_hist.append(q2_loss.item())

                val_sd = self.model.value_fn.state_dict()
                tar_sd = self.target_value_fn.state_dict()
                for layer in tar_sd:
                    tar_sd[layer] = self.polyak * tar_sd[layer] + (1 - self.polyak) * val_sd[layer]

                self.target_value_fn.load_state_dict(tar_sd)


            #Update LRs
            if sgd_lookup:
                pol_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps)
                val_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps)
                q1_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps)
                q2_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps)

        return self.model, self.raw_rew_hist, locals()
Exemple #2
0
def td3(
        env_name,
        train_steps,
        model,
        env_max_steps=0,
        min_steps_per_update=1,
        iters_per_update=200,
        replay_batch_size=64,
        seed=0,
        act_std_schedule=(.1,),
        gamma=0.95,
        polyak=0.995,
        sgd_batch_size=64,
        sgd_lr=3e-4,
        exploration_steps=1000,
        replay_buf_size=int(100000),
        reward_stop=None,
        env_config=None
):
    # Initialize env, and other globals
    # ========================================================================
    if env_config is None:
        env_config = {}
    env = gym.make(env_name, **env_config)
    if isinstance(env.action_space, gym.spaces.Box):
        act_size = env.action_space.shape[0]
        act_dtype = env.action_space.sample().dtype
    else:
        raise NotImplementedError("trying to use unsupported action space", env.action_space)

    obs_size = env.observation_space.shape[0]

    # seed all our RNGs
    env.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    random_model = RandModel(model.act_limit, act_size)
    replay_buf = ReplayBuffer(obs_size, act_size, replay_buf_size)
    target_q1_fn = dill.loads(dill.dumps(model.q1_fn))
    target_q2_fn = dill.loads(dill.dumps(model.q2_fn))
    target_policy = dill.loads(dill.dumps(model.policy))

    for param in target_q1_fn.parameters():
        param.requires_grad = False

    for param in target_q2_fn.parameters():
        param.requires_grad = False

    for param in target_policy.parameters():
        param.requires_grad = False

    act_std_lookup = make_schedule(act_std_schedule, train_steps)
    act_std = act_std_lookup(0)

    pol_opt = torch.optim.Adam(model.policy.parameters(), lr=sgd_lr)
    q1_opt = torch.optim.Adam(model.q1_fn.parameters(), lr=sgd_lr)
    q2_opt = torch.optim.Adam(model.q2_fn.parameters(), lr=sgd_lr)

    progress_bar = tqdm.tqdm(total=train_steps)
    cur_total_steps = 0
    progress_bar.update(0)
    early_stop = False

    raw_rew_hist = []
    pol_loss_hist = []
    q1_loss_hist = []
    q2_loss_hist = []

    # Fill the replay buffer with actions taken from a random model
    # ========================================================================
    while cur_total_steps < exploration_steps:
        ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, env_max_steps, act_std)
        replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done)

        ep_steps = ep_rews.shape[0]
        cur_total_steps += ep_steps

        progress_bar.update(ep_steps)

    # Keep training until we take train_step environment steps
    # ========================================================================
    while cur_total_steps < train_steps:
        cur_batch_steps = 0

        # Bail out if we have met out reward threshold
        if len(raw_rew_hist) > 2 and reward_stop:
            print(raw_rew_hist[-1])
            if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[-2] >= reward_stop:
                early_stop = True
                break

        # collect data with the current policy
        # ========================================================================
        while cur_batch_steps < min_steps_per_update:
            ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, model, env_max_steps, act_std)
            replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done)

            ep_steps = ep_rews.shape[0]
            cur_batch_steps += ep_steps
            cur_total_steps += ep_steps

            raw_rew_hist.append(torch.sum(ep_rews))

        progress_bar.update(cur_batch_steps)

        # Do the update
        # ========================================================================
        for _ in range(min(int(ep_steps), iters_per_update)):

            # Compute target Q
            replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = replay_buf.sample_batch(replay_batch_size)

            with torch.no_grad():
                acts_from_target = target_policy(replay_obs2)
                q_in = torch.cat((replay_obs2, acts_from_target), dim=1)
                q_targ = replay_rews + gamma*(1 - replay_done)*target_q1_fn(q_in)

            num_mbatch = int(replay_batch_size / sgd_batch_size)

            # q_fn update
            # ========================================================================
            for i in range(num_mbatch):
                cur_sample = i * sgd_batch_size

                q_in_local = torch.cat((replay_obs1[cur_sample:cur_sample + sgd_batch_size], replay_acts[cur_sample:cur_sample + sgd_batch_size]), dim=1)
                local_qtarg = q_targ[cur_sample:cur_sample + sgd_batch_size]

                q1_loss = ((model.q1_fn(q_in_local) - local_qtarg)**2).mean()

                #q2_preds = model.q2_fn(q_in)
                #q2_loss = (q2_preds - q_targ[cur_sample:cur_sample + sgd_batch_size]**2).mean()
                q_loss = q1_loss# + q2_loss

                q1_opt.zero_grad()
                #q2_opt.zero_grad()
                q_loss.backward()
                q1_opt.step()
                #q2_opt.step()

            # policy_fn update
            # ========================================================================
            for param in model.q1_fn.parameters():
                param.requires_grad = False

            for i in range(num_mbatch):
                cur_sample = i * sgd_batch_size
                local_obs = replay_obs1[cur_sample:cur_sample + sgd_batch_size]
                local_acts = model.policy(local_obs)
                q_in = torch.cat((local_obs, local_acts), dim=1)

                pol_loss = -(model.q1_fn(q_in).mean())

                pol_opt.zero_grad()
                pol_loss.backward()
                pol_opt.step()

            for param in model.q1_fn.parameters():
                param.requires_grad = True

            # Update target value fn with polyak average
            # ========================================================================
            pol_loss_hist.append(pol_loss.item())
            q1_loss_hist.append(q1_loss.item())
            #q2_loss_hist.append(q2_loss.item())

            target_q1_fn = update_target_fn(model.q1_fn, target_q1_fn, polyak)
            target_q2_fn = update_target_fn(model.q2_fn, target_q2_fn, polyak)
            target_policy = update_target_fn(model.policy, target_policy, polyak)
            act_std = act_std_lookup(cur_total_steps)

    return model, raw_rew_hist, locals()
Exemple #3
0
def sac_switched(
    env_name,
    total_steps,
    model,
    env_steps=0,
    min_steps_per_update=1,
    iters_per_update=100,
    replay_batch_size=64,
    seed=0,
    gamma=0.95,
    polyak=0.995,
    alpha=0.2,
    sgd_batch_size=64,
    sgd_lr=1e-3,
    exploration_steps=100,
    replay_buf_size=int(100000),
    use_gpu=False,
    reward_stop=None,
    goal_state=np.array([np.pi / 2, 0, 0, 0]),
    goal_lookback=10,
    goal_thresh=1,
    needle_lookup_prob=.5,
    gate_update_freq=500,
    gate_x=None,
    gate_y=None,
    gate_lr=1e-5,
    gate_w=1e-2,
    gate_epochs=1,
    env_config={},
):
    """
    Implements soft actor critic

    Args:
        env_name: name of the openAI gym environment to solve
        total_steps: number of timesteps to run the PPO for
        model: model from seagul.rl.models. Contains policy, value fn, q1_fn, q2_fn
        min_steps_per_update: minimun number of steps to take before running updates, will finish episodes before updating
        env_steps: number of steps the environment takes before finishing, if the environment emits a done signal before this we consider it a failure.
        iters_per_update: how many update steps to make every time we update
        replay_batch_size: how big a batch to pull from the replay buffer for each update
        seed: random seed for all rngs
        gamma: discount applied to future rewards, usually close to 1
        polyak: term determining how fast the target network is copied from the value function
        alpha: weighting term for the entropy. 0 corresponds to no penalty for deterministic policy
        sgd_batch_size: minibatch size for policy updates
        sgd_lr: initial learning rate for policy optimizer
        val_lr: initial learning rate for value optimizer
        q_lr: initial learning rate for q fn optimizer
        exploration_steps: initial number of random actions to take, aids exploration
        replay_buf_size: how big of a replay buffer to use
        use_gpu: determines if we try to use a GPU or not
        reward_stop: reward value to bail at
        env_config: dictionary containing kwargs to pass to your the environment

    Returns:
        model: trained model
        avg_reward_hist: list with the average reward per episode at each epoch
        var_dict: dictionary with all locals, for logging/debugging purposes

    Example:
        from seagul.rl.algos.sac import sac
        import torch.nn as nn
        from seagul.nn import MLP
        from seagul.rl.models import SACModel

        input_size = 3
        output_size = 1
        layer_size = 64
        num_layers = 2
        activation = nn.ReLU

        policy = MLP(input_size, output_size*2, num_layers, layer_size, activation)
        value_fn = MLP(input_size, 1, num_layers, layer_size, activation)
        q1_fn = MLP(input_size + output_size, 1, num_layers, layer_size, activation)
        q2_fn = MLP(input_size + output_size, 1, num_layers, layer_size, activation)
        model = SACModel(policy, value_fn, q1_fn, q2_fn, 1)

        model, rews, var_dict = sac("Pendulum-v0", 10000, model)
    """

    args = locals()
    torch.set_num_threads(1)

    env = gym.make(env_name, **env_config)
    if isinstance(env.action_space, gym.spaces.Box):
        act_size = env.action_space.shape[0]
        act_dtype = env.action_space.sample().dtype
    else:
        raise NotImplementedError("trying to use unsupported action space",
                                  env.action_space)

    obs_size = env.observation_space.shape[0]

    random_model = dill.loads(dill.dumps(model))
    random_model.swingup_controller = lambda x: torch.rand(
        model.num_acts) * 2 * model.act_limit - model.act_limit

    replay_buf = ReplayBuffer(obs_size, act_size, replay_buf_size)
    needle_buf = ReplayBuffer(obs_size, act_size, replay_buf_size)
    target_value_fn = dill.loads(dill.dumps(model.value_fn))

    pol_opt = torch.optim.Adam(model.policy.parameters(), lr=sgd_lr)
    val_opt = torch.optim.Adam(model.value_fn.parameters(), lr=sgd_lr)
    q1_opt = torch.optim.Adam(model.q1_fn.parameters(), lr=sgd_lr)
    q2_opt = torch.optim.Adam(model.q2_fn.parameters(), lr=sgd_lr)

    # seed all our RNGs
    env.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    # set defaults, and decide if we are using a GPU or not
    use_cuda = torch.cuda.is_available() and use_gpu
    device = torch.device("cuda:0" if use_cuda else "cpu")

    raw_rew_hist = []
    val_loss_hist = []
    pol_loss_hist = []
    q1_loss_hist = []
    q2_loss_hist = []

    progress_bar = tqdm.tqdm(total=total_steps)
    cur_total_steps = 0
    gate_update_counter = 0
    progress_bar.update(0)
    early_stop = False

    needle_count = 0
    not_needle_count = 0
    while cur_total_steps < exploration_steps:
        ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done, ep_path = do_rollout(
            env, random_model, env_steps)
        in_goal = torch.sum(torch.sqrt(
            (ep_obs2[-goal_lookback:] - goal_state)**2),
                            axis=1) < goal_thresh

        if in_goal.all():
            needle_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done)

        replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done)

        ep_steps = ep_rews.shape[0]
        cur_total_steps += ep_steps

    progress_bar.update(cur_total_steps)

    while cur_total_steps < total_steps:
        cur_batch_steps = 0

        # Bail out if we have met out reward threshold
        if len(raw_rew_hist) > 2 and reward_stop:
            if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[
                    -2] >= reward_stop:
                early_stop = True
                break

        # Transfer back to CPU, which is faster for rollouts
        model = model.to('cpu')
        target_value_fn = target_value_fn.to('cpu')

        # collect data with the current policy
        # ========================================================================
        while cur_batch_steps < min_steps_per_update:
            ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done, ep_path = do_rollout(
                env, model, env_steps)

            in_goal = torch.sum(torch.sqrt(
                (ep_obs2[-goal_lookback:] - goal_state)**2),
                                axis=1) < goal_thresh

            if in_goal.all():
                needle_count += 1
                needle_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done)
            else:
                not_needle_count += 1
                replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done)

            if ep_path.sum() != 0:
                reverse_obs = np.flip(ep_obs1.numpy(), 0).copy()
                reverse_obs = torch.from_numpy(reverse_obs)

                reverse_path = np.flip(ep_path.numpy(), 0).copy()
                reverse_path = torch.from_numpy(reverse_path)

                if in_goal.all():
                    for path, obs in zip(reverse_path, reverse_obs):
                        if not path:
                            break
                        else:
                            gate_x = torch.cat((gate_x, obs.reshape(1, -1)))
                            gate_y = torch.cat(
                                (gate_y, torch.ones((1, 1),
                                                    dtype=torch.float32)))
                else:
                    for path, obs in zip(reverse_path, reverse_obs):
                        if not path:
                            pass
                        else:
                            gate_x = torch.cat((gate_x, obs.reshape(1, -1)))
                            gate_y = torch.cat((gate_y,
                                                torch.zeros(
                                                    (1, 1),
                                                    dtype=torch.float32)))

            ep_steps = ep_rews.shape[0]
            cur_batch_steps += ep_steps
            cur_total_steps += ep_steps
            gate_update_counter += ep_steps

            raw_rew_hist.append(torch.sum(ep_rews))

            progress_bar.update(ep_steps)

        print("needle/normal: ", str(needle_buf.size), str(replay_buf.size))

        if gate_update_counter > gate_update_freq:
            # For training, transfer model to GPU
            model = model.to('cuda:0')
            target_value_fn = target_value_fn.to('cuda:0')

            class_weight = (gate_y.shape[0] / sum(gate_y) *
                            gate_w).to('cuda:0')
            gate_loss = fit_model(
                model.gate_fn,
                gate_x,
                gate_y,
                gate_epochs,
                use_tqdm=False,
                use_cuda=True,
                batch_size=8192,
                loss_fn=torch.nn.BCEWithLogitsLoss(pos_weight=class_weight),
                learning_rate=gate_lr)

            print("gate updated: " + str(gate_y.shape[0]) + "  " +
                  str(sum(gate_y)))

            model = model.to('cpu')
            target_value_fn = target_value_fn.to('cpu')
            gate_update_counter = 0

        for _ in range(min(int(ep_steps), iters_per_update)):
            # compute targets for Q and V
            # ========================================================================

            p = np.random.random_sample(1)
            if p > needle_lookup_prob and needle_buf.size > 0:
                replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = needle_buf.sample_batch(
                    replay_batch_size)
            else:
                replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = replay_buf.sample_batch(
                    replay_batch_size)

            replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = \
            [replay_obs1.to(device),
             replay_obs2.to(device),
             replay_acts.to(device),
             replay_rews.to(device),
             replay_done.to(device)]

            q_targ = replay_rews + gamma * (
                1 - replay_done) * target_value_fn(replay_obs2)
            q_targ = q_targ.detach()

            noise = torch.randn(replay_batch_size, act_size).to(device)
            sample_acts, sample_logp = model.select_action_parallel(
                replay_obs1, noise)

            q_in = torch.cat((replay_obs1, sample_acts), dim=1)

            q_preds = torch.cat((model.q1_fn(q_in), model.q2_fn(q_in)), dim=1)
            q_min, q_min_idx = torch.min(q_preds, dim=1)
            q_min = q_min.reshape(-1, 1)

            v_targ = q_min - alpha * sample_logp
            v_targ = v_targ.detach()

            # q_fn update
            # ========================================================================
            training_data = data.TensorDataset(replay_obs1, replay_acts,
                                               q_targ)
            training_generator = data.DataLoader(training_data,
                                                 batch_size=sgd_batch_size,
                                                 shuffle=True,
                                                 num_workers=0,
                                                 pin_memory=False)

            for local_obs, local_acts, local_qtarg in training_generator:
                # Transfer to GPU (if GPU is enabled, else this does nothing)
                local_obs, local_acts, local_qtarg = (
                    local_obs.to(device),
                    local_acts.to(device),
                    local_qtarg.to(device),
                )

                q_in = torch.cat((local_obs, local_acts), dim=1)
                q1_preds = model.q1_fn(q_in)
                q2_preds = model.q2_fn(q_in)
                q1_loss = torch.pow(q1_preds - local_qtarg, 2).mean()
                q2_loss = torch.pow(q2_preds - local_qtarg, 2).mean()
                q_loss = q1_loss + q2_loss

                q1_opt.zero_grad()
                q2_opt.zero_grad()
                q_loss.backward()
                q1_opt.step()
                q2_opt.step()

            # val_fn update
            # ========================================================================
            training_data = data.TensorDataset(replay_obs1, v_targ)
            training_generator = data.DataLoader(training_data,
                                                 batch_size=sgd_batch_size,
                                                 shuffle=True,
                                                 num_workers=0,
                                                 pin_memory=False)

            for local_obs, local_vtarg in training_generator:
                # Transfer to GPU (if GPU is enabled, else this does nothing)
                local_obs, local_vtarg = (local_obs.to(device),
                                          local_vtarg.to(device))

                # predict and calculate loss for the batch
                val_preds = model.value_fn(local_obs)
                val_loss = torch.sum(torch.pow(val_preds - local_vtarg,
                                               2)) / replay_batch_size

                # do the normal pytorch update
                val_opt.zero_grad()
                val_loss.backward()
                val_opt.step()

            # policy_fn update
            # ========================================================================
            training_data = data.TensorDataset(replay_obs1)
            training_generator = data.DataLoader(training_data,
                                                 batch_size=sgd_batch_size,
                                                 shuffle=True,
                                                 num_workers=0,
                                                 pin_memory=False)

            for local_obs in training_generator:
                # Transfer to GPU (if GPU is enabled, else this does nothing)
                local_obs = local_obs[0].to(device)

                noise = torch.randn(local_obs.shape[0], act_size).to(device)
                local_acts, local_logp = model.select_action_parallel(
                    local_obs, noise)

                q_in = torch.cat((local_obs, local_acts), dim=1)
                pol_loss = torch.sum(alpha * local_logp -
                                     model.q1_fn(q_in)) / replay_batch_size

                # do the normal pytorch update
                pol_opt.zero_grad()
                pol_loss.backward()
                pol_opt.step()

            # Update target networks
            # ========================================================================
            val_sd = model.value_fn.state_dict()
            tar_sd = target_value_fn.state_dict()
            for layer in tar_sd:
                tar_sd[layer] = polyak * tar_sd[layer] + (
                    1 - polyak) * val_sd[layer]

            target_value_fn.load_state_dict(tar_sd)

    return model, raw_rew_hist, locals()
Exemple #4
0
def sac_sym(
        env_name,
        total_steps,
        model,
        env_steps=0,
        min_steps_per_update=1,
        iters_per_update=100,
        replay_batch_size=64,
        seed=0,
        gamma=0.95,
        polyak=0.995,
        alpha=0.2,
        pol_batch_size=64,
        val_batch_size=64,
        q_batch_size=64,
        pol_lr=1e-3,
        val_lr=1e-3,
        q_lr=1e-3,
        exploration_steps=100,
        replay_buf_size=int(50000),
        use_gpu=False,
        reward_stop=None,
):
    """
    Implements soft actor critic

    Args:
        env_name: name of the openAI gym environment to solve
        total_steps: number of timesteps to run the PPO for
        model: model from seagul.rl.models. Contains policy, value fn, q1_fn, q2_fn
        min_steps_per_update: minimun number of steps to take before running updates, will finish episodes before updating
        env_steps: number of steps the environment takes before finishing, if the environment emits a done signal before this we consider it a failure.
        iters_per_update: how many update steps to make every time we update
        replay_batch_size: how big a batch to pull from the replay buffer for each update
        seed: random seed for all rngs
        gamma: discount applied to future rewards, usually close to 1
        polyak: term determining how fast the target network is copied from the value function
        alpha: weighting term for the entropy. 0 corresponds to no penalty for deterministic policy
        pol_batch_size: minibatch size for policy updates
        val_batch_size: minibatch size for value fn updates
        q_batch_size: minibatch size for q fn updates
        pol_lr: initial learning rate for policy optimizer
        val_lr: initial learning rate for value optimizer
        q_lr: initial learning rate for q fn optimizer
        exploration_steps: initial number of random actions to take, aids exploration
        replay_buf_size: how big of a replay buffer to use
        use_gpu: determines if we try to use a GPU or not
        reward_stop: reward value to bail at

    Returns:
        model: trained model
        avg_reward_hist: list with the average reward per episode at each epoch
        var_dict: dictionary with all locals, for logging/debugging purposes

    Example:
        from seagul.rl.algos.sac import sac
        import torch.nn as nn
        from seagul.nn import MLP
        from seagul.rl.models import SACModel


        input_size = 3
        output_size = 1
        layer_size = 64
        num_layers = 2
        activation = nn.ReLU

        policy = MLP(input_size, output_size*2, num_layers, layer_size, activation)
        value_fn = MLP(input_size, 1, num_layers, layer_size, activation)
        q1_fn = MLP(input_size + output_size, 1, num_layers, layer_size, activation)
        q2_fn = MLP(input_size + output_size, 1, num_layers, layer_size, activation)
        model = SACModel(policy, value_fn, q1_fn, q2_fn, 1)


        model, rews, var_dict = sac("Pendulum-v0", 10000, model)
    """

    env = gym.make(env_name)
    if isinstance(env.action_space, gym.spaces.Box):
        act_size = env.action_space.shape[0]
        act_dtype = env.action_space.sample().dtype
    else:
        raise NotImplementedError("trying to use unsupported action space",
                                  env.action_space)

    obs_size = env.observation_space.shape[0]

    random_model = RandModel(model.act_limit, act_size)
    replay_buf = ReplayBuffer(obs_size, act_size, replay_buf_size)
    target_value_fn = dill.loads(dill.dumps(model.value_fn))

    pol_opt = torch.optim.Adam(model.policy.parameters(), lr=pol_lr)
    val_opt = torch.optim.Adam(model.value_fn.parameters(), lr=val_lr)
    q1_opt = torch.optim.Adam(model.q1_fn.parameters(), lr=q_lr)
    q2_opt = torch.optim.Adam(model.q2_fn.parameters(), lr=q_lr)

    # seed all our RNGs
    env.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    # set defaults, and decide if we are using a GPU or not
    use_cuda = torch.cuda.is_available() and use_gpu
    device = torch.device("cuda:0" if use_cuda else "cpu")

    raw_rew_hist = []
    val_loss_hist = []
    pol_loss_hist = []
    q1_loss_hist = []
    q2_loss_hist = []

    #progress_bar = tqdm.tqdm(total=total_steps)
    cur_total_steps = 0
    #progress_bar.update(0)
    early_stop = False

    while cur_total_steps < exploration_steps:
        ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(
            env, random_model, env_steps)

        # can def be made more efficient if found to be a bottleneck

        for obs1, obs2, acts, rews, done in zip(ep_obs1, ep_obs2, ep_acts,
                                                ep_rews, ep_done):
            replay_buf.store(obs1, obs2, acts, rews, done)
            replay_buf.store(mirror_obs(obs1), mirror_obs(obs2),
                             mirror_act(acts), rews, done)

        ep_steps = ep_rews.shape[0] * 2
        cur_total_steps += ep_steps

# progress_bar.update(cur_total_steps)

    while cur_total_steps < total_steps:
        cur_batch_steps = 0

        # Bail out if we have met out reward threshold
        if len(raw_rew_hist) > 2 and reward_stop:
            if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[
                    -2] >= reward_stop:
                early_stop = True
                break

        # collect data with the current policy
        # ========================================================================
        while cur_batch_steps < min_steps_per_update:

            ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(
                env, model, env_steps)

            # can def be made more efficient if found to be a bottleneck
            for obs1, obs2, acts, rews, done in zip(ep_obs1, ep_obs2, ep_acts,
                                                    ep_rews, ep_done):
                replay_buf.store(obs1, obs2, acts, rews, done)
                replay_buf.store(mirror_obs(obs1), mirror_obs(obs2),
                                 mirror_act(acts), rews, done)

            ep_steps = ep_rews.shape[0] * 2
            cur_batch_steps += ep_steps
            cur_total_steps += ep_steps

        raw_rew_hist.append(torch.sum(ep_rews))
        #progress_bar.update(cur_batch_steps)

        for _ in range(ep_steps):
            # compute targets for Q and V
            # ========================================================================
            replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = replay_buf.sample_batch(
                replay_batch_size)

            q_targ = replay_rews + gamma * (
                1 - replay_done) * target_value_fn(replay_obs2)
            q_targ = q_targ.detach()

            noise = torch.randn(replay_batch_size, act_size)
            sample_acts, sample_logp = model.select_action(replay_obs1, noise)

            q_in = torch.cat((replay_obs1, sample_acts), dim=1)
            q_preds = torch.cat((model.q1_fn(q_in), model.q2_fn(q_in)), dim=1)
            q_min, q_min_idx = torch.min(q_preds, dim=1)
            q_min = q_min.reshape(-1, 1)

            v_targ = q_min - alpha * sample_logp
            v_targ = v_targ.detach()

            # q_fn update
            # ========================================================================
            training_data = data.TensorDataset(replay_obs1, replay_acts,
                                               q_targ)
            training_generator = data.DataLoader(training_data,
                                                 batch_size=q_batch_size,
                                                 shuffle=False)

            for local_obs, local_acts, local_qtarg in training_generator:
                # Transfer to GPU (if GPU is enabled, else this does nothing)
                local_obs, local_acts, local_qtarg = (
                    local_obs.to(device),
                    local_acts.to(device),
                    local_qtarg.to(device),
                )

                q_in = torch.cat((local_obs, local_acts), dim=1)
                q1_preds = model.q1_fn(q_in)
                q2_preds = model.q2_fn(q_in)
                q1_loss = torch.pow(q1_preds - local_qtarg, 2).mean()
                q2_loss = torch.pow(q2_preds - local_qtarg, 2).mean()
                q_loss = q1_loss + q2_loss

                q1_opt.zero_grad()
                q2_opt.zero_grad()
                q_loss.backward()
                q1_opt.step()
                q2_opt.step()

            # val_fn update
            # ========================================================================
            training_data = data.TensorDataset(replay_obs1, v_targ)
            training_generator = data.DataLoader(training_data,
                                                 batch_size=q_batch_size,
                                                 shuffle=False)

            for local_obs, local_vtarg in training_generator:
                # Transfer to GPU (if GPU is enabled, else this does nothing)
                local_obs, local_vtarg = (local_obs.to(device),
                                          local_vtarg.to(device))

                # predict and calculate loss for the batch
                val_preds = model.value_fn(local_obs)
                val_loss = torch.sum(torch.pow(val_preds - local_vtarg,
                                               2)) / replay_batch_size

                # do the normal pytorch update
                val_opt.zero_grad()
                val_loss.backward()
                val_opt.step()

            # policy_fn update
            # ========================================================================
            training_data = data.TensorDataset(replay_obs1)
            training_generator = data.DataLoader(training_data,
                                                 batch_size=pol_batch_size,
                                                 shuffle=False)

            for local_obs in training_generator:
                # Transfer to GPU (if GPU is enabled, else this does nothing)
                local_obs = local_obs[0].to(device)

                noise = torch.randn(pol_batch_size, act_size)
                local_acts, local_logp = model.select_action(local_obs, noise)

                q_in = torch.cat((local_obs, local_acts), dim=1)
                pol_loss = torch.sum(alpha * local_logp -
                                     model.q1_fn(q_in)) / replay_batch_size

                # do the normal pytorch update
                pol_opt.zero_grad()
                pol_loss.backward()
                pol_opt.step()

            # Update target value fn with polyak average
            # ========================================================================
            val_loss_hist.append(val_loss.item())
            pol_loss_hist.append(pol_loss.item())
            q1_loss_hist.append(q1_loss.item())
            q2_loss_hist.append(q2_loss.item())

            #
            # model.policy.state_means = update_mean(replay_obs1, model.policy.state_means, cur_total_steps)
            # model.policy.state_var = update_var(replay_obs1, model.policy.state_var, cur_total_steps)
            # model.value_fn.state_means = model.policy.state_means
            # model.value_fn.state_var = model.policy.state_var
            #
            # model.q1_fn.state_means = update_mean(torch.cat((replay_obs1, replay_acts.detach()), dim=1), model.q1_fn.state_means, cur_total_steps)
            # model.q1_fn.state_var = update_var(torch.cat((replay_obs1, replay_acts.detach()), dim=1), model.q1_fn.state_var, cur_total_steps)
            # model.q2_fn.state_means = model.q1_fn.state_means
            # model.q2_fn.state_var = model.q1_fn.state_var

            val_sd = model.value_fn.state_dict()
            tar_sd = target_value_fn.state_dict()
            for layer in tar_sd:
                tar_sd[layer] = polyak * tar_sd[layer] + (
                    1 - polyak) * val_sd[layer]

            target_value_fn.load_state_dict(tar_sd)

    return (model, raw_rew_hist, locals())
Exemple #5
0
def sac(
        env_name,
        train_steps,
        model,
        env_max_steps=0,
        min_steps_per_update=1,
        iters_per_update=100,
        replay_batch_size=64,
        seed=0,
        gamma=0.95,
        polyak=0.995,
        alpha=0.2,
        sgd_batch_size=64,
        sgd_lr=1e-3,
        exploration_steps=100,
        replay_buf_size=int(100000),
        normalize_steps = 1000,
        use_gpu=False,
        reward_stop=None,
        env_config = {},
):
    """
    Implements soft actor critic

    Args:
        env_name: name of the openAI gym environment to solve
        train_steps: number of timesteps to run the PPO for
        model: model from seagul.rl.models. Contains policy, value fn, q1_fn, q2_fn
        min_steps_per_update: minimun number of steps to take before running updates, will finish episodes before updating
        env_max_steps: number of steps the environment takes before finishing, if the environment emits a done signal before this we consider it a failure.
        iters_per_update: how many update steps to make every time we update
        replay_batch_size: how big a batch to pull from the replay buffer for each update
        seed: random seed for all rngs
        gamma: discount applied to future rewards, usually close to 1
        polyak: term determining how fast the target network is copied from the value function
        alpha: weighting term for the entropy. 0 corresponds to no penalty for deterministic policy
        sgd_batch_size: minibatch size for policy updates
        sgd_lr: initial learning rate for policy optimizer
        val_lr: initial learning rate for value optimizer
        q_lr: initial learning rate for q fn optimizer
        exploration_steps: initial number of random actions to take, aids exploration
        replay_buf_size: how big of a replay buffer to use
        use_gpu: determines if we try to use a GPU or not
        reward_stop: reward value to bail at
        env_config: dictionary containing kwargs to pass to your the environment
    
    Returns:
        model: trained model
        avg_reward_hist: list with the average reward per episode at each epoch
        var_dict: dictionary with all locals, for logging/debugging purposes

    Example:
        from seagul.rl.algos.sac import sac
        import torch.nn as nn
        from seagul.nn import MLP
        from seagul.rl.models import SACModel

        input_size = 3
        output_size = 1
        layer_size = 64
        num_layers = 2
        activation = nn.ReLU

        policy = MLP(input_size, output_size*2, num_layers, layer_size, activation)
        value_fn = MLP(input_size, 1, num_layers, layer_size, activation)
        q1_fn = MLP(input_size + output_size, 1, num_layers, layer_size, activation)
        q2_fn = MLP(input_size + output_size, 1, num_layers, layer_size, activation)
        model = SACModel(policy, value_fn, q1_fn, q2_fn, 1)

        model, rews, var_dict = sac("Pendulum-v0", 10000, model)
    """
    torch.set_num_threads(1)

    env = gym.make(env_name, **env_config)
    if isinstance(env.action_space, gym.spaces.Box):
        act_size = env.action_space.shape[0]
        act_dtype = env.action_space.sample().dtype
    else:
        raise NotImplementedError("trying to use unsupported action space", env.action_space)

    obs_size = env.observation_space.shape[0]

    random_model = RandModel(model.act_limit, act_size)
    replay_buf = ReplayBuffer(obs_size, act_size, replay_buf_size)
    target_value_fn = dill.loads(dill.dumps(model.value_fn))

    pol_opt = torch.optim.Adam(model.policy.parameters(), lr=sgd_lr)
    val_opt = torch.optim.Adam(model.value_fn.parameters(), lr=sgd_lr)
    q1_opt = torch.optim.Adam(model.q1_fn.parameters(), lr=sgd_lr)
    q2_opt = torch.optim.Adam(model.q2_fn.parameters(), lr=sgd_lr)

    # seed all our RNGs
    env.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    # set defaults, and decide if we are using a GPU or not
    use_cuda = torch.cuda.is_available() and use_gpu
    device = torch.device("cuda:0" if use_cuda else "cpu")

    raw_rew_hist = []
    val_loss_hist = []
    pol_loss_hist = []
    q1_loss_hist = []
    q2_loss_hist = []

    progress_bar = tqdm.tqdm(total=train_steps + normalize_steps)
    cur_total_steps = 0
    progress_bar.update(0)
    early_stop = False
    norm_obs1 = torch.empty(0)



    while cur_total_steps < normalize_steps:
        ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, env_max_steps)
        norm_obs1 = torch.cat((norm_obs1, ep_obs1))
        
        ep_steps = ep_rews.shape[0]
        cur_total_steps += ep_steps

        progress_bar.update(ep_steps)
    if normalize_steps > 0:
        obs_mean = norm_obs1.mean(axis=0)
        obs_std  = norm_obs1.std(axis=0)
        obs_std[torch.isinf(1/obs_std)] = 1

        model.policy.state_means = obs_mean
        model.policy.state_std  =  obs_std
        model.value_fn.state_means = obs_mean
        model.value_fn.state_std = obs_std
        target_value_fn.state_means = obs_mean
        target_value_fn.state_std = obs_std

        model.q1_fn.state_means = torch.cat((obs_mean, torch.zeros(act_size)))
        model.q1_fn.state_std = torch.cat((obs_std, torch.ones(act_size)))
        model.q2_fn.state_means = model.q1_fn.state_means
        model.q2_fn.state_std = model.q1_fn.state_std

    while cur_total_steps < exploration_steps:
        ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, env_max_steps)
        replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done)

        ep_steps = ep_rews.shape[0]
        cur_total_steps += ep_steps

        progress_bar.update(ep_steps)

    while cur_total_steps < train_steps:
        cur_batch_steps = 0

        # Bail out if we have met out reward threshold
        if len(raw_rew_hist) > 2 and reward_stop:
            if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[-2] >= reward_stop:
                early_stop = True
                break

        # collect data with the current policy
        # ========================================================================
        while cur_batch_steps < min_steps_per_update:
            ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, model, env_max_steps)
            replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done)

            ep_steps = ep_rews.shape[0]
            cur_batch_steps += ep_steps
            cur_total_steps += ep_steps

            raw_rew_hist.append(torch.sum(ep_rews))
            print(raw_rew_hist[-1])



        progress_bar.update(cur_batch_steps)

        for _ in range(min(int(ep_steps), iters_per_update)):
            # compute targets for Q and V
            # ========================================================================
            replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = replay_buf.sample_batch(replay_batch_size)

            q_targ = replay_rews + gamma * (1 - replay_done) * target_value_fn(replay_obs2)
            q_targ = q_targ.detach()

            noise = torch.randn(replay_batch_size, act_size)
            sample_acts, sample_logp = model.select_action(replay_obs1, noise)

            q_in = torch.cat((replay_obs1, sample_acts), dim=1)
            q_preds = torch.cat((model.q1_fn(q_in), model.q2_fn(q_in)), dim=1)
            q_min, q_min_idx = torch.min(q_preds, dim=1)
            q_min = q_min.reshape(-1, 1)

            v_targ = q_min - alpha * sample_logp
            v_targ = v_targ.detach()

            # q_fn update
            # ========================================================================
            num_mbatch = int(replay_batch_size / sgd_batch_size)

            for i in range(num_mbatch):
                cur_sample = i*sgd_batch_size

                q_in = torch.cat((replay_obs1[cur_sample:cur_sample + sgd_batch_size], replay_acts[cur_sample:cur_sample + sgd_batch_size]), dim=1)
                q1_preds = model.q1_fn(q_in)
                q2_preds = model.q2_fn(q_in)
                q1_loss = torch.pow(q1_preds - q_targ[cur_sample:cur_sample + sgd_batch_size], 2).mean()
                q2_loss = torch.pow(q2_preds - q_targ[cur_sample:cur_sample + sgd_batch_size], 2).mean()
                q_loss = q1_loss + q2_loss
            
                q1_opt.zero_grad()
                q2_opt.zero_grad()
                q_loss.backward()
                q1_opt.step()
                q2_opt.step()

            # val_fn update
            # ========================================================================
            for i in range(num_mbatch):
                cur_sample = i*sgd_batch_size

                # predict and calculate loss for the batch
                val_preds = model.value_fn(replay_obs1[cur_sample:cur_sample + sgd_batch_size])
                val_loss = torch.sum(torch.pow(val_preds - v_targ[cur_sample:cur_sample + sgd_batch_size], 2)) / replay_batch_size

                # do the normal pytorch update
                val_opt.zero_grad()
                val_loss.backward()
                val_opt.step()

            # policy_fn update
            # ========================================================================
            for param in model.q1_fn.parameters():
                param.requires_grad = False

            for i in range(num_mbatch):
                cur_sample = i*sgd_batch_size

                noise = torch.randn(replay_obs1[cur_sample:cur_sample + sgd_batch_size].shape[0], act_size)
                local_acts, local_logp = model.select_action(replay_obs1[cur_sample:cur_sample + sgd_batch_size], noise)

                q_in = torch.cat((replay_obs1[cur_sample:cur_sample + sgd_batch_size], local_acts), dim=1)
                pol_loss = torch.sum(alpha * local_logp - model.q1_fn(q_in)) / replay_batch_size

                pol_opt.zero_grad()
                pol_loss.backward()
                pol_opt.step()

            for param in model.q1_fn.parameters():
                param.requires_grad = True

            # Update target value fn with polyak average
            # ========================================================================
            val_loss_hist.append(val_loss.item())
            pol_loss_hist.append(pol_loss.item())
            q1_loss_hist.append(q1_loss.item())
            q2_loss_hist.append(q2_loss.item())



            val_sd = model.value_fn.state_dict()
            tar_sd = target_value_fn.state_dict()
            for layer in tar_sd:
                tar_sd[layer] = polyak * tar_sd[layer] + (1 - polyak) * val_sd[layer]

            target_value_fn.load_state_dict(tar_sd)

    return model, raw_rew_hist, locals()