def create_ema_model(model): ema_model = Res_Deeplab(num_classes=num_classes) for param in ema_model.parameters(): param.detach_() mp = list(model.parameters()) mcp = list(ema_model.parameters()) n = len(mp) for i in range(0, n): mcp[i].data[:] = mp[i].data[:].clone() if len(gpus) > 1: if use_sync_batchnorm: ema_model = convert_model(ema_model) ema_model = DataParallelWithCallback(ema_model, device_ids=gpus) else: ema_model = torch.nn.DataParallel(ema_model, device_ids=gpus) return ema_model
def create_ema_model(model): #ema_model = getattr(models, config['arch']['type'])(self.train_loader.dataset.num_classes, **config['arch']['args']).to(self.device) ema_model = Res_Deeplab(num_classes=num_classes) for param in ema_model.parameters(): param.detach_() mp = list(model.parameters()) mcp = list(ema_model.parameters()) n = len(mp) for i in range(0, n): mcp[i].data[:] = mp[i].data[:].clone() #_, availble_gpus = self._get_available_devices(self.config['n_gpu']) #ema_model = torch.nn.DataParallel(ema_model, device_ids=availble_gpus) if len(gpus) > 1: #return torch.nn.DataParallel(ema_model, device_ids=gpus) if use_sync_batchnorm: ema_model = convert_model(ema_model) ema_model = DataParallelWithCallback(ema_model, device_ids=gpus) else: ema_model = torch.nn.DataParallel(ema_model, device_ids=gpus) return ema_model