Beispiel #1
0
class Agent():
    def __init__(self, learn_rate, input_shape, num_actions):
        self.num_actions = num_actions
        self.gamma = 0.99
        self.critic_update_max = 20
        self.actor_update_max = 10
        self.memories = []

        # self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.device = torch.device("cpu")

        self.actor = Actor().to(self.device)
        self.critic = Critic().to(self.device)

        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=learn_rate)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=learn_rate)

    def choose_action(self, state, hidden_state):
        state = torch.tensor(state, dtype=torch.float32).to(self.device)

        policy, hidden_state_ = self.actor(state, hidden_state)
        policy = F.softmax(policy)
        actions_probs = torch.distributions.Categorical(policy)
        action = actions_probs.sample()
        action_log_prob = actions_probs.log_prob(action).unsqueeze(0)
        # action = torch.argmax(policy)

        #   prep for storage
        action = action.item()

        return action, policy, hidden_state_, action_log_prob

    def store_memory(self, memory):
        self.memories.append(memory)

    def get_discounted_cum_rewards(self, memory):
        cum_rewards = []
        total = 0
        for reward in reversed(memory.rewards):
            total = reward + total * self.gamma
            cum_rewards.append(total)
        cum_rewards = list(reversed(cum_rewards))
        cum_disc_rewards = torch.tensor(cum_rewards).float().to(self.device)

        return cum_rewards

    def learn(self):
        critic_losses = []
        for memory_idx, memory in enumerate(self.memories):
            print(memory_idx)
            states, actions, policies, rewards, dones, actor_hidden_states, action_log_probs = \
                memory.fetch_on_device(self.device)
            cum_disc_rewards = self.get_discounted_cum_rewards(memory)
            ''' train critic    '''
            self.critic.train()
            self.actor.eval()

            critic_hidden_state = self.critic.get_new_hidden_state()
            for i in range(len(memory.states)):
                state = states[i].detach()
                policy = policies[i].detach()
                action_log_prob = action_log_probs[i].detach()
                done = dones[i].detach()

                true_value = cum_disc_rewards[i]

                value, critic_hidden_state_ = self.critic(
                    state, action_log_prob, critic_hidden_state)
                if done:
                    true_value *= 0.0
                error = value - true_value
                # print("true: {}, value: {}".format(true_value, value))
                critic_loss = error**2
                if critic_loss >= self.critic_update_max:
                    print("critic_loss BIG: {}".format(critic_loss))
                critic_loss = torch.clamp(critic_loss, -self.critic_update_max,
                                          self.critic_update_max)
                critic_losses.append(critic_loss)

                critic_hidden_state = critic_hidden_state_

        # print("end")
        all_critic_loss = sum(critic_losses)
        # all_critic_loss = torch.stack(critic_losses).mean()
        self.critic_optimizer.zero_grad()
        all_critic_loss.backward()
        self.critic_optimizer.step()

        actor_losses = []
        for memory_idx, memory in enumerate(self.memories):
            print(memory_idx)
            states, actions, policies, rewards, dones, actor_hidden_states, action_log_probs = \
                memory.fetch_on_device(self.device)
            ''' train actor     '''
            self.critic.eval()
            self.actor.train()

            critic_hidden_state = self.critic.get_new_hidden_state()
            for i in range(len(memory.states)):
                state = states[i].detach()
                # policy = policies[i]
                action_log_prob = action_log_probs[i]
                critic_hidden_state = critic_hidden_state.detach()
                done = dones[i].detach()

                value, critic_hidden_state_ = self.critic(
                    state, action_log_prob, critic_hidden_state)
                if done:
                    value *= 0.0
                # print("true: {}, value: {}".format(true_value, value))
                actor_loss = value
                if actor_loss >= self.actor_update_max:
                    print("actor_loss BIG: {}".format(actor_loss))
                actor_loss = torch.clamp(actor_loss, -self.actor_update_max,
                                         self.actor_update_max)
                actor_losses.append(actor_loss)

                critic_hidden_state = critic_hidden_state_

        all_actor_loss = sum(actor_losses)
        # all_actor_loss = torch.stack(actor_losses).mean()
        self.actor_optimizer.zero_grad()
        all_actor_loss.backward()
        self.actor_optimizer.step()
Beispiel #2
0
def train(BATCH_SIZE, DISCOUNT, ENTROPY_WEIGHT, HIDDEN_SIZE, LEARNING_RATE,
          MAX_STEPS, POLYAK_FACTOR, REPLAY_SIZE, TEST_INTERVAL,
          UPDATE_INTERVAL, UPDATE_START, ENV, OBSERVATION_LOW, VALUE_FNC,
          FLOW_TYPE, FLOWS, DEMONSTRATIONS, PRIORITIZE_REPLAY,
          BEHAVIOR_CLONING, ARM, BASE, RPA, REWARD_DENSE, logdir):

    ALPHA = 0.3
    BETA = 1
    epsilon = 0.0001  #0.1
    epsilon_d = 0.1  #0.3
    weights = 1  #1
    lambda_ac = 0.85  #0.7
    lambda_bc = 0.3  #0.4

    setup_logger(logdir, locals())
    ENV = __import__(ENV)
    if ARM and BASE:
        env = ENV.youBotAll('youbot_navig2.ttt',
                            obs_lowdim=OBSERVATION_LOW,
                            rpa=RPA,
                            reward_dense=REWARD_DENSE,
                            boundary=1)
    elif ARM:
        env = ENV.youBotArm('youbot_navig.ttt',
                            obs_lowdim=OBSERVATION_LOW,
                            rpa=RPA,
                            reward_dense=REWARD_DENSE)
    elif BASE:
        env = ENV.youBotBase('youbot_navig.ttt',
                             obs_lowdim=OBSERVATION_LOW,
                             rpa=RPA,
                             reward_dense=REWARD_DENSE,
                             boundary=1)

    action_space = env.action_space
    obs_space = env.observation_space()
    step_limit = env.step_limit()

    if OBSERVATION_LOW:
        actor = SoftActorGated(HIDDEN_SIZE,
                               action_space,
                               obs_space,
                               flow_type=FLOW_TYPE,
                               flows=FLOWS).float().to(device)
        critic_1 = Critic(HIDDEN_SIZE,
                          1,
                          obs_space,
                          action_space,
                          state_action=True).float().to(device)
        critic_2 = Critic(HIDDEN_SIZE,
                          1,
                          obs_space,
                          action_space,
                          state_action=True).float().to(device)
    else:
        actor = ActorImageNet(HIDDEN_SIZE,
                              action_space,
                              obs_space,
                              flow_type=FLOW_TYPE,
                              flows=FLOWS).float().to(device)
        critic_1 = Critic(HIDDEN_SIZE,
                          1,
                          obs_space,
                          action_space,
                          state_action=True).float().to(device)
        critic_2 = Critic(HIDDEN_SIZE,
                          1,
                          obs_space,
                          action_space,
                          state_action=True).float().to(device)
        critic_1.load_state_dict(
            torch.load(
                'data/youbot_all_final_21-08-2019_22-32-00/models/critic1_model_473000.pkl'
            ))
        critic_2.load_state_dict(
            torch.load(
                'data/youbot_all_final_21-08-2019_22-32-00/models/critic1_model_473000.pkl'
            ))

    actor.apply(weights_init)
    # critic_1.apply(weights_init)
    # critic_2.apply(weights_init)

    if VALUE_FNC:
        value_critic = Critic(HIDDEN_SIZE, 1, obs_space,
                              action_space).float().to(device)
        target_value_critic = create_target_network(value_critic).float().to(
            device)
        value_critic_optimiser = optim.Adam(value_critic.parameters(),
                                            lr=LEARNING_RATE)
    else:
        target_critic_1 = create_target_network(critic_1)
        target_critic_2 = create_target_network(critic_2)
    actor_optimiser = optim.Adam(actor.parameters(), lr=LEARNING_RATE)
    critics_optimiser = optim.Adam(list(critic_1.parameters()) +
                                   list(critic_2.parameters()),
                                   lr=LEARNING_RATE)

    # Replay buffer
    if PRIORITIZE_REPLAY:
        # D = PrioritizedReplayBuffer(REPLAY_SIZE, ALPHA)
        D = ReplayMemory(device, 3, DISCOUNT, 1, BETA, ALPHA, REPLAY_SIZE)
    else:
        D = deque(maxlen=REPLAY_SIZE)

    eval_ = evaluation_sac(env, logdir, device)

    #Automatic entropy tuning init
    target_entropy = -np.prod(action_space).item()
    log_alpha = torch.zeros(1, requires_grad=True, device=device)
    alpha_optimizer = optim.Adam([log_alpha], lr=LEARNING_RATE)

    home = os.path.expanduser('~')
    if DEMONSTRATIONS:
        dir_dem = os.path.join(home, 'robotics_drl/data/demonstrations/',
                               DEMONSTRATIONS)
        D, n_demonstrations = load_buffer_demonstrations(
            D, dir_dem, PRIORITIZE_REPLAY, OBSERVATION_LOW)
    else:
        n_demonstrations = 0

    if not BEHAVIOR_CLONING:
        behavior_loss = 0

    os.mkdir(os.path.join(home, 'robotics_drl', logdir, 'models'))
    dir_models = os.path.join(home, 'robotics_drl', logdir, 'models')

    state, done = env.reset(), False
    if OBSERVATION_LOW:
        state = state.float().to(device)
    else:
        state['low'] = state['low'].float()
        state['high'] = state['high'].float()
    pbar = tqdm(range(1, MAX_STEPS + 1), unit_scale=1, smoothing=0)

    steps = 0
    success = 0
    for step in pbar:
        with torch.no_grad():
            if step < UPDATE_START and not DEMONSTRATIONS:
                # To improve exploration take actions sampled from a uniform random distribution over actions at the start of training
                action = torch.tensor(env.sample_action(),
                                      dtype=torch.float32,
                                      device=device).unsqueeze(dim=0)
            else:
                # Observe state s and select action a ~ μ(a|s)
                if not OBSERVATION_LOW:
                    state['low'] = state['low'].float().to(device)
                    state['high'] = state['high'].float().to(device)
                action, _ = actor(state, log_prob=False, deterministic=False)
                if not OBSERVATION_LOW:
                    state['low'] = state['low'].float().cpu()
                    state['high'] = state['high'].float().cpu()
                #if (policy.mean).mean() > 0.4:
                #    print("GOOD VELOCITY")
            # Execute a in the environment and observe next state s', reward r, and done signal d to indicate whether s' is terminal
            next_state, reward, done = env.step(
                action.squeeze(dim=0).cpu().tolist())
            if OBSERVATION_LOW:
                next_state = next_state.float().to(device)
            else:
                next_state['low'] = next_state['low'].float()
                next_state['high'] = next_state['high'].float()
            # Store (s, a, r, s', d) in replay buffer D
            if PRIORITIZE_REPLAY:
                if OBSERVATION_LOW:
                    D.add(state.cpu().tolist(),
                          action.cpu().squeeze().tolist(), reward,
                          next_state.cpu().tolist(), done)
                else:
                    D.append(state['high'], state['low'],
                             action.cpu().squeeze().tolist(), reward, done)
            else:
                D.append({
                    'state':
                    state.unsqueeze(dim=0) if OBSERVATION_LOW else state,
                    'action':
                    action,
                    'reward':
                    torch.tensor([reward], dtype=torch.float32, device=device),
                    'next_state':
                    next_state.unsqueeze(
                        dim=0) if OBSERVATION_LOW else next_state,
                    'done':
                    torch.tensor([True if reward == 1 else False],
                                 dtype=torch.float32,
                                 device=device)
                })

            state = next_state

            # If s' is terminal, reset environment state
            steps += 1

            if done or steps > step_limit:  #TODO: incorporate step limit in the environment
                eval_c2 = True  #TODO: multiprocess pyrep with a session for each testing and training
                steps = 0
                if OBSERVATION_LOW:
                    state = env.reset().float().to(device)
                else:
                    state = env.reset()
                    state['low'] = state['low'].float()
                    state['high'] = state['high'].float()
                if reward == 1:
                    success += 1

        if step > UPDATE_START and step % UPDATE_INTERVAL == 0:
            for _ in range(1):
                # Randomly sample a batch of transitions B = {(s, a, r, s', d)} from D
                if PRIORITIZE_REPLAY:
                    if OBSERVATION_LOW:
                        state_batch, action_batch, reward_batch, state_next_batch, done_batch, weights_pr, idxes = D.sample(
                            BATCH_SIZE, BETA)
                        state_batch = torch.from_numpy(state_batch).float().to(
                            device)
                        next_state_batch = torch.from_numpy(
                            state_next_batch).float().to(device)
                        action_batch = torch.from_numpy(
                            action_batch).float().to(device)
                        reward_batch = torch.from_numpy(
                            reward_batch).float().to(device)
                        done_batch = torch.from_numpy(done_batch).float().to(
                            device)
                        weights_pr = torch.from_numpy(weights_pr).float().to(
                            device)
                    else:
                        idxes, high_state_batch, low_state_batch, action_batch, reward_batch, high_state_next_batch, low_state_next_batch, done_batch, weights_pr = D.sample(
                            BATCH_SIZE)

                        state_batch = {
                            'low':
                            low_state_batch.float().to(device).view(-1, 32),
                            'high':
                            high_state_batch.float().to(device).view(
                                -1, 12, 128, 128)
                        }
                        next_state_batch = {
                            'low':
                            low_state_next_batch.float().to(device).view(
                                -1, 32),
                            'high':
                            high_state_next_batch.float().to(device).view(
                                -1, 12, 128, 128)
                        }

                        action_batch = action_batch.float().to(device)
                        reward_batch = reward_batch.float().to(device)
                        done_batch = done_batch.float().to(device)
                        weights_pr = weights_pr.float().to(device)
                        # for j in range(BATCH_SIZE):
                        #     new_state_batch['high'] = torch.cat((new_state_batch['high'], state_batch[j].tolist()['high'].view(-1,(3+1)*env.frames,128,128)), dim=0)
                        #     new_state_batch['low'] = torch.cat((new_state_batch['low'], state_batch[j].tolist()['low'].view(-1,32)), dim=0)
                        #     new_next_state_batch['high'] = torch.cat((new_next_state_batch['high'], state_next_batch[j].tolist()['high'].view(-1,(3+1)*env.frames,128,128)), dim=0)
                        #     new_next_state_batch['low'] = torch.cat((new_next_state_batch['low'], state_next_batch[j].tolist()['low'].view(-1,32)), dim=0)
                        # new_state_batch['high'] = new_state_batch['high'].to(device)
                        # new_state_batch['low'] = new_state_batch['low'].to(device)
                        # new_next_state_batch['high'] = new_next_state_batch['high'].to(device)
                        # new_next_state_batch['low'] = new_next_state_batch['low'].to(device)

                    batch = {
                        'state': state_batch,
                        'action': action_batch,
                        'reward': reward_batch,
                        'next_state': next_state_batch,
                        'done': done_batch
                    }
                    state_batch = []
                    state_next_batch = []

                else:
                    batch = random.sample(D, BATCH_SIZE)
                    state_batch = []
                    action_batch = []
                    reward_batch = []
                    state_next_batch = []
                    done_batch = []
                    for d in batch:
                        state_batch.append(d['state'])
                        action_batch.append(d['action'])
                        reward_batch.append(d['reward'])
                        state_next_batch.append(d['next_state'])
                        done_batch.append(d['done'])

                    batch = {
                        'state': torch.cat(state_batch, dim=0),
                        'action': torch.cat(action_batch, dim=0),
                        'reward': torch.cat(reward_batch, dim=0),
                        'next_state': torch.cat(state_next_batch, dim=0),
                        'done': torch.cat(done_batch, dim=0)
                    }

                action, log_prob = actor(batch['state'],
                                         log_prob=True,
                                         deterministic=False)

                #Automatic entropy tuning
                alpha_loss = -(
                    log_alpha.float() *
                    (log_prob + target_entropy).float().detach()).mean()
                alpha_optimizer.zero_grad()
                alpha_loss.backward()
                alpha_optimizer.step()
                alpha = log_alpha.exp()
                weighted_sample_entropy = (alpha.float() * log_prob).view(
                    -1, 1)

                # Compute targets for Q and V functions
                if VALUE_FNC:
                    y_q = batch['reward'] + DISCOUNT * (
                        1 - batch['done']) * target_value_critic(
                            batch['next_state'])
                    y_v = torch.min(
                        critic_1(batch['state']['low'], action.detach()),
                        critic_2(batch['state']['low'], action.detach())
                    ) - weighted_sample_entropy.detach()
                else:
                    # No value function network
                    with torch.no_grad():
                        next_actions, next_log_prob = actor(
                            batch['next_state'],
                            log_prob=True,
                            deterministic=False)
                        target_qs = torch.min(
                            target_critic_1(
                                batch['next_state']['low'] if
                                not OBSERVATION_LOW else batch['next_state'],
                                next_actions),
                            target_critic_2(
                                batch['next_state']['low'] if
                                not OBSERVATION_LOW else batch['next_state'],
                                next_actions)) - alpha * next_log_prob
                    y_q = batch['reward'] + DISCOUNT * (
                        1 - batch['done']) * target_qs.detach()

                td_error_critic1 = critic_1(
                    batch['state']['low'] if not OBSERVATION_LOW else
                    batch['state'], batch['action']) - y_q
                td_error_critic2 = critic_2(
                    batch['state']['low'] if not OBSERVATION_LOW else
                    batch['state'], batch['action']) - y_q

                q_loss = (td_error_critic1).pow(2).mean() + (
                    td_error_critic2).pow(2).mean()
                # q_loss = (F.mse_loss(critic_1(batch['state'], batch['action']), y_q) + F.mse_loss(critic_2(batch['state'], batch['action']), y_q)).mean()
                critics_optimiser.zero_grad()
                q_loss.backward()
                critics_optimiser.step()

                # Compute priorities, taking demonstrations into account
                if PRIORITIZE_REPLAY:
                    td_error = weights_pr * (td_error_critic1.detach() +
                                             td_error_critic2.detach()).mean()
                    action_dem = torch.tensor([]).to(device)
                    if OBSERVATION_LOW:
                        state_dem = torch.tensor([]).to(device)
                    else:
                        state_dem = {
                            'low': torch.tensor([]).float().to(device),
                            'high': torch.tensor([]).float().to(device)
                        }
                    priorities = torch.abs(td_error).tolist()
                    i = 0
                    count_dem = 0
                    for idx in idxes:
                        priorities[i] += epsilon
                        if idx < n_demonstrations:
                            priorities[i] += epsilon_d
                            count_dem += 1
                            if BEHAVIOR_CLONING:
                                action_dem = torch.cat(
                                    (action_dem, batch['action'][i].view(
                                        1, -1)),
                                    dim=0)
                                if OBSERVATION_LOW:
                                    state_dem = torch.cat(
                                        (state_dem, batch['state'][i].view(
                                            1, -1)),
                                        dim=0)
                                else:
                                    state_dem['high'] = torch.cat(
                                        (state_dem['high'],
                                         batch['state']['high'][i, ].view(
                                             -1,
                                             (3 + 1) * env.frames, 128, 128)),
                                        dim=0)
                                    state_dem['low'] = torch.cat(
                                        (state_dem['low'],
                                         batch['state']['low'][i, ].view(
                                             -1, 32)),
                                        dim=0)
                        i += 1
                    if not action_dem.nelement() == 0:
                        actual_action_dem, _ = actor(state_dem,
                                                     log_prob=False,
                                                     deterministic=True)
                        # q_value_actor = (critic_1(batch['state'][i], batch['action'][i]) + critic_2(batch['state'][i], batch['action'][i]))/2
                        # q_value_actual = (critic_1(batch['state'][i], actual_action_dem) + critic_2(batch['state'][i], actual_action_dem))/2
                        # if q_value_actor > q_value_actual: # Q Filter
                        behavior_loss = F.mse_loss(
                            action_dem, actual_action_dem).unsqueeze(dim=0)
                    else:
                        behavior_loss = 0

                    D.update_priorities(idxes, priorities)
                lambda_bc = (count_dem / BATCH_SIZE) / 5

                # Update V-function by one step of gradient descent
                if VALUE_FNC:
                    v_loss = (value_critic(batch['state']) -
                              y_v).pow(2).mean().to(device)

                    value_critic_optimiser.zero_grad()
                    v_loss.backward()
                    value_critic_optimiser.step()

                # Update policy by one step of gradient ascent
                with torch.no_grad():
                    new_qs = torch.min(
                        critic_1(
                            batch["state"]['low'] if not OBSERVATION_LOW else
                            batch['state'], action),
                        critic_2(
                            batch["state"]['low'] if not OBSERVATION_LOW else
                            batch['state'], action))
                policy_loss = lambda_ac * (weighted_sample_entropy.view(
                    -1) - new_qs).mean().to(device) + lambda_bc * behavior_loss
                actor_optimiser.zero_grad()
                policy_loss.backward()
                actor_optimiser.step()

                # Update target value network
                if VALUE_FNC:
                    update_target_network(value_critic, target_value_critic,
                                          POLYAK_FACTOR)
                else:
                    update_target_network(critic_1, target_critic_1,
                                          POLYAK_FACTOR)
                    update_target_network(critic_2, target_critic_2,
                                          POLYAK_FACTOR)
        state_dem = []

        # Continues to sample transitions till episode is done and evaluation is on
        if step > UPDATE_START and step % TEST_INTERVAL == 0: eval_c = True
        else: eval_c = False

        if eval_c == True and eval_c2 == True:
            eval_c = False
            eval_c2 = False
            actor.eval()
            critic_1.eval()
            critic_2.eval()
            q_value_eval = eval_.get_qvalue(critic_1, critic_2)
            return_ep, steps_ep = eval_.sample_episode(actor)

            logz.log_tabular('Training steps', step)
            logz.log_tabular('Cumulative Success', success)
            logz.log_tabular('Validation return', return_ep.mean())
            logz.log_tabular('Validation steps', steps_ep.mean())
            logz.log_tabular('Validation return std', return_ep.std())
            logz.log_tabular('Validation steps std', steps_ep.std())
            logz.log_tabular('Q-value evaluation', q_value_eval)
            logz.log_tabular('Q-network loss', q_loss.detach().cpu().numpy())
            if VALUE_FNC:
                logz.log_tabular('Value-network loss',
                                 v_loss.detach().cpu().numpy())
            logz.log_tabular('Policy-network loss',
                             policy_loss.detach().cpu().squeeze().numpy())
            logz.log_tabular('Alpha loss', alpha_loss.detach().cpu().numpy())
            logz.log_tabular('Alpha', alpha.detach().cpu().squeeze().numpy())
            logz.log_tabular('Demonstrations current batch', count_dem)
            logz.dump_tabular()

            logz.save_pytorch_model(actor.state_dict())

            torch.save(actor.state_dict(),
                       os.path.join(dir_models, 'actor_model_%s.pkl' % (step)))
            torch.save(
                critic_1.state_dict(),
                os.path.join(dir_models, 'critic1_model_%s.pkl' % (step)))
            torch.save(
                critic_2.state_dict(),
                os.path.join(dir_models, 'critic1_model_%s.pkl' % (step)))

            #pbar.set_description('Step: %i | Reward: %f' % (step, return_ep.mean()))

            actor.train()
            critic_1.train()
            critic_2.train()

    env.terminate()
Beispiel #3
0
class Agent():
    def __init__(self, learn_rate, input_shape, num_actions):
        self.num_actions = num_actions
        self.gamma = 0.9999
        self.memory = Memory()

        # self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.device = torch.device("cpu")

        self.actor = Actor().to(self.device)
        self.critic = Critic().to(self.device)

        critic_params = list(self.critic.parameters())
        pprint(critic_params)
        quit()

        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=learn_rate)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=learn_rate)

    def choose_action(self, state, hidden_state):
        state = torch.tensor(state,        dtype=torch.float32).to(self.device)

        policy, hidden_state_ = self.actor(state, hidden_state)
        action = torch.argmax(policy)

        #   prep for storage
        action = action.item()

        return action, policy, hidden_state_

    def store_memory(self, memory):
        self.memory.store(*memory)

    def get_discounted_cum_rewards(self):
        cum_rewards = []
        total = 0
        for reward in reversed(self.memory.rewards):
            total = reward + total * self.gamma
            cum_rewards.append(total)
        cum_rewards = list(reversed(cum_rewards))
        cum_disc_rewards = torch.tensor(cum_rewards).float().to(self.device)
        
        return cum_rewards

    def learn(self):
        states, actions, policies, rewards, dones, actor_hidden_states = self.memory.fetch_on_device(self.device)
        cum_disc_rewards = self.get_discounted_cum_rewards()
        # print("cum_disc_rewards")
        # pprint(cum_disc_rewards)
        # quit()

        ''' train critic    '''
        self.critic.train()
        self.actor.eval()

        critic_losses = []
        critic_hidden_state = self.critic.get_new_hidden_state()
        for i in range(len(self.memory.states) - 1):
            state = states[i].detach()
            policy = policies[i].detach()
            true_value = cum_disc_rewards[i]

            value, critic_hidden_state_ = self.critic(state, policy, critic_hidden_state)
            error = value - true_value
            # print("true: {}, value: {}".format(true_value, value))
            critic_loss = error**2
            critic_losses.append(critic_loss)

            critic_hidden_state = critic_hidden_state_

        # print("end")
        all_critic_loss = sum(critic_losses)
        self.critic_optimizer.zero_grad()
        all_critic_loss.backward()
        self.critic_optimizer.step()

        ''' train actor     '''
        self.critic.eval()
        self.actor.train()

        actor_losses = []
        critic_hidden_state = self.critic.get_new_hidden_state()
        for i in range(len(self.memory.states) - 1):
            state = states[i].detach()
            policy = policies[i]
            critic_hidden_state = critic_hidden_state.detach()

            value, critic_hidden_state_ = self.critic(state, policy, critic_hidden_state)
            # print("true: {}, value: {}".format(true_value, value))
            actor_loss = value
            actor_losses.append(actor_loss)

            critic_hidden_state = critic_hidden_state_

        all_actor_loss = sum(actor_losses)
        self.actor_optimizer.zero_grad()
        all_actor_loss.backward()
        self.actor_optimizer.step()

    def learn_old(self):
        self.actor.eval()
        self.critic.train()
    
        state = torch.tensor(state).float()
        state = state.to(self.device)
        state_ = torch.tensor(state_).float()
        state_ = state_.to(self.device)

        reward = torch.tensor(reward, dtype=torch.float).to(self.device)
        done = torch.tensor(done, dtype=torch.bool).to(self.device)

        value, critic_hidden_state_ = self.critic(state, policy, self.critic_hidden_state)

        policy_, _ = self.actor(state, self.actor_hidden_state)
        value_, critic_hidden_state_ = self.critic(state_, policy_, critic_hidden_state_)
        if done:
            value_ = 0.0

        target = reward + self.gamma * value_
        td = target - value

        critic_loss = td**2

        self.critic_optimizer.zero_grad()
        if not done:
            critic_loss.backward(retain_graph=True, allow_unreachable=True)
        else:
            critic_loss.backward(allow_unreachable=True)
        self.critic_optimizer.step()

        ''' update based on new policy of old states '''
        self.critic.eval()
        self.actor.train()
        retro_value, retro_critic_hidden_state_ = self.critic(state, policy, self.critic_hidden_state)

        retro_policy_, actor_hidden_state_ = self.actor(state, self.actor_hidden_state)
        retro_value_, _ = self.critic(state_, retro_policy_, retro_critic_hidden_state_)
        if done:
            retro_value_ = 0.0

        actor_loss = -(retro_value_ - retro_value)

        self.actor_optimizer.zero_grad()
        if not done:
            actor_loss.backward(retain_graph=True, allow_unreachable=True)
        else:
            actor_loss.backward(allow_unreachable=True)
        self.actor_optimizer.step()

        ''' update hidden states    '''
        self.actor_hidden_state = actor_hidden_state_
        self.critic_hidden_state_ = critic_hidden_state_
Beispiel #4
0
class DDPG(object):

    def __init__(self, gamma, tau,num_inputs, env,device, results_path=None):

        self.gamma = gamma
        self.tau = tau
        self.min_action,self.max_action = env.action_range()
        self.device = device
        self.num_actions = env.action_space()
        self.noise_stddev = 0.3

        self.results_path = results_path
        self.checkpoint_path = os.path.join(self.results_path, 'checkpoint/')
        os.makedirs(self.checkpoint_path, exist_ok=True)

        # Define the actor
        self.actor = Actor(num_inputs, self.num_actions).to(device)
        self.actor_target = Actor(num_inputs, self.num_actions).to(device)

        # Define the critic
        self.critic = Critic(num_inputs, self.num_actions).to(device)
        self.critic_target = Critic(num_inputs, self.num_actions).to(device)

        # Define the optimizers for both networks
        self.actor_optimizer  = Adam(self.actor.parameters(),  lr=1e-4 )                          # optimizer for the actor network
        self.critic_optimizer = Adam(self.critic.parameters(), lr=1e-4,   weight_decay=0.002)  # optimizer for the critic network

        self.hard_swap()

        self.ou_noise = OrnsteinUhlenbeckActionNoise(mu=np.zeros(self.num_actions),
                                            sigma=float(self.noise_stddev) * np.ones(self.num_actions))
        self.ou_noise.reset()

    def eval_mode(self):
        self.actor.eval()
        self.actor_target.eval()
        self.critic_target.eval()
        self.critic.eval()

    def train_mode(self):
        self.actor.train()
        self.actor_target.train()
        self.critic_target.train()
        self.critic.train()


    def get_action(self, state, episode, action_noise=True):
        x = state.to(self.device)

        # Get the continous action value to perform in the env
        self.actor.eval()  # Sets the actor in evaluation mode
        mu = self.actor(x)
        self.actor.train()  # Sets the actor in training mode
        mu = mu.data

        # During training we add noise for exploration
        if action_noise:
            noise = torch.Tensor(self.ou_noise.noise()).to(self.device) * 1.0/(1.0 + 0.1*episode)
            noise = noise.clamp(0,0.1)
            mu = mu + noise  # Add exploration noise ε ~ p(ε) to the action. Do not use OU noise (https://spinningup.openai.com/en/latest/algorithms/ddpg.html)

        # Clip the output according to the action space of the env
        mu = mu.clamp(self.min_action,self.max_action)

        return mu

    def update_params(self, batch):
        # Get tensors from the batch
        state_batch = torch.cat(batch.state).to(self.device)
        action_batch = torch.cat(batch.action).to(self.device)
        reward_batch = torch.cat(batch.reward).to(self.device)
        done_batch = torch.cat(batch.done).to(self.device)
        next_state_batch = torch.cat(batch.next_state).to(self.device)

        # Get the actions and the state values to compute the targets
        next_action_batch = self.actor_target(next_state_batch)
        next_state_action_values = self.critic_target(next_state_batch, next_action_batch.detach())

        # Compute the target
        reward_batch = reward_batch.unsqueeze(1)
        done_batch = done_batch.unsqueeze(1)
        expected_values = reward_batch + (1.0 - done_batch) * self.gamma * next_state_action_values

        # Update the critic network
        self.critic_optimizer.zero_grad()
        state_action_batch = self.critic(state_batch, action_batch)
        value_loss = F.mse_loss(state_action_batch, expected_values.detach())
        value_loss.backward()
        self.critic_optimizer.step()

        # Update the actor network
        self.actor_optimizer.zero_grad()
        policy_loss = -self.critic(state_batch, self.actor(state_batch))
        policy_loss = policy_loss.mean()
        policy_loss.backward()
        for param in self.actor.parameters():
                param.grad.data.clamp_(-1, 1)
        self.actor_optimizer.step()

       # Update the target networks
        soft_update(self.actor_target, self.actor, self.tau)
        soft_update(self.critic_target, self.critic, self.tau)

        return value_loss.item(), policy_loss.item()
    
    def hard_swap(self):
        # Make sure both targets are with the same weight
        hard_update(self.actor_target, self.actor)
        hard_update(self.critic_target, self.critic)

    def store_model(self):
        print("Storing model at: ", self.checkpoint_path)
        checkpoint = {
            'actor': self.actor.state_dict(),
            'actor_optim': self.actor_optimizer.state_dict(),
            'critic': self.critic.state_dict(),
            'criti_optim': self.critic_optimizer.state_dict()
        }
        torch.save(checkpoint, os.path.join(self.checkpoint_path, 'checkpoint.pth') )

    def load_model(self):
        files = os.listdir(self.checkpoint_path)
        if files:
            print("Loading models checkpoints!")
            model_dicts = torch.load(os.path.join(self.checkpoint_path, 'checkpoint.pth'),map_location=self.device)
            self.actor.load_state_dict(model_dicts['actor'])
            self.actor_optimizer.load_state_dict(model_dicts['actor_optim'])
            self.critic.load_state_dict(model_dicts['critic'])
            self.critic_optimizer.load_state_dict(model_dicts['criti_optim'])
        else:
            print("Checkpoints not found!")
Beispiel #5
0
crt   = Critic(Z_DIM , IN_CHANNELS , CHANNELS_IMG).to(DEVICE)

opt_gen = optim.Adam(gen.parameters() , lr = lr , betas = (0.0 , 0.99))
opt_crt = optim.Adam(crt.parameters() , lr = lr , betas = (0.0 , 0.99))

scalar_crt  = torch.cuda.amp.GradScaler()
scalar_gen  = torch.cuda.amp.GradScaler()

writer = SummaryWriter("logs/gan")

if LOAD_MODEL == True:
    load_checkpoint(CHECKPOINT_GEN , gen , opt_gen , lr)
    load_checkpoint(CHECKPOINT_CRITIC , crt , opt_crt , lr)

gen.train()
crt.train()

tensorboard_step = 0

step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))

for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5
    loader , dataset = get_loader(4 * 2**step)
    print(f"Current image size: {4 * 2 ** step}")
    
    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]") 
        tensorboard_step , alpha = train(crt , gen , loader , dataset , step , alpha , opt_crt , opt_gen , tensorboard_step , writer , scalar_crt , scalar_gen)
        if SAVE_MODEL == True:
            save_checkpoint(gen, opt_gen, filename = CHECKPOINT_GEN)
Beispiel #6
0
class Agent:
    def __init__(self,env, env_params, args, models=None, record_episodes=[0,.1,.25,.5,.75,1.]):
        self.env= env
        self.env_params = env_params
        self.args = args


        # networks
        if models == None:
                self.actor = Actor(self.env_params).double()
                self.critic = Critic(self.env_params).double()
        else:
                self.actor , self.critic = self.LoadModels()
        # target networks used to predict env actions with
        self.actor_target = Actor(self.env_params,).double()
        self.critic_target = Critic(self.env_params).double()

        self.actor_target.load_state_dict(self.actor.state_dict())
        self.critic_target.load_state_dict(self.critic.state_dict())

        if self.args.cuda:
            self.actor.cuda()
            self.critic.cuda()
            self.actor_target.cuda()
            self.critic_target.cuda()


        self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=0.001)
        self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=0.001)

        self.normalize = Normalizer(env_params,self.args.gamma)
        self.buffer = ReplayBuffer(1_000_000, self.env_params)
        self.tensorboard = ModifiedTensorBoard(log_dir = f"logs")
        self.record_episodes = [int(eps * self.args.n_epochs) for eps in record_episodes]

    def ModelsEval(self):
        self.actor.eval()
        self.actor_target.eval()
        self.critic.eval()
        self.critic_target.eval()

    def ModelsTrain(self):
        self.actor.train()
        self.actor_target.train()
        self.critic.train()
        self.critic_target.train()

    def GreedyAction(self, state):
        self.ModelsEval()
        with torch.no_grad():
            state = torch.tensor(state, dtype=torch.double).unsqueeze(dim=0)
            if self.args.cuda:
                state = state.cuda()
            action = self.actor.forward(state).detach().cpu().numpy().squeeze()
        return action

    def NoiseAction(self, state):
        self.ModelsEval()
        with torch.no_grad():
            state = torch.tensor(state, dtype=torch.double).unsqueeze(dim=0)
            if self.args.cuda:
                state = state.cuda()
            action = self.actor.forward(state).detach().cpu().numpy()
            action += self.args.noise_eps * self.env_params['max_action'] * np.random.randn(*action.shape)
            action = np.clip(action, -self.env_params['max_action'], self.env_params['max_action'])
        return action.squeeze()

    def Update(self):
        self.ModelsTrain()
        for i in range(self.args.n_batch):
            state, a_batch, r_batch, nextstate, d_batch = self.buffer.SampleBuffer(self.args.batch_size)
            a_batch = torch.tensor(a_batch,dtype=torch.double)
            r_batch = torch.tensor(r_batch,dtype=torch.double)
            # d_batch = torch.tensor(d_batch,dtype=torch.double)
            state = torch.tensor(state,dtype=torch.double)
            nextstate = torch.tensor(nextstate,dtype=torch.double)
            # d_batch = 1 - d_batch

            if self.args.cuda:
                a_batch = a_batch.cuda()
                r_batch = r_batch.cuda()
                # d_batch = d_batch.cuda()
                state = state.cuda()
                nextstate = nextstate.cuda()

            with torch.no_grad():
                action_next = self.actor_target.forward(nextstate)
                q_next = self.critic_target.forward(nextstate,action_next)
                q_next = q_next.detach().squeeze()
                q_target = r_batch + self.args.gamma * q_next
                q_target = q_target.detach().squeeze()

            q_prime = self.critic.forward(state, a_batch).squeeze()
            critic_loss = F.mse_loss(q_target, q_prime)

            action = self.actor.forward(state)
            actor_loss = -self.critic.forward(state, action).mean()
            # params = torch.cat([x.view(-1) for x in self.actor.parameters()])
            # l2_reg = self.args.l2_norm *torch.norm(params,2)
            # actor_loss += l2_reg

            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()

            self.critic_optim.zero_grad()
            critic_loss.backward()
            self.critic_optim.step()

        self.SoftUpdateTarget(self.critic, self.critic_target)
        self.SoftUpdateTarget(self.actor, self.actor_target)

    def Explore(self):
        for epoch in range(self.args.n_epochs +1):
            start_time = time.process_time()
            for cycle in range(self.args.n_cycles):
                for _ in range(self.args.num_rollouts_per_mpi):
                    state = self.env.reset()
                    for t in range(self.env_params['max_timesteps']):
                        action = self.NoiseAction(state)
                        nextstate, reward, done, info = self.env.step([action])
                        nextstate = nextstate.squeeze()
                        reward = self.normalize.normalize_reward(reward)
                        self.buffer.StoreTransition(state, action, reward, nextstate, done)
                        state = nextstate
                    self.Update()
            avg_reward = self.Evaluate()
            self.tensorboard.step = epoch
            elapsed_time = time.process_time() - start_time
            print(f"Epoch {epoch} of total of {self.args.n_epochs +1} epochs, average reward is: {avg_reward}.\
                    Elapsedtime: {int(elapsed_time /60)} minutes {int(elapsed_time %60)} seconds")
            if epoch % 5 or epoch + 1 == self.args.n_epochs:
                self.SaveModels(epoch)
                self.record(epoch)


    def Evaluate(self):
        self.ModelsEval()
        total_reward = []
        episode_reward = 0
        succes_rate = []
        for episode in range(self.args.n_evaluate):
            state = self.env.reset()
            episode_reward = 0
            for t in range(self.env_params['max_timesteps']):
                action = self.GreedyAction(state)
                nextstate, reward, done, info = self.env.step([action])
                episode_reward += reward
                state = nextstate
                if done or t + 1 == self.env_params['max_timesteps']:
                    total_reward.append(episode_reward)
                    episode_reward = 0

        average_reward = sum(total_reward)/len(total_reward)
        min_reward = min(total_reward)
        max_reward = max(total_reward)
        self.tensorboard.update_stats(reward_avg=average_reward, reward_min=min_reward, reward_max=max_reward)
        return average_reward

    def record(self, epoch):
        self.ModelsEval()
        try:
            if not os.path.exists("videos"):
                os.mkdir('videos')
            recorder = VideoRecorder(self.env, path=f'videos/epoch-{epoch}.mp4')
            for _ in range(self.args.n_record):
                done =False
                state = self.env.reset()
                while not done:
                    recorder.capture_frame()
                    action = self.GreedyAction(state)
                    nextstate,reward,done,info = self.env.step([action])
                    state = nextstate
                recorder.close()
        except Exception as e:
            print(e)

    def SaveModels(self, ep):
        if not os.path.exists("models"):
            os.mkdir('models')
        torch.save(self.actor.state_dict(), os.path.join('models', 'Actor.pt'))
        torch.save(self.critic.state_dict(), os.path.join('models', 'Critic.pt'))

    def LoadModels(self, actorpath, criticpath):
        actor = Actor(self.env_params, self.hidden_neurons)
        critic  = Critic(self.env_params, self.hidden_neurons)
        actor.load_state_dict(torch.load(actorpath))
        critic.load_state_dict(torch.load(criticpath))
        return actor, critic

    def SoftUpdateTarget(self, source, target):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_((1 - self.args.polyak) * param.data + self.args.polyak * target_param.data)
Beispiel #7
0
def adversarial_debiasing(model_state_dict, data, config, device):
    logger.info('Training Adversarial model.')
    actor = load_model(data.num_features, config.get('hyperparameters', {}))
    actor.load_state_dict(model_state_dict)
    actor.to(device)
    hid = config['hyperparameters'][
        'hid'] if 'hyperparameters' in config else 32
    critic = Critic(hid * config['adversarial']['batch_size'],
                    num_deep=config['adversarial']['num_deep'],
                    hid=hid)
    critic.to(device)
    critic_optimizer = optim.Adam(critic.parameters())
    critic_loss_fn = torch.nn.MSELoss()

    actor_optimizer = optim.Adam(actor.parameters(),
                                 lr=config['adversarial']['lr'])
    actor_loss_fn = torch.nn.BCELoss()

    for epoch in range(config['adversarial']['epochs']):
        for param in critic.parameters():
            param.requires_grad = True
        for param in actor.parameters():
            param.requires_grad = False
        actor.eval()
        critic.train()
        for step in range(config['adversarial']['critic_steps']):
            critic_optimizer.zero_grad()
            indices = torch.randint(0, data.X_valid.size(0),
                                    (config['adversarial']['batch_size'], ))
            cX_valid = data.X_valid_gpu[indices]
            cy_valid = data.y_valid[indices]
            cp_valid = data.p_valid[indices]
            with torch.no_grad():
                scores = actor(cX_valid)[:, 0].reshape(-1).cpu().numpy()

            bias = compute_bias(scores, cy_valid.numpy(), cp_valid,
                                config['metric'])

            res = critic(actor.trunc_forward(cX_valid))
            loss = critic_loss_fn(torch.tensor([bias], device=device), res[0])
            loss.backward()
            train_loss = loss.item()
            critic_optimizer.step()
            if (epoch % 10 == 0) and (step % 100 == 0):
                logger.info(
                    f'=======> Critic Epoch: {(epoch, step)} loss: {train_loss}'
                )

        for param in critic.parameters():
            param.requires_grad = False
        for param in actor.parameters():
            param.requires_grad = True
        actor.train()
        critic.eval()
        for step in range(config['adversarial']['actor_steps']):
            actor_optimizer.zero_grad()
            indices = torch.randint(0, data.X_valid.size(0),
                                    (config['adversarial']['batch_size'], ))
            cy_valid = data.y_valid_gpu[indices]
            cX_valid = data.X_valid_gpu[indices]

            pred_bias = critic(actor.trunc_forward(cX_valid))
            bceloss = actor_loss_fn(actor(cX_valid)[:, 0], cy_valid)

            # loss = lam*abs(pred_bias) + (1-lam)*loss
            objloss = max(
                1, config['adversarial']['lambda'] *
                (abs(pred_bias[0][0]) - config['objective']['epsilon'] +
                 config['adversarial']['margin']) + 1) * bceloss

            objloss.backward()
            train_loss = objloss.item()
            actor_optimizer.step()
            if (epoch % 10 == 0) and (step % 100 == 0):
                logger.info(
                    f'=======> Actor Epoch: {(epoch, step)} loss: {train_loss}'
                )

        if epoch % 10 == 0:
            with torch.no_grad():
                scores = actor(data.X_valid_gpu)[:,
                                                 0].reshape(-1,
                                                            1).cpu().numpy()
                _, best_adv_obj = get_best_thresh(
                    scores,
                    np.linspace(0, 1, 1001),
                    data,
                    config,
                    valid=False,
                    margin=config['adversarial']['margin'])
                logger.info(f'Objective: {best_adv_obj}')

    logger.info('Finding optimal threshold for Adversarial model.')
    with torch.no_grad():
        scores = actor(data.X_valid_gpu)[:, 0].reshape(-1, 1).cpu().numpy()

    best_adv_thresh, _ = get_best_thresh(
        scores,
        np.linspace(0, 1, 1001),
        data,
        config,
        valid=False,
        margin=config['adversarial']['margin'])

    logger.info('Evaluating Adversarial model on best threshold.')
    with torch.no_grad():
        labels = (actor(data.X_valid_gpu)[:, 0] > best_adv_thresh).reshape(
            -1, 1).cpu().numpy()
    results_valid = get_valid_objective(labels, data, config)
    logger.info(f'Results: {results_valid}')

    with torch.no_grad():
        labels = (actor(data.X_test_gpu)[:, 0] > best_adv_thresh).reshape(
            -1, 1).cpu().numpy()
    results_test = get_test_objective(labels, data, config)

    return results_valid, results_test