예제 #1
0
class ActorCriticAgent():
    def __init__(self):
        # if gpu is to be used
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device
        self.policy_net = ActorCritic().to(device).double()
        self.target_net = ActorCritic().to(device).double()
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.lr = 1e-5
        self.optimizer = optim.Adam([
            {
                "params": self.policy_net.head_a_m.parameters()
            },
            {
                "params": self.policy_net.head_a_t.parameters()
            },
            {
                "params": self.policy_net.fc.parameters()
            },
        ],
                                    lr=self.lr)
        self.optimizer2 = optim.Adam([
            {
                "params": self.policy_net.head_v.parameters()
            },
            {
                "params": self.policy_net.fc.parameters()
            },
        ],
                                     lr=self.lr)
        self.memory = ReplayMemory(100000)

    def preprocess(self, state: RobotState):
        return torch.tensor([state.scan]).double()  # 1x2xseq

    def run_AC(self, tensor_state):
        with torch.no_grad():
            self.target_net.eval()
            a_m, a_t, v = self.target_net(tensor_state.to(self.device))
        a_m = a_m.cpu().numpy()[0]  # left, ahead, right
        a_t = a_t.cpu().numpy()[0]  # turn left, stay, right

        return a_m, a_t

    def decode_action(self, a_m, a_t, state, mode):
        if mode == "max_probability":
            a_m = np.argmax(a_m)
            a_t = np.argmax(a_t)
        elif mode == "sample":
            #a_m += 0.01
            a_m /= a_m.sum()
            a_m = np.random.choice(range(3), p=a_m)
            #a_t += 0.01
            a_t /= a_t.sum()
            a_t = np.random.choice(range(3), p=a_t)

        action = Action()
        if a_m == 0:  # left
            action.v_n = -1.0
        elif a_m == 1:  # ahead
            action.v_t = +1.0
        elif a_m == 2:  # right
            action.v_n = +1.0

        if a_t == 0:  # left
            action.angular = +1.0
        elif a_t == 1:  # stay
            action.angular = 0.0
        elif a_t == 2:  # right
            action.angular = -1.0

        if state.detect:
            action.shoot = +1.0
        else:
            action.shoot = 0.0

        return action

    def select_action(self, state, mode):
        tensor_state = self.preprocess(state).to(self.device)
        a_m, a_t = self.run_AC(tensor_state)
        action = self.decode_action(a_m, a_t, state, mode)

        return action

    def push(self, state, next_state, action, reward):
        self.memory.push(state, action, next_state, reward)

    def make_state_map(self, state):
        return torch.cat(state, dim=0).double()  # batchx2xseq

    def sample_memory(self, is_test=False):
        device = self.device
        transitions = self.memory.sample(BATCH_SIZE, is_test)
        # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
        # detailed explanation). This converts batch-array of Transitions
        # to Transition of batch-arrays.
        batch = Transition(*zip(*transitions))

        #state_batch = torch.cat(batch.state).to(device)
        #next_state_batch = torch.cat(batch.next_state).to(device)
        state_batch = self.make_state_map(batch.state).double()
        next_state_batch = self.make_state_map(batch.next_state).double()
        action_batch = torch.tensor(batch.action).double()
        reward_batch = torch.tensor(batch.reward).double()

        return state_batch, action_batch, reward_batch, next_state_batch

    def optimize_once(self, data):
        state_batch, action_batch, reward_batch, next_state_batch = data
        device = self.device
        state_batch = state_batch.to(device)
        action_batch = action_batch.to(device)
        reward_batch = reward_batch.to(device)
        next_state_batch = next_state_batch.to(device)

        self.policy_net.train()
        state_batch = Variable(state_batch, requires_grad=True)
        a_m, a_t, value_eval = self.policy_net(state_batch)  # batch, 1, 10, 16
        ### Critic ###
        td_error = reward_batch - value_eval
        loss = nn.MSELoss()(value_eval, reward_batch)
        self.optimizer2.zero_grad()
        loss.backward(retain_graph=True)
        self.optimizer2.step()
        ### Actor ###
        #prob = x.gather(1, (action_batch[:,0:1]*32).long()) * y.gather(1, (action_batch[:,1:2]*20).long())
        prob_m = a_m.gather(1, action_batch[:, 0].long())
        prob_t = a_t.gather(1, action_batch[:, 1].long())
        log_prob = torch.log(prob_m * prob_t + 1e-6)
        exp_v = torch.mean(log_prob * td_error.detach())
        loss = -exp_v + F.smooth_l1_loss(value_eval, reward_batch)
        self.optimizer.zero_grad()
        loss.backward()
        # for param in self.model.parameters():
        # if param.grad is not None:
        #param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

        return loss.item()

    def optimize_online(self):
        if len(self.memory) < BATCH_SIZE:
            return
        data = self.sample_memory()
        loss = self.optimize_once(data)
        return loss

    def test_model(self):
        if len(self.memory) < BATCH_SIZE:
            return
        state_batch, action_batch, reward_batch, next_state_batch = self.sample_memory(
            True)
        device = self.device
        state_batch = state_batch.to(device)
        action_batch = action_batch.to(device)
        reward_batch = reward_batch.to(device)
        next_state_batch = next_state_batch.to(device)

        with torch.no_grad():
            self.target_net.eval()
            a_m, a_t, value_eval = self.target_net(
                state_batch)  # batch, 1, 10, 16
            ### Critic ###
            td_error = reward_batch - value_eval
            ### Actor ###
            #prob = x.gather(1, (action_batch[:,0:1]*32).long()) * y.gather(1, (action_batch[:,1:2]*20).long())
            prob_m = a_m.gather(1, action_batch[:, 0:1].long())
            prob_t = a_t.gather(1, action_batch[:, 1:2].long())
            log_prob = torch.log(prob_m * prob_t + 1e-6)
            exp_v = torch.mean(log_prob * td_error.detach())
            loss = -exp_v

        return loss.item()

    def save_model(self, file_path):
        torch.save(self.policy_net.state_dict(), file_path)

    def save_memory(self, file_path):
        torch.save(self.memory, file_path)

    def load_model(self, file_path):
        self.policy_net.load_state_dict(
            torch.load(file_path, map_location=self.device))
        # FIXME 开场直接加载已获得参数作为经验
        # self.update_target_net()

    def load_memory(self, file_path):
        self.memory = torch.load(file_path)

    def optimize_offline(self, num_epoch):
        def batch_state_map(transitions):
            batch = Transition(*zip(*transitions))
            state_batch = self.make_state_map(batch.state).double()
            next_state_batch = self.make_state_map(batch.next_state).double()
            action_batch = torch.tensor(batch.action).double()
            reward_batch = torch.tensor(batch.reward).double()
            return state_batch, action_batch, reward_batch, next_state_batch

        dataloader = DataLoader(self.memory.main_memory,
                                batch_size=BATCH_SIZE,
                                shuffle=True,
                                collate_fn=batch_state_map,
                                num_workers=0,
                                pin_memory=True)
        device = self.device
        for epoch in range(num_epoch):
            #print("Train epoch: [{}/{}]".format(epoch, num_epoch))
            for data in (dataloader):
                loss = self.optimize_once(data)
            #loss = self.test_model()
            #print("Test loss: {}".format(loss))
        return loss

    def decay_LR(self, decay):
        self.lr *= decay
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr)

    def update_target_net(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())
예제 #2
0
파일: train.py 프로젝트: Humoon/RL-VAEGAN
def train(rank, args, shared_model, counter, lock, optimizer=None):
    print('Train with A3C')
    torch.manual_seed(args.seed + rank)

    env = create_atari_env(args.env_name, args)
    env.seed(args.seed + rank)

    model = ActorCritic(env.observation_space.shape[0], env.action_space)

    if optimizer is None:
        optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)

    model.train()
    output_directory = 'outputs/' + args.env_name
    checkpoint_directory, result_directory = prepare_sub_folder(
        output_directory)
    print(f'checkpoint directory {checkpoint_directory}')
    time.sleep(10)
    state = env.reset()
    state = torch.from_numpy(state)
    done = True
    episode_length = 0
    total_step = 0
    rewards_ep = []
    policy_loss_ep = []
    value_loss_ep = []
    for epoch in range(100000000):
        # Sync with the shared model
        model.load_state_dict(shared_model.state_dict())

        values = []
        log_probs = []
        rewards = []
        entropies = []

        # for step in range(args.num_steps):
        is_Terminal = False
        while not is_Terminal:
            episode_length += 1
            total_step += 1
            value, logit = model(state.unsqueeze(0))
            prob = F.softmax(logit, dim=-1)
            log_prob = F.log_softmax(logit, dim=-1)
            entropy = -(log_prob * prob).sum(1, keepdim=True)
            entropies.append(entropy)

            action = prob.multinomial(num_samples=1).detach()
            log_prob = log_prob.gather(1, action)

            state, reward, done, _ = env.step(action.numpy())

            done = done or episode_length >= args.max_episode_length
            reward = max(min(reward, 1), -1)

            with lock:
                counter.value += 1

            if done:
                # print(episode_length)
                print(
                    f'epoch {epoch} - steps {total_step} - total rewards {np.sum(rewards) + reward}'
                )
                total_step = 1
                episode_length = 0
                state = env.reset()

            state = torch.from_numpy(state)
            values.append(value)
            log_probs.append(log_prob)
            rewards.append(reward)

            if done:
                rewards_ep.append(np.sum(rewards))
                is_Terminal = True
                # break

        R = torch.zeros(1, 1)
        if not done:
            value, _ = model(state.unsqueeze(0))
            R = value.detach()

        values.append(R)
        policy_loss = 0
        value_loss = 0
        gae = torch.zeros(1, 1)
        for i in reversed(range(len(rewards))):
            R = args.gamma * R + rewards[i]
            advantage = R - values[i]
            value_loss = value_loss + 0.5 * advantage.pow(2)

            # Generalized Advantage Estimataion
            delta_t = rewards[i] + args.gamma * \
                values[i + 1] - values[i]
            gae = gae * args.gamma * args.tau + delta_t

            policy_loss = policy_loss - \
                log_probs[i] * gae.detach() - args.entropy_coef * entropies[i]

        optimizer.zero_grad()

        policy_loss_ep.append(policy_loss.detach().numpy()[0, 0])
        value_loss_ep.append(value_loss.detach().numpy()[0, 0])

        (policy_loss + args.value_loss_coef * value_loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

        ensure_shared_grads(model, shared_model)
        optimizer.step()

        if epoch % 1000 == 0:
            torch.save({'state_dict': model.state_dict()},
                       checkpoint_directory + '/' + str(epoch) + ".pth.tar")
            with open(result_directory + '/' + str(epoch) + '_rewards.pkl',
                      'wb') as f:
                pickle.dump(rewards_ep, f)
            with open(result_directory + '/' + str(epoch) + '_policy_loss.pkl',
                      'wb') as f:
                pickle.dump(policy_loss_ep, f)
            with open(result_directory + '/' + str(epoch) + '_value_loss.pkl',
                      'wb') as f:
                pickle.dump(value_loss_ep, f)

        if episode_length >= 10000000:
            break

    torch.save({
        'state_dict': model.state_dict(),
    }, checkpoint_directory + '/Last' + ".pth.tar")