示例#1
0
        # Update Q-functions by one step of gradient descent
        value_loss = (
            critic_1(batch['state'], batch['action']) - y_q).pow(2).mean() + (
                critic_2(batch['state'], batch['action']) - y_q).pow(2).mean()
        critics_optimiser.zero_grad()
        value_loss.backward()
        critics_optimiser.step()

        # Update V-function by one step of gradient descent
        value_loss = (value_critic(batch['state']) - y_v).pow(2).mean()
        value_critic_optimiser.zero_grad()
        value_loss.backward()
        value_critic_optimiser.step()

        # Update policy by one step of gradient ascent
        policy_loss = (weighted_sample_entropy -
                       critic_1(batch['state'], action)).mean()
        actor_optimiser.zero_grad()
        policy_loss.backward()
        actor_optimiser.step()

        # Update target value network
        update_target_network(value_critic, target_value_critic, POLYAK_FACTOR)

    if step > UPDATE_START and step % TEST_INTERVAL == 0:
        actor.eval()
        total_reward = test(actor)
        pbar.set_description('Step: %i | Reward: %f' % (step, total_reward))
        plot(step, total_reward, 'sac')
        actor.train()
示例#2
0
        target_action = torch.clamp(
            target_actor(batch['next_state']) +
            torch.clamp(0.2 * torch.randn(1, 1), min=-0.5, max=0.5),
            min=-1,
            max=1)
        # Compute targets (clipped double Q-learning)
        y = batch['reward'] + discount * (1 - batch['done']) * torch.min(
            target_critic_1(batch['next_state'], target_action),
            target_critic_2(batch['next_state'], target_action))

        # Update Q-functions by one step of gradient descent
        value_loss = (critic_1(batch['state'], batch['action']) -
                      y).pow(2).mean() + (critic_2(
                          batch['state'], batch['action']) - y).pow(2).mean()
        critics_optimiser.zero_grad()
        value_loss.backward()
        critics_optimiser.step()

        if step % (policy_delay * update_interval) == 0:
            # Update policy by one step of gradient ascent
            policy_loss = -critic_1(batch['state'], actor(
                batch['state'])).mean()
            actor_optimiser.zero_grad()
            policy_loss.backward()
            actor_optimiser.step()

            # Update target networks
            update_target_network(critic_1, target_critic_1, polyak_rate)
            update_target_network(critic_2, target_critic_2, polyak_rate)
            update_target_network(actor, target_actor, polyak_rate)
示例#3
0
            torch.clamp(TARGET_ACTION_NOISE * torch.randn(1, 1),
                        min=-TARGET_ACTION_NOISE_CLIP,
                        max=TARGET_ACTION_NOISE_CLIP),
            min=-1,
            max=1)
        # Compute targets (clipped double Q-learning)
        y = batch['reward'] + DISCOUNT * (1 - batch['done']) * torch.min(
            target_critic_1(batch['next_state'], target_action),
            target_critic_2(batch['next_state'], target_action))

        # Update Q-functions by one step of gradient descent
        value_loss = (critic_1(batch['state'], batch['action']) -
                      y).pow(2).mean() + (critic_2(
                          batch['state'], batch['action']) - y).pow(2).mean()
        critics_optimiser.zero_grad()
        value_loss.backward()
        critics_optimiser.step()

        if step % (POLICY_DELAY * UPDATE_INTERVAL) == 0:
            # Update policy by one step of gradient ascent
            policy_loss = -critic_1(batch['state'], actor(
                batch['state'])).mean()
            actor_optimiser.zero_grad()
            policy_loss.backward()
            actor_optimiser.step()

            # Update target networks
            update_target_network(critic_1, target_critic_1, POLYAK_FACTOR)
            update_target_network(critic_2, target_critic_2, POLYAK_FACTOR)
            update_target_network(actor, target_actor, POLYAK_FACTOR)
示例#4
0
def train(BATCH_SIZE, DISCOUNT, ENTROPY_WEIGHT, HIDDEN_SIZE, LEARNING_RATE,
          MAX_STEPS, POLYAK_FACTOR, REPLAY_SIZE, TEST_INTERVAL,
          UPDATE_INTERVAL, UPDATE_START, ENV, OBSERVATION_LOW, VALUE_FNC,
          FLOW_TYPE, FLOWS, DEMONSTRATIONS, PRIORITIZE_REPLAY,
          BEHAVIOR_CLONING, ARM, BASE, RPA, REWARD_DENSE, logdir):

    ALPHA = 0.3
    BETA = 1
    epsilon = 0.0001  #0.1
    epsilon_d = 0.1  #0.3
    weights = 1  #1
    lambda_ac = 0.85  #0.7
    lambda_bc = 0.3  #0.4

    setup_logger(logdir, locals())
    ENV = __import__(ENV)
    if ARM and BASE:
        env = ENV.youBotAll('youbot_navig2.ttt',
                            obs_lowdim=OBSERVATION_LOW,
                            rpa=RPA,
                            reward_dense=REWARD_DENSE,
                            boundary=1)
    elif ARM:
        env = ENV.youBotArm('youbot_navig.ttt',
                            obs_lowdim=OBSERVATION_LOW,
                            rpa=RPA,
                            reward_dense=REWARD_DENSE)
    elif BASE:
        env = ENV.youBotBase('youbot_navig.ttt',
                             obs_lowdim=OBSERVATION_LOW,
                             rpa=RPA,
                             reward_dense=REWARD_DENSE,
                             boundary=1)

    action_space = env.action_space
    obs_space = env.observation_space()
    step_limit = env.step_limit()

    if OBSERVATION_LOW:
        actor = SoftActorGated(HIDDEN_SIZE,
                               action_space,
                               obs_space,
                               flow_type=FLOW_TYPE,
                               flows=FLOWS).float().to(device)
        critic_1 = Critic(HIDDEN_SIZE,
                          1,
                          obs_space,
                          action_space,
                          state_action=True).float().to(device)
        critic_2 = Critic(HIDDEN_SIZE,
                          1,
                          obs_space,
                          action_space,
                          state_action=True).float().to(device)
    else:
        actor = ActorImageNet(HIDDEN_SIZE,
                              action_space,
                              obs_space,
                              flow_type=FLOW_TYPE,
                              flows=FLOWS).float().to(device)
        critic_1 = Critic(HIDDEN_SIZE,
                          1,
                          obs_space,
                          action_space,
                          state_action=True).float().to(device)
        critic_2 = Critic(HIDDEN_SIZE,
                          1,
                          obs_space,
                          action_space,
                          state_action=True).float().to(device)
        critic_1.load_state_dict(
            torch.load(
                'data/youbot_all_final_21-08-2019_22-32-00/models/critic1_model_473000.pkl'
            ))
        critic_2.load_state_dict(
            torch.load(
                'data/youbot_all_final_21-08-2019_22-32-00/models/critic1_model_473000.pkl'
            ))

    actor.apply(weights_init)
    # critic_1.apply(weights_init)
    # critic_2.apply(weights_init)

    if VALUE_FNC:
        value_critic = Critic(HIDDEN_SIZE, 1, obs_space,
                              action_space).float().to(device)
        target_value_critic = create_target_network(value_critic).float().to(
            device)
        value_critic_optimiser = optim.Adam(value_critic.parameters(),
                                            lr=LEARNING_RATE)
    else:
        target_critic_1 = create_target_network(critic_1)
        target_critic_2 = create_target_network(critic_2)
    actor_optimiser = optim.Adam(actor.parameters(), lr=LEARNING_RATE)
    critics_optimiser = optim.Adam(list(critic_1.parameters()) +
                                   list(critic_2.parameters()),
                                   lr=LEARNING_RATE)

    # Replay buffer
    if PRIORITIZE_REPLAY:
        # D = PrioritizedReplayBuffer(REPLAY_SIZE, ALPHA)
        D = ReplayMemory(device, 3, DISCOUNT, 1, BETA, ALPHA, REPLAY_SIZE)
    else:
        D = deque(maxlen=REPLAY_SIZE)

    eval_ = evaluation_sac(env, logdir, device)

    #Automatic entropy tuning init
    target_entropy = -np.prod(action_space).item()
    log_alpha = torch.zeros(1, requires_grad=True, device=device)
    alpha_optimizer = optim.Adam([log_alpha], lr=LEARNING_RATE)

    home = os.path.expanduser('~')
    if DEMONSTRATIONS:
        dir_dem = os.path.join(home, 'robotics_drl/data/demonstrations/',
                               DEMONSTRATIONS)
        D, n_demonstrations = load_buffer_demonstrations(
            D, dir_dem, PRIORITIZE_REPLAY, OBSERVATION_LOW)
    else:
        n_demonstrations = 0

    if not BEHAVIOR_CLONING:
        behavior_loss = 0

    os.mkdir(os.path.join(home, 'robotics_drl', logdir, 'models'))
    dir_models = os.path.join(home, 'robotics_drl', logdir, 'models')

    state, done = env.reset(), False
    if OBSERVATION_LOW:
        state = state.float().to(device)
    else:
        state['low'] = state['low'].float()
        state['high'] = state['high'].float()
    pbar = tqdm(range(1, MAX_STEPS + 1), unit_scale=1, smoothing=0)

    steps = 0
    success = 0
    for step in pbar:
        with torch.no_grad():
            if step < UPDATE_START and not DEMONSTRATIONS:
                # To improve exploration take actions sampled from a uniform random distribution over actions at the start of training
                action = torch.tensor(env.sample_action(),
                                      dtype=torch.float32,
                                      device=device).unsqueeze(dim=0)
            else:
                # Observe state s and select action a ~ μ(a|s)
                if not OBSERVATION_LOW:
                    state['low'] = state['low'].float().to(device)
                    state['high'] = state['high'].float().to(device)
                action, _ = actor(state, log_prob=False, deterministic=False)
                if not OBSERVATION_LOW:
                    state['low'] = state['low'].float().cpu()
                    state['high'] = state['high'].float().cpu()
                #if (policy.mean).mean() > 0.4:
                #    print("GOOD VELOCITY")
            # Execute a in the environment and observe next state s', reward r, and done signal d to indicate whether s' is terminal
            next_state, reward, done = env.step(
                action.squeeze(dim=0).cpu().tolist())
            if OBSERVATION_LOW:
                next_state = next_state.float().to(device)
            else:
                next_state['low'] = next_state['low'].float()
                next_state['high'] = next_state['high'].float()
            # Store (s, a, r, s', d) in replay buffer D
            if PRIORITIZE_REPLAY:
                if OBSERVATION_LOW:
                    D.add(state.cpu().tolist(),
                          action.cpu().squeeze().tolist(), reward,
                          next_state.cpu().tolist(), done)
                else:
                    D.append(state['high'], state['low'],
                             action.cpu().squeeze().tolist(), reward, done)
            else:
                D.append({
                    'state':
                    state.unsqueeze(dim=0) if OBSERVATION_LOW else state,
                    'action':
                    action,
                    'reward':
                    torch.tensor([reward], dtype=torch.float32, device=device),
                    'next_state':
                    next_state.unsqueeze(
                        dim=0) if OBSERVATION_LOW else next_state,
                    'done':
                    torch.tensor([True if reward == 1 else False],
                                 dtype=torch.float32,
                                 device=device)
                })

            state = next_state

            # If s' is terminal, reset environment state
            steps += 1

            if done or steps > step_limit:  #TODO: incorporate step limit in the environment
                eval_c2 = True  #TODO: multiprocess pyrep with a session for each testing and training
                steps = 0
                if OBSERVATION_LOW:
                    state = env.reset().float().to(device)
                else:
                    state = env.reset()
                    state['low'] = state['low'].float()
                    state['high'] = state['high'].float()
                if reward == 1:
                    success += 1

        if step > UPDATE_START and step % UPDATE_INTERVAL == 0:
            for _ in range(1):
                # Randomly sample a batch of transitions B = {(s, a, r, s', d)} from D
                if PRIORITIZE_REPLAY:
                    if OBSERVATION_LOW:
                        state_batch, action_batch, reward_batch, state_next_batch, done_batch, weights_pr, idxes = D.sample(
                            BATCH_SIZE, BETA)
                        state_batch = torch.from_numpy(state_batch).float().to(
                            device)
                        next_state_batch = torch.from_numpy(
                            state_next_batch).float().to(device)
                        action_batch = torch.from_numpy(
                            action_batch).float().to(device)
                        reward_batch = torch.from_numpy(
                            reward_batch).float().to(device)
                        done_batch = torch.from_numpy(done_batch).float().to(
                            device)
                        weights_pr = torch.from_numpy(weights_pr).float().to(
                            device)
                    else:
                        idxes, high_state_batch, low_state_batch, action_batch, reward_batch, high_state_next_batch, low_state_next_batch, done_batch, weights_pr = D.sample(
                            BATCH_SIZE)

                        state_batch = {
                            'low':
                            low_state_batch.float().to(device).view(-1, 32),
                            'high':
                            high_state_batch.float().to(device).view(
                                -1, 12, 128, 128)
                        }
                        next_state_batch = {
                            'low':
                            low_state_next_batch.float().to(device).view(
                                -1, 32),
                            'high':
                            high_state_next_batch.float().to(device).view(
                                -1, 12, 128, 128)
                        }

                        action_batch = action_batch.float().to(device)
                        reward_batch = reward_batch.float().to(device)
                        done_batch = done_batch.float().to(device)
                        weights_pr = weights_pr.float().to(device)
                        # for j in range(BATCH_SIZE):
                        #     new_state_batch['high'] = torch.cat((new_state_batch['high'], state_batch[j].tolist()['high'].view(-1,(3+1)*env.frames,128,128)), dim=0)
                        #     new_state_batch['low'] = torch.cat((new_state_batch['low'], state_batch[j].tolist()['low'].view(-1,32)), dim=0)
                        #     new_next_state_batch['high'] = torch.cat((new_next_state_batch['high'], state_next_batch[j].tolist()['high'].view(-1,(3+1)*env.frames,128,128)), dim=0)
                        #     new_next_state_batch['low'] = torch.cat((new_next_state_batch['low'], state_next_batch[j].tolist()['low'].view(-1,32)), dim=0)
                        # new_state_batch['high'] = new_state_batch['high'].to(device)
                        # new_state_batch['low'] = new_state_batch['low'].to(device)
                        # new_next_state_batch['high'] = new_next_state_batch['high'].to(device)
                        # new_next_state_batch['low'] = new_next_state_batch['low'].to(device)

                    batch = {
                        'state': state_batch,
                        'action': action_batch,
                        'reward': reward_batch,
                        'next_state': next_state_batch,
                        'done': done_batch
                    }
                    state_batch = []
                    state_next_batch = []

                else:
                    batch = random.sample(D, BATCH_SIZE)
                    state_batch = []
                    action_batch = []
                    reward_batch = []
                    state_next_batch = []
                    done_batch = []
                    for d in batch:
                        state_batch.append(d['state'])
                        action_batch.append(d['action'])
                        reward_batch.append(d['reward'])
                        state_next_batch.append(d['next_state'])
                        done_batch.append(d['done'])

                    batch = {
                        'state': torch.cat(state_batch, dim=0),
                        'action': torch.cat(action_batch, dim=0),
                        'reward': torch.cat(reward_batch, dim=0),
                        'next_state': torch.cat(state_next_batch, dim=0),
                        'done': torch.cat(done_batch, dim=0)
                    }

                action, log_prob = actor(batch['state'],
                                         log_prob=True,
                                         deterministic=False)

                #Automatic entropy tuning
                alpha_loss = -(
                    log_alpha.float() *
                    (log_prob + target_entropy).float().detach()).mean()
                alpha_optimizer.zero_grad()
                alpha_loss.backward()
                alpha_optimizer.step()
                alpha = log_alpha.exp()
                weighted_sample_entropy = (alpha.float() * log_prob).view(
                    -1, 1)

                # Compute targets for Q and V functions
                if VALUE_FNC:
                    y_q = batch['reward'] + DISCOUNT * (
                        1 - batch['done']) * target_value_critic(
                            batch['next_state'])
                    y_v = torch.min(
                        critic_1(batch['state']['low'], action.detach()),
                        critic_2(batch['state']['low'], action.detach())
                    ) - weighted_sample_entropy.detach()
                else:
                    # No value function network
                    with torch.no_grad():
                        next_actions, next_log_prob = actor(
                            batch['next_state'],
                            log_prob=True,
                            deterministic=False)
                        target_qs = torch.min(
                            target_critic_1(
                                batch['next_state']['low'] if
                                not OBSERVATION_LOW else batch['next_state'],
                                next_actions),
                            target_critic_2(
                                batch['next_state']['low'] if
                                not OBSERVATION_LOW else batch['next_state'],
                                next_actions)) - alpha * next_log_prob
                    y_q = batch['reward'] + DISCOUNT * (
                        1 - batch['done']) * target_qs.detach()

                td_error_critic1 = critic_1(
                    batch['state']['low'] if not OBSERVATION_LOW else
                    batch['state'], batch['action']) - y_q
                td_error_critic2 = critic_2(
                    batch['state']['low'] if not OBSERVATION_LOW else
                    batch['state'], batch['action']) - y_q

                q_loss = (td_error_critic1).pow(2).mean() + (
                    td_error_critic2).pow(2).mean()
                # q_loss = (F.mse_loss(critic_1(batch['state'], batch['action']), y_q) + F.mse_loss(critic_2(batch['state'], batch['action']), y_q)).mean()
                critics_optimiser.zero_grad()
                q_loss.backward()
                critics_optimiser.step()

                # Compute priorities, taking demonstrations into account
                if PRIORITIZE_REPLAY:
                    td_error = weights_pr * (td_error_critic1.detach() +
                                             td_error_critic2.detach()).mean()
                    action_dem = torch.tensor([]).to(device)
                    if OBSERVATION_LOW:
                        state_dem = torch.tensor([]).to(device)
                    else:
                        state_dem = {
                            'low': torch.tensor([]).float().to(device),
                            'high': torch.tensor([]).float().to(device)
                        }
                    priorities = torch.abs(td_error).tolist()
                    i = 0
                    count_dem = 0
                    for idx in idxes:
                        priorities[i] += epsilon
                        if idx < n_demonstrations:
                            priorities[i] += epsilon_d
                            count_dem += 1
                            if BEHAVIOR_CLONING:
                                action_dem = torch.cat(
                                    (action_dem, batch['action'][i].view(
                                        1, -1)),
                                    dim=0)
                                if OBSERVATION_LOW:
                                    state_dem = torch.cat(
                                        (state_dem, batch['state'][i].view(
                                            1, -1)),
                                        dim=0)
                                else:
                                    state_dem['high'] = torch.cat(
                                        (state_dem['high'],
                                         batch['state']['high'][i, ].view(
                                             -1,
                                             (3 + 1) * env.frames, 128, 128)),
                                        dim=0)
                                    state_dem['low'] = torch.cat(
                                        (state_dem['low'],
                                         batch['state']['low'][i, ].view(
                                             -1, 32)),
                                        dim=0)
                        i += 1
                    if not action_dem.nelement() == 0:
                        actual_action_dem, _ = actor(state_dem,
                                                     log_prob=False,
                                                     deterministic=True)
                        # q_value_actor = (critic_1(batch['state'][i], batch['action'][i]) + critic_2(batch['state'][i], batch['action'][i]))/2
                        # q_value_actual = (critic_1(batch['state'][i], actual_action_dem) + critic_2(batch['state'][i], actual_action_dem))/2
                        # if q_value_actor > q_value_actual: # Q Filter
                        behavior_loss = F.mse_loss(
                            action_dem, actual_action_dem).unsqueeze(dim=0)
                    else:
                        behavior_loss = 0

                    D.update_priorities(idxes, priorities)
                lambda_bc = (count_dem / BATCH_SIZE) / 5

                # Update V-function by one step of gradient descent
                if VALUE_FNC:
                    v_loss = (value_critic(batch['state']) -
                              y_v).pow(2).mean().to(device)

                    value_critic_optimiser.zero_grad()
                    v_loss.backward()
                    value_critic_optimiser.step()

                # Update policy by one step of gradient ascent
                with torch.no_grad():
                    new_qs = torch.min(
                        critic_1(
                            batch["state"]['low'] if not OBSERVATION_LOW else
                            batch['state'], action),
                        critic_2(
                            batch["state"]['low'] if not OBSERVATION_LOW else
                            batch['state'], action))
                policy_loss = lambda_ac * (weighted_sample_entropy.view(
                    -1) - new_qs).mean().to(device) + lambda_bc * behavior_loss
                actor_optimiser.zero_grad()
                policy_loss.backward()
                actor_optimiser.step()

                # Update target value network
                if VALUE_FNC:
                    update_target_network(value_critic, target_value_critic,
                                          POLYAK_FACTOR)
                else:
                    update_target_network(critic_1, target_critic_1,
                                          POLYAK_FACTOR)
                    update_target_network(critic_2, target_critic_2,
                                          POLYAK_FACTOR)
        state_dem = []

        # Continues to sample transitions till episode is done and evaluation is on
        if step > UPDATE_START and step % TEST_INTERVAL == 0: eval_c = True
        else: eval_c = False

        if eval_c == True and eval_c2 == True:
            eval_c = False
            eval_c2 = False
            actor.eval()
            critic_1.eval()
            critic_2.eval()
            q_value_eval = eval_.get_qvalue(critic_1, critic_2)
            return_ep, steps_ep = eval_.sample_episode(actor)

            logz.log_tabular('Training steps', step)
            logz.log_tabular('Cumulative Success', success)
            logz.log_tabular('Validation return', return_ep.mean())
            logz.log_tabular('Validation steps', steps_ep.mean())
            logz.log_tabular('Validation return std', return_ep.std())
            logz.log_tabular('Validation steps std', steps_ep.std())
            logz.log_tabular('Q-value evaluation', q_value_eval)
            logz.log_tabular('Q-network loss', q_loss.detach().cpu().numpy())
            if VALUE_FNC:
                logz.log_tabular('Value-network loss',
                                 v_loss.detach().cpu().numpy())
            logz.log_tabular('Policy-network loss',
                             policy_loss.detach().cpu().squeeze().numpy())
            logz.log_tabular('Alpha loss', alpha_loss.detach().cpu().numpy())
            logz.log_tabular('Alpha', alpha.detach().cpu().squeeze().numpy())
            logz.log_tabular('Demonstrations current batch', count_dem)
            logz.dump_tabular()

            logz.save_pytorch_model(actor.state_dict())

            torch.save(actor.state_dict(),
                       os.path.join(dir_models, 'actor_model_%s.pkl' % (step)))
            torch.save(
                critic_1.state_dict(),
                os.path.join(dir_models, 'critic1_model_%s.pkl' % (step)))
            torch.save(
                critic_2.state_dict(),
                os.path.join(dir_models, 'critic1_model_%s.pkl' % (step)))

            #pbar.set_description('Step: %i | Reward: %f' % (step, return_ep.mean()))

            actor.train()
            critic_1.train()
            critic_2.train()

    env.terminate()
示例#5
0
        )  # a(s) is a sample from μ(·|s) which is differentiable wrt θ via the reparameterisation trick
        weighted_sample_entropy = (alpha * policy.log_prob(action)).sum(dim=1)
        y_v = torch.min(
            critic_1(batch['state'], action.detach()),
            critic_2(batch['state'],
                     action.detach())) - weighted_sample_entropy.detach()

        # Update Q-functions by one step of gradient descent
        value_loss = (
            critic_1(batch['state'], batch['action']) - y_q).pow(2).mean() + (
                critic_2(batch['state'], batch['action']) - y_q).pow(2).mean()
        critics_optimiser.zero_grad()
        value_loss.backward()
        critics_optimiser.step()

        # Update V-function by one step of gradient descent
        value_loss = (value_critic(batch['state']) - y_v).pow(2).mean()
        value_critic_optimiser.zero_grad()
        value_loss.backward()
        value_critic_optimiser.step()

        # Update policy by one step of gradient ascent
        policy_loss = -(critic_1(batch['state'], action) +
                        weighted_sample_entropy).mean()
        actor_optimiser.zero_grad()
        policy_loss.backward()
        actor_optimiser.step()

        # Update target value network
        update_target_network(value_critic, target_value_critic, polyak_rate)