Ejemplo n.º 1
0
    def __init__(self, config, args):
        self.config, self.args = config, args
        self.nANN = config.NPOP
        self.anns = [ANN(config, args) for _ in range(self.nANN)]

        self.updates, self.rollouts = defaultdict(lambda: Rollout()), {}
        self.updates_lm, self.rollouts_lm = defaultdict(lambda: Rollout()), {}
        self.initBuffer()

        self.lawmaker = Lawmaker(
            args, config) if args.lm else LawmakerAbstract(args, config)
Ejemplo n.º 2
0
    def __init__(self, config, args, idx):
        self.config, self.args = config, args
        self.nANN, self.h = config.NPOP, config.HIDDEN
        self.anns = [ANN(config, args) for _ in range(self.nANN)]

        self.init, self.nRollouts = True, 32
        self.updates = defaultdict(lambda: Rollout(config))
        self.blobs = []
        self.idx = idx
        self.ReplayMemory = [ReplayMemory(self.config) for _ in range(self.nANN)]
        self.ReplayMemoryLm = ReplayMemoryLm(self.config)
        self.forward, self.forward_lm = Forward(config), ForwardLm(config)

        self.lawmaker = Lawmaker(args, config) if args.lm else LawmakerAbstract(args, config)
        self.tick = 0
Ejemplo n.º 3
0
    def init(self):
        print('Initializing new model...')
        self.anns = [
            ANN(self.config, self.args).to(self.config.device)
            for _ in range(self.nANN)
        ]
        self.targetAnns = [
            ANN(self.config, self.args).to(self.config.device)
            for _ in range(self.nANN)
        ]
        self.updateTargetAnns()
        self.annsOpts = None
        if not self.config.TEST:
            self.annsOpts = [
                Adam(ann.parameters(), lr=0.0005, weight_decay=0.00001)
                for ann in self.anns
            ]

        self.lm = (Lawmaker(self.args, self.config).to(self.config.device) if
                   self.args.lm else LawmakerAbstract(self.args, self.config))
        self.lmOpt = None
        if not self.config.TEST and self.args.lm:
            self.lmOpt = Adam(self.lm.parameters(),
                              lr=0.0005,
                              weight_decay=0.00001)
Ejemplo n.º 4
0
    def init(self):
        print('Initializing new model...')
        self.unshared(self.config.NPOP)

        self.opt = None
        if not self.config.TEST:
            self.opt = [
                Adam(ann.parameters(), lr=self.config.LR, weight_decay=0.00001)
                for ann in self.anns
            ]
            self.scheduler = [StepLR(opt, 1, gamma=0.9998) for opt in self.opt]

        self.lawmaker = (Lawmaker(self.args,
                                  self.config,
                                  device=self.config.DEVICE_OPTIMIZER,
                                  batch_size=self.config.LSTM_PERIOD).to(
                                      self.config.DEVICE_OPTIMIZER)
                         if self.args.lm else LawmakerAbstract(
                             self.args,
                             self.config,
                             device=self.config.DEVICE_OPTIMIZER,
                             batch_size=self.config.LSTM_PERIOD))
        initialize_weights(self.lawmaker)
        if self.args.lm:
            self.lmOpt = Adam(self.lawmaker.parameters(),
                              lr=self.config.LR,
                              weight_decay=0.00001)
            self.lmScheduler = StepLR(self.lmOpt, 1, gamma=0.9998)
Ejemplo n.º 5
0
    def __init__(self, config, args):
        self.config, self.args = config, args
        self.nANN, self.h = config.NPOP, config.HIDDEN
        self.anns = [ANN(config, args) for _ in range(self.nANN)]

        self.updates, self.rollouts = defaultdict(lambda: Rollout(config)), {}
        self.updates_lm, self.rollouts_lm = defaultdict(
            lambda: Rollout(config)), {}
        self.nGrads = 0
        self.rets = defaultdict(deque)
        self.flat_states = defaultdict(deque)
        self.ent_states = defaultdict(deque)
        self.lmActions = defaultdict(deque)
        self.lmPolicy = defaultdict(deque)
        self.policy = defaultdict(deque)
        self.actions = defaultdict(deque)
        self.contacts = defaultdict(deque)
        self.buffer = None

        self.lawmaker = Lawmaker(
            args, config) if args.lm else LawmakerAbstract(args, config)
Ejemplo n.º 6
0
class Sword:
    def __init__(self, config, args):
        self.config, self.args = config, args
        self.nANN, self.h = config.NPOP, config.HIDDEN
        self.anns = [ANN(config, args) for _ in range(self.nANN)]

        self.updates, self.rollouts = defaultdict(lambda: Rollout(config)), {}
        self.updates_lm, self.rollouts_lm = defaultdict(
            lambda: Rollout(config)), {}
        self.nGrads = 0
        self.rets = defaultdict(deque)
        self.flat_states = defaultdict(deque)
        self.ent_states = defaultdict(deque)
        self.lmActions = defaultdict(deque)
        self.lmPolicy = defaultdict(deque)
        self.policy = defaultdict(deque)
        self.actions = defaultdict(deque)
        self.contacts = defaultdict(deque)
        self.buffer = None

        self.lawmaker = Lawmaker(
            args, config) if args.lm else LawmakerAbstract(args, config)

    def backward(self):
        self.rollouts_lm = {}
        self.blobs = [r.feather.blob for r in self.rollouts.values()]
        self.rollouts = {}
        length = min([len(self.rets[i]) for i in range(self.nANN)])
        if length == 0:
            return
        self.initBuffer()
        buffer = defaultdict(
            lambda: {
                'policy': defaultdict(list),
                'action': defaultdict(list),
                'lmPolicy': defaultdict(list),
                'lmAction': defaultdict(list),
                'flat': [],
                'ents': [],
                'return': []
            })
        for _ in range(length):
            for i in range(self.nANN):
                buffer[i]['flat'].append(self.flat_states[i].popleft())
                buffer[i]['ents'].append(self.ent_states[i].popleft())
                buffer[i]['return'].append(self.rets[i].popleft())
                for (k, v), action in zip(self.policy[i].popleft().items(),
                                          self.actions[i].popleft().values()):
                    buffer[i]['policy'][k].append(v.detach().numpy())
                    buffer[i]['action'][k].append(action.detach().numpy())
                for (k,
                     v), action in zip(self.lmPolicy[i].popleft().items(),
                                       self.lmActions[i].popleft().values()):
                    buffer[i]['lmPolicy'][k].append(v.detach().numpy())
                    buffer[i]['lmAction'][k].append(action.detach().numpy())
        for i in range(self.nANN):
            self.buffer[i]['flat'] = np.asarray(buffer[i]['flat'],
                                                dtype=np.float32)
            self.buffer[i]['ents'] = np.asarray(buffer[i]['ents'],
                                                dtype=np.float32)
            self.buffer[i]['return'] = np.asarray(buffer[i]['return'],
                                                  dtype=np.float32)
            self.buffer[i]['policy'] = {
                k: np.asarray(v, dtype=np.float32)
                for k, v in buffer[i]['policy'].items()
            }
            self.buffer[i]['action'] = {
                k: np.asarray(v, dtype=np.float32)
                for k, v in buffer[i]['action'].items()
            }
            self.buffer[i]['lmPolicy'] = {
                k: np.asarray(v, dtype=np.float32)
                for k, v in buffer[i]['lmPolicy'].items()
            }
            self.buffer[i]['lmAction'] = {
                k: np.asarray(v, dtype=np.float32)
                for k, v in buffer[i]['lmAction'].items()
            }

    def sendLogUpdate(self):
        blobs = self.blobs
        self.blobs = []
        return blobs

    def sendUpdate(self):
        if self.buffer is None:
            return None, None
        buffer = self.dispatchBuffer()
        return buffer, self.sendLogUpdate()

    def recvUpdate(self, update):
        update, update_lm = update
        for idx, paramVec in enumerate(update):
            setParameters(self.anns[idx], paramVec)
            zeroGrads(self.anns[idx])

        setParameters(self.lawmaker, update_lm[0])
        zeroGrads(self.lawmaker)

    def collectStep(self, entID, action, policy, flat, ents, reward, contact,
                    val):
        if self.config.TEST:
            return
        self.updates[entID].step(action, policy, flat, ents, reward, contact,
                                 val)

    def collectStepLm(self, entID, action, policy):
        if self.config.TEST:
            return
        self.updates_lm[entID].step(action, policy)

    def collectRollout(self, entID, ent, tick):
        assert entID not in self.rollouts
        rollout = self.updates[entID]
        rollout.finish()
        nGrads = rollout.lifespan
        self.rets[ent.annID] += rollout.rets
        self.ent_states[ent.annID] += rollout.ent_states[:-1]
        self.flat_states[ent.annID] += rollout.flat_states[:-1]
        self.policy[ent.annID] += rollout.policy[:-1]
        self.actions[ent.annID] += rollout.actions[:-1]
        self.contacts[ent.annID] += rollout.contacts[:-1]
        rollout_lm = self.updates_lm[entID]
        self.lmPolicy[ent.annID] += rollout_lm.policy[:-1]
        self.lmActions[ent.annID] += rollout_lm.actions[:-1]
        del self.updates_lm[entID]
        rollout.feather.blob.tick = tick
        self.rollouts[entID] = rollout
        del self.updates[entID]
        self.nGrads += nGrads

        if self.nGrads >= self.config.stepsPerEpoch:
            self.nGrads = 0
            self.backward()

    def initBuffer(self):
        self.buffer = defaultdict(dict)

    def dispatchBuffer(self):
        buffer = self.buffer
        self.buffer = None
        return buffer

    def getActionArguments(self, annReturns, stim, ent):
        actions = ActionTree(stim, ent, ActionV2).actions()
        move, attkShare = actions
        playerActions = [move]
        actionDecisions = {}
        moveAction = int(annReturns['actions']['move'])
        attack = moveAction > 4
        if attack:
            moveAction -= 5
        actionTargets = [move.args(stim, ent, self.config)[moveAction]]

        action = attkShare.args(stim, ent, self.config)['attack']
        targets = action.args(stim, ent, self.config)
        target, decision = checkTile(ent, int(attack), targets)
        playerActions.append(action), actionTargets.append([target])
        actionDecisions['attack'] = decision

        return playerActions, actionTargets, actionDecisions

    def decide(self, ent, stim, isDead, n_dead=0):
        entID, annID = ent.entID, ent.annID
        reward = self.config.STEPREWARD + self.config.DEADREWARD * n_dead

        stim_tensor = torchlib.Stim(ent, stim, self.config)
        outsLm = self.lawmaker(stim_tensor.flat.view(1, -1),
                               stim_tensor.ents.unsqueeze(0), isDead, annID)
        annReturns = self.anns[annID](stim_tensor.flat.view(1, -1),
                                      stim_tensor.ents.unsqueeze(0), outsLm,
                                      isDead)

        playerActions, actionTargets, actionDecisions = self.getActionArguments(
            annReturns, stim, ent)

        moveAction = int(annReturns['actions']['move'])
        attack = actionDecisions.get('attack', None)
        if moveAction > 4:
            moveAction -= 5
        ent.moveDec = moveAction
        contact = int(attack is not None)

        Asw = -np.mean([float(t.mean()) for t in outsLm['Qs'].values()])
        outsLm = self.lawmaker.get_punishment(outsLm, annReturns['actions'])
        Asw += np.mean([float(t) for t in outsLm['Qs'].values()])

        self.collectStep(entID, annReturns['actions'], annReturns['policy'],
                         stim_tensor.flat.numpy(), stim_tensor.ents.numpy(),
                         reward, contact, float(annReturns['val']))
        self.collectStepLm(entID, outsLm['actions'], outsLm['policy'])
        if not self.config.TEST:
            self.updates[entID].feather.scrawl(ent, float(annReturns['val']),
                                               reward, Asw, attack, contact)
        return playerActions, actionTargets
Ejemplo n.º 7
0
class Sword:
    def __init__(self, config, args, idx):
        self.config, self.args = config, args
        self.nANN, self.h = config.NPOP, config.HIDDEN
        self.anns = [ANN(config, args) for _ in range(self.nANN)]

        self.init, self.nRollouts = True, 32
        self.updates = defaultdict(lambda: Rollout(config))
        self.blobs = []
        self.idx = idx
        self.ReplayMemory = [
            ReplayMemory(self.config) for _ in range(self.nANN)
        ]
        self.ReplayMemoryLm = ReplayMemoryLm(self.config)
        self.forward, self.forward_lm = Forward(config), ForwardLm(config)
        self.buffer_size_to_send = 2**8

        self.lawmaker = Lawmaker(
            args, config) if args.lm else LawmakerAbstract(args, config)
        self.tick = 0
        self.logTick = 0

    def sendBufferUpdate(self):
        for replay in self.ReplayMemory:
            if len(replay) < self.buffer_size_to_send:
                return None, None
        buffer = [replay.send_buffer() for replay in self.ReplayMemory]
        priorities = [
            self.forward.get_priorities_from_samples(buf, ann, ann,
                                                     self.lawmaker)
            for buf, ann in zip(buffer, self.anns)
        ] if self.config.REPLAY_PRIO else [None] * len(self.anns)
        if self.args.lm:
            bufferLm = self.ReplayMemoryLm.send_buffer()
            prioritiesLm = self.forward_lm.get_priorities_from_samples(bufferLm, self.anns, self.lawmaker) if \
                self.config.REPLAY_PRIO else None
            return (buffer, priorities), (bufferLm, prioritiesLm)
        else:
            return (buffer, priorities), None

    def sendLogUpdate(self):
        self.logTick += 1
        if ((self.logTick + 1) % 2**6) == 0:
            blobs = self.blobs
            self.blobs = []
            return blobs
        return None

    def sendUpdate(self):
        recvs, recvs_lm = self.sendBufferUpdate()
        logs = self.sendLogUpdate() if recvs is not None else None
        return recvs, recvs_lm, logs

    def recvUpdate(self, update):
        update, update_lm = update
        if update is not None:
            self.loadAnnsFrom(update)
        if self.args.lm and (update_lm is not None):
            self.loadLmFrom(update_lm)

    def collectStep(self, entID, annID, s, atnArgs, reward, dead, val):
        if self.config.TEST:
            return
        actions = {key: val[1] for key, val in atnArgs.items()}
        self.ReplayMemory[annID].append(entID, s, actions, reward, dead, val)
        if self.args.lm:
            self.ReplayMemoryLm.append(entID, annID, s, actions, reward, dead,
                                       val)

    def collectRollout(self, entID, tick):
        rollout = self.updates[entID]
        rollout.feather.blob.tick = tick
        rollout.finish()
        self.blobs.append(rollout.feather.blob)
        del self.updates[entID]

    def decide(self, ent, stim, isDead=False, n_dead=0):
        entID, annID = ent.entID, ent.annID
        reward = self.config.DEADREWARD * n_dead + self.config.STEPREWARD

        flat, ents = self.prepareInput(ent, stim)
        outputsLm, punishmentsLm = self.lawmaker(flat, ents)
        atnArgs, val = self.anns[annID](flat, ents, self.config.EPS_CUR,
                                        punishmentsLm)
        action, arguments, decs = self.actionTree(ent, stim, atnArgs)

        attack = decs.get('attack', None)
        shareFood = decs.get('shareFood', None)
        shareWater = decs.get('shareWater', None)
        ent.moveDec = int(atnArgs['move'][1])
        contact = int(attack is not None)
        if not contact:
            punishAttack, punishWater, punishFood = None, None, None
        else:
            punishAttack = punishmentsLm.get('attack', [[None, None]])[0][1]
            punishWater = punishmentsLm.get('shareWater', [[None, None]])[0][1]
            punishFood = punishmentsLm.get('shareFood', [[None, None]])[0][1]
            ent.shareFoodDec = shareFood
            ent.shareWaterDec = shareWater
            ent.attackDec = attack

        self.collectStep(entID, annID, {
            'flat': flat,
            'ents': ents
        }, atnArgs, reward, isDead,
                         val.detach().mean(2))
        avgPunishmentLm = calcAvgPunishment(atnArgs, outputsLm)
        if not self.config.TEST:
            self.updates[entID].feather.scrawl(stim, ent,
                                               val.detach().mean(2), reward,
                                               avgPunishmentLm, punishAttack,
                                               punishWater, punishFood, attack,
                                               shareFood, shareWater, contact)
        return action, arguments

    def prepareInput(self, ent, env):
        s = torchlib.Stim(ent, env, self.config)
        return s.flat.unsqueeze(0), s.ents.unsqueeze(0)

    def actionTree(self, ent, env, outputs):
        actions = ActionTree(env, ent, ActionV2).actions()
        _, move, attkShare = actions

        playerActions = [move]
        actionTargets = [
            move.args(env, ent, self.config)[int(outputs['move'][1])]
        ]

        actionDecisions = {}
        for name in ['attack', 'shareWater', 'shareFood']:
            if name not in outputs.keys():
                continue
            action = attkShare.args(env, ent, self.config)[name]
            targets = action.args(env, ent, self.config)
            target, decision = checkTile(ent, int(outputs[name][1]), targets)
            playerActions.append(action), actionTargets.append([target])
            actionDecisions[name] = decision
        return playerActions, actionTargets, actionDecisions

    def reset_noise(self):
        nets = self.anns + [self.lawmaker]
        for net in nets:
            net.reset_noise()

    def loadAnnsFrom(self, states):
        [ann.load_state_dict(state) for ann, state in zip(self.anns, states)]

    def loadLmFrom(self, state):
        self.lawmaker.load_state_dict(state)
Ejemplo n.º 8
0
class Sword:
    def __init__(self, config, args, idx):
        self.config, self.args = config, args
        self.nANN, self.h = config.NPOP, config.HIDDEN
        self.anns = [ANN(config, args) for _ in range(self.nANN)]

        self.init, self.nRollouts = True, 32
        self.updates = defaultdict(lambda: Rollout(config))
        self.blobs = []
        self.idx = idx
        self.ReplayMemory = [ReplayMemory(self.config) for _ in range(self.nANN)]
        self.ReplayMemoryLm = ReplayMemoryLm(self.config)
        self.forward, self.forward_lm = Forward(config), ForwardLm(config)

        self.lawmaker = Lawmaker(args, config) if args.lm else LawmakerAbstract(args, config)
        self.tick = 0

    def sendBufferUpdate(self):
        buffer = [replay.send_buffer() for replay in self.ReplayMemory]
        priorities = [self.forward.get_priorities_from_samples(buf, ann, ann, self.lawmaker) for buf, ann in
                      zip(buffer, self.anns)] if self.config.REPLAY_PRIO else [None] * len(self.anns)
        if self.args.lm:
            bufferLm = self.ReplayMemoryLm.send_buffer()
            prioritiesLm = self.forward_lm.get_priorities_from_samples(bufferLm, self.anns, self.lawmaker) if \
                self.config.REPLAY_PRIO else None
            return (buffer, priorities), (bufferLm, prioritiesLm)
        else:
            return (buffer, priorities), None

    def sendLogUpdate(self):
        blobs = self.blobs
        self.blobs = []
        return blobs

    def sendUpdate(self):
        recvs, recvs_lm = self.sendBufferUpdate()
        logs = self.sendLogUpdate()
        return recvs, recvs_lm, logs

    def recvUpdate(self, update):
        update, update_lm = update
        if update is not None:
            self.loadAnnsFrom(update)
        if self.args.lm and (update_lm is not None):
            self.loadLmFrom(update_lm)

    def collectStep(self, entID, annID, s, atnArgs, reward, dead, val):
        if self.config.TEST:
            return
        actions = {key: val[1] for key, val in atnArgs.items()}
        self.ReplayMemory[annID].append(entID, s, actions, reward, dead, val)
        if self.args.lm:
            self.ReplayMemoryLm.append(entID, annID, s, actions, reward, dead, val)

    def collectRollout(self, entID):
        rollout = self.updates[entID]
        rollout.feather.blob.tick = self.tick
        rollout.finish()
        self.blobs.append(rollout.feather.blob)
        del self.updates[entID]

    def decide(self, entID, annID, stim, reward, reward_stats, apples, isDead):
        stim_tensor = self.prepareInput(stim)
        outsLm, punishLm = self.lawmaker(stim_tensor)
        atnArgs, val = self.anns[annID](stim_tensor, self.config.EPS_CUR, punishLm)
        action = int(atnArgs['action'][1])

        self.collectStep(entID, annID, stim_tensor, atnArgs, reward, isDead, val.detach().mean(2))
        if not self.config.TEST:
            self.updates[entID].feather.scrawl(
                annID, val.detach().mean(2), reward_stats, apples,
                outsLm['action'][:, action].mean().detach().numpy() - outsLm['action'].mean().detach().numpy())
        return action

    def prepareInput(self, stim):
        stim = np.transpose(stim, (2, 0, 1)).copy()
        stim_tensor = torch.from_numpy(stim).unsqueeze(0).float()
        return stim_tensor

    def reset_noise(self):
        nets = self.anns + [self.lawmaker]
        for net in nets:
            net.reset_noise()

    def loadAnnsFrom(self, states):
        [ann.load_state_dict(state) for ann, state in zip(self.anns, states)]

    def loadLmFrom(self, state):
        self.lawmaker.load_state_dict(state)
Ejemplo n.º 9
0
class Sword:
    def __init__(self, config, args):
        self.config, self.args = config, args
        self.nANN = config.NPOP
        self.anns = [ANN(config, args) for _ in range(self.nANN)]

        self.updates, self.rollouts = defaultdict(lambda: Rollout()), {}
        self.updates_lm, self.rollouts_lm = defaultdict(lambda: Rollout()), {}
        self.initBuffer()

        self.lawmaker = Lawmaker(
            args, config) if args.lm else LawmakerAbstract(args, config)

    def backward(self):
        self.rollouts_lm = {}
        self.blobs = [r.feather.blob for r in self.rollouts.values()]
        self.rollouts = {}

    def sendLogUpdate(self):
        blobs = self.blobs
        self.blobs = []
        return blobs

    def recvUpdate(self, update):
        update, update_lm = update
        for idx, paramVec in enumerate(update):
            setParameters(self.anns[idx], paramVec)
            zeroGrads(self.anns[idx])

        setParameters(self.lawmaker, update_lm[0])
        zeroGrads(self.lawmaker)

    def collectStep(self, entID, atnArgs, val, reward, stim=None):
        if self.config.TEST:
            return
        self.updates[entID].step(atnArgs, val, reward, stim)

    def collectRollout(self, entID, ent, tick, epoch):
        assert entID not in self.rollouts
        rollout = self.updates[entID]
        rollout.finish()
        rollout.feather.blob.tick = tick
        annID = ent.annID
        self.buffer[annID]['return'][(epoch + 1) * self.config.HORIZON -
                                     1] = self.buffer[annID]['reward'][
                                         (epoch + 1) * self.config.HORIZON - 1]
        for i in reversed(
                range(epoch * self.config.HORIZON,
                      (epoch + 1) * self.config.HORIZON - 1)):
            self.buffer[annID]['return'][i] = self.buffer[annID]['reward'][
                i] + 0.99 * self.buffer[annID]['return'][i + 1]
        self.rollouts[entID] = rollout
        del self.updates[entID]

    def initBuffer(self):
        batchSize = self.config.HORIZON * self.config.EPOCHS
        self.buffer = defaultdict(
            lambda: {
                'state': np.ndarray((batchSize, 3, 15, 15), dtype=int),
                'policy': np.ndarray((batchSize, 8), dtype=float),
                'lmPolicy': np.ndarray((batchSize, 8), dtype=float),
                'action': np.ndarray((batchSize, ), dtype=int),
                'reward': np.ndarray((batchSize, ), dtype=int),
                'return': np.ndarray((batchSize, ), dtype=float),
                'lmAction': np.ndarray((batchSize, 8), dtype=float)
            })

    def dispatchBuffer(self):
        buffer = deepcopy(self.buffer)
        self.initBuffer()
        return buffer

    def decide(self, ent, stim, reward, isDead, step, epoch):
        entID, annID = ent.agent_id + str(epoch), ent.annID

        stim = np.transpose(stim, (2, 0, 1)).copy()
        stim_tensor = torch.from_numpy(stim).unsqueeze(0).float()
        outsLm = self.lawmaker(stim_tensor, isDead, annID)
        annReturns = self.anns[annID](stim_tensor, outsLm, isDead)

        self.buffer[annID]['state'][step] = stim
        self.buffer[annID]['policy'][step] = annReturns['outputs']['action'][
            0].detach().numpy()
        self.buffer[annID]['action'][step] = annReturns['outputs']['action'][1]
        self.buffer[annID]['lmAction'][step] = outsLm['action'][1].detach(
        ).numpy()
        action = int(annReturns['outputs']['action'][1])

        Asw = -outsLm['action'][-1].mean()
        outsLm = self.lawmaker.get_punishment(
            outsLm, torch.tensor(action),
            annReturns['outputs']['action'][0].detach())
        self.buffer[annID]['lmPolicy'][step] = outsLm['action'][1].detach(
        ).numpy()
        Asw += float(outsLm['action'][-1])

        self.collectStep(entID, annReturns['outputs'], annReturns['val'],
                         reward, stim)
        if not self.config.TEST:
            self.updates[entID].feather.scrawl(
                np.max(self.buffer[annID]['lmAction'][step]), ent,
                np.max(self.buffer[annID]['policy'][step]), reward, float(Asw))
        return action