Exemplo n.º 1
0
class Sampler():

    def __init__(self,device,actionsize):
        self.samplenet = DQN(actionsize).to(device)
        self.targetnet = DQN(actionsize).to(device)
        self.opt = torch.optim.Adam(itertools.chain(self.samplenet.parameters()),lr=0.00001,betas=(0.0,0.9))
        self.device = device
        self.memory = ReplayMemory(1000,device=device)
        self.BATCH_SIZE = 10
        self.GAMMA = 0.99
        self.count = 0

    def select_action(self, model):
        self.samplenet.eval()
        action = self.samplenet(model.conv2.weight.data.view(-1,5,5).unsqueeze(0))
        return torch.max(action,1)[1]

    def step(self,state,action,next_state,reward,done):
        self.memory.push(state,action,next_state,reward,done)

        #don't bother if you don't have enough in memory
        if len(self.memory) >= self.BATCH_SIZE:
            self.optimize()

    def optimize(self):
        self.samplenet.train()
        self.targetnet.eval()
        s1,actions,r1,s2,d = self.memory.sample(self.BATCH_SIZE)

        #get old Q values and new Q values for belmont eq
        qvals = self.samplenet(s1)
        state_action_values = qvals.gather(1,actions[:,0].unsqueeze(1))
        with torch.no_grad():
            qvals_t = self.targetnet(s2)
            q1_t = qvals_t.max(1)[0].unsqueeze(1)

        expected_state_action_values = (q1_t * self.GAMMA) * (1-d) + r1

        #LOSS IS l2 loss of belmont equation
        loss = torch.nn.MSELoss()(state_action_values,expected_state_action_values)

        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

        if self.count % 20 == 0:
            self.targetnet.load_state_dict(self.samplenet.state_dict())

        return loss.item()
Exemplo n.º 2
0
class Generator():
    def __init__(self, device, data):
        self.data = data
        self.actor = Actor().to(device)
        self.critic = Critic().to(device)
        #self.ctarget = Critic().to(device)
        self.actor_opt = torch.optim.Adam(itertools.chain(
            self.actor.parameters()),
                                          lr=0.0001,
                                          betas=(0.0, 0.9))
        self.critic_opt = torch.optim.Adam(itertools.chain(
            self.critic.parameters()),
                                           lr=0.001,
                                           betas=(0.0, 0.9))

        def init_weights(m):
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_uniform_(m.weight.data)

        self.actor.apply(init_weights)
        self.critic.apply(init_weights)
        #self.ctarget.apply(init_weights)

        self.device = device
        self.memory = ReplayMemory(1000, device=device)
        self.batch_size = 5
        self.GAMMA = 0.99
        self.count = 0

    def select_action(self, imgs):
        with torch.no_grad():
            self.actor.eval()
            action = self.actor(imgs)
            return action

    def step(self, state, action, next_state, reward, done):
        self.memory.push(state, action, next_state, reward, done)

        if len(self.memory) >= self.batch_size:
            self.optimize()

    def optimize(self):
        self.actor.train()
        self.critic.train()
        #self.ctarget.eval()

        s1, a, r, s2, d = self.memory.sample(self.batch_size)

        #train the critic
        for reward, action in zip(r, a):
            qval = self.critic(action)
            avgQ = qval.mean().unsqueeze(0)
            loss = torch.nn.L1Loss()(avgQ, reward)
            self.critic_opt.zero_grad()
            loss.backward()
            self.critic_opt.step()

        #train the actor
        img, target = self.data[random.randint(0, len(self.data) - 1)]
        batch = self.actor(img)
        score = self.critic(batch)
        actor_loss = -score.mean()
        self.actor_opt.zero_grad()
        actor_loss.backward()
        self.actor_opt.step()

        #if self.count % 5 == 0:
        #    self.ctarget.load_state_dict(self.critic.state_dict())
        #self.count += 1

    def save(self):
        torch.save(self.actor.state_dict(), os.path.join('model', 'actor.pth'))
        torch.save(self.critic.state_dict(),
                   os.path.join('model', 'critic.pth'))
Exemplo n.º 3
0
class Agent:
    def __init__(self, **config):
        self.config = config
        self.n_actions = self.config["n_actions"]
        self.state_shape = self.config["state_shape"]
        self.batch_size = self.config["batch_size"]
        self.gamma = self.config["gamma"]
        self.initial_mem_size_to_train = self.config[
            "initial_mem_size_to_train"]
        torch.manual_seed(self.config["seed"])

        if torch.cuda.is_available():
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            torch.cuda.empty_cache()
            torch.cuda.manual_seed(self.config["seed"])
            torch.cuda.manual_seed_all(self.config["seed"])
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        self.memory = ReplayMemory(self.config["mem_size"],
                                   self.config["alpha"], self.config["seed"])
        self.v_min = self.config["v_min"]
        self.v_max = self.config["v_max"]
        self.n_atoms = self.config["n_atoms"]
        self.support = torch.linspace(self.v_min, self.v_max,
                                      self.n_atoms).to(self.device)
        self.delta_z = (self.v_max - self.v_min) / (self.n_atoms - 1)
        self.offset = torch.linspace(0, (self.batch_size - 1) * self.n_atoms, self.batch_size).long() \
            .unsqueeze(1).expand(self.batch_size, self.n_atoms).to(self.device)

        self.n_step = self.config["n_step"]
        self.n_step_buffer = deque(maxlen=self.n_step)

        self.online_model = Model(self.state_shape, self.n_actions,
                                  self.n_atoms, self.support,
                                  self.device).to(self.device)
        self.target_model = Model(self.state_shape, self.n_actions,
                                  self.n_atoms, self.support,
                                  self.device).to(self.device)
        self.hard_update_target_network()

        self.optimizer = Adam(self.online_model.parameters(),
                              lr=self.config["lr"],
                              eps=self.config["adam_eps"])

    def choose_action(self, state):
        state = np.expand_dims(state, axis=0)
        state = from_numpy(state).byte().to(self.device)
        with torch.no_grad():
            self.online_model.reset()
            action = self.online_model.get_q_value(state).argmax(-1)
        return action.item()

    def store(self, state, action, reward, next_state, done):
        """Save I/O s to store them in RAM and not to push pressure on GPU RAM """
        assert state.dtype == "uint8"
        assert next_state.dtype == "uint8"
        assert isinstance(reward, int)
        assert isinstance(done, bool)

        self.n_step_buffer.append((state, action, reward, next_state, done))
        if len(self.n_step_buffer) < self.n_step:
            return

        reward, next_state, done = self.get_n_step_returns()
        state, action, *_ = self.n_step_buffer.popleft()

        self.memory.add(state, np.uint8(action), reward, next_state, done)

    def soft_update_target_network(self, tau):
        for target_param, local_param in zip(self.target_model.parameters(),
                                             self.online_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
        # self.target_model.train()
        for param in self.target_model.parameters():
            param.requires_grad = False

    def hard_update_target_network(self):
        self.target_model.load_state_dict(self.online_model.state_dict())
        # self.target_model.train()
        for param in self.target_model.parameters():
            param.requires_grad = False

    def unpack_batch(self, batch):
        batch = self.config["transition"](*zip(*batch))

        states = from_numpy(np.stack(batch.state)).to(self.device)
        actions = from_numpy(np.stack(batch.action)).to(self.device).view(
            (-1, 1))
        rewards = from_numpy(np.stack(batch.reward)).to(self.device).view(
            (-1, 1))
        next_states = from_numpy(np.stack(batch.next_state)).to(self.device)
        dones = from_numpy(np.stack(batch.done)).to(self.device).view((-1, 1))
        return states, actions, rewards, next_states, dones

    def train(self, beta):
        if len(self.memory) < self.initial_mem_size_to_train:
            return 0, 0  # as no loss
        batch, weights, indices = self.memory.sample(self.batch_size, beta)
        states, actions, rewards, next_states, dones = self.unpack_batch(batch)
        weights = from_numpy(weights).float().to(self.device)

        with torch.no_grad():
            self.online_model.reset()
            self.target_model.reset()
            q_eval_next = self.online_model.get_q_value(next_states)
            selected_actions = torch.argmax(q_eval_next, dim=-1)
            q_next = self.target_model(next_states)[range(self.batch_size),
                                                    selected_actions]

            projected_atoms = rewards + (self.gamma**
                                         self.n_step) * self.support * (~dones)
            projected_atoms = projected_atoms.clamp(min=self.v_min,
                                                    max=self.v_max)

            b = (projected_atoms - self.v_min) / self.delta_z
            lower_bound = b.floor().long()
            upper_bound = b.ceil().long()
            lower_bound[(upper_bound > 0) * (lower_bound == upper_bound)] -= 1
            upper_bound[(lower_bound < (self.n_atoms - 1)) *
                        (lower_bound == upper_bound)] += 1

            projected_dist = torch.zeros(q_next.size(),
                                         dtype=torch.float64).to(self.device)
            projected_dist.view(-1).index_add_(
                0, (lower_bound + self.offset).view(-1),
                (q_next * (upper_bound.float() - b)).view(-1))
            projected_dist.view(-1).index_add_(
                0, (upper_bound + self.offset).view(-1),
                (q_next * (b - lower_bound.float())).view(-1))

        eval_dist = self.online_model(states)[range(self.batch_size),
                                              actions.squeeze().long()]
        dqn_loss = -(projected_dist * torch.log(eval_dist + 1e-6)).sum(-1)
        td_error = dqn_loss.abs() + 1e-6
        self.memory.update_priorities(indices, td_error.detach().cpu().numpy())
        dqn_loss = (dqn_loss * weights).mean()

        self.optimizer.zero_grad()
        dqn_loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(
            self.online_model.parameters(), self.config["clip_grad_norm"])
        self.optimizer.step()

        return dqn_loss.item(), grad_norm.item()

    def ready_to_play(self, state_dict):
        self.online_model.load_state_dict(state_dict)
        self.online_model.eval()

    def get_n_step_returns(self):
        reward, next_state, done = self.n_step_buffer[-1][-3:]

        for transition in reversed(list(self.n_step_buffer)[:-1]):
            r, n_s, d = transition[-3:]

            reward = r + self.gamma * reward * (1 - d)
            next_state, done = (n_s, d) if d else (next_state, done)

        return reward, next_state, done
Exemplo n.º 4
0
class PpgAgent:
    def __init__(
            self,
            observation_space,
            action_space,
            device,
            gamma=0.995,
            actor_lr=5e-4,
            critic_lr=5e-4,
            batch_size=128,
            memory_size=50000,
            tau=5e-3,
            weight_decay=1e-2,
            sigma=0.2,
            noise_clip=0.5,
            alpha=0.2,
            alpha_lr=3e-4,
            rollout_length=2048,
            lambda_=0.95,
            beta_clone=1.0,
            coef_ent=0.01,
            num_updates=32,
            policy_epoch=1,
            value_epoch=1,
            aux_num_updates=6,
            aux_epoch_batch=64,
            max_grad_norm=0.5,
            aux_critic_loss_coef=1.0,
            clip_eps=0.2,
            writer=None,
            is_image=False,
            clip_aux_critic_loss=None,
            clip_aux_multinet_critic_loss=None,
            multipleet_upadte_clip_grad_norm=None,
            summary_interval=1,
            debug_no_aux_phase=False):
        super(PpgAgent, self).__init__()
        self.action_mean = (0.5 * (action_space.high + action_space.low))[0]
        self.action_halfwidth = (0.5 * (action_space.high - action_space.low))[0]
        self.num_state = observation_space.shape[0]
        self.num_action = action_space.shape[0]
        self.state_mean = None
        self.state_halfwidth = None
        if abs(observation_space.high[0]) != math.inf:
            self.state_mean = 0.5 * (observation_space.high + observation_space.low)
            self.state_halfwidth = 0.5 * (observation_space.high - observation_space.low)
        self.gamma = gamma
        self.batch_size = batch_size
        self.device = device
        self.multipleNet = MultipleNetwork(self.num_state, action_space, device, is_image = is_image).to(self.device)
        self.multipleNet_target = MultipleNetwork(self.num_state, action_space, device, is_image = is_image).to(self.device)
        self.multipleNet_target.load_state_dict(self.multipleNet.state_dict())
        self.multipleNet_optimizer = optim.Adam(self.multipleNet.parameters(), lr=actor_lr)

        self.critic = CriticNetwork(self.num_state, action_space, device, is_image = is_image).to(self.device)
        self.critic_target = CriticNetwork(self.num_state, action_space, device, is_image = is_image).to(self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr = critic_lr, weight_decay=weight_decay)

        self.alpha = alpha
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.log_alpha_optimizer = optim.Adam([self.log_alpha], lr = alpha_lr)

        self.memory = ReplayMemory(observation_space, action_space, device, num_state = self.num_state, memory_size = memory_size, is_image = is_image)
        self.criterion = nn.MSELoss()
        self.device = device
        self.tau = tau
        self.writer = writer
        self.is_image =is_image
        self.sigma = sigma
        self.noise_clip = noise_clip
        self.rollout_length = rollout_length
        self.lambda_ = lambda_
        self.coef_ent = coef_ent
        self.aux_critic_loss_coef = aux_critic_loss_coef
        self.max_grad_norm = max_grad_norm
        self.aux_num_updates = aux_num_updates
        self.clip_eps = clip_eps
        self.beta_clone = beta_clone
        self.policy_epoch = policy_epoch
        self.value_epoch = value_epoch
        self.num_updates = num_updates
        self.aux_epoch_batch = aux_epoch_batch
        self.clip_aux_critic_loss = clip_aux_critic_loss
        self.clip_aux_multinet_critic_loss = clip_aux_multinet_critic_loss
        self.multipleet_upadte_clip_grad_norm = multipleet_upadte_clip_grad_norm
        self.summary_interval = summary_interval
        self.debug_no_aux_phase = debug_no_aux_phase
        self.update_step = 0

    def normalize_state(self, state):
        """return normalized state
        """
        if self.state_mean is None:
            return state
        state = (state - self.state_mean) / self.state_halfwidth
        return state

    def soft_update(self, target_net, net):
        """Polyark update
        """
        for target_param, param in zip(target_net.parameters(), net.parameters()):
            target_param.data.copy_(
                self.tau * param.data + (1 - self.tau) * target_param.data
            )

    def is_update(self, steps):
        """update in rollout length interval
        """
        return steps % self.rollout_length == 0

    def update(self, state = None):
        """Training process, according to the original paper, trainign process is follow:
        initialize replay buffer B
        1. perform rollout N_{\pi} times \\ Policy phase
        2. update multiplenet, loss =  L^{clip} + Bhevior Cloning Loss
           update critic, loss = L^{value} (using GAE)
        3. update multiplenet, loss =  L^{joint} \\ Auxiliary Phase
           update critic, loss = L^{value} (using GAE)
        4. reset B
        """
        if not self.is_update(self.memory.index):
            return
        self.update_step += 1
        # sample from replay buffer
        with torch.no_grad():
            batch = self.memory.sample(state)
            action_batch = batch['actions'].to(self.device)
            state_batch = batch['obs'].to(self.device)
            reward_batch = batch['rewards'].to(self.device)
            terminate_batch = batch['terminates'].to(self.device)
            log_pis_batch = batch['log_pis'].to(self.device)
            values = self.critic(state_batch)

        # calulate value target (\hat{V}\_t^{targ}) for each state
        targets, advantages = util.calculate_advantage(values, reward_batch, terminate_batch, self.gamma, self.lambda_)

        # 2. policy phase, update multiplenet and critic
        # https://arxiv.org/pdf/2009.04416.pdf algorithm 1, line 6-7
        for i in range(self.value_epoch):
            indices = np.arange(self.rollout_length)
            np.random.shuffle(indices)
            for start in range(0, self.rollout_length, self.batch_size):
                idxes = indices[start:start + self.batch_size]
                loss_critic = self.update_critic(state_batch[idxes], targets[idxes])
                loss_critic += loss_critic

        n_train_iteration = self.value_epoch * (self.rollout_length // self.batch_size)
        loss_critic = loss_critic / n_train_iteration
        self.writer.add_scalar('/critic/loss/policy_phase', loss_critic, self.update_step)

        # https://arxiv.org/pdf/2009.04416.pdf algorithm 1, line 8-9
        loss_actor, l_clip, bc_loss = 0, 0, 0
        for i in range(self.policy_epoch):
            indices = np.arange(self.rollout_length)
            np.random.shuffle(indices)
            for start in range(0, self.rollout_length, self.batch_size):
                idxes = indices[start:start + self.batch_size]
                loss_actor, l_clip, bc_loss =  self.update_MultipleNet(state_batch[idxes], action_batch[idxes], log_pis_batch[idxes], advantages[idxes])
                loss_actor += loss_actor
                l_clip += l_clip
                bc_loss += bc_loss

        # averaging losses and writeto tensorboard
        n_train_iteration = self.policy_epoch * (self.rollout_length // self.batch_size)
        loss_actor = loss_actor / n_train_iteration
        l_clip = l_clip / n_train_iteration
        bc_loss = bc_loss / n_train_iteration
        self.writer.add_scalar('/multiplenet/policy_phase/actor', loss_actor, self.update_step)
        self.writer.add_scalar('/multiplenet/policy_phase/l_clip', l_clip, self.update_step)
        self.writer.add_scalar('/multiplenet/policy_phase/bc_loss', bc_loss, self.update_step)

        with torch.no_grad():
            log_pis_old = self.multipleNet.evaluate_log_pi(state_batch[:-1], action_batch)

        # 3. auxialry phase, update multiplenet and critic
        # https://arxiv.org/pdf/2009.04416.pdf algorithm 1, line 12-14
        # if self.debug_no_aux_phase is True, skip this phase which should makae this code equivalant to TPO
        loss_critic_multi, bc_loss, loss_joint, loss_critic_aux = 0, 0, 0, 0
        if (self.update_step % self.num_updates == 0) and (not self.debug_no_aux_phase):
            for _ in range(self.aux_num_updates):
                indices = np.arange(self.rollout_length)
                np.random.shuffle(indices)
                for start in range(0, self.rollout_length, self.batch_size):
                    idxes = indices[start:start + self.batch_size]
                    loss_critic_multi, bc_loss, loss_joint = self.update_actor_Auxiliary(state_batch[idxes], action_batch[idxes], log_pis_old[idxes], targets[idxes], advantages[idxes])
                    loss_critic_aux = self.update_critic_Auxiliary(state_batch[idxes], targets[idxes])
                    loss_critic_multi += loss_critic_multi
                    bc_loss += bc_loss
                    loss_joint += loss_joint
                    loss_critic_aux += loss_critic_aux
            # 4. initialize replay buffer to empty
            # https://arxiv.org/pdf/2009.04416.pdf algorithm 1, line 2
            self.memory.reset()

            # averaging losses and writeto tensorboard
            n_train_iteration = self.aux_num_updates * (self.rollout_length // self.batch_size)
            loss_critic_multi = loss_critic_multi / n_train_iteration
            bc_loss = bc_loss / n_train_iteration
            loss_joint = loss_joint / n_train_iteration
            loss_critic_aux = loss_critic_aux / n_train_iteration
            self.writer.add_scalar('/multiplenet/loss/auxialry_phase/critic', loss_critic_multi, self.update_step)
            self.writer.add_scalar('/multiplenet/loss/auxialry_phase/bc_loss', bc_loss, self.update_step)
            self.writer.add_scalar('/multiplenet/loss/auxialry_phase/loss_joint', loss_joint, self.update_step)
            self.writer.add_scalar('/critic/loss/auxialry_phase/critic', loss_critic_aux, self.update_step)
        self.multipleNet.eval()
        self.critic.eval()

    def update_actor_Auxiliary(self, states, actions, log_pis_old, targets, advantages):
        """loss = L^{joint}
        L^{joint} = L^{aux} + \beta_{clone} * KL(\pi_{old}, \pi_{current})
        In original paper,  L^{aux} =  mse(v_{pi}(s_t), v_targ) \\ taks for V_{\theta_\pi}
        """
        loss_critic = (self.multipleNet.q_forward(states) - targets).pow_(2).mean() * 0.5
        if self.clip_aux_multinet_critic_loss is not None:
            loss_critic = torch.clamp(loss_critic, min=0, max=self.clip_aux_critic_loss)
        loss_critic = self.aux_critic_loss_coef * loss_critic
        log_pis = self.multipleNet.evaluate_log_pi(states, actions)
        pis_old = log_pis_old.exp_()
        kl_loss = (pis_old * (log_pis - log_pis_old)).mean()
        
        loss_joint = loss_critic + self.beta_clone * kl_loss
        self.multipleNet_optimizer.zero_grad()
        loss_joint.backward(retain_graph=False)
        self.multipleNet_optimizer.step()
        return loss_critic, self.beta_clone * kl_loss, loss_joint
        
    def update_critic_Auxiliary(self, states, targets):
        """loss = L^{value} = mse(v(s) - v_targ)
        """
        # add * 0.5 according to https://arxiv.org/pdf/2009.04416.pdf page 2
        loss_critic_aux = (self.critic(states) - targets).pow_(2).mean() * 0.5
        if self.clip_aux_critic_loss is not None:
            loss_critic_aux = torch.clamp(loss_critic_aux, min=0, max=self.clip_aux_critic_loss)
        self.critic_optimizer.zero_grad()
        loss_critic_aux.backward(retain_graph=False)
        # nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
        self.critic_optimizer.step()
        return loss_critic_aux

    def update_critic(self, states, targets):
        """loss = L^{value} = mse(v(s) - v_targ)
        """
        # add * 0.5 according to https://arxiv.org/pdf/2009.04416.pdf page 2
        loss_critic = (self.critic(states) - targets).pow_(2).mean() * 0.5
        self.critic_optimizer.zero_grad()
        loss_critic.backward(retain_graph=False)
        # nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
        self.critic_optimizer.step()
        return loss_critic

    def update_MultipleNet(self, states, actions, log_pis_old, advantages):
        """policy phase, update multiplenet.
        loss =  L^{clip} + behavir_cloing_loss
        """
        log_pis = self.multipleNet.evaluate_log_pi(states, actions)
        mean_ent = log_pis.mean()
        ratios = (log_pis - log_pis_old).exp_()
        loss_actor1 = -ratios * advantages
        loss_actor2 = -torch.clamp(
            ratios,
            1.0 - self.clip_eps,
            1.0 + self.clip_eps
        ) * advantages
        l_clip = torch.max(loss_actor1, loss_actor2).mean()
        bc_loss = self.coef_ent * mean_ent
        loss_actor = l_clip + bc_loss    
        loss_actor.backward(retain_graph=False)
        if self.multipleet_upadte_clip_grad_norm is not None:
            torch.nn.utils.clip_grad_norm(self.multipleNet.parameters(), 
                                          self.multipleet_upadte_clip_grad_norm)
        #nn.utils.clip_grad_norm_(self.multipleNet.parameters(), self.max_grad_norm)
        self.multipleNet_optimizer.step()
        return loss_actor, l_clip, bc_loss

    def get_action(self, state):
        """select action that has maximus Q value.
        """
        self.multipleNet.eval()
        if not self.is_image:
            state_tensor = torch.tensor(self.normalize_state(state), dtype=torch.float).view(-1, self.num_state).to(self.device)
        else:
            state_tensor = torch.tensor(state.copy() / 255., dtype=torch.float).unsqueeze(0).to(self.device)
        with torch.no_grad():
            action, log_pis = self.multipleNet.sample(state_tensor)
            action = action.view(self.num_action).to('cpu').detach().numpy().copy()
        return action, log_pis
Exemplo n.º 5
0
class PpgAgent:
    def __init__(self,
                 observation_space,
                 action_space,
                 device,
                 gamma=0.995,
                 actor_lr=5e-4,
                 critic_lr=5e-4,
                 batch_size=128,
                 memory_size=50000,
                 tau=5e-3,
                 weight_decay=1e-2,
                 sigma=0.2,
                 noise_clip=0.5,
                 alpha=0.2,
                 alpha_lr=3e-4,
                 rollout_length=2048,
                 lambda_=0.95,
                 beta_clone=1.0,
                 coef_ent=0.01,
                 num_updates=32,
                 policy_epoch=1,
                 value_epoch=1,
                 aux_num_updates=6,
                 aux_epoch_batch=16,
                 max_grad_norm=0.5,
                 clip_eps=0.2,
                 writer=None,
                 is_image=False):
        super(PpgAgent, self).__init__()
        self.action_mean = (0.5 * (action_space.high + action_space.low))[0]
        self.action_halfwidth = (0.5 *
                                 (action_space.high - action_space.low))[0]
        self.num_state = observation_space.shape[0]
        self.num_action = action_space.shape[0]
        self.state_mean = None
        self.state_halfwidth = None
        if abs(observation_space.high[0]) != math.inf:
            self.state_mean = 0.5 * (observation_space.high +
                                     observation_space.low)
            self.state_halfwidth = 0.5 * (observation_space.high -
                                          observation_space.low)
        self.gamma = gamma
        self.batch_size = batch_size
        self.device = device
        self.multipleNet = MultipleNetwork(self.num_state,
                                           action_space,
                                           device,
                                           is_image=is_image).to(self.device)
        self.multipleNet_target = MultipleNetwork(self.num_state,
                                                  action_space,
                                                  device,
                                                  is_image=is_image).to(
                                                      self.device)
        self.multipleNet_target.load_state_dict(self.multipleNet.state_dict())
        self.multipleNet_optimizer = optim.Adam(self.multipleNet.parameters(),
                                                lr=actor_lr)

        self.critic = CriticNetwork(self.num_state,
                                    action_space,
                                    device,
                                    is_image=is_image).to(self.device)
        self.critic_target = CriticNetwork(self.num_state,
                                           action_space,
                                           device,
                                           is_image=is_image).to(self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(),
                                           lr=critic_lr,
                                           weight_decay=weight_decay)

        self.alpha = alpha
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.log_alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr)

        self.memory = ReplayMemory(observation_space,
                                   action_space,
                                   device,
                                   num_state=self.num_state,
                                   memory_size=memory_size,
                                   is_image=is_image)
        self.criterion = nn.MSELoss()
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.tau = tau
        self.writer = writer
        self.update_step = 0
        self.is_image = is_image
        self.sigma = sigma
        self.noise_clip = noise_clip
        self.rollout_length = rollout_length
        self.lambda_ = lambda_
        self.coef_ent = coef_ent
        self.max_grad_norm = max_grad_norm
        self.aux_num_updates = aux_num_updates
        self.clip_eps = clip_eps
        self.beta_clone = beta_clone
        self.policy_epoch = policy_epoch
        self.value_epoch = value_epoch
        self.num_updates = num_updates
        self.aux_epoch_batch = aux_epoch_batch

    def normalize_state(self, state):
        if self.state_mean is None:
            return state
        state = (state - self.state_mean) / self.state_halfwidth
        return state

    def soft_update(self, target_net, net):
        for target_param, param in zip(target_net.parameters(),
                                       net.parameters()):
            target_param.data.copy_(self.tau * param.data +
                                    (1 - self.tau) * target_param.data)

    def is_update(self, steps):
        return steps % self.rollout_length == 0

    def update(self, state=None):
        if not self.is_update(self.memory.index):
            return
        self.update_step += 1
        with torch.no_grad():
            batch = self.memory.sample(state)
            #各サンプルにおける状態行動の値を取ってくる
            action_batch = batch['actions'].to(self.device)
            state_batch = batch['obs'].to(self.device)
            reward_batch = batch['rewards'].to(self.device)
            terminate_batch = batch['terminates'].to(self.device)
            log_pis_batch = batch['log_pis'].to(self.device)

            values = self.critic(state_batch)
        targets, advantages = util.calculate_advantage(values, reward_batch,
                                                       terminate_batch,
                                                       self.gamma,
                                                       self.lambda_)
        for j in range(self.num_updates):
            for i in range(max(self.policy_epoch, self.value_epoch)):
                indices = np.arange(self.rollout_length)
                np.random.shuffle(indices)
                for start in range(0, self.rollout_length, self.batch_size):
                    idxes = indices[start:start + self.batch_size]
                    if self.policy_epoch > i:
                        self.update_MultipleNet(state_batch[idxes],
                                                action_batch[idxes],
                                                log_pis_batch[idxes],
                                                advantages[idxes])
                    if self.value_epoch > i:
                        self.update_critic(state_batch[idxes], targets[idxes])
        with torch.no_grad():
            log_pis_old = self.multipleNet.evaluate_log_pi(
                state_batch[:-1], action_batch)
        for _ in range(self.aux_num_updates):
            indices = np.arange(self.rollout_length)
            np.random.shuffle(indices)
            for start in range(0, self.rollout_length, self.aux_epoch_batch):
                idxes = indices[start:start + self.aux_epoch_batch]
                self.update_actor_Auxiliary(state_batch[idxes],
                                            action_batch[idxes],
                                            log_pis_old[idxes], targets[idxes],
                                            advantages[idxes])
                self.update_critic_Auxiliary(state_batch[idxes],
                                             targets[idxes])
        self.multipleNet.eval()
        self.critic.eval()

    def update_actor_Auxiliary(self, states, actions, log_pis_old, targets,
                               advantages):
        loss_critic = (self.multipleNet.q_forward(states) -
                       targets).pow_(2).mean()
        loss_bc = (self.multipleNet.p_forward(states) - actions).pow_(2).mean()
        log_pis = self.multipleNet.evaluate_log_pi(states, actions)
        pis_old = log_pis_old.exp_()
        kl_loss = (pis_old * (log_pis - log_pis_old)).mean()

        loss_joint = loss_critic + self.beta_clone * kl_loss
        self.multipleNet_optimizer.zero_grad()
        loss_joint.backward(retain_graph=False)
        self.multipleNet_optimizer.step()
        if self.update_step % 10 == 0:
            print("aux actor loss:", loss_joint.item())

    def update_critic_Auxiliary(self, states, targets):
        loss_critic_aux = (self.critic(states) - targets).pow_(2).mean()
        self.critic_optimizer.zero_grad()
        loss_critic_aux.backward(retain_graph=False)
        #nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
        self.critic_optimizer.step()
        #if self.update_step % 50 == 0:
        #    print("aux citic loss:", loss_critic_aux.item())

    def update_critic(self, states, targets):
        loss_critic = (self.critic(states) - targets).pow_(2).mean()
        self.critic_optimizer.zero_grad()
        loss_critic.backward(retain_graph=False)
        #nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
        self.critic_optimizer.step()
        #if self.update_step % 50 == 0:
        #    print("citic loss:", loss_critic.item())

    def update_MultipleNet(self, states, actions, log_pis_old, advantages):
        log_pis = self.multipleNet.evaluate_log_pi(states, actions)
        if self.update_step % 50 == 0:
            print("log_pis:", log_pis)
        mean_ent = -log_pis.mean()
        ratios = (log_pis - log_pis_old).exp_()
        loss_actor1 = -ratios * advantages
        loss_actor2 = -torch.clamp(ratios, 1.0 - self.clip_eps,
                                   1.0 + self.clip_eps) * advantages
        loss_actor = torch.max(loss_actor1,
                               loss_actor2).mean() - self.coef_ent * mean_ent
        self.multipleNet_optimizer.zero_grad()
        loss_actor.backward(retain_graph=False)
        #nn.utils.clip_grad_norm_(self.multipleNet.parameters(), self.max_grad_norm)
        self.multipleNet_optimizer.step()
        if self.update_step % 50 == 0:
            print("actor loss:", loss_actor.item())

    # Q値が最大の行動を選択
    def get_action(self, state):
        self.multipleNet.eval()
        if not self.is_image:
            state_tensor = torch.tensor(self.normalize_state(state),
                                        dtype=torch.float).view(
                                            -1, self.num_state).to(self.device)
        else:
            state_tensor = torch.tensor(state.copy() / 255.,
                                        dtype=torch.float).unsqueeze(0).to(
                                            self.device)
        with torch.no_grad():
            action, log_pis = self.multipleNet.sample(state_tensor)
            action = action.view(
                self.num_action).to('cpu').detach().numpy().copy()
        return action, log_pis