示例#1
0
def main(env='PongNoFrameskip-v4'):
    random.seed(SEED)
    np.random.seed(SEED)
    th.manual_seed(SEED)

    env = gym.make(env)
    env = envs.OpenAIAtari(env)
    env = envs.Logger(env, interval=PPO_STEPS)
    env = envs.Torch(env)
    env = envs.Runner(env)
    env.seed(SEED)

    policy = NatureCNN(env).to('cuda:0')
    optimizer = optim.Adam(policy.parameters(), lr=LR, eps=1e-5)
    num_updates = TOTAL_STEPS // PPO_STEPS + 1
    lr_schedule = optim.lr_scheduler.LambdaLR(
        optimizer, lambda epoch: 1 - epoch / num_updates)
    get_action = lambda state: get_action_value(state, policy)

    for epoch in range(num_updates):
        policy.cpu()
        replay = env.run(get_action, steps=PPO_STEPS, render=RENDER)
        replay = replay.cuda()
        policy.cuda()
        update(replay, optimizer, policy, env, lr_schedule)
示例#2
0
def main(env='PongNoFrameskip-v4'):
    num_steps = 10000000

    th.set_num_threads(1)
    random.seed(SEED)
    th.manual_seed(SEED)
    np.random.seed(SEED)

    env = gym.make(env)
    env.seed(1234)
    env = envs.Logger(env, interval=1000)
    env = envs.OpenAIAtari(env)
    env = envs.Torch(env)
    env = envs.Runner(env)
    env.seed(SEED)

    num_updates = num_steps // A2C_STEPS + 1
    th.manual_seed(1234)
    policy = NatureCNN(env)
    optimizer = optim.RMSprop(policy.parameters(), lr=LR, alpha=0.99, eps=1e-5)
    #lr_schedule = optim.lr_scheduler.LambdaLR(optimizer, lambda step: 1 - step/num_updates)
    get_action = lambda state: get_action_value(state, policy)

    for updt in range(num_updates):
        # Sample some transitions
        replay = env.run(get_action, steps=A2C_STEPS)

        # Update policy
        update(replay, optimizer, policy, env=env)
示例#3
0
        def test_config(n_envs, base_env, use_torch, use_logger, return_info):
            config = 'n_envs' + str(n_envs) + '-base_env' + str(base_env) \
                    + '-torch' + str(use_torch) + '-logger' + str(use_logger) \
                    + '-info' + str(return_info)
            if isinstance(base_env, str):
                env = vec_env = gym.vector.make(base_env, num_envs=n_envs)
            else:

                def make_env():
                    env = base_env()
                    return env

                env_fns = [make_env for _ in range(n_envs)]
                env = vec_env = AsyncVectorEnv(env_fns)

            if use_logger:
                env = envs.Logger(env, interval=5, logger=self.logger)

            if use_torch:
                env = envs.Torch(env)
                policy = lambda x: ch.totensor(vec_env.action_space.sample())
            else:
                policy = lambda x: vec_env.action_space.sample()

            if return_info:
                agent = lambda x: (policy(x), {'policy': policy(x)[0]})
            else:
                agent = policy

            # Gather experience
            env = envs.Runner(env)
            replay = env.run(agent, steps=NUM_STEPS)

            # Pre-compute some shapes
            shape = (NUM_STEPS, n_envs)
            state_shape = vec_env.observation_space.sample()[0]
            if isinstance(state_shape, (int, float)):
                state_shape = tuple()
            else:
                state_shape = state_shape.shape
            action_shape = vec_env.action_space.sample()[0]
            if isinstance(action_shape, (int, float)):
                action_shape = (1, )
            else:
                action_shape = action_shape.shape
            done_shape = tuple()

            # Check shapes
            states = replay.state()
            self.assertEqual(states.shape, shape + state_shape, config)
            actions = replay.action()
            self.assertEqual(actions.shape, shape + action_shape, config)
            dones = replay.done()
            self.assertEqual(dones.shape, shape + done_shape, config)
            if return_info:
                policies = replay.policy()
                self.assertEqual(policies.shape, (NUM_STEPS, ) + action_shape,
                                 config)
示例#4
0
def main(env='Pendulum-v0'):
    agent = ActorCritic(HIDDEN_SIZE)
    actor_optimiser = optim.Adam(agent.actor.parameters(), lr=LEARNING_RATE)
    critic_optimiser = optim.Adam(agent.critic.parameters(), lr=LEARNING_RATE)
    replay = ch.ExperienceReplay()

    env = gym.make(env)
    env.seed(SEED)
    env = envs.Torch(env)
    env = envs.Logger(env)
    env = envs.Runner(env)
    replay = ch.ExperienceReplay()

    for step in range(1, MAX_STEPS + 1):
        replay += env.run(agent, episodes=1)

        if len(replay) >= BATCH_SIZE:
            with torch.no_grad():
                advantages = pg.generalized_advantage(DISCOUNT,
                                                      TRACE_DECAY,
                                                      replay.reward(),
                                                      replay.done(),
                                                      replay.value(),
                                                      torch.zeros(1))
                advantages = ch.normalize(advantages, epsilon=1e-8)
                returns = td.discount(DISCOUNT,
                                         replay.reward(),
                                         replay.done())
                old_log_probs = replay.log_prob()

            new_values = replay.value()
            new_log_probs = replay.log_prob()
            for epoch in range(PPO_EPOCHS):
                # Recalculate outputs for subsequent iterations
                if epoch > 0:
                    _, infos = agent(replay.state())
                    masses = infos['mass']
                    new_values = infos['value'].view(-1, 1)
                    new_log_probs = masses.log_prob(replay.action())

                # Update the policy by maximising the PPO-Clip objective
                policy_loss = ch.algorithms.ppo.policy_loss(new_log_probs,
                                                            old_log_probs,
                                                            advantages,
                                                            clip=PPO_CLIP_RATIO)
                actor_optimiser.zero_grad()
                policy_loss.backward()
                actor_optimiser.step()

                # Fit value function by regression on mean-squared error
                value_loss = ch.algorithms.a2c.state_value_loss(new_values,
                                                                returns)
                critic_optimiser.zero_grad()
                value_loss.backward()
                critic_optimiser.step()

            replay.empty()
示例#5
0
def main(env):
    env = gym.make(env)
    env.seed(SEED)
    env = envs.Torch(env)
    env = envs.ActionLambda(env, convert_discrete_to_continuous_action)
    env = envs.Logger(env)
    env = envs.Runner(env)

    replay = ch.ExperienceReplay()
    agent = DQN(HIDDEN_SIZE, ACTION_DISCRETISATION)
    target_agent = create_target_network(agent)
    optimiser = optim.Adam(agent.parameters(), lr=LEARNING_RATE)

    def get_random_action(state):
        action = torch.tensor([[random.randint(0, ACTION_DISCRETISATION - 1)]])
        return action

    def get_action(state):
        # Original sampling (for unit test)
        #if random.random() < EPSILON:
        #  action = torch.tensor([[random.randint(0, ACTION_DISCRETISATION - 1)]])
        #else:
        #  action = agent(state)[1].argmax(dim=1, keepdim=True)
        #return action
        return agent(state)[0]

    for step in range(1, MAX_STEPS + 1):
        with torch.no_grad():
            if step < UPDATE_START:
                replay += env.run(get_random_action, steps=1)
            else:
                replay += env.run(get_action, steps=1)

            replay = replay[-REPLAY_SIZE:]

        if step > UPDATE_START and step % UPDATE_INTERVAL == 0:
            # Randomly sample a batch of experience
            batch = random.sample(replay, BATCH_SIZE)
            batch = ch.ExperienceReplay(batch)

            # Compute targets
            target_values = target_agent(batch.next_state())[1].max(
                dim=1, keepdim=True)[0]
            target_values = batch.reward() + DISCOUNT * (
                1 - batch.done()) * target_values

            # Update Q-function by one step of gradient descent
            pred_values = agent(batch.state())[1].gather(1, batch.action())
            value_loss = F.mse_loss(pred_values, target_values)
            optimiser.zero_grad()
            value_loss.backward()
            optimiser.step()

        if step > UPDATE_START and step % TARGET_UPDATE_INTERVAL == 0:
            # Update target network
            target_agent = create_target_network(agent)
示例#6
0
def main(env='Pendulum-v0'):
    env = gym.make(env)
    env.seed(SEED)
    env = envs.Torch(env)
    env = envs.Logger(env)
    env = envs.Runner(env)

    actor = Actor(HIDDEN_SIZE, stochastic=False, layer_norm=True)
    critic = Critic(HIDDEN_SIZE, state_action=True, layer_norm=True)
    target_actor = create_target_network(actor)
    target_critic = create_target_network(critic)
    actor_optimiser = optim.Adam(actor.parameters(), lr=LEARNING_RATE)
    critic_optimiser = optim.Adam(critic.parameters(), lr=LEARNING_RATE)
    replay = ch.ExperienceReplay()

    get_action = lambda s: (actor(s) + ACTION_NOISE * torch.randn(1, 1)).clamp(
        -1, 1)

    for step in range(1, MAX_STEPS + 1):
        with torch.no_grad():
            if step < UPDATE_START:
                replay += env.run(get_random_action, steps=1)
            else:
                replay += env.run(get_action, steps=1)

        replay = replay[-REPLAY_SIZE:]
        if step > UPDATE_START and step % UPDATE_INTERVAL == 0:
            sample = random.sample(replay, BATCH_SIZE)
            batch = ch.ExperienceReplay(sample)

            next_values = target_critic(batch.next_state(),
                                        target_actor(batch.next_state())).view(
                                            -1, 1)
            values = critic(batch.state(), batch.action()).view(-1, 1)
            value_loss = ch.algorithms.ddpg.state_value_loss(
                values, next_values.detach(), batch.reward(), batch.done(),
                DISCOUNT)
            critic_optimiser.zero_grad()
            value_loss.backward()
            critic_optimiser.step()

            # Update policy by one step of gradient ascent
            policy_loss = -critic(batch.state(), actor(batch.state())).mean()
            actor_optimiser.zero_grad()
            policy_loss.backward()
            actor_optimiser.step()

            # Update target networks
            ch.models.polyak_average(target_critic, critic, POLYAK_FACTOR)
            ch.models.polyak_average(target_actor, actor, POLYAK_FACTOR)
示例#7
0
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int)
    args = parser.parse_args()
    dist.init_process_group('gloo',
                            init_method='file:///home/seba-1511/.dist_init',
                            rank=args.local_rank,
                            world_size=16)

    rank = dist.get_rank()
    th.set_num_threads(1)
    random.seed(SEED + rank)
    th.manual_seed(SEED + rank)
    np.random.seed(SEED + rank)

    # env_name = 'CartPoleBulletEnv-v0'
    env_name = 'AntBulletEnv-v0'
    #    env_name = 'RoboschoolAnt-v1'
    env = gym.make(env_name)
    env = envs.AddTimestep(env)
    if rank == 0:
        env = envs.Logger(env, interval=PPO_STEPS)
    env = envs.Normalizer(env, states=True, rewards=True)
    env = envs.Torch(env)
    env = envs.Runner(env)
    env.seed(SEED)

    th.set_num_threads(1)
    policy = ActorCriticNet(env)
    optimizer = optim.Adam(policy.parameters(), lr=LR, eps=1e-5)
    num_updates = TOTAL_STEPS // PPO_STEPS + 1
    lr_schedule = optim.lr_scheduler.LambdaLR(
        optimizer, lambda epoch: 1 - epoch / num_updates)
    optimizer = Distributed(policy.parameters(), optimizer)
    get_action = lambda state: get_action_value(state, policy)

    for epoch in range(num_updates):
        # We use the Runner collector, but could've written our own
        replay = env.run(get_action, steps=PPO_STEPS, render=False)

        # Update policy
        update(replay, optimizer, policy, env, lr_schedule)
示例#8
0
def main(num_steps=10000000,
         env_name='PongNoFrameskip-v4',
#         env_name='BreakoutNoFrameskip-v4',
         seed=42):
    th.set_num_threads(1)
    random.seed(seed)
    th.manual_seed(seed)
    np.random.seed(seed)

    env = gym.make(env_name)
    env = envs.Logger(env, interval=1000)
    env = envs.OpenAIAtari(env)
    env = envs.Torch(env)
    env = envs.Runner(env)
    env.seed(seed)

    dqn = DQN(env)
    target_dqn = copy.deepcopy(dqn)
    optimizer = optim.RMSprop(dqn.parameters(), lr=LR, alpha=0.95,
                              eps=0.01, centered=True)
    replay = ch.ExperienceReplay()
    epsilon = EPSILON
    get_action = lambda state: epsilon_greedy(dqn(state), epsilon)

    for step in range(num_steps // UPDATE_FREQ + 1):
        # Sample some transitions
        ep_replay = env.run(get_action, steps=UPDATE_FREQ)
        replay += ep_replay

        if step * UPDATE_FREQ < 1e6:
            # Update epsilon
            epsilon -= 9.9e-7 * UPDATE_FREQ

        if step * UPDATE_FREQ > EXPLORATION_STEPS:
            # Only keep the last 1M transitions
            replay = replay[-REPLAY_SIZE:]

            # Update Q-function
            update(replay, optimizer, dqn, target_dqn, env=env)

            if step % TARGET_UPDATE_FREQ == 0:
                target_dqn.load_state_dict(dqn.state_dict())
示例#9
0
def main(env='CliffWalking-v0'):
    env = gym.make(env)
    env = envs.Logger(env, interval=1000)
    env = envs.Torch(env)
    env = envs.Runner(env)
    agent = Agent(env)
    discount = 1.00
    optimizer = optim.SGD(agent.parameters(), lr=0.5, momentum=0.0)
    for t in range(1, 10000):
        transition = env.run(agent, steps=1)[0]

        curr_q = transition.q_action
        next_state = ch.onehot(transition.next_state, dim=env.state_size)
        next_q = agent.qf(next_state).max().detach()
        td_error = ch.temporal_difference(discount, transition.reward,
                                          transition.done, curr_q, next_q)

        optimizer.zero_grad()
        loss = td_error.pow(2).mul(0.5)
        loss.backward()
        optimizer.step()
示例#10
0
def main(env='PongNoFrameskip-v4'):
    num_steps = 5000000
    seed = 42

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int)
    args = parser.parse_args()
    dist.init_process_group('gloo',
   			    init_method='file:///home/seba-1511/.dist_init_' + env,
			    rank=args.local_rank,
			    world_size=16)

    rank = dist.get_rank()
    th.set_num_threads(1)
    random.seed(seed + rank)
    th.manual_seed(seed + rank)
    np.random.seed(seed + rank)

    env = gym.make(env)
    if rank == 0:
        env = envs.Logger(env, interval=1000)
    env = envs.OpenAIAtari(env)
    env = envs.Torch(env)
    env = envs.Runner(env)
    env.seed(seed + rank)

    policy = NatureCNN(env)
    optimizer = optim.RMSprop(policy.parameters(), lr=LR, alpha=0.99, eps=1e-5)
    optimizer = Distributed(policy.parameters(), optimizer)
    get_action = lambda state: get_action_value(state, policy)

    for step in range(num_steps // A2C_STEPS + 1):
        # Sample some transitions
        replay = env.run(get_action, steps=A2C_STEPS)

        # Update policy
        update(replay, optimizer, policy, env=env)
示例#11
0
def main(env='MinitaurTrottingEnv-v0'):
    env = gym.make(env)
    env = envs.AddTimestep(env)
    env = envs.Logger(env, interval=PPO_STEPS)
    env = envs.Normalizer(env, states=True, rewards=True)
    env = envs.Torch(env)
    # env = envs.Recorder(env)
    env = envs.Runner(env)
    env.seed(SEED)

    th.set_num_threads(1)
    policy = ActorCriticNet(env)
    optimizer = optim.Adam(policy.parameters(), lr=LR, eps=1e-5)
    num_updates = TOTAL_STEPS // PPO_STEPS + 1
    lr_schedule = optim.lr_scheduler.LambdaLR(
        optimizer, lambda epoch: 1 - epoch / num_updates)
    get_action = lambda state: get_action_value(state, policy)

    for epoch in range(num_updates):
        # We use the Runner collector, but could've written our own
        replay = env.run(get_action, steps=PPO_STEPS, render=RENDER)

        # Update policy
        update(replay, optimizer, policy, env, lr_schedule)
示例#12
0
    # Compute loss
    for sars, reward in zip(replay, rewards):
        log_prob = sars.log_prob
        policy_loss.append(-log_prob * reward)

    # Take optimization step
    optimizer.zero_grad()
    policy_loss = th.stack(policy_loss).sum()
    policy_loss.backward()
    optimizer.step()


if __name__ == '__main__':
    env = gym.make('CartPole-v0')
    env = envs.Logger(env, interval=1000)
    env = envs.Torch(env)
    env.seed(SEED)

    policy = PolicyNet()
    optimizer = optim.Adam(policy.parameters(), lr=1e-2)
    running_reward = 10.0
    replay = ch.ExperienceReplay()

    for i_episode in count(1):
        state = env.reset()
        for t in range(10000):  # Don't infinite loop while learning
            mass = Categorical(policy(state))
            action = mass.sample()
            old_state = state
            state, reward, done, _ = env.step(action)
示例#13
0
def main(env='Pendulum-v0'):
    agent = ActorCritic(HIDDEN_SIZE).to(device)
    agent.apply(weights_init)

    actor_optimizer = optim.Adam(agent.actor.parameters(), lr=LEARNING_RATE)
    critic_optimizer = optim.Adam(agent.critic.parameters(), lr=LEARNING_RATE)
    actor_scheduler = torch.optim.lr_scheduler.StepLR(actor_optimizer,
                                                      step_size=2000,
                                                      gamma=0.5)
    critic_scheduler = torch.optim.lr_scheduler.StepLR(critic_optimizer,
                                                       step_size=2000,
                                                       gamma=0.5)
    replay = ch.ExperienceReplay()

    env = gym.make(env)
    env.seed(SEED)
    env = envs.Torch(env)
    env = envs.Logger(env)
    env = envs.Runner(env)
    replay = ch.ExperienceReplay()

    def get_action(state):
        return agent(state.to(device))

    for step in range(1, MAX_STEPS + 1):
        replay += env.run(get_action, episodes=1)

        if len(replay) >= BATCH_SIZE:
            #batch = replay.sample(BATCH_SIZE).to(device)
            batch = replay.to(device)
            with torch.no_grad():
                advantages = pg.generalized_advantage(
                    DISCOUNT, TRACE_DECAY, batch.reward(), batch.done(),
                    batch.value(),
                    torch.zeros(1).to(device))
                advantages = ch.normalize(advantages, epsilon=1e-8)
                returns = td.discount(DISCOUNT, batch.reward(), batch.done())
                old_log_probs = batch.log_prob()

            new_values = batch.value()
            new_log_probs = batch.log_prob()
            for epoch in range(PPO_EPOCHS):
                # Recalculate outputs for subsequent iterations
                if epoch > 0:
                    _, infos = agent(batch.state())
                    masses = infos['mass']
                    new_values = infos['value'].view(-1, 1)
                    new_log_probs = masses.log_prob(batch.action())

                # Update the policy by maximising the PPO-Clip objective
                policy_loss = ch.algorithms.ppo.policy_loss(
                    new_log_probs,
                    old_log_probs,
                    advantages,
                    clip=PPO_CLIP_RATIO)
                actor_optimizer.zero_grad()
                policy_loss.backward()
                #nn.utils.clip_grad_norm_(agent.actor.parameters(), 1.0)
                actor_optimizer.step()

                # Fit value function by regression on mean-squared error
                value_loss = ch.algorithms.a2c.state_value_loss(
                    new_values, returns)
                critic_optimizer.zero_grad()
                value_loss.backward()
                #nn.utils.clip_grad_norm_(agent.critic.parameters(), 1.0)
                critic_optimizer.step()

            actor_scheduler.step()
            critic_scheduler.step()

            replay.empty()
示例#14
0
 def setUp(self):
     env = Dummy()
     self.logger = envs.Logger(env)
示例#15
0
        def test_config(n_envs, n_episodes, base_env, use_torch, use_logger,
                        return_info, retry):
            config = 'n_envs' + str(n_envs) + '-n_eps' + str(n_episodes) \
                    + '-base_env' + str(base_env) \
                    + '-torch' + str(use_torch) + '-logger' + str(use_logger) \
                    + '-info' + str(return_info)
            if isinstance(base_env, str):
                env = vec_env = gym.vector.make(base_env, num_envs=n_envs)
            else:

                def make_env():
                    env = base_env()
                    return env

                env_fns = [make_env for _ in range(n_envs)]
                env = vec_env = AsyncVectorEnv(env_fns)

            if use_logger:
                env = envs.Logger(env, interval=5, logger=self.logger)

            if use_torch:
                env = envs.Torch(env)
                policy = lambda x: ch.totensor(vec_env.action_space.sample())
            else:
                policy = lambda x: vec_env.action_space.sample()

            if return_info:
                agent = lambda x: (policy(x), {
                    'policy': policy(x)[0],
                    'act': policy(x)
                })
            else:
                agent = policy

            # Gather experience
            env = envs.Runner(env)
            replay = env.run(agent, episodes=n_episodes)
            if retry:
                replay = env.run(agent, episodes=n_episodes)

            # Pre-compute some shapes
            shape = (len(replay), )
            state_shape = vec_env.observation_space.sample().shape[1:]
            action_shape = np.array(vec_env.action_space.sample())[0].shape
            if len(action_shape) == 0:
                action_shape = (1, )
            done_shape = (1, )

            # Check shapes
            states = replay.state()
            self.assertEqual(states.shape, shape + state_shape, config)
            actions = replay.action()
            self.assertEqual(actions.shape, shape + action_shape, config)
            dones = replay.done()
            self.assertEqual(dones.shape, shape + done_shape, config)
            if return_info:
                policies = replay.policy()
                self.assertEqual(policies.shape, shape + action_shape, config)
                acts = replay.act()
                self.assertEqual(acts.shape,
                                 (len(replay), n_envs) + action_shape, config)
示例#16
0
def main(env='Pendulum-v0'):
    env = gym.make(env)
    env.seed(SEED)
    env = envs.Torch(env)
    env = envs.Logger(env)
    env = envs.Runner(env)
    replay = ch.ExperienceReplay()

    actor = SoftActor(HIDDEN_SIZE)
    critic_1 = Critic(HIDDEN_SIZE, state_action=True)
    critic_2 = Critic(HIDDEN_SIZE, state_action=True)
    value_critic = Critic(HIDDEN_SIZE)
    target_value_critic = create_target_network(value_critic)
    actor_optimiser = optim.Adam(actor.parameters(), lr=LEARNING_RATE)
    critics_optimiser = optim.Adam(
        (list(critic_1.parameters()) + list(critic_2.parameters())),
        lr=LEARNING_RATE)
    value_critic_optimiser = optim.Adam(value_critic.parameters(),
                                        lr=LEARNING_RATE)
    get_action = lambda state: actor(state).sample()

    for step in range(1, MAX_STEPS + 1):
        with torch.no_grad():
            if step < UPDATE_START:
                replay += env.run(get_random_action, steps=1)
            else:
                replay += env.run(get_action, steps=1)
        replay = replay[-REPLAY_SIZE:]

        if step > UPDATE_START and step % UPDATE_INTERVAL == 0:
            sample = random.sample(replay, BATCH_SIZE)
            batch = ch.ExperienceReplay(sample)

            # Pre-compute some quantities
            states = batch.state()
            rewards = batch.reward()
            old_actions = batch.action()
            dones = batch.done()
            masses = actor(states)
            actions = masses.rsample()
            log_probs = masses.log_prob(actions)
            q_values = torch.min(critic_1(states, actions.detach()),
                                 critic_2(states,
                                          actions.detach())).view(-1, 1)

            # Compute Q losses
            v_next = target_value_critic(batch.next_state()).view(-1, 1)
            q_old_pred1 = critic_1(states, old_actions.detach()).view(-1, 1)
            q_old_pred2 = critic_2(states, old_actions.detach()).view(-1, 1)
            qloss1 = ch.algorithms.sac.action_value_loss(
                q_old_pred1, v_next.detach(), rewards, dones, DISCOUNT)
            qloss2 = ch.algorithms.sac.action_value_loss(
                q_old_pred2, v_next.detach(), rewards, dones, DISCOUNT)

            # Update Q-functions by one step of gradient descent
            qloss = qloss1 + qloss2
            critics_optimiser.zero_grad()
            qloss.backward()
            critics_optimiser.step()

            # Update V-function by one step of gradient descent
            v_pred = value_critic(batch.state()).view(-1, 1)
            vloss = ch.algorithms.sac.state_value_loss(v_pred,
                                                       log_probs.detach(),
                                                       q_values.detach(),
                                                       alpha=ENTROPY_WEIGHT)
            value_critic_optimiser.zero_grad()
            vloss.backward()
            value_critic_optimiser.step()

            # Update policy by one step of gradient ascent
            q_actions = critic_1(batch.state(), actions).view(-1, 1)
            policy_loss = ch.algorithms.sac.policy_loss(log_probs,
                                                        q_actions,
                                                        alpha=ENTROPY_WEIGHT)
            actor_optimiser.zero_grad()
            policy_loss.backward()
            actor_optimiser.step()

            # Update target value network
            ch.models.polyak_average(target_value_critic, value_critic,
                                     POLYAK_FACTOR)