def step(self, packet, weights, backward): '''Synchronizes weights from upstream; computes agent decisions; computes policy updates. Args: packet : An IO object specifying observations weights : An optional parameter vector to replace model weights backward : (bool) Whether of not a backward pass should be performed Returns: data : The same IO object populated with action decisions grads : A vector of gradients aggregated across trajectories summary : A BlobSummary object logging agent statistics ''' grads, blobs = None, None #Sync model weights; batch obs; compute forward pass setParameters(self.net, weights) self.manager.collectInputs(packet) self.net(packet, self.manager) #Compute backward pass and logs from full rollouts, #discarding any partial trajectories if backward and not self.config.TEST: rollouts, blobs = self.manager.step() optim.backward(rollouts, self.config) #self.manager.inputs.clear() grads = self.net.grads() return packet, grads, blobs
def recvUpdate(self, update): update, update_lm = update for idx, paramVec in enumerate(update): setParameters(self.anns[idx], paramVec) zeroGrads(self.anns[idx]) setParameters(self.lawmaker, update_lm[0]) zeroGrads(self.lawmaker)
def permuteNet(self, goodIdx, badIdx): goodNet = self.model.net.net[goodIdx] badNet = self.model.net.net[badIdx] goodParams = getParameters(goodNet) noise = self.config.PERMVAL * np.random.randn(len(goodParams)) goodParams = np.array(goodParams) + noise setParameters(badNet, goodParams)
def recvUpdate(self, update): if update is None: return setParameters(self, update)
def recvUpdate(self, update): for idx, paramVec in enumerate(update): setParameters(self.anns[idx], paramVec) zeroGrads(self.anns[idx])
def syncParameters(self): parameters = self.parameters.detach().numpy().tolist() setParameters(self.net, parameters)