예제 #1
0
class ActorCriticAgent():
    def __init__(self):
        # if gpu is to be used
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device
        self.policy_net = ActorCritic().to(device).double()
        self.target_net = ActorCritic().to(device).double()
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.lr = 1e-5
        self.optimizer = optim.Adam([
            {
                "params": self.policy_net.head_a_m.parameters()
            },
            {
                "params": self.policy_net.head_a_t.parameters()
            },
            {
                "params": self.policy_net.fc.parameters()
            },
        ],
                                    lr=self.lr)
        self.optimizer2 = optim.Adam([
            {
                "params": self.policy_net.head_v.parameters()
            },
            {
                "params": self.policy_net.fc.parameters()
            },
        ],
                                     lr=self.lr)
        self.memory = ReplayMemory(100000)

    def preprocess(self, state: RobotState):
        return torch.tensor([state.scan]).double()  # 1x2xseq

    def run_AC(self, tensor_state):
        with torch.no_grad():
            self.target_net.eval()
            a_m, a_t, v = self.target_net(tensor_state.to(self.device))
        a_m = a_m.cpu().numpy()[0]  # left, ahead, right
        a_t = a_t.cpu().numpy()[0]  # turn left, stay, right

        return a_m, a_t

    def decode_action(self, a_m, a_t, state, mode):
        if mode == "max_probability":
            a_m = np.argmax(a_m)
            a_t = np.argmax(a_t)
        elif mode == "sample":
            #a_m += 0.01
            a_m /= a_m.sum()
            a_m = np.random.choice(range(3), p=a_m)
            #a_t += 0.01
            a_t /= a_t.sum()
            a_t = np.random.choice(range(3), p=a_t)

        action = Action()
        if a_m == 0:  # left
            action.v_n = -1.0
        elif a_m == 1:  # ahead
            action.v_t = +1.0
        elif a_m == 2:  # right
            action.v_n = +1.0

        if a_t == 0:  # left
            action.angular = +1.0
        elif a_t == 1:  # stay
            action.angular = 0.0
        elif a_t == 2:  # right
            action.angular = -1.0

        if state.detect:
            action.shoot = +1.0
        else:
            action.shoot = 0.0

        return action

    def select_action(self, state, mode):
        tensor_state = self.preprocess(state).to(self.device)
        a_m, a_t = self.run_AC(tensor_state)
        action = self.decode_action(a_m, a_t, state, mode)

        return action

    def push(self, state, next_state, action, reward):
        self.memory.push(state, action, next_state, reward)

    def make_state_map(self, state):
        return torch.cat(state, dim=0).double()  # batchx2xseq

    def sample_memory(self, is_test=False):
        device = self.device
        transitions = self.memory.sample(BATCH_SIZE, is_test)
        # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
        # detailed explanation). This converts batch-array of Transitions
        # to Transition of batch-arrays.
        batch = Transition(*zip(*transitions))

        #state_batch = torch.cat(batch.state).to(device)
        #next_state_batch = torch.cat(batch.next_state).to(device)
        state_batch = self.make_state_map(batch.state).double()
        next_state_batch = self.make_state_map(batch.next_state).double()
        action_batch = torch.tensor(batch.action).double()
        reward_batch = torch.tensor(batch.reward).double()

        return state_batch, action_batch, reward_batch, next_state_batch

    def optimize_once(self, data):
        state_batch, action_batch, reward_batch, next_state_batch = data
        device = self.device
        state_batch = state_batch.to(device)
        action_batch = action_batch.to(device)
        reward_batch = reward_batch.to(device)
        next_state_batch = next_state_batch.to(device)

        self.policy_net.train()
        state_batch = Variable(state_batch, requires_grad=True)
        a_m, a_t, value_eval = self.policy_net(state_batch)  # batch, 1, 10, 16
        ### Critic ###
        td_error = reward_batch - value_eval
        loss = nn.MSELoss()(value_eval, reward_batch)
        self.optimizer2.zero_grad()
        loss.backward(retain_graph=True)
        self.optimizer2.step()
        ### Actor ###
        #prob = x.gather(1, (action_batch[:,0:1]*32).long()) * y.gather(1, (action_batch[:,1:2]*20).long())
        prob_m = a_m.gather(1, action_batch[:, 0].long())
        prob_t = a_t.gather(1, action_batch[:, 1].long())
        log_prob = torch.log(prob_m * prob_t + 1e-6)
        exp_v = torch.mean(log_prob * td_error.detach())
        loss = -exp_v + F.smooth_l1_loss(value_eval, reward_batch)
        self.optimizer.zero_grad()
        loss.backward()
        # for param in self.model.parameters():
        # if param.grad is not None:
        #param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

        return loss.item()

    def optimize_online(self):
        if len(self.memory) < BATCH_SIZE:
            return
        data = self.sample_memory()
        loss = self.optimize_once(data)
        return loss

    def test_model(self):
        if len(self.memory) < BATCH_SIZE:
            return
        state_batch, action_batch, reward_batch, next_state_batch = self.sample_memory(
            True)
        device = self.device
        state_batch = state_batch.to(device)
        action_batch = action_batch.to(device)
        reward_batch = reward_batch.to(device)
        next_state_batch = next_state_batch.to(device)

        with torch.no_grad():
            self.target_net.eval()
            a_m, a_t, value_eval = self.target_net(
                state_batch)  # batch, 1, 10, 16
            ### Critic ###
            td_error = reward_batch - value_eval
            ### Actor ###
            #prob = x.gather(1, (action_batch[:,0:1]*32).long()) * y.gather(1, (action_batch[:,1:2]*20).long())
            prob_m = a_m.gather(1, action_batch[:, 0:1].long())
            prob_t = a_t.gather(1, action_batch[:, 1:2].long())
            log_prob = torch.log(prob_m * prob_t + 1e-6)
            exp_v = torch.mean(log_prob * td_error.detach())
            loss = -exp_v

        return loss.item()

    def save_model(self, file_path):
        torch.save(self.policy_net.state_dict(), file_path)

    def save_memory(self, file_path):
        torch.save(self.memory, file_path)

    def load_model(self, file_path):
        self.policy_net.load_state_dict(
            torch.load(file_path, map_location=self.device))
        # FIXME 开场直接加载已获得参数作为经验
        # self.update_target_net()

    def load_memory(self, file_path):
        self.memory = torch.load(file_path)

    def optimize_offline(self, num_epoch):
        def batch_state_map(transitions):
            batch = Transition(*zip(*transitions))
            state_batch = self.make_state_map(batch.state).double()
            next_state_batch = self.make_state_map(batch.next_state).double()
            action_batch = torch.tensor(batch.action).double()
            reward_batch = torch.tensor(batch.reward).double()
            return state_batch, action_batch, reward_batch, next_state_batch

        dataloader = DataLoader(self.memory.main_memory,
                                batch_size=BATCH_SIZE,
                                shuffle=True,
                                collate_fn=batch_state_map,
                                num_workers=0,
                                pin_memory=True)
        device = self.device
        for epoch in range(num_epoch):
            #print("Train epoch: [{}/{}]".format(epoch, num_epoch))
            for data in (dataloader):
                loss = self.optimize_once(data)
            #loss = self.test_model()
            #print("Test loss: {}".format(loss))
        return loss

    def decay_LR(self, decay):
        self.lr *= decay
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr)

    def update_target_net(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())
예제 #2
0
# Setup model and data loader
trainer = RL_VAEGAN(config)
trainer.cuda()

# Setup output folders
output_directory = opts.output_path + '/output/' + opts.env_name

checkpoint_directory, result_directory = prepare_sub_folder(output_directory)
print('checkpoint: ', checkpoint_directory)
print(f'attacker {opts.attacker} with epsilon {opts.epsilon_adv}')

# Start training
iterations = 0

trained_model.eval().cuda()
state = env.reset()  #(3,80,80)
state = torch.from_numpy(state).unsqueeze(0).cuda()
episode_length = 1
epoch = 0
actions = deque(maxlen=100)
rewards = []
while True:
    epoch += 1
    episode_length = 1
    is_Terminal = False
    rewards = []
    while not is_Terminal:
        value, logit = trained_model(state)  #(1,3,80,80)
        prob = F.softmax(logit, dim=-1)
        action = prob.multinomial(num_samples=1)[0]