Ejemplo n.º 1
0
    def _data_parallel(self, batch):
        """
        Do the forward pass using multiple GPUs.  This is a simplification
        of torch.nn.parallel.data_parallel to support the allennlp model
        interface.
        """
        inputs, module_kwargs = scatter_kwargs((), batch, self._cuda_devices, 0)
        used_device_ids = self._cuda_devices[:len(inputs)]
        replicas = replicate(self._model, used_device_ids)
        outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)

        # Only the 'loss' is needed.
        # a (num_gpu, ) tensor with loss on each GPU
        losses = gather([output['loss'].unsqueeze(0) for output in outputs], used_device_ids[0], 0)
        return {'loss': losses.mean()}
 def gather(self, outputs, output_device):
     return gather(outputs, output_device, dim=self.dim).mean()
Ejemplo n.º 3
0
    def parallel_forward(self, dense_x, lS_o, lS_i):
        ### prepare model (overwrite) ###
        # WARNING: # of devices must be >= batch size in parallel_forward call
        batch_size = dense_x.size()[0]
        ndevices = min(self.ndevices, batch_size, len(self.emb_l))
        device_ids = range(ndevices)
        # WARNING: must redistribute the model if mini-batch size changes(this is common
        # for last mini-batch, when # of elements in the dataset/batch size is not even
        if self.parallel_model_batch_size != batch_size:
            self.parallel_model_is_not_prepared = True

        if self.sync_dense_params or self.parallel_model_is_not_prepared:
            # replicate mlp (data parallelism)
            self.bot_l_replicas = replicate(self.bot_l, device_ids)
            self.top_l_replicas = replicate(self.top_l, device_ids)
            # distribute embeddings (model parallelism)
            t_list = []
            for k, emb in enumerate(self.emb_l):
                d = torch.device("cuda:" + str(k % ndevices))
                emb.to(d)
                t_list.append(emb.to(d))
            self.emb_l = nn.ModuleList(t_list)
            self.parallel_model_batch_size = batch_size
            self.parallel_model_is_not_prepared = False

        ### prepare input (overwrite) ###
        # scatter dense features (data parallelism)
        # print(dense_x.device)
        dense_x = scatter(dense_x, device_ids, dim=0)
        # distribute sparse features (model parallelism)
        if (len(self.emb_l) != len(lS_o)) or (len(self.emb_l) != len(lS_i)):
            sys.exit("ERROR: corrupted model input detected in parallel_forward call")

        t_list = []
        i_list = []
        for k, _ in enumerate(self.emb_l):
            d = torch.device("cuda:" + str(k % ndevices))
            t_list.append(lS_o[k].to(d))
            i_list.append(lS_i[k].to(d))
        lS_o = t_list
        lS_i = i_list

        ### compute results in parallel ###
        # bottom mlp
        # WARNING: Note that the self.bot_l is a list of bottom mlp modules
        # that have been replicated across devices, while dense_x is a tuple of dense
        # inputs that has been scattered across devices on the first (batch) dimension.
        # The output is a list of tensors scattered across devices according to the
        # distribution of dense_x.
        x = parallel_apply(self.bot_l_replicas, dense_x, None, device_ids)
        # debug prints
        # print(x)

        # embeddings
        ly = self.apply_emb(lS_o, lS_i, self.emb_l)
        # debug prints
        # print(ly)

        # butterfly shuffle (implemented inefficiently for now)
        # WARNING: Note that at this point we have the result of the embedding lookup
        # for the entire batch on each device. We would like to obtain partial results
        # corresponding to all embedding lookups, but part of the batch on each device.
        # Therefore, matching the distribution of output of bottom mlp, so that both
        # could be used for subsequent interactions on each device.
        if len(self.emb_l) != len(ly):
            sys.exit("ERROR: corrupted intermediate result in parallel_forward call")

        t_list = []
        for k, _ in enumerate(self.emb_l):
            d = torch.device("cuda:" + str(k % ndevices))
            y = scatter(ly[k], device_ids, dim=0)
            t_list.append(y)
        # adjust the list to be ordered per device
        ly = list(map(lambda y: list(y), zip(*t_list)))
        # debug prints
        # print(ly)

        # interactions
        z = []
        for k in range(ndevices):
            zk = self.interact_features(x[k], ly[k])
            z.append(zk)
        # debug prints
        # print(z)

        # top mlp
        # WARNING: Note that the self.top_l is a list of top mlp modules that
        # have been replicated across devices, while z is a list of interaction results
        # that by construction are scattered across devices on the first (batch) dim.
        # The output is a list of tensors scattered across devices according to the
        # distribution of z.
        p = parallel_apply(self.top_l_replicas, z, None, device_ids)

        ### gather the distributed results ###
        p0 = gather(p, self.output_d, dim=0)

        # clamp output if needed
        if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
            z0 = torch.clamp(
                p0, min=self.loss_threshold, max=(1.0 - self.loss_threshold)
            )
        else:
            z0 = p0

        return z0
Ejemplo n.º 4
0
def validation(model, val_loader, epoch, writer):
    # set evaluate mode
    model.eval()

    total_correct, total_label = 0, 0
    total_correct_hb, total_label_hb = 0, 0
    total_correct_fb, total_label_fb = 0, 0
    hist = np.zeros((args.num_classes, args.num_classes))
    hist_hb = np.zeros((args.hbody_cls, args.hbody_cls))
    hist_fb = np.zeros((args.fbody_cls, args.fbody_cls))

    # Iterate over data.
    from tqdm import tqdm
    tbar = tqdm(val_loader)
    for idx, batch in enumerate(tbar):
        image, target, hlabel, flabel, _ = batch
        image, target, hlabel, flabel = image.cuda(), target.cuda(
        ), hlabel.cuda(), flabel.cuda()
        with torch.no_grad():
            h, w = target.size(1), target.size(2)
            outputs = model(image)
            outputs = gather(outputs, 0, dim=0)
            preds = F.interpolate(input=outputs[0],
                                  size=(h, w),
                                  mode='bilinear',
                                  align_corners=True)
            preds_hb = F.interpolate(input=outputs[1],
                                     size=(h, w),
                                     mode='bilinear',
                                     align_corners=True)
            preds_fb = F.interpolate(input=outputs[2],
                                     size=(h, w),
                                     mode='bilinear',
                                     align_corners=True)
            if idx % 50 == 0:
                img_vis = inv_preprocess(image, num_images=args.save_num)
                label_vis = decode_predictions(target.int(),
                                               num_images=args.save_num,
                                               num_classes=args.num_classes)
                pred_vis = decode_predictions(torch.argmax(preds, dim=1),
                                              num_images=args.save_num,
                                              num_classes=args.num_classes)

                # visual grids
                img_grid = torchvision.utils.make_grid(
                    torch.from_numpy(img_vis.transpose(0, 3, 1, 2)))
                label_grid = torchvision.utils.make_grid(
                    torch.from_numpy(label_vis.transpose(0, 3, 1, 2)))
                pred_grid = torchvision.utils.make_grid(
                    torch.from_numpy(pred_vis.transpose(0, 3, 1, 2)))
                writer.add_image('val_images', img_grid,
                                 epoch * len(val_loader) + idx + 1)
                writer.add_image('val_labels', label_grid,
                                 epoch * len(val_loader) + idx + 1)
                writer.add_image('val_preds', pred_grid,
                                 epoch * len(val_loader) + idx + 1)

            # pixelAcc
            correct, labeled = batch_pix_accuracy(preds.data, target)
            correct_hb, labeled_hb = batch_pix_accuracy(preds_hb.data, hlabel)
            correct_fb, labeled_fb = batch_pix_accuracy(preds_fb.data, flabel)
            # mIoU
            hist += fast_hist(preds, target, args.num_classes)
            hist_hb += fast_hist(preds_hb, hlabel, args.hbody_cls)
            hist_fb += fast_hist(preds_fb, flabel, args.fbody_cls)

            total_correct += correct
            total_correct_hb += correct_hb
            total_correct_fb += correct_fb
            total_label += labeled
            total_label_hb += labeled_hb
            total_label_fb += labeled_fb
            pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
            IoU = round(np.nanmean(per_class_iu(hist)) * 100, 2)
            pixAcc_hb = 1.0 * total_correct_hb / (np.spacing(1) +
                                                  total_label_hb)
            IoU_hb = round(np.nanmean(per_class_iu(hist_hb)) * 100, 2)
            pixAcc_fb = 1.0 * total_correct_fb / (np.spacing(1) +
                                                  total_label_fb)
            IoU_fb = round(np.nanmean(per_class_iu(hist_fb)) * 100, 2)
            # plot progress
            tbar.set_description('{} / {} | {pixAcc:.4f}, {IoU:.4f} |' \
                         '{pixAcc_hb:.4f}, {IoU_hb:.4f} |' \
                         '{pixAcc_fb:.4f}, {IoU_fb:.4f}'.format(idx + 1, len(val_loader), pixAcc=pixAcc, IoU=IoU,pixAcc_hb=pixAcc_hb, IoU_hb=IoU_hb,pixAcc_fb=pixAcc_fb, IoU_fb=IoU_fb))

    print('\n per class iou part: {}'.format(per_class_iu(hist) * 100))
    print('per class iou hb: {}'.format(per_class_iu(hist_hb) * 100))
    print('per class iou fb: {}'.format(per_class_iu(hist_fb) * 100))

    mIoU = round(np.nanmean(per_class_iu(hist)) * 100, 2)
    mIoU_hb = round(np.nanmean(per_class_iu(hist_hb)) * 100, 2)
    mIoU_fb = round(np.nanmean(per_class_iu(hist_fb)) * 100, 2)

    writer.add_scalar('val_pixAcc', pixAcc, epoch)
    writer.add_scalar('val_mIoU', mIoU, epoch)
    writer.add_scalar('val_pixAcc_hb', pixAcc_hb, epoch)
    writer.add_scalar('val_mIoU_hb', mIoU_hb, epoch)
    writer.add_scalar('val_pixAcc_fb', pixAcc_fb, epoch)
    writer.add_scalar('val_mIoU_fb', mIoU_fb, epoch)
    tbar.close()
    return pixAcc, mIoU
Ejemplo n.º 5
0
 def gather(self, outputs, output_device):
     if self.training:
         return outputs
     else:
         return gather(outputs, output_device, dim=self.dim)
Ejemplo n.º 6
0
    def gather(self, outputs, output_device):
        if self.gather_:
            return gather(outputs, output_device, dim=self.dim)

        return outputs
Ejemplo n.º 7
0
 def gather(self, outputs, output_device):
     return gather(outputs, output_device, dim=0)