コード例 #1
0
ファイル: train.py プロジェクト: renqibing/RL_DDPG
class DDPG_trainer(object):
    def __init__(self, nb_state, nb_action):
        self.nb_state = nb_state
        self.nb_action = nb_action

        self.actor = Actor(self.nb_state, self.nb_action)
        self.actor_target = Actor(self.nb_state, self.nb_action)
        self.actor_optim = Adam(self.actor.parameters(), lr=LEARNING_RATE)

        self.critic = Critic(self.nb_state, self.nb_action)
        self.critic_target = Critic(self.nb_state, self.nb_action)
        self.critic_optim = Adam(self.critic.parameters(), lr=LEARNING_RATE)

        hard_update(self.actor_target,
                    self.actor)  # Make sure target is with the same weight
        hard_update(self.critic_target, self.critic)

        #Create replay buffer
        self.memory = SequentialMemory(limit=MEMORY_SIZE, window_length=1)
        self.random_process = OrnsteinUhlenbeckProcess(size=nb_action,
                                                       theta=OU_THETA,
                                                       mu=OU_MU,
                                                       sigma=OU_SIGMA)

        self.is_training = True
        self.epsilon = 1.0
        self.a_t = None
        self.s_t = None

        if USE_CUDA: self.cuda()

    def cuda(self):
        self.actor.cuda()
        self.actor_target.cuda()
        self.critic.cuda()
        self.critic_target.cuda()

    def select_action(self, s_t, decay_epsilon=True):

        action = to_numpy(self.actor(to_tensor(np.array([s_t])))).squeeze(0)
        action += self.is_training * max(self.epsilon,
                                         0) * self.random_process.sample()
        action = np.clip(action, -1., 1.)

        if decay_epsilon:
            self.epsilon -= DELTA_EPSILON

        self.a_t = action
        return action

    def reset(self, observation):
        self.start_state = observation
        self.random_process.reset_states()

    def observe(self, r_t, s_t1, done):

        if self.is_training:
            self.memory.append(self.s_t, self.a_t, r_t, done)
            self.s_t = s_t1

    def update_all(self):
        # Help Warm Up
        if self.memory.nb_entries < BATCH_SIZE * 2:
            return

        # Sample batch
        state_batch, action_batch, reward_batch, \
        next_state_batch, terminal_batch = self.memory.sample_and_split(BATCH_SIZE)

        # Prepare for the target q batch
        with torch.no_grad():
            next_q_values = self.critic_target([
                to_tensor(next_state_batch),
                self.actor_target(to_tensor(next_state_batch)),
            ])

        target_q_batch = to_tensor(reward_batch) + \
                         DISCOUNT * to_tensor(terminal_batch.astype(np.float)) * next_q_values

        # Critic update
        self.critic.zero_grad()
        for state in state_batch:
            if state.shape[0] <= 2:
                # print("Error sampled memory!")
                return

        q_batch = self.critic(
            [to_tensor(state_batch),
             to_tensor(action_batch)])
        value_loss = CRITERION(q_batch, target_q_batch)
        value_loss.backward()
        self.critic_optim.step()

        # Actor update
        self.actor.zero_grad()

        policy_loss = -self.critic(
            [to_tensor(state_batch),
             self.actor(to_tensor(state_batch))])

        policy_loss = policy_loss.mean()
        policy_loss.backward()
        self.actor_optim.step()

        # Target update
        soft_update(self.actor_target, self.actor, TAU)
        soft_update(self.critic_target, self.critic, TAU)
コード例 #2
0
class DDPG(object):
    def __init__(self, nb_states, nb_actions, args):
        # self.cuda = USE_CUDA #args.cuda
        self.cuda = args.cuda

        self.nb_states = nb_states
        self.nb_actions = nb_actions

        #Init models
        #actor_kwargs = {'n_inp':self.nb_states, 'n_feature_list':[args.hidden1,args.hidden2], 'n_class':self.nb_actions}
        #self.actor = MLP(**actor_kwargs)
        #self.actor_target = MLP(**actor_kwargs)
        #self.critic = MLP(**actor_kwargs)  #TODO: actor and critic has same structure for now.
        #self.critic_target = MLP(**actor_kwargs)

        net_cfg = {
            'hidden1': args.hidden1,
            'hidden2': args.hidden2,
            'init_w': args.init_w
        }
        self.actor = Actor(self.nb_states, self.nb_actions, **net_cfg)
        self.actor_target = Actor(self.nb_states, self.nb_actions, **net_cfg)

        self.critic = Critic(self.nb_states, self.nb_actions, **net_cfg)
        self.critic_target = Critic(self.nb_states, self.nb_actions, **net_cfg)

        self.criterion = nn.MSELoss()
        if self.cuda:
            self.actor = self.actor.cuda(
            )  # torch.nn.DataParallel(self.model).cuda()  #TODO dataparallel not working
            self.critic = self.critic.cuda()
            self.actor_target = self.actor_target.cuda()
            self.critic_target = self.critic_target.cuda()
            self.criterion = self.criterion.cuda()

        # Set optimizer
        self.actor_optim = torch.optim.Adam(self.actor.parameters(),
                                            lr=args.prate)
        self.critic_optim = torch.optim.Adam(self.critic.parameters(),
                                             lr=args.rate)
        # Loss function
        self.loss_fn = torch.nn.MSELoss(size_average=False)

        hard_update(self.actor_target,
                    self.actor)  # Make sure target is with the same weight
        hard_update(self.critic_target, self.critic)

        self.memory = SequentialMemory(limit=args.rmsize,
                                       window_length=args.window_length)
        self.random_process = OrnsteinUhlenbeckProcess(size=nb_actions,
                                                       theta=args.ou_theta,
                                                       mu=args.ou_mu,
                                                       sigma=args.ou_sigma)

        # Hyper-parameters
        self.batch_size = args.bsize
        self.tau = args.tau
        self.discount = args.discount
        self.depsilon = 1.0 / args.epsilon

        self.epsilon = 1.0
        self.s_t = None  # Most recent state
        self.a_t = None  # Most recent action
        self.is_training = True

    def update_policy(self):
        # Sample batch
        state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = self.memory.sample_and_split(
            self.batch_size)

        # Prepare for the target q batch
        next_q_values = self.critic_target([
            to_tensor(next_state_batch, volatile=True),
            self.actor_target(to_tensor(next_state_batch, volatile=True)),
        ])
        next_q_values.volatile = False

        target_q_batch = to_tensor(reward_batch) + \
            self.discount*to_tensor(terminal_batch.astype(np.float))*next_q_values

        # Critic update
        self.critic.zero_grad()

        q_batch = self.critic(
            [to_tensor(state_batch),
             to_tensor(action_batch)])

        value_loss = self.criterion(q_batch, target_q_batch)
        value_loss.backward()
        self.critic_optim.step()

        # Actor update
        self.actor.zero_grad()

        policy_loss = -self.critic(
            [to_tensor(state_batch),
             self.actor(to_tensor(state_batch))])

        policy_loss = policy_loss.mean()
        policy_loss.backward()
        self.actor_optim.step()

        # Target update
        soft_update(self.actor_target, self.actor, self.tau)
        soft_update(self.critic_target, self.critic, self.tau)

    def eval(self):
        self.actor.eval()
        self.actor_target.eval()
        self.critic.eval()
        self.critic_target.eval()

    def random_action(self):
        action = np.random.uniform(-1., 1., self.nb_actions)
        self.a_t = action
        return action

    def observe(self, r_t, s_t1, done):
        if self.is_training:
            self.memory.append(self.s_t, self.a_t, r_t, done)
            self.s_t = s_t1

    def select_action(self, s_t, decay_epsilon=True):
        action = to_numpy(self.actor(to_tensor(np.array([s_t])))).squeeze(0)
        action += self.is_training * max(self.epsilon,
                                         0) * self.random_process.sample()
        action = np.clip(action, -1., 1.)

        if decay_epsilon:
            self.epsilon -= self.depsilon

        self.a_t = action
        return action

    def reset(self, obs):
        self.s_t = obs
        self.random_process.reset_states()

    def load_wts(self, modelfile):
        if os.path.isfile(modelfile + 'model.pth.tar'):
            checkpoint = torch.load(modelfile)
            self.actor.load_state_dict(checkpoint['actor_state_dict'])
            self.critic.load_state_dict(checkpoint['critic_state_dict'])
            self.actor_optim.load_state_dict(checkpoint['actor_optim'])
            self.critic_optim.load_state_dict(checkpoint['critic_optim'])
            return checkpoint['step']
        else:
            return 0

    def save_wts(self, savefile, step):
        saveme = {  #TODO save other stuff too, like epoch etc
            'actor_state_dict': self.actor.state_dict(),
            'critic_state_dict': self.critic.state_dict(),
            'actor_optim': self.actor_optim.state_dict(),
            'critic_optim': self.critic_optim.state_dict(),
            'step': step
        }
        torch.save(saveme, savefile + 'model.pth.tar')