def forward(self, inputs, single=False): #pylint: disable=arguments-differ if single or not self.gpus or len(self.gpus) == 1: return self._forward(inputs) # return data_parallel(self.super_net, (inputs, self.genotypes_grouped), self.gpus) return data_parallel(self, (inputs, ), self.gpus, module_kwargs={"single": True})
def forward(self, inputs, single=False): # pylint: disable=arguments-differ if self.multiprocess: out = self.super_net.parallel_model.forward(inputs, self.rollout) elif len(self.gpus) > 1: out = data_parallel(self, (inputs, ), self.gpus, module_kwargs={"single": True}) else: return self._forward(inputs)
def forward(self, inputs, detach_arch=True): #pylint: disable=arguments-differ if detach_arch: arch = [ DartsArch( op_weights=op_weights.detach(), edge_norms=edge_norms.detach() if edge_norms is not None else None ) for op_weights, edge_norms in self.arch ] else: arch = self.arch if not self.gpus or len(self.gpus) == 1: return self.super_net.forward(inputs, arch, detach_arch=detach_arch) if arch[0].op_weights.ndimension() == 2: arch = [ DartsArch( op_weights=a.op_weights.repeat(len(self.gpus), 1), edge_norms=(a.edge_norms.repeat(len(self.gpus)) \ if a.edge_norms is not None else None)) for a in arch ] else: # Ugly fix for rollout_size > 1 # call scatter here and stack... # split along dimension 1, # then concatenate along dimension 0 for `data_parallel` to scatter it again num_split = len(self.gpus) rollout_batch_size = arch[0].op_weights.shape[1] assert rollout_batch_size % num_split == 0 split_size = rollout_batch_size // num_split # arch = [torch.cat(torch.split(a, split_size, dim=1), dim=0) for a in arch] # Note: edge_norms (1-dim) do not support batch_size, just repeat arch = [DartsArch( op_weights=torch.cat(torch.split(a.op_weights, split_size, dim=1), dim=0), edge_norms=(a.edge_norms.repeat(len(self.gpus)) \ if a.edge_norms is not None else None)) for a in arch] return data_parallel(self.super_net, (inputs, arch), self.gpus, module_kwargs={"detach_arch": detach_arch})
def forward(self, inputs, detach_arch=True): #pylint: disable=arguments-differ arch = [a.detach() for a in self.arch] if detach_arch else self.arch if not self.gpus or len(self.gpus) == 1: return self.super_net.forward(inputs, arch, detach_arch=detach_arch) if arch[0].ndimension() == 2: arch = [a.repeat([len(self.gpus), 1]) for a in arch] else: # Ugly fix for rollout_size > 1 # call scatter here and stack... # split along dimension 1, # then concatenate along dimension 0 for `data_parallel` to scatter it again num_split = len(self.gpus) rollout_batch_size = arch[0].shape[1] assert rollout_batch_size % num_split == 0 split_size = rollout_batch_size // num_split arch = [ torch.cat(torch.split(a, split_size, dim=1), dim=0) for a in arch ] return data_parallel(self.super_net, (inputs, arch), self.gpus, module_kwargs={"detach_arch": detach_arch})