Esempio n. 1
0
    def forward(self, stim, actions, embed):
        nameMap, embed = embed
        atnTensor, atnTensorLens, atnLens, atnLenLens = actions

        batch, nAtn, nArgs, nAtnArg, keyDim = atnTensor.shape
        atnTensor = atnTensor.reshape(-1, keyDim)

        targs = [tuple(e) for e in atnTensor]
        names = self.names(nameMap, targs)
        targs = embed[names]
        targs = targs.view(batch, nAtn, nArgs, nAtnArg, -1)

        #Sum the atn and arg embedding to make a key dim
        targs = targs.sum(-2)

        #The dot prod net does not match dims.
        stim = stim.unsqueeze(1).unsqueeze(1)
        atns, atnsIdx = self.net(stim, targs, atnLens)

        if self.config.TEST:
            atns = atns.detach()

        atns = [unpack(atn, l) for atn, l in zip(atns, atnLens)]

        outList = (atns, atnsIdx)
        return outList
Esempio n. 2
0
    def unbatch(atnTensor, idxTensor, keyTensor, lenTensor):
        '''Internal inverse batcher'''
        lenTensor, lenLens = lenTensor
        actions = []

        #Unpack outer set (careful with unpack dim)
        atnTensor = utils.unpack(atnTensor, lenLens, dim=1)
        idxTensor = utils.unpack(idxTensor, lenLens, dim=1)
        keyTensor = utils.unpack(keyTensor, lenLens, dim=1)
        lenTensor = utils.unpack(lenTensor, lenLens, dim=1)

        #Unpack inner set
        for atns, idxs, keys, lens in zip(atnTensor, idxTensor, keyTensor,
                                          lenTensor):
            atns = utils.unpack(atns, lens, dim=-2)
            actions.append(list(zip(keys, atns, idxs)))

        return actions
Esempio n. 3
0
    def forward(self, pop, rollouts, data):
        '''Recompute forward pass and assemble rollout objects'''
        keys, _, stims, rawActions, actions, rewards, dones = data
        _, outs, vals = self.net(pop, stims, atnArgs=actions)

        #Unpack outputs
        atnTensor, idxTensor, atnKeyTensor, lenTensor = actions
        lens, lenTensor = lenTensor
        atnOuts = utils.unpack(outs, lenTensor, dim=1)

        #Collect rollouts
        for key, out, atn, val, reward, done in zip(keys, outs, rawActions,
                                                    vals, rewards, dones):

            atnKey, lens, atn = list(
                zip(*[(k, len(e), idx) for k, e, idx in atn]))

            atn = np.array(atn)
            out = utils.unpack(out, lens)

            self.manager.fill(key, (atnKey, atn, out), val, done)

        return rollouts
Esempio n. 4
0
    def unbatch(batch):
        '''Internal inverse batcher'''
        stims = []
        for group, stat in batch.items():
            keys, values = stat

            #Assign keys
            for idx, key in enumerate(keys):
                if idx == len(stims):
                    stims.append(defaultdict(list))
                stims[idx][group] = [key, defaultdict(list)]

            #Assign values
            for attr, vals in values.items():
                lens = [len(e) for e in keys]
                vals = utils.unpack(vals, lens)
                for idx, val in enumerate(vals):
                    stims[idx][group][1][attr] = val

        return stims
Esempio n. 5
0
   def forward(self, net, stims):
      features, lookup = {}, Lookup()
      self.actions(lookup)
 
      #Pack entities of each observation set
      for group, stim in stims.items():
         names, subnet = stim
         embs, feats     = self.attrs(group, net.attns[group], subnet)
         features[group] = feats

         #Unpack and flatten for embedding
         lens = [len(e) for e in names]
         vals = utils.unpack(embs, lens, dim=1)
         for k, v in zip(names, vals):
            v = v.split(1, dim=0)
            lookup.add(k, v)


      k = [tuple([0]*Serial.KEYLEN)]
      v = [v[-1] * 0]
      lookup.add(k, v)

      k = [tuple([-1]*Serial.KEYLEN)]
      v = [v[-1] * 0]
      lookup.add(k, v)
   
      #Concat feature block
      feats = features 
      features = list(features.values())
      features = torch.stack(features, -2)

      #Batch, group (tile/ent), hidden
      features = net.attns['Meta'](features)#.squeeze(0)
      
      embed = lookup.table()
      #embed = None
      return features, embed
Esempio n. 6
0
   def forward(self, net, stims):
      features, lookup = {}, Lookup()
      self.actions(lookup)

      #Pack entities of each observation set
      for group, stim in stims.items():
         names, subnet = stim
         emb = self.attrs(group, net, subnet)
         features[group] = emb.unsqueeze(0)

         #Unpack and flatten for embedding
         lens = [len(e) for e in names]
         vals = utils.unpack(emb, lens, dim=1)
         for k, v in zip(names, vals):
            v = v.split(1, dim=0)
            lookup.add(k, v)

      #Concat feature block
      features = list(features.values())
      features = torch.cat(features, -2)
      features = net.attn2(features).squeeze(0)
      
      embed = lookup.table()
      return features, embed