def _data_parallel(self, batch): """ Do the forward pass using multiple GPUs. This is a simplification of torch.nn.parallel.data_parallel to support the allennlp model interface. """ inputs, module_kwargs = scatter_kwargs((), batch, self._cuda_devices, 0) used_device_ids = self._cuda_devices[:len(inputs)] replicas = replicate(self._model, used_device_ids) outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) # Only the 'loss' is needed. # a (num_gpu, ) tensor with loss on each GPU losses = gather([output['loss'].unsqueeze(0) for output in outputs], used_device_ids[0], 0) return {'loss': losses.mean()}
def gather(self, outputs, output_device): return gather(outputs, output_device, dim=self.dim).mean()
def parallel_forward(self, dense_x, lS_o, lS_i): ### prepare model (overwrite) ### # WARNING: # of devices must be >= batch size in parallel_forward call batch_size = dense_x.size()[0] ndevices = min(self.ndevices, batch_size, len(self.emb_l)) device_ids = range(ndevices) # WARNING: must redistribute the model if mini-batch size changes(this is common # for last mini-batch, when # of elements in the dataset/batch size is not even if self.parallel_model_batch_size != batch_size: self.parallel_model_is_not_prepared = True if self.sync_dense_params or self.parallel_model_is_not_prepared: # replicate mlp (data parallelism) self.bot_l_replicas = replicate(self.bot_l, device_ids) self.top_l_replicas = replicate(self.top_l, device_ids) # distribute embeddings (model parallelism) t_list = [] for k, emb in enumerate(self.emb_l): d = torch.device("cuda:" + str(k % ndevices)) emb.to(d) t_list.append(emb.to(d)) self.emb_l = nn.ModuleList(t_list) self.parallel_model_batch_size = batch_size self.parallel_model_is_not_prepared = False ### prepare input (overwrite) ### # scatter dense features (data parallelism) # print(dense_x.device) dense_x = scatter(dense_x, device_ids, dim=0) # distribute sparse features (model parallelism) if (len(self.emb_l) != len(lS_o)) or (len(self.emb_l) != len(lS_i)): sys.exit("ERROR: corrupted model input detected in parallel_forward call") t_list = [] i_list = [] for k, _ in enumerate(self.emb_l): d = torch.device("cuda:" + str(k % ndevices)) t_list.append(lS_o[k].to(d)) i_list.append(lS_i[k].to(d)) lS_o = t_list lS_i = i_list ### compute results in parallel ### # bottom mlp # WARNING: Note that the self.bot_l is a list of bottom mlp modules # that have been replicated across devices, while dense_x is a tuple of dense # inputs that has been scattered across devices on the first (batch) dimension. # The output is a list of tensors scattered across devices according to the # distribution of dense_x. x = parallel_apply(self.bot_l_replicas, dense_x, None, device_ids) # debug prints # print(x) # embeddings ly = self.apply_emb(lS_o, lS_i, self.emb_l) # debug prints # print(ly) # butterfly shuffle (implemented inefficiently for now) # WARNING: Note that at this point we have the result of the embedding lookup # for the entire batch on each device. We would like to obtain partial results # corresponding to all embedding lookups, but part of the batch on each device. # Therefore, matching the distribution of output of bottom mlp, so that both # could be used for subsequent interactions on each device. if len(self.emb_l) != len(ly): sys.exit("ERROR: corrupted intermediate result in parallel_forward call") t_list = [] for k, _ in enumerate(self.emb_l): d = torch.device("cuda:" + str(k % ndevices)) y = scatter(ly[k], device_ids, dim=0) t_list.append(y) # adjust the list to be ordered per device ly = list(map(lambda y: list(y), zip(*t_list))) # debug prints # print(ly) # interactions z = [] for k in range(ndevices): zk = self.interact_features(x[k], ly[k]) z.append(zk) # debug prints # print(z) # top mlp # WARNING: Note that the self.top_l is a list of top mlp modules that # have been replicated across devices, while z is a list of interaction results # that by construction are scattered across devices on the first (batch) dim. # The output is a list of tensors scattered across devices according to the # distribution of z. p = parallel_apply(self.top_l_replicas, z, None, device_ids) ### gather the distributed results ### p0 = gather(p, self.output_d, dim=0) # clamp output if needed if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: z0 = torch.clamp( p0, min=self.loss_threshold, max=(1.0 - self.loss_threshold) ) else: z0 = p0 return z0
def validation(model, val_loader, epoch, writer): # set evaluate mode model.eval() total_correct, total_label = 0, 0 total_correct_hb, total_label_hb = 0, 0 total_correct_fb, total_label_fb = 0, 0 hist = np.zeros((args.num_classes, args.num_classes)) hist_hb = np.zeros((args.hbody_cls, args.hbody_cls)) hist_fb = np.zeros((args.fbody_cls, args.fbody_cls)) # Iterate over data. from tqdm import tqdm tbar = tqdm(val_loader) for idx, batch in enumerate(tbar): image, target, hlabel, flabel, _ = batch image, target, hlabel, flabel = image.cuda(), target.cuda( ), hlabel.cuda(), flabel.cuda() with torch.no_grad(): h, w = target.size(1), target.size(2) outputs = model(image) outputs = gather(outputs, 0, dim=0) preds = F.interpolate(input=outputs[0], size=(h, w), mode='bilinear', align_corners=True) preds_hb = F.interpolate(input=outputs[1], size=(h, w), mode='bilinear', align_corners=True) preds_fb = F.interpolate(input=outputs[2], size=(h, w), mode='bilinear', align_corners=True) if idx % 50 == 0: img_vis = inv_preprocess(image, num_images=args.save_num) label_vis = decode_predictions(target.int(), num_images=args.save_num, num_classes=args.num_classes) pred_vis = decode_predictions(torch.argmax(preds, dim=1), num_images=args.save_num, num_classes=args.num_classes) # visual grids img_grid = torchvision.utils.make_grid( torch.from_numpy(img_vis.transpose(0, 3, 1, 2))) label_grid = torchvision.utils.make_grid( torch.from_numpy(label_vis.transpose(0, 3, 1, 2))) pred_grid = torchvision.utils.make_grid( torch.from_numpy(pred_vis.transpose(0, 3, 1, 2))) writer.add_image('val_images', img_grid, epoch * len(val_loader) + idx + 1) writer.add_image('val_labels', label_grid, epoch * len(val_loader) + idx + 1) writer.add_image('val_preds', pred_grid, epoch * len(val_loader) + idx + 1) # pixelAcc correct, labeled = batch_pix_accuracy(preds.data, target) correct_hb, labeled_hb = batch_pix_accuracy(preds_hb.data, hlabel) correct_fb, labeled_fb = batch_pix_accuracy(preds_fb.data, flabel) # mIoU hist += fast_hist(preds, target, args.num_classes) hist_hb += fast_hist(preds_hb, hlabel, args.hbody_cls) hist_fb += fast_hist(preds_fb, flabel, args.fbody_cls) total_correct += correct total_correct_hb += correct_hb total_correct_fb += correct_fb total_label += labeled total_label_hb += labeled_hb total_label_fb += labeled_fb pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label) IoU = round(np.nanmean(per_class_iu(hist)) * 100, 2) pixAcc_hb = 1.0 * total_correct_hb / (np.spacing(1) + total_label_hb) IoU_hb = round(np.nanmean(per_class_iu(hist_hb)) * 100, 2) pixAcc_fb = 1.0 * total_correct_fb / (np.spacing(1) + total_label_fb) IoU_fb = round(np.nanmean(per_class_iu(hist_fb)) * 100, 2) # plot progress tbar.set_description('{} / {} | {pixAcc:.4f}, {IoU:.4f} |' \ '{pixAcc_hb:.4f}, {IoU_hb:.4f} |' \ '{pixAcc_fb:.4f}, {IoU_fb:.4f}'.format(idx + 1, len(val_loader), pixAcc=pixAcc, IoU=IoU,pixAcc_hb=pixAcc_hb, IoU_hb=IoU_hb,pixAcc_fb=pixAcc_fb, IoU_fb=IoU_fb)) print('\n per class iou part: {}'.format(per_class_iu(hist) * 100)) print('per class iou hb: {}'.format(per_class_iu(hist_hb) * 100)) print('per class iou fb: {}'.format(per_class_iu(hist_fb) * 100)) mIoU = round(np.nanmean(per_class_iu(hist)) * 100, 2) mIoU_hb = round(np.nanmean(per_class_iu(hist_hb)) * 100, 2) mIoU_fb = round(np.nanmean(per_class_iu(hist_fb)) * 100, 2) writer.add_scalar('val_pixAcc', pixAcc, epoch) writer.add_scalar('val_mIoU', mIoU, epoch) writer.add_scalar('val_pixAcc_hb', pixAcc_hb, epoch) writer.add_scalar('val_mIoU_hb', mIoU_hb, epoch) writer.add_scalar('val_pixAcc_fb', pixAcc_fb, epoch) writer.add_scalar('val_mIoU_fb', mIoU_fb, epoch) tbar.close() return pixAcc, mIoU
def gather(self, outputs, output_device): if self.training: return outputs else: return gather(outputs, output_device, dim=self.dim)
def gather(self, outputs, output_device): if self.gather_: return gather(outputs, output_device, dim=self.dim) return outputs
def gather(self, outputs, output_device): return gather(outputs, output_device, dim=0)