def main():
    # setting output directory
    out_dir = cfg.train.out_dir
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    else:
        print('Folder already exists. Are you sure you want to overwrite results?')
        print('Debug')  # put a break point here

    print('Configuration:')
    print(cfg)

    ## Data loaders
    # Training data loader
    cfg.train.mode = 'train'
    ds_train = get_dataset( cfg.train.mode)

    # validation data loader
    cfg.train.mode = 'test'
    ds_test = get_dataset(cfg.train.mode)  # ds
    print('Data loaders have been prepared!')

    ## Getting the model
    cfg.train.mode = 'train'

    # segmentation network
    net = Unet_class()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net.to(device)

    print('Network loaded. Starting training...')

    # weights for  # building, road, BG
    my_weight = torch.from_numpy(np.asarray([1, 2, 0.5])).type('torch.cuda.FloatTensor')
    criterion = torch.nn.CrossEntropyLoss(weight=my_weight )

    l1_maps = torch.nn.L1Loss(reduction='sum')

    # optimizer
    optim = torch.optim.Adam(list(net.parameters()) , lr=cfg.train.learning_rate, weight_decay=cfg.train.learning_rate_decay)

    # learning rate scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=5, gamma=0.5)

    ep = cfg.train.num_epochs # number of epochs

    # loss logs
    loss_train = 9999.0*np.ones(ep)
    temp_train_loss = 0
    loss_val = 9999.0*np.ones(ep)

    # training the network
    for epoch in range(ep):
        running_loss = 0.0
        running_ctr = 0

        # switch model to training mode, clear gradient accumulators
        net.train()
        optim.zero_grad()

        scheduler.step()  # update learning rate

        t1 = datetime.now()
        for i, data in enumerate(ds_train, 0):

            optim.zero_grad()

            # reading images
            images = data[0].type('torch.cuda.FloatTensor')

            # labels
            labels = data[1].type('torch.cuda.LongTensor')

            # occluded images
            occluded_imgs = data[2]

            # computing unweighted average
            base_map =0.0 * occluded_imgs[0].type('torch.cuda.FloatTensor')  # initialize with zero
            for j in range(len(occluded_imgs)):
                base_map = base_map + occluded_imgs[j].type('torch.cuda.FloatTensor') # avoiding inline operation i.e. +=

            base_map = base_map / np.float(len(occluded_imgs))

            predicted = net(base_map)

            loss = criterion(predicted, labels)

            loss.backward()

            optim.step()

            # print statistics
            running_loss += loss.item()
            running_ctr += 1
            if i %25 ==0:
                t2 = datetime.now()
                delta = t2 - t1
                t_print = delta.total_seconds()
                temp_train_loss = running_loss/25.0
                print('[%d, %5d out of %5d] loss: %f, time = %f' %
                      (epoch + 1, i + 1, len(ds_train) , running_loss / running_ctr,  t_print ))

                iou_build, iou_road, iou_bg = IoU(predicted, labels)
                print('building IoU = ' + str(iou_build) + ', road IoU = ' + str(iou_road) + ', background IoU = ' + str(iou_bg) )

                basemap_error = l1_maps(base_map, images)
                print('L1 error (base map, true image) = ' + str(basemap_error.item()))
                running_loss = 0.0
                running_ctr = 0
                t1 = t2

        # at the end of every epoch, calculating val loss
        net.eval()
        val_loss = 0

        with torch.no_grad():
            for i, data in enumerate(ds_test, 0):
                # get clean images
                images = data[0].type('torch.cuda.FloatTensor')

                # labels
                labels = data[1].type('torch.cuda.LongTensor')

                # occluded images
                occluded_imgs = data[2]
                base_map = 0.0 * occluded_imgs[0].type('torch.cuda.FloatTensor')  # initialize with zero
                for j in range(len(occluded_imgs)):
                    base_map = base_map + occluded_imgs[j].type(
                        'torch.cuda.FloatTensor')  # avoiding inline operation i.e. +=

                base_map = base_map / np.float(len(occluded_imgs))
                predicted = net(base_map)

                loss = criterion(predicted, labels)  # for cross entropy

                # val loss
                val_loss +=  loss.item()


            # print statistics
            val_loss = val_loss /len(ds_test)
            print('End of epoch ' + str(epoch + 1) + '. Val loss is ' + str(val_loss))

            print('Following stats are only for the last batch of the test set:')
            iou_build, iou_road, iou_bg = IoU(predicted, labels)
            print('building IoU = ' + str(iou_build) + ', road IoU = ' + str(iou_road) + ', background IoU = ' + str(
                iou_bg))
            basemap_error = l1_maps(base_map, images)
            print('L1 error (base map, true image) = ' + str(basemap_error.item()))

            # Model check point
            if val_loss < np.min(loss_val, axis=0):
               model_path = os.path.join(out_dir, "trained_model_checkpoint.pth")
               torch.save(net, model_path)   # segmentation network


            # saving losses
            loss_val[epoch] = val_loss
            loss_train[epoch] = temp_train_loss

            temp_train_loss = 0  # setting additive losses to zero



    print('Training finished')
    # saving model
    model_path = os.path.join(out_dir, "trained_model_end.pth")
    torch.save(net, model_path)

    # Saving logs
    log_name = os.path.join(out_dir, "logging.txt")
    with open(log_name, 'w') as result_file:
        result_file.write('Logging... \n')
        result_file.write('Validation loss ')
        result_file.write(str(loss_val))
        result_file.write('\nTraining loss  ')
        result_file.write(str(loss_train))


    print('Model saved')

    # saving loss curves
    a = loss_val.cpu().detach().numpy()
    b = loss_train.cpu().detach().numpy()
    # print(a.shape)
    print(a[0:epoch])

    plt.figure()
    plt.plot(b[0:epoch])
    plt.plot(a[0:epoch])
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(['Training loss', 'Validation Loss'])
    fname1 = str('loss.png')
    plt.savefig(os.path.join(out_dir, fname1), bbox_inches='tight')

    print('All done!')
示例#2
0
def main():
    out_dir = cfg.train.out_dir
    if not os.path.exists(out_dir):
        raise ValueError(
            'The folder does not exist. Make sure to set the correct folder variable cfg.train.out_dir in config.py'
        )

    if os.path.exists(os.path.join(out_dir, 'qual_results')):
        raise ValueError(
            'The validation folder image_results already exists. Delete the folder if those results are not needed'
        )
    else:
        os.makedirs(os.path.join(out_dir, 'qual_results'))

    qual_net = torch.load(
        os.path.join(out_dir, "trained_basemap_checkpoint.pth"))

    print('Network loaded...')
    print(cfg)

    ## Data loader
    # only test/validation set is needed
    if eval_potsdam == True:
        cfg.train.mode = 'test_potsdam'
    else:
        cfg.train.mode = 'test'

    ds_test = get_dataset(cfg.train.mode)
    print('Data loaders have been prepared!')

    qual_net.eval()

    ctr = 0
    with torch.no_grad():
        for i, data in enumerate(ds_test, 0):
            # reading clean images
            images = data[0].type('torch.cuda.FloatTensor')

            # occluded images
            occluded_imgs = data[2]

            # initializing the quality scores of all images
            q_pre = torch.zeros(
                occluded_imgs[0].shape[0], len(occluded_imgs),
                occluded_imgs[0].shape[1],
                occluded_imgs[0].shape[2]).type('torch.cuda.FloatTensor')

            for j in range(
                    len(occluded_imgs)):  # compute all the quality masks
                q_now = qual_net(
                    occluded_imgs[j].type('torch.cuda.FloatTensor'))
                q_pre[:, j, :, :] = q_now[:, 0, :, :]

            # do the softmax across quality masks dimension
            q_final = F.softmax(1 * q_pre, dim=1)

            # make the final basemap
            base_map = 0.0 * occluded_imgs[0].type(
                'torch.cuda.FloatTensor')  # initialization with zero
            for j in range(
                    len(occluded_imgs)):  # compute all the quality masks
                image_now = occluded_imgs[j].type('torch.cuda.FloatTensor')
                base_map = base_map + q_final[:, j, :, :].view(
                    q_now.shape).permute(0, 2, 3, 1) * image_now

            # computing unweigted average as baseline
            average_image = 0.0 * occluded_imgs[0].type(
                'torch.cuda.FloatTensor')  # initialize with zero
            for j in range(len(occluded_imgs)):
                average_image = average_image + occluded_imgs[j].type(
                    'torch.cuda.FloatTensor'
                )  # avoiding inline operation i.e. +=

            average_image = average_image / np.float(len(occluded_imgs))

            num_fig = np.minimum(base_map.shape[0], 18)

            plt.ioff()
            # save results of the last batch
            for k in range(num_fig):
                # target output
                plt.figure()
                plt.imshow(images[k, :, :, :].detach().cpu().numpy())
                plt.axis('off')
                fname1 = str(str(ctr) + '_target' + '.png')
                plt.savefig(os.path.join(out_dir, 'qual_results', fname1),
                            bbox_inches='tight')

                # basemap
                plt.figure()
                plt.imshow(base_map[k, :, :, :].detach().cpu().numpy())
                plt.axis('off')
                fname1 = str(str(ctr) + '_out_basemap' + '.png')
                plt.savefig(os.path.join(out_dir, 'qual_results', fname1),
                            bbox_inches='tight')

                plt.figure()
                plt.imshow(base_map[k, :, :, :].detach().cpu().numpy())
                plt.axis('off')
                fname1 = str(str(ctr) + '_out_basemap' + '.png')
                plt.savefig(os.path.join(out_dir, 'qual_results', fname1),
                            bbox_inches='tight')

                # baseline
                plt.figure()
                plt.imshow(average_image[k, :, :, :].detach().cpu().numpy())
                plt.axis('off')
                fname1 = str(str(ctr) + '_out_average' + '.png')
                plt.savefig(os.path.join(out_dir, 'qual_results', fname1),
                            bbox_inches='tight')

                # input images
                for j in range(len(occluded_imgs)):
                    plt.figure()
                    plt.imshow(occluded_imgs[j][k, :, :, :])
                    plt.axis('off')
                    fname1 = str(str(ctr) + '_image' + str(j) + '.png')
                    plt.savefig(os.path.join(out_dir, 'qual_results', fname1),
                                bbox_inches='tight')

                # quality masks
                for j in range(len(occluded_imgs)):
                    plt.figure()
                    plt.imshow(q_final[k, j, :, :].detach().cpu().numpy())
                    plt.axis('off')
                    fname1 = str(str(ctr) + '_mask' + str(j) + '.png')
                    plt.savefig(os.path.join(out_dir, 'qual_results', fname1),
                                bbox_inches='tight')

                ctr += 1
示例#3
0
def main():
    ROOT_DIR = os.path.dirname(
        os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

    with open(os.path.join(ROOT_DIR, 'config.json')) as json_file:
        config = json.load(json_file)

    out_dir = os.path.join(ROOT_DIR, 'models')

    if not os.path.exists(out_dir):
        raise ValueError(
            'The folder does not exist. Make sure the training has been completed before running this script'
        )

    segment_net = Unet_class()

    latest_model = max(glob(out_dir + "/*"))
    trained_model = glob(os.path.join(out_dir, latest_model) + "/*.pth")

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # loading segmentation net
    if len(trained_model) == 2:
        # segment_net = torch.load(os.path.join(out_dir, latest_model, trained_model[1]))
        segment_net.load_state_dict(
            torch.load(os.path.join(out_dir, latest_model, trained_model[1])))
    else:
        # segment_net = torch.load(os.path.join(out_dir, latest_model, trained_model[0]))
        segment_net.load_state_dict(
            torch.load(os.path.join(out_dir, latest_model, trained_model[0])))

    segment_net.eval()
    segment_net.to(device)

    print('Network loaded...')
    print(config)

    ## getting the dataset
    mode = 'test'

    ds_test = get_dataset(mode, config)
    print('Data loaders have been prepared!')

    # Initialize metrics

    iou_build = 0
    iou_road = 0
    iou_bg = 0
    mIoU = 0
    fwIou = 0
    acc = 0

    with torch.no_grad():
        for t in range(repeat_times):  # evaluate everything 10 times

            for i, data in enumerate(ds_test, 0):
                images = data[0].type('torch.FloatTensor')  # reading images

                # labels
                labels = data[1].type('torch.LongTensor')

                # segmentation performance
                predicted = segment_net(images)
                i1, i2, i3, i4, i5, i6 = IoU(predicted, labels, extra=True)
                iou_build += i1
                iou_road += i2
                iou_bg += i3
                mIoU += i4
                fwIou += i5
                acc += i6

            print('Completed ' + str(t) + 'out of ' + str(repeat_times))

    # average of segmentation numbers
    iou_build /= (len(ds_test) * repeat_times)
    iou_road /= (len(ds_test) * repeat_times)
    iou_bg /= (len(ds_test) * repeat_times)
    mIoU /= (len(ds_test) * repeat_times)
    fwIou /= (len(ds_test) * repeat_times)

    acc /= (len(ds_test) * repeat_times)

    print('Building IoU on test set = ' + str(iou_build))
    print('Road IoU on test set = ' + str(iou_road))
    print('BG IoU on test set = ' + str(iou_bg))
    print('mIoU on test set = ' + str(mIoU))
    print('Frequency weighted IoU on test set = ' + str(fwIou))
    print('Pixel accuracy on test set = ' + str(acc))

    fname = os.path.join(ROOT_DIR, 'reports', now.strftime("%Y%m%d-%H%M%S"))
    if not os.path.exists(fname):
        os.makedirs(fname)
    fname = os.path.join(fname, 'eval_results.txt')

    # saving results on disk
    with open(fname, 'w') as result_file:
        result_file.write('Logging... \n')
        result_file.write('\nBuilding IoU on test set =   ')
        result_file.write(str(iou_build))
        result_file.write('\nRoad IoU on test set =   ')
        result_file.write(str(iou_road))
        result_file.write('\nBG IoU on test set =   ')
        result_file.write(str(iou_bg))
        result_file.write('\nMean IoU on test set =   ')
        result_file.write(str(mIoU))
        result_file.write('\nfrequency weighted IoU on test set =   ')
        result_file.write(str(fwIou))
        result_file.write('\nPixel accuracy on test set =   ')
        result_file.write(str(acc))

    print('All done. Results saved in reports directory')
示例#4
0
from torch.autograd import Variable

from loss import loss_function
from utils import Gaussian_Conv_Update, apply_warp, af_plus_loss, get_grid

# setup checkpoint directory
out_dir = cfg.train.out_dir
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
else:
    print('Folder already exists. Are you sure you want to overwrite results?')

print('Configuration: \n', cfg)

# dataloaders
train_loader = get_dataset(dataset_name=cfg.data.name, mode='train')
test_loader = get_dataset(dataset_name=cfg.data.name, mode='test')
print('Data loaders have been prepared!')

# networks
af_plus = get_network('AF_plus')
fds = get_network('FDS')
fusion = get_network('GAF')
feature_loss = get_network('feature_loss')
discrim = get_network('discriminator')
gan_loss = get_network('gan_loss')

# load weights of a trained AF++
af_plus.load_state_dict(
    torch.load(os.path.join('checkpoints', 'af_plus_dict.pth')))
af_plus.eval()
示例#5
0
def main():
    out_dir = cfg.train.out_dir
    if not os.path.exists(out_dir):
        raise ValueError(
            'The folder does not exist. Make sure to set the correct folder variable cfg.train.out_dir in config.py'
        )

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if eval_baseline == False:  # quality net based method

        qual_net = torch.load(
            os.path.join(
                out_dir,
                "trained_basemap_checkpoint.pth"))  # loading quality net
        qual_net.eval()

        qual_net.to(device)

    # loading segmentation net
    segment_net = torch.load(
        os.path.join(out_dir, "trained_model_checkpoint.pth"))
    segment_net.eval()
    segment_net.to(device)

    print('Network loaded...')
    print(cfg)

    ## getting the dataset
    if eval_potsdam == True:
        cfg.train.mode = 'test_potsdam'  # Potsdam
    else:
        cfg.train.mode = 'test'  # Val set of Berlin

    ds_test = get_dataset(cfg.train.mode)
    print('Data loaders have been prepared!')

    # Metrics for pixel-wise comparison of fused images to respective clean images
    l1_error_abs = torch.nn.L1Loss(
        reduction='sum')  # This is dependent on image size
    l1_error_mean = torch.nn.L1Loss(
        reduction='mean')  # This is reported in the paper

    # Initializae metrics
    abs_error = 0
    mean_error = 0

    iou_build = 0
    iou_road = 0
    iou_bg = 0
    mIoU = 0
    fwIou = 0
    acc = 0

    with torch.no_grad():
        for t in range(repeat_times):  # evaluate everything 10 times

            for i, data in enumerate(ds_test, 0):
                images = data[0].type(
                    'torch.cuda.FloatTensor')  # reading images

                # labels
                labels = data[1].type('torch.cuda.LongTensor')
                occluded_imgs = data[2]

                if eval_baseline == False:
                    q_pre = torch.zeros(occluded_imgs[0].shape[0],
                                        len(occluded_imgs),
                                        occluded_imgs[0].shape[1],
                                        occluded_imgs[0].shape[2]).type(
                                            'torch.cuda.FloatTensor')

                    for j in range(len(
                            occluded_imgs)):  # compute all the quality masks
                        q_now = qual_net(
                            occluded_imgs[j].type('torch.cuda.FloatTensor'))
                        q_pre[:, j, :, :] = q_now[:, 0, :, :]

                    # do the softmax across quality masks dimension
                    q_final = F.softmax(q_pre, dim=1)

                    # make the final basemap
                    base_map = 0.0 * occluded_imgs[0].type(
                        'torch.cuda.FloatTensor')  # initialization with zero
                    for j in range(
                            len(occluded_imgs)
                    ):  # synthesizing fused image by combining images, weighted by quality scores
                        image_now = occluded_imgs[j].type(
                            'torch.cuda.FloatTensor')
                        base_map = base_map + q_final[:, j, :, :].view(
                            q_now.shape).permute(0, 2, 3, 1) * image_now

                if eval_baseline == True:  # Evaluating baseline?
                    # the following code is for Baseline (average) ONLY
                    base_map = 0.0 * occluded_imgs[0].type(
                        'torch.cuda.FloatTensor')  # initialize with zero
                    for j in range(len(occluded_imgs)):
                        base_map = base_map + occluded_imgs[j].type(
                            'torch.cuda.FloatTensor'
                        )  # avoiding inline operation i.e. +=

                        base_map = base_map / np.float(len(occluded_imgs))

                loss_abs = l1_error_abs(base_map, images)
                loss_mean = l1_error_mean(base_map, images)

                abs_error += loss_abs.item()
                mean_error += loss_mean.item()

                # segmentation performance
                predicted = segment_net(base_map)
                i1, i2, i3, i4, i5, i6 = IoU(predicted, labels, extra=True)
                iou_build += i1
                iou_road += i2
                iou_bg += i3
                mIoU += i4
                fwIou += i5
                acc += i6

            print('Completed ' + str(t) + 'out of ' + str(repeat_times))

    # computing average
    abs_error /= (len(ds_test) * repeat_times)
    mean_error /= (len(ds_test) * repeat_times)

    # average of segmentation numbers
    iou_build /= (len(ds_test) * repeat_times)
    iou_road /= (len(ds_test) * repeat_times)
    iou_bg /= (len(ds_test) * repeat_times)
    mIoU /= (len(ds_test) * repeat_times)
    fwIou /= (len(ds_test) * repeat_times)

    acc /= (len(ds_test) * repeat_times)

    print('Mean error on test set = ' + str(mean_error))
    print('Absolute error on test set = ' + str(abs_error))
    print('Building IoU on test set = ' + str(iou_build))
    print('Road IoU on test set = ' + str(iou_road))
    print('BG IoU on test set = ' + str(iou_bg))
    print('mIoU on test set = ' + str(mIoU))
    print('Frequency weighted IoU on test set = ' + str(fwIou))
    print('Pixel accuracy on test set = ' + str(acc))

    if eval_potsdam == True:
        if eval_baseline:
            n1 = str('eval_result_Potsdam_baseline_multiple.txt')
            fname = os.path.join(out_dir, n1)
        else:
            n1 = str('eval_result_Potsdam_multiple.txt')
            fname = os.path.join(out_dir, n1)
    else:
        fname = os.path.join(out_dir, 'eval_result_Berlin.txt')

    # saving results on disk
    with open(fname, 'w') as result_file:
        result_file.write('Logging... \n')
        result_file.write('Mean error on test set =  ')
        result_file.write(str(mean_error))
        result_file.write('\nAbsolute error on test set =   ')
        result_file.write(str(abs_error))
        result_file.write('\nBuilding IoU on test set =   ')
        result_file.write(str(iou_build))
        result_file.write('\nRoad IoU on test set =   ')
        result_file.write(str(iou_road))
        result_file.write('\nBG IoU on test set =   ')
        result_file.write(str(iou_bg))
        result_file.write('\nMean IoU on test set =   ')
        result_file.write(str(mIoU))
        result_file.write('\nfrequency weighted IoU on test set =   ')
        result_file.write(str(fwIou))
        result_file.write('\nPixel accuracy on test set =   ')
        result_file.write(str(acc))

    print('All done. Results saved in eval_result.txt in output directory')
def main():
    ROOT_DIR = os.path.dirname(
        os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

    with open(os.path.join(ROOT_DIR, 'config.json')) as json_file:
        config = json.load(json_file)

    out_dir = os.path.join(ROOT_DIR, 'models', now.strftime("%Y%m%d-%H%M%S"))

    # setting output directory
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    else:
        print('Folder already exists. overwriting the results......')

    print('Configuration:')
    print(config['train'])

    ## Data loaders
    mode = 'train'  # Training data loader
    ds_train = get_dataset(mode, config)

    mode = 'test'  # validation data loader
    ds_test = get_dataset(mode, config)
    print('Data loaders have been prepared!')

    ## Model
    mode = 'train'
    net = Unet_class()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net.to(device)

    print('Network loaded. Starting training...')

    # weights for  # building, road, BG - To balance the classes
    my_weight = torch.from_numpy(np.asarray([1, 2,
                                             0.5])).type('torch.FloatTensor')
    criterion = torch.nn.CrossEntropyLoss(weight=my_weight)

    optim = torch.optim.Adam(
        net.parameters(),
        lr=config['train']['learning_rate'],
        weight_decay=config['train']['learning_rate_decay'])

    # learning rate scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=5, gamma=0.5)

    ep = config['train']['num_epochs']  # number of epochs

    # loss logs
    loss_train = 9999.0 * np.ones(ep)
    temp_train_loss = 0
    loss_val = 9999.0 * np.ones(ep)

    # training the network
    for epoch in range(ep):
        running_loss = 0.0
        running_ctr = 0

        # switch model to training mode, clear gradient accumulators
        net.train()

        # scheduler.step()  # update learning rate

        optim.zero_grad()

        t1 = datetime.now()

        for i, data in enumerate(ds_train, 0):
            optim.zero_grad()

            # reading images
            images = data[0].type('torch.FloatTensor')
            # labels
            labels = data[1].type('torch.LongTensor')

            predicted = net(images)

            loss = criterion(predicted, labels)

            loss.backward()

            optim.step()

            # print statistics
            running_loss += loss.item()
            running_ctr += 1
            if i % 25 == 0:
                t2 = datetime.now()
                delta = t2 - t1
                t_print = delta.total_seconds()
                temp_train_loss = running_loss / 25.0
                print('[%d, %5d out of %5d] loss: %f, time = %f' %
                      (epoch + 1, i + 1, len(ds_train),
                       running_loss / running_ctr, t_print))

                iou_build, iou_road, iou_bg = IoU(predicted, labels)
                print('building IoU = ' + str(iou_build) + ', road IoU = ' +
                      str(iou_road) + ', background IoU = ' + str(iou_bg))

                running_loss = 0.0
                running_ctr = 0
                t1 = t2

        net.eval()
        scheduler.step()  # update learning rate
        val_loss = 0

        with torch.no_grad():
            for i, data in enumerate(ds_test, 0):
                # reading images
                images = data[0].type('torch.FloatTensor')
                # labels
                labels = data[1].type('torch.LongTensor')

                predicted = net(images)

                loss = criterion(predicted, labels)

                # Val loss
                val_loss += loss.item()

            # print statistics
            val_loss = val_loss / len(ds_test)
            print('End of epoch ' + str(epoch + 1) + '. Val loss is ' +
                  str(val_loss))

            print(
                'Following stats are only for the last batch of the test set:')
            iou_build, iou_road, iou_bg = IoU(predicted, labels)
            print('building IoU = ' + str(iou_build) + ', road IoU = ' +
                  str(iou_road) + ', background IoU = ' + str(iou_bg))

            # Model check point
            if val_loss < np.min(loss_val, axis=0):
                model_path = os.path.join(out_dir,
                                          "trained_model_checkpoint.pth")
                #    torch.save(net, model_path)
                torch.save(net.state_dict(), model_path)
                print('Model saved at epoch ' + str(epoch + 1))

            # saving losses
            loss_val[epoch] = val_loss
            loss_train[epoch] = temp_train_loss

            temp_train_loss = 0  # setting additive losses to zero

    print('Training finished')
    # saving model
    model_path = os.path.join(out_dir, "trained_model_end.pth")
    # torch.save(net, model_path)
    torch.save(net.state_dict(), model_path)

    print('Model saved')

    # Saving logs in a text file in the output directory
    log_name = os.path.join(out_dir, "logging.txt")
    with open(log_name, 'w') as result_file:
        result_file.write('Logging... \n')
        result_file.write('Validation loss ')
        result_file.write(str(loss_val))
        result_file.write('\nTraining loss  ')
        result_file.write(str(loss_train))

    # saving loss curves
    a = loss_val
    b = loss_train

    plt.figure()
    plt.plot(b)
    plt.plot(a)
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(['Training loss', 'Validation Loss'])
    fname1 = str('loss.png')
    plt.savefig(os.path.join(out_dir, fname1), bbox_inches='tight')

    print('Training finished!!!')