class Agent: def __init__(self, exploration='epsilonGreedy', memory=10000, discount=0.99, uncertainty=True, uncertainty_weight=1, update_every=200, double=True, use_distribution=True, reward_normalization=False, encoder=None, hidden_size=40, state_difference=True, state_difference_weight=1, **kwargs) -> None: self.uncertainty = uncertainty self.hidden_size = hidden_size self.network = NetWork(self.hidden_size).to(device) self.createEncoder(encoder) self.network.hasEncoder = self.hasEncoder print("Number of parameters in network:", count_parameters(self.network)) self.criterion = MSELoss() self.memory = ReplayBuffer(int(memory)) self.remember = self.memory.remember() self.exploration = Exploration() if exploration == 'greedy': self.explore = self.exploration.greedy elif exploration == 'epsilonGreedy': self.explore = self.exploration.epsilonGreedy elif exploration == 'softmax': self.explore = self.exploration.softmax elif exploration == 'epsintosoftmax': self.explore = self.exploration.epsintosoftmax self.target_network = NetWork(self.hidden_size).to(device) self.target_network.hasEncoder = self.hasEncoder self.placeholder_network = NetWork(self.hidden_size).to(device) self.placeholder_network.hasEncoder = self.hasEncoder self.gamma, self.f = discount, 0 self.update_every, self.double, self.use_distribution = update_every, double, use_distribution self.counter = 0 self.reward_normalization = reward_normalization self.state_difference = state_difference self.true_state_trace = None self.uncertainty_weight = uncertainty_weight self.state_difference_weight = state_difference_weight if encoder is not None: self.optimizer_value = Adam( list(self.network.fromEncoder.parameters()) + list(self.network.lstm.parameters()) + list(self.network.linear.parameters()), lr=1e-4, weight_decay=1e-5) else: self.optimizer_value = Adam(list(self.network.color.parameters()) + list(self.network.conv1.parameters()) + list(self.network.lstm.parameters()) + list(self.network.linear.parameters()), lr=1e-4, weight_decay=1e-5) if self.uncertainty: self.optimizer_exploration = Adam(list( self.network.exploration_network.parameters()), lr=1e-4, weight_decay=1e-5) if self.state_difference: self.optimizer_state_avoidance = Adam(list( self.network.state_difference_network.parameters()), lr=1e-4, weight_decay=1e-5) self.onpolicy = True def rememberMulti(self, *args): if self.state_difference: [ self.remember(obs_old.cpu(), act, obs.cpu(), rew, h0.detach().cpu(), c0.detach().cpu(), hn.detach().cpu(), cn.detach().cpu(), int(not done), before_trace.detach().cpu(), state_diff_label.detach().cpu()) for obs_old, act, obs, rew, h0, c0, hn, cn, done, before_trace, state_diff_label in zip(*args) ] else: args = args[:9] [ self.remember(obs_old.cpu(), act, obs.cpu(), rew, h0.detach().cpu(), c0.detach().cpu(), hn.detach().cpu(), cn.detach().cpu(), int(not done)) for obs_old, act, obs, rew, h0, c0, hn, cn, done in zip(*args) ] def choose(self, pixels, hn, cn): self.network.hn, self.network.cn = hn, cn vals, uncertainty = self.network(pixels) vals.reshape(15) return self.explore( vals, uncertainty), pixels, hn, cn, self.network.hn, self.network.cn def chooseMulti(self, pixels, hn, cn, lambda_decay=0.95, avoid_trace=0, done=None): self.network.hn, self.network.cn = concatenation( hn, 1).to(device), concatenation(cn, 1).to(device) vals, uncertainties, true_state = self.network( concatenation(pixels, 0).to(device)) before_trace = self.true_state_trace if self.state_difference and done is not None: done = 1 - torch.tensor(list(done)).type(torch.float) if self.true_state_trace is None: self.true_state_trace = true_state.detach() * (1 - lambda_decay) else: self.true_state_trace = ( (done.to(device) * (self.true_state_trace.transpose(0, 2))).transpose(0, 2)) self.true_state_trace = self.true_state_trace * lambda_decay + true_state * ( 1 - lambda_decay) avoid_trace = self.network.avoid_similar_state( self.true_state_trace)[0] if not self.onpolicy: self.explore = self.exploration.epsilonGreedy vals = self.convert_values(vals, uncertainties, avoid_trace) return [ self.explore(val.reshape(15)) for val in torch.split(vals, 1) ], pixels, hn, cn, torch.split( self.network.hn, 1, dim=1), torch.split(self.network.cn, 1, dim=1), before_trace, self.true_state_trace else: self.explore = self.exploration.greedy return [ (val.reshape(15)).detach().cpu().numpy().argmax() for val in torch.split(vals, 1) ], pixels, hn, cn, torch.split( self.network.hn, 1, dim=1), torch.split(self.network.cn, 1, dim=1), before_trace, self.true_state_trace def learn(self): self.f += 1 if self.f % 50000 > 48000: self.onpolicy = True return else: self.onpolicy = False if self.f % self.update_every == 0: self.update_target_network() if self.f > self.update_every: for _ in range(1): self.TD_learn() def TD_learn(self): if self.state_difference: obs, action, obs_next, reward, h0, c0, hn, sn, done, before_trace, after_trace = self.memory.sample_distribution( 256) if self.use_distribution else self.memory.sample(256) else: obs, action, obs_next, reward, h0, c0, hn, sn, done = self.memory.sample_distribution( 256) if self.use_distribution else self.memory.sample(256) self.network.hn, self.network.cn, self.target_network.hn, self.target_network.cn = hn, sn, hn, sn vals_target, uncertainties_target, _ = self.target_network(obs_next) if self.double: v_s_next = torch.gather(vals_target, 1, torch.argmax(vals_target, 1).view(-1, 1)).squeeze(1) else: v_s_next, _ = torch.max(vals_target, 1) if self.uncertainty: if self.double: uncer_next = torch.gather( uncertainties_target, 1, torch.argmax(uncertainties_target, 1).view(-1, 1)).squeeze(1) else: uncer_next, _ = torch.max(uncertainties_target, 1) self.network.hn, self.network.cn = h0, c0 vals, uncertainties, _ = self.network(obs) vs = torch.gather(vals, 1, action) td = (reward + self.gamma * v_s_next * done.type(torch.float)).detach().view( -1, 1) if self.uncertainty: estimate_uncertainties = torch.gather(uncertainties, 1, action) reward_uncertainty = (abs(vs - td)).detach().view(-1) td_uncertainty = (reward_uncertainty + self.gamma * uncer_next * done.type(torch.float)).detach().view(-1, 1) loss_uncertainty = self.criterion(estimate_uncertainties, td_uncertainty) loss_uncertainty.backward(retain_graph=True) self.optimizer_exploration.step() self.optimizer_exploration.zero_grad() if self.state_difference: estimate_state_difference = ( torch.gather(self.network.avoid_similar_state(before_trace), 1, action).view(-1) * done.type(torch.float)).view(-1) reward_state_difference = ((((torch.sum( (before_trace - after_trace)**2, dim=1))**(1 / 2)).view(-1) * done.type( torch.float)).view(-1)).detach() loss_state_avoidance = self.criterion(estimate_state_difference, reward_state_difference) loss_state_avoidance.backward(retain_graph=True) self.optimizer_state_avoidance.step() self.optimizer_state_avoidance.zero_grad() loss_value_network = self.criterion(vs, td) loss_value_network.backward() self.optimizer_value.step() self.optimizer_value.zero_grad() # torch.cuda.empty_cache() def convert_values(self, vals, uncertainties, state_differences): if self.f % 100 == 0: print([int(x) / 100 for x in 100 * vals[0].cpu().detach().numpy()]) print([ int(x) / 100 for x in 100 * uncertainties[0].cpu().detach().numpy() ]) if state_differences is not 0: print([ int(x) / 100 for x in 100 * state_differences[0].cpu().detach().numpy() ]) print(" ") return vals + (float(self.uncertainty_weight) * uncertainties * float(self.uncertainty)) + ( float(self.state_difference_weight) * state_differences * float(self.state_difference)) def update_target_network(self): self.target_network = pickle.loads( pickle.dumps(self.placeholder_network)) self.placeholder_network = pickle.loads(pickle.dumps(self.network)) self.memory.update_distribution() def createEncoder(self, encoder): if encoder: with open(f"Encoders/{encoder}.obj", "rb") as file: self.encoder = pickle.load(file).encoder.to(device) self.hasEncoder = True else: self.hasEncoder = False
class Agent: def __init__(self) -> None: self.network = NetWork().to(device) print("Number of parameters in network:", count_parameters(self.network)) self.criterion = MSELoss() self.optimizer = Adam(self.network.parameters(), lr=0.001, weight_decay=0.001) self.memory = ReplayBuffer(100000) self.remember = self.memory.remember() self.exploration = Exploration() self.explore = self.exploration.epsilonGreedy self.target_network = NetWork().to(device) self.placeholder_network = NetWork().to(device) def choose(self, pixels, hn, cn): self.network.hn, self.network.cn = hn, cn vals = self.network(pixels).reshape(15) return self.explore( vals), pixels, hn, cn, self.network.hn, self.network.cn def learn(self, double=False): gamma = 0.96 obs, action, obs_next, reward, h0, c0, hn, sn, done = self.memory.sample_distribution( 20) # self.network.hn, self.network.cn = hn, sn # if double: # v_s_next = torch.gather(self.target_network(obs_next), 1, torch.argmax(self.network(obs_next), 1).view(-1, 1)).squeeze(1) # else: # v_s_next, input_indexes = torch.max(self.target_network(obs_next), 1) # self.network.hn, self.network.cn = h0, c0 # v_s = torch.gather(self.network(obs), 1, action) # #v_s, _ = torch.max(self.network(obs), 1) # td = (reward + gamma * v_s_next * done.type(torch.float)).detach().view(-1, 1) # loss = self.criterion(v_s, td) # loss.backward() # self.optimizer.step() # self.optimizer.zero_grad() # torch.cuda.empty_cache() self.autoEncode(obs) def autoEncode(self, obs): enc = self.network.color(obs) obs_Guess = self.network.colorReverse(enc) # print(enc.prod(1).sum()) # print(f"[{str(float(enc_stand.max()))[:8]}]", end=" ") entro = (enc + 1).prod(1) - (1 + enc).max(1)[0] img = self.criterion(obs_Guess.view(20, -1), obs.view(20, -1) / 256) # print(enc.max(1)[0].max(1)[0].max(1)[0].shape) # print(enc.max(1, keepdim=True)[0].shape) maxi = enc.max(1, keepdim=False)[0] loss = img * 100 + (entro * entro).mean() loss.backward() self.optimizer.step() self.optimizer.zero_grad() print(f"[{str(float(loss))[:8]}]", end=" ") print(f"[{str(float(img*100))[:8]}]", end=" ") print(f"[{str(float((entro * entro).mean()))[:8]}]", end=" ") print(f"[{str(float(enc.min()))[:8]}]", end=" ") print(f"[{str(float(enc.max()))[:8]}]") # print(f"[{str(float(enc.mean()))[:8]}]") # print(f"[{str(float(entro.min()))[:8]}]", end=" ") # print(f"[{str(float(entro.max()))[:8]}]", end=" ") # print(f"[{str(float(enc.max()))[:8]}]") # print(*[[float(str(f)[:5]) for f in list(p.detach().cpu().numpy().reshape(-1))] for p in self.network.color.parameters()], *[[float(str(f)[:5]) for f in list(p.detach().cpu().numpy().reshape(-1))] for p in self.network.colorReverse.parameters()]) torch.cuda.empty_cache() def update_target_network(self): self.target_network = copy.deepcopy(self.placeholder_network) self.placeholder_network = copy.deepcopy(self.network) self.memory.update_distribution()