示例#1
0
def update_params(batch_mgr, batch_wrk):
    states_mgr = torch.from_numpy(np.stack(batch_mgr.state)).to(dtype).to(device)
    subgoals = torch.from_numpy(np.stack(batch_mgr.action)).to(dtype).to(device)
    rewards_mgr = torch.from_numpy(np.stack(batch_mgr.reward)).to(dtype).to(device)
    masks_mgr = torch.from_numpy(np.stack(batch_mgr.mask)).to(dtype).to(device)

    states_wrk = torch.from_numpy(np.stack(batch_wrk.state)).to(dtype).to(device)
    actions = torch.from_numpy(np.stack(batch_wrk.action)).to(dtype).to(device)
    rewards_wrk = torch.from_numpy(np.stack(batch_wrk.reward)).to(dtype).to(device)
    masks_wrk = torch.from_numpy(np.stack(batch_wrk.mask)).to(dtype).to(device)

    with torch.no_grad():
        values_mgr = value_mgr(states_mgr)
        values_wrk = value_wrk(states_wrk)

    """get advantage estimation from the trajectories"""
    advantages_mgr, returns_mgr = estimate_advantages(rewards_mgr, masks_mgr, values_mgr, args.gamma, args.tau, device)
    advantages_wrk, returns_wrk = estimate_advantages(rewards_wrk, masks_wrk, values_wrk, args.gamma, args.tau, device)

    #print (torch.sum(torch.isnan(advantages_mgr)*1.0), torch.sum(torch.isnan(returns_mgr)*1.0))
    #print (torch.sum(torch.isnan(advantages_wrk)*1.0), torch.sum(torch.isnan(returns_wrk)*1.0))

    """perform TRPO update"""
    policy_loss_m = 0
    policy_loss_m, value_loss_m = a2c_step(policy_mgr, value_mgr, optim_policy_m, optim_value_m, states_mgr, subgoals, returns_mgr, advantages_mgr, args.l2_reg)
    policy_loss_w, value_loss_w = a2c_step(policy_wrk, value_wrk, optim_policy_w, optim_value_w, states_wrk, actions, returns_wrk, advantages_wrk, args.l2_reg)

    return policy_loss_m, policy_loss_w
示例#2
0
def update_params(batch_mgr, batch_wrk):
    states_mgr = torch.from_numpy(np.stack(batch_mgr.state)).to(dtype).to(device)
    directions = torch.from_numpy(np.stack(batch_mgr.action)).to(dtype).to(device)
    rewards_mgr = torch.from_numpy(np.stack(batch_mgr.reward)).to(dtype).to(device)
    masks_mgr = torch.from_numpy(np.stack(batch_mgr.mask)).to(dtype).to(device)

    states_wrk = torch.from_numpy(np.stack(batch_wrk.state)).to(dtype).to(device)
    actions = torch.from_numpy(np.stack(batch_wrk.action)).to(dtype).to(device)
    rewards_wrk = torch.from_numpy(np.stack(batch_wrk.reward)).to(dtype).to(device)
    masks_wrk = torch.from_numpy(np.stack(batch_wrk.mask)).to(dtype).to(device)

    with torch.no_grad():
        values_mgr = value_mgr(states_mgr)
        fixed_logprobs_mgr = policy_mgr.get_log_prob(states_mgr,directions)
        values_wrk = value_wrk(states_wrk)
        fixed_logprobs_wrk = policy_wrk.get_log_prob(states_wrk,actions)

    """get advantage estimation from the trajectories"""
    advantages_mgr, returns_mgr = estimate_advantages(rewards_mgr, masks_mgr, values_mgr, args.gamma, args.tau, device)
    advantages_wrk, returns_wrk = estimate_advantages(rewards_wrk, masks_wrk, values_wrk, args.gamma, args.tau, device)

    #print (torch.sum(torch.isnan(advantages_mgr)*1.0), torch.sum(torch.isnan(returns_mgr)*1.0))
    #print (torch.sum(torch.isnan(advantages_wrk)*1.0), torch.sum(torch.isnan(returns_wrk)*1.0))

    """perform TRPO update"""
    #policy_loss_m, value_loss_m = a2c_step(policy_mgr, value_mgr, optim_policy_m, optim_value_m, states_mgr, directions, returns_mgr, advantages_mgr, args.l2_reg)
    #policy_loss_w, value_loss_w = a2c_step(policy_wrk, value_wrk, optim_policy_w, optim_value_w, states_wrk, actions, returns_wrk, advantages_wrk, args.l2_reg)
    optim_iter_mgr = int(math.ceil(states_mgr.shape[0] / optim_batch_size))
    optim_iter_wrk = int(math.ceil(states_wrk.shape[0] / optim_batch_size))
    for _ in range(optim_epochs):
        perm_mgr = np.arange(states_mgr.shape[0])
        np.random.shuffle(perm_mgr)
        perm_mgr = LongTensor(perm_mgr).to(device)

        perm_wrk = np.arange(states_wrk.shape[0])
        np.random.shuffle(perm_wrk)
        perm_wrk = LongTensor(perm_wrk).to(device)

        states_mgr, directions, returns_mgr, advantages_mgr, fixed_logprobs_mgr = \
            states_mgr[perm_mgr].clone(), directions[perm_mgr].clone(), returns_mgr[perm_mgr].clone(), advantages_mgr[perm_mgr].clone(), fixed_logprobs_mgr[perm_mgr].clone()
        states_wrk, actions, returns_wrk, advantages_wrk, fixed_logprobs_wrk = \
            states_wrk[perm_wrk].clone(), actions[perm_wrk].clone(), returns_wrk[perm_wrk].clone(), advantages_wrk[perm_wrk].clone(), fixed_logprobs_wrk[perm_wrk].clone()

        for i in range(optim_iter_mgr):
            ind = slice(i * optim_batch_size, min((i + 1) * optim_batch_size, states_mgr.shape[0]))
            states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
                states_mgr[ind], directions[ind], advantages_mgr[ind], returns_mgr[ind], fixed_logprobs_mgr[ind]

            ppo_step(policy_mgr, value_mgr, optim_policy_m, optim_value_m, 1, states_b, actions_b, returns_b,
                     advantages_b, fixed_log_probs_b, args.clip_epsilon, args.l2_reg)

        for i in range(optim_iter_wrk):
            ind = slice(i * optim_batch_size, min((i + 1) * optim_batch_size, states_wrk.shape[0]))
            states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
                states_wrk[ind], actions[ind], advantages_wrk[ind], returns_wrk[ind], fixed_logprobs_wrk[ind]

            ppo_step(policy_wrk, value_wrk, optim_policy_w, optim_value_w, 1, states_b, actions_b, returns_b,
                     advantages_b, fixed_log_probs_b, 1e20, args.l2_reg)
示例#3
0
def process_data(value_net, policy_net, model_states, model_actions,
                 model_rewards, gamma, tau, i_iter, print_freq):
    if use_gpu:
        model_states, model_actions, model_rewards = model_states.cuda(
        ), model_actions.cuda(), model_rewards.cuda()
    model_values = value_net(Variable(model_states, volatile=True))[0].data
    fixed_log_probs = policy_net.get_log_prob(
        Variable(model_states, volatile=True), Variable(model_actions)).data
    model_advantages, model_returns = estimate_advantages(
        model_rewards, model_values, gamma, tau, use_gpu)

    if i_iter % print_freq == 0:
        with open("intermediates.txt", "a") as text_file:
            text_file.write('iter: {} \n'.format(i_iter))
            text_file.write('rewards: ' +
                            to_string(model_rewards[:, 0].squeeze()) + '\n')
            text_file.write('values: ' +
                            to_string(model_values[:, 0].squeeze()) + '\n')
            text_file.write('advs: ' +
                            to_string(model_advantages[:, 0].squeeze()) + '\n')
            text_file.write('returns: ' +
                            to_string(model_returns[:, 0].squeeze()) + '\n')
            text_file.write('\n')

    return model_advantages, model_returns, fixed_log_probs
示例#4
0
def update_params(batch, i_iter):
    states = torch.from_numpy(np.stack(batch.state)).to(dtype).to(device)
    actions = torch.from_numpy(np.stack(batch.action)).to(dtype).to(device)
    rewards = torch.from_numpy(np.stack(batch.reward)).to(dtype).to(device)
    masks = torch.from_numpy(np.stack(batch.mask)).to(dtype).to(device)
    with torch.no_grad():
        values = value_net(states)
        fixed_log_probs = policy_net.get_log_prob(states, actions)
    """get advantage estimation from the trajectories"""
    advantages, returns = estimate_advantages(rewards, masks, values,
                                              args.gamma, args.tau, device)
    """perform mini-batch PPO update"""
    optim_iter_num = int(math.ceil(states.shape[0] / optim_batch_size))
    for _ in range(optim_epochs):
        perm = np.arange(states.shape[0])
        np.random.shuffle(perm)
        perm = LongTensor(perm).to(device)

        states, actions, returns, advantages, fixed_log_probs = \
            states[perm].clone(), actions[perm].clone(), returns[perm].clone(), advantages[perm].clone(), fixed_log_probs[perm].clone()

        for i in range(optim_iter_num):
            ind = slice(i * optim_batch_size,
                        min((i + 1) * optim_batch_size, states.shape[0]))
            states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
                states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind]

            ppo_step(policy_net, value_net, optimizer_policy, optimizer_value,
                     1, states_b, actions_b, returns_b, advantages_b,
                     fixed_log_probs_b, args.clip_epsilon, args.l2_reg)
def update_params(batch, i_iter):
    states = torch.from_numpy(np.stack(batch.state))
    actions = torch.from_numpy(np.stack(batch.action))
    rewards = torch.from_numpy(np.stack(batch.reward))
    masks = torch.from_numpy(np.stack(batch.mask).astype(np.float64))
    if use_gpu:
        states, actions, rewards, masks = states.cuda(), actions.cuda(), rewards.cuda(), masks.cuda()
    values = value_net(Variable(states, volatile=True)).data
    fixed_log_probs = policy_net.get_log_prob(Variable(states, volatile=True), Variable(actions)).data

    """get advantage estimation from the trajectories"""
    advantages, returns = estimate_advantages(rewards, masks, values, args.gamma, args.tau, use_gpu)

    lr_mult = max(1.0 - float(i_iter) / args.max_iter_num, 0)

    """perform mini-batch PPO update"""
    optim_iter_num = int(math.ceil(states.shape[0] / optim_batch_size))
    for _ in range(optim_epochs):
        perm = torch.randperm(states.shape[0])
        # perm = np.arange(states.shape[0])
        # np.random.shuffle(perm)
        # perm = LongTensor(perm).cuda() if use_gpu else LongTensor(perm)

        states, actions, returns, advantages, fixed_log_probs = \
            states[perm], actions[perm], returns[perm], advantages[perm], fixed_log_probs[perm]

        for i in range(optim_iter_num):
            ind = slice(i * optim_batch_size, min((i + 1) * optim_batch_size, states.shape[0]))
            states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
                states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind]

            ppo_step(policy_net, value_net, optimizer_policy, optimizer_value, 1, states_b, actions_b, returns_b,
                     advantages_b, fixed_log_probs_b, lr_mult, args.learning_rate, args.clip_epsilon, args.l2_reg)
示例#6
0
def update_params(batch):
    states = Tensor(batch.state)
    actions = ActionTensor(batch.action)
    rewards = Tensor(batch.reward)
    masks = Tensor(batch.mask)
    values = value_net(Variable(states, volatile=True)).data
    """get advantage estimation from the trajectories"""
    advantages, returns = estimate_advantages(rewards, masks, values,
                                              args.gamma, args.tau, Tensor)
    """perform TRPO update"""
    a2c_step(policy_net, value_net, optimizer_policy, optimizer_value, states,
             actions, returns, advantages, args.l2_reg)
示例#7
0
def update_params(batch):
    states = torch.from_numpy(np.stack(batch.state)).to(dtype).to(device)
    actions = torch.from_numpy(np.stack(batch.action)).to(dtype).to(device)
    rewards = torch.from_numpy(np.stack(batch.reward)).to(dtype).to(device)
    masks = torch.from_numpy(np.stack(batch.mask)).to(dtype).to(device)
    with torch.no_grad():
        values = value_net(states)
    """get advantage estimation from the trajectories"""
    advantages, returns = estimate_advantages(rewards, masks, values,
                                              args.gamma, args.tau, device)
    """perform TRPO update"""
    a2c_step(policy_net, value_net, optimizer_policy, optimizer_value, states,
             actions, returns, advantages, args.l2_reg)
def update_params_trpo(batch):
    # (3)
    states = torch.from_numpy(np.stack(batch.state)).to(args.dtype).to(device)
    actions = torch.from_numpy(np.stack(batch.action)).to(args.dtype).to(device)
    rewards = torch.from_numpy(np.stack(batch.reward)).to(args.dtype).to(device)
    masks = torch.from_numpy(np.stack(batch.mask)).to(args.dtype).to(device)
    with torch.no_grad():
        values = value_net(states)  # estimate value function of each state with NN

    """get advantage estimation from the trajectories"""
    advantages, returns = estimate_advantages(rewards, masks, values, args.gamma, args.tau, device)

    """perform TRPO update"""
    trpo_step(policy_net, value_net, states, actions, returns, advantages, args.max_kl_trpo, args.damping, args.l2_reg)
示例#9
0
def update_params(batch):
    states = torch.from_numpy(np.stack(batch.state))
    actions = torch.from_numpy(np.stack(batch.action))
    rewards = torch.from_numpy(np.stack(batch.reward))
    masks = torch.from_numpy(np.stack(batch.mask).astype(np.float64))
    if use_gpu:
        states, actions, rewards, masks = states.cuda(), actions.cuda(), rewards.cuda(), masks.cuda()
    values = value_net(Variable(states, volatile=True)).data

    """get advantage estimation from the trajectories"""
    advantages, returns = estimate_advantages(rewards, masks, values, args.gamma, args.tau, use_gpu)

    """perform TRPO update"""
    trpo_step(policy_net, value_net, states, actions, returns, advantages, args.max_kl, args.damping, args.l2_reg)
示例#10
0
def update_params(batch, i_iter):
    states = torch.from_numpy(np.stack(batch.state))
    actions = torch.from_numpy(np.stack(batch.action))
    rewards = torch.from_numpy(np.stack(batch.reward))
    masks = torch.from_numpy(np.stack(batch.mask).astype(np.float64))
    if use_gpu:
        states, actions, rewards, masks = states.cuda(), actions.cuda(
        ), rewards.cuda(), masks.cuda()
    values = value_net(Variable(states, volatile=True)).data
    fixed_log_probs = policy_net.get_log_prob(Variable(states, volatile=True),
                                              Variable(actions)).data
    """get advantage estimation from the trajectories"""
    advantages, returns = estimate_advantages(rewards, masks, values,
                                              args.gamma, args.tau, use_gpu)

    lr_mult = max(1.0 - float(i_iter) / args.max_iter_num, 0)
    """update discriminator"""
    for _ in range(3):
        expert_state_actions = Tensor(expert_traj)
        if use_gpu:
            expert_state_actions = expert_state_actions.cuda()
        g_o = discrim_net(Variable(torch.cat([states, actions], 1)))
        e_o = discrim_net(Variable(expert_state_actions))
        optimizer_discrim.zero_grad()
        discrim_loss = discrim_criterion(g_o, Variable(ones((states.shape[0], 1)))) + \
            discrim_criterion(e_o, Variable(zeros((expert_traj.shape[0], 1))))
        discrim_loss.backward()
        optimizer_discrim.step()
    """perform mini-batch PPO update"""
    optim_iter_num = int(math.ceil(states.shape[0] / optim_batch_size))
    for _ in range(optim_epochs):
        perm = np.arange(states.shape[0])
        np.random.shuffle(perm)
        perm = LongTensor(perm)
        if use_gpu:
            perm = perm.cuda()
        states, actions, returns, advantages, fixed_log_probs = \
            states[perm], actions[perm], returns[perm], advantages[perm], fixed_log_probs[perm]

        for i in range(optim_iter_num):
            ind = slice(i * optim_batch_size,
                        min((i + 1) * optim_batch_size, states.shape[0]))
            states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
                states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind]

            ppo_step(policy_net, value_net, optimizer_policy, optimizer_value,
                     1, states_b, actions_b, returns_b, advantages_b,
                     fixed_log_probs_b, lr_mult, args.learning_rate,
                     args.clip_epsilon, args.l2_reg)
示例#11
0
def update_params(batch):
    states = torch.from_numpy(np.stack(batch.state)).to(dtype).to(device)
    actions = torch.from_numpy(np.stack(batch.action)).to(dtype).to(device)
    rewards = torch.from_numpy(np.stack(batch.reward)).to(dtype).to(device)
    masks = torch.from_numpy(np.stack(batch.mask)).to(dtype).to(device)
    with torch.no_grad():
        values = value_net(states)

    #get advantage estimation from the trajectories
    advantages, returns = estimate_advantages(rewards, masks, values,
                                              args.gamma, args.tau, device)

    #perform TRPO update
    trpo_step(policy_net, value_net, states, actions, returns, advantages,
              args.max_kl, args.damping, args.l2_reg)
示例#12
0
def update_params(batch, i_iter):
    states = torch.from_numpy(np.stack(batch.state)).to(dtype).to(device)
    actions = torch.from_numpy(np.stack(batch.action)).to(dtype).to(device)
    rewards = torch.from_numpy(np.stack(batch.reward)).to(dtype).to(device)
    masks = torch.from_numpy(np.stack(batch.mask)).to(dtype).to(device)
    with torch.no_grad():
        values = value_net(states)
        fixed_log_probs, act_mean, act_std = policy_net.get_log_prob(
            states, actions)
    """get advantage estimation from the trajectories"""
    advantages, returns = estimate_advantages(rewards, masks, values,
                                              args.gamma, args.tau, device)
    """update discriminator"""
    for _ in range(1):
        expert_state_actions = torch.from_numpy(expert_traj).to(dtype).to(
            device)
        g_o = discrim_net(torch.cat([states, actions], 1))
        e_o = discrim_net(expert_state_actions)
        optimizer_discrim.zero_grad()
        discrim_loss = discrim_criterion(g_o, ones((states.shape[0], 1), device=device)) + \
            discrim_criterion(e_o, zeros((expert_traj.shape[0], 1), device=device))
        discrim_loss.backward()
        torch.nn.utils.clip_grad_norm_(discrim_net.parameters(), 0.5)

        optimizer_discrim.step()
    """perform mini-batch PPO update"""
    optim_iter_num = int(math.ceil(states.shape[0] / optim_batch_size))
    for _ in range(optim_epochs):
        perm = np.arange(states.shape[0])
        np.random.shuffle(perm)
        perm = LongTensor(perm).to(device)

        states, actions, returns, advantages, fixed_log_probs = \
            states[perm].clone(), actions[perm].clone(), returns[perm].clone(), advantages[perm].clone(), fixed_log_probs[perm].clone()

        for i in range(optim_iter_num):
            ind = slice(i * optim_batch_size,
                        min((i + 1) * optim_batch_size, states.shape[0]))
            states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
                states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind]

            policy_surr, value_loss, ev, clip_frac, entropy, approxkl = ppo_step(
                policy_net, value_net, optimizer_policy, optimizer_value, 1,
                states_b, actions_b, returns_b, advantages_b,
                fixed_log_probs_b, args.clip_epsilon, args.l2_reg)

    return discrim_loss.item(
    ), policy_surr, value_loss, ev, clip_frac, entropy, approxkl
示例#13
0
def calculate_returns(batch,expert_data):
    tau = 1
    states = torch.from_numpy(np.stack(batch.state))
    actions = torch.from_numpy(np.stack(batch.action))
    rewards = torch.from_numpy(np.stack(batch.reward))
    next_states = torch.from_numpy(np.stack(batch.next_state))
    masks = torch.from_numpy(np.stack(batch.mask).astype(np.float64))
    if use_gpu:
        states, actions, rewards, masks = states.cuda(), actions.cuda(), rewards.cuda(), masks.cuda()
    values = value_net(Variable(states)).data
    
    perm = np.arange(states.shape[0])
    advantages, returns = estimate_advantages(rewards, masks, values, args.gamma, tau, use_gpu)
    num = 0
    for i in perm:
        expert_data.push(states[i], actions[i], masks[i], next_states[i], returns[i])
        num += 1
示例#14
0
def update_params(batch):

    states = torch.from_numpy(np.stack(batch.state)).to(dtype).to(device)
    actions = torch.from_numpy(np.stack(batch.action)).to(dtype).to(device)
    rewards = torch.from_numpy(np.stack(batch.reward)).to(dtype).to(device)
    masks = torch.from_numpy(np.stack(batch.mask)).to(dtype).to(device)

    with torch.no_grad():
        values = value_net(states)

    advantages, returns = estimate_advantages(rewards, masks, values,
                                              exp_args["config"]["gamma"],
                                              exp_args["config"]["tau"],
                                              device)

    a2c_step(policy_net, value_net, optimizer_policy, optimizer_value, states,
             actions, returns, advantages, exp_args["config"]["l2-reg"])
示例#15
0
def update_params_trpo(batch, i_iter):
    # (3)
    value_np.training = False
    num_transition = 0
    idx = []
    for b in batch:
        l = len(b)
        idx.append(l + num_transition - 1)
        num_transition += l
    masks = ones(num_transition)
    masks[idx] = 0
    states = zeros(1, num_transition, state_dim)
    actions = zeros(1, num_transition, action_dim)
    disc_rewards = zeros(1, num_transition, 1)
    rewards = zeros(1, num_transition, 1)
    i = 0
    for e, ep in enumerate(batch):
        for t, tr in enumerate(ep):
            states[:, i, :] = torch.from_numpy(tr.state).to(dtype).to(device)
            actions[:, i, :] = torch.from_numpy(tr.action).to(dtype).to(device)
            disc_rewards[:,
                         i, :] = torch.tensor(tr.disc_rew).to(dtype).to(device)
            rewards[:, i, :] = torch.tensor(tr.reward).to(dtype).to(device)
            i += 1

    with torch.no_grad():
        values_distr = value_np(
            states, disc_rewards,
            states)  # estimate value function of each state with NN
        values = values_distr.mean
    """get advantage estimation from the trajectories"""
    advantages, returns = estimate_advantages(rewards.squeeze_(0), masks,
                                              values.squeeze_(0), args.gamma,
                                              args.tau, device)
    """plot"""
    plot_values(value_np, value_net, states, values, advantages, disc_rewards,
                env, args, i_iter)
    """perform TRPO update"""
    trpo_step(policy_net, value_net, states.squeeze(0), actions.squeeze(0),
              returns.squeeze(0), advantages, args.max_kl, args.damping,
              args.l2_reg)
示例#16
0
def update_params(batch, i_iter):
    states = torch.from_numpy(np.stack(batch.state)).to(dtype).to(device)
    if args.peb:
        next_states = torch.from_numpy(np.stack(
            batch.next_state)).to(dtype).to(device)
    actions = torch.from_numpy(np.stack(batch.action)).to(dtype).to(device)
    rewards = torch.from_numpy(np.stack(batch.reward)).to(dtype).to(device)
    masks = torch.from_numpy(np.stack(batch.mask)).to(dtype).to(device)
    if hasattr(batch, 'aux_state'):
        aux_states = torch.from_numpy(np.stack(
            batch.aux_state)).to(dtype).to(device)
        if args.peb:
            aux_next_states = torch.from_numpy(np.stack(
                batch.aux_next_state)).to(dtype).to(device)
    else:
        aux_states = None
    if hasattr(batch, 'expert_mask'):
        expert_masks_np = np.array(batch.expert_mask)
        expert_masks = torch.from_numpy(np.stack(
            batch.expert_mask)).to(dtype).to(device)
    else:
        expert_masks_np = None
        expert_masks = None
    with torch.no_grad():
        if aux_states is not None:
            values = value_net(states, aux_states)
            fixed_log_probs = policy_net.get_log_prob(
                states, actions, aux_states)  # sums log probs for all actions
        else:
            values = value_net(states)
            fixed_log_probs = policy_net.get_log_prob(states, actions)

            # TODO if training of value_net is fixed to have normalized outputs, then they need to be correpsondingly
            # unnormalized in estimated_advantages to match the statistics of the returns in the new data
    """partial episode bootstrapping for fixing biased low return on env timeout"""
    if args.peb:
        with torch.no_grad():
            if aux_states is not None:
                terminal_ns_values = value_net(next_states[masks == 0],
                                               aux_next_states[masks == 0])
            else:
                terminal_ns_values = value_net(next_states[masks == 0])
    else:
        terminal_ns_values = None
    """modify buffers if expert intervention occurred"""
    if np.any(expert_masks_np):
        # states & actions marked with a 1 are new expert data
        expert_data['states'] = torch.cat((expert_data['states'], states))
        expert_data['actions'] = torch.cat((expert_data['actions'], actions))

        # masks need to be last autonomous state before every correction
        masks = masks.fill_(1)

        new_e_labels = torch.zeros(states.shape[0])
        next_non_e_ind = np.argmax(1. - expert_masks_np)
        next_e_ind = np.argmax(expert_masks_np)
        last_loop = False
        while True:
            if next_non_e_ind < next_e_ind:
                first_label = 1 / (next_e_ind - next_non_e_ind)
                new_e_labels[next_non_e_ind:next_e_ind] = torch.linspace(
                    float(first_label), 1., int(next_e_ind - next_non_e_ind))
                masks[next_e_ind - 1] = 0
                if last_loop:
                    break
                next_non_e_ind = np.argmax(
                    (1. - expert_masks_np)[next_e_ind:]) + next_e_ind
                if not np.any((1 - expert_masks_np)[next_e_ind:]):
                    next_non_e_ind = expert_masks_np.shape[0]
                    break
            else:
                # new_e_labels[next_e_ind:next_non_e_ind] = torch.zeros(next_non_e_ind - next_e_ind)
                if last_loop:
                    break
                next_e_ind = np.argmax(
                    expert_masks_np[next_non_e_ind:]) + next_non_e_ind
                if not np.any(expert_masks_np[next_non_e_ind:]):
                    next_e_ind = expert_masks_np.shape[0]
                    last_loop = True

            # TODO consider adding new reward values for the ppo update with a similar scheme as the new
            # discriminator labels, but it would have to be -math.log(new label), and they would have to be
            # somehow comparable in scale to what the rewards already being output were
            # (to not mess up value estimator)

        expert_data['labels'] = torch.cat(
            (expert_data['labels'],
             new_e_labels.unsqueeze(1).to(dtype).to(device)))

        new_expert_states = states[expert_masks_np.nonzero()]
        states = states[(1 - expert_masks_np).nonzero()]
        if aux_states is not None:
            new_expert_aux_states = aux_states[expert_masks_np.nonzero()]
            aux_states = aux_states[(1 - expert_masks_np).nonzero()]
        else:
            new_expert_aux_states = None
        new_expert_actions = actions[expert_masks_np.nonzero()]
        actions = actions[(1 - expert_masks_np).nonzero()]
    """get advantage estimation from the trajectories"""
    advantages, returns = estimate_advantages(rewards, masks, values,
                                              args.gamma, args.tau, device,
                                              terminal_ns_values)
    """update discriminator"""
    for _ in range(10):
        if is_img_state:
            g_o = discrim_net(states, actions, aux_states)
            e_o = discrim_net(expert_data['states'], expert_data['actions'],
                              expert_data['aux'])
        else:
            expert_state_actions = torch.from_numpy(expert_traj).to(dtype).to(
                device)
            g_o = discrim_net(torch.cat([states, actions], 1))
            e_o = discrim_net(expert_state_actions)
        optimizer_discrim.zero_grad()
        if args.intervention_device is not None:
            discrim_loss = discrim_criterion(g_o, ones((states.shape[0], 1), device=device)) + \
                discrim_criterion(e_o, expert_data['labels'])
        else:
            discrim_loss = discrim_criterion(g_o, ones((states.shape[0], 1), device=device)) + \
            discrim_criterion(e_o, zeros((num_expert_data, 1), device=device))
        discrim_loss.backward()
        optimizer_discrim.step()
    """if removing episode-termination bias, remove last set of states before terminal"""
    # if args.peb and i_iter <= args.first_ppo_iter:
    if args.peb:
        max_ret_dip = .001  # as a percent
        num_to_remove = math.ceil(
            np.log(max_ret_dip) / np.log(args.gamma * args.tau))
        terminal_inds = (masks == 0).nonzero()
        ep_first_ind = 0
        states_f, actions_f, returns_f, advantages_f, fixed_log_probs_f = [], [], [], [], []
        if aux_states is not None:
            aux_states_f = []
        for terminal_ind in terminal_inds:
            ep_last_ind = max(terminal_ind + 1 - num_to_remove, ep_first_ind)
            states_f.append(states[ep_first_ind:ep_last_ind])
            actions_f.append(actions[ep_first_ind:ep_last_ind])
            returns_f.append(returns[ep_first_ind:ep_last_ind])
            advantages_f.append(advantages[ep_first_ind:ep_last_ind])
            fixed_log_probs_f.append(fixed_log_probs[ep_first_ind:ep_last_ind])
            if aux_states is not None:
                aux_states_f.append(aux_states[ep_first_ind:ep_last_ind])

            ep_first_ind = terminal_ind + 1

        states, actions, returns, advantages, fixed_log_probs = \
            torch.cat(states_f), torch.cat(actions_f), torch.cat(returns_f), torch.cat(advantages_f), \
            torch.cat(fixed_log_probs_f)
        if aux_states is not None:
            aux_states = torch.cat(aux_states_f)

        # renormalize advantages
        advantages = (advantages - advantages.mean()) / advantages.std()
    """perform mini-batch PPO update"""
    if i_iter >= args.first_ppo_iter:
        original_policy_net = copy.deepcopy(policy_net)
        optim_iter_num = int(math.ceil(states.shape[0] / optim_batch_size))
        for _ in range(optim_epochs):
            perm = np.arange(states.shape[0])
            np.random.shuffle(perm)
            perm = LongTensor(perm).to(device)

            states, actions, returns, advantages, fixed_log_probs = \
                states[perm].clone(), actions[perm].clone(), returns[perm].clone(), \
                advantages[perm].clone(), fixed_log_probs[perm].clone()
            if aux_states is not None:
                aux_states = aux_states[perm].clone()

            # TODO also need to divide new expert data into minibatches AND ensure that whichever
            # of expert vs non expert is bigger will control the number of minibatches

            for i in range(optim_iter_num):
                ind = slice(i * optim_batch_size,
                            min((i + 1) * optim_batch_size, states.shape[0]))
                states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
                    states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind]
                if aux_states is not None:
                    aux_states_b = aux_states[ind]
                else:
                    aux_states_b = None

                ppo_step(policy_net, value_net, optimizer_policy,
                         optimizer_value, 1, states_b, actions_b, returns_b,
                         advantages_b, fixed_log_probs_b, args.clip_epsilon,
                         args.l2_reg, aux_states_b)

                kl = policy_net.get_kl_comp(states, aux_states,
                                            original_policy_net)
                print("kl div: %f" % kl.mean())
                if kl.mean() > .05:
                    break

            if kl.mean() > .05:
                break
示例#17
0
def update_params(batch, i_iter, wi_list, partitioner):
    states = torch.from_numpy(np.stack(batch.state))
    actions = torch.from_numpy(np.stack(batch.action))
    rewards = torch.from_numpy(np.stack(batch.reward))
    masks = torch.from_numpy(np.stack(batch.mask).astype(np.float64))
    if use_gpu:
        states, actions, rewards, masks = states.cuda(), actions.cuda(
        ), rewards.cuda(), masks.cuda()
    # remove volatile
    # values = value_net(Variable(states, volatile=True)).data
    with torch.no_grad():
        values = value_net(states)

    if i_iter % 10 == 0:
        with torch.no_grad():
            advantage_net(states, actions, verbose=True)
    #advantage = advantages_symbol.data
    # remove volatile
    # fixed_log_probs = policy_net.get_log_prob(Variable(states, volatile=True), Variable(actions), wi_list).data
    with torch.no_grad():
        fixed_log_probs = policy_net.get_log_prob(states, actions, wi_list)
    """get advantage estimation from the trajectories"""
    advantages, returns, advantages_unbiased = estimate_advantages(
        rewards, masks, values, args.gamma, args.tau, use_gpu)

    lr_mult = max(1.0 - float(i_iter) / args.max_iter_num, 0)

    list_H = []
    """perform mini-batch PPO update"""
    optim_iter_num = int(math.ceil(states.shape[0] / optim_batch_size))
    for _ in range(optim_epochs):
        perm = np.arange(states.shape[0])
        np.random.shuffle(perm)
        perm = LongTensor(perm).cuda() if use_gpu else LongTensor(perm)

        states, actions, returns, advantages, fixed_log_probs = \
            states[perm], actions[perm], returns[perm], advantages[perm], fixed_log_probs[perm]

        for i in range(optim_iter_num):
            ind = slice(i * optim_batch_size,
                        min((i + 1) * optim_batch_size, states.shape[0]))
            states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
                states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind]
            H = ppo_step(policy_net, value_net, advantage_net,
                         optimizer_policy, optimizer_value,
                         optimizer_advantage, 1, states_b, actions_b,
                         returns_b, advantages_b, fixed_log_probs_b, lr_mult,
                         args.learning_rate, args.clip_epsilon, args.l2_reg,
                         wi_list, args.method)

            if args.method in ['posa', 'posa-mlp']:
                list_H.append(H.unsqueeze(0))
    if args.method in ['posa', 'posa-mlp']:
        H_hat = torch.cat(list_H,
                          dim=0).mean(dim=0).cpu().numpy().astype(np.float64)
    elif args.method in ['combinatorial']:
        action_prime = policy_net.select_action(states)
        H = combinatorial(advantage_net, states, actions, action_prime)
        H_hat = H.astype(np.float64)
    elif args.method in ['submodular']:
        action_prime = policy_net.select_action(states)
        wi = submodular(advantage_net, states, actions, action_prime)
        wi_list = [wi, 1 - wi]
        H_hat = None
    if args.method in ['posa', 'posa-mlp', 'combinatorial']:
        partition = partitioner.step(H_hat)
        wi_list = get_wi(partition)
    #with open('logh{0}'.format(logger_name), 'a') as fa:
    #    fa.write(str(H_hat) + '\n')
    #put a log and see how many clusters are connected
    #import pdb; pdb.set_trace()
    print(wi_list[0])
    return wi_list, H_hat
示例#18
0
def update_params(batch, i_iter, opt):
    """update discriminator"""
    reirl_weights.write(
        reirl(expert_traj[:, :-action_dim], np.stack(batch.state), opt))
    value_net = Value(state_dim)
    optimizer_value = torch.optim.Adam(value_net.parameters(),
                                       lr=args.learning_rate)
    if i_iter > 0:
        j_max = 3  #if i_iter < 20 else 15
        for j in range(j_max):  #3):
            batch, log = ppo_agent.collect_samples(3000)
            print('{}\tT_sample {}\texpert_R_avg {}\tR_avg {}'.format(
                j, log['sample_time'], log['avg_c_reward'], log['avg_reward']))
            states = torch.from_numpy(np.stack(
                batch.state)).to(dtype).to(device)
            player_actions = torch.from_numpy(np.stack(
                batch.player_action)).to(dtype).to(device)
            opponent_actions = torch.from_numpy(np.stack(
                batch.opponent_action)).to(dtype).to(device)
            rewards = torch.from_numpy(np.stack(
                batch.reward)).to(dtype).to(device)
            masks = torch.from_numpy(np.stack(batch.mask)).to(dtype).to(device)
            with torch.no_grad():
                values = value_net(states)
                fixed_log_probs = policy_net.get_log_prob(
                    states, player_actions)
                opponent_fixed_log_probs = opponent_net.get_log_prob(
                    states, opponent_actions)
            """get advantage estimation from the trajectories"""
            advantages, returns = estimate_advantages(rewards, masks, values,
                                                      args.gamma, args.tau,
                                                      device)
            """perform mini-batch PPO update"""
            optim_iter_num = int(math.ceil(states.shape[0] / optim_batch_size))
            for _ in range(optim_epochs):
                perm = np.arange(states.shape[0])
                np.random.shuffle(perm)
                perm = LongTensor(perm).to(device)

                states, player_actions, opponent_actions, returns, advantages, fixed_log_probs, opponent_fixed_log_probs = \
                    states[perm].clone(), player_actions[perm].clone(), \
                    opponent_actions[perm].clone(), returns[perm].clone(), \
                    advantages[perm].clone(), \
                    fixed_log_probs[perm].clone(), opponent_fixed_log_probs[
                        perm].clone()

                for i in range(optim_iter_num):
                    ind = slice(
                        i * optim_batch_size,
                        min((i + 1) * optim_batch_size, states.shape[0]))
                    states_b, player_actions_b, opponent_actions_b, advantages_b, returns_b, fixed_log_probs_b, opponent_fixed_log_probs_b = \
                        states[ind], player_actions[ind], opponent_actions[ind], \
                        advantages[ind], returns[ind], fixed_log_probs[ind], \
                        opponent_fixed_log_probs[ind]

                    # Update the player
                    ppo_step(policy_net,
                             value_net,
                             optimizer_policy,
                             optimizer_value,
                             1,
                             states_b,
                             player_actions_b,
                             returns_b,
                             advantages_b,
                             fixed_log_probs_b,
                             args.clip_epsilon,
                             args.l2_reg,
                             max_grad=max_grad)
                    # Update the opponent
                    ppo_step(opponent_net,
                             value_net,
                             optimizer_opponent,
                             optimizer_value,
                             1,
                             states_b,
                             opponent_actions_b,
                             returns_b,
                             advantages_b,
                             opponent_fixed_log_probs_b,
                             args.clip_epsilon,
                             args.l2_reg,
                             opponent=True,
                             max_grad=max_grad)
示例#19
0
    def update_params(batch, i_iter):
        dataSize = min(args.min_batch_size, len(batch.state))
        states = torch.from_numpy(np.stack(
            batch.state)[:dataSize, ]).to(dtype).to(device)
        actions = torch.from_numpy(np.stack(
            batch.action)[:dataSize, ]).to(dtype).to(device)
        rewards = torch.from_numpy(np.stack(
            batch.reward)[:dataSize, ]).to(dtype).to(device)
        masks = torch.from_numpy(np.stack(
            batch.mask)[:dataSize, ]).to(dtype).to(device)
        with torch.no_grad():
            values = value_net(states)
            fixed_log_probs = policy_net.get_log_prob(states, actions)
        """estimate reward"""
        """get advantage estimation from the trajectories"""
        advantages, returns = estimate_advantages(rewards, masks, values,
                                                  args.gamma, args.tau, device)
        """update discriminator"""
        for _ in range(args.discriminator_epochs):
            #dataSize = states.size()[0]
            # expert_state_actions = torch.from_numpy(expert_traj).to(dtype).to(device)
            exp_idx = random.sample(range(expert_traj.shape[0]), dataSize)
            expert_state_actions = torch.from_numpy(
                expert_traj[exp_idx, :]).to(dtype).to(device)

            dis_input_real = expert_state_actions
            if len(actions.shape) == 1:
                actions.unsqueeze_(-1)
                dis_input_fake = torch.cat([states, actions], 1)
                actions.squeeze_(-1)
            else:
                dis_input_fake = torch.cat([states, actions], 1)

            if args.EBGAN or args.GMMIL or args.GEOMGAN:
                # tbd, no discriminaotr learning
                pass
            else:
                g_o = discrim_net(dis_input_fake)
                e_o = discrim_net(dis_input_real)

            optimizer_discrim.zero_grad()
            if args.GEOMGAN:
                optimizer_kernel.zero_grad()

            if args.WGAN:
                if args.LSGAN:
                    pdist = l1dist(dis_input_real,
                                   dis_input_fake).mul(args.lamb)
                    discrim_loss = LeakyReLU(e_o - g_o + pdist).mean()
                else:
                    discrim_loss = torch.mean(e_o) - torch.mean(g_o)
            elif args.EBGAN:
                e_recon = elementwise_loss(e_o, dis_input_real)
                g_recon = elementwise_loss(g_o, dis_input_fake)
                discrim_loss = e_recon
                if (args.margin - g_recon).item() > 0:
                    discrim_loss += (args.margin - g_recon)
            elif args.GMMIL:
                #mmd2_D,K = mix_rbf_mmd2(e_o_enc, g_o_enc, args.sigma_list)
                mmd2_D, K = mix_rbf_mmd2(dis_input_real, dis_input_fake,
                                         args.sigma_list)
                #tbd
                #rewards = K[0]+K[1]-2*K[2]
                rewards = K[1] - K[2]  # -(exp - gen): -(kxy-kyy)=kyy-kxy
                rewards = -rewards.detach(
                )  # exp - gen, maximize (gen label negative)
                errD = mmd2_D
                discrim_loss = -errD  # maximize errD

                # prep for generator
                advantages, returns = estimate_advantages(
                    rewards, masks, values, args.gamma, args.tau, device)
            elif args.GEOMGAN:
                # larger, better, but slower
                noise_num = 100
                mmd2_D, K = mix_imp_mmd2(e_o_enc, g_o_enc, noise_num,
                                         noise_dim, kernel_net, cuda)
                rewards = K[1] - K[2]  # -(exp - gen): -(kxy-kyy)=kyy-kxy
                rewards = -rewards.detach()
                errD = mmd2_D  #+ args.lambda_rg * one_side_errD
                discrim_loss = -errD  # maximize errD

                # prep for generator
                advantages, returns = estimate_advantages(
                    rewards, masks, values, args.gamma, args.tau, device)
            else:
                discrim_loss = discrim_criterion(g_o, ones((states.shape[0], 1), device=device)) + \
                               discrim_criterion(e_o, zeros((e_o.shape[0], 1), device=device))
            if args.GEOMGAN:
                optimizer_kernel.step()
        """perform mini-batch PPO update"""
        optim_iter_num = int(math.ceil(states.shape[0] / args.ppo_batch_size))
        for _ in range(args.generator_epochs):
            perm = np.arange(states.shape[0])
            np.random.shuffle(perm)
            perm = LongTensor(perm).to(device)

            states, actions, returns, advantages, fixed_log_probs = \
                states[perm].clone(), actions[perm].clone(), returns[perm].clone(), advantages[perm].clone(), \
                fixed_log_probs[perm].clone()

            for i in range(optim_iter_num):
                ind = slice(
                    i * args.ppo_batch_size,
                    min((i + 1) * args.ppo_batch_size, states.shape[0]))
                states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
                    states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind]

                ppo_step(policy_net, value_net, optimizer_policy,
                         optimizer_value, 1, states_b, actions_b, returns_b,
                         advantages_b, fixed_log_probs_b, args.clip_epsilon,
                         args.l2_reg)

        return rewards
示例#20
0
def update_params(batch, i_iter):
    # because each item in the batch is a list we must convert it
    #  accordingly.
    # print(batch.reward)
    # print(np.stack(batch.state, 1)[0])
    # print(len(np.stack(batch.state, 1)[0]))
    # print(batch.state[:2])
    # print(np.stack(batch.state[:2], 1))
    # print(batch.state[:3])
    states = [
        torch.from_numpy(a).to(dtype).to(device)
        for a in np.stack(batch.state, 1)
    ]
    actions = [
        torch.from_numpy(a).to(dtype).to(device)
        for a in np.stack(batch.action, 1)
    ]
    rewards = [
        torch.from_numpy(a).to(dtype).to(device)[:, None]
        for a in np.stack(batch.reward, 1)
    ]
    masks = [
        torch.from_numpy(a).to(dtype).to(device)
        for a in np.stack(batch.mask, 1)
    ]
    # print(rewards)
    # print(states)
    # print(states[0].size())
    # print(rewards)
    # states = [torch.from_numpy(np.stack(bs)).to(dtype).to(device) for bs in batch.state]
    # actions = [torch.from_numpy(np.stack(bs)).to(dtype).to(device) for bs in batch.action]
    # rewards = [torch.from_numpy(np.stack(bs)).to(dtype).to(device) for bs in batch.reward]
    # masks = [torch.from_numpy(np.stack(bs)).to(dtype).to(device) for bs in batch.mask]
    # print(rewards)

    # states = torch.from_numpy(np.stack(batch.state)).to(dtype).to(device)
    # actions = torch.from_numpy(np.stack(batch.action)).to(dtype).to(device)
    # rewards = torch.from_numpy(np.stack(batch.reward)).to(dtype).to(device)
    # masks = torch.from_numpy(np.stack(batch.mask)).to(dtype).to(device)
    with torch.no_grad():
        values = critic_net(states)
        # print(values)
        # print(states)
        # print(len(states))
        # print(states[0].size())
        # print(actions)
        # print(len(actions))
        # print(actions[0].size())
        fixed_log_probs = actor_net.get_log_prob(states, actions)
    """get advantage estimation from the trajectories"""
    adv_returns_tuple = \
        [estimate_advantages(r, m, v, gamma, tau, device)
        for r,m,v in zip(rewards, masks, values)]
    advantages = []
    returns = []
    for a, r in adv_returns_tuple:
        advantages.append(a)
        returns.append(r)
    # print(advantages)

    # advantages, returns = estimate_advantages(rewards, masks, values, args.gamma, args.tau, device)
    """perform mini-batch PPO update"""
    optim_iter_num = int(math.ceil(states[0].shape[0] / optim_batch_size))
    for _ in range(optim_epochs):
        perm = np.arange(states[0].shape[0])
        np.random.shuffle(perm)
        perm = LongTensor(perm).to(device)

        states = [st[perm].clone() for st in states]
        actions = [st[perm].clone() for st in actions]
        returns = [st[perm].clone() for st in returns]
        advantages = [st[perm].clone() for st in advantages]
        fixed_log_probs = [st[perm].clone() for st in fixed_log_probs]

        # states, actions, returns, advantages, fixed_log_probs = \
        #     states[perm].clone(), actions[perm].clone(), returns[perm].clone(), advantages[perm].clone(), fixed_log_probs[perm].clone()

        for i in range(optim_iter_num):
            ind = slice(i * optim_batch_size,
                        min((i + 1) * optim_batch_size, states[0].shape[0]))

            states_b = [st[ind] for st in states]
            actions_b = [st[ind] for st in actions]
            returns_b = [st[ind] for st in returns]
            advantages_b = [st[ind] for st in advantages]
            fixed_log_probs_b = [st[ind] for st in fixed_log_probs]

            # states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
            #     states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind]

            ppo_step(actor_net,
                     critic_net,
                     optimiser_actor,
                     optimiser_critic,
                     1,
                     states_b,
                     actions_b,
                     returns_b,
                     advantages_b,
                     fixed_log_probs_b,
                     clip_epsilon,
                     l2_reg,
                     multi_agent=True,
                     num_agents=num_agents)
示例#21
0
def update_params(batch):
    for i in range(len(policy_net)):
        policy_net[i].train()
        value_net[i].train()

    # states = torch.from_numpy(np.stack(batch.state))
    # actions = torch.from_numpy(np.stack(batch.action))
    # rewards = torch.from_numpy(np.stack(batch.reward))
    # masks = torch.from_numpy(np.stack(batch.mask).astype(np.float64))
    states = to_tensor_var(batch.state,True,"double").view(-1, agent.n_agents, agent.obs_shape_n[0]).data
    actions = to_tensor_var(batch.action,True,"long").view(-1, agent.n_agents, 1).data
    rewards = to_tensor_var(batch.reward,True,"double").view(-1, agent.n_agents, 1).data
    masks = to_tensor_var(batch.mask,True,"double").view(-1, agent.n_agents, 1).data

    whole_states_var = states.view(-1, agent.whole_critic_state_dim)
    whole_actions_var = actions.view(-1, agent.whole_critic_action_dim)

    # print( whole_states_var, whole_actions_var )



    if use_gpu:
        states, actions, rewards, masks = states.cuda(), actions.cuda(), rewards.cuda(), masks.cuda()
        whole_states_var, whole_actions_var = whole_states_var.cuda(), whole_actions_var.cuda()
    # values = value_net(Variable(whole_states_var, volatile=True)).data
    values = []
    for i in range(len(value_net)):
        # values.append(value_net[i](th.Tensor(whole_states_var)).data)
        # input = Variable(whole_states_var, volatile=True)
        values.append(value_net[i](Variable(whole_states_var)))

    # print(rewards, masks, values)
    # values = to_tensor_var(values,True,"double").view(-1, agent.n_agents, 1).data

    # Transpose!
    values_tmp = [[r[col] for r in values] for col in range(len(values[0]))]
    values = to_tensor_var(values_tmp,True,"double").view(-1, agent.n_agents,1 ).data.cuda()

    """get advantage estimation from the trajectories"""
    # advantages, returns = estimate_advantages(rewards, masks, values, args.gamma, args.tau, use_gpu)
    advantages, returns = [], []
    for i in range(len(value_net)):
        adv, ret = estimate_advantages(rewards[:,i,:], masks[:,i,:], values[:,i,:], args.gamma, args.tau, use_gpu)
        advantages.append(adv)
        returns.append(ret)
    #print(advantages, returns)

    # Transpose!
    advantages = [[r[col] for r in advantages] for col in range(len(advantages[0]))]
    advantages = to_tensor_var(advantages,True,"double").view(-1, agent.n_agents,1 ).data.cuda()

    # Transpose!
    returns = [[r[col] for r in returns] for col in range(len(returns[0]))]
    returns = to_tensor_var(returns,True,"double").view(-1, agent.n_agents,1 ).data.cuda()

    # # combine n agent's related advantages together
    # tmp_ary = np.empty_like(advantages[0])
    # for i in range(len(advantages)):
    #     tmp_ary = np.hstack((tmp_ary, advantages[i]))
    # advantages = tmp_ary[:,1:len(value_net)+1]

    # tmp_ary = np.empty_like(returns[0])
    # for i in range(len(returns)):
    #     tmp_ary = np.hstack((tmp_ary, returns[i]))
    # returns = tmp_ary[:,1:len(value_net)+1]

    # advantages = to_tensor_var(advantages, True, "double").view(-1, agent.n_agents, 1).data.cuda()
    # returns = to_tensor_var(returns, True, "double").view(-1, agent.n_agents, 1).data.cuda()

    """perform TRPO update"""
    for i in range(len(value_net)):
        # a2c_step(policy_net[i], value_net[i], optimizer_policy[i], optimizer_value[i], states[:,i,:], actions[:,i,:], returns[:,i,:], advantages[:,i,:], args.l2_reg)
        a2c_step(policy_net[i], value_net[i], optimizer_policy[i], optimizer_value[i], states,
                 actions, returns[:,i,:], advantages[:,i,:], args.l2_reg, i)
    def update_params(batch, i_iter):
        dataSize = min(args.min_batch_size, len(batch.state))
        states = torch.from_numpy(np.stack(
            batch.state)[:dataSize, ]).to(dtype).to(device)
        actions = torch.from_numpy(np.stack(
            batch.action)[:dataSize, ]).to(dtype).to(device)
        rewards = torch.from_numpy(np.stack(
            batch.reward)[:dataSize, ]).to(dtype).to(device)
        masks = torch.from_numpy(np.stack(
            batch.mask)[:dataSize, ]).to(dtype).to(device)
        with torch.no_grad():
            values = value_net(states)
            fixed_log_probs = policy_net.get_log_prob(states, actions)
        """estimate reward"""
        """get advantage estimation from the trajectories"""
        advantages, returns = estimate_advantages(rewards, masks, values,
                                                  args.gamma, args.tau, device)
        """update discriminator"""
        for _ in range(args.discriminator_epochs):
            exp_idx = random.sample(range(expert_traj.shape[0]), dataSize)
            expert_state_actions = torch.from_numpy(
                expert_traj[exp_idx, :]).to(dtype).to(device)

            dis_input_real = expert_state_actions
            if len(actions.shape) == 1:
                actions.unsqueeze_(-1)
                dis_input_fake = torch.cat([states, actions], 1)
                actions.squeeze_(-1)
            else:
                dis_input_fake = torch.cat([states, actions], 1)

            if args.EBGAN or args.GMMIL or args.VAKLIL:
                g_o_enc, g_mu, g_sigma = discrim_net(dis_input_fake,
                                                     mean_mode=False)
                e_o_enc, e_mu, e_sigma = discrim_net(dis_input_real,
                                                     mean_mode=False)
            else:
                g_o = discrim_net(dis_input_fake)
                e_o = discrim_net(dis_input_real)

            optimizer_discrim.zero_grad()
            if args.VAKLIL:
                optimizer_kernel.zero_grad()

            if args.AL:
                if args.LSGAN:
                    pdist = l1dist(dis_input_real,
                                   dis_input_fake).mul(args.lamb)
                    discrim_loss = LeakyReLU(e_o - g_o + pdist).mean()
                else:
                    discrim_loss = torch.mean(e_o) - torch.mean(g_o)
            elif args.EBGAN:
                e_recon = elementwise_loss(e_o, dis_input_real)
                g_recon = elementwise_loss(g_o, dis_input_fake)
                discrim_loss = e_recon
                if (args.margin - g_recon).item() > 0:
                    discrim_loss += (args.margin - g_recon)
            elif args.GMMIL:
                mmd2_D, K = mix_rbf_mmd2(e_o_enc, g_o_enc, args.sigma_list)
                rewards = K[1] - K[2]  # -(exp - gen): -(kxy-kyy)=kyy-kxy
                rewards = -rewards.detach(
                )  # exp - gen, maximize (gen label negative)
                errD = mmd2_D
                discrim_loss = -errD  # maximize errD

                advantages, returns = estimate_advantages(
                    rewards, masks, values, args.gamma, args.tau, device)
            elif args.VAKLIL:
                noise_num = 20000
                mmd2_D_net, _, penalty = mix_imp_with_bw_mmd2(
                    e_o_enc, g_o_enc, noise_num, noise_dim, kernel_net, cuda,
                    args.sigma_list)
                mmd2_D_rbf, _ = mix_rbf_mmd2(e_o_enc, g_o_enc, args.sigma_list)
                errD = (mmd2_D_net + mmd2_D_rbf) / 2
                # 1e-8: small number for numerical stability
                i_c = 0.2
                bottleneck_loss = torch.mean((0.5 * torch.sum((torch.cat(
                    (e_mu, g_mu), dim=0)**2) + (torch.cat(
                        (e_sigma, g_sigma), dim=0)**2) - torch.log((torch.cat(
                            (e_sigma, g_sigma), dim=0)**2) + 1e-8) - 1,
                                                              dim=1))) - i_c
                discrim_loss = -errD + (args.beta * bottleneck_loss) + (
                    args.lambda_h * penalty)
            else:
                discrim_loss = discrim_criterion(g_o, ones((states.shape[0], 1), device=device)) + \
                               discrim_criterion(e_o, zeros((e_o.shape[0], 1), device=device))

            discrim_loss.backward()
            optimizer_discrim.step()
            if args.VAKLIL:
                optimizer_kernel.step()

        if args.VAKLIL:
            with torch.no_grad():
                noise_num = 20000
                g_o_enc, _, _ = discrim_net(dis_input_fake)
                e_o_enc, _, _ = discrim_net(dis_input_real)
                _, K_net, _ = mix_imp_with_bw_mmd2(e_o_enc, g_o_enc, noise_num,
                                                   noise_dim, kernel_net, cuda,
                                                   args.sigma_list)
                _, K_rbf = mix_rbf_mmd2(e_o_enc, g_o_enc, args.sigma_list)
                K = [sum(x) / 2 for x in zip(K_net, K_rbf)]
                rewards = K[1] - K[2]  # -(exp - gen): -(kxy-kyy)=kyy-kxy
                rewards = -rewards  #.detach()
                advantages, returns = estimate_advantages(
                    rewards, masks, values, args.gamma, args.tau, device)
        """perform mini-batch PPO update"""
        optim_iter_num = int(math.ceil(states.shape[0] / args.ppo_batch_size))
        for _ in range(args.generator_epochs):
            perm = np.arange(states.shape[0])
            np.random.shuffle(perm)
            perm = LongTensor(perm).to(device)

            states, actions, returns, advantages, fixed_log_probs = \
                states[perm].clone(), actions[perm].clone(), returns[perm].clone(), advantages[perm].clone(), \
                fixed_log_probs[perm].clone()

            for i in range(optim_iter_num):
                ind = slice(
                    i * args.ppo_batch_size,
                    min((i + 1) * args.ppo_batch_size, states.shape[0]))
                states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
                    states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind]

                ppo_step(policy_net, value_net, optimizer_policy,
                         optimizer_value, 1, states_b, actions_b, returns_b,
                         advantages_b, fixed_log_probs_b, args.clip_epsilon,
                         args.l2_reg)

        return rewards