Exemplo n.º 1
0
def train(args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Generate the train and validation sets for the model:
    split_train_val(args, per_val=args.per_val)

    current_time = datetime.now().strftime('%b%d_%H%M%S')
    log_dir = os.path.join(
        'runs', current_time + "_{}_{}".format(args.arch, args.loss))
    writer = SummaryWriter(log_dir=log_dir)
    # Setup Augmentations
    if args.aug:
        data_aug = Compose(
            [RandomRotate(30),
             RandomHorizontallyFlip(),
             AddNoise()])
    else:
        data_aug = None

    train_set = patch_loader(is_transform=True,
                             split='train',
                             stride=args.stride,
                             patch_size=args.patch_size,
                             augmentations=data_aug)

    # Without Augmentation:
    val_set = patch_loader(is_transform=True,
                           split='val',
                           stride=args.stride,
                           patch_size=args.patch_size)

    n_classes = train_set.n_classes

    trainloader = data.DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)
    valloader = data.DataLoader(val_set,
                                batch_size=args.batch_size,
                                num_workers=4)

    # Setup Metrics
    running_metrics = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            model = torch.load(args.resume)
        else:
            print("No checkpoint found at '{}'".format(args.resume))
    else:
        model = get_model(args.arch, args.pretrained, n_classes)

    # Use as many GPUs as we can
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model = model.to(device)  # Send to GPU

    # PYTROCH NOTE: ALWAYS CONSTRUCT OPTIMIZERS AFTER MODEL IS PUSHED TO GPU/CPU,

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        print('Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        # optimizer = torch.optim.Adadelta(model.parameters())
        optimizer = torch.optim.Adam(model.parameters(), amsgrad=True)

    if (args.loss == 'FL'):
        loss_fn = core.loss.focal_loss2d
    else:
        loss_fn = core.loss.cross_entropy

    if args.class_weights:
        # weights are inversely proportional to the frequency of the classes in the training set
        class_weights = torch.tensor(
            [0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852],
            device=device,
            requires_grad=False)
    else:
        class_weights = None

    best_iou = -100.0
    class_names = [
        'upper_ns', 'middle_ns', 'lower_ns', 'rijnland_chalk', 'scruff',
        'zechstein'
    ]

    for arg in vars(args):
        text = arg + ': ' + str(getattr(args, arg))
        writer.add_text('Parameters/', text)

    # training
    for epoch in range(args.n_epoch):
        # Training Mode:
        model.train()
        loss_train, total_iteration = 0, 0

        for i, (images, labels) in enumerate(trainloader):
            image_original, labels_original = images, labels
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            pred = outputs.detach().max(1)[1].cpu().numpy()
            gt = labels.detach().cpu().numpy()
            running_metrics.update(gt, pred)

            loss = loss_fn(input=outputs, target=labels, weight=class_weights)
            loss_train += loss.item()
            loss.backward()

            # gradient clipping
            if args.clip != 0:
                torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
            optimizer.step()
            total_iteration = total_iteration + 1

            if (i) % 20 == 0:
                print("Epoch [%d/%d] training Loss: %.4f" %
                      (epoch + 1, args.n_epoch, loss.item()))

            numbers = [0]
            if i in numbers:
                # number 0 image in the batch
                tb_original_image = vutils.make_grid(image_original[0][0],
                                                     normalize=True,
                                                     scale_each=True)
                writer.add_image('train/original_image', tb_original_image,
                                 epoch + 1)

                labels_original = labels_original.numpy()[0]
                correct_label_decoded = train_set.decode_segmap(
                    np.squeeze(labels_original))
                writer.add_image('train/original_label', correct_label_decoded,
                                 epoch + 1)
                out = F.softmax(outputs, dim=1)

                # this returns the max. channel number:
                prediction = out.max(1)[1].cpu().numpy()[0]
                # this returns the confidence:
                confidence = out.max(1)[0].cpu().detach()[0]
                tb_confidence = vutils.make_grid(confidence,
                                                 normalize=True,
                                                 scale_each=True)

                decoded = train_set.decode_segmap(np.squeeze(prediction))
                writer.add_image('train/predicted', decoded, epoch + 1)
                writer.add_image('train/confidence', tb_confidence, epoch + 1)

                unary = outputs.cpu().detach()
                unary_max = torch.max(unary)
                unary_min = torch.min(unary)
                unary = unary.add((-1 * unary_min))
                unary = unary / (unary_max - unary_min)

                for channel in range(0, len(class_names)):
                    decoded_channel = unary[0][channel]
                    tb_channel = vutils.make_grid(decoded_channel,
                                                  normalize=True,
                                                  scale_each=True)
                    writer.add_image(f'train_classes/_{class_names[channel]}',
                                     tb_channel, epoch + 1)

        # Average metrics, and save in writer()
        loss_train /= total_iteration
        score, class_iou = running_metrics.get_scores()
        writer.add_scalar('train/Pixel Acc', score['Pixel Acc: '], epoch + 1)
        writer.add_scalar('train/Mean Class Acc', score['Mean Class Acc: '],
                          epoch + 1)
        writer.add_scalar('train/Freq Weighted IoU',
                          score['Freq Weighted IoU: '], epoch + 1)
        writer.add_scalar('train/Mean_IoU', score['Mean IoU: '], epoch + 1)
        running_metrics.reset()
        writer.add_scalar('train/loss', loss_train, epoch + 1)

        if args.per_val != 0:
            with torch.no_grad():  # operations inside don't track history
                # Validation Mode:
                model.eval()
                loss_val, total_iteration_val = 0, 0

                for i_val, (images_val,
                            labels_val) in tqdm(enumerate(valloader)):
                    image_original, labels_original = images_val, labels_val
                    images_val, labels_val = images_val.to(
                        device), labels_val.to(device)

                    outputs_val = model(images_val)
                    pred = outputs_val.detach().max(1)[1].cpu().numpy()
                    gt = labels_val.detach().cpu().numpy()

                    running_metrics_val.update(gt, pred)

                    loss = loss_fn(input=outputs_val, target=labels_val)

                    total_iteration_val = total_iteration_val + 1

                    if (i_val) % 20 == 0:
                        print("Epoch [%d/%d] validation Loss: %.4f" %
                              (epoch, args.n_epoch, loss.item()))

                    numbers = [0]
                    if i_val in numbers:
                        # number 0 image in the batch
                        tb_original_image = vutils.make_grid(
                            image_original[0][0],
                            normalize=True,
                            scale_each=True)
                        writer.add_image('val/original_image',
                                         tb_original_image, epoch)
                        labels_original = labels_original.numpy()[0]
                        correct_label_decoded = train_set.decode_segmap(
                            np.squeeze(labels_original))
                        writer.add_image('val/original_label',
                                         correct_label_decoded, epoch + 1)

                        out = F.softmax(outputs_val, dim=1)

                        # this returns the max. channel number:
                        prediction = out.max(1)[1].cpu().detach().numpy()[0]
                        # this returns the confidence:
                        confidence = out.max(1)[0].cpu().detach()[0]
                        tb_confidence = vutils.make_grid(confidence,
                                                         normalize=True,
                                                         scale_each=True)

                        decoded = train_set.decode_segmap(
                            np.squeeze(prediction))
                        writer.add_image('val/predicted', decoded, epoch + 1)
                        writer.add_image('val/confidence', tb_confidence,
                                         epoch + 1)

                        unary = outputs.cpu().detach()
                        unary_max, unary_min = torch.max(unary), torch.min(
                            unary)
                        unary = unary.add((-1 * unary_min))
                        unary = unary / (unary_max - unary_min)

                        for channel in range(0, len(class_names)):
                            tb_channel = vutils.make_grid(unary[0][channel],
                                                          normalize=True,
                                                          scale_each=True)
                            writer.add_image(
                                f'val_classes/_{class_names[channel]}',
                                tb_channel, epoch + 1)

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)

                writer.add_scalar('val/Pixel Acc', score['Pixel Acc: '],
                                  epoch + 1)
                writer.add_scalar('val/Mean IoU', score['Mean IoU: '],
                                  epoch + 1)
                writer.add_scalar('val/Mean Class Acc',
                                  score['Mean Class Acc: '], epoch + 1)
                writer.add_scalar('val/Freq Weighted IoU',
                                  score['Freq Weighted IoU: '], epoch + 1)

                writer.add_scalar('val/loss', loss.item(), epoch + 1)
                running_metrics_val.reset()

                if score['Mean IoU: '] >= best_iou:
                    best_iou = score['Mean IoU: ']
                    model_dir = os.path.join(
                        log_dir, f"{args.arch}_{args.loss}_model.pkl")
                    torch.save(model, model_dir)

        else:  # validation is turned off:
            # just save the latest model:
            if (epoch + 1) % 5 == 0:
                model_dir = os.path.join(
                    log_dir, f"{args.arch}_{args.loss}_ep{epoch+1}_model.pkl")
                torch.save(model, model_dir)

    writer.close()
Exemplo n.º 2
0
def train(args):
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")  #Selects Torch Device
    split_train_val(
        args, per_val=args.per_val
    )  #Generate the train and validation sets for the model as text files:

    current_time = datetime.now().strftime(
        '%b%d_%H%M%S')  #Gets Current Time and Date
    log_dir = os.path.join(
        'runs', current_time +
        f"_{args.arch}_{args.model_name}")  #Greate the log directory
    writer = SummaryWriter(
        log_dir=log_dir)  #Initialize the tensorboard summary writer

    # Setup Augmentations
    if args.aug:  #if augmentation is true
        data_aug = Compose(
            [RandomRotate(10),
             RandomHorizontallyFlip(),
             AddNoise()])  #compose some augmentation functions
    else:
        data_aug = None

    loader = section_loader  #name the loader
    train_set = loader(
        is_transform=True, split='train', augmentations=data_aug
    )  #use custom data loader to get the training set (instance of the loader class)
    val_set = loader(
        is_transform=True,
        split='val')  #use custom made data  loader to get the validation

    n_classes = train_set.n_classes  #initalize the number of classes which is hard coded in the dataloader

    # Create sampler:

    shuffle = False  # must turn False if using a custom sampler
    with open(pjoin('data', 'splits', 'section_train.txt'), 'r') as f:
        train_list = f.read().splitlines(
        )  #load the section train list previously stored in a text file created by split_train_val() function
    with open(pjoin('data', 'splits', 'section_val.txt'), 'r') as f:
        val_list = f.read().splitlines(
        )  #load the section train list previously stored in a text file created by split_train_val() function

    class CustomSamplerTrain(torch.utils.data.Sampler
                             ):  #create a custom sampler
        def __iter__(self):
            char = ['i' if np.random.randint(2) == 1 else 'x'
                    ]  #choose randomly between letter i and letter x
            self.indices = [
                idx for (idx, name) in enumerate(train_list) if char[0] in name
            ]  #choose index all inlines or all crosslines from the training list created by split_train_val() function
            return (self.indices[i] for i in torch.randperm(len(self.indices))
                    )  #shuffle the indices and return them

    class CustomSamplerVal(torch.utils.data.Sampler):
        def __iter__(self):
            char = ['i' if np.random.randint(2) == 1 else 'x'
                    ]  #choose randomly between letter i and letter x
            self.indices = [
                idx for (idx, name) in enumerate(val_list) if char[0] in name
            ]  #choose index all inlines or all crosslines from the validation list created by split_train_val() function
            return (self.indices[i] for i in torch.randperm(len(self.indices))
                    )  #shuffle the indices and return them

    trainloader = data.DataLoader(
        train_set, batch_size=args.batch_size, num_workers=12, shuffle=True
    )  #use pytorch data loader to get the batches of training set
    valloader = data.DataLoader(
        val_set, batch_size=args.batch_size, num_workers=12
    )  #use pytorch data loader to get the batches of validation set

    # Setup Metrics
    running_metrics = runningScore(
        n_classes
    )  #initialize class instance for evaluation metrics for training
    running_metrics_val = runningScore(
        n_classes
    )  #initialize class instance for evaluation meterics for validation

    # Setup Model
    if args.resume is not None:  #Check if we have a stored model or not
        if os.path.isfile(args.resume):  #if yes then load the stored model
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            model = torch.load(args.resume)
        else:
            print("No checkpoint found at '{}'".format(
                args.resume))  #if stored model requested with invalid path
    else:  #if  no stord model then load the requested model
        #n_classes=64
        model = get_model(name=args.arch,
                          pretrained=args.pretrained,
                          batch_size=args.batch_size,
                          growth_rate=32,
                          drop_rate=0,
                          n_classes=n_classes)  #get the stored model

    model = torch.nn.DataParallel(
        model, device_ids=range(
            torch.cuda.device_count()))  #Use as many GPUs as we can
    model = model.to(device)  # Send to GPU

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        print('Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
            amsgrad=True,
            weight_decay=args.weight_decay,
            eps=args.eps
        )  #if no specified optimizer then load the defualt optimizer

    loss_fn = core.loss.focal_loss2d  #initialize a function loss function

    if args.class_weights:  #if class weights are to be used then intailize them
        # weights are inversely proportional to the frequency of the classes in the training set
        class_weights = torch.tensor(
            [0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852],
            device=device,
            requires_grad=False)
    else:
        class_weights = None  #if no class weights then no need to use them

    best_iou = -100.0
    class_names = [
        'null', 'upper_ns', 'middle_ns', 'lower_ns', 'rijnland_chalk',
        'scruff', 'zechstein'
    ]  #initialize the name of different classes

    for arg in vars(
            args
    ):  #Before training start writting the summary of the parameters
        text = arg + ': ' + str(getattr(
            args, arg))  #get the attribute name and value, make them as string
        writer.add_text('Parameters/', text)  #store the whole string

    # training
    for epoch in range(args.n_epoch):  #for loop on the number of epochs
        # Training Mode:
        model.train()  #initialize training mode
        loss_train, total_iteration = 0, 0  # intialize training loss and total number of iterations

        for i, (images, labels) in enumerate(
                trainloader
        ):  #start the epoch then initialize the number of iterations per epoch i is the batch number
            image_original, labels_original = images, labels  #store the image and label batch in new varaibles
            images, labels = images.to(device), labels.to(
                device)  #move images and labels to the GPU

            optimizer.zero_grad()  #intialize the optimizer
            outputs = model(
                images
            )  #feed forward the images through the model (outputs is a 7 channel o/p)

            pred = outputs.detach().max(1)[1].cpu().numpy(
            )  #get the model o/p from GPU, select the index of the maximum channel and send it back to CPU
            gt = labels.detach().cpu().numpy(
            )  #get the true lablels from GPU and send them to CPU
            running_metrics.update(
                gt, pred
            )  #call the function update and pass the ground truth and the predicted classes

            loss = loss_fn(input=outputs,
                           target=labels,
                           gamma=args.gamma,
                           loss_type=args.loss_parameters
                           )  #call the loss fuction to calculate the loss
            loss_train += loss.item()  #gets the scalar value held in the loss.
            loss.backward(
            )  # Use autograd to compute the backward pass. This call will compute the gradient of loss with respect to all Tensors with requires_grad=True.

            # gradient clipping
            if args.clip != 0:
                torch.nn.utils.clip_grad_norm(
                    model.parameters(), args.clip
                )  #The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.

            optimizer.step(
            )  #step the optimizer (update the model weights with the new gradients)
            total_iteration = total_iteration + 1  #increment the total number of iterations by 1

            if (
                    i
            ) % 20 == 0:  #if 20% of the total number of iterations pass then
                print(
                    "Epoch [%d/%d] training Loss: %.4f" %
                    (epoch + 1, args.n_epoch, loss.item())
                )  #print the current epoch, total number of epochs and the current training loss

            numbers = [0, 14, 29, 49, 99]  #select some numbers
            if i in numbers:  #if the current batch number is in numbers
                # number 0 image in the batch
                tb_original_image = vutils.make_grid(
                    image_original[0][0], normalize=True, scale_each=True
                )  #select the first image in the batch create a tensorboard grid form the image tensor
                writer.add_image('train/original_image', tb_original_image,
                                 epoch + 1)  #send the image to writer

                labels_original = labels_original.numpy(
                )[0]  #convert the ground truth lablels of the first image in the batch to numpy array
                correct_label_decoded = train_set.decode_segmap(
                    np.squeeze(labels_original)
                )  #Decode segmentation class labels into a color image
                writer.add_image('train/original_label',
                                 np_to_tb(correct_label_decoded),
                                 epoch + 1)  #send the image to the writer
                out = F.softmax(outputs, dim=1)  #softmax of the network o/p
                prediction = out.max(1)[1].cpu().numpy()[
                    0]  #get the index of the maximum value after softmax
                confidence = out.max(1)[0].cpu().detach()[
                    0]  # this returns the confidence in the chosen class

                tb_confidence = vutils.make_grid(
                    confidence, normalize=True, scale_each=True
                )  #convert the confidence from tensor to image

                decoded = train_set.decode_segmap(np.squeeze(
                    prediction))  #Decode predicted classes to colours
                writer.add_image(
                    'train/predicted', np_to_tb(decoded), epoch + 1
                )  #send predicted map to writer along with the epoch number
                writer.add_image(
                    'train/confidence', tb_confidence, epoch + 1
                )  #send the confidence to writer along with the epoch number

                unary = outputs.cpu().detach(
                )  #get the Nw o/p for the whole batch
                unary_max = torch.max(
                    unary)  #normalize the Nw o/p w.r.t whole batch
                unary_min = torch.min(unary)
                unary = unary.add((-1 * unary_min))
                unary = unary / (unary_max - unary_min)

                for channel in range(0, len(class_names)):
                    decoded_channel = unary[0][
                        channel]  #get the normalized o/p for the first image in the batch
                    tb_channel = vutils.make_grid(
                        decoded_channel, normalize=True,
                        scale_each=True)  #prepare a image from tensor
                    writer.add_image(f'train_classes/_{class_names[channel]}',
                                     tb_channel,
                                     epoch + 1)  #send image to writer

        # Average metrics after finishing all batches for the whole epoch, and save in writer()
        loss_train /= total_iteration  #total loss for all iterations/ number of iterations
        score, class_iou = running_metrics.get_scores(
        )  #returns a dictionary of the calculated accuracy metrics and class iu
        writer.add_scalar(
            'train/Pixel Acc', score['Pixel Acc: '],
            epoch + 1)  # store the epoch metrics in the tensorboard writer
        writer.add_scalar('train/Mean Class Acc', score['Mean Class Acc: '],
                          epoch + 1)
        writer.add_scalar('train/Freq Weighted IoU',
                          score['Freq Weighted IoU: '], epoch + 1)
        writer.add_scalar('train/Mean_IoU', score['Mean IoU: '], epoch + 1)
        confusion = score['confusion_matrix']
        writer.add_image(f'train/confusion matrix', np_to_tb(confusion),
                         epoch + 1)

        running_metrics.reset()  #resets the confusion matrix
        writer.add_scalar('train/loss', loss_train,
                          epoch + 1)  #store the training loss
        #Finished one epoch of training, starting one epoch of testing
        if args.per_val != 0:  # if validation is required
            with torch.no_grad():  # operations inside don't track history
                # Validation Mode:
                model.eval()  #start validation mode
                loss_val, total_iteration_val = 0, 0  # initialize validation loss and total number of iterations

                for i_val, (images_val, labels_val) in tqdm(
                        enumerate(valloader)):  #start validation testing
                    image_original, labels_original = images_val, labels_val  #store original validation errors
                    images_val, labels_val = images_val.to(
                        device), labels_val.to(
                            device)  #send validation images and labels to GPU

                    outputs_val = model(images_val)  #feedforward the image
                    pred = outputs_val.detach().max(
                        1)[1].cpu().numpy()  #get the network class prediction
                    gt = labels_val.detach().cpu().numpy(
                    )  #get the ground truth from the GPU

                    running_metrics_val.update(
                        gt, pred)  #run metrics on the validation data

                    loss = loss_fn(input=outputs_val,
                                   target=labels_val,
                                   gamma=args.gamma,
                                   loss_type=args.loss_parameters
                                   )  #calculate the loss function
                    total_iteration_val = total_iteration_val + 1  #increment the loop counter

                    if (
                            i_val
                    ) % 20 == 0:  #After 20% of batches for validation print the validation loss
                        print("Epoch [%d/%d] validation Loss: %.4f" %
                              (epoch, args.n_epoch, loss.item()))

                    numbers = [0]
                    if i_val in numbers:  #select batch number 0
                        # number 0 image in the batch
                        tb_original_image = vutils.make_grid(
                            image_original[0][0],
                            normalize=True,
                            scale_each=True
                        )  #make first tensor in the batch as image
                        writer.add_image('val/original_image',
                                         tb_original_image,
                                         epoch)  #send image to writer
                        labels_original = labels_original.numpy()[
                            0]  #get origianl labels of image 0
                        correct_label_decoded = train_set.decode_segmap(
                            np.squeeze(labels_original)
                        )  #convert the labels to colour map
                        writer.add_image('val/original_label',
                                         np_to_tb(correct_label_decoded),
                                         epoch +
                                         1)  #send the coloured map to writer

                        out = F.softmax(
                            outputs_val,
                            dim=1)  #get soft max of the network 7 channel o/p

                        # this returns the max. channel number:
                        prediction = out.max(1)[1].cpu().detach().numpy(
                        )[0]  #get the position of the max o/p across different channels
                        # this returns the confidence:
                        confidence = out.max(1)[0].cpu().detach(
                        )[0]  #get the maximum o/p of the Nw across different channels
                        tb_confidence = vutils.make_grid(
                            confidence, normalize=True,
                            scale_each=True)  #convert tensor to image

                        decoded = train_set.decode_segmap(
                            np.squeeze(prediction)
                        )  #convert predicted classes to colour maps
                        writer.add_image('val/predicted', np_to_tb(decoded),
                                         epoch + 1)  #send prediction to writer
                        writer.add_image('val/confidence', tb_confidence,
                                         epoch + 1)  #send confidence to writer

                        unary = outputs.cpu().detach(
                        )  #get Nw o/p of the current batch
                        unary_max, unary_min = torch.max(unary), torch.min(
                            unary)  #normalize across all the Nw o/p
                        unary = unary.add((-1 * unary_min))
                        unary = unary / (unary_max - unary_min)

                        for channel in range(
                                0, len(class_names)
                        ):  #for all the 7 channels of the Nw op
                            tb_channel = vutils.make_grid(
                                unary[0][channel],
                                normalize=True,
                                scale_each=True
                            )  #convert the channel o/p of the class to image
                            writer.add_image(
                                f'val_classes/_{class_names[channel]}',
                                tb_channel, epoch + 1)  #send image to writer
                # finished one cycle of validation after iterating over all validation batched
                score, class_iou = running_metrics_val.get_scores(
                )  #returns a dictionary of the calculated accuracy metrics and class iu
                for k, v in score.items():  #??
                    print(k, v)

                writer.add_scalar('val/Pixel Acc', score['Pixel Acc: '],
                                  epoch + 1)  #send metrics to writer
                writer.add_scalar('val/Mean IoU', score['Mean IoU: '],
                                  epoch + 1)
                writer.add_scalar('val/Mean Class Acc',
                                  score['Mean Class Acc: '], epoch + 1)
                writer.add_scalar('val/Freq Weighted IoU',
                                  score['Freq Weighted IoU: '], epoch + 1)
                confusion = score['confusion_matrix']
                writer.add_image(f'val/confusion matrix', np_to_tb(confusion),
                                 epoch + 1)
                writer.add_scalar('val/loss', loss.item(), epoch + 1)
                running_metrics_val.reset()  #reset confusion matrix

                if score['Mean IoU: '] >= best_iou:  #compare with the validation mean iou of current epoch with the best stored validation mean IoU
                    best_iou = score[
                        'Mean IoU: ']  #if better, then store the better and store the current model as the best model
                    model_dir = os.path.join(log_dir,
                                             f"{args.arch}_model_best.pkl")
                    torch.save(model, model_dir)

                if epoch % 10 == 0:  #every 10 epochs store the current model
                    model_dir = os.path.join(
                        log_dir, f"{args.arch}_ep{epoch}_model.pkl")
                    torch.save(model, model_dir)

        else:  # validation is turned off:
            # just save the latest model every 10 epochs:
            if (epoch + 1) % 10 == 0:
                model_dir = os.path.join(
                    log_dir, f"{args.arch}_ep{epoch + 1}_model.pkl")
                torch.save(model, model_dir)

    writer.close()  #close the writer
Exemplo n.º 3
0
def train(args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Generate the train and validation sets for the model:
    split_train_val(args, per_val=args.per_val)

    current_time = datetime.now().strftime('%b%d_%H%M%S')
    log_dir = os.path.join('runs', current_time +
                           "_{}".format(args.arch))
    writer = SummaryWriter(log_dir=log_dir)
    # Setup Augmentations
    if args.aug:
        data_aug = Compose(
            [RandomRotate(10), RandomHorizontallyFlip(), AddNoise()])
    else:
        data_aug = None

    train_set = PatchLoader(is_transform=True,
                            split='train',
                            stride=args.stride,
                            patch_size=args.patch_size,
                            augmentations=data_aug)

    # Without Augmentation:
    val_set = PatchLoader(is_transform=True,
                          split='val',
                          stride=args.stride,
                          patch_size=args.patch_size)

    n_classes = train_set.n_classes

    trainloader = data.DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  num_workers=1,
                                  shuffle=True)
    valloader = data.DataLoader(val_set,
                                batch_size=args.batch_size,
                                num_workers=1)

    # Setup Metrics
    running_metrics = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)
   

    # Setup Model edited by Tannistha
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            model = torch.load(args.resume)
        else:
            print("No checkpoint found at '{}'".format(args.resume))
    else:
        #model = getattr(deeplab, 'resnet101')(
        #pretrained=(not args.scratch),
        #num_classes=n_classes,
        #num_groups=args.groups,
        #weight_std=args.weight_std,
        #beta=args.beta)
        # edited by Tannistha
        model = getattr(ResNet9, 'resnet9')(
        pretrained=(args.scratch),
        num_classes=n_classes,
        num_groups=args.groups,
        weight_std=args.weight_std,
        beta=args.beta)

    # Use as many GPUs as we can
    model = torch.nn.DataParallel(
        model, device_ids=range(torch.cuda.device_count()))
    model = model.to(device)  # Send to GPU

    # PYTROCH NOTE: ALWAYS CONSTRUCT OPTIMIZERS AFTER MODEL IS PUSHED TO GPU/CPU,

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        print('Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        # optimizer = torch.optim.Adadelta(model.parameters())
        optimizer = torch.optim.SGD(model.parameters(),lr=args.base_lr, weight_decay=0.0001, momentum=0.9)
        #optimizer = torch.optim.Adam(model.parameters(),lr=args.base_lr, weight_decay=0.0001, amsgrad=True)
     ### edited by Tannistha to work with new optimizer
    if args.train:
        criterion = nn.CrossEntropyLoss(ignore_index=255)
        model.train()
        if args.freeze_bn:
            for m in model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad = False
                    m.bias.requires_grad = False
                    
        #optimizer = torch.optim.SGD(model.parameters(),lr=args.base_lr, weight_decay=0.0001, momentum=0.9)
        optimizer = torch.optim.Adam(model.parameters(),lr=args.base_lr, weight_decay=0.0001, amsgrad=True)
        
        start_epoch = 0

    loss_fn = core.loss.cross_entropy

    if args.class_weights:
        # weights are inversely proportional to the frequency of the classes in the training set
        class_weights = torch.tensor(
            [0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852], device=device, requires_grad=False)
    else:
        class_weights = None

    best_iou = -100.0
    class_names = ['upper_ns', 'middle_ns', 'lower_ns',
                   'rijnland_chalk', 'scruff', 'zechstein']
    
    for arg in vars(args):
        text = arg + ': ' + str(getattr(args, arg))
        writer.add_text('Parameters/', text)
        
    model_fname = 'data/deeplab_' + str(args.base_lr) + '_batch_size_' + str(args.batch_size) + '_' + args.exp + '_epoch_%d.pth'
    val_fname = 'val_lr_' + str(args.base_lr) + '_batch_size_' + str(args.batch_size) + '_' + args.exp
    
    for epoch in range(args.n_epoch):
        # Training Mode:
        model.train()
        loss_train, total_iteration = 0, 0

        for i, (images, labels) in enumerate(trainloader):
            
            image_original, labels_original = images, labels
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            
            pred = outputs.detach().max(1)[1].cpu().numpy()
            gt = labels.detach().cpu().numpy()
            running_metrics.update(gt, pred)

            loss = loss_fn(input=outputs, target=labels, weight=class_weights)

            loss_train += loss.item()
            optimizer.zero_grad()
            loss.backward()

            if args.clip != 0:
                torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
            optimizer.step()
            
            total_iteration = total_iteration + 1

            if (i) % 20 == 0:
                print('epoch: {0}/{1}\t\t'
                  'iter: {2}/{3}\t\t'
                  'training Loss:{4:.4f}'.format(epoch + 1, args.n_epoch, i + 1, len(trainloader), loss.item()))

            numbers = [0]
            if i in numbers:
                # number 0 image in the batch
                tb_original_image = vutils.make_grid(
                    image_original[0][0], normalize=True, scale_each=True)
                writer.add_image('train/original_image',
                                 tb_original_image, epoch + 1)

                labels_original = labels_original.numpy()[0]
                correct_label_decoded = train_set.decode_segmap(np.squeeze(labels_original))
                writer.add_image('train/original_label',np_to_tb(correct_label_decoded), epoch + 1)
                out = F.softmax(outputs, dim=1)

                # this returns the max. channel number:
                prediction = out.max(1)[1].cpu().numpy()[0]
                # this returns the confidence:
                confidence = out.max(1)[0].cpu().detach()[0]
                tb_confidence = vutils.make_grid(
                    confidence, normalize=True, scale_each=True)

                decoded = train_set.decode_segmap(np.squeeze(prediction))
                writer.add_image('train/predicted', np_to_tb(decoded), epoch + 1)
                writer.add_image('train/confidence', tb_confidence, epoch + 1)

                unary = outputs.cpu().detach()
                unary_max = torch.max(unary)
                unary_min = torch.min(unary)
                unary = unary.add((-1*unary_min))
                unary = unary/(unary_max - unary_min)

                for channel in range(0, len(class_names)):
                    decoded_channel = unary[0][channel]
                    tb_channel = vutils.make_grid(
                        decoded_channel, normalize=True, scale_each=True)
                    writer.add_image(f'train_classes/_{class_names[channel]}', tb_channel, epoch + 1)

        # Average metrics, and save in writer()
        loss_train /= total_iteration
        score, class_iou = running_metrics.get_scores()
        writer.add_scalar('train/Pixel Acc', score['Pixel Acc: '], epoch+1)
        writer.add_scalar('train/Mean Class Acc',
                          score['Mean Class Acc: '], epoch+1)
        writer.add_scalar('train/Freq Weighted IoU',
                          score['Freq Weighted IoU: '], epoch+1)
        writer.add_scalar('train/Mean_IoU', score['Mean IoU: '], epoch+1)
        running_metrics.reset()
        writer.add_scalar('train/loss', loss_train, epoch+1)
        
        if args.per_val != 0:
            with torch.no_grad():  # operations inside don't track history
                # Validation Mode:
                model.eval()
                loss_val, total_iteration_val = 0, 0

                for i_val, (images_val, labels_val) in enumerate(valloader):
                    image_original, labels_original = images_val, labels_val
                    images_val, labels_val = images_val.to(
                        device), labels_val.to(device)
                    #image_val = to_3_channels(images_val)
                    outputs_val = model(images_val)
                    #outputs_val = model(image_val)
                    pred = outputs_val.detach().max(1)[1].cpu().numpy()
                    gt = labels_val.detach().cpu().numpy()

                    running_metrics_val.update(gt, pred)

                    loss = loss_fn(input=outputs_val, target=labels_val)
                    
                    loss_val += loss.item()

                    total_iteration_val = total_iteration_val + 1

                    if (i_val) % 20 == 0:
                        print("Epoch [%d/%d] validation Loss: %.4f" %
                              (epoch+1, args.n_epoch, loss.item()))

                    numbers = [0]
                    if i_val in numbers:
                        # number 0 image in the batch
                        tb_original_image = vutils.make_grid(
                            image_original[0][0], normalize=True, scale_each=True)
                        writer.add_image('val/original_image',
                                         tb_original_image, epoch)
                        labels_original = labels_original.numpy()[0]
                        correct_label_decoded = train_set.decode_segmap(
                            np.squeeze(labels_original))
                        writer.add_image('val/original_label',
                                         np_to_tb(correct_label_decoded), epoch + 1)

                        out = F.softmax(outputs_val, dim=1)

                        # this returns the max. channel number:
                        prediction = out.max(1)[1].cpu().detach().numpy()[0]
                        # this returns the confidence:
                        confidence = out.max(1)[0].cpu().detach()[0]
                        tb_confidence = vutils.make_grid(
                            confidence, normalize=True, scale_each=True)

                        decoded = train_set.decode_segmap(
                            np.squeeze(prediction))
                        writer.add_image('val/predicted', np_to_tb(decoded), epoch + 1)
                        writer.add_image('val/confidence',
                                         tb_confidence, epoch + 1)

                        unary = outputs.cpu().detach()
                        unary_max, unary_min = torch.max(
                            unary), torch.min(unary)
                        unary = unary.add((-1*unary_min))
                        unary = unary/(unary_max - unary_min)

                        for channel in range(0, len(class_names)):
                            tb_channel = vutils.make_grid(
                                unary[0][channel], normalize=True, scale_each=True)
                            writer.add_image(
                                f'val_classes/_{class_names[channel]}', tb_channel, epoch + 1)
                loss_val /= total_iteration_val
                score, class_iou = running_metrics_val.get_scores()
                
                pd.DataFrame([running_metrics_val.get_scores()[0]["Pixel Acc: "]]).to_csv(os.path.join(val_fname, "metrics", "pixel_acc.csv"), index=False, mode='a', header=(i==0))
                pd.DataFrame([running_metrics_val.get_scores()[0]["Mean Class Acc: "]]).to_csv(os.path.join(val_fname, "metrics", "mean_class_acc.csv"),index=False, mode='a', header=(i==0))
                pd.DataFrame([running_metrics_val.get_scores()[0]["Freq Weighted IoU: "]]).to_csv(os.path.join(val_fname, "metrics", "freq_weighted_iou.csv"),index=False, mode='a', header=(i==0))
                pd.DataFrame([running_metrics_val.get_scores()[0]["Mean IoU: "]]).to_csv(os.path.join(val_fname, "metrics", "mean_iou.csv"), index=False, mode='a', header=(i==0))
                
                cname = os.path.join(val_fname, "metrics", "confusion_matrix", "confusion_matrix_" + str(epoch + 1) + ".csv")
                pd.DataFrame(running_metrics_val.get_scores()[0]["confusion_matrix"]).to_csv(cname, index=False)
                
                pd.DataFrame(running_metrics_val.get_scores()[0]["Class Accuracy: "].reshape((1, 6)), columns=[0, 1, 2, 3, 4, 5]).to_csv(os.path.join(val_fname, "metrics", "class_acc.csv"), index=False, mode = "a", header = (i == 0))
                pd.DataFrame(running_metrics_val.get_scores()[1], columns=[0, 1, 2, 3, 4, 5], index=[0]).to_csv(os.path.join(val_fname, "metrics", "cls_iu.csv"), mode = "a", header = (i == 0))
                

                writer.add_scalar(
                    'val/Pixel Acc', score['Pixel Acc: '], epoch+1)
                writer.add_scalar('val/Mean IoU', score['Mean IoU: '], epoch+1)
                writer.add_scalar('val/Mean Class Acc',
                                  score['Mean Class Acc: '], epoch+1)
                writer.add_scalar('val/Freq Weighted IoU',
                                  score['Freq Weighted IoU: '], epoch+1)

                writer.add_scalar('val/loss', loss_val, epoch+1)
                running_metrics_val.reset()

                if score['Mean IoU: '] >= best_iou:
                    best_iou = score['Mean IoU: ']
                    model_dir = os.path.join(
                        log_dir, f"{args.arch}_model.pkl")
                    #torch.save(model, model_dir)

                    torch.save({'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),}, model_fname % (epoch + 1))


        else:  # validation is turned off:
            # just save the latest model:
            if (epoch+1) % 5 == 0:
                model_dir = os.path.join(
                    log_dir, f"{args.arch}_ep{epoch+1}_model.pkl")
                #torch.save(model, model_dir)
                torch.save({'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),}, model_fname % (epoch + 1))
                        
        writer.add_scalar('train/epoch_lr', optimizer.param_groups[0]["lr"], epoch+1)
        
    writer.close()
def train(args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Generate the train and validation sets for the model:
    split_train_val_weak(args, per_val=args.per_val)
    loader = patch_loader_weak

    current_time = datetime.now().strftime('%b%d_%H%M%S')
    log_dir = os.path.join('runs',
                           current_time + f"_{args.arch}_{args.model_name}")
    writer = SummaryWriter(log_dir=log_dir)
    # Setup Augmentations
    if args.aug:
        data_aug = Compose(
            [RandomRotate(15),
             RandomHorizontallyFlip(),
             AddNoise()])
    else:
        data_aug = None

    train_set = loader(is_transform=True,
                       split='train',
                       augmentations=data_aug)

    # Without Augmentation:
    val_set = loader(is_transform=True,
                     split='val',
                     patch_size=args.patch_size)

    #if args.mixup:
    #    train_set1 = loader(is_transform=True,
    #                       split='train',
    #                       augmentations=data_aug)

    n_classes = train_set.n_classes

    trainloader = data.DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)

    #####################################################################
    #shuffle and load
    random.shuffle(train_set.patches['train'])  #shuffle list of IDs
    alpha = 0.5
    trainloader1 = data.DataLoader(
        train_set, batch_size=args.batch_size, num_workers=4,
        shuffle=True)  #load shuffeled data again in another loader
    ######################################################################

    valloader = data.DataLoader(val_set,
                                batch_size=args.batch_size,
                                num_workers=4)

    # Setup Metrics
    running_metrics = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            model = torch.load(args.resume)
        else:
            print("No checkpoint found at '{}'".format(args.resume))
    else:
        model = get_model(args.arch, args.pretrained, n_classes)

    # Use as many GPUs as we can
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model = model.to(device)  # Send to GPU

    # PYTROCH NOTE: ALWAYS CONSTRUCT OPTIMIZERS AFTER MODEL IS PUSHED TO GPU/CPU,

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        print('Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.Adam(model.parameters(), amsgrad=True)

    loss_fn = core.loss.focal_loss2d

    if args.class_weights:
        # weights are inversely proportional to the frequency of the classes in the training set
        class_weights = torch.tensor(
            [0, 0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852],
            device=device,
            requires_grad=False)
    else:
        class_weights = None

    best_iou = -100.0
    class_names = [
        'null', 'upper_ns', 'middle_ns', 'lower_ns', 'rijnland_chalk',
        'scruff', 'zechstein'
    ]

    for arg in vars(args):
        text = arg + ': ' + str(getattr(args, arg))
        writer.add_text('Parameters/', text)

    # training
    for epoch in range(args.n_epoch):
        # Training Mode:
        model.train()
        loss_train, total_iteration = 0, 0

        for (i, (images, labels, confs,
                 sims)), (i1, (images1, labels1, confs1,
                               sims1)) in zip(enumerate(trainloader),
                                              enumerate(trainloader1)):

            N, c, w, h = labels.shape
            one_hot = torch.FloatTensor(N, 7, w, h).zero_()
            labels_hot = one_hot.scatter_(
                1, labels.data,
                1)  # create one hot representation for the labels

            if args.mixup:  #if mixup is true then mix
                lam = torch.from_numpy(
                    np.random.beta(alpha, alpha,
                                   (N, 1, 1, 1))).float()  #sampling lambda
                one_hot = torch.FloatTensor(N, 7, w, h).zero_()
                labels_hot1 = one_hot.scatter_(
                    1, labels1.data,
                    1)  # create one hot representation for the labels
                images, labels, labels_hot, confs, sims = (
                    lam * images + (1 - lam) * images1), (
                        lam * labels.float() + (1 - lam) * labels1.float()), (
                            lam * labels_hot + (1 - lam) * labels_hot1), (
                                lam * confs.squeeze() +
                                (1 - lam) * confs1.squeeze()), (
                                    lam.squeeze() * sims.float() +
                                    (1 - lam).squeeze() * sims1.float()
                                )  #mixup

            image_original = images  #TODO Q: Are the passed original lables correct? in the context of following comaprison in line 233
            images, labels_hot, confs, sims = images.to(device), labels_hot.to(
                device), confs.to(device), sims.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            pred = outputs.detach().max(1)[1].cpu().numpy()
            labels_original = confs.squeeze().permute(
                0, 3, 1, 2).detach().max(1)[1].cpu().numpy()
            running_metrics.update(labels_original, pred)
            loss = loss_fn(input=outputs,
                           target=labels_hot,
                           conf=confs,
                           alpha=class_weights,
                           sim=sims,
                           gamma=args.gamma,
                           loss_type=args.loss_parameters,
                           soft_dev=args.soft_dev)
            loss_train += loss.item()
            loss.backward()

            # gradient clipping
            if args.clip != 0:
                torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
            optimizer.step()
            total_iteration = total_iteration + 1

            if (i) % 20 == 0:
                print("Epoch [%d/%d] training Loss: %.4f" %
                      (epoch + 1, args.n_epoch, loss.item()))

            numbers = [0, 14, 29]
            if i in numbers:

                tb_original_image = vutils.make_grid(image_original[i][0],
                                                     normalize=True,
                                                     scale_each=True)
                writer.add_image('train/original_image', tb_original_image,
                                 epoch + 1)

                # tb_confs_original = vutils.make_grid(confs_tb, normalize=True, scale_each=True)
                # writer.add_image('train/confs_original',tb_confs_original, epoch +1)

                labels_original = labels_original[i]
                correct_label_decoded = train_set.decode_segmap(
                    np.squeeze(labels_original))
                writer.add_image('train/original_label',
                                 np_to_tb(correct_label_decoded), epoch + 1)
                out = F.softmax(outputs, dim=1)

                # this returns the max. channel number:
                prediction = out.max(1)[1].cpu().numpy()[0]
                # this returns the confidence:
                confidence = out.max(1)[0].cpu().detach()[0]
                tb_confidence = vutils.make_grid(confidence,
                                                 normalize=True,
                                                 scale_each=True)

                decoded = train_set.decode_segmap(np.squeeze(prediction))
                writer.add_image('train/predicted', np_to_tb(decoded),
                                 epoch + 1)
                writer.add_image('train/confidence', tb_confidence, epoch + 1)

                unary = outputs.cpu().detach()
                unary_max = torch.max(unary)
                unary_min = torch.min(unary)
                unary = unary.add((-1 * unary_min))
                unary = unary / (unary_max - unary_min)

                for channel in range(0, len(class_names)):
                    decoded_channel = unary[0][channel]
                    tb_channel = vutils.make_grid(decoded_channel,
                                                  normalize=True,
                                                  scale_each=True)
                    writer.add_image(f'train_classes/_{class_names[channel]}',
                                     tb_channel, epoch + 1)

        # Average metrics, and save in writer()
        loss_train /= total_iteration
        score, class_iou = running_metrics.get_scores()
        writer.add_scalar('train/Pixel Acc', score['Pixel Acc: '], epoch + 1)
        writer.add_scalar('train/Mean Class Acc', score['Mean Class Acc: '],
                          epoch + 1)
        writer.add_scalar('train/Freq Weighted IoU',
                          score['Freq Weighted IoU: '], epoch + 1)
        writer.add_scalar('train/Mean_IoU', score['Mean IoU: '], epoch + 1)

        confusion = score['confusion_matrix']
        writer.add_image(f'train/confusion matrix', np_to_tb(confusion),
                         epoch + 1)

        running_metrics.reset()
        writer.add_scalar('train/loss', loss_train, epoch + 1)

        if args.per_val != 0:
            with torch.no_grad():  # operations inside don't track history
                # Validation Mode:
                model.eval()
                loss_val, total_iteration_val = 0, 0

                for i_val, (images_val, labels_val, conf_val,
                            sim_val) in tqdm(enumerate(valloader)):

                    N, c, w, h = labels_val.shape
                    one_hot = torch.FloatTensor(N, 7, w, h).zero_()
                    labels_hot_val = one_hot.scatter_(
                        1, labels_val.data,
                        1)  # create one hot representation for the labels

                    image_original, labels_original = images_val, labels_val
                    images_val, labels_hot_val, conf_val, sim_val = images_val.to(
                        device), labels_hot_val.to(device), conf_val.to(
                            device), sim_val.to(device)

                    outputs_val = model(images_val)
                    pred = outputs_val.detach().max(1)[1].cpu().numpy()
                    gt = labels_val.numpy()

                    running_metrics_val.update(gt, pred)

                    loss = loss_fn(input=outputs_val,
                                   target=labels_hot_val,
                                   conf=conf_val,
                                   alpha=class_weights,
                                   sim=sim_val,
                                   gamma=args.gamma,
                                   loss_type=args.loss_parameters,
                                   soft_dev=args.soft_dev)

                    total_iteration_val = total_iteration_val + 1

                    if (i_val) % 20 == 0:
                        print("Epoch [%d/%d] validation Loss: %.4f" %
                              (epoch, args.n_epoch, loss.item()))

                    numbers = [0]
                    if i_val in numbers:

                        tb_original_image = vutils.make_grid(
                            image_original[i_val][0],
                            normalize=True,
                            scale_each=True)
                        writer.add_image('val/original_image',
                                         tb_original_image, epoch)
                        labels_original = labels_original.numpy()[0]
                        correct_label_decoded = train_set.decode_segmap(
                            np.squeeze(labels_original))
                        writer.add_image('val/original_label',
                                         np_to_tb(correct_label_decoded),
                                         epoch + 1)

                        out = F.softmax(outputs_val, dim=1)

                        # this returns the max. channel number:
                        prediction = out.max(1)[1].cpu().detach().numpy()[0]
                        # this returns the confidence:
                        confidence = out.max(1)[0].cpu().detach()[0]
                        tb_confidence = vutils.make_grid(confidence,
                                                         normalize=True,
                                                         scale_each=True)

                        decoded = train_set.decode_segmap(
                            np.squeeze(prediction))
                        writer.add_image('val/predicted', np_to_tb(decoded),
                                         epoch + 1)
                        writer.add_image('val/confidence', tb_confidence,
                                         epoch + 1)

                        unary = outputs.cpu().detach()
                        unary_max, unary_min = torch.max(unary), torch.min(
                            unary)
                        unary = unary.add((-1 * unary_min))
                        unary = unary / (unary_max - unary_min)

                        for channel in range(0, len(class_names)):
                            tb_channel = vutils.make_grid(unary[0][channel],
                                                          normalize=True,
                                                          scale_each=True)
                            writer.add_image(
                                f'val_classes/_{class_names[channel]}',
                                tb_channel, epoch + 1)

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)

                writer.add_scalar('val/Pixel Acc', score['Pixel Acc: '],
                                  epoch + 1)
                writer.add_scalar('val/Mean IoU', score['Mean IoU: '],
                                  epoch + 1)
                writer.add_scalar('val/Mean Class Acc',
                                  score['Mean Class Acc: '], epoch + 1)
                writer.add_scalar('val/Freq Weighted IoU',
                                  score['Freq Weighted IoU: '], epoch + 1)

                confusion = score['confusion_matrix']
                writer.add_image(f'val/confusion matrix', np_to_tb(confusion),
                                 epoch + 1)
                writer.add_scalar('val/loss', loss.item(), epoch + 1)
                running_metrics_val.reset()

                if score['Mean IoU: '] >= best_iou:
                    best_iou = score['Mean IoU: ']
                    model_dir = os.path.join(log_dir,
                                             f"{args.arch}_model_best.pkl")
                    torch.save(model, model_dir)

                if epoch % 10 == 0:
                    model_dir = os.path.join(
                        log_dir, f"{args.arch}_ep{epoch}_model.pkl")
                    torch.save(model, model_dir)

        else:  # validation is turned off:
            # just save the latest model:
            if epoch % 10 == 0:
                model_dir = os.path.join(log_dir,
                                         f"{args.arch}_ep{epoch+1}_model.pkl")
                torch.save(model, model_dir)

    writer.close()