コード例 #1
0
def train(args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Generate the train and validation sets for the model:
    split_train_val(args, per_val=args.per_val)

    current_time = datetime.now().strftime('%b%d_%H%M%S')
    log_dir = os.path.join('runs', current_time +
                           "_{}_{}".format(args.arch, args.loss))
    writer = SummaryWriter(log_dir=log_dir)
    # Setup Augmentations
    if args.aug:
        data_aug = Compose(
            [RandomRotate(25), RandomHorizontallyFlip(), AddNoise()])
    else:
        data_aug = None

    train_set = section_loader(is_transform=True,
                               split='train',
                               augmentations=data_aug)

    # Without Augmentation:
    val_set = section_loader(is_transform=True,
                             split='val',)

    n_classes = train_set.n_classes

    # Create sampler:

    shuffle = False  # must turn False if using a custom sampler
    with open(pjoin('data', 'splits', 'section_train.txt'), 'r') as f:
        train_list = f.read().splitlines()
    with open(pjoin('data', 'splits', 'section_val.txt'), 'r') as f:
        val_list = f.read().splitlines()

    class CustomSamplerTrain(torch.utils.data.Sampler):
        def __iter__(self):
            char = ['i' if np.random.randint(2) == 1 else 'x']
            self.indices = [idx for (idx, name) in enumerate(
                train_list) if char[0] in name]
            return (self.indices[i] for i in torch.randperm(len(self.indices)))

    class CustomSamplerVal(torch.utils.data.Sampler):
        def __iter__(self):
            char = ['i' if np.random.randint(2) == 1 else 'x']
            self.indices = [idx for (idx, name) in enumerate(
                val_list) if char[0] in name]
            return (self.indices[i] for i in torch.randperm(len(self.indices)))

    trainloader = data.DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  sampler=CustomSamplerTrain(train_list),
                                  num_workers=4,
                                  shuffle=shuffle)
    valloader = data.DataLoader(val_set,
                                batch_size=args.batch_size,
                                sampler=CustomSamplerVal(val_list),
                                num_workers=4)

    # Setup Metrics
    running_metrics = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            model = torch.load(args.resume)
        else:
            print("No checkpoint found at '{}'".format(args.resume))
    else:
        model = get_model(args.arch, args.pretrained, n_classes)

    # Use as many GPUs as we can
    model = torch.nn.DataParallel(
        model, device_ids=range(torch.cuda.device_count()))
    model = model.to(device)  # Send to GPU

    # PYTROCH NOTE: ALWAYS CONSTRUCT OPTIMIZERS AFTER MODEL IS PUSHED TO GPU/CPU,

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        print('Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        # optimizer = torch.optim.Adadelta(model.parameters())
        optimizer = torch.optim.Adam(model.parameters(), amsgrad=True)

    if(args.loss == 'FL'):
        loss_fn = core.loss.focal_loss2d
    else:
        loss_fn = core.loss.cross_entropy

    if args.class_weights:
        # weights are inversely proportional to the frequency of the classes in the training set
        class_weights = torch.tensor(
            [0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852], device=device, requires_grad=False)
    else:
        class_weights = None

    best_iou = -100.0
    class_names = ['upper_ns', 'middle_ns', 'lower_ns',
                   'rijnland_chalk', 'scruff', 'zechstein']

    for arg in vars(args):
        text = arg + ': ' + str(getattr(args, arg))
        writer.add_text('Parameters/', text)

    # training
    for epoch in range(args.n_epoch):
        # Training Mode:
        model.train()
        loss_train, total_iteration = 0, 0

        for i, (images, labels) in enumerate(trainloader):
            image_original, labels_original = images, labels
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            pred = outputs.detach().max(1)[1].cpu().numpy()
            gt = labels.detach().cpu().numpy()
            running_metrics.update(gt, pred)

            loss = loss_fn(input=outputs, target=labels, weight=class_weights)
            loss_train += loss.item()
            loss.backward()

            # gradient clipping
            if args.clip != 0:
                torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
            optimizer.step()
            total_iteration = total_iteration + 1

            if (i) % 20 == 0:
                print("Epoch [%d/%d] training Loss: %.4f" %
                      (epoch + 1, args.n_epoch, loss.item()))

            numbers = [0]
            if i in numbers:
                # number 0 image in the batch
                tb_original_image = vutils.make_grid(
                    image_original[0][0], normalize=True, scale_each=True)
                writer.add_image('train/original_image',
                                 tb_original_image, epoch + 1)

                labels_original = labels_original.numpy()[0]
                correct_label_decoded = train_set.decode_segmap(
                    np.squeeze(labels_original))
                writer.add_image('train/original_label',
                                 np_to_tb(correct_label_decoded), epoch + 1)
                out = F.softmax(outputs, dim=1)

                # this returns the max. channel number:
                prediction = out.max(1)[1].cpu().numpy()[0]
                # this returns the confidence:
                confidence = out.max(1)[0].cpu().detach()[0]
                tb_confidence = vutils.make_grid(
                    confidence, normalize=True, scale_each=True)

                decoded = train_set.decode_segmap(np.squeeze(prediction))
                writer.add_image('train/predicted', np_to_tb(decoded), epoch + 1)
                writer.add_image('train/confidence', tb_confidence, epoch + 1)

                unary = outputs.cpu().detach()
                unary_max = torch.max(unary)
                unary_min = torch.min(unary)
                unary = unary.add((-1*unary_min))
                unary = unary/(unary_max - unary_min)

                for channel in range(0, len(class_names)):
                    decoded_channel = unary[0][channel]
                    tb_channel = vutils.make_grid(
                        decoded_channel, normalize=True, scale_each=True)
                    writer.add_image(
                        f'train_classes/_{class_names[channel]}', tb_channel, epoch + 1)

        # Average metrics, and save in writer()
        loss_train /= total_iteration
        score, class_iou = running_metrics.get_scores()
        writer.add_scalar('train/Pixel Acc', score['Pixel Acc: '], epoch+1)
        writer.add_scalar('train/Mean Class Acc',
                          score['Mean Class Acc: '], epoch+1)
        writer.add_scalar('train/Freq Weighted IoU',
                          score['Freq Weighted IoU: '], epoch+1)
        writer.add_scalar('train/Mean_IoU', score['Mean IoU: '], epoch+1)
        running_metrics.reset()
        writer.add_scalar('train/loss', loss_train, epoch+1)

        if args.per_val != 0:
            with torch.no_grad():  # operations inside don't track history
                # Validation Mode:
                model.eval()
                loss_val, total_iteration_val = 0, 0

                for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
                    image_original, labels_original = images_val, labels_val
                    images_val, labels_val = images_val.to(
                        device), labels_val.to(device)

                    outputs_val = model(images_val)
                    pred = outputs_val.detach().max(1)[1].cpu().numpy()
                    gt = labels_val.detach().cpu().numpy()

                    running_metrics_val.update(gt, pred)

                    loss = loss_fn(input=outputs_val, target=labels_val)

                    total_iteration_val = total_iteration_val + 1

                    if (i_val) % 20 == 0:
                        print("Epoch [%d/%d] validation Loss: %.4f" %
                              (epoch, args.n_epoch, loss.item()))

                    numbers = [0]
                    if i_val in numbers:
                        # number 0 image in the batch
                        tb_original_image = vutils.make_grid(
                            image_original[0][0], normalize=True, scale_each=True)
                        writer.add_image('val/original_image',
                                         tb_original_image, epoch)
                        labels_original = labels_original.numpy()[0]
                        correct_label_decoded = train_set.decode_segmap(
                            np.squeeze(labels_original))
                        writer.add_image('val/original_label',
                                         np_to_tb(correct_label_decoded), epoch + 1)

                        out = F.softmax(outputs_val, dim=1)

                        # this returns the max. channel number:
                        prediction = out.max(1)[1].cpu().detach().numpy()[0]
                        # this returns the confidence:
                        confidence = out.max(1)[0].cpu().detach()[0]
                        tb_confidence = vutils.make_grid(
                            confidence, normalize=True, scale_each=True)

                        decoded = train_set.decode_segmap(np.squeeze(prediction))
                        writer.add_image('val/predicted', np_to_tb(decoded), epoch + 1)
                        writer.add_image('val/confidence',tb_confidence, epoch + 1)

                        unary = outputs.cpu().detach()
                        unary_max, unary_min = torch.max(
                            unary), torch.min(unary)
                        unary = unary.add((-1*unary_min))
                        unary = unary/(unary_max - unary_min)

                        for channel in range(0, len(class_names)):
                            tb_channel = vutils.make_grid(
                                unary[0][channel], normalize=True, scale_each=True)
                            writer.add_image(
                                f'val_classes/_{class_names[channel]}', tb_channel, epoch + 1)

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)

                writer.add_scalar(
                    'val/Pixel Acc', score['Pixel Acc: '], epoch+1)
                writer.add_scalar('val/Mean IoU', score['Mean IoU: '], epoch+1)
                writer.add_scalar('val/Mean Class Acc',
                                  score['Mean Class Acc: '], epoch+1)
                writer.add_scalar('val/Freq Weighted IoU',
                                  score['Freq Weighted IoU: '], epoch+1)

                writer.add_scalar('val/loss', loss.item(), epoch+1)
                running_metrics_val.reset()

                if score['Mean IoU: '] >= best_iou:
                    best_iou = score['Mean IoU: ']
                    model_dir = os.path.join(
                        log_dir, f"{args.arch}_{args.loss}_model.pkl")
                    torch.save(model, model_dir)

        else:  # validation is turned off:
            # just save the latest model:
            if (epoch+1) % 10 == 0:
                model_dir = os.path.join(
                    log_dir, f"{args.arch}_{args.loss}_ep{epoch+1}_model.pkl")
                torch.save(model, model_dir)

    writer.close()
コード例 #2
0
def test(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    log_dir, model_name = os.path.split(args.model_path)
    # load model:
    model = torch.load(args.model_path)
    model = model.to(device)  # Send to GPU if available
    writer = SummaryWriter(log_dir=log_dir)

    class_names = ['upper_ns', 'middle_ns', 'lower_ns',
                   'rijnland_chalk', 'scruff', 'zechstein']
    running_metrics_overall = runningScore(6)

    splits = [args.split if 'both' not in args.split else 'test1', 'test2']
    for sdx, split in enumerate(splits):
        # define indices of the array
        labels = np.load(pjoin('data', 'test_once', split + '_labels.npy'))
        irange, xrange, depth = labels.shape

        if args.inline:
            i_list = list(range(irange))
            i_list = ['i_'+str(inline) for inline in i_list]
        else:
            i_list = []

        if args.crossline:
            x_list = list(range(xrange))
            x_list = ['x_'+str(crossline) for crossline in x_list]
        else:
            x_list = []

        list_test = i_list + x_list

        file_object = open(
            pjoin('data', 'splits', 'section_' + split + '.txt'), 'w')
        file_object.write('\n'.join(list_test))
        file_object.close()

        test_set = section_loader(is_transform=True,
                                  split=split,
                                  augmentations=None)
        n_classes = test_set.n_classes

        test_loader = data.DataLoader(test_set,
                                      batch_size=1,
                                      num_workers=4,
                                      shuffle=False)

        # print the results of this split:
        running_metrics_split = runningScore(n_classes)

        # testing mode:
        with torch.no_grad():  # operations inside don't track history
            model.eval()
            total_iteration = 0
            for i, (images, labels) in enumerate(test_loader):
                total_iteration = total_iteration + 1
                image_original, labels_original = images, labels
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)

                pred = outputs.detach().max(1)[1].cpu().numpy()
                gt = labels.detach().cpu().numpy()
                running_metrics_split.update(gt, pred)
                running_metrics_overall.update(gt, pred)

                numbers = [0, 99, 149, 399, 499]

                if i in numbers:
                    tb_original_image = vutils.make_grid(
                        image_original[0][0], normalize=True, scale_each=True)
                    writer.add_image('test/original_image',
                                     tb_original_image, i)

                    labels_original = labels_original.numpy()[0]
                    correct_label_decoded = test_set.decode_segmap(np.squeeze(labels_original))
                    writer.add_image('test/original_label',
                                     np_to_tb(correct_label_decoded), i)
                    out = F.softmax(outputs, dim=1)

                    # this returns the max. channel number:
                    prediction = out.max(1)[1].cpu().numpy()[0]
                    # this returns the confidence:
                    confidence = out.max(1)[0].cpu().detach()[0]
                    tb_confidence = vutils.make_grid(
                        confidence, normalize=True, scale_each=True)

                    decoded = test_set.decode_segmap(np.squeeze(prediction))
                    writer.add_image('test/predicted', np_to_tb(decoded), i)
                    writer.add_image('test/confidence', tb_confidence, i)

                    # uncomment if you want to visualize the different class heatmaps
                    unary = outputs.cpu().detach()
                    unary_max = torch.max(unary)
                    unary_min = torch.min(unary)
                    unary = unary.add((-1*unary_min))
                    unary = unary/(unary_max - unary_min)

                    for channel in range(0, len(class_names)):
                        decoded_channel = unary[0][channel]
                        tb_channel = vutils.make_grid(decoded_channel, normalize=True, scale_each=True)
                        writer.add_image(f'test_classes/_{class_names[channel]}', tb_channel, i)

        # get scores and save in writer()
        score, class_iou = running_metrics_split.get_scores()

        # Add split results to TB:
        writer.add_text(f'test__{split}/',
                        f'Pixel Acc: {score["Pixel Acc: "]:.3f}', 0)
        for cdx, class_name in enumerate(class_names):
            writer.add_text(
                f'test__{split}/', f'  {class_name}_accuracy {score["Class Accuracy: "][cdx]:.3f}', 0)

        writer.add_text(
            f'test__{split}/', f'Mean Class Acc: {score["Mean Class Acc: "]:.3f}', 0)
        writer.add_text(
            f'test__{split}/', f'Freq Weighted IoU: {score["Freq Weighted IoU: "]:.3f}', 0)
        writer.add_text(f'test__{split}/',
                        f'Mean IoU: {score["Mean IoU: "]:0.3f}', 0)

        running_metrics_split.reset()

    # FINAL TEST RESULTS:
    score, class_iou = running_metrics_overall.get_scores()

    # Add split results to TB:
    writer.add_text('test_final', f'Pixel Acc: {score["Pixel Acc: "]:.3f}', 0)
    for cdx, class_name in enumerate(class_names):
        writer.add_text(
            'test_final', f'  {class_name}_accuracy {score["Class Accuracy: "][cdx]:.3f}', 0)

    writer.add_text(
        'test_final', f'Mean Class Acc: {score["Mean Class Acc: "]:.3f}', 0)
    writer.add_text(
        'test_final', f'Freq Weighted IoU: {score["Freq Weighted IoU: "]:.3f}', 0)
    writer.add_text('test_final', f'Mean IoU: {score["Mean IoU: "]:0.3f}', 0)

    print('--------------- FINAL RESULTS -----------------')
    print(f'Pixel Acc: {score["Pixel Acc: "]:.3f}')
    for cdx, class_name in enumerate(class_names):
        print(
            f'     {class_name}_accuracy {score["Class Accuracy: "][cdx]:.3f}')
    print(f'Mean Class Acc: {score["Mean Class Acc: "]:.3f}')
    print(f'Freq Weighted IoU: {score["Freq Weighted IoU: "]:.3f}')
    print(f'Mean IoU: {score["Mean IoU: "]:0.3f}')

    # Save confusion matrix: 
    confusion = score['confusion_matrix']
    np.savetxt(pjoin(log_dir,'confusion.csv'), confusion, delimiter=" ")

    writer.close()
    return
コード例 #3
0
def train(args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_fname = pjoin(val_fname, 'best_checkpoint.pth')

    log_dir = os.path.join(val_fname, "Tensorboard_Records")
    writer = SummaryWriter(log_dir=log_dir)

    #<--------------------------------------------------------------------------NOT NECESSARY-------------------------------------------------------------------------->
    # Setup Augmentations
    #if args.aug:
    #    data_aug = Compose(
    #        [RandomRotate(10), RandomHorizontallyFlip(), AddNoise()])
    #else:
    #    data_aug = None
    #<--------------------------------------------------------------------------NOT NECESSARY-------------------------------------------------------------------------->

    source_train_set = data_loader_netherlands.PatchLoader(is_transform=True,
                                                           split='train',
                                                           augmentations=None)

    # Without Augmentation:
    source_val_set = data_loader_netherlands.PatchLoader(is_transform=True,
                                                         split='valid')

    n_classes = source_train_set.n_classes

    source_trainloader = data.DataLoader(
        source_train_set,
        batch_size=args.netherlands_batch_size,
        num_workers=1,
        shuffle=True)

    source_valloader = data.DataLoader(source_val_set,
                                       batch_size=args.netherlands_batch_size,
                                       num_workers=1)

    # Setup Metrics
    running_metrics = runningScore(n_classes)
    source_running_metrics_val = runningScore(n_classes)

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            model = torch.load(args.resume)
        else:
            print("No checkpoint found at '{}'".format(args.resume))
    else:
        model = classification.SeismicNet_New(num_classes=n_classes)

    # Use as many GPUs as we can
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model = model.to(device)  # Send to GPU

    # PYTROCH NOTE: ALWAYS CONSTRUCT OPTIMIZERS AFTER MODEL IS PUSHED TO GPU/CPU,

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        print('Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        if args.optim in ["sgd", "SGD", "Sgd"]:
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=args.base_lr,
                                        weight_decay=0.0001,
                                        momentum=0.9)
        elif args.optim in ["adam", "ADAM", "Adam"]:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.base_lr,
                                         weight_decay=0.0001,
                                         amsgrad=True)
        elif args.optim in ["adadelta", "ADADELTA", "AdaDelta", "Adadelta"]:
            optimizer = torch.optim.Adadelta(model.parameters(),
                                             lr=args.base_lr,
                                             rho=0.9,
                                             eps=1e-06,
                                             weight_decay=0.0001)
        else:
            print("Unknown Optimizer! Choose from [sgd, adam, adadelta]")

    if args.train:

        loss_fn = F.cross_entropy
        model.train()
        if args.freeze_bn:
            for m in model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad = False
                    m.bias.requires_grad = False

        start_epoch = 0

    if args.class_weights:
        # weights are inversely proportional to the frequency of the classes in the training set
        class_weights = torch.tensor(
            [0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852],
            device=device,
            requires_grad=False)
    else:
        class_weights = None

    best_mca = -100.0

    #<----------------------------------------------------------------CHANGE ACCORDINGLY------------------------------------------------------------------------------->
    class_names = [
        'upper_ns', 'middle_ns', 'lower_ns', 'rijnland_chalk', 'scruff',
        'zechstein'
    ]
    #<----------------------------------------------------------------CHANGE ACCORDINGLY------------------------------------------------------------------------------->

    for arg in vars(args):
        text = arg + ': ' + str(getattr(args, arg))
        writer.add_text('Parameters/', text)

    for epoch in range(args.n_epoch):
        # Training Mode:
        model.train()
        classification_loss_train, total_iteration = 0, 0

        for i, (source_image, source_label) in enumerate(source_trainloader):

            source_image_original, source_labels_original = source_image, source_label
            source_image, source_label = source_image.to(
                device), source_label.to(device)

            optimizer.zero_grad()

            source_output = model(source_image)

            pred = source_output.detach().max(1)[1].cpu().numpy()
            gt = source_label.detach().cpu().numpy()

            running_metrics.update(gt, pred)

            classification_loss = loss_fn(input=source_output,
                                          target=torch.squeeze(source_label),
                                          weight=class_weights)
            classification_loss_train += classification_loss.item()
            classification_loss.backward()

            if args.clip != 0:
                torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
            optimizer.step()

            total_iteration = total_iteration + 1

            if (i) % 20 == 0:
                print(
                    'epoch: {0}/{1}\t\t iter: {2}/{3}\t\t Classification Loss:{4:.4f}'
                    .format(epoch + 1, args.n_epoch, i + 1,
                            len(source_trainloader),
                            classification_loss.item()))

        # Average metrics, and save in writer()
        classification_loss_train /= total_iteration

        score = running_metrics.get_classification_scores()
        writer.add_scalar('train/Mean Class Acc', score['Mean Class Acc: '],
                          epoch + 1)
        writer.add_scalar('train/Loss', classification_loss_train, epoch + 1)
        running_metrics.reset()

        if args.per_val != 0:
            with torch.no_grad():  # operations inside don't track history
                # Validation Mode:
                model.eval()

                source_classification_loss_val, total_iteration_source_val = 0, 0
                print()
                print(
                    "====================================================================================================================================================================================="
                )
                print()

                for i_val_source, (
                        source_image_val,
                        source_label_val) in enumerate(source_valloader):

                    source_image_val, source_label_val = source_image_val.to(
                        device), source_label_val.to(device)

                    source_outputs_val = model(source_image_val)

                    source_pred = source_outputs_val.detach().max(
                        1)[1].cpu().numpy()
                    source_gt = source_label_val.detach().cpu().numpy()

                    source_running_metrics_val.update(source_gt, source_pred)

                    source_classification_loss = loss_fn(
                        input=source_outputs_val,
                        target=torch.squeeze(source_label_val),
                        weight=class_weights)
                    source_classification_loss_val += source_classification_loss.item(
                    )

                    total_iteration_source_val = total_iteration_source_val + 1

                    if (i_val_source) % 20 == 0:
                        print(
                            'epoch: {0}/{1}\t\t iter: {2}/{3}\t\t Validation Classification Loss(SOURCE): {4:.4f}'
                            .format(epoch + 1, args.n_epoch, i_val_source + 1,
                                    len(source_valloader),
                                    source_classification_loss.item()))

                source_classification_loss_val /= total_iteration_source_val

                source_score = source_running_metrics_val.get_classification_scores(
                )

                save_classification_csv(source_score, val_fname, epoch,
                                        "Source", n_classes)

                writer.add_scalar('val_Source/Mean Class Acc',
                                  source_score['Mean Class Acc: '], epoch + 1)
                writer.add_scalar('val_Source/Loss',
                                  source_classification_loss_val, epoch + 1)
                source_running_metrics_val.reset()

                #Model saving based on Target Validation
                if source_score['Mean Class Acc: '] >= best_mca:
                    best_mca = source_score['Mean Class Acc: ']
                    best_epoch = epoch + 1
                    torch.save(
                        {
                            'epoch': epoch + 1,
                            'state_dict': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        }, model_fname)

                print()
                print()
                print(
                    "Most Recent Checkpoint Saved at Epoch Number {0}.".format(
                        best_epoch))
                print()
                print(
                    "====================================================================================================================================================================================="
                )
                print()


#<--------------------------------------------------------------------------NOT NECESSARY-------------------------------------------------------------------------->
#else:  # validation is turned off:
#    # just save the latest model:
#    if (epoch+1) % 5 == 0:
#        model_dir = os.path.join(
#            log_dir, f"{args.arch}_ep{epoch+1}_model.pkl")
#       #torch.save(model, model_dir)
#        torch.save({'epoch': epoch + 1,
#                'state_dict': model.state_dict(),
#                'optimizer': optimizer.state_dict(),}, model_fname % (epoch + 1))
#<--------------------------------------------------------------------------NOT NECESSARY-------------------------------------------------------------------------->

#writer.add_scalar('train/epoch_lr', optimizer.param_groups[0]["lr"], epoch+1)

    writer.close()
    print("Best Checkpoint Saved at epoch number {0}.".format(best_epoch))
コード例 #4
0
def train(args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Generate the train and validation sets for the model:
    split_train_val(args, per_val=args.per_val)

    current_time = datetime.now().strftime('%b%d_%H%M%S')
    log_dir = os.path.join('runs', current_time +
                           "_{}".format(args.arch))
    writer = SummaryWriter(log_dir=log_dir)
    # Setup Augmentations
    if args.aug:
        data_aug = Compose(
            [RandomRotate(10), RandomHorizontallyFlip(), AddNoise()])
    else:
        data_aug = None

    train_set = PatchLoader(is_transform=True,
                            split='train',
                            stride=args.stride,
                            patch_size=args.patch_size,
                            augmentations=data_aug)

    # Without Augmentation:
    val_set = PatchLoader(is_transform=True,
                          split='val',
                          stride=args.stride,
                          patch_size=args.patch_size)

    n_classes = train_set.n_classes

    trainloader = data.DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  num_workers=1,
                                  shuffle=True)
    valloader = data.DataLoader(val_set,
                                batch_size=args.batch_size,
                                num_workers=1)

    # Setup Metrics
    running_metrics = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)
   

    # Setup Model edited by Tannistha
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            model = torch.load(args.resume)
        else:
            print("No checkpoint found at '{}'".format(args.resume))
    else:
        #model = getattr(deeplab, 'resnet101')(
        #pretrained=(not args.scratch),
        #num_classes=n_classes,
        #num_groups=args.groups,
        #weight_std=args.weight_std,
        #beta=args.beta)
        # edited by Tannistha
        model = getattr(ResNet9, 'resnet9')(
        pretrained=(args.scratch),
        num_classes=n_classes,
        num_groups=args.groups,
        weight_std=args.weight_std,
        beta=args.beta)

    # Use as many GPUs as we can
    model = torch.nn.DataParallel(
        model, device_ids=range(torch.cuda.device_count()))
    model = model.to(device)  # Send to GPU

    # PYTROCH NOTE: ALWAYS CONSTRUCT OPTIMIZERS AFTER MODEL IS PUSHED TO GPU/CPU,

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        print('Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        # optimizer = torch.optim.Adadelta(model.parameters())
        optimizer = torch.optim.SGD(model.parameters(),lr=args.base_lr, weight_decay=0.0001, momentum=0.9)
        #optimizer = torch.optim.Adam(model.parameters(),lr=args.base_lr, weight_decay=0.0001, amsgrad=True)
     ### edited by Tannistha to work with new optimizer
    if args.train:
        criterion = nn.CrossEntropyLoss(ignore_index=255)
        model.train()
        if args.freeze_bn:
            for m in model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad = False
                    m.bias.requires_grad = False
                    
        optimizer = torch.optim.SGD(model.parameters(),lr=args.base_lr, weight_decay=0.0001, momentum=0.9)
        #optimizer = torch.optim.Adam(model.parameters(),lr=args.base_lr, weight_decay=0.0001, amsgrad=True)
        
        start_epoch = 0

    loss_fn = core.loss.cross_entropy

    if args.class_weights:
        # weights are inversely proportional to the frequency of the classes in the training set
        class_weights = torch.tensor(
            [0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852], device=device, requires_grad=False)
    else:
        class_weights = None

    #best_iou = -100.0
    best_mca = -100.0
    class_names = ['upper_ns', 'middle_ns', 'lower_ns',
                   'rijnland_chalk', 'scruff', 'zechstein']
    
    for arg in vars(args):
        text = arg + ': ' + str(getattr(args, arg))
        writer.add_text('Parameters/', text)
        
    model_fname = 'data/deeplab_' + str(args.base_lr) + '_batch_size_' + str(args.batch_size) + '_' + args.exp + '_epoch_%d.pth'
    val_fname = 'val_lr_' + str(args.base_lr) + '_batch_size_' + str(args.batch_size) + '_' + args.exp
    
    for epoch in range(args.n_epoch):
        # Training Mode:
        model.train()
        loss_train, total_iteration = 0, 0

        for i, (images, labels) in enumerate(trainloader):
            
            image_original, labels_original = images, labels
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            
            pred = outputs.detach().max(1)[1].cpu().numpy()
            gt = labels.detach().cpu().numpy()
            running_metrics.update(gt, pred)

            loss = loss_fn(input=outputs, target=labels, weight=class_weights)

            loss_train += loss.item()
            optimizer.zero_grad()
            loss.backward()

            if args.clip != 0:
                torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
            optimizer.step()
            
            total_iteration = total_iteration + 1

            if (i) % 20 == 0:
                print('epoch: {0}/{1}\t\t'
                  'iter: {2}/{3}\t\t'
                  'training Loss:{4:.4f}'.format(epoch + 1, args.n_epoch, i + 1, len(trainloader), loss.item()))

            numbers = [0]
            if i in numbers:
                # number 0 image in the batch
                tb_original_image = vutils.make_grid(
                    image_original[0][0], normalize=True, scale_each=True)
                writer.add_image('train/original_image',
                                 tb_original_image, epoch + 1)

                labels_original = labels_original.numpy()[0]
                correct_label_decoded = train_set.decode_segmap(np.squeeze(labels_original))
                writer.add_image('train/original_label',np_to_tb(correct_label_decoded), epoch + 1)
                out = F.softmax(outputs, dim=1)

                # this returns the max. channel number:
                prediction = out.max(1)[1].cpu().numpy()[0]
                # this returns the confidence:
                confidence = out.max(1)[0].cpu().detach()[0]
                tb_confidence = vutils.make_grid(
                    confidence, normalize=True, scale_each=True)

                decoded = train_set.decode_segmap(np.squeeze(prediction))
                writer.add_image('train/predicted', np_to_tb(decoded), epoch + 1)
                writer.add_image('train/confidence', tb_confidence, epoch + 1)

                unary = outputs.cpu().detach()
                unary_max = torch.max(unary)
                unary_min = torch.min(unary)
                unary = unary.add((-1*unary_min))
                unary = unary/(unary_max - unary_min)

                for channel in range(0, len(class_names)):
                    decoded_channel = unary[0][channel]
                    tb_channel = vutils.make_grid(
                        decoded_channel, normalize=True, scale_each=True)
                    writer.add_image(f'train_classes/_{class_names[channel]}', tb_channel, epoch + 1)

        # Average metrics, and save in writer()
        loss_train /= total_iteration
        score, class_iou = running_metrics.get_scores()
        writer.add_scalar('train/Pixel Acc', score['Pixel Acc: '], epoch+1)
        writer.add_scalar('train/Mean Class Acc',
                          score['Mean Class Acc: '], epoch+1)
        writer.add_scalar('train/Freq Weighted IoU',
                          score['Freq Weighted IoU: '], epoch+1)
        writer.add_scalar('train/Mean_IoU', score['Mean IoU: '], epoch+1)
        running_metrics.reset()
        writer.add_scalar('train/loss', loss_train, epoch+1)
        
        if args.per_val != 0:
            with torch.no_grad():  # operations inside don't track history
                # Validation Mode:
                model.eval()
                loss_val, total_iteration_val = 0, 0

                for i_val, (images_val, labels_val) in enumerate(valloader):
                    image_original, labels_original = images_val, labels_val
                    images_val, labels_val = images_val.to(
                        device), labels_val.to(device)
                    #image_val = to_3_channels(images_val)
                    outputs_val = model(images_val)
                    #outputs_val = model(image_val)
                    pred = outputs_val.detach().max(1)[1].cpu().numpy()
                    gt = labels_val.detach().cpu().numpy()

                    running_metrics_val.update(gt, pred)

                    loss = loss_fn(input=outputs_val, target=labels_val)
                    
                    loss_val += loss.item()

                    total_iteration_val = total_iteration_val + 1

                    if (i_val) % 20 == 0:
                        print("Epoch [%d/%d] validation Loss: %.4f" %
                              (epoch+1, args.n_epoch, loss.item()))

                    numbers = [0]
                    if i_val in numbers:
                        # number 0 image in the batch
                        tb_original_image = vutils.make_grid(
                            image_original[0][0], normalize=True, scale_each=True)
                        writer.add_image('val/original_image',
                                         tb_original_image, epoch)
                        labels_original = labels_original.numpy()[0]
                        correct_label_decoded = train_set.decode_segmap(
                            np.squeeze(labels_original))
                        writer.add_image('val/original_label',
                                         np_to_tb(correct_label_decoded), epoch + 1)

                        out = F.softmax(outputs_val, dim=1)

                        # this returns the max. channel number:
                        prediction = out.max(1)[1].cpu().detach().numpy()[0]
                        # this returns the confidence:
                        confidence = out.max(1)[0].cpu().detach()[0]
                        tb_confidence = vutils.make_grid(
                            confidence, normalize=True, scale_each=True)

                        decoded = train_set.decode_segmap(
                            np.squeeze(prediction))
                        writer.add_image('val/predicted', np_to_tb(decoded), epoch + 1)
                        writer.add_image('val/confidence',
                                         tb_confidence, epoch + 1)

                        unary = outputs.cpu().detach()
                        unary_max, unary_min = torch.max(
                            unary), torch.min(unary)
                        unary = unary.add((-1*unary_min))
                        unary = unary/(unary_max - unary_min)

                        for channel in range(0, len(class_names)):
                            tb_channel = vutils.make_grid(
                                unary[0][channel], normalize=True, scale_each=True)
                            writer.add_image(
                                f'val_classes/_{class_names[channel]}', tb_channel, epoch + 1)
                loss_val /= total_iteration_val
                score, class_iou = running_metrics_val.get_scores()
                
                pd.DataFrame([running_metrics_val.get_scores()[0]["Pixel Acc: "]]).to_csv(os.path.join(val_fname, "metrics", "pixel_acc.csv"), index=False, mode='a', header=(i==0))
                pd.DataFrame([running_metrics_val.get_scores()[0]["Mean Class Acc: "]]).to_csv(os.path.join(val_fname, "metrics", "mean_class_acc.csv"),index=False, mode='a', header=(i==0))
                pd.DataFrame([running_metrics_val.get_scores()[0]["Freq Weighted IoU: "]]).to_csv(os.path.join(val_fname, "metrics", "freq_weighted_iou.csv"),index=False, mode='a', header=(i==0))
                pd.DataFrame([running_metrics_val.get_scores()[0]["Mean IoU: "]]).to_csv(os.path.join(val_fname, "metrics", "mean_iou.csv"), index=False, mode='a', header=(i==0))
                
                cname = os.path.join(val_fname, "metrics", "confusion_matrix", "confusion_matrix_" + str(epoch + 1) + ".csv")
                pd.DataFrame(running_metrics_val.get_scores()[0]["confusion_matrix"]).to_csv(cname, index=False)
                
                pd.DataFrame(running_metrics_val.get_scores()[0]["Class Accuracy: "].reshape((1, 6)), columns=[0, 1, 2, 3, 4, 5]).to_csv(os.path.join(val_fname, "metrics", "class_acc.csv"), index=False, mode = "a", header = (i == 0))
                pd.DataFrame(running_metrics_val.get_scores()[1], columns=[0, 1, 2, 3, 4, 5], index=[0]).to_csv(os.path.join(val_fname, "metrics", "cls_iu.csv"), mode = "a", header = (i == 0))
                

                writer.add_scalar(
                    'val/Pixel Acc', score['Pixel Acc: '], epoch+1)
                writer.add_scalar('val/Mean IoU', score['Mean IoU: '], epoch+1)
                writer.add_scalar('val/Mean Class Acc',
                                  score['Mean Class Acc: '], epoch+1)
                writer.add_scalar('val/Freq Weighted IoU',
                                  score['Freq Weighted IoU: '], epoch+1)

                writer.add_scalar('val/loss', loss_val, epoch+1)
                running_metrics_val.reset()

                #if score['Mean IoU: '] >= best_iou:
                if score['Mean Class Acc: '] >= best_mca:
                    #best_iou = score['Mean IoU: ']
                    best_mca = score['Mean Class Acc: ']
                    model_dir = os.path.join(
                        log_dir, f"{args.arch}_model.pkl")
                    #torch.save(model, model_dir)

                    torch.save({'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),}, model_fname % (epoch + 1))


        else:  # validation is turned off:
            # just save the latest model:
            if (epoch+1) % 5 == 0:
                model_dir = os.path.join(
                    log_dir, f"{args.arch}_ep{epoch+1}_model.pkl")
                #torch.save(model, model_dir)
                torch.save({'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),}, model_fname % (epoch + 1))
        
    writer.close()
コード例 #5
0
def test(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    log_dir, model_name = os.path.split(args.model_path)
    # load model:
    model = torch.load(args.model_path)
    model = model.to(device)  # Send to GPU if available
    writer = SummaryWriter(log_dir=log_dir)

    class_names = [
        'upper_ns', 'middle_ns', 'lower_ns', 'rijnland_chalk', 'scruff',
        'zechstein'
    ]
    running_metrics_overall = runningScore(6)

    splits = [args.split if 'both' not in args.split else 'test1', 'test2']
    for sdx, split in enumerate(splits):
        # define indices of the array
        labels = np.load(pjoin('data', 'test_once', split + '_labels.npy'))
        irange, xrange, depth = labels.shape

        if args.inline:
            i_list = list(range(irange))
            i_list = ['i_' + str(inline) for inline in i_list]
        else:
            i_list = []

        if args.crossline:
            x_list = list(range(xrange))
            x_list = ['x_' + str(crossline) for crossline in x_list]
        else:
            x_list = []

        list_test = i_list + x_list
        print(list_test)

        file_object = open(
            pjoin('data', 'splits', 'section_' + split + '.txt'), 'w')
        file_object.write('\n'.join(list_test))
        file_object.close()

        test_set = SectionLoader(is_transform=True,
                                 split=split,
                                 augmentations=None)
        n_classes = test_set.n_classes

        test_loader = data.DataLoader(test_set,
                                      batch_size=1,
                                      num_workers=4,
                                      shuffle=False)

        running_metrics_split = runningScore(n_classes)
        pred_class = []
        pixel_acc = []
        class_acc = []
        mean_class_acc = []
        fre_weighted_iou = []
        mean_iou = []
        confusion_mat = []
        cls_iu = []

        # testing mode:
        with torch.no_grad():  # operations inside don't track history
            model.eval()
            total_iteration = 0
            for i, (images, labels) in enumerate(test_loader):
                print(f'split: {split}, section: {i}')
                total_iteration = total_iteration + 1
                image_original, labels_original = images, labels

                outputs = patch_label_2d(model=model,
                                         img=images,
                                         patch_size=args.train_patch_size,
                                         stride=args.test_stride)

                pred = outputs.detach().max(1)[1].numpy()
                gt = labels.numpy()
                running_metrics_split.update(gt, pred)
                running_metrics_overall.update(gt, pred)
                pred_class.append(pred)
                pixel_acc.append(
                    running_metrics_overall.get_scores()[0]["Pixel Acc: "])
                class_acc.append(running_metrics_overall.get_scores()[0]
                                 ["Class Accuracy: "])
                mean_class_acc.append(running_metrics_overall.get_scores()[0]
                                      ["Mean Class Acc: "])
                fre_weighted_iou.append(running_metrics_overall.get_scores()[0]
                                        ["Freq Weighted IoU: "])
                mean_iou.append(
                    running_metrics_overall.get_scores()[0]["Mean IoU: "])
                confusion_mat.append(running_metrics_overall.get_scores()[0]
                                     ["confusion_matrix"])
                cls_iu.append(running_metrics_overall.get_scores()[1])

                #numbers = [0, 99, 149, 399, 499]
                numbers = range(100)
                if i in numbers:
                    tb_original_image = vutils.make_grid(image_original[0][0],
                                                         normalize=True,
                                                         scale_each=True)
                    writer.add_image('original_image', tb_original_image, i)
                    #torchvision.transforms.ToPILImage()(tb_original_image)
                    my_string = 'image_original_' + str(i)
                    original_image = tb_original_image.permute(1, 2, 0).numpy()
                    fig, ax = plt.subplots(figsize=(14, 8))
                    ax.imshow(original_image)
                    plt.savefig("test/original_image/{}.jpg".format(
                        my_string))  #, img)
                    plt.close()

                    labels_original = labels_original.numpy()[0]
                    correct_label_decoded = test_set.decode_segmap(
                        np.squeeze(labels_original))
                    print(correct_label_decoded.shape)
                    fig, ax1 = plt.subplots(figsize=(14, 8))
                    ax1.imshow(correct_label_decoded)
                    my_string1 = 'correct_label_' + str(i)
                    plt.savefig("test/original_label/{}.jpg".format(
                        my_string1))  #, img)
                    plt.close()

                    out = F.softmax(outputs, dim=1)
                    # this returns the max. channel number:
                    prediction = out.max(1)[1].cpu().numpy()[0]
                    # this returns the confidence:
                    confidence = out.max(1)[0].cpu().detach()[0]
                    tb_confidence = vutils.make_grid(confidence,
                                                     normalize=True,
                                                     scale_each=True)

                    decoded = test_set.decode_segmap(np.squeeze(prediction))
                    print(decoded.shape)
                    fig, ax2 = plt.subplots(figsize=(14, 8))
                    my_string2 = 'predicted_' + str(i)
                    #plt.imsave("test/predicted/{}.jpg".format(my_string2), decoded)
                    ax2.imshow(decoded)
                    #my_string2 = 'predicted_' + str(i)
                    plt.savefig(
                        "test/predicted/{}.png".format(my_string2))  #, img)
                    plt.close()

    pd.DataFrame(pixel_acc).to_csv("test/metrics/pixel_acc.csv", index=False)
    pd.DataFrame(class_acc).to_csv("test/metrics/class_acc.csv", index=False)
    pd.DataFrame(mean_class_acc).to_csv("test/metrics/mean_class_acc.csv",
                                        index=False)
    pd.DataFrame(fre_weighted_iou).to_csv("test/metrics/freq_weighted_iou.csv",
                                          index=False)
    pd.DataFrame(mean_iou).to_csv("test/metrics/mean_iou.csv", index=False)
    pd.DataFrame(cls_iu).to_csv("test/metrics/cls_iu.csv", index=False)

    for i in range(len(confusion_mat)):
        name = "test/metrics/confusion_matrix/confusion_matrix_" + str(
            i) + ".csv"
        pd.DataFrame(confusion_mat[i]).to_csv(name, index=False)

        # uncomment if you want to visualize the different class heatmaps
        # unary = outputs.cpu().detach()
        # unary_max = torch.max(unary)
        # unary_min = torch.min(unary)
        # unary = unary.add((-1*unary_min))
        # unary = unary/(unary_max - unary_min)

        # for channel in range(0, len(class_names)):
        #     decoded_channel = unary[0][channel]
        #     tb_channel = vutils.make_grid(decoded_channel, normalize=True, scale_each=True)
        #     writer.add_image(f'test_classes/_{class_names[channel]}', tb_channel, i)

        # get scores and save in writer()
        score, class_iou = running_metrics_split.get_scores()

        # Add split results to TB:
        writer.add_text(f'test__{split}/',
                        f'Pixel Acc: {score["Pixel Acc: "]:.3f}', 0)
        for cdx, class_name in enumerate(class_names):
            writer.add_text(
                f'test__{split}/',
                f'  {class_name}_accuracy {score["Class Accuracy: "][cdx]:.3f}',
                0)

        writer.add_text(f'test__{split}/',
                        f'Mean Class Acc: {score["Mean Class Acc: "]:.3f}', 0)
        writer.add_text(
            f'test__{split}/',
            f'Freq Weighted IoU: {score["Freq Weighted IoU: "]:.3f}', 0)
        writer.add_text(f'test__{split}/',
                        f'Mean IoU: {score["Mean IoU: "]:0.3f}', 0)

        running_metrics_split.reset()

    # FINAL TEST RESULTS:
    score, class_iou = running_metrics_overall.get_scores()

    # Add split results to TB:
    writer.add_text('test_final', f'Pixel Acc: {score["Pixel Acc: "]:.3f}', 0)
    for cdx, class_name in enumerate(class_names):
        writer.add_text(
            'test_final',
            f'  {class_name}_accuracy {score["Class Accuracy: "][cdx]:.3f}', 0)

    writer.add_text('test_final',
                    f'Mean Class Acc: {score["Mean Class Acc: "]:.3f}', 0)
    writer.add_text('test_final',
                    f'Freq Weighted IoU: {score["Freq Weighted IoU: "]:.3f}',
                    0)
    writer.add_text('test_final', f'Mean IoU: {score["Mean IoU: "]:0.3f}', 0)

    print('--------------- FINAL RESULTS -----------------')
    print(f'Pixel Acc: {score["Pixel Acc: "]:.3f}')
    for cdx, class_name in enumerate(class_names):
        print(
            f'     {class_name}_accuracy {score["Class Accuracy: "][cdx]:.3f}')
    print(f'Mean Class Acc: {score["Mean Class Acc: "]:.3f}')
    print(f'Freq Weighted IoU: {score["Freq Weighted IoU: "]:.3f}')
    print(f'Mean IoU: {score["Mean IoU: "]:0.3f}')

    # Save confusion matrix:
    confusion = score['confusion_matrix']
    np.savetxt(pjoin(log_dir, 'confusion.csv'), confusion, delimiter=" ")

    writer.close()
    return
コード例 #6
0
def test(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if args.tta:
        testing_code = test_time_augmentation
    else:
        testing_code = test

    log_dir, model_name = os.path.split(args.model_path)
    # load model:
    model = torch.load(args.model_path)
    model = model.to(device)  # Send to GPU if available
    writer = SummaryWriter(log_dir=log_dir)

    class_names = [
        'null', 'upper_ns', 'middle_ns', 'lower_ns', 'rijnland_chalk',
        'scruff', 'zechstein'
    ]
    running_metrics_overall = runningScore(7)

    splits = [args.split if 'both' not in args.split else 'test1', 'test2']

    for sdx, split in enumerate(splits):
        # define indices of the array
        labels = np.load(pjoin('data', 'test_once', split + '_labels.npy'))
        irange, xrange, depth = labels.shape

        if split == 'test1':
            result_volume = torch.zeros([7, 200, 701, 255],
                                        dtype=torch.float32,
                                        device='cpu',
                                        requires_grad=False)
        elif split == 'test2':
            result_volume = torch.zeros([7, 601, 200, 255],
                                        dtype=torch.float32,
                                        device='cpu',
                                        requires_grad=False)

        if args.inline:
            i_list = list(range(irange))
            i_list = ['i_' + str(inline) for inline in i_list]
        else:
            i_list = []

        if args.crossline:
            x_list = list(range(xrange))
            x_list = ['x_' + str(crossline) for crossline in x_list]
        else:
            x_list = []

        num_iline = len(i_list)
        num_xline = len(x_list)

        list_test = i_list + x_list

        file_object = open(
            pjoin('data', 'splits', 'section_' + split + '.txt'), 'w')
        file_object.write('\n'.join(list_test))
        file_object.close()

        test_set = section_loader(is_transform=True,
                                  split=split,
                                  augmentations=None)
        n_classes = test_set.n_classes

        test_loader = data.DataLoader(test_set,
                                      batch_size=1,
                                      num_workers=8,
                                      shuffle=False)

        running_metrics_split = runningScore(n_classes)

        # testing mode:
        with torch.no_grad():  # operations inside don't track history
            model.eval()
            total_iteration = 0
            for i, (imgs, lbls) in enumerate(test_loader):
                print(f'split: {split}, section: {i}')
                total_iteration = total_iteration + 1

                # get sections labaled (7 channels/sections outputed for each section)
                outputs = patch_label_2d(model=model,
                                         img=imgs,
                                         patch_size=args.train_patch_size,
                                         stride=args.test_stride,
                                         testing=testing_code)

                # detach, send to cpu and add to corresponding location:
                update = outputs.detach().cpu()
                # this is the tricky part: (how to know where to add it:)
                if split == 'test1':
                    if i < num_iline and args.inline:  # inline -- split 1
                        assert update.shape == torch.Size(
                            [1, 7, 255, 701]), 'Hmm, something is wrong.'
                        result_volume[:, i, :, :] += update.squeeze().permute(
                            (0, 2, 1))
                    elif i < num_iline + num_xline and args.crossline:  # crossline -- split 1
                        assert update.shape == torch.Size(
                            [1, 7, 255, 200]), 'Hmm, something is wrong.'
                        result_volume[:, :,
                                      i - num_iline, :] += update.squeeze(
                                      ).permute((0, 2, 1))
                    else:  # ???
                        raise ValueError(
                            'Something is wrong with the value of i')
                elif split == 'test2':
                    if i < num_iline and args.inline:  # inline -- split 2
                        assert update.shape == torch.Size(
                            [1, 7, 255, 200]), 'Hmm, something is wrong.'
                        result_volume[:, i, :, :] += update.squeeze().permute(
                            (0, 2, 1))
                    elif i < num_iline + num_xline and args.crossline:  # crossline -- split 2
                        assert update.shape == torch.Size(
                            [1, 7, 255, 601]), 'Hmm, something is wrong.'
                        result_volume[:, :,
                                      i - num_iline, :] += update.squeeze(
                                      ).permute((0, 2, 1))
                    else:  # ???
                        raise ValueError(
                            'Something is wrong with the value of i')

                if i == num_iline - 1:  # last iteration in inline:
                    np.save(pjoin(log_dir, f'volume_split_{split}_iline.npy'),
                            result_volume)
                elif i == num_iline + num_xline - 1:  # last iteration in crossline:
                    np.save(pjoin(log_dir, f'volume_split_{split}_xline.npy'),
                            result_volume)

            # FLATTEN THE VOLUMES (GT AND PRED), AND compute the metrics:
            final_volume = result_volume.max(0)[1].numpy()
            pred = final_volume.flatten()
            gt = labels.flatten() + 1  # make 1-indexed like pred
            running_metrics_split.update(gt, pred)
            running_metrics_overall.update(gt, pred)

            # SAVE THE RESULTING LABELS AS NP NDARRAY: (TO VISUALIZE RESULT LATER)
            np.save(pjoin(log_dir, f'final_volume_split_{split}.npy'),
                    final_volume)

        # ------------------------- Split sdx is done -------------------------

        # get scores and save in writer()
        score, class_iou = running_metrics_split.get_scores()

        # Add split results to TB:
        writer.add_text(f'test__{split}/',
                        f'Pixel Acc: {score["Pixel Acc: "]:.3f}', 0)
        for cdx, class_name in enumerate(class_names[1:]):
            writer.add_text(
                f'test__{split}/',
                f'  {class_name}_accuracy {score["Class Accuracy: "][cdx]:.3f}',
                0)

        writer.add_text(f'test__{split}/',
                        f'Mean Class Acc: {score["Mean Class Acc: "]:.3f}', 0)
        writer.add_text(
            f'test__{split}/',
            f'Freq Weighted IoU: {score["Freq Weighted IoU: "]:.3f}', 0)
        writer.add_text(f'test__{split}/',
                        f'Mean IoU: {score["Mean IoU: "]:0.3f}', 0)
        running_metrics_split.reset()

    # FINAL TEST RESULTS:
    score, class_iou = running_metrics_overall.get_scores()

    # Add split results to TB:
    writer.add_text('test_final', f'Pixel Acc: {score["Pixel Acc: "]:.3f}', 0)
    for cdx, class_name in enumerate(class_names[1:]):
        writer.add_text(
            'test_final',
            f'  {class_name}_accuracy {score["Class Accuracy: "][cdx]:.3f}', 0)

    writer.add_text('test_final',
                    f'Mean Class Acc: {score["Mean Class Acc: "]:.3f}', 0)
    writer.add_text('test_final',
                    f'Freq Weighted IoU: {score["Freq Weighted IoU: "]:.3f}',
                    0)
    writer.add_text('test_final', f'Mean IoU: {score["Mean IoU: "]:0.3f}', 0)
    writer.close()

    print('--------------- FINAL RESULTS -----------------')
    print(f'Pixel Acc: {score["Pixel Acc: "]:.3f}')
    for cdx, class_name in enumerate(class_names[1:]):
        print(
            f'     {class_name}_accuracy {score["Class Accuracy: "][cdx]:.3f}')
    print(f'Mean Class Acc: {score["Mean Class Acc: "]:.3f}')
    print(f'Freq Weighted IoU: {score["Freq Weighted IoU: "]:.3f}')
    print(f'Mean IoU: {score["Mean IoU: "]:0.3f}')

    confusion = score['confusion_matrix']
    np.savetxt(pjoin(log_dir, 'confusion.csv'), confusion, delimiter=" ")

    proper_class_names = [
        'Upper N.S.', 'Middle N.S.', 'Lower N.S.', 'Rijnland/Chalk', 'Scruff',
        'Zechstein'
    ]
    # normalize confidence matrix:
    confusion = confusion.astype('float') / confusion.sum(axis=1)[:,
                                                                  np.newaxis]
    df_cm = pd.DataFrame(confusion,
                         index=[i for i in proper_class_names],
                         columns=[i for i in proper_class_names])
    plt.figure(figsize=(10, 7))
    ax = sn.heatmap(df_cm, annot=True)
    fig = ax.get_figure()
    fig.savefig(pjoin(log_dir, 'confusion.png'), dpi=300)

    return
def test(args):
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")  #get device name

    log_dir, model_name = os.path.split(
        args.model_path)  #split the model directory
    # load model:
    model = torch.load(args.model_path)  #load model
    model = model.to(device)  # Send to GPU if available
    writer = SummaryWriter(log_dir=log_dir)  #open summary writer

    class_names = [
        'null', 'upper_ns', 'middle_ns', 'lower_ns', 'rijnland_chalk',
        'scruff', 'zechstein'
    ]  #class names
    running_metrics_overall = runningScore(6)  #ToDo

    splits = [args.split if 'both' not in args.split else 'test1',
              'test2']  #check if both tests are required
    for sdx, split in enumerate(
            splits
    ):  # sdx: test index, split name (For loop on the number of tests)
        # define indices of the array
        labels = np.load(
            pjoin('data', 'test_once',
                  split + '_labels.npy'))  #load labels of the required test
        irange, xrange, depth = labels.shape  #get the number of test images in that test

        if args.inline:  # if inline mode is required
            i_list = list(range(irange))  #create a list for inline indces
            i_list = ['i_' + str(inline)
                      for inline in i_list]  #create a list of inline names
        else:
            i_list = []  #else send an empty list

        if args.crossline:  #if cross lines are required
            x_list = list(range(xrange))  #create a list of cross line indces
            x_list = ['x_' + str(crossline) for crossline in x_list
                      ]  #create a list of cross line names
        else:
            x_list = []

        list_test = i_list + x_list  #combine inline and crossline indces

        file_object = open(
            pjoin('data', 'splits', 'section_' + split + '.txt'),
            'w')  #open a text file  with the name of the test in write mode
        file_object.write('\n'.join(list_test))  #write the list in the file
        file_object.close()  #close the list

        test_set = section_loader(
            is_transform=True, split=split,
            augmentations=None)  #call the costume made data loader
        n_classes = test_set.n_classes  #set the number of classes
        test_loader = data.DataLoader(
            test_set, batch_size=1, num_workers=4, shuffle=False
        )  #ToDo: batch size equal 1, each section is a batch by itself #call pytorch data loader and pass in the costume made data loader with the overwritten methodes

        # print the results of this split:
        running_metrics_split = runningScore(
            n_classes)  #function calling #ToDo

        # testing mode: Start testind
        with torch.no_grad():  #stop gradient tarcking
            model.eval()  #start evaluation mode
            total_iteration = 0  #iteration variable
            for i, (images, labels
                    ) in enumerate(test_loader):  #load batches one by one
                #print(labels.shape)
                print(f'split: {split}, section: {i}'
                      )  #print test name and section number
                total_iteration = total_iteration + 1  #increment the total number of iterations
                image_original, labels_original = images, labels  #copy images and labels
                images, labels = images.to(device), labels.to(
                    device)  #move images and labels to GPU
                outputs = model(images)  #Feed forward the image to the model
                pred = outputs.detach().max(
                    1)[1].cpu().numpy()  #get the predicted class
                #print('iteration',i,'images',images.shape,'labels',labels,'prediction', pred)
                gt = labels.detach().cpu().numpy(
                )  #get the ground truth labels

                running_metrics_split.update(
                    gt, pred
                )  #send the predicted class and the labels to metrics intialized on 7 classes
                running_metrics_overall.update(
                    gt, pred
                )  #send the predicted class and the labels to metrics intialized on 6 classes

                numbers = [0, 99, 149, 399,
                           499]  #images to consider during testing
                if i in numbers:
                    tb_original_image = vutils.make_grid(
                        image_original[0][0], normalize=True,
                        scale_each=True)  #convert tensor to image
                    writer.add_image('test/original_image', tb_original_image,
                                     i)  #send image to writer

                    labels_original = labels_original.numpy()[
                        0]  #get the ground truth labels
                    correct_label_decoded = test_set.decode_segmap(
                        np.squeeze(labels_original)
                    )  #get the color map of the ground truth
                    writer.add_image('test/original_label',
                                     np_to_tb(correct_label_decoded),
                                     i)  #send the color map to writer
                    out = F.softmax(outputs, dim=1)  #do soft max on Nw op

                    # this returns the max. channel number:
                    prediction = out.max(1)[1].cpu().numpy()[
                        0]  #get the predicions
                    # this returns the confidence:
                    confidence = out.max(1)[0].cpu().detach()[
                        0]  #get the confidence
                    tb_confidence = vutils.make_grid(
                        confidence, normalize=True, scale_each=True
                    )  #convert the Nw op to confidence to Image

                    decoded = test_set.decode_segmap(np.squeeze(
                        prediction))  #get the colour map of the prediction
                    writer.add_image('test/predicted', np_to_tb(decoded),
                                     i)  #send colour map to writer
                    writer.add_image('test/confidence', tb_confidence,
                                     i)  #send confidence to writer

                    # uncomment if you want to visualize the different class heatmaps
                    # unary = outputs.cpu().detach()
                    # unary_max = torch.max(unary)
                    # unary_min = torch.min(unary)
                    # unary = unary.add((-1*unary_min))
                    # unary = unary/(unary_max - unary_min)

                    # for channel in range(0, len(class_names)):
                    #     decoded_channel = unary[0][channel]
                    #     tb_channel = vutils.make_grid(decoded_channel, normalize=True, scale_each=True)
                    #     writer.add_image(f'test_classes/_{class_names[channel]}', tb_channel, i)

        # get scores and save in writer()
        #after finishing one test
        score, class_iou = running_metrics_split.get_scores()  #ToDo:
        #print('score',score)
        # Add split results to TB:
        writer.add_text(f'test__{split}/',
                        f'Pixel Acc: {score["Pixel Acc: "]:.3f}',
                        0)  #send scores to writer
        for cdx, class_name in enumerate(class_names[1:]):
            #print('index',cdx,'class name',class_name)
            writer.add_text(
                f'test__{split}/',
                f'  {class_name}_accuracy {score["Class Accuracy: "][cdx]:.3f}',
                0)  #send individual  class scores to the writer

        writer.add_text(f'test__{split}/',
                        f'Mean Class Acc: {score["Mean Class Acc: "]:.3f}',
                        0)  #send averages to the writer
        writer.add_text(
            f'test__{split}/',
            f'Freq Weighted IoU: {score["Freq Weighted IoU: "]:.3f}', 0)
        writer.add_text(f'test__{split}/',
                        f'Mean IoU: {score["Mean IoU: "]:0.3f}', 0)
        confusion = score['confusion_matrix']
        writer.add_image(f'test/confusion matrix', np_to_tb(confusion), 0)

        running_metrics_split.reset()  #clear the confusion matrix
    #after finishing both tests
    # FINAL TEST RESULTS:
    score, class_iou = running_metrics_overall.get_scores(
    )  # get scores of both tests

    # Add split results to TB:
    writer.add_text('test_final', f'Pixel Acc: {score["Pixel Acc: "]:.3f}',
                    0)  #send results to writer
    for cdx, class_name in enumerate(class_names[1:]):
        writer.add_text(
            'test_final',
            f'  {class_name}_accuracy {score["Class Accuracy: "][cdx]:.3f}',
            0)  #send individual class scores to writer

    writer.add_text('test_final',
                    f'Mean Class Acc: {score["Mean Class Acc: "]:.3f}',
                    0)  #send average test results to the writer
    writer.add_text('test_final',
                    f'Freq Weighted IoU: {score["Freq Weighted IoU: "]:.3f}',
                    0)
    writer.add_text('test_final', f'Mean IoU: {score["Mean IoU: "]:0.3f}', 0)
    writer.add_text('test_final',
                    f'Mean Class Acc: {score["Mean Class Acc: "]:.3f}', 0)
    writer.add_text('test_final',
                    f'Freq Weighted IoU: {score["Freq Weighted IoU: "]:.3f}',
                    0)
    writer.add_text('test_final', f'Mean IoU: {score["Mean IoU: "]:0.3f}', 0)
    confusion = score['confusion_matrix']
    writer.add_image(f'test/FINAL confusion matrix', np_to_tb(confusion), 0)

    print(
        '--------------- FINAL RESULTS -----------------')  #print the results
    print(f'Pixel Acc: {score["Pixel Acc: "]:.3f}')
    for cdx, class_name in enumerate(class_names[1:]):
        print(
            f'     {class_name}_accuracy {score["Class Accuracy: "][cdx]:.3f}')
    print(f'Mean Class Acc: {score["Mean Class Acc: "]:.3f}')
    print(f'Freq Weighted IoU: {score["Freq Weighted IoU: "]:.3f}')
    print(f'Mean IoU: {score["Mean IoU: "]:0.3f}')

    confusion = score['confusion_matrix']
    np.savetxt(pjoin(log_dir, 'confusion.csv'), confusion,
               delimiter=" ")  #save confusion Mtx as text

    writer.close()  #close writer
    return
コード例 #8
0
def train(args):
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")  #Selects Torch Device
    split_train_val(
        args, per_val=args.per_val
    )  #Generate the train and validation sets for the model as text files:

    current_time = datetime.now().strftime(
        '%b%d_%H%M%S')  #Gets Current Time and Date
    log_dir = os.path.join(
        'runs', current_time +
        f"_{args.arch}_{args.model_name}")  #Greate the log directory
    writer = SummaryWriter(
        log_dir=log_dir)  #Initialize the tensorboard summary writer

    # Setup Augmentations
    if args.aug:  #if augmentation is true
        data_aug = Compose(
            [RandomRotate(10),
             RandomHorizontallyFlip(),
             AddNoise()])  #compose some augmentation functions
    else:
        data_aug = None

    loader = section_loader  #name the loader
    train_set = loader(
        is_transform=True, split='train', augmentations=data_aug
    )  #use custom data loader to get the training set (instance of the loader class)
    val_set = loader(
        is_transform=True,
        split='val')  #use custom made data  loader to get the validation

    n_classes = train_set.n_classes  #initalize the number of classes which is hard coded in the dataloader

    # Create sampler:

    shuffle = False  # must turn False if using a custom sampler
    with open(pjoin('data', 'splits', 'section_train.txt'), 'r') as f:
        train_list = f.read().splitlines(
        )  #load the section train list previously stored in a text file created by split_train_val() function
    with open(pjoin('data', 'splits', 'section_val.txt'), 'r') as f:
        val_list = f.read().splitlines(
        )  #load the section train list previously stored in a text file created by split_train_val() function

    class CustomSamplerTrain(torch.utils.data.Sampler
                             ):  #create a custom sampler
        def __iter__(self):
            char = ['i' if np.random.randint(2) == 1 else 'x'
                    ]  #choose randomly between letter i and letter x
            self.indices = [
                idx for (idx, name) in enumerate(train_list) if char[0] in name
            ]  #choose index all inlines or all crosslines from the training list created by split_train_val() function
            return (self.indices[i] for i in torch.randperm(len(self.indices))
                    )  #shuffle the indices and return them

    class CustomSamplerVal(torch.utils.data.Sampler):
        def __iter__(self):
            char = ['i' if np.random.randint(2) == 1 else 'x'
                    ]  #choose randomly between letter i and letter x
            self.indices = [
                idx for (idx, name) in enumerate(val_list) if char[0] in name
            ]  #choose index all inlines or all crosslines from the validation list created by split_train_val() function
            return (self.indices[i] for i in torch.randperm(len(self.indices))
                    )  #shuffle the indices and return them

    trainloader = data.DataLoader(
        train_set, batch_size=args.batch_size, num_workers=12, shuffle=True
    )  #use pytorch data loader to get the batches of training set
    valloader = data.DataLoader(
        val_set, batch_size=args.batch_size, num_workers=12
    )  #use pytorch data loader to get the batches of validation set

    # Setup Metrics
    running_metrics = runningScore(
        n_classes
    )  #initialize class instance for evaluation metrics for training
    running_metrics_val = runningScore(
        n_classes
    )  #initialize class instance for evaluation meterics for validation

    # Setup Model
    if args.resume is not None:  #Check if we have a stored model or not
        if os.path.isfile(args.resume):  #if yes then load the stored model
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            model = torch.load(args.resume)
        else:
            print("No checkpoint found at '{}'".format(
                args.resume))  #if stored model requested with invalid path
    else:  #if  no stord model then load the requested model
        #n_classes=64
        model = get_model(name=args.arch,
                          pretrained=args.pretrained,
                          batch_size=args.batch_size,
                          growth_rate=32,
                          drop_rate=0,
                          n_classes=n_classes)  #get the stored model

    model = torch.nn.DataParallel(
        model, device_ids=range(
            torch.cuda.device_count()))  #Use as many GPUs as we can
    model = model.to(device)  # Send to GPU

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        print('Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
            amsgrad=True,
            weight_decay=args.weight_decay,
            eps=args.eps
        )  #if no specified optimizer then load the defualt optimizer

    loss_fn = core.loss.focal_loss2d  #initialize a function loss function

    if args.class_weights:  #if class weights are to be used then intailize them
        # weights are inversely proportional to the frequency of the classes in the training set
        class_weights = torch.tensor(
            [0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852],
            device=device,
            requires_grad=False)
    else:
        class_weights = None  #if no class weights then no need to use them

    best_iou = -100.0
    class_names = [
        'null', 'upper_ns', 'middle_ns', 'lower_ns', 'rijnland_chalk',
        'scruff', 'zechstein'
    ]  #initialize the name of different classes

    for arg in vars(
            args
    ):  #Before training start writting the summary of the parameters
        text = arg + ': ' + str(getattr(
            args, arg))  #get the attribute name and value, make them as string
        writer.add_text('Parameters/', text)  #store the whole string

    # training
    for epoch in range(args.n_epoch):  #for loop on the number of epochs
        # Training Mode:
        model.train()  #initialize training mode
        loss_train, total_iteration = 0, 0  # intialize training loss and total number of iterations

        for i, (images, labels) in enumerate(
                trainloader
        ):  #start the epoch then initialize the number of iterations per epoch i is the batch number
            image_original, labels_original = images, labels  #store the image and label batch in new varaibles
            images, labels = images.to(device), labels.to(
                device)  #move images and labels to the GPU

            optimizer.zero_grad()  #intialize the optimizer
            outputs = model(
                images
            )  #feed forward the images through the model (outputs is a 7 channel o/p)

            pred = outputs.detach().max(1)[1].cpu().numpy(
            )  #get the model o/p from GPU, select the index of the maximum channel and send it back to CPU
            gt = labels.detach().cpu().numpy(
            )  #get the true lablels from GPU and send them to CPU
            running_metrics.update(
                gt, pred
            )  #call the function update and pass the ground truth and the predicted classes

            loss = loss_fn(input=outputs,
                           target=labels,
                           gamma=args.gamma,
                           loss_type=args.loss_parameters
                           )  #call the loss fuction to calculate the loss
            loss_train += loss.item()  #gets the scalar value held in the loss.
            loss.backward(
            )  # Use autograd to compute the backward pass. This call will compute the gradient of loss with respect to all Tensors with requires_grad=True.

            # gradient clipping
            if args.clip != 0:
                torch.nn.utils.clip_grad_norm(
                    model.parameters(), args.clip
                )  #The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.

            optimizer.step(
            )  #step the optimizer (update the model weights with the new gradients)
            total_iteration = total_iteration + 1  #increment the total number of iterations by 1

            if (
                    i
            ) % 20 == 0:  #if 20% of the total number of iterations pass then
                print(
                    "Epoch [%d/%d] training Loss: %.4f" %
                    (epoch + 1, args.n_epoch, loss.item())
                )  #print the current epoch, total number of epochs and the current training loss

            numbers = [0, 14, 29, 49, 99]  #select some numbers
            if i in numbers:  #if the current batch number is in numbers
                # number 0 image in the batch
                tb_original_image = vutils.make_grid(
                    image_original[0][0], normalize=True, scale_each=True
                )  #select the first image in the batch create a tensorboard grid form the image tensor
                writer.add_image('train/original_image', tb_original_image,
                                 epoch + 1)  #send the image to writer

                labels_original = labels_original.numpy(
                )[0]  #convert the ground truth lablels of the first image in the batch to numpy array
                correct_label_decoded = train_set.decode_segmap(
                    np.squeeze(labels_original)
                )  #Decode segmentation class labels into a color image
                writer.add_image('train/original_label',
                                 np_to_tb(correct_label_decoded),
                                 epoch + 1)  #send the image to the writer
                out = F.softmax(outputs, dim=1)  #softmax of the network o/p
                prediction = out.max(1)[1].cpu().numpy()[
                    0]  #get the index of the maximum value after softmax
                confidence = out.max(1)[0].cpu().detach()[
                    0]  # this returns the confidence in the chosen class

                tb_confidence = vutils.make_grid(
                    confidence, normalize=True, scale_each=True
                )  #convert the confidence from tensor to image

                decoded = train_set.decode_segmap(np.squeeze(
                    prediction))  #Decode predicted classes to colours
                writer.add_image(
                    'train/predicted', np_to_tb(decoded), epoch + 1
                )  #send predicted map to writer along with the epoch number
                writer.add_image(
                    'train/confidence', tb_confidence, epoch + 1
                )  #send the confidence to writer along with the epoch number

                unary = outputs.cpu().detach(
                )  #get the Nw o/p for the whole batch
                unary_max = torch.max(
                    unary)  #normalize the Nw o/p w.r.t whole batch
                unary_min = torch.min(unary)
                unary = unary.add((-1 * unary_min))
                unary = unary / (unary_max - unary_min)

                for channel in range(0, len(class_names)):
                    decoded_channel = unary[0][
                        channel]  #get the normalized o/p for the first image in the batch
                    tb_channel = vutils.make_grid(
                        decoded_channel, normalize=True,
                        scale_each=True)  #prepare a image from tensor
                    writer.add_image(f'train_classes/_{class_names[channel]}',
                                     tb_channel,
                                     epoch + 1)  #send image to writer

        # Average metrics after finishing all batches for the whole epoch, and save in writer()
        loss_train /= total_iteration  #total loss for all iterations/ number of iterations
        score, class_iou = running_metrics.get_scores(
        )  #returns a dictionary of the calculated accuracy metrics and class iu
        writer.add_scalar(
            'train/Pixel Acc', score['Pixel Acc: '],
            epoch + 1)  # store the epoch metrics in the tensorboard writer
        writer.add_scalar('train/Mean Class Acc', score['Mean Class Acc: '],
                          epoch + 1)
        writer.add_scalar('train/Freq Weighted IoU',
                          score['Freq Weighted IoU: '], epoch + 1)
        writer.add_scalar('train/Mean_IoU', score['Mean IoU: '], epoch + 1)
        confusion = score['confusion_matrix']
        writer.add_image(f'train/confusion matrix', np_to_tb(confusion),
                         epoch + 1)

        running_metrics.reset()  #resets the confusion matrix
        writer.add_scalar('train/loss', loss_train,
                          epoch + 1)  #store the training loss
        #Finished one epoch of training, starting one epoch of testing
        if args.per_val != 0:  # if validation is required
            with torch.no_grad():  # operations inside don't track history
                # Validation Mode:
                model.eval()  #start validation mode
                loss_val, total_iteration_val = 0, 0  # initialize validation loss and total number of iterations

                for i_val, (images_val, labels_val) in tqdm(
                        enumerate(valloader)):  #start validation testing
                    image_original, labels_original = images_val, labels_val  #store original validation errors
                    images_val, labels_val = images_val.to(
                        device), labels_val.to(
                            device)  #send validation images and labels to GPU

                    outputs_val = model(images_val)  #feedforward the image
                    pred = outputs_val.detach().max(
                        1)[1].cpu().numpy()  #get the network class prediction
                    gt = labels_val.detach().cpu().numpy(
                    )  #get the ground truth from the GPU

                    running_metrics_val.update(
                        gt, pred)  #run metrics on the validation data

                    loss = loss_fn(input=outputs_val,
                                   target=labels_val,
                                   gamma=args.gamma,
                                   loss_type=args.loss_parameters
                                   )  #calculate the loss function
                    total_iteration_val = total_iteration_val + 1  #increment the loop counter

                    if (
                            i_val
                    ) % 20 == 0:  #After 20% of batches for validation print the validation loss
                        print("Epoch [%d/%d] validation Loss: %.4f" %
                              (epoch, args.n_epoch, loss.item()))

                    numbers = [0]
                    if i_val in numbers:  #select batch number 0
                        # number 0 image in the batch
                        tb_original_image = vutils.make_grid(
                            image_original[0][0],
                            normalize=True,
                            scale_each=True
                        )  #make first tensor in the batch as image
                        writer.add_image('val/original_image',
                                         tb_original_image,
                                         epoch)  #send image to writer
                        labels_original = labels_original.numpy()[
                            0]  #get origianl labels of image 0
                        correct_label_decoded = train_set.decode_segmap(
                            np.squeeze(labels_original)
                        )  #convert the labels to colour map
                        writer.add_image('val/original_label',
                                         np_to_tb(correct_label_decoded),
                                         epoch +
                                         1)  #send the coloured map to writer

                        out = F.softmax(
                            outputs_val,
                            dim=1)  #get soft max of the network 7 channel o/p

                        # this returns the max. channel number:
                        prediction = out.max(1)[1].cpu().detach().numpy(
                        )[0]  #get the position of the max o/p across different channels
                        # this returns the confidence:
                        confidence = out.max(1)[0].cpu().detach(
                        )[0]  #get the maximum o/p of the Nw across different channels
                        tb_confidence = vutils.make_grid(
                            confidence, normalize=True,
                            scale_each=True)  #convert tensor to image

                        decoded = train_set.decode_segmap(
                            np.squeeze(prediction)
                        )  #convert predicted classes to colour maps
                        writer.add_image('val/predicted', np_to_tb(decoded),
                                         epoch + 1)  #send prediction to writer
                        writer.add_image('val/confidence', tb_confidence,
                                         epoch + 1)  #send confidence to writer

                        unary = outputs.cpu().detach(
                        )  #get Nw o/p of the current batch
                        unary_max, unary_min = torch.max(unary), torch.min(
                            unary)  #normalize across all the Nw o/p
                        unary = unary.add((-1 * unary_min))
                        unary = unary / (unary_max - unary_min)

                        for channel in range(
                                0, len(class_names)
                        ):  #for all the 7 channels of the Nw op
                            tb_channel = vutils.make_grid(
                                unary[0][channel],
                                normalize=True,
                                scale_each=True
                            )  #convert the channel o/p of the class to image
                            writer.add_image(
                                f'val_classes/_{class_names[channel]}',
                                tb_channel, epoch + 1)  #send image to writer
                # finished one cycle of validation after iterating over all validation batched
                score, class_iou = running_metrics_val.get_scores(
                )  #returns a dictionary of the calculated accuracy metrics and class iu
                for k, v in score.items():  #??
                    print(k, v)

                writer.add_scalar('val/Pixel Acc', score['Pixel Acc: '],
                                  epoch + 1)  #send metrics to writer
                writer.add_scalar('val/Mean IoU', score['Mean IoU: '],
                                  epoch + 1)
                writer.add_scalar('val/Mean Class Acc',
                                  score['Mean Class Acc: '], epoch + 1)
                writer.add_scalar('val/Freq Weighted IoU',
                                  score['Freq Weighted IoU: '], epoch + 1)
                confusion = score['confusion_matrix']
                writer.add_image(f'val/confusion matrix', np_to_tb(confusion),
                                 epoch + 1)
                writer.add_scalar('val/loss', loss.item(), epoch + 1)
                running_metrics_val.reset()  #reset confusion matrix

                if score['Mean IoU: '] >= best_iou:  #compare with the validation mean iou of current epoch with the best stored validation mean IoU
                    best_iou = score[
                        'Mean IoU: ']  #if better, then store the better and store the current model as the best model
                    model_dir = os.path.join(log_dir,
                                             f"{args.arch}_model_best.pkl")
                    torch.save(model, model_dir)

                if epoch % 10 == 0:  #every 10 epochs store the current model
                    model_dir = os.path.join(
                        log_dir, f"{args.arch}_ep{epoch}_model.pkl")
                    torch.save(model, model_dir)

        else:  # validation is turned off:
            # just save the latest model every 10 epochs:
            if (epoch + 1) % 10 == 0:
                model_dir = os.path.join(
                    log_dir, f"{args.arch}_ep{epoch + 1}_model.pkl")
                torch.save(model, model_dir)

    writer.close()  #close the writer
def train(args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Generate the train and validation sets for the model:
    split_train_val_weak(args, per_val=args.per_val)
    loader = patch_loader_weak

    current_time = datetime.now().strftime('%b%d_%H%M%S')
    log_dir = os.path.join('runs',
                           current_time + f"_{args.arch}_{args.model_name}")
    writer = SummaryWriter(log_dir=log_dir)
    # Setup Augmentations
    if args.aug:
        data_aug = Compose(
            [RandomRotate(15),
             RandomHorizontallyFlip(),
             AddNoise()])
    else:
        data_aug = None

    train_set = loader(is_transform=True,
                       split='train',
                       augmentations=data_aug)

    # Without Augmentation:
    val_set = loader(is_transform=True,
                     split='val',
                     patch_size=args.patch_size)

    #if args.mixup:
    #    train_set1 = loader(is_transform=True,
    #                       split='train',
    #                       augmentations=data_aug)

    n_classes = train_set.n_classes

    trainloader = data.DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)

    #####################################################################
    #shuffle and load
    random.shuffle(train_set.patches['train'])  #shuffle list of IDs
    alpha = 0.5
    trainloader1 = data.DataLoader(
        train_set, batch_size=args.batch_size, num_workers=4,
        shuffle=True)  #load shuffeled data again in another loader
    ######################################################################

    valloader = data.DataLoader(val_set,
                                batch_size=args.batch_size,
                                num_workers=4)

    # Setup Metrics
    running_metrics = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            model = torch.load(args.resume)
        else:
            print("No checkpoint found at '{}'".format(args.resume))
    else:
        model = get_model(args.arch, args.pretrained, n_classes)

    # Use as many GPUs as we can
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model = model.to(device)  # Send to GPU

    # PYTROCH NOTE: ALWAYS CONSTRUCT OPTIMIZERS AFTER MODEL IS PUSHED TO GPU/CPU,

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        print('Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.Adam(model.parameters(), amsgrad=True)

    loss_fn = core.loss.focal_loss2d

    if args.class_weights:
        # weights are inversely proportional to the frequency of the classes in the training set
        class_weights = torch.tensor(
            [0, 0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852],
            device=device,
            requires_grad=False)
    else:
        class_weights = None

    best_iou = -100.0
    class_names = [
        'null', 'upper_ns', 'middle_ns', 'lower_ns', 'rijnland_chalk',
        'scruff', 'zechstein'
    ]

    for arg in vars(args):
        text = arg + ': ' + str(getattr(args, arg))
        writer.add_text('Parameters/', text)

    # training
    for epoch in range(args.n_epoch):
        # Training Mode:
        model.train()
        loss_train, total_iteration = 0, 0

        for (i, (images, labels, confs,
                 sims)), (i1, (images1, labels1, confs1,
                               sims1)) in zip(enumerate(trainloader),
                                              enumerate(trainloader1)):

            N, c, w, h = labels.shape
            one_hot = torch.FloatTensor(N, 7, w, h).zero_()
            labels_hot = one_hot.scatter_(
                1, labels.data,
                1)  # create one hot representation for the labels

            if args.mixup:  #if mixup is true then mix
                lam = torch.from_numpy(
                    np.random.beta(alpha, alpha,
                                   (N, 1, 1, 1))).float()  #sampling lambda
                one_hot = torch.FloatTensor(N, 7, w, h).zero_()
                labels_hot1 = one_hot.scatter_(
                    1, labels1.data,
                    1)  # create one hot representation for the labels
                images, labels, labels_hot, confs, sims = (
                    lam * images + (1 - lam) * images1), (
                        lam * labels.float() + (1 - lam) * labels1.float()), (
                            lam * labels_hot + (1 - lam) * labels_hot1), (
                                lam * confs.squeeze() +
                                (1 - lam) * confs1.squeeze()), (
                                    lam.squeeze() * sims.float() +
                                    (1 - lam).squeeze() * sims1.float()
                                )  #mixup

            image_original = images  #TODO Q: Are the passed original lables correct? in the context of following comaprison in line 233
            images, labels_hot, confs, sims = images.to(device), labels_hot.to(
                device), confs.to(device), sims.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            pred = outputs.detach().max(1)[1].cpu().numpy()
            labels_original = confs.squeeze().permute(
                0, 3, 1, 2).detach().max(1)[1].cpu().numpy()
            running_metrics.update(labels_original, pred)
            loss = loss_fn(input=outputs,
                           target=labels_hot,
                           conf=confs,
                           alpha=class_weights,
                           sim=sims,
                           gamma=args.gamma,
                           loss_type=args.loss_parameters,
                           soft_dev=args.soft_dev)
            loss_train += loss.item()
            loss.backward()

            # gradient clipping
            if args.clip != 0:
                torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
            optimizer.step()
            total_iteration = total_iteration + 1

            if (i) % 20 == 0:
                print("Epoch [%d/%d] training Loss: %.4f" %
                      (epoch + 1, args.n_epoch, loss.item()))

            numbers = [0, 14, 29]
            if i in numbers:

                tb_original_image = vutils.make_grid(image_original[i][0],
                                                     normalize=True,
                                                     scale_each=True)
                writer.add_image('train/original_image', tb_original_image,
                                 epoch + 1)

                # tb_confs_original = vutils.make_grid(confs_tb, normalize=True, scale_each=True)
                # writer.add_image('train/confs_original',tb_confs_original, epoch +1)

                labels_original = labels_original[i]
                correct_label_decoded = train_set.decode_segmap(
                    np.squeeze(labels_original))
                writer.add_image('train/original_label',
                                 np_to_tb(correct_label_decoded), epoch + 1)
                out = F.softmax(outputs, dim=1)

                # this returns the max. channel number:
                prediction = out.max(1)[1].cpu().numpy()[0]
                # this returns the confidence:
                confidence = out.max(1)[0].cpu().detach()[0]
                tb_confidence = vutils.make_grid(confidence,
                                                 normalize=True,
                                                 scale_each=True)

                decoded = train_set.decode_segmap(np.squeeze(prediction))
                writer.add_image('train/predicted', np_to_tb(decoded),
                                 epoch + 1)
                writer.add_image('train/confidence', tb_confidence, epoch + 1)

                unary = outputs.cpu().detach()
                unary_max = torch.max(unary)
                unary_min = torch.min(unary)
                unary = unary.add((-1 * unary_min))
                unary = unary / (unary_max - unary_min)

                for channel in range(0, len(class_names)):
                    decoded_channel = unary[0][channel]
                    tb_channel = vutils.make_grid(decoded_channel,
                                                  normalize=True,
                                                  scale_each=True)
                    writer.add_image(f'train_classes/_{class_names[channel]}',
                                     tb_channel, epoch + 1)

        # Average metrics, and save in writer()
        loss_train /= total_iteration
        score, class_iou = running_metrics.get_scores()
        writer.add_scalar('train/Pixel Acc', score['Pixel Acc: '], epoch + 1)
        writer.add_scalar('train/Mean Class Acc', score['Mean Class Acc: '],
                          epoch + 1)
        writer.add_scalar('train/Freq Weighted IoU',
                          score['Freq Weighted IoU: '], epoch + 1)
        writer.add_scalar('train/Mean_IoU', score['Mean IoU: '], epoch + 1)

        confusion = score['confusion_matrix']
        writer.add_image(f'train/confusion matrix', np_to_tb(confusion),
                         epoch + 1)

        running_metrics.reset()
        writer.add_scalar('train/loss', loss_train, epoch + 1)

        if args.per_val != 0:
            with torch.no_grad():  # operations inside don't track history
                # Validation Mode:
                model.eval()
                loss_val, total_iteration_val = 0, 0

                for i_val, (images_val, labels_val, conf_val,
                            sim_val) in tqdm(enumerate(valloader)):

                    N, c, w, h = labels_val.shape
                    one_hot = torch.FloatTensor(N, 7, w, h).zero_()
                    labels_hot_val = one_hot.scatter_(
                        1, labels_val.data,
                        1)  # create one hot representation for the labels

                    image_original, labels_original = images_val, labels_val
                    images_val, labels_hot_val, conf_val, sim_val = images_val.to(
                        device), labels_hot_val.to(device), conf_val.to(
                            device), sim_val.to(device)

                    outputs_val = model(images_val)
                    pred = outputs_val.detach().max(1)[1].cpu().numpy()
                    gt = labels_val.numpy()

                    running_metrics_val.update(gt, pred)

                    loss = loss_fn(input=outputs_val,
                                   target=labels_hot_val,
                                   conf=conf_val,
                                   alpha=class_weights,
                                   sim=sim_val,
                                   gamma=args.gamma,
                                   loss_type=args.loss_parameters,
                                   soft_dev=args.soft_dev)

                    total_iteration_val = total_iteration_val + 1

                    if (i_val) % 20 == 0:
                        print("Epoch [%d/%d] validation Loss: %.4f" %
                              (epoch, args.n_epoch, loss.item()))

                    numbers = [0]
                    if i_val in numbers:

                        tb_original_image = vutils.make_grid(
                            image_original[i_val][0],
                            normalize=True,
                            scale_each=True)
                        writer.add_image('val/original_image',
                                         tb_original_image, epoch)
                        labels_original = labels_original.numpy()[0]
                        correct_label_decoded = train_set.decode_segmap(
                            np.squeeze(labels_original))
                        writer.add_image('val/original_label',
                                         np_to_tb(correct_label_decoded),
                                         epoch + 1)

                        out = F.softmax(outputs_val, dim=1)

                        # this returns the max. channel number:
                        prediction = out.max(1)[1].cpu().detach().numpy()[0]
                        # this returns the confidence:
                        confidence = out.max(1)[0].cpu().detach()[0]
                        tb_confidence = vutils.make_grid(confidence,
                                                         normalize=True,
                                                         scale_each=True)

                        decoded = train_set.decode_segmap(
                            np.squeeze(prediction))
                        writer.add_image('val/predicted', np_to_tb(decoded),
                                         epoch + 1)
                        writer.add_image('val/confidence', tb_confidence,
                                         epoch + 1)

                        unary = outputs.cpu().detach()
                        unary_max, unary_min = torch.max(unary), torch.min(
                            unary)
                        unary = unary.add((-1 * unary_min))
                        unary = unary / (unary_max - unary_min)

                        for channel in range(0, len(class_names)):
                            tb_channel = vutils.make_grid(unary[0][channel],
                                                          normalize=True,
                                                          scale_each=True)
                            writer.add_image(
                                f'val_classes/_{class_names[channel]}',
                                tb_channel, epoch + 1)

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)

                writer.add_scalar('val/Pixel Acc', score['Pixel Acc: '],
                                  epoch + 1)
                writer.add_scalar('val/Mean IoU', score['Mean IoU: '],
                                  epoch + 1)
                writer.add_scalar('val/Mean Class Acc',
                                  score['Mean Class Acc: '], epoch + 1)
                writer.add_scalar('val/Freq Weighted IoU',
                                  score['Freq Weighted IoU: '], epoch + 1)

                confusion = score['confusion_matrix']
                writer.add_image(f'val/confusion matrix', np_to_tb(confusion),
                                 epoch + 1)
                writer.add_scalar('val/loss', loss.item(), epoch + 1)
                running_metrics_val.reset()

                if score['Mean IoU: '] >= best_iou:
                    best_iou = score['Mean IoU: ']
                    model_dir = os.path.join(log_dir,
                                             f"{args.arch}_model_best.pkl")
                    torch.save(model, model_dir)

                if epoch % 10 == 0:
                    model_dir = os.path.join(
                        log_dir, f"{args.arch}_ep{epoch}_model.pkl")
                    torch.save(model, model_dir)

        else:  # validation is turned off:
            # just save the latest model:
            if epoch % 10 == 0:
                model_dir = os.path.join(log_dir,
                                         f"{args.arch}_ep{epoch+1}_model.pkl")
                torch.save(model, model_dir)

    writer.close()