Ejemplo n.º 1
0
 def __init__(self):
     self.n_item = 5
     self.max_c = 100
     self.obs_low = np.concatenate(
         ([0] * args.n_discrete_state, [-5] * args.n_continuous_state))
     self.obs_high = np.concatenate(
         ([1] * args.n_discrete_state, [5] * args.n_continuous_state))
     self.observation_space = spaces.Box(low=self.obs_low,
                                         high=self.obs_high,
                                         dtype=np.float32)
     self.action_space = spaces.Box(low=-5,
                                    high=5,
                                    shape=(args.n_discrete_action +
                                           args.n_continuous_action, ),
                                    dtype=np.float32)
     self.trans_model = Policy(discrete_action_sections,
                               discrete_state_sections,
                               state_0=dataset.state)
     if mode == 'train':
         self.trans_model.transition_net.load_state_dict(
             torch.load('./model_pkl/Transition_model_sas_train_4.pkl'))
         self.trans_model.policy_net.load_state_dict(
             torch.load('./model_pkl/Policy_model_sas_train_4.pkl'))
         self.trans_model.policy_net_action_std = torch.load(
             './model_pkl/policy_net_action_std_model_sas_train_4.pkl')
     elif mode == 'test':
         self.trans_model.transition_net.load_state_dict(
             torch.load('./model_pkl/Transition_model_sas_test.pkl'))
         self.trans_model.policy_net.load_state_dict(
             torch.load('./model_pkl/Policy_model_sas_test.pkl'))
         self.trans_model.policy_net_action_std = torch.load(
             './model_pkl/policy_net_action_std_model_sas_test.pkl')
     else:
         assert False
     self.reset()
Ejemplo n.º 2
0
def test():
    env_name = "Pendulum-v0"
    env = gym.make(env_name)
    render = True
    ppo = Policy([0], [0])
    ppo.policy_net.load_state_dict(
        torch.load('./model_pkl/Policy_model_4.pkl'))
    for ep in range(500):
        ep_reward = 0
        state = env.reset()
        # env.render()
        for t in range(200):
            state = torch.unsqueeze(
                torch.from_numpy(state).type(torch.FloatTensor), 0).to(device)
            # _, action, _ = ppo.get_policy_net_action(state,size=10000)
            discrete_action_probs_with_continuous_mean = ppo.policy_net(state)
            action = discrete_action_probs_with_continuous_mean[:, 0:]
            action = torch.squeeze(action, 1)
            # print(action)
            action = action.cpu().detach().numpy()
            state, reward, done, _ = env.step(action)
            ep_reward += reward
            env.render()
            if done:
                break
        # writer.add_scalar('ep_reward', ep_reward, ep)
        print('Episode: {}\tReward: {}'.format(ep, int(ep_reward)))
        env.close()
Ejemplo n.º 3
0
def test():
    # load std models
    # policy_log_std = torch.load('./model_pkl/policy_net_action_std_model_1.pkl')
    # transition_log_std = torch.load('./model_pkl/transition_net_state_std_model_1.pkl')

    # define actor/critic/discriminator net and optimizer
    policy = Policy(discrete_action_sections, discrete_state_sections)
    value = Value()
    discriminator = Discriminator()
    discriminator_criterion = nn.BCELoss()

    # load expert data
    dataset = ExpertDataSet(args.data_set_path)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=0)

    # load net  models
    discriminator.load_state_dict(
        torch.load('./model_pkl/Discriminator_model_2.pkl'))
    policy.transition_net.load_state_dict(
        torch.load('./model_pkl/Transition_model_2.pkl'))
    policy.policy_net.load_state_dict(
        torch.load('./model_pkl/Policy_model_2.pkl'))
    value.load_state_dict(torch.load('./model_pkl/Value_model_2.pkl'))

    discrete_state_loss_list = []
    continous_state_loss_list = []
    action_loss_list = []
    cnt = 0
    for expert_state_batch, expert_action_batch, expert_next_state in data_loader:
        cnt += 1
        expert_state_action = torch.cat(
            (expert_state_batch, expert_action_batch),
            dim=-1).type(FloatTensor)
        next_discrete_state, next_continuous_state, _ = policy.get_transition_net_state(
            expert_state_action)
        gen_next_state = torch.cat(
            (next_discrete_state.to(device), next_continuous_state.to(device)),
            dim=-1)

        loss_func = torch.nn.MSELoss()
        continous_state_loss = loss_func(gen_next_state[:, 132:],
                                         expert_next_state[:, 132:])
        discrete_state_loss = hamming_loss(
            gen_next_state[:, :132],
            expert_next_state[:, :132].type(torch.LongTensor))

        discrete_action, continuous_action, _ = policy.get_policy_net_action(
            expert_state_batch.type(FloatTensor))
        gen_action = torch.FloatTensor(continuous_action)
        loss_func = torch.nn.MSELoss()
        action_loss = loss_func(gen_action, expert_action_batch)

        discrete_state_loss_list.append(discrete_state_loss)
        continous_state_loss_list.append(continous_state_loss.item())
        action_loss_list.append(action_loss)
    print(sum(discrete_state_loss_list) / cnt)
    print(sum(continous_state_loss_list) / cnt)
Ejemplo n.º 4
0
 def __init__(self, alpha: float = 0.1, gamma: float = 0.9):
     super().__init__()
     self._alpha = alpha
     self._gamma = gamma
     behavioral_mapping = {
         state: [0., 1.] if state.current_sum < 12 else [.5, .5]
         for state in ALL_STATES
     }
     self._beh_policy = Policy.from_probabilistic_mapping(
         behavioral_mapping)
Ejemplo n.º 5
0
 def __init__(self, states):
     self._states = states
     initial_policy_map = {}
     for ind, state in self._states.get_states.items():
         initial_policy_map[ind] = []
         actions = ['left', 'right', 'up', 'down']
         for action in actions:
             initial_policy_map[ind].append(action)
     self._state_values = np.zeros((self._states.get_num_states,))
     self._current_policy = Policy(initial_policy_map)
     super().__init__()
    def _play_stage(self, initial_state: State, policy: Policy,
                    log_action: Callable) -> State:
        taken_action = None
        state = initial_state

        while taken_action != Action.STICK and state != BUST:
            taken_action = policy.make_decision_in(state)
            log_action(state, taken_action)
            if taken_action == Action.HIT:
                state = state.move_with(self._deck.get_next_card())
        return state
Ejemplo n.º 7
0
    def setUp(self) -> None:
        n_discrete_state = randint(20, 30)
        n_discrete_action = randint(20, 30)

        self.policy_discrete_state_sections = self.generate_random_sections(
            n_discrete_state)
        self.policy_discrete_action_sections = self.generate_random_sections(
            n_discrete_action)
        self.policy = Policy(self.policy_discrete_action_sections,
                             self.policy_discrete_state_sections,
                             n_discrete_state=n_discrete_state,
                             n_discrete_action=n_discrete_action,
                             n_continuous_action=1,
                             n_continuous_state=1)
        self.no_discrete_policy = Policy([0], [0],
                                         n_discrete_action=0,
                                         n_discrete_state=0,
                                         n_continuous_state=1,
                                         n_continuous_action=1)
        self.no_continuous_policy = Policy(
            self.policy_discrete_action_sections,
            self.policy_discrete_state_sections,
            n_continuous_action=0,
            n_continuous_state=0,
            n_discrete_state=n_discrete_state,
            n_discrete_action=n_discrete_action)
Ejemplo n.º 8
0
def evaluate_env():
    env_name = "Pendulum-v0"
    env = gym.make(env_name)
    ppo = Policy([0], [0])
    ppo.policy_net.load_state_dict(
        torch.load('./model_pkl/Policy_model_2.pkl'))
    ppo.transition_net.load_state_dict(
        torch.load('./model_pkl/Transition_model_2.pkl'))
    state = env.reset()
    t = 0
    value_list = []
    for j in range(50):
        state = env.reset()
        for i in range(200):
            t += 1
            if (i == 0):
                gen_state = torch.unsqueeze(
                    torch.from_numpy(state).type(torch.FloatTensor), 0)
                real_state = torch.unsqueeze(
                    torch.from_numpy(state).type(torch.FloatTensor), 0)
            else:
                _, action, _ = ppo.get_policy_net_action(gen_state.to(device),
                                                         size=10000)
                _, gen_state, _ = ppo.get_transition_net_state(torch.cat(
                    (gen_state.to(device), action), dim=-1),
                                                               size=10000)
                action = torch.squeeze(action, 1)
                action = action.cpu().numpy()
                real_state, reward, done, _ = env.step(action)
                value = torch.dist((gen_state.to(device)).float(),
                                   (torch.from_numpy(real_state).unsqueeze(0)
                                    ).float().to(device),
                                   p=2)
                value_list.append(value)
                if done:
                    i = 0
    plt.plot(np.linspace(0, 100, len(value_list)), value_list)
    plt.show()
Ejemplo n.º 9
0
    def run_algorithm(self):
        policy_map = self._current_policy.get_actions_map()
        iterations = 100

        for k in range(iterations):
            # policy evaluation
            policy_state_values = self.evaluate_policy(self._state_values,
                                                       policy_map)
            self._state_values = policy_state_values
            # greedy policy improvement
            new_policy_map, changed = self.improve_policy(
                policy_state_values, policy_map)
            policy_map = new_policy_map
            # check convergence to the optimal policy
            if not changed:
                break

        self._current_policy = Policy(policy_map)
Ejemplo n.º 10
0
    def run_algorithm(self):
        """
            V*(s) = max_A Q(S,A)
            Q(S,A) = R(S,A) + gamma * sigma ( T(S,A,S') * V*(S') )
            since T(S,A,S') = 1, as it's deterministic
            thus Q(S,A) = R(S,A) + gamma * V*(S')
        """
        num_states = self._states.get_num_states
        states = self._states.get_states
        values = np.zeros((num_states, ))
        policy_map = {}
        iterations = 100
        convergence_error = 0.001

        # iterate for N times.
        for k in range(iterations):
            new_values = np.zeros((num_states, ))
            max_delta = 0

            # Update for each state
            for ind, state in states.items():

                # for each state, see all actions
                max_q, max_action = self._get_max_q(state, states, values)

                # update the value of the state
                new_values[state.get_single_index] = max_q
                # state.set_value(max_q)
                max_delta = max(
                    max_delta,
                    abs(new_values[state.get_single_index] -
                        values[state.get_single_index]))

                # update the policy
                policy_map[ind] = [max_action]

            values = new_values
            self._state_values = values
            if max_delta < convergence_error:
                break

        self._current_policy = Policy(policy_map)
Ejemplo n.º 11
0
def main():
    # define actor/critic/discriminator net and optimizer
    policy = Policy(discrete_action_sections, discrete_state_sections)
    value = Value()
    discriminator = Discriminator()
    optimizer_policy = torch.optim.Adam(policy.parameters(), lr=args.policy_lr)
    optimizer_value = torch.optim.Adam(value.parameters(), lr=args.value_lr)
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=args.discrim_lr)
    discriminator_criterion = nn.BCELoss()
    writer = SummaryWriter()

    # load expert data
    dataset = ExpertDataSet(args.expert_activities_data_path,
                            args.expert_cost_data_path)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=args.expert_batch_size,
                                  shuffle=False,
                                  num_workers=1)

    # load models
    # discriminator.load_state_dict(torch.load('./model_pkl/Discriminator_model_3.pkl'))
    # policy.transition_net.load_state_dict(torch.load('./model_pkl/Transition_model_3.pkl'))
    # policy.policy_net.load_state_dict(torch.load('./model_pkl/Policy_model_3.pkl'))
    # value.load_state_dict(torch.load('./model_pkl/Value_model_3.pkl'))

    print('#############  start training  ##############')

    # update discriminator
    num = 0
    for ep in tqdm(range(args.training_epochs)):
        # collect data from environment for ppo update
        start_time = time.time()
        memory = policy.collect_samples(args.ppo_buffer_size, size=10000)
        # print('sample_data_time:{}'.format(time.time()-start_time))
        batch = memory.sample()
        continuous_state = torch.stack(
            batch.continuous_state).squeeze(1).detach()
        discrete_action = torch.stack(
            batch.discrete_action).squeeze(1).detach()
        continuous_action = torch.stack(
            batch.continuous_action).squeeze(1).detach()
        next_discrete_state = torch.stack(
            batch.next_discrete_state).squeeze(1).detach()
        next_continuous_state = torch.stack(
            batch.next_continuous_state).squeeze(1).detach()
        old_log_prob = torch.stack(batch.old_log_prob).detach()
        mask = torch.stack(batch.mask).squeeze(1).detach()
        discrete_state = torch.stack(batch.discrete_state).squeeze(1).detach()
        d_loss = torch.empty(0, device=device)
        p_loss = torch.empty(0, device=device)
        v_loss = torch.empty(0, device=device)
        gen_r = torch.empty(0, device=device)
        expert_r = torch.empty(0, device=device)
        for _ in range(1):
            for expert_state_batch, expert_action_batch in data_loader:
                gen_state = torch.cat((discrete_state, continuous_state),
                                      dim=-1)
                gen_action = torch.cat((discrete_action, continuous_action),
                                       dim=-1)
                gen_r = discriminator(gen_state, gen_action)
                expert_r = discriminator(expert_state_batch,
                                         expert_action_batch)
                optimizer_discriminator.zero_grad()
                d_loss = discriminator_criterion(gen_r,
                                                 torch.zeros(gen_r.shape, device=device)) + \
                         discriminator_criterion(expert_r,
                                                 torch.ones(expert_r.shape, device=device))
                total_d_loss = d_loss - 10 * torch.var(gen_r.to(device))
                d_loss.backward()
                # total_d_loss.backward()
                optimizer_discriminator.step()
        writer.add_scalar('d_loss', d_loss, ep)
        # writer.add_scalar('total_d_loss', total_d_loss, ep)
        writer.add_scalar('expert_r', expert_r.mean(), ep)

        # update PPO
        gen_r = discriminator(
            torch.cat((discrete_state, continuous_state), dim=-1),
            torch.cat((discrete_action, continuous_action), dim=-1))
        optimize_iter_num = int(
            math.ceil(discrete_state.shape[0] / args.ppo_mini_batch_size))
        for ppo_ep in range(args.ppo_optim_epoch):
            for i in range(optimize_iter_num):
                num += 1
                index = slice(
                    i * args.ppo_mini_batch_size,
                    min((i + 1) * args.ppo_mini_batch_size,
                        discrete_state.shape[0]))
                discrete_state_batch, continuous_state_batch, discrete_action_batch, continuous_action_batch, \
                old_log_prob_batch, mask_batch, next_discrete_state_batch, next_continuous_state_batch, gen_r_batch = \
                    discrete_state[index], continuous_state[index], discrete_action[index], continuous_action[index], \
                    old_log_prob[index], mask[index], next_discrete_state[index], next_continuous_state[index], gen_r[
                        index]
                v_loss, p_loss = ppo_step(
                    policy, value, optimizer_policy, optimizer_value,
                    discrete_state_batch, continuous_state_batch,
                    discrete_action_batch, continuous_action_batch,
                    next_discrete_state_batch, next_continuous_state_batch,
                    gen_r_batch, old_log_prob_batch, mask_batch,
                    args.ppo_clip_epsilon)
            writer.add_scalar('p_loss', p_loss, num)
            writer.add_scalar('v_loss', v_loss, num)
            writer.add_scalar('gen_r', gen_r.mean(), num)

        print('#' * 5 + 'training episode:{}'.format(ep) + '#' * 5)
        print('d_loss', d_loss.item())
        # print('p_loss', p_loss.item())
        # print('v_loss', v_loss.item())
        print('gen_r:', gen_r.mean().item())
        print('expert_r:', expert_r.mean().item())

        memory.clear_memory()
        # save models
        torch.save(discriminator.state_dict(),
                   './model_pkl/Discriminator_model_4.pkl')
        torch.save(policy.transition_net.state_dict(),
                   './model_pkl/Transition_model_4.pkl')
        torch.save(policy.policy_net.state_dict(),
                   './model_pkl/Policy_model_4.pkl')
        torch.save(value.state_dict(), './model_pkl/Value_model_4.pkl')
Ejemplo n.º 12
0
class PolicyTestCase(unittest.TestCase):
    def generate_random_sections(self, total_dim):
        # discard a random value
        if total_dim == 0:
            return [0]
        randint(2, 8)
        sections_len = randint(2, 8)
        sections = [2] * sections_len
        remain_dim = total_dim - sections_len * 2
        if remain_dim <= 0:
            return self.generate_random_sections(total_dim)
        for i in range(sections_len - 1):
            dim = randint(1, remain_dim)
            sections[i] += dim
            remain_dim -= dim
            if remain_dim <= 1:
                break
        sections[sections_len - 1] += (total_dim - sum(sections))
        assert sum(sections) == total_dim
        return sections

    def setUp(self) -> None:
        n_discrete_state = randint(20, 30)
        n_discrete_action = randint(20, 30)

        self.policy_discrete_state_sections = self.generate_random_sections(
            n_discrete_state)
        self.policy_discrete_action_sections = self.generate_random_sections(
            n_discrete_action)
        self.policy = Policy(self.policy_discrete_action_sections,
                             self.policy_discrete_state_sections,
                             n_discrete_state=n_discrete_state,
                             n_discrete_action=n_discrete_action,
                             n_continuous_action=1,
                             n_continuous_state=1)
        self.no_discrete_policy = Policy([0], [0],
                                         n_discrete_action=0,
                                         n_discrete_state=0,
                                         n_continuous_state=1,
                                         n_continuous_action=1)
        self.no_continuous_policy = Policy(
            self.policy_discrete_action_sections,
            self.policy_discrete_state_sections,
            n_continuous_action=0,
            n_continuous_state=0,
            n_discrete_state=n_discrete_state,
            n_discrete_action=n_discrete_action)

    def test_collect_time(self):
        import time
        start_time = time.time()
        memory = self.policy.collect_samples(2048)
        end_time = time.time()
        print('Total time %f' % (end_time - start_time))
        start_time = time.time()
        memory = self.policy.collect_samples(4, 512)
        end_time = time.time()
        print('Total time %f' % (end_time - start_time))

    def test_collect_samples_size(self):
        memory = self.policy.collect_samples(2048, 99)
        batch = memory.sample()
        discrete_state = torch.stack(batch.discrete_state).to(
            self.policy.device).squeeze(1).detach()
        continuous_state = torch.stack(batch.continuous_state).to(
            self.policy.device).squeeze(1).detach()
        next_discrete_state = torch.stack(batch.next_discrete_state).to(
            self.policy.device).squeeze(1).detach()
        continuous_action = torch.stack(batch.continuous_action).to(
            self.policy.device).squeeze(1).detach()
        next_continuous_state = torch.stack(batch.next_continuous_state).to(
            self.policy.device).squeeze(1).detach()
        discrete_action = torch.stack(batch.discrete_action).to(
            self.policy.device).squeeze(1).detach()

        old_log_prob = torch.stack(batch.old_log_prob).to(
            self.policy.device).squeeze(1).detach()
        mask = torch.stack(batch.mask).to(
            self.policy.device).squeeze(1).detach()

        policy_log_prob_new = self.policy.get_policy_net_log_prob(
            torch.cat((discrete_state, continuous_state), dim=-1),
            discrete_action, continuous_action)
        transition_log_prob_new = self.policy.get_transition_net_log_prob(
            torch.cat((discrete_state, continuous_state, discrete_action,
                       continuous_action),
                      dim=-1), next_discrete_state, next_continuous_state)
        new_log_prob = policy_log_prob_new + transition_log_prob_new
        # Due to we are not update policy gradient, old_log_prob and new_log_prob must be equal
        assert tensor_close(new_log_prob, old_log_prob)

    def test_log_prob(self):
        memory = self.policy.collect_samples(2048)
        batch = memory.sample()
        discrete_state = torch.stack(batch.discrete_state).to(
            self.policy.device).squeeze(1).detach()
        continuous_state = torch.stack(batch.continuous_state).to(
            self.policy.device).squeeze(1).detach()
        next_discrete_state = torch.stack(batch.next_discrete_state).to(
            self.policy.device).squeeze(1).detach()
        continuous_action = torch.stack(batch.continuous_action).to(
            self.policy.device).squeeze(1).detach()
        next_continuous_state = torch.stack(batch.next_continuous_state).to(
            self.policy.device).squeeze(1).detach()
        discrete_action = torch.stack(batch.discrete_action).to(
            self.policy.device).squeeze(1).detach()

        assert discrete_state.size(1) == sum(
            self.policy_discrete_state_sections)
        assert continuous_state.size(1) == 1
        assert next_discrete_state.size(1) == sum(
            self.policy_discrete_state_sections)
        assert continuous_action.size(1) == 1
        assert next_continuous_state.size(1) == 1
        assert discrete_action.size(1) == sum(
            self.policy_discrete_action_sections)

        old_log_prob = torch.stack(batch.old_log_prob).to(
            self.policy.device).squeeze(1).detach()
        mask = torch.stack(batch.mask).to(
            self.policy.device).squeeze(1).detach()
        # it should contain 'done'
        assert (mask == 0).any()

        policy_log_prob_new = self.policy.get_policy_net_log_prob(
            torch.cat((discrete_state, continuous_state), dim=-1),
            discrete_action, continuous_action)
        transition_log_prob_new = self.policy.get_transition_net_log_prob(
            torch.cat((discrete_state, continuous_state, discrete_action,
                       continuous_action),
                      dim=-1), next_discrete_state, next_continuous_state)
        new_log_prob = policy_log_prob_new + transition_log_prob_new
        # Due to we are not update policy gradient, old_log_prob and new_log_prob must be equal
        assert tensor_close(new_log_prob, old_log_prob)

    def test_no_discrete_log_prob(self):
        memory = self.no_discrete_policy.collect_samples(2048)
        batch = memory.sample()
        discrete_state = torch.stack(batch.discrete_state).to(
            self.no_discrete_policy.device).squeeze(1).detach()
        continuous_state = torch.stack(batch.continuous_state).to(
            self.no_discrete_policy.device).squeeze(1).detach()
        next_discrete_state = torch.stack(batch.next_discrete_state).to(
            self.no_discrete_policy.device).squeeze(1).detach()
        continuous_action = torch.stack(batch.continuous_action).to(
            self.no_discrete_policy.device).squeeze(1).detach()
        next_continuous_state = torch.stack(batch.next_continuous_state).to(
            self.no_discrete_policy.device).squeeze(1).detach()
        discrete_action = torch.stack(batch.discrete_action).to(
            self.no_discrete_policy.device).squeeze(1).detach()

        assert discrete_state.size(1) == sum(
            self.no_discrete_policy.discrete_state_sections)
        assert continuous_state.size(1) == 1
        assert next_discrete_state.size(1) == sum(
            self.no_discrete_policy.discrete_state_sections)
        assert continuous_action.size(1) == 1
        assert next_continuous_state.size(1) == 1
        assert discrete_action.size(1) == sum(
            self.no_discrete_policy.discrete_action_sections)

        old_log_prob = torch.stack(batch.old_log_prob).to(
            self.no_discrete_policy.device).squeeze(1).detach()
        mask = torch.stack(batch.mask).to(
            self.no_discrete_policy.device).squeeze(1).detach()
        # it should contain 'done'
        assert (mask == 0).any()

        policy_log_prob_new = self.no_discrete_policy.get_policy_net_log_prob(
            torch.cat((discrete_state, continuous_state), dim=-1),
            discrete_action, continuous_action)
        transition_log_prob_new = self.no_discrete_policy.get_transition_net_log_prob(
            torch.cat((discrete_state, continuous_state, discrete_action,
                       continuous_action),
                      dim=-1), next_discrete_state, next_continuous_state)
        new_log_prob = policy_log_prob_new + transition_log_prob_new
        # Due to we are not update policy gradient, old_log_prob and new_log_prob must be equal
        assert tensor_close(new_log_prob, old_log_prob)

    def test_no_continuous_log_prob(self):
        memory = self.no_continuous_policy.collect_samples(2048)
        batch = memory.sample()
        discrete_state = torch.stack(batch.discrete_state).to(
            self.no_continuous_policy.device).squeeze(1).detach()
        continuous_state = torch.stack(batch.continuous_state).to(
            self.no_continuous_policy.device).squeeze(1).detach()
        next_discrete_state = torch.stack(batch.next_discrete_state).to(
            self.no_continuous_policy.device).squeeze(1).detach()
        continuous_action = torch.stack(batch.continuous_action).to(
            self.no_continuous_policy.device).squeeze(1).detach()
        next_continuous_state = torch.stack(batch.next_continuous_state).to(
            self.no_continuous_policy.device).squeeze(1).detach()
        discrete_action = torch.stack(batch.discrete_action).to(
            self.no_continuous_policy.device).squeeze(1).detach()

        assert discrete_state.size(1) == sum(
            self.no_continuous_policy.discrete_state_sections)
        assert continuous_state.size(1) == 0
        assert next_discrete_state.size(1) == sum(
            self.no_continuous_policy.discrete_state_sections)
        assert continuous_action.size(1) == 0
        assert next_continuous_state.size(1) == 0
        assert discrete_action.size(1) == sum(
            self.no_continuous_policy.discrete_action_sections)

        old_log_prob = torch.stack(batch.old_log_prob).to(
            self.no_continuous_policy.device).squeeze(1).detach()
        mask = torch.stack(batch.mask).to(
            self.no_continuous_policy.device).squeeze(1).detach()
        # it should contain 'done'
        assert (mask == 0).any()

        policy_log_prob_new = self.no_continuous_policy.get_policy_net_log_prob(
            torch.cat((discrete_state, continuous_state), dim=-1),
            discrete_action, continuous_action)
        transition_log_prob_new = self.no_continuous_policy.get_transition_net_log_prob(
            torch.cat((discrete_state, continuous_state, discrete_action,
                       continuous_action),
                      dim=-1), next_discrete_state, next_continuous_state)
        new_log_prob = policy_log_prob_new + transition_log_prob_new
        # Due to we are not update policy gradient, old_log_prob and new_log_prob must be equal
        assert tensor_close(new_log_prob, old_log_prob)
Ejemplo n.º 13
0
 def policy(self) -> Policy:
     return Policy.from_probabilistic_mapping(self._pi)
Ejemplo n.º 14
0
 def policy(self) -> Policy:
     return Policy.from_values(self._Q)
Ejemplo n.º 15
0
def main():
    ## load std models
    # policy_log_std = torch.load('./model_pkl/policy_net_action_std_model_1.pkl')
    # transition_log_std = torch.load('./model_pkl/transition_net_state_std_model_1.pkl')

    # load expert data
    print(args.data_set_path)
    dataset = ExpertDataSet(args.data_set_path)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=args.expert_batch_size,
                                  shuffle=True,
                                  num_workers=0)
    # define actor/critic/discriminator net and optimizer
    policy = Policy(onehot_action_sections,
                    onehot_state_sections,
                    state_0=dataset.state)
    value = Value()
    discriminator = Discriminator()
    optimizer_policy = torch.optim.Adam(policy.parameters(), lr=args.policy_lr)
    optimizer_value = torch.optim.Adam(value.parameters(), lr=args.value_lr)
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=args.discrim_lr)
    discriminator_criterion = nn.BCELoss()
    if write_scalar:
        writer = SummaryWriter(log_dir='runs/' + model_name)

    # load net  models
    if load_model:
        discriminator.load_state_dict(
            torch.load('./model_pkl/Discriminator_model_' + model_name +
                       '.pkl'))
        policy.transition_net.load_state_dict(
            torch.load('./model_pkl/Transition_model_' + model_name + '.pkl'))
        policy.policy_net.load_state_dict(
            torch.load('./model_pkl/Policy_model_' + model_name + '.pkl'))
        value.load_state_dict(
            torch.load('./model_pkl/Value_model_' + model_name + '.pkl'))

        policy.policy_net_action_std = torch.load(
            './model_pkl/Policy_net_action_std_model_' + model_name + '.pkl')
        policy.transition_net_state_std = torch.load(
            './model_pkl/Transition_net_state_std_model_' + model_name +
            '.pkl')
    print('#############  start training  ##############')

    # update discriminator
    num = 0
    for ep in tqdm(range(args.training_epochs)):
        # collect data from environment for ppo update
        policy.train()
        value.train()
        discriminator.train()
        start_time = time.time()
        memory, n_trajs = policy.collect_samples(
            batch_size=args.sample_batch_size)
        # print('sample_data_time:{}'.format(time.time()-start_time))
        batch = memory.sample()
        onehot_state = torch.cat(batch.onehot_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        multihot_state = torch.cat(batch.multihot_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        continuous_state = torch.cat(batch.continuous_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()

        onehot_action = torch.cat(batch.onehot_action, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        multihot_action = torch.cat(batch.multihot_action, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        continuous_action = torch.cat(batch.continuous_action, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        next_onehot_state = torch.cat(batch.next_onehot_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        next_multihot_state = torch.cat(batch.next_multihot_state,
                                        dim=1).reshape(
                                            n_trajs * args.sample_traj_length,
                                            -1).detach()
        next_continuous_state = torch.cat(
            batch.next_continuous_state,
            dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach()

        old_log_prob = torch.cat(batch.old_log_prob, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        mask = torch.cat(batch.mask,
                         dim=1).reshape(n_trajs * args.sample_traj_length,
                                        -1).detach()
        gen_state = torch.cat((onehot_state, multihot_state, continuous_state),
                              dim=-1)
        gen_action = torch.cat(
            (onehot_action, multihot_action, continuous_action), dim=-1)
        if ep % 1 == 0:
            # if (d_slow_flag and ep % 50 == 0) or (not d_slow_flag and ep % 1 == 0):
            d_loss = torch.empty(0, device=device)
            p_loss = torch.empty(0, device=device)
            v_loss = torch.empty(0, device=device)
            gen_r = torch.empty(0, device=device)
            expert_r = torch.empty(0, device=device)
            for expert_state_batch, expert_action_batch in data_loader:
                noise1 = torch.normal(0,
                                      args.noise_std,
                                      size=gen_state.shape,
                                      device=device)
                noise2 = torch.normal(0,
                                      args.noise_std,
                                      size=gen_action.shape,
                                      device=device)
                noise3 = torch.normal(0,
                                      args.noise_std,
                                      size=expert_state_batch.shape,
                                      device=device)
                noise4 = torch.normal(0,
                                      args.noise_std,
                                      size=expert_action_batch.shape,
                                      device=device)
                gen_r = discriminator(gen_state + noise1, gen_action + noise2)
                expert_r = discriminator(
                    expert_state_batch.to(device) + noise3,
                    expert_action_batch.to(device) + noise4)

                # gen_r = discriminator(gen_state, gen_action)
                # expert_r = discriminator(expert_state_batch.to(device), expert_action_batch.to(device))
                optimizer_discriminator.zero_grad()
                d_loss = discriminator_criterion(gen_r, torch.zeros(gen_r.shape, device=device)) + \
                            discriminator_criterion(expert_r,torch.ones(expert_r.shape, device=device))
                variance = 0.5 * torch.var(gen_r.to(device)) + 0.5 * torch.var(
                    expert_r.to(device))
                total_d_loss = d_loss - 10 * variance
                d_loss.backward()
                # total_d_loss.backward()
                optimizer_discriminator.step()
            if write_scalar:
                writer.add_scalar('d_loss', d_loss, ep)
                writer.add_scalar('total_d_loss', total_d_loss, ep)
                writer.add_scalar('variance', 10 * variance, ep)
        if ep % 1 == 0:
            # update PPO
            noise1 = torch.normal(0,
                                  args.noise_std,
                                  size=gen_state.shape,
                                  device=device)
            noise2 = torch.normal(0,
                                  args.noise_std,
                                  size=gen_action.shape,
                                  device=device)
            gen_r = discriminator(gen_state + noise1, gen_action + noise2)
            #if gen_r.mean().item() < 0.1:
            #    d_stop = True
            #if d_stop and gen_r.mean()
            optimize_iter_num = int(
                math.ceil(onehot_state.shape[0] / args.ppo_mini_batch_size))
            # gen_r = -(1 - gen_r + 1e-10).log()
            for ppo_ep in range(args.ppo_optim_epoch):
                for i in range(optimize_iter_num):
                    num += 1
                    index = slice(
                        i * args.ppo_mini_batch_size,
                        min((i + 1) * args.ppo_mini_batch_size,
                            onehot_state.shape[0]))
                    onehot_state_batch, multihot_state_batch, continuous_state_batch, onehot_action_batch, multihot_action_batch, continuous_action_batch, \
                    old_log_prob_batch, mask_batch, next_onehot_state_batch, next_multihot_state_batch, next_continuous_state_batch, gen_r_batch = \
                        onehot_state[index], multihot_state[index], continuous_state[index], onehot_action[index], multihot_action[index], continuous_action[index], \
                        old_log_prob[index], mask[index], next_onehot_state[index], next_multihot_state[index], next_continuous_state[index], gen_r[
                            index]
                    v_loss, p_loss = ppo_step(
                        policy, value, optimizer_policy, optimizer_value,
                        onehot_state_batch, multihot_state_batch,
                        continuous_state_batch, onehot_action_batch,
                        multihot_action_batch, continuous_action_batch,
                        next_onehot_state_batch, next_multihot_state_batch,
                        next_continuous_state_batch, gen_r_batch,
                        old_log_prob_batch, mask_batch, args.ppo_clip_epsilon)
                    if write_scalar:
                        writer.add_scalar('p_loss', p_loss, ep)
                        writer.add_scalar('v_loss', v_loss, ep)
        policy.eval()
        value.eval()
        discriminator.eval()
        noise1 = torch.normal(0,
                              args.noise_std,
                              size=gen_state.shape,
                              device=device)
        noise2 = torch.normal(0,
                              args.noise_std,
                              size=gen_action.shape,
                              device=device)
        gen_r = discriminator(gen_state + noise1, gen_action + noise2)
        expert_r = discriminator(
            expert_state_batch.to(device) + noise3,
            expert_action_batch.to(device) + noise4)
        gen_r_noise = gen_r.mean().item()
        expert_r_noise = expert_r.mean().item()
        gen_r = discriminator(gen_state, gen_action)
        expert_r = discriminator(expert_state_batch.to(device),
                                 expert_action_batch.to(device))
        if write_scalar:
            writer.add_scalar('gen_r', gen_r.mean(), ep)
            writer.add_scalar('expert_r', expert_r.mean(), ep)
            writer.add_scalar('gen_r_noise', gen_r_noise, ep)
            writer.add_scalar('expert_r_noise', expert_r_noise, ep)
        print('#' * 5 + 'training episode:{}'.format(ep) + '#' * 5)
        print('gen_r_noise', gen_r_noise)
        print('expert_r_noise', expert_r_noise)
        print('gen_r:', gen_r.mean().item())
        print('expert_r:', expert_r.mean().item())
        print('d_loss', d_loss.item())
        # save models
        if model_name is not None:
            torch.save(
                discriminator.state_dict(),
                './model_pkl/Discriminator_model_' + model_name + '.pkl')
            torch.save(policy.transition_net.state_dict(),
                       './model_pkl/Transition_model_' + model_name + '.pkl')
            torch.save(policy.policy_net.state_dict(),
                       './model_pkl/Policy_model_' + model_name + '.pkl')
            torch.save(
                policy.policy_net_action_std,
                './model_pkl/Policy_net_action_std_model_' + model_name +
                '.pkl')
            torch.save(
                policy.transition_net_state_std,
                './model_pkl/Transition_net_state_std_model_' + model_name +
                '.pkl')
            torch.save(value.state_dict(),
                       './model_pkl/Value_model_' + model_name + '.pkl')
        memory.clear_memory()
Ejemplo n.º 16
0
 def policy(self) -> Policy:
     return Policy.from_deterministic_mapping(self._pi)
Ejemplo n.º 17
0
class WebEye(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self):
        self.n_item = 5
        self.max_c = 100
        self.obs_low = np.concatenate(
            ([0] * args.n_discrete_state, [-5] * args.n_continuous_state))
        self.obs_high = np.concatenate(
            ([1] * args.n_discrete_state, [5] * args.n_continuous_state))
        self.observation_space = spaces.Box(low=self.obs_low,
                                            high=self.obs_high,
                                            dtype=np.float32)
        self.action_space = spaces.Box(low=-5,
                                       high=5,
                                       shape=(args.n_discrete_action +
                                              args.n_continuous_action, ),
                                       dtype=np.float32)
        self.trans_model = Policy(discrete_action_sections,
                                  discrete_state_sections,
                                  state_0=dataset.state)
        if mode == 'train':
            self.trans_model.transition_net.load_state_dict(
                torch.load('./model_pkl/Transition_model_sas_train_4.pkl'))
            self.trans_model.policy_net.load_state_dict(
                torch.load('./model_pkl/Policy_model_sas_train_4.pkl'))
            self.trans_model.policy_net_action_std = torch.load(
                './model_pkl/policy_net_action_std_model_sas_train_4.pkl')
        elif mode == 'test':
            self.trans_model.transition_net.load_state_dict(
                torch.load('./model_pkl/Transition_model_sas_test.pkl'))
            self.trans_model.policy_net.load_state_dict(
                torch.load('./model_pkl/Policy_model_sas_test.pkl'))
            self.trans_model.policy_net_action_std = torch.load(
                './model_pkl/policy_net_action_std_model_sas_test.pkl')
        else:
            assert False
        self.reset()

    def seed(self, sd=0):
        torch.manual_seed(sd)

    @property
    def state(self):
        return torch.cat((self.discrete_state, self.continuous_state), axis=-1)

    def __user_generator(self):
        # with shape(n_user_feature,)
        user = self.user_model.generate()
        self.__leave = self.user_leave_model.predict(user)
        return user

    def _calc_reward(self):
        return self.state[:, args.n_discrete_state - 1 + 8].to(device)

    def step(self, action):
        assert action.shape[0] == self.batch_size
        self.length += 1
        self.discrete_state, self.continuous_state, _ = self.trans_model.get_transition_net_state(
            torch.cat((self.state, action), dim=-1))
        done = (self.length >= 5)
        #if done:
        reward = self._calc_reward()
        #else:
        #    reward = torch.zeros(size=(self.batch_size,)).to(device)
        return self.state, reward, done, {}

    def reset(self, batch_size=1):
        self.batch_size = batch_size
        self.length = 0
        self.discrete_state, self.continuous_state = self.trans_model.reset(
            num_trajs=self.batch_size)
        return self.state

    def render(self, mode='human', close=False):
        pass