예제 #1
0
 def forward(self, inputs, single=False):  #pylint: disable=arguments-differ
     if single or not self.gpus or len(self.gpus) == 1:
         return self._forward(inputs)
     # return data_parallel(self.super_net, (inputs, self.genotypes_grouped), self.gpus)
     return data_parallel(self, (inputs, ),
                          self.gpus,
                          module_kwargs={"single": True})
예제 #2
0
파일: ofa.py 프로젝트: zeta1999/aw_nas
 def forward(self, inputs, single=False):  # pylint: disable=arguments-differ
     if self.multiprocess:
         out = self.super_net.parallel_model.forward(inputs, self.rollout)
     elif len(self.gpus) > 1:
         out = data_parallel(self, (inputs, ),
                             self.gpus,
                             module_kwargs={"single": True})
     else:
         return self._forward(inputs)
예제 #3
0
    def forward(self, inputs, detach_arch=True): #pylint: disable=arguments-differ
        if detach_arch:
            arch = [
                DartsArch(
                    op_weights=op_weights.detach(),
                    edge_norms=edge_norms.detach() if edge_norms is not None else None
                ) for op_weights, edge_norms in self.arch
            ]
        else:
            arch = self.arch

        if not self.gpus or len(self.gpus) == 1:
            return self.super_net.forward(inputs, arch, detach_arch=detach_arch)

        if arch[0].op_weights.ndimension() == 2:
            arch = [
                DartsArch(
                    op_weights=a.op_weights.repeat(len(self.gpus), 1),
                    edge_norms=(a.edge_norms.repeat(len(self.gpus)) \
                     if a.edge_norms is not None else None))
                for a in arch
            ]
        else:
            # Ugly fix for rollout_size > 1
            # call scatter here and stack...
            # split along dimension 1,
            # then concatenate along dimension 0 for `data_parallel` to scatter it again
            num_split = len(self.gpus)
            rollout_batch_size = arch[0].op_weights.shape[1]
            assert rollout_batch_size % num_split == 0
            split_size = rollout_batch_size // num_split
            # arch = [torch.cat(torch.split(a, split_size, dim=1), dim=0) for a in arch]
            # Note: edge_norms (1-dim) do not support batch_size, just repeat
            arch = [DartsArch(
                op_weights=torch.cat(torch.split(a.op_weights, split_size, dim=1), dim=0),
                edge_norms=(a.edge_norms.repeat(len(self.gpus)) \
                            if a.edge_norms is not None else None))
                    for a in arch]
        return data_parallel(self.super_net, (inputs, arch), self.gpus,
                             module_kwargs={"detach_arch": detach_arch})
예제 #4
0
 def forward(self, inputs, detach_arch=True):  #pylint: disable=arguments-differ
     arch = [a.detach() for a in self.arch] if detach_arch else self.arch
     if not self.gpus or len(self.gpus) == 1:
         return self.super_net.forward(inputs,
                                       arch,
                                       detach_arch=detach_arch)
     if arch[0].ndimension() == 2:
         arch = [a.repeat([len(self.gpus), 1]) for a in arch]
     else:
         # Ugly fix for rollout_size > 1
         # call scatter here and stack...
         # split along dimension 1,
         # then concatenate along dimension 0 for `data_parallel` to scatter it again
         num_split = len(self.gpus)
         rollout_batch_size = arch[0].shape[1]
         assert rollout_batch_size % num_split == 0
         split_size = rollout_batch_size // num_split
         arch = [
             torch.cat(torch.split(a, split_size, dim=1), dim=0)
             for a in arch
         ]
     return data_parallel(self.super_net, (inputs, arch),
                          self.gpus,
                          module_kwargs={"detach_arch": detach_arch})