def __init__(self, config, args): self.config, self.args = config, args self.nANN = config.NPOP self.anns = [ANN(config, args) for _ in range(self.nANN)] self.updates, self.rollouts = defaultdict(lambda: Rollout()), {} self.updates_lm, self.rollouts_lm = defaultdict(lambda: Rollout()), {} self.initBuffer() self.lawmaker = Lawmaker( args, config) if args.lm else LawmakerAbstract(args, config)
def __init__(self, config, args, idx): self.config, self.args = config, args self.nANN, self.h = config.NPOP, config.HIDDEN self.anns = [ANN(config, args) for _ in range(self.nANN)] self.init, self.nRollouts = True, 32 self.updates = defaultdict(lambda: Rollout(config)) self.blobs = [] self.idx = idx self.ReplayMemory = [ReplayMemory(self.config) for _ in range(self.nANN)] self.ReplayMemoryLm = ReplayMemoryLm(self.config) self.forward, self.forward_lm = Forward(config), ForwardLm(config) self.lawmaker = Lawmaker(args, config) if args.lm else LawmakerAbstract(args, config) self.tick = 0
def init(self): print('Initializing new model...') self.anns = [ ANN(self.config, self.args).to(self.config.device) for _ in range(self.nANN) ] self.targetAnns = [ ANN(self.config, self.args).to(self.config.device) for _ in range(self.nANN) ] self.updateTargetAnns() self.annsOpts = None if not self.config.TEST: self.annsOpts = [ Adam(ann.parameters(), lr=0.0005, weight_decay=0.00001) for ann in self.anns ] self.lm = (Lawmaker(self.args, self.config).to(self.config.device) if self.args.lm else LawmakerAbstract(self.args, self.config)) self.lmOpt = None if not self.config.TEST and self.args.lm: self.lmOpt = Adam(self.lm.parameters(), lr=0.0005, weight_decay=0.00001)
def init(self): print('Initializing new model...') self.unshared(self.config.NPOP) self.opt = None if not self.config.TEST: self.opt = [ Adam(ann.parameters(), lr=self.config.LR, weight_decay=0.00001) for ann in self.anns ] self.scheduler = [StepLR(opt, 1, gamma=0.9998) for opt in self.opt] self.lawmaker = (Lawmaker(self.args, self.config, device=self.config.DEVICE_OPTIMIZER, batch_size=self.config.LSTM_PERIOD).to( self.config.DEVICE_OPTIMIZER) if self.args.lm else LawmakerAbstract( self.args, self.config, device=self.config.DEVICE_OPTIMIZER, batch_size=self.config.LSTM_PERIOD)) initialize_weights(self.lawmaker) if self.args.lm: self.lmOpt = Adam(self.lawmaker.parameters(), lr=self.config.LR, weight_decay=0.00001) self.lmScheduler = StepLR(self.lmOpt, 1, gamma=0.9998)
def __init__(self, config, args): self.config, self.args = config, args self.nANN, self.h = config.NPOP, config.HIDDEN self.anns = [ANN(config, args) for _ in range(self.nANN)] self.updates, self.rollouts = defaultdict(lambda: Rollout(config)), {} self.updates_lm, self.rollouts_lm = defaultdict( lambda: Rollout(config)), {} self.nGrads = 0 self.rets = defaultdict(deque) self.flat_states = defaultdict(deque) self.ent_states = defaultdict(deque) self.lmActions = defaultdict(deque) self.lmPolicy = defaultdict(deque) self.policy = defaultdict(deque) self.actions = defaultdict(deque) self.contacts = defaultdict(deque) self.buffer = None self.lawmaker = Lawmaker( args, config) if args.lm else LawmakerAbstract(args, config)
class Sword: def __init__(self, config, args): self.config, self.args = config, args self.nANN, self.h = config.NPOP, config.HIDDEN self.anns = [ANN(config, args) for _ in range(self.nANN)] self.updates, self.rollouts = defaultdict(lambda: Rollout(config)), {} self.updates_lm, self.rollouts_lm = defaultdict( lambda: Rollout(config)), {} self.nGrads = 0 self.rets = defaultdict(deque) self.flat_states = defaultdict(deque) self.ent_states = defaultdict(deque) self.lmActions = defaultdict(deque) self.lmPolicy = defaultdict(deque) self.policy = defaultdict(deque) self.actions = defaultdict(deque) self.contacts = defaultdict(deque) self.buffer = None self.lawmaker = Lawmaker( args, config) if args.lm else LawmakerAbstract(args, config) def backward(self): self.rollouts_lm = {} self.blobs = [r.feather.blob for r in self.rollouts.values()] self.rollouts = {} length = min([len(self.rets[i]) for i in range(self.nANN)]) if length == 0: return self.initBuffer() buffer = defaultdict( lambda: { 'policy': defaultdict(list), 'action': defaultdict(list), 'lmPolicy': defaultdict(list), 'lmAction': defaultdict(list), 'flat': [], 'ents': [], 'return': [] }) for _ in range(length): for i in range(self.nANN): buffer[i]['flat'].append(self.flat_states[i].popleft()) buffer[i]['ents'].append(self.ent_states[i].popleft()) buffer[i]['return'].append(self.rets[i].popleft()) for (k, v), action in zip(self.policy[i].popleft().items(), self.actions[i].popleft().values()): buffer[i]['policy'][k].append(v.detach().numpy()) buffer[i]['action'][k].append(action.detach().numpy()) for (k, v), action in zip(self.lmPolicy[i].popleft().items(), self.lmActions[i].popleft().values()): buffer[i]['lmPolicy'][k].append(v.detach().numpy()) buffer[i]['lmAction'][k].append(action.detach().numpy()) for i in range(self.nANN): self.buffer[i]['flat'] = np.asarray(buffer[i]['flat'], dtype=np.float32) self.buffer[i]['ents'] = np.asarray(buffer[i]['ents'], dtype=np.float32) self.buffer[i]['return'] = np.asarray(buffer[i]['return'], dtype=np.float32) self.buffer[i]['policy'] = { k: np.asarray(v, dtype=np.float32) for k, v in buffer[i]['policy'].items() } self.buffer[i]['action'] = { k: np.asarray(v, dtype=np.float32) for k, v in buffer[i]['action'].items() } self.buffer[i]['lmPolicy'] = { k: np.asarray(v, dtype=np.float32) for k, v in buffer[i]['lmPolicy'].items() } self.buffer[i]['lmAction'] = { k: np.asarray(v, dtype=np.float32) for k, v in buffer[i]['lmAction'].items() } def sendLogUpdate(self): blobs = self.blobs self.blobs = [] return blobs def sendUpdate(self): if self.buffer is None: return None, None buffer = self.dispatchBuffer() return buffer, self.sendLogUpdate() def recvUpdate(self, update): update, update_lm = update for idx, paramVec in enumerate(update): setParameters(self.anns[idx], paramVec) zeroGrads(self.anns[idx]) setParameters(self.lawmaker, update_lm[0]) zeroGrads(self.lawmaker) def collectStep(self, entID, action, policy, flat, ents, reward, contact, val): if self.config.TEST: return self.updates[entID].step(action, policy, flat, ents, reward, contact, val) def collectStepLm(self, entID, action, policy): if self.config.TEST: return self.updates_lm[entID].step(action, policy) def collectRollout(self, entID, ent, tick): assert entID not in self.rollouts rollout = self.updates[entID] rollout.finish() nGrads = rollout.lifespan self.rets[ent.annID] += rollout.rets self.ent_states[ent.annID] += rollout.ent_states[:-1] self.flat_states[ent.annID] += rollout.flat_states[:-1] self.policy[ent.annID] += rollout.policy[:-1] self.actions[ent.annID] += rollout.actions[:-1] self.contacts[ent.annID] += rollout.contacts[:-1] rollout_lm = self.updates_lm[entID] self.lmPolicy[ent.annID] += rollout_lm.policy[:-1] self.lmActions[ent.annID] += rollout_lm.actions[:-1] del self.updates_lm[entID] rollout.feather.blob.tick = tick self.rollouts[entID] = rollout del self.updates[entID] self.nGrads += nGrads if self.nGrads >= self.config.stepsPerEpoch: self.nGrads = 0 self.backward() def initBuffer(self): self.buffer = defaultdict(dict) def dispatchBuffer(self): buffer = self.buffer self.buffer = None return buffer def getActionArguments(self, annReturns, stim, ent): actions = ActionTree(stim, ent, ActionV2).actions() move, attkShare = actions playerActions = [move] actionDecisions = {} moveAction = int(annReturns['actions']['move']) attack = moveAction > 4 if attack: moveAction -= 5 actionTargets = [move.args(stim, ent, self.config)[moveAction]] action = attkShare.args(stim, ent, self.config)['attack'] targets = action.args(stim, ent, self.config) target, decision = checkTile(ent, int(attack), targets) playerActions.append(action), actionTargets.append([target]) actionDecisions['attack'] = decision return playerActions, actionTargets, actionDecisions def decide(self, ent, stim, isDead, n_dead=0): entID, annID = ent.entID, ent.annID reward = self.config.STEPREWARD + self.config.DEADREWARD * n_dead stim_tensor = torchlib.Stim(ent, stim, self.config) outsLm = self.lawmaker(stim_tensor.flat.view(1, -1), stim_tensor.ents.unsqueeze(0), isDead, annID) annReturns = self.anns[annID](stim_tensor.flat.view(1, -1), stim_tensor.ents.unsqueeze(0), outsLm, isDead) playerActions, actionTargets, actionDecisions = self.getActionArguments( annReturns, stim, ent) moveAction = int(annReturns['actions']['move']) attack = actionDecisions.get('attack', None) if moveAction > 4: moveAction -= 5 ent.moveDec = moveAction contact = int(attack is not None) Asw = -np.mean([float(t.mean()) for t in outsLm['Qs'].values()]) outsLm = self.lawmaker.get_punishment(outsLm, annReturns['actions']) Asw += np.mean([float(t) for t in outsLm['Qs'].values()]) self.collectStep(entID, annReturns['actions'], annReturns['policy'], stim_tensor.flat.numpy(), stim_tensor.ents.numpy(), reward, contact, float(annReturns['val'])) self.collectStepLm(entID, outsLm['actions'], outsLm['policy']) if not self.config.TEST: self.updates[entID].feather.scrawl(ent, float(annReturns['val']), reward, Asw, attack, contact) return playerActions, actionTargets
class Sword: def __init__(self, config, args, idx): self.config, self.args = config, args self.nANN, self.h = config.NPOP, config.HIDDEN self.anns = [ANN(config, args) for _ in range(self.nANN)] self.init, self.nRollouts = True, 32 self.updates = defaultdict(lambda: Rollout(config)) self.blobs = [] self.idx = idx self.ReplayMemory = [ ReplayMemory(self.config) for _ in range(self.nANN) ] self.ReplayMemoryLm = ReplayMemoryLm(self.config) self.forward, self.forward_lm = Forward(config), ForwardLm(config) self.buffer_size_to_send = 2**8 self.lawmaker = Lawmaker( args, config) if args.lm else LawmakerAbstract(args, config) self.tick = 0 self.logTick = 0 def sendBufferUpdate(self): for replay in self.ReplayMemory: if len(replay) < self.buffer_size_to_send: return None, None buffer = [replay.send_buffer() for replay in self.ReplayMemory] priorities = [ self.forward.get_priorities_from_samples(buf, ann, ann, self.lawmaker) for buf, ann in zip(buffer, self.anns) ] if self.config.REPLAY_PRIO else [None] * len(self.anns) if self.args.lm: bufferLm = self.ReplayMemoryLm.send_buffer() prioritiesLm = self.forward_lm.get_priorities_from_samples(bufferLm, self.anns, self.lawmaker) if \ self.config.REPLAY_PRIO else None return (buffer, priorities), (bufferLm, prioritiesLm) else: return (buffer, priorities), None def sendLogUpdate(self): self.logTick += 1 if ((self.logTick + 1) % 2**6) == 0: blobs = self.blobs self.blobs = [] return blobs return None def sendUpdate(self): recvs, recvs_lm = self.sendBufferUpdate() logs = self.sendLogUpdate() if recvs is not None else None return recvs, recvs_lm, logs def recvUpdate(self, update): update, update_lm = update if update is not None: self.loadAnnsFrom(update) if self.args.lm and (update_lm is not None): self.loadLmFrom(update_lm) def collectStep(self, entID, annID, s, atnArgs, reward, dead, val): if self.config.TEST: return actions = {key: val[1] for key, val in atnArgs.items()} self.ReplayMemory[annID].append(entID, s, actions, reward, dead, val) if self.args.lm: self.ReplayMemoryLm.append(entID, annID, s, actions, reward, dead, val) def collectRollout(self, entID, tick): rollout = self.updates[entID] rollout.feather.blob.tick = tick rollout.finish() self.blobs.append(rollout.feather.blob) del self.updates[entID] def decide(self, ent, stim, isDead=False, n_dead=0): entID, annID = ent.entID, ent.annID reward = self.config.DEADREWARD * n_dead + self.config.STEPREWARD flat, ents = self.prepareInput(ent, stim) outputsLm, punishmentsLm = self.lawmaker(flat, ents) atnArgs, val = self.anns[annID](flat, ents, self.config.EPS_CUR, punishmentsLm) action, arguments, decs = self.actionTree(ent, stim, atnArgs) attack = decs.get('attack', None) shareFood = decs.get('shareFood', None) shareWater = decs.get('shareWater', None) ent.moveDec = int(atnArgs['move'][1]) contact = int(attack is not None) if not contact: punishAttack, punishWater, punishFood = None, None, None else: punishAttack = punishmentsLm.get('attack', [[None, None]])[0][1] punishWater = punishmentsLm.get('shareWater', [[None, None]])[0][1] punishFood = punishmentsLm.get('shareFood', [[None, None]])[0][1] ent.shareFoodDec = shareFood ent.shareWaterDec = shareWater ent.attackDec = attack self.collectStep(entID, annID, { 'flat': flat, 'ents': ents }, atnArgs, reward, isDead, val.detach().mean(2)) avgPunishmentLm = calcAvgPunishment(atnArgs, outputsLm) if not self.config.TEST: self.updates[entID].feather.scrawl(stim, ent, val.detach().mean(2), reward, avgPunishmentLm, punishAttack, punishWater, punishFood, attack, shareFood, shareWater, contact) return action, arguments def prepareInput(self, ent, env): s = torchlib.Stim(ent, env, self.config) return s.flat.unsqueeze(0), s.ents.unsqueeze(0) def actionTree(self, ent, env, outputs): actions = ActionTree(env, ent, ActionV2).actions() _, move, attkShare = actions playerActions = [move] actionTargets = [ move.args(env, ent, self.config)[int(outputs['move'][1])] ] actionDecisions = {} for name in ['attack', 'shareWater', 'shareFood']: if name not in outputs.keys(): continue action = attkShare.args(env, ent, self.config)[name] targets = action.args(env, ent, self.config) target, decision = checkTile(ent, int(outputs[name][1]), targets) playerActions.append(action), actionTargets.append([target]) actionDecisions[name] = decision return playerActions, actionTargets, actionDecisions def reset_noise(self): nets = self.anns + [self.lawmaker] for net in nets: net.reset_noise() def loadAnnsFrom(self, states): [ann.load_state_dict(state) for ann, state in zip(self.anns, states)] def loadLmFrom(self, state): self.lawmaker.load_state_dict(state)
class Sword: def __init__(self, config, args, idx): self.config, self.args = config, args self.nANN, self.h = config.NPOP, config.HIDDEN self.anns = [ANN(config, args) for _ in range(self.nANN)] self.init, self.nRollouts = True, 32 self.updates = defaultdict(lambda: Rollout(config)) self.blobs = [] self.idx = idx self.ReplayMemory = [ReplayMemory(self.config) for _ in range(self.nANN)] self.ReplayMemoryLm = ReplayMemoryLm(self.config) self.forward, self.forward_lm = Forward(config), ForwardLm(config) self.lawmaker = Lawmaker(args, config) if args.lm else LawmakerAbstract(args, config) self.tick = 0 def sendBufferUpdate(self): buffer = [replay.send_buffer() for replay in self.ReplayMemory] priorities = [self.forward.get_priorities_from_samples(buf, ann, ann, self.lawmaker) for buf, ann in zip(buffer, self.anns)] if self.config.REPLAY_PRIO else [None] * len(self.anns) if self.args.lm: bufferLm = self.ReplayMemoryLm.send_buffer() prioritiesLm = self.forward_lm.get_priorities_from_samples(bufferLm, self.anns, self.lawmaker) if \ self.config.REPLAY_PRIO else None return (buffer, priorities), (bufferLm, prioritiesLm) else: return (buffer, priorities), None def sendLogUpdate(self): blobs = self.blobs self.blobs = [] return blobs def sendUpdate(self): recvs, recvs_lm = self.sendBufferUpdate() logs = self.sendLogUpdate() return recvs, recvs_lm, logs def recvUpdate(self, update): update, update_lm = update if update is not None: self.loadAnnsFrom(update) if self.args.lm and (update_lm is not None): self.loadLmFrom(update_lm) def collectStep(self, entID, annID, s, atnArgs, reward, dead, val): if self.config.TEST: return actions = {key: val[1] for key, val in atnArgs.items()} self.ReplayMemory[annID].append(entID, s, actions, reward, dead, val) if self.args.lm: self.ReplayMemoryLm.append(entID, annID, s, actions, reward, dead, val) def collectRollout(self, entID): rollout = self.updates[entID] rollout.feather.blob.tick = self.tick rollout.finish() self.blobs.append(rollout.feather.blob) del self.updates[entID] def decide(self, entID, annID, stim, reward, reward_stats, apples, isDead): stim_tensor = self.prepareInput(stim) outsLm, punishLm = self.lawmaker(stim_tensor) atnArgs, val = self.anns[annID](stim_tensor, self.config.EPS_CUR, punishLm) action = int(atnArgs['action'][1]) self.collectStep(entID, annID, stim_tensor, atnArgs, reward, isDead, val.detach().mean(2)) if not self.config.TEST: self.updates[entID].feather.scrawl( annID, val.detach().mean(2), reward_stats, apples, outsLm['action'][:, action].mean().detach().numpy() - outsLm['action'].mean().detach().numpy()) return action def prepareInput(self, stim): stim = np.transpose(stim, (2, 0, 1)).copy() stim_tensor = torch.from_numpy(stim).unsqueeze(0).float() return stim_tensor def reset_noise(self): nets = self.anns + [self.lawmaker] for net in nets: net.reset_noise() def loadAnnsFrom(self, states): [ann.load_state_dict(state) for ann, state in zip(self.anns, states)] def loadLmFrom(self, state): self.lawmaker.load_state_dict(state)
class Sword: def __init__(self, config, args): self.config, self.args = config, args self.nANN = config.NPOP self.anns = [ANN(config, args) for _ in range(self.nANN)] self.updates, self.rollouts = defaultdict(lambda: Rollout()), {} self.updates_lm, self.rollouts_lm = defaultdict(lambda: Rollout()), {} self.initBuffer() self.lawmaker = Lawmaker( args, config) if args.lm else LawmakerAbstract(args, config) def backward(self): self.rollouts_lm = {} self.blobs = [r.feather.blob for r in self.rollouts.values()] self.rollouts = {} def sendLogUpdate(self): blobs = self.blobs self.blobs = [] return blobs def recvUpdate(self, update): update, update_lm = update for idx, paramVec in enumerate(update): setParameters(self.anns[idx], paramVec) zeroGrads(self.anns[idx]) setParameters(self.lawmaker, update_lm[0]) zeroGrads(self.lawmaker) def collectStep(self, entID, atnArgs, val, reward, stim=None): if self.config.TEST: return self.updates[entID].step(atnArgs, val, reward, stim) def collectRollout(self, entID, ent, tick, epoch): assert entID not in self.rollouts rollout = self.updates[entID] rollout.finish() rollout.feather.blob.tick = tick annID = ent.annID self.buffer[annID]['return'][(epoch + 1) * self.config.HORIZON - 1] = self.buffer[annID]['reward'][ (epoch + 1) * self.config.HORIZON - 1] for i in reversed( range(epoch * self.config.HORIZON, (epoch + 1) * self.config.HORIZON - 1)): self.buffer[annID]['return'][i] = self.buffer[annID]['reward'][ i] + 0.99 * self.buffer[annID]['return'][i + 1] self.rollouts[entID] = rollout del self.updates[entID] def initBuffer(self): batchSize = self.config.HORIZON * self.config.EPOCHS self.buffer = defaultdict( lambda: { 'state': np.ndarray((batchSize, 3, 15, 15), dtype=int), 'policy': np.ndarray((batchSize, 8), dtype=float), 'lmPolicy': np.ndarray((batchSize, 8), dtype=float), 'action': np.ndarray((batchSize, ), dtype=int), 'reward': np.ndarray((batchSize, ), dtype=int), 'return': np.ndarray((batchSize, ), dtype=float), 'lmAction': np.ndarray((batchSize, 8), dtype=float) }) def dispatchBuffer(self): buffer = deepcopy(self.buffer) self.initBuffer() return buffer def decide(self, ent, stim, reward, isDead, step, epoch): entID, annID = ent.agent_id + str(epoch), ent.annID stim = np.transpose(stim, (2, 0, 1)).copy() stim_tensor = torch.from_numpy(stim).unsqueeze(0).float() outsLm = self.lawmaker(stim_tensor, isDead, annID) annReturns = self.anns[annID](stim_tensor, outsLm, isDead) self.buffer[annID]['state'][step] = stim self.buffer[annID]['policy'][step] = annReturns['outputs']['action'][ 0].detach().numpy() self.buffer[annID]['action'][step] = annReturns['outputs']['action'][1] self.buffer[annID]['lmAction'][step] = outsLm['action'][1].detach( ).numpy() action = int(annReturns['outputs']['action'][1]) Asw = -outsLm['action'][-1].mean() outsLm = self.lawmaker.get_punishment( outsLm, torch.tensor(action), annReturns['outputs']['action'][0].detach()) self.buffer[annID]['lmPolicy'][step] = outsLm['action'][1].detach( ).numpy() Asw += float(outsLm['action'][-1]) self.collectStep(entID, annReturns['outputs'], annReturns['val'], reward, stim) if not self.config.TEST: self.updates[entID].feather.scrawl( np.max(self.buffer[annID]['lmAction'][step]), ent, np.max(self.buffer[annID]['policy'][step]), reward, float(Asw)) return action