Exemple #1
0
    def __init__(self, args, state_dim, action_dim, is_dict_action):

        self.device = args.device
        self.config = args

        # initialize discriminator
        discriminator_action_dim = 1 if is_dict_action else action_dim
        self.discriminator = Discriminator(state_dim +
                                           discriminator_action_dim).to(
                                               self.device)
        self.discriminator_loss = nn.BCEWithLogitsLoss()
        self.discriminator_optimizer = torch.optim.Adam(
            self.discriminator.parameters(), lr=self.config.learning_rate)
        # initialize actor
        self.policy = PPO(args, state_dim, action_dim, is_dict_action)
Exemple #2
0
    def __init__(self, state_dim, action_dim, channels, kernel_sizes, strides, paddings=None,
                 head_hidden_size=(128, 128), num_aux=0, activation='relu', use_maxpool=False,
                 resnet_first_layer=False):
        super().__init__(state_dim, 1, channels, kernel_sizes, strides, paddings,
                         activation, use_maxpool, num_aux, resnet_first_layer)

        self.head = Discriminator(self.conv_out_size_for_fc + action_dim + num_aux,
                                  head_hidden_size, activation)
Exemple #3
0
class TestDiscriminator(TestCase):
    def setUp(self) -> None:
        self.d = Discriminator(6, 5, drop_rate=0.5)
        print(self.d)

    def test_forward(self):
        res = self.d.forward(torch.rand((6, 6)), torch.rand((6, 5)))
        self.assertEqual(res.size(), torch.Size([6, 1]))
Exemple #4
0
    def _init_model(self):
        self.V = Value(num_states=self.config["value"]["num_states"],
                       num_hiddens=self.config["value"]["num_hiddens"],
                       drop_rate=self.config["value"]["drop_rate"],
                       activation=self.config["value"]["activation"])
        self.P = JointPolicy(
            initial_state=self.expert_dataset.state.to(device),
            config=self.config["jointpolicy"])
        self.D = Discriminator(
            num_states=self.config["discriminator"]["num_states"],
            num_actions=self.config["discriminator"]["num_actions"],
            num_hiddens=self.config["discriminator"]["num_hiddens"],
            drop_rate=self.config["discriminator"]["drop_rate"],
            use_noise=self.config["discriminator"]["use_noise"],
            noise_std=self.config["discriminator"]["noise_std"],
            activation=self.config["discriminator"]["activation"])

        print("Model Structure")
        print(self.P)
        print(self.V)
        print(self.D)
        print()

        self.optimizer_policy = optim.Adam(
            self.P.parameters(),
            lr=self.config["jointpolicy"]["learning_rate"])
        self.optimizer_value = optim.Adam(
            self.V.parameters(), lr=self.config["value"]["learning_rate"])
        self.optimizer_discriminator = optim.Adam(
            self.D.parameters(),
            lr=self.config["discriminator"]["learning_rate"])
        self.scheduler_discriminator = optim.lr_scheduler.StepLR(
            self.optimizer_discriminator, step_size=2000, gamma=0.95)

        self.discriminator_func = nn.BCELoss()

        to_device(self.V, self.P, self.D, self.D, self.discriminator_func)
Exemple #5
0
torch.manual_seed(args.seed)
if use_gpu:
    torch.cuda.manual_seed_all(args.seed)

env_dummy = env_factory(0)
state_dim = env_dummy.observation_space.shape[0]
is_disc_action = len(env_dummy.action_space.shape) == 0
action_dim = (1 if is_disc_action else env_dummy.action_space.shape[0])
ActionTensor = LongTensor if is_disc_action else DoubleTensor
"""define actor, critic and discrimiator"""
if is_disc_action:
    policy_net = DiscretePolicy(state_dim, env_dummy.action_space.n)
else:
    policy_net = Policy(state_dim, env_dummy.action_space.shape[0])
value_net = Value(state_dim)
discrim_net = Discriminator(state_dim + action_dim)
discrim_criterion = nn.BCELoss()
if use_gpu:
    policy_net = policy_net.cuda()
    value_net = value_net.cuda()
    discrim_net = discrim_net.cuda()
    discrim_criterion = discrim_criterion.cuda()

optimizer_policy = torch.optim.Adam(policy_net.parameters(),
                                    lr=args.learning_rate)
optimizer_value = torch.optim.Adam(value_net.parameters(),
                                   lr=args.learning_rate)
optimizer_discrim = torch.optim.Adam(discrim_net.parameters(),
                                     lr=args.learning_rate)

# optimization epoch number and batch size for PPO
Exemple #6
0
def create_networks():
    """define actor and critic"""
    if is_disc_action:
        policy_net = DiscretePolicy(state_dim,
                                    env.action_space.n,
                                    hidden_size=(64, 32),
                                    activation='relu')
    else:
        policy_net = Policy(state_dim,
                            env.action_space.shape[0],
                            log_std=args.log_std,
                            hidden_size=(64, 32),
                            activation='relu')
    value_net = Value(state_dim, hidden_size=(32, 16), activation='relu')
    if args.WGAN:
        discrim_net = SNDiscriminator(state_dim + action_dim,
                                      hidden_size=(32, 16),
                                      activation='relu')
    elif args.EBGAN or args.GMMIL:
        discrim_net = AESNDiscriminator(state_dim + action_dim,
                                        hidden_size=(32, ),
                                        encode_size=64,
                                        activation='relu',
                                        slope=0.1,
                                        dropout=False,
                                        dprob=0.2)
    elif args.GEOMGAN:
        # new kernel
        #discrim_net = KernelNet(state_dim + action_dim,state_dim + action_dim)
        noise_dim = 64
        discrim_net = AESNDiscriminator(state_dim + action_dim,
                                        hidden_size=(32, ),
                                        encode_size=noise_dim,
                                        activation='relu',
                                        slope=0.1,
                                        dropout=False,
                                        dprob=0.2)
        kernel_net = NoiseNet(noise_dim,
                              hidden_size=(32, ),
                              encode_size=noise_dim,
                              activation='relu',
                              slope=0.1,
                              dropout=False,
                              dprob=0.2)
        optimizer_kernel = torch.optim.Adam(kernel_net.parameters(),
                                            lr=args.learning_rate / 2)
        scheduler_kernel = MultiStepLR(optimizer_kernel,
                                       milestones=args.milestones,
                                       gamma=args.lr_decay)
    else:
        discrim_net = Discriminator(state_dim + action_dim,
                                    hidden_size=(32, 16),
                                    activation='relu')

    optimizer_policy = torch.optim.Adam(policy_net.parameters(),
                                        lr=args.learning_rate)
    optimizer_value = torch.optim.Adam(value_net.parameters(),
                                       lr=args.learning_rate)
    optimizer_discrim = torch.optim.Adam(discrim_net.parameters(),
                                         lr=args.learning_rate)

    scheduler_policy = MultiStepLR(optimizer_policy,
                                   milestones=args.milestones,
                                   gamma=args.lr_decay)
    scheduler_value = MultiStepLR(optimizer_value,
                                  milestones=args.milestones,
                                  gamma=args.lr_decay)
    scheduler_discrim = MultiStepLR(optimizer_discrim,
                                    milestones=args.milestones,
                                    gamma=args.lr_decay)

    if args.WGAN:

        class ExpertReward():
            def __init__(self):
                self.a = 0

            def expert_reward(self, state, action):
                state_action = tensor(np.hstack([state, action]), dtype=dtype)
                with torch.no_grad():
                    return -discrim_net(state_action)[0].item()
                    # return -discrim_net(state_action).sum().item()

        learned_reward = ExpertReward()
    elif args.EBGAN:

        class ExpertReward():
            def __init__(self):
                self.a = 0

            def expert_reward(self, state, action):
                state_action = tensor(np.hstack([state, action]), dtype=dtype)
                with torch.no_grad():
                    _, recon_out = discrim_net(state_action)
                    return -elementwise_loss(
                        recon_out, state_action).item() + args.r_margin

        learned_reward = ExpertReward()
    elif args.GMMIL or args.GEOMGAN:

        class ExpertReward():
            def __init__(self):
                self.r_bias = 0

            def expert_reward(self, state, action):
                with torch.no_grad():
                    return self.r_bias

            def update_XX_YY(self):
                self.XX = torch.diag(torch.mm(self.e_o_enc, self.e_o_enc.t()))
                self.YY = torch.diag(torch.mm(self.g_o_enc, self.g_o_enc.t()))

        learned_reward = ExpertReward()
    else:

        class ExpertReward():
            def __init__(self):
                self.a = 0

            def expert_reward(self, state, action):
                state_action = tensor(np.hstack([state, action]), dtype=dtype)
                with torch.no_grad():
                    return -math.log(discrim_net(state_action)[0].item())

        learned_reward = ExpertReward()
    """create agent"""
    agent = Agent(env,
                  policy_net,
                  device,
                  custom_reward=learned_reward,
                  running_state=None,
                  render=args.render,
                  num_threads=args.num_threads)

    def update_params(batch, i_iter):
        dataSize = min(args.min_batch_size, len(batch.state))
        states = torch.from_numpy(np.stack(
            batch.state)[:dataSize, ]).to(dtype).to(device)
        actions = torch.from_numpy(np.stack(
            batch.action)[:dataSize, ]).to(dtype).to(device)
        rewards = torch.from_numpy(np.stack(
            batch.reward)[:dataSize, ]).to(dtype).to(device)
        masks = torch.from_numpy(np.stack(
            batch.mask)[:dataSize, ]).to(dtype).to(device)
        with torch.no_grad():
            values = value_net(states)
            fixed_log_probs = policy_net.get_log_prob(states, actions)
        """estimate reward"""
        """get advantage estimation from the trajectories"""
        advantages, returns = estimate_advantages(rewards, masks, values,
                                                  args.gamma, args.tau, device)
        """update discriminator"""
        for _ in range(args.discriminator_epochs):
            #dataSize = states.size()[0]
            # expert_state_actions = torch.from_numpy(expert_traj).to(dtype).to(device)
            exp_idx = random.sample(range(expert_traj.shape[0]), dataSize)
            expert_state_actions = torch.from_numpy(
                expert_traj[exp_idx, :]).to(dtype).to(device)

            dis_input_real = expert_state_actions
            if len(actions.shape) == 1:
                actions.unsqueeze_(-1)
                dis_input_fake = torch.cat([states, actions], 1)
                actions.squeeze_(-1)
            else:
                dis_input_fake = torch.cat([states, actions], 1)

            if args.EBGAN or args.GMMIL or args.GEOMGAN:
                # tbd, no discriminaotr learning
                pass
            else:
                g_o = discrim_net(dis_input_fake)
                e_o = discrim_net(dis_input_real)

            optimizer_discrim.zero_grad()
            if args.GEOMGAN:
                optimizer_kernel.zero_grad()

            if args.WGAN:
                if args.LSGAN:
                    pdist = l1dist(dis_input_real,
                                   dis_input_fake).mul(args.lamb)
                    discrim_loss = LeakyReLU(e_o - g_o + pdist).mean()
                else:
                    discrim_loss = torch.mean(e_o) - torch.mean(g_o)
            elif args.EBGAN:
                e_recon = elementwise_loss(e_o, dis_input_real)
                g_recon = elementwise_loss(g_o, dis_input_fake)
                discrim_loss = e_recon
                if (args.margin - g_recon).item() > 0:
                    discrim_loss += (args.margin - g_recon)
            elif args.GMMIL:
                #mmd2_D,K = mix_rbf_mmd2(e_o_enc, g_o_enc, args.sigma_list)
                mmd2_D, K = mix_rbf_mmd2(dis_input_real, dis_input_fake,
                                         args.sigma_list)
                #tbd
                #rewards = K[0]+K[1]-2*K[2]
                rewards = K[1] - K[2]  # -(exp - gen): -(kxy-kyy)=kyy-kxy
                rewards = -rewards.detach(
                )  # exp - gen, maximize (gen label negative)
                errD = mmd2_D
                discrim_loss = -errD  # maximize errD

                # prep for generator
                advantages, returns = estimate_advantages(
                    rewards, masks, values, args.gamma, args.tau, device)
            elif args.GEOMGAN:
                # larger, better, but slower
                noise_num = 100
                mmd2_D, K = mix_imp_mmd2(e_o_enc, g_o_enc, noise_num,
                                         noise_dim, kernel_net, cuda)
                rewards = K[1] - K[2]  # -(exp - gen): -(kxy-kyy)=kyy-kxy
                rewards = -rewards.detach()
                errD = mmd2_D  #+ args.lambda_rg * one_side_errD
                discrim_loss = -errD  # maximize errD

                # prep for generator
                advantages, returns = estimate_advantages(
                    rewards, masks, values, args.gamma, args.tau, device)
            else:
                discrim_loss = discrim_criterion(g_o, ones((states.shape[0], 1), device=device)) + \
                               discrim_criterion(e_o, zeros((e_o.shape[0], 1), device=device))
            if args.GEOMGAN:
                optimizer_kernel.step()
        """perform mini-batch PPO update"""
        optim_iter_num = int(math.ceil(states.shape[0] / args.ppo_batch_size))
        for _ in range(args.generator_epochs):
            perm = np.arange(states.shape[0])
            np.random.shuffle(perm)
            perm = LongTensor(perm).to(device)

            states, actions, returns, advantages, fixed_log_probs = \
                states[perm].clone(), actions[perm].clone(), returns[perm].clone(), advantages[perm].clone(), \
                fixed_log_probs[perm].clone()

            for i in range(optim_iter_num):
                ind = slice(
                    i * args.ppo_batch_size,
                    min((i + 1) * args.ppo_batch_size, states.shape[0]))
                states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
                    states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind]

                ppo_step(policy_net, value_net, optimizer_policy,
                         optimizer_value, 1, states_b, actions_b, returns_b,
                         advantages_b, fixed_log_probs_b, args.clip_epsilon,
                         args.l2_reg)

        return rewards

    if args.GEOMGAN:
        return policy_net,value_net,discrim_net,kernel_net,optimizer_policy,optimizer_value,optimizer_discrim,optimizer_kernel,agent,update_params \
            ,scheduler_policy,scheduler_value,scheduler_discrim,scheduler_kernel
    else:
        return policy_net,value_net,discrim_net,optimizer_policy,optimizer_value,optimizer_discrim,agent,update_params \
            ,scheduler_policy,scheduler_value,scheduler_discrim
Exemple #7
0
                    help="pretrain discriminator iteration (default: 30)")

args = parser.parse_args()
use_gpu = True
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if use_gpu:
    torch.cuda.manual_seed_all(args.seed)

is_disc_action = False
action_dim = 10
ActionTensor = DoubleTensor
"""define actor, critic and discrimiator"""
policy_net = Policy(10, 256, 10, num_layers=2)
value_net = Value(10, 256, num_layers=3)
discrim_net = Discriminator(10, 256, 10, num_layers=3)
discrim_criterion = nn.BCELoss()

#####################################################
### Load Models
load_models = True
if load_models:
    print("Loading Models")
    policy_net, value_net, discrim_net = pickle.load(
        open('learned_models/nextaction_pretrain_sigpolicy.p', 'rb'))
    #_, _, discrim_net = pickle.load(open('learned_models/nextaction_trained_sigpolicy.p', 'rb'))
    print("Loading Models Finished")
#####################################################

if use_gpu:
    policy_net = policy_net.cuda()
Exemple #8
0
action_dim = 1 if is_disc_action else env.action_space.shape[0]
running_state = ZFilter((state_dim,), clip=5)
# running_reward = ZFilter((1,), demean=False, clip=10)

"""seeding"""
np.random.seed(args.seed)
torch.manual_seed(args.seed)
env.seed(args.seed)

"""define actor and critic"""
if is_disc_action:
    policy_net = DiscretePolicy(state_dim, env.action_space.n)
else:
    policy_net = Policy(state_dim, env.action_space.shape[0], log_std=args.log_std)
value_net = Value(state_dim)
discrim_net = Discriminator(state_dim + action_dim)
discrim_criterion = nn.BCELoss()
to_device(device, policy_net, value_net, discrim_net, discrim_criterion)

optimizer_policy = torch.optim.Adam(policy_net.parameters(), lr=args.learning_rate)
optimizer_value = torch.optim.Adam(value_net.parameters(), lr=args.learning_rate)
optimizer_discrim = torch.optim.Adam(discrim_net.parameters(), lr=args.learning_rate)

# optimization epoch number and batch size for PPO
optim_epochs = 10
optim_batch_size = 64

# load trajectory
expert_traj, running_state = pickle.load(open(args.expert_traj_path, "rb"))
running_state.fix = True
Exemple #9
0
 def setUp(self) -> None:
     self.d = Discriminator(6, 5, drop_rate=0.5)
     print(self.d)
Exemple #10
0
if torch.cuda.is_available():
    torch.cuda.set_device(args.gpu_index)
env = gym.make(args.env_name)
state_dim = env.observation_space.shape[0]
is_disc_action = len(env.action_space.shape) == 0
action_dim = 1 if is_disc_action else env.action_space.shape[0]
running_state = ZFilter((state_dim,), clip=5)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
env.seed(args.seed)
if is_disc_action:
    policy_net = DiscretePolicy(state_dim, env.action_space.n)
else:
    policy_net = Policy(state_dim, env.action_space.shape[0], log_std=args.log_std)
value_net = Value(state_dim)
discriminator = Discriminator(state_dim + action_dim)
discrim_criterion = nn.BCELoss()
to_device(device, policy_net, value_net, discriminator, discrim_criterion)
optimizer_policy = torch.optim.Adam(policy_net.parameters(), lr=args.learning_rate)
optimizer_value = torch.optim.Adam(value_net.parameters(), lr=args.learning_rate)
optimizer_discrim = torch.optim.Adam(discriminator.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999))
if args.resume_training:
    policy_net, value_net, discriminator, running_state = pickle.load(open('assets/learned_models/AIRL/{}/{}_AIRL_s{}.p'.format(args.dataset_size, args.env_name, args.seed), "rb"))
# load trajectory and concatenate all demonstration sets - this is done because subsampled trajectories
# are saved in the form (traj_num, samples, dim) to fit in SIL's algorithm (1). Here we put them all together for GAIL.
subsampled_expert_traj, running_state = pickle.load(open(args.expert_traj_path, "rb"))
running_state.fix = True
print(running_state.clip)
subsampled_expert_traj = subsampled_expert_traj[:args.dataset_size, ::]
print(subsampled_expert_traj.shape)
expert_traj = []
Exemple #11
0
class MAGAIL:
    def __init__(self, config, log_dir, exp_name):
        self.config = config
        self.exp_name = exp_name
        self.writer = SummaryWriter(log_dir=f"{log_dir}/{self.exp_name}")
        """seeding"""
        seed = self.config["general"]["seed"]
        torch.manual_seed(seed)
        np.random.seed(seed)

        self._load_expert_data()
        self._init_model()

    def _init_model(self):
        self.V = Value(num_states=self.config["value"]["num_states"],
                       num_hiddens=self.config["value"]["num_hiddens"],
                       drop_rate=self.config["value"]["drop_rate"],
                       activation=self.config["value"]["activation"])
        self.P = JointPolicy(
            initial_state=self.expert_dataset.state.to(device),
            config=self.config["jointpolicy"])
        self.D = Discriminator(
            num_states=self.config["discriminator"]["num_states"],
            num_actions=self.config["discriminator"]["num_actions"],
            num_hiddens=self.config["discriminator"]["num_hiddens"],
            drop_rate=self.config["discriminator"]["drop_rate"],
            use_noise=self.config["discriminator"]["use_noise"],
            noise_std=self.config["discriminator"]["noise_std"],
            activation=self.config["discriminator"]["activation"])

        print("Model Structure")
        print(self.P)
        print(self.V)
        print(self.D)
        print()

        self.optimizer_policy = optim.Adam(
            self.P.parameters(),
            lr=self.config["jointpolicy"]["learning_rate"])
        self.optimizer_value = optim.Adam(
            self.V.parameters(), lr=self.config["value"]["learning_rate"])
        self.optimizer_discriminator = optim.Adam(
            self.D.parameters(),
            lr=self.config["discriminator"]["learning_rate"])
        self.scheduler_discriminator = optim.lr_scheduler.StepLR(
            self.optimizer_discriminator, step_size=2000, gamma=0.95)

        self.discriminator_func = nn.BCELoss()

        to_device(self.V, self.P, self.D, self.D, self.discriminator_func)

    def _load_expert_data(self):
        num_expert_states = self.config["general"]["num_states"]
        num_expert_actions = self.config["general"]["num_actions"]
        expert_batch_size = self.config["general"]["expert_batch_size"]

        self.expert_dataset = ExpertDataSet(
            data_set_path=self.config["general"]["expert_data_path"],
            num_states=num_expert_states,
            num_actions=num_expert_actions)
        self.expert_data_loader = DataLoader(
            dataset=self.expert_dataset,
            batch_size=expert_batch_size,
            shuffle=True,
            num_workers=multiprocessing.cpu_count() // 2)

    def train(self, epoch):
        self.P.train()
        self.D.train()
        self.V.train()

        # collect generated batch
        gen_batch = self.P.collect_samples(
            self.config["ppo"]["sample_batch_size"])
        # batch: ('state', 'action', 'next_state', 'log_prob', 'mask')
        gen_batch_state = trans_shape_func(
            torch.stack(gen_batch.state
                        ))  # [trajectory length * parallel size, state size]
        gen_batch_action = trans_shape_func(
            torch.stack(gen_batch.action
                        ))  # [trajectory length * parallel size, action size]
        gen_batch_next_state = trans_shape_func(
            torch.stack(gen_batch.next_state)
        )  # [trajectory length * parallel size, state size]
        gen_batch_old_log_prob = trans_shape_func(
            torch.stack(
                gen_batch.log_prob))  # [trajectory length * parallel size, 1]
        gen_batch_mask = trans_shape_func(torch.stack(
            gen_batch.mask))  # [trajectory length * parallel size, 1]

        # grad_collect_func = lambda d: torch.cat([grad.view(-1) for grad in torch.autograd.grad(d, self.D.parameters(), retain_graph=True)]).unsqueeze(0)
        ####################################################
        # update discriminator
        ####################################################
        for expert_batch_state, expert_batch_action in self.expert_data_loader:
            gen_r = self.D(gen_batch_state, gen_batch_action)
            expert_r = self.D(expert_batch_state.to(device),
                              expert_batch_action.to(device))

            # label smoothing for discriminator
            expert_labels = torch.ones_like(expert_r)
            gen_labels = torch.zeros_like(gen_r)

            if self.config["discriminator"]["use_label_smoothing"]:
                smoothing_rate = self.config["discriminator"][
                    "label_smooth_rate"]
                expert_labels *= (1 - smoothing_rate)
                gen_labels += torch.ones_like(gen_r) * smoothing_rate

            e_loss = self.discriminator_func(expert_r, expert_labels)
            g_loss = self.discriminator_func(gen_r, gen_labels)
            d_loss = e_loss + g_loss

            # """ WGAN with Gradient Penalty"""
            # d_loss = gen_r.mean() - expert_r.mean()
            # differences_batch_state = gen_batch_state[:expert_batch_state.size(0)] - expert_batch_state
            # differences_batch_action = gen_batch_action[:expert_batch_action.size(0)] - expert_batch_action
            # alpha = torch.rand(expert_batch_state.size(0), 1)
            # interpolates_batch_state = gen_batch_state[:expert_batch_state.size(0)] + (alpha * differences_batch_state)
            # interpolates_batch_action = gen_batch_action[:expert_batch_action.size(0)] + (alpha * differences_batch_action)
            # gradients = torch.cat([x for x in map(grad_collect_func, self.D(interpolates_batch_state, interpolates_batch_action))])
            # slopes = torch.norm(gradients, p=2, dim=-1)
            # gradient_penalty = torch.mean((slopes - 1.) ** 2)
            # d_loss += 10 * gradient_penalty

            self.optimizer_discriminator.zero_grad()
            d_loss.backward()
            self.optimizer_discriminator.step()

            self.scheduler_discriminator.step()

        self.writer.add_scalar('train/loss/d_loss', d_loss.item(), epoch)
        self.writer.add_scalar("train/loss/e_loss", e_loss.item(), epoch)
        self.writer.add_scalar("train/loss/g_loss", g_loss.item(), epoch)
        self.writer.add_scalar('train/reward/expert_r',
                               expert_r.mean().item(), epoch)
        self.writer.add_scalar('train/reward/gen_r',
                               gen_r.mean().item(), epoch)

        with torch.no_grad():
            gen_batch_value = self.V(gen_batch_state)
            gen_batch_reward = self.D(gen_batch_state, gen_batch_action)

        gen_batch_advantage, gen_batch_return = estimate_advantages(
            gen_batch_reward, gen_batch_mask, gen_batch_value,
            self.config["gae"]["gamma"], self.config["gae"]["tau"],
            self.config["jointpolicy"]["trajectory_length"])

        ####################################################
        # update policy by ppo [mini_batch]
        ####################################################
        ppo_optim_epochs = self.config["ppo"]["ppo_optim_epochs"]
        ppo_mini_batch_size = self.config["ppo"]["ppo_mini_batch_size"]
        gen_batch_size = gen_batch_state.shape[0]
        optim_iter_num = int(math.ceil(gen_batch_size / ppo_mini_batch_size))

        for _ in range(ppo_optim_epochs):
            perm = torch.randperm(gen_batch_size)

            for i in range(optim_iter_num):
                ind = perm[slice(
                    i * ppo_mini_batch_size,
                    min((i + 1) * ppo_mini_batch_size, gen_batch_size))]
                mini_batch_state, mini_batch_action, mini_batch_next_state, mini_batch_advantage, mini_batch_return, \
                mini_batch_old_log_prob = gen_batch_state[ind], gen_batch_action[ind], gen_batch_next_state[ind], \
                                          gen_batch_advantage[ind], gen_batch_return[ind], gen_batch_old_log_prob[ind]

                v_loss, p_loss = ppo_step(
                    self.P,
                    self.V,
                    self.optimizer_policy,
                    self.optimizer_value,
                    states=mini_batch_state,
                    actions=mini_batch_action,
                    next_states=mini_batch_next_state,
                    returns=mini_batch_return,
                    old_log_probs=mini_batch_old_log_prob,
                    advantages=mini_batch_advantage,
                    ppo_clip_ratio=self.config["ppo"]["clip_ratio"],
                    value_l2_reg=self.config["value"]["l2_reg"])

                self.writer.add_scalar('train/loss/p_loss', p_loss, epoch)
                self.writer.add_scalar('train/loss/v_loss', v_loss, epoch)

        print(f" Training episode:{epoch} ".center(80, "#"))
        print('gen_r:', gen_r.mean().item())
        print('expert_r:', expert_r.mean().item())
        print('d_loss', d_loss.item())

    def eval(self, epoch):
        self.P.eval()
        self.D.eval()
        self.V.eval()

        gen_batch = self.P.collect_samples(
            self.config["ppo"]["sample_batch_size"])
        gen_batch_state = torch.stack(gen_batch.state)
        gen_batch_action = torch.stack(gen_batch.action)

        gen_r = self.D(gen_batch_state, gen_batch_action)
        for expert_batch_state, expert_batch_action in self.expert_data_loader:
            expert_r = self.D(expert_batch_state.to(device),
                              expert_batch_action.to(device))

            print(f" Evaluating episode:{epoch} ".center(80, "-"))
            print('validate_gen_r:', gen_r.mean().item())
            print('validate_expert_r:', expert_r.mean().item())

        self.writer.add_scalar("validate/reward/gen_r",
                               gen_r.mean().item(), epoch)
        self.writer.add_scalar("validate/reward/expert_r",
                               expert_r.mean().item(), epoch)

    def save_model(self, save_path):
        if not os.path.exists(save_path):
            os.mkdir(save_path)
        # dump model from pkl file
        # torch.save((self.D, self.P, self.V), f"{save_path}/{self.exp_name}.pt")
        torch.save(self.D, f"{save_path}/{self.exp_name}_Discriminator.pt")
        torch.save(self.P, f"{save_path}/{self.exp_name}_JointPolicy.pt")
        torch.save(self.V, f"{save_path}/{self.exp_name}_Value.pt")

    def load_model(self, model_path):
        # load entire model
        # self.D, self.P, self.V = torch.load((self.D, self.P, self.V), f"{save_path}/{self.exp_name}.pt")
        self.D = torch.load(f"{model_path}_Discriminator.pt",
                            map_location=device)
        self.P = torch.load(f"{model_path}_JointPolicy.pt",
                            map_location=device)
        self.V = torch.load(f"{model_path}_Value.pt", map_location=device)
def create_networks():
    """define actor and critic"""
    if is_disc_action:
        policy_net = DiscretePolicy(state_dim,
                                    env.action_space.n,
                                    hidden_size=(64, 32),
                                    activation='relu')
    else:
        policy_net = Policy(state_dim,
                            env.action_space.shape[0],
                            log_std=args.log_std,
                            hidden_size=(64, 32),
                            activation='relu')
    value_net = Value(state_dim, hidden_size=(32, 16), activation='relu')
    if args.AL:
        discrim_net = SNDiscriminator(state_dim + action_dim,
                                      hidden_size=(32, 16),
                                      activation='relu')
    elif args.EBGAN or args.GMMIL:
        discrim_net = AESNDiscriminator(state_dim + action_dim,
                                        hidden_size=(32, ),
                                        encode_size=64,
                                        activation='leakyrelu',
                                        slope=0.1,
                                        dropout=True,
                                        dprob=0.2)
    elif args.VAKLIL:
        noise_dim = 64
        mid_dim = 32
        discrim_net = VAEDiscriminator(state_dim + action_dim,
                                       num_outputs=noise_dim,
                                       sigmoid_out=False,
                                       sn=True,
                                       test=False,
                                       w_init=False,
                                       hidden_size_enc=(),
                                       hidden_size_dec=(),
                                       encode_size=mid_dim,
                                       activation='relu',
                                       dropout=False)
        kernel_net = NoiseNet(noise_dim,
                              hidden_size=(32, ),
                              encode_size=noise_dim,
                              activation='relu',
                              dropout=False)
        optimizer_kernel = torch.optim.Adam(kernel_net.parameters(),
                                            lr=args.learning_rate)
        scheduler_kernel = MultiStepLR(optimizer_kernel,
                                       milestones=args.milestones,
                                       gamma=args.lr_kernel_decay)
    else:
        discrim_net = Discriminator(state_dim + action_dim,
                                    hidden_size=(32, 16),
                                    activation='relu')

    optimizer_policy = torch.optim.Adam(policy_net.parameters(),
                                        lr=args.learning_rate)
    optimizer_value = torch.optim.Adam(value_net.parameters(),
                                       lr=args.learning_rate)
    optimizer_discrim = torch.optim.Adam(discrim_net.parameters(),
                                         lr=args.learning_rate)

    scheduler_policy = MultiStepLR(optimizer_policy,
                                   milestones=args.milestones,
                                   gamma=args.lr_decay)
    scheduler_value = MultiStepLR(optimizer_value,
                                  milestones=args.milestones,
                                  gamma=args.lr_decay)
    scheduler_discrim = MultiStepLR(optimizer_discrim,
                                    milestones=args.milestones,
                                    gamma=args.lr_kernel_decay)

    if args.AL:

        class ExpertReward():
            def __init__(self):
                self.a = 0

            def expert_reward(self, state, action):
                state_action = tensor(np.hstack([state, action]), dtype=dtype)
                with torch.no_grad():
                    return -discrim_net(state_action)[0].item()

        learned_reward = ExpertReward()
    elif args.EBGAN:

        class ExpertReward():
            def __init__(self):
                self.a = 0

            def expert_reward(self, state, action):
                state_action = tensor(np.hstack([state, action]), dtype=dtype)
                with torch.no_grad():
                    _, recon_out = discrim_net(state_action)
                    return -elementwise_loss(
                        recon_out, state_action).item() + args.r_margin

        learned_reward = ExpertReward()
    elif args.GMMIL or args.VAKLIL:

        class ExpertReward():
            def __init__(self):
                self.r_bias = 0

            def expert_reward(self, state, action):
                with torch.no_grad():
                    return self.r_bias

            def update_XX_YY(self):
                self.XX = torch.diag(torch.mm(self.e_o_enc, self.e_o_enc.t()))
                self.YY = torch.diag(torch.mm(self.g_o_enc, self.g_o_enc.t()))

        learned_reward = ExpertReward()
    else:

        class ExpertReward():
            def __init__(self):
                self.a = 0

            def expert_reward(self, state, action):
                state_action = tensor(np.hstack([state, action]), dtype=dtype)
                with torch.no_grad():
                    return -math.log(discrim_net(state_action)[0].item())

        learned_reward = ExpertReward()
    """create agent"""
    agent = Agent(env,
                  policy_net,
                  device,
                  custom_reward=learned_reward,
                  running_state=None,
                  render=args.render,
                  num_threads=args.num_threads)

    def update_params(batch, i_iter):
        dataSize = min(args.min_batch_size, len(batch.state))
        states = torch.from_numpy(np.stack(
            batch.state)[:dataSize, ]).to(dtype).to(device)
        actions = torch.from_numpy(np.stack(
            batch.action)[:dataSize, ]).to(dtype).to(device)
        rewards = torch.from_numpy(np.stack(
            batch.reward)[:dataSize, ]).to(dtype).to(device)
        masks = torch.from_numpy(np.stack(
            batch.mask)[:dataSize, ]).to(dtype).to(device)
        with torch.no_grad():
            values = value_net(states)
            fixed_log_probs = policy_net.get_log_prob(states, actions)
        """estimate reward"""
        """get advantage estimation from the trajectories"""
        advantages, returns = estimate_advantages(rewards, masks, values,
                                                  args.gamma, args.tau, device)
        """update discriminator"""
        for _ in range(args.discriminator_epochs):
            exp_idx = random.sample(range(expert_traj.shape[0]), dataSize)
            expert_state_actions = torch.from_numpy(
                expert_traj[exp_idx, :]).to(dtype).to(device)

            dis_input_real = expert_state_actions
            if len(actions.shape) == 1:
                actions.unsqueeze_(-1)
                dis_input_fake = torch.cat([states, actions], 1)
                actions.squeeze_(-1)
            else:
                dis_input_fake = torch.cat([states, actions], 1)

            if args.EBGAN or args.GMMIL or args.VAKLIL:
                g_o_enc, g_mu, g_sigma = discrim_net(dis_input_fake,
                                                     mean_mode=False)
                e_o_enc, e_mu, e_sigma = discrim_net(dis_input_real,
                                                     mean_mode=False)
            else:
                g_o = discrim_net(dis_input_fake)
                e_o = discrim_net(dis_input_real)

            optimizer_discrim.zero_grad()
            if args.VAKLIL:
                optimizer_kernel.zero_grad()

            if args.AL:
                if args.LSGAN:
                    pdist = l1dist(dis_input_real,
                                   dis_input_fake).mul(args.lamb)
                    discrim_loss = LeakyReLU(e_o - g_o + pdist).mean()
                else:
                    discrim_loss = torch.mean(e_o) - torch.mean(g_o)
            elif args.EBGAN:
                e_recon = elementwise_loss(e_o, dis_input_real)
                g_recon = elementwise_loss(g_o, dis_input_fake)
                discrim_loss = e_recon
                if (args.margin - g_recon).item() > 0:
                    discrim_loss += (args.margin - g_recon)
            elif args.GMMIL:
                mmd2_D, K = mix_rbf_mmd2(e_o_enc, g_o_enc, args.sigma_list)
                rewards = K[1] - K[2]  # -(exp - gen): -(kxy-kyy)=kyy-kxy
                rewards = -rewards.detach(
                )  # exp - gen, maximize (gen label negative)
                errD = mmd2_D
                discrim_loss = -errD  # maximize errD

                advantages, returns = estimate_advantages(
                    rewards, masks, values, args.gamma, args.tau, device)
            elif args.VAKLIL:
                noise_num = 20000
                mmd2_D_net, _, penalty = mix_imp_with_bw_mmd2(
                    e_o_enc, g_o_enc, noise_num, noise_dim, kernel_net, cuda,
                    args.sigma_list)
                mmd2_D_rbf, _ = mix_rbf_mmd2(e_o_enc, g_o_enc, args.sigma_list)
                errD = (mmd2_D_net + mmd2_D_rbf) / 2
                # 1e-8: small number for numerical stability
                i_c = 0.2
                bottleneck_loss = torch.mean((0.5 * torch.sum((torch.cat(
                    (e_mu, g_mu), dim=0)**2) + (torch.cat(
                        (e_sigma, g_sigma), dim=0)**2) - torch.log((torch.cat(
                            (e_sigma, g_sigma), dim=0)**2) + 1e-8) - 1,
                                                              dim=1))) - i_c
                discrim_loss = -errD + (args.beta * bottleneck_loss) + (
                    args.lambda_h * penalty)
            else:
                discrim_loss = discrim_criterion(g_o, ones((states.shape[0], 1), device=device)) + \
                               discrim_criterion(e_o, zeros((e_o.shape[0], 1), device=device))

            discrim_loss.backward()
            optimizer_discrim.step()
            if args.VAKLIL:
                optimizer_kernel.step()

        if args.VAKLIL:
            with torch.no_grad():
                noise_num = 20000
                g_o_enc, _, _ = discrim_net(dis_input_fake)
                e_o_enc, _, _ = discrim_net(dis_input_real)
                _, K_net, _ = mix_imp_with_bw_mmd2(e_o_enc, g_o_enc, noise_num,
                                                   noise_dim, kernel_net, cuda,
                                                   args.sigma_list)
                _, K_rbf = mix_rbf_mmd2(e_o_enc, g_o_enc, args.sigma_list)
                K = [sum(x) / 2 for x in zip(K_net, K_rbf)]
                rewards = K[1] - K[2]  # -(exp - gen): -(kxy-kyy)=kyy-kxy
                rewards = -rewards  #.detach()
                advantages, returns = estimate_advantages(
                    rewards, masks, values, args.gamma, args.tau, device)
        """perform mini-batch PPO update"""
        optim_iter_num = int(math.ceil(states.shape[0] / args.ppo_batch_size))
        for _ in range(args.generator_epochs):
            perm = np.arange(states.shape[0])
            np.random.shuffle(perm)
            perm = LongTensor(perm).to(device)

            states, actions, returns, advantages, fixed_log_probs = \
                states[perm].clone(), actions[perm].clone(), returns[perm].clone(), advantages[perm].clone(), \
                fixed_log_probs[perm].clone()

            for i in range(optim_iter_num):
                ind = slice(
                    i * args.ppo_batch_size,
                    min((i + 1) * args.ppo_batch_size, states.shape[0]))
                states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
                    states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind]

                ppo_step(policy_net, value_net, optimizer_policy,
                         optimizer_value, 1, states_b, actions_b, returns_b,
                         advantages_b, fixed_log_probs_b, args.clip_epsilon,
                         args.l2_reg)

        return rewards

    if args.VAKLIL:
        return policy_net,value_net,discrim_net,kernel_net,optimizer_policy,optimizer_value,optimizer_discrim,optimizer_kernel,agent,update_params \
            ,scheduler_policy,scheduler_value,scheduler_discrim,scheduler_kernel
    else:
        return policy_net,value_net,discrim_net,optimizer_policy,optimizer_value,optimizer_discrim,agent,update_params \
            ,scheduler_policy,scheduler_value,scheduler_discrim
Exemple #13
0
class GAIL(object):
    """
    A vanilla GAIL with JS-divergence Discriminator and PPO actor.
    """
    def __init__(self, args, state_dim, action_dim, is_dict_action):

        self.device = args.device
        self.config = args

        # initialize discriminator
        discriminator_action_dim = 1 if is_dict_action else action_dim
        self.discriminator = Discriminator(state_dim +
                                           discriminator_action_dim).to(
                                               self.device)
        self.discriminator_loss = nn.BCEWithLogitsLoss()
        self.discriminator_optimizer = torch.optim.Adam(
            self.discriminator.parameters(), lr=self.config.learning_rate)
        # initialize actor
        self.policy = PPO(args, state_dim, action_dim, is_dict_action)

    def expert_reward(self, state, action):
        """
        Compute the reward signal for the (PPO) actor update
        :param state:
        :param action:
        :return:
        """
        state_action = torch.DoubleTensor(np.hstack([state, action]))
        with torch.no_grad():
            return -math.log(
                1 - torch.sigmoid(self.discriminator(state_action)[0]) + 1e-8)

    def compute_entropy(self, logits):
        logsigmoid = nn.LogSigmoid()
        ent = (1. - torch.sigmoid(logits)) * logits - logsigmoid(logits)
        return torch.mean(ent)

    def set_expert(self, expert_traj, num_trajs):
        """
        Set the expert trajectories.
        :param expert_traj:
        :return:
        """
        self.expert_traj_pool = expert_traj
        self.expert_traj = np.vstack(expert_traj[:num_trajs])

    def train_discriminator(self, batch):
        """
        Train the discriminator.
        :param batch:
        :return:
        """
        states = torch.DoubleTensor(np.stack(batch.state)).to(self.device)
        actions = torch.DoubleTensor(np.stack(batch.action)).to(self.device)
        # assume one gradient step for now
        for _ in range(5):
            expert_state_actions = torch.DoubleTensor(self.expert_traj).to(
                self.device)
            g_o = self.discriminator(torch.cat([states, actions], 1))
            e_o = self.discriminator(expert_state_actions)
            self.discriminator_optimizer.zero_grad()
            generator_loss = self.discriminator_loss(
                g_o, zeros((states.shape[0], 1), device=self.device))
            expert_loss = self.discriminator_loss(
                e_o, ones((self.expert_traj.shape[0], 1), device=self.device))

            # print(g_o.shape, e_o.shape)
            # entropy_loss = self.compute_entropy(torch.cat([g_o, e_o], dim=0))
            # print(generator_loss, expert_loss, entropy_loss)
            # discrim_loss = generator_loss + expert_loss + 0.001 * entropy_loss
            discrim_loss = generator_loss + expert_loss

            # compute accuracy
            with torch.no_grad():
                generator_accuracy = torch.mean(
                    (torch.sigmoid(g_o) < 0.5).float())
                expert_accuracy = torch.mean(
                    (torch.sigmoid(e_o) > 0.5).float())

            # discrim_loss = self.discriminator_loss(g_o, ones((states.shape[0], 1), device=self.device)) + \
            #     self.discriminator_loss(e_o, zeros((self.expert_traj.shape[0], 1), device=self.device))
            discrim_loss.backward()
            self.discriminator_optimizer.step()
        return {
            "d_loss": discrim_loss.to('cpu').detach().numpy(),
            "e_loss": expert_loss.to('cpu').detach().numpy(),
            "g_loss": generator_loss.to('cpu').detach().numpy(),
            "g_acc": generator_accuracy.to('cpu').detach().numpy(),
            "e_acc": expert_accuracy.to('cpu').detach().numpy()
        }

    def train(self, batch):
        """
        Train the discriminator and the actor.
        :param batch:
        :return:
        """

        loss = self.train_discriminator(batch)
        self.policy.train(batch)
        return loss
Exemple #14
0
                         cnn_options['strides'],
                         head_hidden_size=cnn_options['head_hidden_sizes'],
                         num_aux=num_aux,
                         resnet_first_layer=args.cnn_resnet_first_layer)
    discrim_net = CNNDiscriminator(
        state_dim,
        action_dim,
        cnn_options['channels'],
        cnn_options['kernel_sizes'],
        cnn_options['strides'],
        head_hidden_size=cnn_options['head_hidden_sizes'],
        num_aux=num_aux,
        resnet_first_layer=args.cnn_resnet_first_layer)
else:
    value_net = Value(state_dim)
    discrim_net = Discriminator(state_dim + action_dim)

discrim_criterion = nn.BCELoss()
to_device(device, policy_net, value_net, discrim_net, discrim_criterion)

optimizer_policy = torch.optim.Adam(policy_net.parameters(),
                                    lr=args.pol_learning_rate)
optimizer_value = torch.optim.Adam(value_net.parameters(),
                                   lr=args.learning_rate)
optimizer_discrim = torch.optim.Adam(discrim_net.parameters(),
                                     lr=args.learning_rate)

# optimization epoch number and batch size for PPO
optim_epochs = 10
optim_batch_size = 64