Ejemplo n.º 1
0
def train(opt, model, use_cuda):
    model.train()
    loader = DataLoader(dt_ex(opt.root, input_transform, target_transform, 512),
                        num_workers=opt.workers,
                        batch_size=opt.batch,
                        pin_memory=True,
                        shuffle=True)
    weight = torch.ones(2)
    weight[0] = 0

    if use_cuda:
        criterion = CrossEntropyLoss2d(weight.cuda())
    else:
        criterion = CrossEntropyLoss2d(weight)

    criterion = CrossEntropyLoss2d().cuda()

    optimizer = SGD(model.parameters(), 1e-4, .9, 2e-5)

    if opt.steps_plot > 0:
        board = Dashboard(opt.port)

    for epoch in range(opt.epochs+1):
        epoch_loss = []
        for step, (images, labels) in enumerate(loader):
            if use_cuda:
                print('use cuda!!!!!!')
                images = images.cuda()
                labels = labels.cuda()
            inputs = Variable(images)
            targets = Variable(labels)

            outputs = model(inputs)

            optimizer.zero_grad()

            loss = criterion(outputs, targets[:, 0])
            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.data[0])

            if opt.steps_plot > 0 and step % opt.steps_plot == 0:
                image = inputs[0].cpu().data
                image[0] = image[0] * .229 + .485
                image[1] = image[1] * .224 + .456
                image[2] = image[2] * .225 + .406
                board.image(image,
                            f'input (epoch: {epoch}, step: {step})')
                board.image(color_transform(outputs[0].cpu().max(0)[1].data),
                            f'output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                            f'target (epoch: {epoch}, step: {step})')
            if opt.steps_loss > 0 and step % opt.steps_loss == 0:
                average = sum(epoch_loss) / len(epoch_loss)
                print(f'loss: {average} (epoch: {epoch}, step: {step})')
            if opt.steps_save > 0 and step % opt.steps_save == 0:
                filename = f'fcn8-{epoch:03}-{step:04}.pth'
                torch.save(model.state_dict(), filename)
                print(f'save: {filename} (epoch: {epoch}, step: {step})')
Ejemplo n.º 2
0
def train(args, model):
    model.train()

    # loader = DataLoader(MA('/Users/zhangweidong03/Code/dl/pytorch/github/piwise/MAdata', input_transform, target_transform),
    #                     num_workers=1, batch_size=1, shuffle=True)

    loader = DataLoader(dt_ma(args.datadir, input_transform, target_transform),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=True)

    weight = torch.ones(2)
    weight[0] = 0

    use_cuda = False
    if args.cuda:
        criterion = CrossEntropyLoss2d(weight.cuda())
    else:
        criterion = CrossEntropyLoss2d(weight)

    criterion = CrossEntropyLoss2d()

    # optimizer = SGD(model.parameters(), 1e-4, .9, 2e-5)

    optimizer = Adam(model.parameters())
    if args.model.startswith('FCN'):
        optimizer = SGD(model.parameters(), 1e-4, .9, 2e-5)
    if args.model.startswith('PSP'):
        optimizer = SGD(model.parameters(), 1e-2, .9, 1e-4)
    if args.model.startswith('Seg'):
        optimizer = SGD(model.parameters(), 1e-3, .9)

    if args.steps_plot > 0:
        board = Dashboard(args.port)

    for epoch in range(1, args.num_epochs + 1):
        epoch_loss = []

        for step, (images, labels) in enumerate(loader):
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(images)
            targets = Variable(labels)
            outputs = model(inputs)

            optimizer.zero_grad()
            loss = criterion(outputs, targets[:, 0])
            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.data[0])
            if args.steps_plot > 0 and step % args.steps_plot == 0:
                image = inputs[0].cpu().data
                image[0] = image[0] * .229 + .485
                image[1] = image[1] * .224 + .456
                image[2] = image[2] * .225 + .406
                board.image(image, f'input (epoch: {epoch}, step: {step})')
                board.image(color_transform(outputs[0].cpu().max(0)[1].data),
                            f'output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                            f'target (epoch: {epoch}, step: {step})')
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss) / len(epoch_loss)
                print(f'loss: {average} (epoch: {epoch}, step: {step})')
            if args.steps_save > 0 and step % args.steps_save == 0:
                filename = f'{args.model}-{epoch:03}-{step:04}.pth'
                torch.save(model.state_dict(), filename)
                print(f'save: {filename} (epoch: {epoch}, step: {step})')
def main():
    print('====> Retinal Image Segmentation: ')
    args = parse_args()
    print('====> Parsing Options: ')
    print(args)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    if not os.path.isdir(args.output):
        os.makedirs(args.output)
    time_stamp = time.strftime('%Y%m%d%H%M%S', time.localtime(time.time()))
    output_dir = os.path.join(
        args.output, args.dataset + '_segmentaion_' + args.phase + '_' +
        time_stamp + '_' + args.model + '_' + args.exp)
    if not os.path.exists(output_dir):
        print('====> Creating ', output_dir)
        os.makedirs(output_dir)
    print('====> load model: ')
    model = get_model(args.model, args.num_classes, args.weight)
    criterion = CrossEntropyLoss2d()
    patches_per_image = 200
    print('====> start visualize dashboard: ')
    board = Dashboard(args.port)
    if args.phase == 'train':
        print('=====> Training model:')
        train_dataloader = DataLoader(
            RetinalVesselTrainingDS(args.root, args.size, args.patch_size,
                                    patches_per_image),
            batch_size=args.batch,
            # num_workers=args.workers,
            shuffle=True,
            pin_memory=True)

        val_dataloader = DataLoader(
            RetinalVesselValidationDS(args.root_val, args.size,
                                      args.patch_size),
            batch_size=args.batch,
            # num_workers=args.workers,
            shuffle=False,
            pin_memory=False)
        best_train_loss = 1e4
        best_val_loss = 1e4
        for epoch in range(args.epoch):
            if epoch < args.fix:
                lr = args.lr
            else:
                lr = args.lr * (0.1**(epoch // args.step))
            optimizer = get_optimizer(args.optim)
            optimizer = optimizer(model.parameters(), lr, args.mom, args.wd)
            train_logger, train_loss = train(train_dataloader,
                                             nn.DataParallel(model).cuda(),
                                             criterion, optimizer, epoch,
                                             args.steps_plot, args.steps_loss,
                                             args.steps_save, board)
            # val_logger = []
            # val_loss = 0
            val_logger, val_loss = val(val_dataloader,
                                       nn.DataParallel(model).cuda(),
                                       criterion,
                                       epoch,
                                       board=board)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                print_info = 'Current Validation Loss: {}'.format(
                    best_val_loss)
                val_logger.append(print_info)
                print(print_info)
                tmp_file = os.path.join(
                    output_dir, args.dataset + '_segmentation_' + args.model +
                    '_%04d' % epoch + '_best.pth')
                print_info = '====> Save model: {}'.format(tmp_file)
                torch.save(model.cpu().state_dict(), tmp_file)
                val_logger.append(print_info)
                print(print_info)
            if not os.path.isfile(os.path.join(output_dir, 'train.log')):
                with open(os.path.join(output_dir, 'train.log'), 'w') as fp:
                    fp.write(str(args) + '\n\n')
            with open(os.path.join(output_dir, 'train.log'), 'a') as fp:
                fp.write('\n' + '\n'.join(train_logger))
                fp.write('\n' + '\n'.join(val_logger))
    elif args.phase == 'predict':
        # image_file = '/home/weidong/code/dr/RetinalImagesVesselExtraction/data/DRIVE/test/ahe/02_test_ahe.png'
        # image_files = glob(os.path.join(args.root_val, 'ahe/*.png'))
        image_files = glob(
            os.path.join(
                '/home/weidong/code/github/DiabeticRetinopathy_solution/data/zhizhen_new/LabelImages/512_ahe',
                '*.png'))
        for index in image_files:
            image_file = index
            size = 512
            patch_size = 128
            stride = 64
            MEAN = [.485, .456, .406]
            STD = [.229, .224, .225]
            input_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize(MEAN, STD)])
            color_trans = transforms.ToPILImage()
            ds = RetinalVesselPredictImage(image_file, input_transform, size,
                                           patch_size, stride)
            data_loader = DataLoader(ds,
                                     batch_size=20,
                                     shuffle=False,
                                     pin_memory=False)
            pred_image(image_file, data_loader,
                       nn.DataParallel(model).cuda(), size, patch_size, stride,
                       board)
    else:
        raise Exception('No phase found')
Ejemplo n.º 4
0
def train(args, model):
    model.train()

    weight = torch.ones(NUM_CLASSES)
    weight[0] = 0.1

    loader = DataLoader(VOCTrain(args.datadir, 'train', input_transform,
                                 target_transform),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=True)

    if args.cuda:
        criterion = CrossEntropyLoss2d(weight.cuda())
    else:
        criterion = CrossEntropyLoss2d(weight)

    optimizer = Adam(model.parameters(), lr=1e-5)
    # if args.model.startswith('FCN'):
    #     optimizer = SGD(model.parameters(), 1e-4, .9, 2e-5)
    # if args.model.startswith('PSP'):
    #     optimizer = SGD(model.parameters(), 1e-2, .9, 1e-4)
    # if args.model.startswith('Seg'):
    #     optimizer = SGD(model.parameters(), 1e-3, .9)

    if args.steps_plot > 0:
        board = Dashboard(args.port)

    for epoch in range(1, args.num_epochs + 1):
        epoch_loss = []

        for step, (images, labels) in enumerate(loader):
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(images)
            targets = Variable(labels)
            outputs = model(inputs)

            optimizer.zero_grad()
            loss = criterion(outputs, targets[:, 0])
            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.item())
            if args.steps_plot > 0 and step % args.steps_plot == 0:
                image = inputs[0].cpu().data
                image[0] = image[0] * .5 + .5
                image[1] = image[1] * .5 + .5
                image[2] = image[2] * .5 + .5
                board.image(image, f'input (epoch: {epoch}, step: {step})')
                board.image(
                    color_transform(outputs[0].cpu().max(
                        0, keepdim=True)[1].data),
                    f'output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                            f'target (epoch: {epoch}, step: {step})')
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss) / len(epoch_loss)
                print(f'loss: {average} (epoch: {epoch}, step: {step})')
            if args.steps_save > 0 and step % args.steps_save == 0:
                filename = f'{args.model}-{epoch:03}-{step:04}.pth'
                torch.save(model.state_dict(), filename)
                print(f'save: {filename} (epoch: {epoch}, step: {step})')
Ejemplo n.º 5
0
def train(args, model, enc=False):
    global best_acc

    #TODO: calculate weights by processing dataset histogram (now its being set by hand from the torch values)
    #create a loder to run all images and calculate histogram of labels, then create weight array using class balancing

    weight = torch.ones(NUM_CLASSES)
    if (enc):
        weight[0] = 2.3653597831726
        weight[1] = 4.4237880706787
        weight[2] = 2.9691488742828
        weight[3] = 5.3442072868347
        weight[4] = 5.2983593940735
        weight[5] = 5.2275490760803
        weight[6] = 5.4394111633301
        weight[7] = 5.3659925460815
        weight[8] = 3.4170460700989
        weight[9] = 5.2414722442627
        weight[10] = 4.7376127243042
        weight[11] = 5.2286224365234
        weight[12] = 5.455126285553
        weight[13] = 4.3019247055054
        weight[14] = 5.4264230728149
        weight[15] = 5.4331531524658
        weight[16] = 5.433765411377
        weight[17] = 5.4631009101868
        weight[18] = 5.3947434425354
    else:
        weight[0] = 2.8149201869965
        weight[1] = 6.9850029945374
        weight[2] = 3.7890393733978
        weight[3] = 9.9428062438965
        weight[4] = 9.7702074050903
        weight[5] = 9.5110931396484
        weight[6] = 10.311357498169
        weight[7] = 10.026463508606
        weight[8] = 4.6323022842407
        weight[9] = 9.5608062744141
        weight[10] = 7.8698215484619
        weight[11] = 9.5168733596802
        weight[12] = 10.373730659485
        weight[13] = 6.6616044044495
        weight[14] = 10.260489463806
        weight[15] = 10.287888526917
        weight[16] = 10.289801597595
        weight[17] = 10.405355453491
        weight[18] = 10.138095855713

    weight[19] = 0

    #loader = DataLoader(VOC12(args.datadir, input_transform, target_transform),
    #    num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)

    assert os.path.exists(
        args.datadir), "Error: datadir (dataset directory) could not be loaded"

    co_transform = MyCoTransform(enc, augment=True, height=args.height)  #1024)
    co_transform_val = MyCoTransform(enc, augment=False,
                                     height=args.height)  #1024)
    dataset_train = cityscapes(args.datadir, co_transform, 'train')
    dataset_val = cityscapes(args.datadir, co_transform_val, 'val')

    loader = DataLoader(dataset_train,
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=True)
    loader_val = DataLoader(dataset_val,
                            num_workers=args.num_workers,
                            batch_size=args.batch_size,
                            shuffle=False)

    if args.cuda:
        criterion = CrossEntropyLoss2d(weight.cuda())
        #criterion = CriterionDataParallel(criterion).cuda()
    else:
        criterion = CrossEntropyLoss2d(weight)

    print(type(criterion))

    savedir = f'../save/{args.savedir}'

    if (enc):
        automated_log_path = savedir + "/automated_log_encoder.txt"
        modeltxtpath = savedir + "/model_encoder.txt"
    else:
        automated_log_path = savedir + "/automated_log.txt"
        modeltxtpath = savedir + "/model.txt"

    if (not os.path.exists(automated_log_path)
        ):  #dont add first line if it exists
        with open(automated_log_path, "a") as myfile:
            myfile.write(
                "Epoch\t\tTrain-loss\t\tTest-loss\t\tTrain-IoU\t\tTest-IoU\t\tlearningRate"
            )

    with open(modeltxtpath, "w") as myfile:
        myfile.write(str(model))

    #TODO: reduce memory in first gpu: https://discuss.pytorch.org/t/multi-gpu-training-memory-usage-in-balance/4163/4        #https://github.com/pytorch/pytorch/issues/1893
    """
	#Some optimizer examples:
    optimizer = Adam(model.parameters())    
    if args.model.startswith('FCN'):
        optimizer = SGD(model.parameters(), 1e-4, .9, 2e-5)
    if args.model.startswith('PSP'):
        optimizer = SGD(model.parameters(), 1e-2, .9, 1e-4)
    if args.model.startswith('Seg'):
        optimizer = SGD(model.parameters(), 1e-3, .9)
    if args.model.startswith('E'):
        #optimizer = Adam(model.parameters(), 1e-3, .9)
        optimizer = Adam(model.parameters(), 5e-4, .9, weight_decay=2e-4)
#5e-4 wd: 2e-4
    """
    #optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999),  eps=1e-08, weight_decay=2e-4)     ## scheduler 1
    optimizer = Adam(model.parameters(),
                     5e-4, (0.9, 0.999),
                     eps=1e-08,
                     weight_decay=1e-4)  ## scheduler 2

    start_epoch = 1
    if args.resume:
        #Must load weights, optimizer, epoch and best value.
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'

        assert os.path.exists(
            filenameCheckpoint
        ), "Error: resume option was used but checkpoint was not found in folder"
        checkpoint = torch.load(filenameCheckpoint)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_acc = checkpoint['best_acc']
        print("=> Loaded checkpoint at epoch {})".format(checkpoint['epoch']))

    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5) # set up scheduler     ## scheduler 1
    lambda1 = lambda epoch: pow(
        (1 - ((epoch - 1) / args.num_epochs)), 0.9)  ## scheduler 2
    scheduler = lr_scheduler.LambdaLR(optimizer,
                                      lr_lambda=lambda1)  ## scheduler 2

    if args.visualize and args.steps_plot > 0:
        board = Dashboard(args.port)

    for epoch in range(start_epoch, args.num_epochs + 1):
        print("----- TRAINING - EPOCH", epoch, "-----")

        scheduler.step(epoch)  ## scheduler 2

        epoch_loss = []
        time_train = []

        doIouTrain = args.iouTrain
        doIouVal = args.iouVal

        #TODO: remake the evalIoU.py code to avoid using "evalIoU.args"
        confMatrix = evalIoU.generateMatrixTrainId(evalIoU.args)
        perImageStats = {}
        nbPixels = 0

        usedLr = 0
        for param_group in optimizer.param_groups:
            print("LEARNING RATE: ", param_group['lr'])
            usedLr = float(param_group['lr'])

        model.train()
        for step, (images, labels) in enumerate(loader):

            start_time = time.time()
            #print (labels.size())
            #print (np.unique(labels.numpy()))
            #print("labels: ", np.unique(labels[0].numpy()))
            #labels = torch.ones(4, 1, 512, 1024).long()
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(images)
            targets = Variable(labels)
            outputs = model(inputs, only_encode=enc)

            #print("targets", np.unique(targets[:, 0].cpu().data.numpy()))

            optimizer.zero_grad()
            loss = criterion(outputs, targets[:, 0])
            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.data[0])
            time_train.append(time.time() - start_time)

            #print (outputs_cpu.size())
            #Add outputs to confusion matrix    #CODE USING evalIoU.py remade from cityscapes/scripts/evaluation/evalPixelLevelSemanticLabeling.py
            if (doIouTrain):
                #compatibility with criterion dataparallel
                if isinstance(outputs, list):  #merge gpu tensors
                    outputs_cpu = outputs[0].cpu()
                    for i in range(1, len(outputs)):
                        outputs_cpu = torch.cat(
                            (outputs_cpu, outputs[i].cpu()), 0)
                    #print(outputs_cpu.size())
                else:
                    outputs_cpu = outputs.cpu()

                #start_time_iou = time.time()
                for i in range(0, outputs_cpu.size(0)):  #args.batch_size
                    prediction = ToPILImage()(
                        outputs_cpu[i].max(0)[1].data.unsqueeze(0).byte())
                    groundtruth = ToPILImage()(labels[i].cpu().byte())
                    nbPixels += evalIoU.evaluatePairPytorch(
                        prediction, groundtruth, confMatrix, perImageStats,
                        evalIoU.args)
                #print ("Time to add confusion matrix: ", time.time() - start_time_iou)

            #print(outputs.size())
            if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0:
                start_time_plot = time.time()
                image = inputs[0].cpu().data
                #image[0] = image[0] * .229 + .485
                #image[1] = image[1] * .224 + .456
                #image[2] = image[2] * .225 + .406
                #print("output", np.unique(outputs[0].cpu().max(0)[1].data.numpy()))
                board.image(image, f'input (epoch: {epoch}, step: {step})')
                if isinstance(outputs, list):  #merge gpu tensors
                    board.image(
                        color_transform(
                            outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)),
                        f'output (epoch: {epoch}, step: {step})')
                else:
                    board.image(
                        color_transform(
                            outputs[0].cpu().max(0)[1].data.unsqueeze(0)),
                        f'output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                            f'target (epoch: {epoch}, step: {step})')
                print("Time to paint images: ", time.time() - start_time_plot)
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss) / len(epoch_loss)
                print(
                    f'loss: {average:0.4} (epoch: {epoch}, step: {step})',
                    "// Avg time/img: %.4f s" %
                    (sum(time_train) / len(time_train) / args.batch_size))

        average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss)
        #evalIoU.printConfMatrix(confMatrix, evalIoU.args)

        iouTrain = 0
        if (doIouTrain):
            # Calculate IOU scores on class level from matrix
            classScoreList = {}
            for label in evalIoU.args.evalLabels:
                labelName = evalIoU.trainId2label[label].name
                classScoreList[labelName] = evalIoU.getIouScoreForTrainLabel(
                    label, confMatrix, evalIoU.args)
            iouAvgStr = evalIoU.getColorEntry(
                evalIoU.getScoreAverage(classScoreList, evalIoU.args),
                evalIoU.args) + "{avg:5.3f}".format(
                    avg=evalIoU.getScoreAverage(
                        classScoreList, evalIoU.args)) + evalIoU.args.nocol

            iouTrain = float(
                evalIoU.getScoreAverage(classScoreList, evalIoU.args))
            print("EPOCH IoU on TRAIN set: ", iouAvgStr)
            #print("")
            #evalIoU.printClassScoresPytorchTrain(classScoreList, evalIoU.args)
            #print("--------------------------------")
            #print("Score Average : " + iouAvgStr )#+ "    " + niouAvgStr)
            #print("--------------------------------")
            #print("")
            #input ("Press key to continue...")

        #Validate on 500 val images after each epoch of training
        print("----- VALIDATING - EPOCH", epoch, "-----")
        model.eval()
        epoch_loss_val = []
        time_val = []

        #New confusion matrix data
        confMatrix = evalIoU.generateMatrixTrainId(evalIoU.args)
        perImageStats = {}
        nbPixels = 0

        for step, (images, labels) in enumerate(loader_val):
            start_time = time.time()
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(
                images, volatile=True
            )  #volatile flag makes it free backward or outputs for eval
            targets = Variable(labels, volatile=True)
            outputs = model(inputs, only_encode=enc)

            loss = criterion(outputs, targets[:, 0])
            epoch_loss_val.append(loss.data[0])
            time_val.append(time.time() - start_time)

            #Add outputs to confusion matrix
            if (doIouVal):
                #compatibility with criterion dataparallel
                if isinstance(outputs, list):  #merge gpu tensors
                    outputs_cpu = outputs[0].cpu()
                    for i in range(1, len(outputs)):
                        outputs_cpu = torch.cat(
                            (outputs_cpu, outputs[i].cpu()), 0)
                    #print(outputs_cpu.size())
                else:
                    outputs_cpu = outputs.cpu()

                #start_time_iou = time.time()
                for i in range(0, outputs_cpu.size(0)):  #args.batch_size
                    prediction = ToPILImage()(
                        outputs_cpu[i].max(0)[1].data.unsqueeze(0).byte())
                    groundtruth = ToPILImage()(labels[i].cpu().byte())
                    nbPixels += evalIoU.evaluatePairPytorch(
                        prediction, groundtruth, confMatrix, perImageStats,
                        evalIoU.args)
                #print ("Time to add confusion matrix: ", time.time() - start_time_iou)

            if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0:
                start_time_plot = time.time()
                image = inputs[0].cpu().data
                board.image(image, f'VAL input (epoch: {epoch}, step: {step})')
                if isinstance(outputs, list):  #merge gpu tensors
                    board.image(
                        color_transform(
                            outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)),
                        f'VAL output (epoch: {epoch}, step: {step})')
                else:
                    board.image(
                        color_transform(
                            outputs[0].cpu().max(0)[1].data.unsqueeze(0)),
                        f'VAL output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                            f'VAL target (epoch: {epoch}, step: {step})')
                print("Time to paint images: ", time.time() - start_time_plot)
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss_val) / len(epoch_loss_val)
                print(
                    f'VAL loss: {average:0.4} (epoch: {epoch}, step: {step})',
                    "// Avg time/img: %.4f s" %
                    (sum(time_val) / len(time_val) / args.batch_size))

        average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val)
        #scheduler.step(average_epoch_loss_val, epoch)  ## scheduler 1   # update lr if needed

        # Calculate IOU scores on class level from matrix
        iouVal = 0
        if (doIouVal):
            #start_time_iou = time.time()
            classScoreList = {}
            for label in evalIoU.args.evalLabels:
                labelName = evalIoU.trainId2label[label].name
                classScoreList[labelName] = evalIoU.getIouScoreForTrainLabel(
                    label, confMatrix, evalIoU.args)

            iouAvgStr = evalIoU.getColorEntry(
                evalIoU.getScoreAverage(classScoreList, evalIoU.args),
                evalIoU.args) + "{avg:5.3f}".format(
                    avg=evalIoU.getScoreAverage(
                        classScoreList, evalIoU.args)) + evalIoU.args.nocol
            iouVal = float(
                evalIoU.getScoreAverage(classScoreList, evalIoU.args))
            print("EPOCH IoU on VAL set: ", iouAvgStr)
            #print("")
            #evalIoU.printClassScoresPytorchTrain(classScoreList, evalIoU.args)
            #print("--------------------------------")
            #print("Score Average : " + iouAvgStr )#+ "    " + niouAvgStr)
            #print("--------------------------------")
            #print("")
            #print ("Time to calculate confusion matrix: ", time.time() - start_time_iou)
            #input ("Press key to continue...")

        # remember best valIoU and save checkpoint
        if iouVal == 0:
            current_acc = average_epoch_loss_val
        else:
            current_acc = iouVal
        is_best = current_acc > best_acc
        best_acc = max(current_acc, best_acc)
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
            filenameBest = savedir + '/model_best_enc.pth.tar'
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'
            filenameBest = savedir + '/model_best.pth.tar'
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': str(model),
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best, filenameCheckpoint, filenameBest)

        #SAVE MODEL AFTER EPOCH
        if (enc):
            filename = f'{savedir}/model_encoder-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_encoder_best.pth'
        else:
            filename = f'{savedir}/model-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_best.pth'
        if args.epochs_save > 0 and step > 0 and step % args.epochs_save == 0:
            torch.save(model.state_dict(), filename)
            print(f'save: {filename} (epoch: {epoch})')
        if (is_best):
            torch.save(model.state_dict(), filenamebest)
            print(f'save: {filenamebest} (epoch: {epoch})')
            if (not enc):
                with open(savedir + "/best.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" %
                                 (epoch, iouVal))
            else:
                with open(savedir + "/best_encoder.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" %
                                 (epoch, iouVal))

        #SAVE TO FILE A ROW WITH THE EPOCH RESULT (train loss, val loss, train IoU, val IoU)
        #Epoch		Train-loss		Test-loss	Train-IoU	Test-IoU		learningRate
        with open(automated_log_path, "a") as myfile:
            myfile.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.8f" %
                         (epoch, average_epoch_loss_train,
                          average_epoch_loss_val, iouTrain, iouVal, usedLr))

    return (model)  #return model (convenience for encoder-decoder training)
mask_file = 'DRIVE_dataset_groundTruth_train.hdf5'
patch_w = args.batch
patch_h = args.batch
patch_num_per_img = math.ceil(
    (2048 / args.patch_size) * (2048 / args.patch_size))
data_set = RetinalVesselTrainingDS(data_root, img_file, mask_file, patch_w,
                                   patch_h, patch_num_per_img)
data_loader = torch.utils.data.DataLoader(dataset=data_set,
                                          batch_size=10,
                                          shuffle=True,
                                          pin_memory=True)

criterion = CrossEntropyLoss2d()
optimizer = SGD(model.parameters(), 1e-4, .9, 2e-5)

board = Dashboard('8097')
color_transform = Colorize()


def train(train_dataloader, model, criterion, optimizer, epoch, display):
    model.train()

    for index, (img, mask) in enumerate(train_dataloader):
        input = Variable(img.cuda())
        target = Variable(mask.type(torch.LongTensor).cuda())
        output = model(input)
        optimizer.zero_grad()
        loss = criterion(output, target[:, 0])
        loss.backward()
        optimizer.step()
        print('loss: {}'.format(loss.cpu().data[0]))