Esempio n. 1
0
def test_make_one_hot():
    arr1 = utils.make_one_hot(np.array([1, 1, 3, 4, 6]))
    assert (np.array_equal(arr1.shape, np.array([5, 6])))
    arr2 = utils.make_one_hot(np.array([0, 1, 1, 3, 4, 6]))
    assert (np.array_equal(arr2.shape, np.array([6, 7])))
    arr3 = utils.make_one_hot(np.array([0, 0, 0, 0]))
    assert (np.array_equal(arr3.shape, np.array([4, 1])))
def create_loss(predicts: torch.Tensor,
                labels: torch.Tensor,
                num_classes,
                cal_miou=True):
    """
    创建loss
    @param predicts: shape=(n, c, h, w)
    @param labels: shape=(n, h, w) or shape=(n, 1, h, w)
    @param num_classes: int should equal to channels of predicts
    @return: loss, mean_iou
    """
    # permute to (n, h, w, c)
    predicts = predicts.permute((0, 2, 3, 1))
    # reshape to (-1, num_classes)  每个像素在每种分类上都有一个概率
    predicts = predicts.reshape((-1, num_classes))
    # BCE with DICE
    bce_loss = F.cross_entropy(predicts, labels.flatten(),
                               reduction='mean')  # 函数内会自动做softmax
    # 将labels做one_hot处理,得到的形状跟predicts相同
    labels_one_hot = utils.make_one_hot(labels.reshape((-1, 1)), num_classes)
    dice_loss = utils.DiceLoss()(predicts, labels_one_hot.to(
        labels.device))  # torch没有原生的,从老师给的代码里拿过来用
    loss = bce_loss + dice_loss
    if cal_miou:
        ious = compute_iou(predicts, labels.reshape((-1, 1)), num_classes)
        miou = np.nanmean(ious.numpy())
    else:
        miou = None
    return loss, miou
Esempio n. 3
0
    def forward(self, input_, target) -> float:
        target = utils.make_one_hot(target, C=self.C)
        # subindex target without the ignore label
        target = target[:, :self.ignore_label, ...]

        assert input_.size() == target.size(), "Input sizes must be equal."
        assert input_.dim() == 4, "Input must be a 4D Tensor."

        probs = F.softmax(input_, dim=1)
        num = probs * target  #b,c,h,w--p*g
        num = torch.sum(num, dim=3)  #b,c,h
        num = torch.sum(num, dim=2)

        den1 = probs * probs  #--p^2
        den1 = torch.sum(den1, dim=3)  #b,c,h
        den1 = torch.sum(den1, dim=2)

        den2 = target * target  #--g^2
        den2 = torch.sum(den2, dim=3)  #b,c,h
        den2 = torch.sum(den2, dim=2)  #b,c

        dice = 2 * (num / (den1 + den2))
        dice_eso = dice[:, 1:]  #we ignore bg dice val, and take the fg

        dice_total = -1 * torch.sum(dice_eso) / dice_eso.size(
            0)  #divide by batch_sz

        return dice_total
Esempio n. 4
0
def test_onehot():
    root = '/Users/jizong/workspace/Semi-supervised-cycleGAN/data_utils/VOC2012'
    img_size = 256
    batchsize = 2
    transform = get_transformation(img_size)
    voc = VOCDataset(root_path=root, name='label', ratio=1, transformation=transform, augmentation=None)
    voc_loader = DataLoader(voc, batch_size=batchsize, shuffle=True)
    img = voc[0][0]
    assert img.size().__len__() == 3
    assert img.size(0) == 3
    assert img.size(1) == img_size
    assert img.size(2) == img_size
    img, gt = iter(voc_loader).__next__()[0:2]
    assert gt.shape.__len__() == 4
    assert gt.shape[0] == batchsize
    assert gt.shape[1] == 1
    assert gt.shape[2] == img_size
    assert gt.shape[3] == img_size
    onehot_gt = make_one_hot(gt, 'voc2012')

    # visulization for the first one image
    plt.imshow(img[0].squeeze()[0].numpy())
    plt.show()
    plt.imshow(gt[0].squeeze().numpy())
    plt.show()

    onehot_gt = onehot_gt[0]
    for c in range(onehot_gt.shape[0]):
        channel = onehot_gt[c]
        if channel.sum() > 0:
            plt.imshow(channel.squeeze().numpy(), cmap='gray')
            plt.show()

    pass
Esempio n. 5
0
def train(train_loader, net, criterion, optimizer, epoch, train_args):
    train_loss = AverageMeter()
    curr_iter = (epoch - 1) * len(train_loader)
    for i, data in enumerate(train_loader):
        inputs, labels = data
        labels = make_one_hot(labels.unsqueeze(1), wp.num_classes)
        # assert inputs.size()[2:] == labels.size()[1:]
        N = inputs.size(0)

        inputs = inputs.cuda()
        labels = labels.cuda()

        optimizer.zero_grad()
        outputs = net(inputs)
        # assert outputs.size()[2:] == labels.size()[1:]
        # assert outputs.size()[1] == wp.num_classes

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss.update(loss.item(), N)
        # print(loss.item())
        curr_iter += 1
        writer.add_scalar('train_loss', train_loss.avg, curr_iter)

        if (i + 1) % train_args['print_freq'] == 0:
            print('[epoch %d], [iter %d / %d], [train loss %.5f]' % (
                epoch, i + 1, len(train_loader), train_loss.avg
            ))
Esempio n. 6
0
def train(train_loader, net, D, criterion, criterion_D, optimizer_AE, optimizer_D, epoch, train_args):
    train_loss = AverageMeter()
    curr_iter = (epoch - 1) * len(train_loader)
    for i, data in enumerate(train_loader):
        # 训练D
        D.zero_grad()
        inputs, labels = data
        labels = make_one_hot(labels.unsqueeze(1), wp.num_classes)

        inputs = inputs.cuda()
        labels = labels.cuda()

        outputs = net(inputs)
        # ((a1,a2,...), axis=0)
        origin_outputs = torch.cat(
            [inputs, outputs], axis=1).cuda()  # B,16,H,W
        origin_labels = torch.cat([inputs, labels], axis=1).cuda()  # B,16,H,W
        batch_size = inputs.shape[0]

        output_D = D(origin_labels)  # B
        real_label = torch.ones(batch_size).cuda()  # 定义真实的图片label为1
        fake_label = torch.zeros(batch_size).cuda()  # 定义假的图片的label为0
        errD_real = criterion_D(output_D, real_label)
        errD_real.backward()
        # real_data_score = output_D.mean().item()

        output_D = D(origin_outputs)  # B
        errD_fake = criterion_D(output_D, fake_label)
        errD_fake.backward()
        # fake_data_score用来输出查看的,是虚假照片的评分,0最假,1为真
        # fake_data_score = output_D.data.mean()
        errD = errD_real + errD_fake
        optimizer_D.step()
        # print('errD', errD.item())

        # 训练AE
        net.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_AE.step()

        train_loss.update(loss.item(), batch_size)
        # print('loss', loss.item())
        curr_iter += 1
        writer.add_scalar('train_loss', train_loss.avg, curr_iter)

        if (i + 1) % train_args['print_freq'] == 0:
            print('[epoch %d], [iter %d / %d], [train loss %.5f]' % (
                epoch, i + 1, len(train_loader), train_loss.avg
            ))
def dice_loss_integer(input_, target, ignore_label=2, C=2):
    """
    Computes a Dice loss from 2D input of class scores and a target of integer labels.

    Parameters
    ----------
    input : torch.autograd.Variable
        size B x C x H x W representing class scores.
    target : torch.autograd.Variable
        integer label representation of the ground truth, same size as the input.
    ignore_label : integer.
        Must be final label in the sequence (to do, generalize).
    C : integer.
        number of classes (including an ignored label if present!)

    Returns
    -------
    dice_total : float.
        total dice loss.
    """
    target = utils.make_one_hot(target, C=C)
    # subindex target without the ignore label
    target = target[:, :ignore_label, ...]

    assert input_.size() == target.size(), "Input sizes must be equal."
    assert input_.dim() == 4, "Input must be a 4D Tensor."

    probs = F.softmax(input_)
    num = probs * target  #b,c,h,w--p*g
    num = torch.sum(num, dim=3)  #b,c,h
    num = torch.sum(num, dim=2)

    den1 = probs * probs  #--p^2
    den1 = torch.sum(den1, dim=3)  #b,c,h
    den1 = torch.sum(den1, dim=2)

    den2 = target * target  #--g^2
    den2 = torch.sum(den2, dim=3)  #b,c,h
    den2 = torch.sum(den2, dim=2)  #b,c

    dice = 2 * (num / (den1 + den2))
    dice_eso = dice[:, 1:]  #we ignore bg dice val, and take the fg

    dice_total = -1 * torch.sum(dice_eso) / dice_eso.size(
        0)  #divide by batch_sz

    return dice_total
Esempio n. 8
0
def run_all_datasets(datasets, y, names, classifiers, n_folds):
    """
    Loop through a list of datasets running potentially numerous classifiers on each
    :param datasets:
    :param y:
    :param names:
    :param classifiers:
    :param n_folds:
    :return: A tuple of pandas DataFrames for each dataset containing (macroF1, microF1)
    """
    try:
        n_data, n_classes = y.shape
    except ValueError:  # data is encoded with integers instead of one-hot
        y = utils.make_one_hot(y)
    results = []
    for data in zip(datasets, names):
        temp = run_detectors(data[0], y, data[1], classifiers, n_folds)
        results.append(temp)
    return results
Esempio n. 9
0
 def _discriminator(self, h, l):
     last_hidden = self._phi(h)
     l_onehot = make_one_hot(l, self.attr)
     term1 = torch.sum(l_onehot * self.W(last_hidden), dim=1)
     term2 = torch.sum(self.v * last_hidden, dim=1)
     return term1 + term2 # (B,), logit not prob
def validation(args):

    ### For selecting the number of channels
    if args.dataset == 'voc2012':
        n_channels = 21
    elif args.dataset == 'cityscapes':
        n_channels = 20
    elif args.dataset == 'acdc':
        n_channels = 4

    transform = get_transformation((args.crop_height, args.crop_width), resize=True, dataset=args.dataset)

    ## let the choice of dataset configurable
    if args.dataset == 'voc2012':
        val_set = VOCDataset(root_path=root, name='val', ratio=0.5, transformation=transform, augmentation=None)
    elif args.dataset == 'cityscapes':
        val_set = CityscapesDataset(root_path=root_cityscapes, name='val', ratio=0.5, transformation=transform, augmentation=None)
    elif args.dataset == 'acdc':
        val_set = ACDCDataset(root_path=root_acdc, name='val', ratio=0.5, transformation=transform, augmentation=None)

    val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False)

    Gsi = define_Gen(input_nc=3, output_nc=n_channels, ngf=args.ngf, netG='deeplab', 
                                    norm=args.norm, use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)

    Gis = define_Gen(input_nc=n_channels, output_nc=3, ngf=args.ngf, netG='deeplab',
                              norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids)

    ### best_iou
    best_iou = 0

    ### Interpolation
    interp = nn.Upsample(size = (args.crop_height, args.crop_width), mode='bilinear', align_corners=True)

    ### Softmax activation
    activation_softmax = nn.Softmax2d()
    activation_tanh = nn.Tanh()

    if(args.model == 'supervised_model'):

        ### loading the checkpoint
        try:
            ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' % (args.checkpoint_dir))
            Gsi.load_state_dict(ckpt['Gsi'])
            best_iou = ckpt['best_iou']

        except:
            print(' [*] No checkpoint!')

        ### run
        Gsi.eval()
        for i, (image_test, real_segmentation, image_name) in enumerate(val_loader):
            image_test = utils.cuda(image_test, args.gpu_ids)
            seg_map = Gsi(image_test)
            seg_map = interp(seg_map)
            seg_map = activation_softmax(seg_map)

            prediction = seg_map.data.max(1)[1].squeeze_(1).cpu().numpy()   ### To convert from 22 --> 1 channel
            for j in range(prediction.shape[0]):
                new_img = prediction[j]     ### Taking a particular image from the batch
                new_img = utils.colorize_mask(new_img, args.dataset)   ### So as to convert it back to a paletted image

                ### Now the new_img is PIL.Image
                new_img.save(os.path.join(args.validation_dir+'/supervised/'+image_name[j]+'.png'))

            
            print('Epoch-', str(i+1), ' Done!')
        
        print('The iou of the resulting segment maps: ', str(best_iou))


    elif(args.model == 'semisupervised_cycleGAN'):

        ### loading the checkpoint
        try:
            ckpt = utils.load_checkpoint('%s/latest_semisuper_cycleGAN.ckpt' % (args.checkpoint_dir))
            Gsi.load_state_dict(ckpt['Gsi'])
            Gis.load_state_dict(ckpt['Gis'])
            best_iou = ckpt['best_iou']

        except:
            print(' [*] No checkpoint!')

        ### run
        Gsi.eval()
        for i, (image_test, real_segmentation, image_name) in enumerate(val_loader):
            image_test, real_segmentation = utils.cuda([image_test, real_segmentation], args.gpu_ids)
            seg_map = Gsi(image_test)
            seg_map = interp(seg_map)
            seg_map = activation_softmax(seg_map)
            fake_img = Gis(seg_map).detach()
            fake_img = interp(fake_img)
            fake_img = activation_tanh(fake_img)

            fake_img_from_labels = Gis(make_one_hot(real_segmentation, args.dataset, args.gpu_ids).float()).detach()
            fake_img_from_labels = interp(fake_img_from_labels)
            fake_img_from_labels = activation_tanh(fake_img_from_labels)
            fake_label_regenerated = Gsi(fake_img_from_labels).detach()
            fake_label_regenerated = interp(fake_label_regenerated)
            fake_label_regenerated = activation_softmax(fake_label_regenerated)

            prediction = seg_map.data.max(1)[1].squeeze_(1).cpu().numpy()   ### To convert from 22 --> 1 channel
            fake_regenerated_label = fake_label_regenerated.data.max(1)[1].squeeze_(1).cpu().numpy()

            fake_img = fake_img.cpu()
            fake_img_from_labels = fake_img_from_labels.cpu()

            ### Now i am going to revert back the transformation on these images
            if args.dataset == 'voc2012' or args.dataset == 'cityscapes':
                trans_mean = [0.5, 0.5, 0.5]
                trans_std = [0.5, 0.5, 0.5]
                for k in range(3):
                    fake_img[:, k, :, :] = ((fake_img[:, k, :, :] * trans_std[k]) + trans_mean[k])
                    fake_img_from_labels[:, k, :, :] = ((fake_img_from_labels[:, k, :, :] * trans_std[k]) + trans_mean[k])

            elif args.dataset == 'acdc':
                trans_mean = [0.5]
                trans_std = [0.5]
                for k in range(1):
                    fake_img[:, k, :, :] = ((fake_img[:, k, :, :] * trans_std[k]) + trans_mean[k])
                    fake_img_from_labels[:, k, :, :] = ((fake_img_from_labels[:, k, :, :] * trans_std[k]) + trans_mean[k])

            for j in range(prediction.shape[0]):
                new_img = prediction[j]     ### Taking a particular image from the batch
                new_img = utils.colorize_mask(new_img, args.dataset)   ### So as to convert it back to a paletted image

                regen_label = fake_regenerated_label[j]
                regen_label = utils.colorize_mask(regen_label, args.dataset)

                ### Now the new_img is PIL.Image
                new_img.save(os.path.join(args.validation_dir+'/unsupervised/generated_labels/'+image_name[j]+'.png'))
                regen_label.save(os.path.join(args.validation_dir+'/unsupervised/regenerated_labels/'+image_name[j]+'.png'))
                torchvision.utils.save_image(fake_img[j], os.path.join(args.validation_dir+'/unsupervised/regenerated_image/'+image_name[j]+'.jpg'))
                torchvision.utils.save_image(fake_img_from_labels[j], os.path.join(args.validation_dir+'/unsupervised/image_from_labels/'+image_name[j]+'.jpg'))
            
            print('Epoch-', str(i+1), ' Done!')
        
        print('The iou of the resulting segment maps: ', str(best_iou))
Esempio n. 11
0
    print('---------------------------')
    # -----------------------------------------------------------------
    # Find Adversarial Examples
    # -----------------------------------------

    # PGD attack
    epsilon = 0.1
    pgd_iter = 12
    a = 0.2
    random_start = True

    loss = tf.losses.softmax_cross_entropy(my_classifier.y,
                                           my_classifier.logits)
    attack = LinfPGDAttack(loss, my_classifier.x, epsilon, pgd_iter, a,
                           random_start)
    y_labels = utils.make_one_hot(my_classifier.original_y_train)

    X_adv = attack.perturb(my_classifier.original_X_train, y_labels,
                           my_classifier.x, my_classifier.y, sess)
    y_Adv = my_classifier.original_y_train

    # -----------------------------------------------------------------
    # Adversarial Training
    # -----------------------------------------
    my_classifier.set_adv_dataset(batch_size, X_adv, y_Adv)
    my_classifier.train(sess, num_epochs, learning_rate)

    # Test Classifier
    my_classifier.eval_model(sess)

    # Classifier Decision Boundary
    def train(self, args):
        transform = get_transformation((args.crop_height, args.crop_width),
                                       resize=True,
                                       dataset=args.dataset)

        # let the choice of dataset configurable
        if self.args.dataset == 'voc2012':
            labeled_set = VOCDataset(root_path=root,
                                     name='label',
                                     ratio=0.2,
                                     transformation=transform,
                                     augmentation=None)
            unlabeled_set = VOCDataset(root_path=root,
                                       name='unlabel',
                                       ratio=0.2,
                                       transformation=transform,
                                       augmentation=None)
            val_set = VOCDataset(root_path=root,
                                 name='val',
                                 ratio=0.5,
                                 transformation=transform,
                                 augmentation=None)
        elif self.args.dataset == 'cityscapes':
            labeled_set = CityscapesDataset(root_path=root_cityscapes,
                                            name='label',
                                            ratio=0.5,
                                            transformation=transform,
                                            augmentation=None)
            unlabeled_set = CityscapesDataset(root_path=root_cityscapes,
                                              name='unlabel',
                                              ratio=0.5,
                                              transformation=transform,
                                              augmentation=None)
            val_set = CityscapesDataset(root_path=root_cityscapes,
                                        name='val',
                                        ratio=0.5,
                                        transformation=transform,
                                        augmentation=None)
        elif self.args.dataset == 'acdc':
            labeled_set = ACDCDataset(root_path=root_acdc,
                                      name='label',
                                      ratio=0.5,
                                      transformation=transform,
                                      augmentation=None)
            unlabeled_set = ACDCDataset(root_path=root_acdc,
                                        name='unlabel',
                                        ratio=0.5,
                                        transformation=transform,
                                        augmentation=None)
            val_set = ACDCDataset(root_path=root_acdc,
                                  name='val',
                                  ratio=0.5,
                                  transformation=transform,
                                  augmentation=None)
        '''
        https://discuss.pytorch.org/t/about-the-relation-between-batch-size-and-length-of-data-loader/10510
        ^^ The reason for using drop_last=True so as to obtain an even size of all the batches and
        deleting the last batch with less images
        '''
        labeled_loader = DataLoader(labeled_set,
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    drop_last=True)
        unlabeled_loader = DataLoader(unlabeled_set,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      drop_last=True)
        val_loader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=True,
                                drop_last=True)

        new_img_fake_sample = utils.Sample_from_Pool()
        img_fake_sample = utils.Sample_from_Pool()
        gt_fake_sample = utils.Sample_from_Pool()

        img_dis_loss, gt_dis_loss, unsupervisedloss, fullsupervisedloss = 0, 0, 0, 0

        ### Variable to regulate the frequency of update between Discriminators and Generators
        counter = 0

        for epoch in range(self.start_epoch, args.epochs):
            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            self.Gsi.train()
            self.Gis.train()

            # if (epoch+1)%10 == 0:
            # args.lamda_img = args.lamda_img + 0.08
            # args.lamda_gt = args.lamda_gt + 0.04

            for i, ((l_img, l_gt, _),
                    (unl_img, _,
                     _)) in enumerate(zip(labeled_loader, unlabeled_loader)):
                # step
                step = epoch * min(len(labeled_loader),
                                   len(unlabeled_loader)) + i + 1

                l_img, unl_img, l_gt = utils.cuda([l_img, unl_img, l_gt],
                                                  args.gpu_ids)

                # Generator Computations
                ##################################################

                set_grad([self.Di, self.Ds, self.old_Di], False)
                set_grad([self.old_Gsi, self.old_Gis], False)
                self.g_optimizer.zero_grad()

                # Forward pass through generators
                ##################################################
                fake_img = self.Gis(
                    make_one_hot(l_gt, args.dataset, args.gpu_ids).float())
                fake_gt = self.Gsi(unl_img.float())  ### having 21 channels
                lab_gt = self.Gsi(l_img)  ### having 21 channels

                ### Getting the outputs of the model to correct dimensions
                fake_img = self.interp(fake_img)
                fake_gt = self.interp(fake_gt)
                lab_gt = self.interp(lab_gt)

                # fake_gt = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0)  ### will get into no channels
                # fake_gt = fake_gt.unsqueeze(1)   ### will get into 1 channel only
                # fake_gt = make_one_hot(fake_gt, args.dataset, args.gpu_ids)

                lab_loss_CE = self.CE(lab_gt, l_gt.squeeze(1))

                ### Again applying activations
                lab_gt = self.activation_softmax(lab_gt)
                fake_gt = self.activation_softmax(fake_gt)
                # fake_gt = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0)
                # fake_gt = fake_gt.unsqueeze(1)
                # fake_gt = make_one_hot(fake_gt, args.dataset, args.gpu_ids)
                # fake_img = self.activation_tanh(fake_img)

                recon_img = self.Gis(fake_gt.float())
                recon_lab_img = self.Gis(lab_gt.float())
                recon_gt = self.Gsi(fake_img.float())

                ### Getting the outputs of the model to correct dimensions
                recon_img = self.interp(recon_img)
                recon_lab_img = self.interp(recon_lab_img)
                recon_gt = self.interp(recon_gt)

                ### This is for the case of the new loss between the recon_img from resnet and deeplab network
                resnet_fake_gt = self.old_Gsi(unl_img.float())
                resnet_lab_gt = self.old_Gsi(l_img)
                resnet_lab_gt = self.activation_softmax(resnet_lab_gt)
                resnet_fake_gt = self.activation_softmax(resnet_fake_gt)
                resnet_recon_img = self.old_Gis(resnet_fake_gt.float())
                resnet_recon_lab_img = self.old_Gis(resnet_lab_gt.float())

                ## Applying the tanh activations
                # recon_img = self.activation_tanh(recon_img)
                # recon_lab_img = self.activation_tanh(recon_lab_img)

                # Adversarial losses
                ###################################################
                fake_img_dis = self.Di(fake_img)
                resnet_fake_img_dis = self.old_Di(recon_img)

                ### For passing different type of input to Ds
                fake_gt_discriminator = fake_gt.data.max(1)[1].squeeze_(
                    1).squeeze_(0)
                fake_gt_discriminator = fake_gt_discriminator.unsqueeze(1)
                fake_gt_discriminator = make_one_hot(fake_gt_discriminator,
                                                     args.dataset,
                                                     args.gpu_ids)
                fake_gt_dis = self.Ds(fake_gt_discriminator.float())
                # lab_gt_dis = self.Ds(lab_gt)

                real_label_gt = utils.cuda(
                    Variable(torch.ones(fake_gt_dis.size())), args.gpu_ids)
                real_label_img = utils.cuda(
                    Variable(torch.ones(fake_img_dis.size())), args.gpu_ids)

                # here is much better to have a cross entropy loss for classification.
                img_gen_loss = self.MSE(fake_img_dis, real_label_img)
                gt_gen_loss = self.MSE(fake_gt_dis, real_label_gt)
                # gt_label_gen_loss = self.MSE(lab_gt_dis, real_label)

                # Cycle consistency losses
                ###################################################
                resnet_img_cycle_loss = self.MSE(resnet_fake_img_dis,
                                                 real_label_img)
                # img_cycle_loss = self.L1(recon_img, unl_img)
                # img_cycle_loss_perceptual = perceptual_loss(recon_img, unl_img, args.gpu_ids)
                gt_cycle_loss = self.CE(recon_gt, l_gt.squeeze(1))
                # lab_img_cycle_loss = self.L1(recon_lab_img, l_img) * args.lamda

                # Total generators losses
                ###################################################
                # lab_loss_CE = self.CE(lab_gt, l_gt.squeeze(1))
                lab_loss_MSE = self.L1(fake_img, l_img)
                # lab_loss_perceptual = perceptual_loss(fake_img, l_img, args.gpu_ids)

                fullsupervisedloss = args.lab_CE_weight * lab_loss_CE + args.lab_MSE_weight * lab_loss_MSE

                unsupervisedloss = args.adversarial_weight * (
                    img_gen_loss + gt_gen_loss
                ) + resnet_img_cycle_loss + gt_cycle_loss * args.lamda_gt

                gen_loss = fullsupervisedloss + unsupervisedloss

                # Update generators
                ###################################################
                gen_loss.backward()

                self.g_optimizer.step()

                if counter % 1 == 0:
                    # Discriminator Computations
                    #################################################

                    set_grad([self.Di, self.Ds, self.old_Di], True)
                    self.d_optimizer.zero_grad()

                    # Sample from history of generated images
                    #################################################
                    if torch.rand(1) < 0.0:
                        fake_img = self.gauss_noise(fake_img.cpu())
                        fake_gt = self.gauss_noise(fake_gt.cpu())

                    recon_img = Variable(
                        torch.Tensor(
                            new_img_fake_sample([recon_img.cpu().data.numpy()
                                                 ])[0]))
                    fake_img = Variable(
                        torch.Tensor(
                            img_fake_sample([fake_img.cpu().data.numpy()])[0]))
                    # lab_gt = Variable(torch.Tensor(gt_fake_sample([lab_gt.cpu().data.numpy()])[0]))
                    fake_gt = Variable(
                        torch.Tensor(
                            gt_fake_sample([fake_gt.cpu().data.numpy()])[0]))

                    recon_img, fake_img, fake_gt = utils.cuda(
                        [recon_img, fake_img, fake_gt], args.gpu_ids)

                    # Forward pass through discriminators
                    #################################################
                    unl_img_dis = self.Di(unl_img)
                    fake_img_dis = self.Di(fake_img)
                    resnet_recon_img_dis = self.old_Di(resnet_recon_img)
                    resnet_fake_img_dis = self.old_Di(recon_img)

                    # lab_gt_dis = self.Ds(lab_gt)

                    l_gt = make_one_hot(l_gt, args.dataset, args.gpu_ids)
                    real_gt_dis = self.Ds(l_gt.float())

                    fake_gt_discriminator = fake_gt.data.max(1)[1].squeeze_(
                        1).squeeze_(0)
                    fake_gt_discriminator = fake_gt_discriminator.unsqueeze(1)
                    fake_gt_discriminator = make_one_hot(
                        fake_gt_discriminator, args.dataset, args.gpu_ids)
                    fake_gt_dis = self.Ds(fake_gt_discriminator.float())

                    real_label_img = utils.cuda(
                        Variable(torch.ones(unl_img_dis.size())), args.gpu_ids)
                    fake_label_img = utils.cuda(
                        Variable(torch.zeros(fake_img_dis.size())),
                        args.gpu_ids)
                    real_label_gt = utils.cuda(
                        Variable(torch.ones(real_gt_dis.size())), args.gpu_ids)
                    fake_label_gt = utils.cuda(
                        Variable(torch.zeros(fake_gt_dis.size())),
                        args.gpu_ids)

                    # Discriminator losses
                    ##################################################
                    img_dis_real_loss = self.MSE(unl_img_dis, real_label_img)
                    img_dis_fake_loss = self.MSE(fake_img_dis, fake_label_img)
                    gt_dis_real_loss = self.MSE(real_gt_dis, real_label_gt)
                    gt_dis_fake_loss = self.MSE(fake_gt_dis, fake_label_gt)
                    # lab_gt_dis_fake_loss = self.MSE(lab_gt_dis, fake_label)

                    cycle_img_dis_real_loss = self.MSE(resnet_recon_img_dis,
                                                       real_label_img)
                    cycle_img_dis_fake_loss = self.MSE(resnet_fake_img_dis,
                                                       fake_label_img)

                    # Total discriminators losses
                    img_dis_loss = (img_dis_real_loss +
                                    img_dis_fake_loss) * 0.5
                    gt_dis_loss = (gt_dis_real_loss + gt_dis_fake_loss) * 0.5
                    # lab_gt_dis_loss = (gt_dis_real_loss + lab_gt_dis_fake_loss)*0.33
                    cycle_img_dis_loss = cycle_img_dis_real_loss + cycle_img_dis_fake_loss

                    # Update discriminators
                    ##################################################
                    discriminator_loss = args.discriminator_weight * (
                        img_dis_loss + gt_dis_loss) + cycle_img_dis_loss
                    discriminator_loss.backward()

                    # lab_gt_dis_loss.backward()
                    self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Dis Loss:%.2e | Unlab Gen Loss:%.2e | Lab Gen loss:%.2e"
                    % (epoch, i + 1,
                       min(len(labeled_loader),
                           len(unlabeled_loader)), img_dis_loss + gt_dis_loss,
                       unsupervisedloss, fullsupervisedloss))

                self.writer_semisuper.add_scalars(
                    'Dis Loss', {
                        'img_dis_loss': img_dis_loss,
                        'gt_dis_loss': gt_dis_loss,
                        'cycle_img_dis_loss': cycle_img_dis_loss
                    },
                    len(labeled_loader) * epoch + i)
                self.writer_semisuper.add_scalars(
                    'Unlabelled Loss', {
                        'img_gen_loss': img_gen_loss,
                        'gt_gen_loss': gt_gen_loss,
                        'img_cycle_loss': resnet_img_cycle_loss,
                        'gt_cycle_loss': gt_cycle_loss
                    },
                    len(labeled_loader) * epoch + i)
                self.writer_semisuper.add_scalars(
                    'Labelled Loss', {
                        'lab_loss_CE': lab_loss_CE,
                        'lab_loss_MSE': lab_loss_MSE
                    },
                    len(labeled_loader) * epoch + i)

                counter += 1

            ### For getting the mean IoU
            self.Gsi.eval()
            self.Gis.eval()
            with torch.no_grad():
                for i, (val_img, val_gt, _) in enumerate(val_loader):
                    val_img, val_gt = utils.cuda([val_img, val_gt],
                                                 args.gpu_ids)

                    outputs = self.Gsi(val_img)
                    outputs = self.interp(outputs)
                    outputs = self.activation_softmax(outputs)

                    pred = outputs.data.max(1)[1].cpu().numpy()
                    gt = val_gt.squeeze().data.cpu().numpy()

                    self.running_metrics_val.update(gt, pred)

            score, class_iou = self.running_metrics_val.get_scores()

            self.running_metrics_val.reset()

            print('The mIoU for the epoch is: ', score["Mean IoU : \t"])

            ### For displaying the images generated by generator on tensorboard using validation images
            val_image, val_gt, _ = iter(val_loader).next()
            val_image, val_gt = utils.cuda([val_image, val_gt], args.gpu_ids)
            with torch.no_grad():
                fake_label = self.Gsi(val_image).detach()
                fake_label = self.interp(fake_label)
                fake_label = self.activation_softmax(fake_label)
                fake_label = fake_label.data.max(1)[1].squeeze_(1).squeeze_(0)
                fake_label = fake_label.unsqueeze(1)
                fake_label = make_one_hot(fake_label, args.dataset,
                                          args.gpu_ids)
                fake_img = self.Gis(fake_label).detach()
                fake_img = self.interp(fake_img)
                # fake_img = self.activation_tanh(fake_img)

                fake_img_from_labels = self.Gis(
                    make_one_hot(val_gt, args.dataset,
                                 args.gpu_ids).float()).detach()
                fake_img_from_labels = self.interp(fake_img_from_labels)
                # fake_img_from_labels = self.activation_tanh(fake_img_from_labels)
                fake_label_regenerated = self.Gsi(
                    fake_img_from_labels).detach()
                fake_label_regenerated = self.interp(fake_label_regenerated)
                fake_label_regenerated = self.activation_softmax(
                    fake_label_regenerated)
            fake_prediction_label = fake_label.data.max(1)[1].squeeze_(
                1).cpu().numpy()
            fake_regenerated_label = fake_label_regenerated.data.max(
                1)[1].squeeze_(1).cpu().numpy()
            val_gt = val_gt.cpu()

            fake_img = fake_img.cpu()
            fake_img_from_labels = fake_img_from_labels.cpu()
            ### Now i am going to revert back the transformation on these images
            if self.args.dataset == 'voc2012' or self.args.dataset == 'cityscapes':
                trans_mean = [0.5, 0.5, 0.5]
                trans_std = [0.5, 0.5, 0.5]
                for i in range(3):
                    fake_img[:, i, :, :] = (
                        (fake_img[:, i, :, :] * trans_std[i]) + trans_mean[i])
                    fake_img_from_labels[:, i, :, :] = (
                        (fake_img_from_labels[:, i, :, :] * trans_std[i]) +
                        trans_mean[i])

            elif self.args.dataset == 'acdc':
                trans_mean = [0.5]
                trans_std = [0.5]
                for i in range(1):
                    fake_img[:, i, :, :] = (
                        (fake_img[:, i, :, :] * trans_std[i]) + trans_mean[i])
                    fake_img_from_labels[:, i, :, :] = (
                        (fake_img_from_labels[:, i, :, :] * trans_std[i]) +
                        trans_mean[i])

            ### display_tensor is the final tensor that will be displayed on tensorboard
            display_tensor_label = torch.zeros([
                fake_label.shape[0], 3, fake_label.shape[2],
                fake_label.shape[3]
            ])
            display_tensor_gt = torch.zeros(
                [val_gt.shape[0], 3, val_gt.shape[2], val_gt.shape[3]])
            display_tensor_regen_label = torch.zeros([
                fake_label_regenerated.shape[0], 3,
                fake_label_regenerated.shape[2],
                fake_label_regenerated.shape[3]
            ])
            for i in range(fake_prediction_label.shape[0]):
                new_img_label = fake_prediction_label[i]
                new_img_label = utils.colorize_mask(
                    new_img_label, self.args.dataset
                )  ### So this is the generated image in PIL.Image format
                img_tensor_label = utils.PIL_to_tensor(new_img_label,
                                                       self.args.dataset)
                display_tensor_label[i, :, :, :] = img_tensor_label

                display_tensor_gt[i, :, :, :] = val_gt[i]

                regen_label = fake_regenerated_label[i]
                regen_label = utils.colorize_mask(regen_label,
                                                  self.args.dataset)
                regen_tensor_label = utils.PIL_to_tensor(
                    regen_label, self.args.dataset)
                display_tensor_regen_label[i, :, :, :] = regen_tensor_label

            self.writer_semisuper.add_image(
                'Generated segmented image: ',
                torchvision.utils.make_grid(display_tensor_label,
                                            nrow=2,
                                            normalize=True), epoch)
            self.writer_semisuper.add_image(
                'Generated image back from segmentation: ',
                torchvision.utils.make_grid(fake_img, nrow=2, normalize=True),
                epoch)
            self.writer_semisuper.add_image(
                'Ground truth for the image: ',
                torchvision.utils.make_grid(display_tensor_gt,
                                            nrow=2,
                                            normalize=True), epoch)
            self.writer_semisuper.add_image(
                'Image generated from val labels: ',
                torchvision.utils.make_grid(fake_img_from_labels,
                                            nrow=2,
                                            normalize=True), epoch)
            self.writer_semisuper.add_image(
                'Labels generated back from the cycle: ',
                torchvision.utils.make_grid(display_tensor_regen_label,
                                            nrow=2,
                                            normalize=True), epoch)

            if score["Mean IoU : \t"] >= self.best_iou:
                self.best_iou = score["Mean IoU : \t"]

                # Override the latest checkpoint
                #######################################################
                utils.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'Di': self.Di.state_dict(),
                        'Ds': self.Ds.state_dict(),
                        'Gis': self.Gis.state_dict(),
                        'Gsi': self.Gsi.state_dict(),
                        'd_optimizer': self.d_optimizer.state_dict(),
                        'g_optimizer': self.g_optimizer.state_dict(),
                        'best_iou': self.best_iou,
                        'class_iou': class_iou
                    }, '%s/latest_semisuper_cycleGAN.ckpt' %
                    (args.checkpoint_dir))

            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()

        self.writer_semisuper.close()
Esempio n. 13
0
def DAG(model,
        image,
        ground_truth,
        adv_target,
        num_iterations=20,
        gamma=0.07,
        no_background=True,
        background_class=0,
        device='cuda:0',
        verbose=False):
    '''
    Generates adversarial example for a given Image
    
    Parameters
    ----------
        model: Torch Model
        image: Torch tensor of dtype=float. Requires gradient. [b*c*h*w]
        ground_truth: Torch tensor of labels as one hot vector per class
        adv_target: Torch tensor of dtype=float. This is the purturbed labels. [b*classes*h*w]
        num_iterations: Number of iterations for the algorithm
        gamma: epsilon value. The maximum Change possible.
        no_background: If True, does not purturb the background class
        background_class: The index of the background class. Used to filter background
        device: Device to perform the computations on
        verbose: Bool. If true, prints the amount of change and the number of values changed in each iteration
    Returns
    -------
        Image:  Adversarial Output, logits of original image as torch tensor
        logits: Output of the Clean Image as torch tensor
        noise_total: List of total noise added per iteration as numpy array
        noise_iteration: List of noise added per iteration as numpy array
        prediction_iteration: List of prediction per iteration as numpy array
        image_iteration: List of image per iteration as numpy array

    '''

    noise_total = []
    noise_iteration = []
    prediction_iteration = []
    image_iteration = []
    background = None
    logits = model(image)
    orig_image = image
    _, predictions_orig = torch.max(logits, 1)
    predictions_orig = make_one_hot(predictions_orig, logits.shape[1], device)

    if (no_background):
        background = torch.zeros(logits.shape)
        background[:, background_class, :, :] = torch.ones(
            (background.shape[2], background.shape[3]))
        background = background.to(device)

    for a in range(num_iterations):
        output = model(image)
        _, predictions = torch.max(output, 1)
        prediction_iteration.append(predictions[0].cpu().numpy())
        predictions = make_one_hot(predictions, logits.shape[1], device)

        condition1 = torch.eq(predictions, ground_truth)
        condition = condition1

        if no_background:
            condition2 = (ground_truth != background)
            condition = torch.mul(condition1, condition2)
        condition = condition.float()

        if (condition.sum() == 0):
            print("Condition Reached")
            image = None
            break

        #Finding pixels to purturb
        adv_log = torch.mul(output, adv_target)
        #Getting the values of the original output
        clean_log = torch.mul(output, ground_truth)

        #Finding r_m
        adv_direction = adv_log - clean_log
        r_m = torch.mul(adv_direction, condition)
        r_m.requires_grad_()
        #Summation
        r_m_sum = r_m.sum()
        r_m_sum.requires_grad_()
        #Finding gradient with respect to image
        r_m_grad = torch.autograd.grad(r_m_sum, image, retain_graph=True)
        #Saving gradient for calculation
        r_m_grad_calc = r_m_grad[0]

        #Calculating Magnitude of the gradient
        r_m_grad_mag = r_m_grad_calc.norm()

        if (r_m_grad_mag == 0):
            print("Condition Reached, no gradient")
            #image=None
            break
        #Calculating final value of r_m
        r_m_norm = (gamma / r_m_grad_mag) * r_m_grad_calc

        #if no_background:
        if False:
            condition_image = condition.sum(dim=1)
            condition_image = condition_image.unsqueeze(1)
            r_m_norm = torch.mul(r_m_norm, condition_image)

        #Updating the image
        image = torch.clamp((image + r_m_norm), 0, 1)
        image_iteration.append(image[0][0].detach().cpu().numpy())
        noise_total.append((image - orig_image)[0][0].detach().cpu().numpy())
        noise_iteration.append(r_m_norm[0][0].cpu().numpy())

        if verbose:
            print("Iteration ", a)
            print("Change to the image is ", r_m_norm.sum())
            print("Magnitude of grad is ", r_m_grad_mag)
            print("Condition 1 ", condition1.sum())
            if no_background:
                print("Condition 2 ", condition2.sum())
                print("Condition is", condition.sum())

    return image, logits, noise_total, noise_iteration, prediction_iteration, image_iteration
Esempio n. 14
0
    def __init__(
        self,
        images_dir,
        transform=None,
        image_size=256,
        subset="train",
        random_sampling=True,
        validation_cases=0,
        seed=42,
    ):
        assert subset in ["all", "train", "validation"]

        # read images
        volumes = {}
        masks = {}
        print("reading {} images...".format(subset))
        #dirpath 是当前目录, dirnames,是目录下的文件夹,filenames, 是目录下的文件
        for (dirpath, dirnames, filenames) in os.walk(images_dir):
            image_slices = []
            mask_slices = []
            mask_path = ""
            #filter 来筛选名字带.tif的文件
            #key指按照某一项排序
            #filter 来筛选名字带.tif的文件
            #key指按照某一项排序
            for filename in sorted(filter(lambda f: ".gz" in f, filenames)):  
                if "seg" in filename:
                    mask_path = os.path.join(dirpath,filename)
                    mask_slices.append(load_nii(mask_path))
                else:
                    filepath = os.path.join(dirpath, filename) 
                    image_slices.append(load_nii(filepath))

            embed()
            #只筛选带有肿瘤的slice
            if len(image_slices) > 0:
                patient_id = dirpath.split("/")[-1]

                volumes[patient_id] = np.array(image_slices).transpose(1,2,3,0)
                masks[patient_id] = np.array(mask_slices).transpose(1,2,3,0)

            embed()

        #patient 是一个字典,里面是patient_id和其对应的image(无mask)
        self.patients = sorted(volumes)

        # select cases to subset
        if not subset == "all":
            random.seed(seed)
            #分出validation set
            validation_patients = random.sample(self.patients, k=validation_cases)                                      #注意K有可能超
            if subset == "validation":
                self.patients = validation_patients
            else:
                self.patients = sorted(
                    list(set(self.patients).difference(validation_patients))
                )

        print("preprocessing {} volumes...".format(subset))
        # create list of tuples (volume, mask)
        self.volumes = [(volumes[k], masks[k]) for k in self.patients]
        embed()

        # probabilities for sampling slices based on masks
        self.slice_weights = [m.sum(axis=-1).sum(axis=-1).sum(axis=-1) for v, m in self.volumes]
        self.slice_weights = [(s + (s.sum() * 0.1 / len(s))) / (s.sum() * 1.1) for s in self.slice_weights]
        
        print("one hotting {} masks...".format(subset))
        self.volumes = [(v, make_one_hot(m)) for v,  m in self.volumes]
        embed()

        print("resizing {} volumes...".format(subset))
        # resize
        self.volumes = [resize_sample(v, size=image_size) for v in self.volumes]
        embed()

        print("normalizing {} volumes...".format(subset))
        # normalize channel-wise
        self.volumes = [(normalize_volume(v), m) for v,  m in self.volumes]
        embed()

        print("one hotting {} masks...".format(subset))
        self.volumes = [(v, convert_mask_to_one(m)) for v,  m in self.volumes]
        embed()

        print("done creating {} dataset".format(subset))

        # create global index for patient and slice (idx -> (p_idx, s_idx))
        num_slices = [v.shape[0] for v, m in self.volumes]
        self.patient_slice_index = list(
            zip(
                sum([[i] * num_slices[i] for i in range(len(num_slices))], []),
                sum([list(range(x)) for x in num_slices], []),
            )
        )

        self.random_sampling = random_sampling
        self.transform = transform
        embed()
Esempio n. 15
0
                else:
                    seg_loss_ir_only = criterion_semseg_weighted(out_model['pred_label_b'],
                                                                 torch.argmax(out_night_ir_only, 1).squeeze(1).long())
                    cert = F.softmax(out_night_ir_only)
                    cert = cert.max(1)[0]

                    if opt.vis:
                        vis_utils.visDepth(cert[0:1, ...],  'weighting_ir')
                        vis_utils.visDepth(out_night_ir_only[0:1,...].max(1)[1].float(), 'night_ir_label')
                    seg_loss_ir_only = torch.mean(cert * seg_loss_ir_only)

                seg_loss += seg_loss_ir_only
                print('Night Seg loss: %f' % seg_loss_ir_only)

            if opt.cert_branch and not night_supervision_active:
                one_hot_label = utils.make_one_hot(label_day.unsqueeze(1), out_model['pred_label_a'].size(1))
                cert = torch.sum((one_hot_label.float() * F.softmax(out_model['pred_label_a'])), 1)
                cert = torch.ones_like(cert) - cert
                if opt.vis:
                    vis_utils.visDepth(cert[0:1, ...], 'cert_gt')

                cert_loss = torch.mean((out_model['cert_a'] - cert)**2) * 10
                print('Cert_loss : %f , Seg loss: %f' % (cert_loss, seg_loss))
                seg_loss += cert_loss

            # Visualize training images
            if opt.vis:
                vis_utils.visImage3Chan(rgb_day[0:1,...], 'rgb_day')
                # vis_utils.visDepth(label_day[0:1, ...].float(), 'day_label')
                day_label_colored = color_coder.color_code_labels(label_day[0:1, ...].unsqueeze(0), False)
                cv2.imshow('sup_day_label', day_label_colored)
Esempio n. 16
0
def validate(val_loader, net, criterion, optimizer, epoch, train_args, restore, visualize):
    net.eval()

    val_loss = AverageMeter()
    inputs_all, gts_all, predictions_all = [], [], []

    for vi, data in enumerate(val_loader):
        inputs, gts = data
        N = inputs.size(0)
        inputs = inputs.cuda()
        gts_l = make_one_hot(gts.unsqueeze(1), wp.num_classes).cuda()
        outputs = net(inputs)
        predictions = outputs.data.max(1)[1].squeeze_(
            1).squeeze_(0).cpu().numpy()

        val_loss.update(criterion(outputs, gts_l).item(), N)

        if random.random() > train_args['val_img_sample_rate']:
            inputs_all.append(None)
        else:
            inputs_all.append(inputs.squeeze_(0).cpu())
        gts_all.append(gts.squeeze_(0).cpu().numpy())
        predictions_all.append(predictions)

    acc, acc_cls, mean_iu, fwavacc = evaluate(
        predictions_all, gts_all, wp.num_classes)

    if mean_iu > train_args['best_record']['mean_iu']:
        train_args['best_record']['val_loss'] = val_loss.avg
        train_args['best_record']['epoch'] = epoch
        train_args['best_record']['acc'] = acc
        train_args['best_record']['acc_cls'] = acc_cls
        train_args['best_record']['mean_iu'] = mean_iu
        train_args['best_record']['fwavacc'] = fwavacc
        snapshot_name = 'epoch_%d_loss_%.5f_acc_%.5f_acc-cls_%.5f_mean-iu_%.5f_fwavacc_%.5f_lr_%.10f' % (
            epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc, optimizer.param_groups[
                1]['lr']
        )
        torch.save(net.state_dict(), os.path.join(
            ckpt_path, exp_name, snapshot_name + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(
            ckpt_path, exp_name, 'opt_' + snapshot_name + '.pth'))

        if train_args['val_save_to_img_file']:
            to_save_dir = os.path.join(ckpt_path, exp_name, str(epoch))
            check_mkdir(to_save_dir)

        val_visual = []
        for idx, data in enumerate(zip(inputs_all, gts_all, predictions_all)):
            if data[0] is None:
                continue
            input_pil = restore(data[0])
            gt_pil = wp.colorize_mask(data[1])
            predictions_pil = wp.colorize_mask(data[2])
            if train_args['val_save_to_img_file']:
                input_pil.save(os.path.join(to_save_dir, '%d_input.png' % idx))
                predictions_pil.save(os.path.join(
                    to_save_dir, '%d_prediction.png' % idx))
                gt_pil.save(os.path.join(to_save_dir, '%d_gt.png' % idx))
            val_visual.extend([visualize(input_pil.convert('RGB')), visualize(gt_pil.convert('RGB')),
                               visualize(predictions_pil.convert('RGB'))])
        val_visual = torch.stack(val_visual, 0)
        val_visual = vutils.make_grid(val_visual, nrow=3, padding=5)
        writer.add_image(snapshot_name, val_visual)

    print('--------------------------------------------------------------------')
    print('[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]' % (
        epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc))

    print('best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f], [epoch %d]' % (
        train_args['best_record']['val_loss'], train_args['best_record']['acc'], train_args['best_record']['acc_cls'],
        train_args['best_record']['mean_iu'], train_args['best_record']['fwavacc'], train_args['best_record']['epoch']))

    print('--------------------------------------------------------------------')

    writer.add_scalar('val_loss', val_loss.avg, epoch)
    writer.add_scalar('acc', acc, epoch)
    writer.add_scalar('acc_cls', acc_cls, epoch)
    writer.add_scalar('mean_iu', mean_iu, epoch)
    writer.add_scalar('fwavacc', fwavacc, epoch)
    writer.add_scalar('lr', optimizer.param_groups[1]['lr'], epoch)

    net.train()
    return val_loss.avg
Esempio n. 17
0
    val_set = wp.Wp('train',
                    transform=input_transform,
                    target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=4,
                            shuffle=False)
    inputs_all, gts_all, predictions_all = [], [], []

    for vi, data in enumerate(val_loader):
        inputs, gts = data
        N = inputs.size(0)
        inputs = Variable(inputs, volatile=True).cuda()
        gts = Variable(gts, volatile=True)
        gts_l = make_one_hot(gts.unsqueeze(1), wp.num_classes).cuda()
        outputs = net(inputs)
        predictions = outputs.data.max(1)[1].squeeze_(1).squeeze_(
            0).cpu().numpy()

        if random.random() > 1:
            inputs_all.append(None)
        else:
            inputs_all.append(inputs.squeeze(0).detach().cpu())
        gts_all.append(gts.squeeze(0).detach().cpu().numpy())
        predictions_all.append(predictions)

    acc, acc_cls, mean_iu, fwavacc = evaluate(predictions_all, gts_all,
                                              wp.num_classes)
    print(mean_iu)
Esempio n. 18
0
        for i, param_group in enumerate(optimizer.param_groups):
            print("Current LR: {} of {}th group".format(param_group['lr'], i))

        train_mse = 0.0
        train_msssim = 0.0
        train_hash = 0.0
        train_loss = 0.0

        with tqdm(total=len(train_loader), desc="Batches") as pbar:
            for i, (data) in enumerate(train_loader):
                model.train()

                img, label = data
                label = label.to(device)
                label = make_one_hot(label)
                img = img.to(device)

                _, output, hashed_layer = model(img)

                if (i % 100 == 0 and epoch == 0) or (i % 500 == 0 and epoch > 0):

                    # PATH ASSUMES ONLY 1 STAGE
                    save_image(torch.cat((img, output)), "../results/ae_hash/images/train_check/{}_{}.jpg".format(epoch, i), nrow=batch_size)

                loss, mse, msssim = criterion1(output, img)
                loss_hash = criterion2(hashed_layer, label)

                if torch.isnan(loss_hash).any():
                    torch.save(model.state_dict(), "nan_aya_wo_model_weights.pt")
                    torch.save(img, "nan_dene_wala_img_batch.pt")
Esempio n. 19
0
######################################################
# K-Means
# model = models.KM(254)
# model.fit(train_data)

######################################################
# K-Nearest Neighbors
# model = models.KNN(3)
# model.fit(train_data, train_labels)
# model.save("model_saves/KNN")
# model.load("model_saves/KNN")

######################################################
# Deep Learning
train_labels = utils.make_one_hot(train_labels, 3)
model = models.NN(neurons=[5, 4, 3])
model.train(train_data,
            train_labels,
            validation_data,
            validation_labels,
            epochs=8)
# model.save("model_saves/NN_model.ckpt")
# model.load("model_saves/NN_model_e7.ckpt")

######################################################
# Test set accuracy for the model
test_predictions = model.predict(test_data)
correct = np.sum(test_predictions == test_labels)
print("Accuracy:", correct / len(test_data))