Ejemplo n.º 1
0
class A3CSpliterPPOEnv(game_env.GameEnv):
    def __init__(self, *args, **kwargs):
        super(A3CSpliterPPOEnv, self).__init__(*args, **kwargs)

        random.seed(time.time())

        self.game_no = 0

        self.out_classes = 9

        self.a3c_model = ActorCritic(5, self.out_classes, 64)
        self.optimizer = optim.SGD(self.a3c_model.parameters(), lr=0.001)

        self.reset()

        self.gamma = 0.99
        self.tau = 1.0
        self.entropy_coef = 0.01
        self.epsilon = 0.2

        self.batch_size = 256
        self.buffer_size = 1000
        self.i = 0

        self.update_steps = 3

        self.nll_loss_fn = nn.NLLLoss()

        self.sw = SpliterWarpper(5, 128)

    def reset(self):
        self.states = []
        self.actions = []
        self.entropies = []
        self.values = []
        self.rewards = []
        self.log_probs = []
        self.raw_log_probs = []
        self.raw_probs = []
        self.predefined_steps = []

    def update(self):
        print("Game %d updating" % self.game_no)
        self.optimizer.step()
        self.a3c_model.zero_grad()
        self.optimizer.zero_grad()

    def ppo_train_actor(self, old_model):
        self.a3c_model.zero_grad()
        self.optimizer.zero_grad()

        l = 0.0
        R = torch.zeros(1, 1)

        reduced_r = []
        for i in reversed(range(len(self.rewards))):
            R = self.gamma * R + self.rewards[i]
            reduced_r.append(R)

        reduced_r = list(reversed(reduced_r))

        idxs = list(range(len(self.rewards)))
        random.shuffle(idxs)
        idxs = idxs[:self.batch_size]

        total_r = 0.0

        #TODO: turn `for loop` to tensor operations
        for i in idxs:
            new_prob, v = self.a3c_model(self.states[i])
            new_prob = F.softmax(new_prob)
            old_prob, _ = old_model(self.states[i])
            old_prob = F.softmax(old_prob)
            adv = reduced_r[i] - v.data
            onehot_act = torch.zeros(self.out_classes)
            onehot_act[self.actions[i]] = 1

            ratio = torch.sum(new_prob * onehot_act) / torch.sum(
                old_prob * onehot_act)
            surr = ratio * adv

            l = l - min(
                surr,
                torch.clamp(ratio, 1.0 - self.epsilon, 1.0 + self.epsilon) *
                adv)

        l = l / self.batch_size

        l.backward(retain_graph=True)
        writer.add_scalar("train/policy_loss", l.item())
        self.optimizer.step()

    def teacher_train(self):
        #TODO: teacher_loss * (1 - Gini coefficient)
        #TODO: entropy loss
        self.a3c_model.zero_grad()
        self.optimizer.zero_grad()
        teacher_loss = 0

        labels = torch.cat(self.predefined_steps)

        #balance loss
        weight = torch.zeros((self.out_classes, ))
        for i in range(self.out_classes):
            weight[i] = torch.sum(labels == i)
            if weight[i] > 0:
                weight[i] = 1. / weight[i]

        nll = nn.NLLLoss(weight=weight)

        log_probs = torch.cat(self.raw_log_probs)

        teacher_loss = nll(log_probs, labels)
        teacher_loss.backward(retain_graph=True)

        writer.add_scalar("train/teacher_loss", teacher_loss.item())

        self.optimizer.step()

    def ppo_train_critic(self):

        self.a3c_model.zero_grad()
        self.optimizer.zero_grad()

        R = torch.zeros(1, 1)
        l = 0.0

        reduced_r = []
        for i in reversed(range(len(self.rewards))):
            R = self.gamma * R + self.rewards[i]
            reduced_r.append(R)

        reduced_r = list(reversed(reduced_r))

        idxs = list(range(len(self.rewards)))
        random.shuffle(idxs)
        idxs = idxs[:self.batch_size]

        for i in idxs:
            adv = reduced_r[i] - self.a3c_model(self.states[i])[1]
            l = l + adv**2

        l = l / self.batch_size
        l.backward(retain_graph=True)

        writer.add_scalar("train/value_loss", l.item())
        self.optimizer.step()

    def train(self):
        #just make it simple

        old_model = self.a3c_model.clone()

        #TODO:move it to the last
        self.teacher_train()

        print("reward %f" % sum(self.rewards))

        if self.game_no % 100 == 0:
            torch.save(self.a3c_model.state_dict(),
                       "./tmp/model_%d" % self.game_no)

        writer.add_scalar("train/rewards", sum(self.rewards))
        writer.add_scalar("train/values",
                          sum(self.values).item() / len(self.values))
        writer.add_scalar("train/entropy",
                          sum(self.entropies).item() / len(self.entropies))

    def _generator_run(self, input_):
        self.game_no = self.game_no + 1
        self.init_fn(input_)

        self.engine = simulator.Simulator(feature_name='unit_test1',
                                          actionspace_name='lattice1',
                                          canvas=self.canvas)

        self.reset()
        self.sw.reset()

        while self.engine.get_time() < 200:
            self.i = self.i + 1

            #print(dire_predefine_step)
            dire_state = self.engine.get_state_tup("Dire", 0)

            dire_predefine_step = self.engine.predefined_step("Dire", 0)
            predefine_move = torch.LongTensor([dire_predefine_step[1]])

            is_end = dire_state[2]
            if is_end:
                break

            self.predefined_steps.append(predefine_move)
            state_now = dire_state[0]
            self.states.append(state_now)
            action_out, value_out = self.a3c_model(state_now)

            self.sw.step(state_now)

            prob = F.softmax(action_out)
            self.raw_probs.append(prob)
            log_prob = F.log_softmax(action_out)
            self.raw_log_probs.append(log_prob)

            entropy = -(log_prob * prob).sum(1, keepdim=True)
            self.entropies.append(entropy)

            #action = prob.multinomial(num_samples=1).data

            action = torch.argmax(log_prob, 1).data.view(-1, 1)
            log_prob = log_prob.gather(1, Variable(action))
            self.actions.append(action)

            self.engine.set_order("Dire", 0, (1, action))

            self.engine.loop()

            reward = dire_state[1]
            self.rewards.append(reward)
            self.values.append(value_out)
            self.log_probs.append(log_prob)

            yield

        self.train()
Ejemplo n.º 2
0
class DoubleA3CPPOEnv(game_env.GameEnv):
    def __init__(self, *args, **kwargs):
        super(DoubleA3CPPOEnv, self).__init__(*args, **kwargs)

        random.seed(os.getpid())
        torch.random.manual_seed(os.getpid())

        self.game_no = 0

        self.out_classes = 9

        self.a3c_model = ActorCritic(5, self.out_classes, 64)
        self.optimizer = optim.SGD(self.a3c_model.parameters(), lr=0.1)

        self.reset()

        self.gamma = 0.99
        self.tau = 1.0
        self.entropy_coef = 0.01
        self.epsilon = 0.2

        self.batch_size = 256
        self.buffer_size = 1000
        self.i = 0

        self.update_steps = 3

        self.nll_loss_fn = nn.NLLLoss()

        self.writer = SummaryWriter(comment='_%d' % os.getpid())

        self.rank = -1

    def set_rank(self, rank):
        self.rank = rank

    def set_model(self, model):
        self.a3c_model = model
        self.optimizer = optim.SGD(self.a3c_model.parameters(), lr=0.1)

    def get_model(self):
        return self.a3c_model

    def reset(self):
        self.states = []
        self.actions = []
        self.entropies = []
        self.values = []
        self.rewards = []
        self.log_probs = []
        self.raw_log_probs = []
        self.raw_probs = []
        self.predefined_steps = []

    def update(self):
        print("Game %d updating" % self.game_no)
        self.optimizer.step()
        self.a3c_model.zero_grad()
        self.optimizer.zero_grad()

    def ppo_train_actor(self, old_model):
        self.a3c_model.zero_grad()
        self.optimizer.zero_grad()

        l = 0.0
        R = torch.zeros(1, 1)

        reduced_r = []
        for i in reversed(range(len(self.rewards))):
            R = self.gamma * R + self.rewards[i]
            reduced_r.append(R)

        reduced_r = list(reversed(reduced_r))

        idxs = list(range(len(self.rewards)))
        random.shuffle(idxs)
        idxs = idxs[:self.batch_size]

        #TODO: turn `for loop` to tensor operations
        for i in idxs:
            new_prob, v = self.a3c_model(self.states[i])
            new_prob = F.softmax(new_prob)
            old_prob, _ = old_model(self.states[i])
            old_prob = F.softmax(old_prob)
            adv = reduced_r[i] - v.data
            onehot_act = torch.zeros(self.out_classes)
            onehot_act[self.actions[i]] = 1

            ratio = torch.sum(new_prob * onehot_act) / torch.sum(
                old_prob * onehot_act)
            surr = ratio * adv

            l = l - min(
                surr,
                torch.clamp(ratio, 1.0 - self.epsilon, 1.0 + self.epsilon) *
                adv)

        l = l / self.batch_size

        l.backward(retain_graph=True)
        self.writer.add_scalar("train/policy_loss", l.item() / self.batch_size)
        self.optimizer.step()

    def ppo_train_critic(self):

        self.a3c_model.zero_grad()
        self.optimizer.zero_grad()

        R = torch.zeros(1, 1)
        l = 0.0

        reduced_r = []
        for i in reversed(range(len(self.rewards))):
            R = self.gamma * R + self.rewards[i]
            reduced_r.append(R)

        reduced_r = list(reversed(reduced_r))

        idxs = list(range(len(self.rewards)))
        random.shuffle(idxs)
        idxs = idxs[:self.batch_size]

        for i in idxs:
            adv = reduced_r[i] - self.a3c_model(self.states[i])[1]
            l = l + adv**2

        l = l / self.batch_size
        l.backward(retain_graph=True)

        self.writer.add_scalar("train/value_loss", l.item() / self.batch_size)
        self.optimizer.step()

    def train(self):
        #just make it simple
        self.a3c_model.train()

        old_model = self.a3c_model.clone()

        for _ in range(self.update_steps):
            self.ppo_train_actor(old_model)

        for _ in range(self.update_steps):
            self.ppo_train_critic()

        print("reward %f" % sum(self.rewards))

        self.writer.add_scalar("train/rewards", sum(self.rewards))
        self.writer.add_scalar("train/values",
                               sum(self.values).item() / len(self.values))
        self.writer.add_scalar(
            "train/entropy",
            sum(self.entropies).item() / len(self.entropies))

        acts = torch.cat(self.raw_log_probs)
        pd = torch.tensor(self.predefined_steps)

    def _generator_run(self, input_):
        self.game_no = self.game_no + 1
        self.init_fn(input_)

        self.engine = simulator.Simulator(feature_name='unit_test1',
                                          actionspace_name='lattice1',
                                          canvas=self.canvas)

        self.reset()

        while self.engine.get_time() < 200:
            self.i = self.i + 1

            #print(dire_predefine_step)
            dire_state = self.engine.get_state_tup("Dire", 0)

            dire_predefine_step = self.engine.predefined_step("Dire", 0)
            predefine_move = torch.LongTensor([dire_predefine_step[1]])

            is_end = dire_state[2]
            if is_end:
                break

            self.predefined_steps.append(predefine_move)
            state_now = dire_state[0]
            self.states.append(state_now)
            action_out, value_out = self.a3c_model(state_now)

            prob = F.softmax(action_out)
            self.raw_probs.append(prob)
            log_prob = F.log_softmax(action_out)
            self.raw_log_probs.append(log_prob)

            entropy = -(log_prob * prob).sum(1, keepdim=True)
            self.entropies.append(entropy)

            if self.rank != 0:
                action = predefine_move.view(1, -1).data
            else:
                #action = prob.multinomial(num_samples=1).data

                action = torch.argmax(log_prob, 1).data.view(-1, 1)
            self.actions.append(action)
            log_prob = log_prob.gather(1, Variable(action))

            self.engine.set_order("Dire", 0, (1, action))

            self.engine.loop()

            reward = dire_state[1]
            self.rewards.append(reward)
            self.values.append(value_out)
            self.log_probs.append(log_prob)

            yield
        print("rank %d os.pid %d" % (self.rank, os.getpid()))
        if self.rank != 0:

            self.train()