示例#1
0
class Agent:
    def __init__(self,
                 exploration='epsilonGreedy',
                 memory=10000,
                 discount=0.99,
                 uncertainty=True,
                 uncertainty_weight=1,
                 update_every=200,
                 double=True,
                 use_distribution=True,
                 reward_normalization=False,
                 encoder=None,
                 hidden_size=40,
                 state_difference=True,
                 state_difference_weight=1,
                 **kwargs) -> None:
        self.uncertainty = uncertainty
        self.hidden_size = hidden_size
        self.network = NetWork(self.hidden_size).to(device)
        self.createEncoder(encoder)
        self.network.hasEncoder = self.hasEncoder
        print("Number of parameters in network:",
              count_parameters(self.network))
        self.criterion = MSELoss()
        self.memory = ReplayBuffer(int(memory))
        self.remember = self.memory.remember()
        self.exploration = Exploration()
        if exploration == 'greedy':
            self.explore = self.exploration.greedy
        elif exploration == 'epsilonGreedy':
            self.explore = self.exploration.epsilonGreedy
        elif exploration == 'softmax':
            self.explore = self.exploration.softmax
        elif exploration == 'epsintosoftmax':
            self.explore = self.exploration.epsintosoftmax
        self.target_network = NetWork(self.hidden_size).to(device)
        self.target_network.hasEncoder = self.hasEncoder
        self.placeholder_network = NetWork(self.hidden_size).to(device)
        self.placeholder_network.hasEncoder = self.hasEncoder
        self.gamma, self.f = discount, 0
        self.update_every, self.double, self.use_distribution = update_every, double, use_distribution
        self.counter = 0
        self.reward_normalization = reward_normalization
        self.state_difference = state_difference
        self.true_state_trace = None
        self.uncertainty_weight = uncertainty_weight
        self.state_difference_weight = state_difference_weight
        if encoder is not None:
            self.optimizer_value = Adam(
                list(self.network.fromEncoder.parameters()) +
                list(self.network.lstm.parameters()) +
                list(self.network.linear.parameters()),
                lr=1e-4,
                weight_decay=1e-5)
        else:
            self.optimizer_value = Adam(list(self.network.color.parameters()) +
                                        list(self.network.conv1.parameters()) +
                                        list(self.network.lstm.parameters()) +
                                        list(self.network.linear.parameters()),
                                        lr=1e-4,
                                        weight_decay=1e-5)
        if self.uncertainty:
            self.optimizer_exploration = Adam(list(
                self.network.exploration_network.parameters()),
                                              lr=1e-4,
                                              weight_decay=1e-5)
        if self.state_difference:
            self.optimizer_state_avoidance = Adam(list(
                self.network.state_difference_network.parameters()),
                                                  lr=1e-4,
                                                  weight_decay=1e-5)
        self.onpolicy = True

    def rememberMulti(self, *args):
        if self.state_difference:
            [
                self.remember(obs_old.cpu(), act, obs.cpu(), rew,
                              h0.detach().cpu(),
                              c0.detach().cpu(),
                              hn.detach().cpu(),
                              cn.detach().cpu(), int(not done),
                              before_trace.detach().cpu(),
                              state_diff_label.detach().cpu())
                for obs_old, act, obs, rew, h0, c0, hn, cn, done, before_trace,
                state_diff_label in zip(*args)
            ]
        else:
            args = args[:9]
            [
                self.remember(obs_old.cpu(), act, obs.cpu(), rew,
                              h0.detach().cpu(),
                              c0.detach().cpu(),
                              hn.detach().cpu(),
                              cn.detach().cpu(), int(not done))
                for obs_old, act, obs, rew, h0, c0, hn, cn, done in zip(*args)
            ]

    def choose(self, pixels, hn, cn):
        self.network.hn, self.network.cn = hn, cn
        vals, uncertainty = self.network(pixels)
        vals.reshape(15)
        return self.explore(
            vals,
            uncertainty), pixels, hn, cn, self.network.hn, self.network.cn

    def chooseMulti(self,
                    pixels,
                    hn,
                    cn,
                    lambda_decay=0.95,
                    avoid_trace=0,
                    done=None):
        self.network.hn, self.network.cn = concatenation(
            hn, 1).to(device), concatenation(cn, 1).to(device)
        vals, uncertainties, true_state = self.network(
            concatenation(pixels, 0).to(device))
        before_trace = self.true_state_trace
        if self.state_difference and done is not None:
            done = 1 - torch.tensor(list(done)).type(torch.float)
            if self.true_state_trace is None:
                self.true_state_trace = true_state.detach() * (1 -
                                                               lambda_decay)
            else:
                self.true_state_trace = (
                    (done.to(device) *
                     (self.true_state_trace.transpose(0, 2))).transpose(0, 2))
                self.true_state_trace = self.true_state_trace * lambda_decay + true_state * (
                    1 - lambda_decay)
            avoid_trace = self.network.avoid_similar_state(
                self.true_state_trace)[0]

        if not self.onpolicy:
            self.explore = self.exploration.epsilonGreedy
            vals = self.convert_values(vals, uncertainties, avoid_trace)
            return [
                self.explore(val.reshape(15)) for val in torch.split(vals, 1)
            ], pixels, hn, cn, torch.split(
                self.network.hn, 1,
                dim=1), torch.split(self.network.cn, 1,
                                    dim=1), before_trace, self.true_state_trace
        else:
            self.explore = self.exploration.greedy
            return [
                (val.reshape(15)).detach().cpu().numpy().argmax()
                for val in torch.split(vals, 1)
            ], pixels, hn, cn, torch.split(
                self.network.hn, 1,
                dim=1), torch.split(self.network.cn, 1,
                                    dim=1), before_trace, self.true_state_trace

    def learn(self):
        self.f += 1
        if self.f % 50000 > 48000:
            self.onpolicy = True
            return
        else:
            self.onpolicy = False
        if self.f % self.update_every == 0:
            self.update_target_network()
        if self.f > self.update_every:
            for _ in range(1):
                self.TD_learn()

    def TD_learn(self):
        if self.state_difference:
            obs, action, obs_next, reward, h0, c0, hn, sn, done, before_trace, after_trace = self.memory.sample_distribution(
                256) if self.use_distribution else self.memory.sample(256)
        else:
            obs, action, obs_next, reward, h0, c0, hn, sn, done = self.memory.sample_distribution(
                256) if self.use_distribution else self.memory.sample(256)
        self.network.hn, self.network.cn, self.target_network.hn, self.target_network.cn = hn, sn, hn, sn
        vals_target, uncertainties_target, _ = self.target_network(obs_next)
        if self.double:
            v_s_next = torch.gather(vals_target, 1,
                                    torch.argmax(vals_target,
                                                 1).view(-1, 1)).squeeze(1)
        else:
            v_s_next, _ = torch.max(vals_target, 1)

        if self.uncertainty:
            if self.double:
                uncer_next = torch.gather(
                    uncertainties_target, 1,
                    torch.argmax(uncertainties_target, 1).view(-1,
                                                               1)).squeeze(1)
            else:
                uncer_next, _ = torch.max(uncertainties_target, 1)

        self.network.hn, self.network.cn = h0, c0
        vals, uncertainties, _ = self.network(obs)
        vs = torch.gather(vals, 1, action)
        td = (reward +
              self.gamma * v_s_next * done.type(torch.float)).detach().view(
                  -1, 1)

        if self.uncertainty:
            estimate_uncertainties = torch.gather(uncertainties, 1, action)
            reward_uncertainty = (abs(vs - td)).detach().view(-1)
            td_uncertainty = (reward_uncertainty + self.gamma * uncer_next *
                              done.type(torch.float)).detach().view(-1, 1)
            loss_uncertainty = self.criterion(estimate_uncertainties,
                                              td_uncertainty)
            loss_uncertainty.backward(retain_graph=True)
            self.optimizer_exploration.step()
            self.optimizer_exploration.zero_grad()
        if self.state_difference:
            estimate_state_difference = (
                torch.gather(self.network.avoid_similar_state(before_trace), 1,
                             action).view(-1) *
                done.type(torch.float)).view(-1)
            reward_state_difference = ((((torch.sum(
                (before_trace - after_trace)**2, dim=1))**(1 / 2)).view(-1) *
                                        done.type(
                                            torch.float)).view(-1)).detach()
            loss_state_avoidance = self.criterion(estimate_state_difference,
                                                  reward_state_difference)
            loss_state_avoidance.backward(retain_graph=True)
            self.optimizer_state_avoidance.step()
            self.optimizer_state_avoidance.zero_grad()

        loss_value_network = self.criterion(vs, td)
        loss_value_network.backward()
        self.optimizer_value.step()
        self.optimizer_value.zero_grad()

        # torch.cuda.empty_cache()

    def convert_values(self, vals, uncertainties, state_differences):
        if self.f % 100 == 0:
            print([int(x) / 100 for x in 100 * vals[0].cpu().detach().numpy()])
            print([
                int(x) / 100
                for x in 100 * uncertainties[0].cpu().detach().numpy()
            ])
            if state_differences is not 0:
                print([
                    int(x) / 100
                    for x in 100 * state_differences[0].cpu().detach().numpy()
                ])
            print(" ")
        return vals + (float(self.uncertainty_weight) * uncertainties *
                       float(self.uncertainty)) + (
                           float(self.state_difference_weight) *
                           state_differences * float(self.state_difference))

    def update_target_network(self):
        self.target_network = pickle.loads(
            pickle.dumps(self.placeholder_network))
        self.placeholder_network = pickle.loads(pickle.dumps(self.network))
        self.memory.update_distribution()

    def createEncoder(self, encoder):
        if encoder:
            with open(f"Encoders/{encoder}.obj", "rb") as file:
                self.encoder = pickle.load(file).encoder.to(device)
                self.hasEncoder = True
        else:
            self.hasEncoder = False
class Agent:
    def __init__(self) -> None:
        self.network = NetWork().to(device)
        print("Number of parameters in network:",
              count_parameters(self.network))
        self.criterion = MSELoss()
        self.optimizer = Adam(self.network.parameters(),
                              lr=0.001,
                              weight_decay=0.001)
        self.memory = ReplayBuffer(100000)
        self.remember = self.memory.remember()
        self.exploration = Exploration()
        self.explore = self.exploration.epsilonGreedy
        self.target_network = NetWork().to(device)
        self.placeholder_network = NetWork().to(device)

    def choose(self, pixels, hn, cn):
        self.network.hn, self.network.cn = hn, cn
        vals = self.network(pixels).reshape(15)
        return self.explore(
            vals), pixels, hn, cn, self.network.hn, self.network.cn

    def learn(self, double=False):
        gamma = 0.96
        obs, action, obs_next, reward, h0, c0, hn, sn, done = self.memory.sample_distribution(
            20)
        # self.network.hn, self.network.cn = hn, sn

        # if double:
        #     v_s_next = torch.gather(self.target_network(obs_next), 1, torch.argmax(self.network(obs_next), 1).view(-1, 1)).squeeze(1)
        # else:
        #     v_s_next, input_indexes = torch.max(self.target_network(obs_next), 1)

        # self.network.hn, self.network.cn = h0, c0
        # v_s = torch.gather(self.network(obs), 1, action)
        # #v_s, _ = torch.max(self.network(obs), 1)
        # td = (reward + gamma * v_s_next * done.type(torch.float)).detach().view(-1, 1)
        # loss = self.criterion(v_s, td)
        # loss.backward()
        # self.optimizer.step()
        # self.optimizer.zero_grad()
        # torch.cuda.empty_cache()
        self.autoEncode(obs)

    def autoEncode(self, obs):
        enc = self.network.color(obs)
        obs_Guess = self.network.colorReverse(enc)
        # print(enc.prod(1).sum())

        # print(f"[{str(float(enc_stand.max()))[:8]}]", end=" ")
        entro = (enc + 1).prod(1) - (1 + enc).max(1)[0]
        img = self.criterion(obs_Guess.view(20, -1), obs.view(20, -1) / 256)
        # print(enc.max(1)[0].max(1)[0].max(1)[0].shape)
        # print(enc.max(1, keepdim=True)[0].shape)
        maxi = enc.max(1, keepdim=False)[0]
        loss = img * 100 + (entro * entro).mean()
        loss.backward()

        self.optimizer.step()
        self.optimizer.zero_grad()
        print(f"[{str(float(loss))[:8]}]", end=" ")
        print(f"[{str(float(img*100))[:8]}]", end=" ")
        print(f"[{str(float((entro * entro).mean()))[:8]}]", end=" ")
        print(f"[{str(float(enc.min()))[:8]}]", end=" ")
        print(f"[{str(float(enc.max()))[:8]}]")
        # print(f"[{str(float(enc.mean()))[:8]}]")
        # print(f"[{str(float(entro.min()))[:8]}]", end=" ")
        # print(f"[{str(float(entro.max()))[:8]}]", end=" ")
        # print(f"[{str(float(enc.max()))[:8]}]")
        # print(*[[float(str(f)[:5]) for f in list(p.detach().cpu().numpy().reshape(-1))] for p in self.network.color.parameters()], *[[float(str(f)[:5]) for f in list(p.detach().cpu().numpy().reshape(-1))] for p in self.network.colorReverse.parameters()])
        torch.cuda.empty_cache()

    def update_target_network(self):
        self.target_network = copy.deepcopy(self.placeholder_network)
        self.placeholder_network = copy.deepcopy(self.network)
        self.memory.update_distribution()