def forward(self, stim, actions, embed): nameMap, embed = embed atnTensor, atnTensorLens, atnLens, atnLenLens = actions batch, nAtn, nArgs, nAtnArg, keyDim = atnTensor.shape atnTensor = atnTensor.reshape(-1, keyDim) targs = [tuple(e) for e in atnTensor] names = self.names(nameMap, targs) targs = embed[names] targs = targs.view(batch, nAtn, nArgs, nAtnArg, -1) #Sum the atn and arg embedding to make a key dim targs = targs.sum(-2) #The dot prod net does not match dims. stim = stim.unsqueeze(1).unsqueeze(1) atns, atnsIdx = self.net(stim, targs, atnLens) if self.config.TEST: atns = atns.detach() atns = [unpack(atn, l) for atn, l in zip(atns, atnLens)] outList = (atns, atnsIdx) return outList
def unbatch(atnTensor, idxTensor, keyTensor, lenTensor): '''Internal inverse batcher''' lenTensor, lenLens = lenTensor actions = [] #Unpack outer set (careful with unpack dim) atnTensor = utils.unpack(atnTensor, lenLens, dim=1) idxTensor = utils.unpack(idxTensor, lenLens, dim=1) keyTensor = utils.unpack(keyTensor, lenLens, dim=1) lenTensor = utils.unpack(lenTensor, lenLens, dim=1) #Unpack inner set for atns, idxs, keys, lens in zip(atnTensor, idxTensor, keyTensor, lenTensor): atns = utils.unpack(atns, lens, dim=-2) actions.append(list(zip(keys, atns, idxs))) return actions
def forward(self, pop, rollouts, data): '''Recompute forward pass and assemble rollout objects''' keys, _, stims, rawActions, actions, rewards, dones = data _, outs, vals = self.net(pop, stims, atnArgs=actions) #Unpack outputs atnTensor, idxTensor, atnKeyTensor, lenTensor = actions lens, lenTensor = lenTensor atnOuts = utils.unpack(outs, lenTensor, dim=1) #Collect rollouts for key, out, atn, val, reward, done in zip(keys, outs, rawActions, vals, rewards, dones): atnKey, lens, atn = list( zip(*[(k, len(e), idx) for k, e, idx in atn])) atn = np.array(atn) out = utils.unpack(out, lens) self.manager.fill(key, (atnKey, atn, out), val, done) return rollouts
def unbatch(batch): '''Internal inverse batcher''' stims = [] for group, stat in batch.items(): keys, values = stat #Assign keys for idx, key in enumerate(keys): if idx == len(stims): stims.append(defaultdict(list)) stims[idx][group] = [key, defaultdict(list)] #Assign values for attr, vals in values.items(): lens = [len(e) for e in keys] vals = utils.unpack(vals, lens) for idx, val in enumerate(vals): stims[idx][group][1][attr] = val return stims
def forward(self, net, stims): features, lookup = {}, Lookup() self.actions(lookup) #Pack entities of each observation set for group, stim in stims.items(): names, subnet = stim embs, feats = self.attrs(group, net.attns[group], subnet) features[group] = feats #Unpack and flatten for embedding lens = [len(e) for e in names] vals = utils.unpack(embs, lens, dim=1) for k, v in zip(names, vals): v = v.split(1, dim=0) lookup.add(k, v) k = [tuple([0]*Serial.KEYLEN)] v = [v[-1] * 0] lookup.add(k, v) k = [tuple([-1]*Serial.KEYLEN)] v = [v[-1] * 0] lookup.add(k, v) #Concat feature block feats = features features = list(features.values()) features = torch.stack(features, -2) #Batch, group (tile/ent), hidden features = net.attns['Meta'](features)#.squeeze(0) embed = lookup.table() #embed = None return features, embed
def forward(self, net, stims): features, lookup = {}, Lookup() self.actions(lookup) #Pack entities of each observation set for group, stim in stims.items(): names, subnet = stim emb = self.attrs(group, net, subnet) features[group] = emb.unsqueeze(0) #Unpack and flatten for embedding lens = [len(e) for e in names] vals = utils.unpack(emb, lens, dim=1) for k, v in zip(names, vals): v = v.split(1, dim=0) lookup.add(k, v) #Concat feature block features = list(features.values()) features = torch.cat(features, -2) features = net.attn2(features).squeeze(0) embed = lookup.table() return features, embed