def collectOutputs(self, realm, obs, actions, rewards, dones): '''Collects output data to internal buffers''' self.partial.clear() for ob, atn, reward, done in zip(obs, actions, rewards, dones): _, key, atn = Serial.outputs(realm, ob, atn) iden = Serial.nontemporal(key) assert iden in self.complete self.complete[iden].outputs(atn, reward, done)
def collectInputs(self, realm, obs, stims): '''Collects observation data to internal buffers''' for ob, stim, in zip(obs, stims): _, key, stim = Serial.inputs(realm, ob, stim) self.nUpdates += 1 iden = Serial.nontemporal(key) self.partial[iden].inputs(key, ob, stim) self.complete[iden].inputs(key, ob, stim)
def recv(partial, full, packets): '''Unpack rollouts from workers on optim server Args: partial: A defaultdict of partially complete rollouts full: A defaultdict of complete rollouts packets: a list of serialized experience packets ''' nUpdates, nRollouts = 0, 0 for sword, data in enumerate(packets): keys, stims, actions, rewards, dones = data stims = Stimulus.unbatch(stims) actions = Action.unbatch(*actions) for iden, stim, atn, reward, done in zip(keys, stims, actions, rewards, dones): key = Serial.nontemporal(iden) partial[key].inputs(iden, None, stim) partial[key].outputs(atn, reward, done) if partial[key].done: assert key not in full full[key] = partial[key] del partial[key] nUpdates += len(full[key]) nRollouts += 1 return nUpdates, nRollouts
def fill(self, key, out, val, done): '''Fill in output/value data needed for the backward pass''' key = Serial.nontemporal(key) rollout = self.complete[key] rollout.fill(key, out, val) if done: rollout.feather.finish() self.log.append(rollout.feather.blob) del self.complete[key]
def serialize(stim, iden): '''Internal stimulus serializer for communication across machines''' from forge.ethyr.io import Serial rets = {} for group, data in stim.items(): names, data = data serialNames = [] for name in names: serialName = Serial.key(name, iden) name.injectedSerial = serialName serialNames.append(serialName) rets[group] = (serialNames, data) return rets
def serialize(stim): '''Internal stimulus serializer for communication across machines''' from forge.ethyr.io import Serial rets = {} for group, data in stim.items(): names, data = data serialized = [] for name in names: key = Serial.key(name) name.injectedSerial = key serialized.append(key) rets[group] = (serialized, data) return rets
def actions(self, lookup): '''Embed actions''' for atn in StaticAction.actions: #Brackets on atn.serial? idx = torch.Tensor([atn.idx]) idx = idx.long().to(self.config.DEVICE) emb = self.action(idx) #Dirty hack -- remove duplicates lookup.add([atn], [emb]) key = Serial.key(atn) #What to replace serial with here lookup.add([key], [emb])
def grouped(rollouts): groups = defaultdict(dict) for key, rollout in rollouts.items(): groups[Serial.population(key)][key] = rollout return groups.items()