Пример #1
0
def get_result(net, gpu=False):
    ids = get_ids(dir_img)

    val = get_imgs_and_masks(ids, dir_img, dir_mask, 1.0)

    val_dice = eval_net(net, val, gpu)
    print('Validation Dice Coeff: {}'.format(val_dice))
Пример #2
0
def batch_calc(args):
    from eval import eval_net

    # get net
    net = UNet(n_channels=3, n_classes=num_classes, bilinear=True)

    logging.info(f'Loading model {args.model}')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if args.device:
        device = torch.device(args.device)
    logging.info(f'Using device {device}')
    net.to(device=device)
    net.load_state_dict(torch.load(args.model, map_location=device))
    logging.info('Model loaded !')

    val_dataset = CityscapesDataset(type='val', scale=args.scale)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batchsize,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True,
                            drop_last=True)
    miou, ious, hist = eval_net(net, val_loader, device, type='miou')
    logging.info(f'total mIoU value: {miou}\n'
                 f'per category\'s IoU value: \n{ious}\n'
                 f'hist save as {args.histname}')
    np.savetxt(path.join('./test/hist/', args.histname), hist, delimiter=',')
    return miou
Пример #3
0
def train_net(net,
              train_data,
              val_data,
              optimizer,
              lr_scheduler,
              loss_fn,
              epochs,
              gpu=True,
              save_model=None,
              save_image=False,
              print_step=100):
    num_batch = len(train_data)
    for epoch in range(epochs):
        net.train()
        print("Epoch %d/%d" % (epoch + 1, epochs))
        epoch_loss = 0.0
        for step, data in enumerate(train_data):
            imgs, masks = data['img'], data['mask']
            if gpu:
                imgs = imgs.cuda()
                masks = masks.cuda()
            pred_masks = net(imgs)

            pred_masks_flat = pred_masks.view(-1)
            masks_flat = masks.view(-1)

            loss = loss_fn(pred_masks_flat, masks_flat)
            epoch_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (step + 1) % print_step == 0:
                print("Step %d/%d  Loss: %.4f" %
                      (step + 1, num_batch, epoch_loss / (step + 1)))

        print("Epoch %d/%d Finished!   Train loss: %.4f" %
              (epoch + 1, epochs, epoch_loss / (step + 1)))
        lr_scheduler.step(epoch_loss)
        val_dice, val_loss = eval_net(net,
                                      val_data,
                                      loss_fn,
                                      gpu,
                                      save_image,
                                      epoch=epoch + 1)
        print('Validation dice coeff is %f, loss is %.4f' %
              (val_dice, val_loss))

        if save_model:
            if not os.path.exists(save_model):
                os.mkdir(save_model)
            save_path = os.path.join(
                save_model, "epoch_%d_dice_%f.pth" % (epoch + 1, val_dice))
            torch.save(net.state_dict(), save_path)
Пример #4
0
def train_net(net, args, epochs, batch_size, lr, device):

    train = BasicDataset(args, False)
    val = BasicDataset(args, True)
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=1,
                            shuffle=False,
                            pin_memory=True,
                            drop_last=True)
    n_val = len(val)
    n_train = len(train)
    #
    writer = SummaryWriter('./records/tensorboard')
    global_step = 0
    optimizer = optim.Adam(net.parameters(), lr=lr)
    #
    criterion = nn.BCEWithLogitsLoss()
    for epoch in range(epochs):
        net.train()
        epoch_loss = 0
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                masks = batch['mask']
                name = batch['name'][0]
                imgs = imgs.to(device=device, dtype=torch.float32)
                masks = masks.to(device=device, dtype=torch.float32)
                #
                masks_pred = net(imgs)
                loss = criterion(masks_pred, masks)
                pbar.set_postfix(**{'loss': loss.item()})
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                pbar.update(imgs.shape[0])
                #
                global_step += 1
                writer.add_scalar('train', loss.item(), global_step)
                if (global_step % (n_train // (batch_size)) == 0):
                    valid_score = eval_net(args, net, val_loader, device)
                    # scheduler.step(val_score)
                    writer.add_scalar('valid', valid_score, global_step)
                    torch.save(
                        net.state_dict(), args.checkpoints_dir +
                        'naive_baseline_{}.pth'.format(epoch))
Пример #5
0
    def validation(self, number_of_train_data, epoch):
        loss = self.epoch_loss / (number_of_train_data + 1)
        print("Epoch finished ! Loss: {}".format(loss))
        torch.save(
            self.net.state_dict(),
            str(
                self.save_weight_path.parent.joinpath(
                    "epoch_weight/{:05d}.pth".format(epoch))),
        )
        val_loss = eval_net(
            self.net,
            self.val_loader,
            self.gpu,
            self.loss_flag,
            self.vis,
            self.img_view_val,
            self.gt_view_val,
            self.pred_view_val,
            self.criterion,
        )

        print("val_loss: {}".format(val_loss))
        try:
            if min(self.val_losses) > val_loss:
                torch.save(self.net.state_dict(), str(self.save_weight_path))
                self.bad = 0
                print("update bad")
                with self.save_weight_path.parent.joinpath("best.txt").open(
                        "w") as f:
                    f.write("{}".format(epoch))
                pass
            else:
                self.bad += 1
                print("bad ++")
        except ValueError:
            torch.save(self.net.state_dict(), str(self.save_weight_path))
        self.val_losses.append(val_loss)

        if self.need_vis:
            self.update_vis_plot(
                iteration=epoch,
                loss=loss,
                val=[loss, val_loss],
                window1=self.iter_plot,
                window2=self.epoch_plot,
                update_type="append",
            )
        print("bad = {}".format(self.bad))
        self.epoch_loss = 0
Пример #6
0
def validation(net, imgs, true_masks, masks_pred, writer, val_loader, n_val,
               global_step):
    val_score = eval_net(net, val_loader, device, n_val)
    if net.n_classes > 1:
        logging.info('Validation cross entropy: {}'.format(val_score))
        writer.add_scalar('Loss/test', val_score, global_step)

    else:
        logging.info('Validation Dice Coeff: {}'.format(val_score))
        writer.add_scalar('Dice/test', val_score, global_step)

    writer.add_images('images', imgs, global_step)
    if net.n_classes == 1:
        writer.add_images('masks/true', true_masks, global_step)
        writer.add_images('masks/pred',
                          torch.sigmoid(masks_pred) > 0.5, global_step)
Пример #7
0
def validation_only(net,
                    device,
                    batch_size=1,
                    img_width=0, 
                    img_height=0,
                    img_scale=1.0,
                    use_bw=False,
                    standardize=False,
                    compute_statistics=False):

    load_statstics = not compute_statistics
    dataset = BasicDataset(dir_img_test, dir_mask_test, img_width, img_height, img_scale, use_bw,
                           standardize=standardize, load_statistics=load_statstics, save_statistics=True)
    val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)
    val_score = eval_net(net, val_loader, device)
    if net.n_classes > 1:
        logging.info('Validation cross entropy: {}'.format(val_score))
    else:
        logging.info('Validation Dice Coeff: {}'.format(val_score))
Пример #8
0
def test_net(net, device, batch_size=4, scale=512, threshold=0.5):

    dataset = BasicDataset(dir_img, dir_mask, 512, False, 5)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=8,
                        pin_memory=True)

    tm = TimeManager()

    val_score, precision, recall = eval_net(net, loader, device, threshold)

    if net.n_classes > 1:
        print('Validation cross entropy:', val_score)
    else:
        print('Validation Dice Coeff:', val_score)
        print('Validation Precision:', precision)
        print('Validation Recall:', recall)

    tm.show()
Пример #9
0
def train_net(net,
              epochs=5,
              batch_size=1,
              lr=0.1,
              val_percent=0.05,
              save_cp=True,
              gpu=False,
              img_scale=0.5):

    # dir_img = 'data/train/'
    # dir_mask = 'data/train_masks/'
    dir_img = 'E:/git/dataset/tgs-salt-identification-challenge/train/images/'
    dir_mask = 'E:/git/dataset/tgs-salt-identification-challenge/train/masks/'
    # dir_img = 'E:/git/dataset/tgs-salt-identification-challenge/train/my_images/'
    # dir_mask = 'E:/git/dataset/tgs-salt-identification-challenge/train/my_masks/'
    dir_checkpoint = 'checkpoints/'

    ids = get_ids(dir_img)
    ids = split_ids(ids)

    iddataset = split_train_val(ids, val_percent)

    print('''
    Starting training:
        Epochs: {}
        Batch size: {}
        Learning rate: {}
        Training size: {}
        Validation size: {}
        Checkpoints: {}
        CUDA: {}
    '''.format(epochs, batch_size, lr, len(iddataset['train']),
               len(iddataset['val']), str(save_cp), str(gpu)))

    N_train = len(iddataset['train'])

    optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=0.0005)

    criterion = nn.BCELoss()

    for epoch in range(epochs):
        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
        net.train()

        # reset the generators
        train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask,
                                   img_scale)
        val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask,
                                 img_scale)

        epoch_loss = 0

        for i, b in enumerate(batch(train, batch_size)):
            imgs = np.array([i[0] for i in b]).astype(np.float32)
            # true_masks = np.array([i[1] for i in b])#np.rot90(m)
            true_masks = np.array([i[1].T / 65535 for i in b])  #np.rot90(m)

            # show_batch_image(true_masks)
            imgs = torch.from_numpy(imgs)
            true_masks = torch.from_numpy(true_masks)

            if gpu:
                imgs = imgs.cuda()
                true_masks = true_masks.cuda()

            # show_batch_image(imgs)

            masks_pred = net(imgs)
            masks_probs_flat = masks_pred.view(-1)

            true_masks_flat = true_masks.view(-1)

            loss = criterion(masks_probs_flat, true_masks_flat)
            epoch_loss += loss.item()

            print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train,
                                                     loss.item()))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('Epoch finished ! Loss: {}'.format(epoch_loss / i))

        if 1:
            val_dice = eval_net(net, val, gpu)
            print('Validation Dice Coeff: {}'.format(val_dice))

        if save_cp:
            torch.save(net.state_dict(),
                       dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
            print('Checkpoint {} saved !'.format(epoch + 1))
Пример #10
0
def train_net(
    net,
    writer,
    load,
    epochs=5,
    batch_size=1,
    lr=0.1,
    val_percent=0.1,
    save_cp=False,
    gpu=True,
):

    image_dir = 'train/images_cut/'
    mask_dir = 'train/masks_cut/'
    checkpoint_dir = 'checkpoints/'

    name_list = get_names(image_dir)
    split_list = train_val_split(name_list, val_percent)

    print('''
        Starting training:
        Epochs: {}
        Batch size: {}
        Learning rate: {}
        Training size: {}
        Validation size: {}
        Checkpoints: {}
        CUDA: {}
    '''.format(epochs, batch_size, lr, len(split_list['train']),
               len(split_list['val']), str(save_cp), str(gpu)))
    N_train = len(split_list['train'])
    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=0.005)

    print('Model loaded from {}'.format(args.load))
    model_dict = net.state_dict()
    pretrained_dict = torch.load(args.load)
    #        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    net.load_state_dict(model_dict)
    train_params = []
    if args.fix:
        print("fixing parameters")
        for k, v in net.named_parameters():
            train_params.append(k)
            pref = k[:12]
            if pref == 'module.conv1' or pref == 'module.conv2':
                v.requires_grad = False
                train_params.remove(k)

        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      net.parameters()),
                               lr=lr,
                               weight_decay=0.005)

    criterion = mixloss()

    for epoch in range(epochs):
        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
        net.train()

        train = get_train_pics(image_dir, mask_dir, split_list)

        epoch_loss = 0

        for i, samps in enumerate(batch(train, batch_size)):
            images = np.array([samp['image'] for samp in samps])
            masks = np.array([samp['mask'] for samp in samps])

            images = torch.from_numpy(images).type(torch.FloatTensor)
            masks = torch.from_numpy(masks).type(torch.FloatTensor)

            if gpu:
                images = images.cuda()
                true_masks = masks.cuda()

            masks_pred = net(images)

            masks_probs_flat = masks_pred.view(-1)
            true_masks_flat = true_masks.view(-1)
            loss = criterion(masks_probs_flat, true_masks_flat)
            epoch_loss += loss.item()

            print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train,
                                                     loss.item()))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        avg_train_loss = epoch_loss / i
        print('Epoch finished ! Loss: {}'.format(avg_train_loss))

        val = get_val_pics(image_dir, mask_dir, split_list)

        if 1:
            val_iou, val_ls = eval_net(net, val, gpu)
            print('Validation IoU: {} Loss:{}'.format(val_iou, val_ls))

        writer.add_scalar('train/loss', avg_train_loss, epoch)
        writer.add_scalar('val/loss', val_ls, epoch)
        writer.add_scalar('val/IoU', val_iou, epoch)

        torch.save(net.state_dict(),
                   checkpoint_dir + 'CP{}.pth'.format(epoch + 1))
        print('Checkpoint {} saved !'.format(epoch + 1))
def train_net(
        net,
        epochs=5,
        batch_size=2,
        lr=0.0001,
        save_cp=True,
        gpu=True,
        target_path='',
        checkpoint_path='/mnt/HDD1/Frederic/Segmentation/Seg_deepv3/checkpoints/'
):

    #Set path to store checkpoint
    dir_checkpoint = checkpoint_path
    result_path = result_path_global

    #Print training details
    print('''
	Get Start, training details:
		Epochs: {}
		Batch size: {}
		Learning rate: {}
		Checkpoints: {}
		CUDA: {}
	'''.format(epochs, batch_size, lr, str(save_cp), str(gpu)))

    #loss function and optimizer
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)

    criterion = nn.CrossEntropyLoss()
    # criterion = nn.BCELoss()

    #Train iteration
    logger = Logger(result_path + 'log.txt', title='ISIC2016_U_Net')
    logger.set_names(['Epochs', 'Avg_Trainning_Loss', 'Val_Dice_coefficient'])
    #load data
    val_sets = load_validation_data()
    start_epoch = args.start_epoch
    best_dice, best_epoch = 0, 0
    for epoch in range(start_epoch, start_epoch + epochs):
        net.train()
        #use epoch_loss to store total loss for whole iteration
        trainloader, datasize = load_train_data(args.batchsize)
        epoch_loss = 0
        if epoch == 75 or epoch == 150 or epoch == 225:
            lr = lr * 0.1
            optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
        with tqdm(total=datasize / batch_size) as pbar:
            for ite, data in enumerate(trainloader[0]):
                imgs = data[0]
                true_masks = data[1]

                if gpu:
                    imgs = imgs.cuda()
                    true_masks = true_masks.cuda()

                masks_pred = net(imgs)
                true_masks = true_masks.squeeze(dim=1)

                # loss = DiceLoss(masks_pred,true_masks)
                loss = criterion(masks_pred, true_masks.long())
                epoch_loss += loss.item()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                #set bar for training
                pbar.set_description(
                    'Epoch:[%d|%d],loss: %.4f ' %
                    (epoch + 1, args.epochs + start_epoch, loss))
                pbar.update(1)

        avg_loss = epoch_loss / ite
        print('Epoch finished ! Loss: {}'.format(avg_loss))

        # save the sample output every 40 epochs
        save_sample_mask = False
        if (epoch + 1) % 50 == 0:
            save_sample_mask = True
        val_dice = eval_net(net, val_sets, epoch, gpu, save_sample_mask,
                            result_path)
        print('Validation Dice Coeff: {}'.format(val_dice))
        logger.append([epoch + 1, epoch_loss / ite, val_dice])

        #save best epoch and checkpoint
        if best_dice < val_dice:
            best_dice = val_dice
            best_epoch = epoch
            torch.save(net.state_dict(),
                       dir_checkpoint + 'best_checkpoint.pth')

        print('best checkpoint is epoch {} with dice {} '.format(
            best_epoch, best_dice))
        #save normal epoch
        if save_cp and (epoch + 1) % 50 == 0:
            torch.save(net.state_dict(),
                       dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
            # print('Checkpoint {} saved !'.format(epoch + 1))

    #plot fig after train
    logger.close()
    logger.plot()
Пример #12
0
def train_net(net, epochs=100, batch_size=2, lr=0.02, val_percent=0.05,
              cp=True, gpu=False):
    dir_img = '/home/wdh/DataSets/hand-segmentation/GTEA_gaze_part/Resize/Images/'
    dir_mask = '/home/wdh/DataSets/hand-segmentation/GTEA_gaze_part/Resize/Masks_1/'
    dir_checkpoint = 'checkpoints/'

    ids = get_ids(dir_img)
    ids = split_ids(ids)

    iddataset = split_train_val(ids, val_percent)

    print('''
    Starting training:
        Epochs: {}
        Batch size: {}
        Learning rate: {}
        Training size: {}
        Validation size: {}
        Checkpoints: {}
        CUDA: {}
    '''.format(epochs, batch_size, lr, len(iddataset['train']),
               len(iddataset['val']), str(cp), str(gpu)))

    N_train = len(iddataset['train'])

    optimizer = optim.Adam(net.parameters(),lr=lr,betas=(0.9,0.99))
    criterion = nn.BCELoss()

    for epoch in range(epochs):
        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))

        # reset the generators
        train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
        val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask)

        epoch_loss = 0

        if 1:
            val_dice = eval_net(net, val, gpu)
            print('Validation Dice Coeff: {}'.format(val_dice))

        for i, b in enumerate(batch(train, batch_size)):
            X = np.array([i[0] for i in b])
            y = np.array([i[1] for i in b])

            X = torch.FloatTensor(X)
            y = torch.ByteTensor(y)

            if gpu:
                X = Variable(X).cuda()
                y = Variable(y).cuda()
            else:
                X = Variable(X)
                y = Variable(y)

            y_pred = net(X)
            probs = F.sigmoid(y_pred)
            probs_flat = probs.view(-1)

            y_flat = y.view(-1)

            loss = criterion(probs_flat, y_flat.float())
            epoch_loss += loss.data[0]

            print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train,
                                                     loss.data[0]))

            optimizer.zero_grad()

            loss.backward()

            optimizer.step()

        print('Epoch finished ! Loss: {}'.format(epoch_loss / i))

        if cp:
            torch.save(net.state_dict(),
                       dir_checkpoint + 'CP{}.pth'.format(epoch + 1))

            print('Checkpoint {} saved !'.format(epoch + 1))
Пример #13
0
    torch.save(net.state_dict(), save_file)
    log.info('Saving pruned to {}...'.format(save_file))

    save_txt = osp.join(save_dir, "pruned_channels.txt")
    pruner.channel_save(save_txt)
    log.info('Pruned channels to {}...'.format(save_txt))

    del net, pruner
    net = UNet(n_channels=3, n_classes=1, f_channels=save_txt)
    log.info("Re-Built model using {}...".format(save_txt))
    if args.gpu:
        net.cuda()
    if args.load:
        net.load_state_dict(torch.load(save_file))
        log.info('Re-Loaded checkpoint from {}...'.format(save_file))

    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=0.0005)

    # Use epochs or iterations for fine-tuning
    save_file = osp.join(save_dir, "Finetuned.pth")

    finetune(net, optimizer, criterion, iddataset['train'], log, save_file,
             args.iters, args.epochs, args.batch_size, args.gpu, args.scale)

    val_dice = eval_net(net, val, len(iddataset['val']), args.gpu,
                        args.batch_size)
    log.info('Validation Dice Coeff: {}'.format(val_dice))
def train_net(net,
              epochs=5,
              batch_size=1,
              lr=0.1,
              val_percent=0.2,
              save_cp=True,
              gpu=False,
              img_scale=0.5):
    path = [['data/ori1/', 'data/gt1/'],
            ['data/original1/', 'data/ground_truth1/'],
            ['data/Original/', 'data/Ground_Truth/']]
    dir_img = path[0][0]
    dir_mask = path[0][1]
    dir_checkpoint = 'sdgcheck/'

    ids = get_ids(dir_img)
    ids = split_ids(ids)

    iddataset = split_train_val(ids, val_percent)

    print('''
    Starting training:
        Epochs: {}
        Batch size: {}
        Learning rate: {}
        Training size: {}
        Validation size: {}
        Checkpoints: {}
        CUDA: {}
    '''.format(epochs, batch_size, lr, len(iddataset['train']),
               len(iddataset['val']), str(save_cp), str(gpu)))

    N_train = len(iddataset['train'])

    optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=0.7,
                          weight_decay=0.005)
    '''
    optimizer = optim.Adam(net.parameters(),
                      lr=lr,

                      weight_decay=0.0005)
    '''
    criterion = nn.BCELoss()

    for epoch in range(epochs):
        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
        net.train()

        # reset the generators
        train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask,
                                   img_scale)
        val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask,
                                 img_scale)

        epoch_loss = 0
        x = 0
        for i, b in enumerate(batch(train, batch_size)):
            imgs = np.array([i[0] for i in b]).astype(np.float32)
            true_masks = np.array([i[1] for i in b])
            '''
            ori=np.transpose(imgs[0], axes=[1, 2, 0])   
            scipy.misc.imsave("ori/ori_"+str(x)+'.jpg', ori)
            
            gt = np.stack((true_masks[0],)*3, axis=-1)
            
            #gt=np.transpose(true_masks[0], axes=[1, 2, 0])
            scipy.misc.imsave("gt/gt_"+str(x)+'.jpg', gt)
            '''
            imgs = torch.from_numpy(imgs)
            true_masks = torch.from_numpy(true_masks)

            x += 1
            if gpu:
                imgs = imgs.cuda()
                true_masks = true_masks.cuda()

            masks_pred = net(imgs)
            masks_probs_flat = masks_pred.view(-1)

            true_masks_flat = true_masks.view(-1)

            loss = criterion(masks_probs_flat, true_masks_flat)
            epoch_loss += loss.item()

            print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train,
                                                     loss.item()))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('Epoch finished ! Loss: {}'.format(epoch_loss / i))

        if 1:
            val_dice = eval_net(net, val, gpu)
            print('Validation Dice Coeff: {}'.format(val_dice))

        if save_cp:
            torch.save(net.state_dict(),
                       dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
            print('Checkpoint {} saved !'.format(epoch + 1))
Пример #15
0
def train_net(net,
              device,
              epochs=300,
              batch_size=1,
              lr=0.1,
              val_percent=0.5,
              save_cp=True,
              img_scale=0.366):

    #histogram_matching()
    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    #n_val = int(len(dataset) * val_percent)
    #n_train = len(dataset) - n_val
    #train, val = random_split(dataset, [n_train, n_val])
    n_val = 1
    n_train = 1
    #train = dataset
    val = list(dataset)[0:1]
    train = list(dataset)[0:1]
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

    writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
    global_step = 0

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images scaling:  {img_scale}
    ''')

    #optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8)
    optimizer = torch.optim.SGD(net.parameters(),lr = args.lr,momentum = 0.9)
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        adjust_learning_rate(optimizer, epoch)
        lr = optimizer.param_groups[0]['lr']
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                #true_mask_pro = true_masks.squeeze().numpy().astype(np.uint8)
                #img_show = sitk.GetImageFromArray(true_mask_pro)
                #sitk.WriteImage(img_show, './data/pred/scale{}_mask.nii'.format(img_scale))
                '''
                imgs_pnp = imgs.squeeze().cpu().detach().numpy()
                imgs_show = sitk.GetImageFromArray(imgs_pnp)
                sitk.WriteImage(imgs_show, f'./data/debug/epoch300_41Img/input_round1/{epoch+1}_{global_step+1}.nii.gz')
                img_pro = (imgs*32768).squeeze().numpy().astype(np.int16)
                true_masks_pnp = true_masks.squeeze().cpu().detach().numpy()
                true_masks_show = sitk.GetImageFromArray(true_masks_pnp)
                sitk.WriteImage(true_masks_show, f'./data/debug/epoch300_41Img/gt_round1/{epoch+1}_{global_step+1}.nii.gz')
                img_show = sitk.GetImageFromArray(img_pro)
                #sitk.WriteImage(img_show, './data/pred/scale{}_input.nii'.format(img_scale))
                '''
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)
 

                masks_pred = net(imgs)
                # if (epoch+1) % 20 == 0:
                #     masks_pnp = masks_pred.squeeze().cpu().detach().numpy().astype(np.int16)
                #     masks_show = sitk.GetImageFromArray(masks_pnp)
                #     sitk.WriteImage(masks_show, f'./data/debug/scale0.5_1Img/model_out/{epoch+1}_{global_step+1}.nii')
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % (n_train // (batch_size)) == 0:
                #if global_step % (1) == 0:
                    val_score = eval_net(net, val_loader, device, n_val)
                    if net.n_classes > 1:
                        logging.info('Validation cross entropy: {}'.format(val_score))
                        writer.add_scalar('Loss/test', val_score, global_step)

                    else:
                        logging.info('Validation Dice Coeff: {}'.format(val_score))
                        writer.add_scalar('Dice/test', val_score, global_step)

                #writer.add_images('images', imgs, global_step)
                #if net.n_classes == 1:
                    #writer.add_images('masks/true', true_masks, global_step)
                    #writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)

        if save_cp:
            if (epoch+1) % 10 == 0:
                try:
                    os.mkdir(dir_checkpoint)
                    logging.info('Created checkpoint directory')
                except OSError:
                    pass
                torch.save(net.state_dict(),
                           dir_checkpoint + f'72Img_80160_f4/' +f'CP_epoch{epoch + 1}.pth')
                logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()
Пример #16
0
def train_net(
    net,
    device,
    writer_detail,
    writer_main,
    random_seed,
    epochs=5,
    batch_size=1,
    lr=0.001,
    save_cp=True,
    val_i=4,
):

    train_dataset, val_dataset, train_list, val_list = make_dataset(
        root_dir, dataset_dir, val_i)
    n_val = val_dataset.__len__()
    n_train = train_dataset.__len__()
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=0,
                              pin_memory=True,
                              drop_last=True,
                              worker_init_fn=np.random.seed(random_seed))
    val_loader = DataLoader(val_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=0,
                            pin_memory=True,
                            drop_last=False,
                            worker_init_fn=np.random.seed(random_seed))

    global_step = 0

    logging.info(f'''Starting training:
        val_i:           {val_i}
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
    ''')

    # optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    optimizer = optim.Adam(params=net.parameters(), lr=lr)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if n_classes > 1 else 'max', patience=2)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                 gamma=np.power(
                                                     0.1, 1 / epochs))

    if n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss(reduction='mean')

    best_score = 0
    best_net = net
    best_e = 0
    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                contour_masks = batch['mask_contour']
                true_biMasks = (true_masks > 0).int()
                contour_masks = (contour_masks > 0).int()
                assert imgs.shape[1] == n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if n_classes == 1 else torch.long
                true_biMasks = true_biMasks.to(device=device, dtype=mask_type)
                contour_masks = contour_masks.to(device=device,
                                                 dtype=mask_type)
                # combine_masks = torch.cat([true_biMasks, contour_masks], dim=1) #

                masks_pred = net(imgs)
                loss_mask = criterion(masks_pred[:, 0:1], true_biMasks)
                loss_contour = criterion(masks_pred[:, 1:2], contour_masks)
                loss = loss_mask + 10 * loss_contour
                epoch_loss += loss.item()
                writer_main.add_scalar('val_%d_Loss/train' % val_i,
                                       loss.item(), global_step)

                pbar.set_postfix(loss_mask=loss_mask.item(),
                                 loss_contour=loss_contour.item())

                optimizer.zero_grad()
                loss.backward()
                # nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % ((n_train + n_val) // (2 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer_detail.add_histogram(
                            'val_%d_weights/' % val_i + tag,
                            value.data.cpu().numpy(), global_step)
                        writer_detail.add_histogram(
                            'val_%d_grads/' % val_i + tag,
                            value.grad.data.cpu().numpy(), global_step)

                    writer_detail.add_images('val_%d_images' % val_i,
                                             imgs[0:1], global_step)
                    if n_classes == 1:
                        writer_detail.add_images('val_%d_masks/true' % val_i,
                                                 true_biMasks[0:1],
                                                 global_step)
                        writer_detail.add_images(
                            'val_%d_masks/pred' % val_i,
                            torch.sigmoid(masks_pred[0:1, 0:1]) > 0.5,
                            global_step)
                        writer_detail.add_images(
                            'val_%d_masks_contour/true' % val_i,
                            contour_masks[0:1], global_step)
                        writer_detail.add_images(
                            'val_%d_masks_contour/pred' % val_i,
                            torch.sigmoid(masks_pred[0:1, 1:2]) > 0.5,
                            global_step)

        val_score_mask, val_score_contour = eval_net(net, val_loader, device,
                                                     n_classes)
        # scheduler.step(val_score)
        writer_main.add_scalar('val_%d_learning_rate' % val_i,
                               optimizer.param_groups[0]['lr'], global_step)

        # if val_score>=best_score:
        #     best_score=val_score
        #     best_net=net
        #     best_e=epoch

        logging.info(
            'val {} Validation score mask: {} Validation score contour: {}'.
            format(val_i, val_score_mask, val_score_contour))
        # if val_score<10:
        writer_main.add_scalar('val_%d_score_mask/test' % val_i,
                               val_score_mask, global_step)
        writer_main.add_scalar('val_%d_score_contour/test' % val_i,
                               val_score_contour, global_step)

        scheduler.step()

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(
                net.state_dict(), dir_checkpoint +
                f'val {val_i}_CP_epoch{epoch + 1}_scoreMask_%.4f_scoreContour_%.4f.pth'
                % (val_score_mask, val_score_contour))
        # logging.info(f'val {val_i} Checkpoint {epoch + 1} saved, best score: %.4f !'%best_score)

    # result_path='result_img'
    # if not os.path.exists(result_path):
    #     os.mkdir(result_path)
    # jaccard_vali = predict_netOnDataList(root_dir, best_net, val_list, device, result_path=result_path, vali=val_i, mask_threhold=0.5, vis_flag=True)
    # logging.info('val %d epoch %d jaccard score: %.6f'%(val_i, best_e+1, jaccard_vali))

    # if save_cp:
    #     try:
    #         os.mkdir(dir_checkpoint)
    #         logging.info('Created checkpoint directory')
    #     except OSError:
    #         pass
    #     torch.save(best_net.state_dict(),
    #                dir_checkpoint + f'val {val_i}_CP_epoch{best_e + 1}_score_%.4f_jaccard_%.4f.pth'%(best_score,jaccard_vali))
    #     logging.info(f'val {val_i} Checkpoint {best_e + 1} saved, best score: %.4f !'%best_score)

    # return jaccard_vali
    return 0
Пример #17
0
def train_net(net,
              epochs=5,
              batch_size=1,
              lr=1e-3,
              val_percent=0.05,
              save_cp=True,
              gpu=False,
              img_scale=0.5):

    dir_img = '/home/xyj/data/spacenet/vegas/images_rgb_1300/'
    dir_mask = '/home/xyj/test/Pytorch-UNet/data/train_mask_point/'
    dir_checkpoint = 'checkpoints_point/'

    if not os.path.exists(dir_checkpoint):
        os.mkdir(dir_checkpoint)

    # ids = get_ids(dir_img)  # 返回train文件夹下文件的名字列表,生成器(except last 4 character,.jpg这样的)
    with open('train_list.txt', 'r') as f:
        lines = f.readlines()
        ids = (i.strip('\n')[:-4] for i in lines)

    ids = split_ids(
        ids)  # 返回(id, i), id属于ids,i属于range(n),相当于把train的图✖️了n倍多张,是tuple的生成器

    iddataset = split_train_val(
        ids, val_percent
    )  # validation percentage,是dict = {"train": ___(一个list), "val":___(一个list)}

    print('''
    Starting training:
        Epochs: {}
        Batch size: {}
        Learning rate: {}
        Training size: {}
        Validation size: {}
        Checkpoints: {}
        CUDA: {}
    '''.format(epochs, batch_size, lr, len(iddataset['train']),
               len(iddataset['val']), str(save_cp), str(gpu)))

    N_train = len(iddataset['train'])

    #     optimizer = optim.SGD(net.parameters(),
    #                           lr=lr,
    #                           momentum=0.9,
    #                           weight_decay=0.0005)
    optimizer = optim.Adam(net.parameters(),
                           lr=lr,
                           betas=(0.9, 0.999),
                           eps=1e-3)
    #     scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=40,gamma = 0.3)

    criterion = nn.BCELoss()

    for epoch in range(epochs):
        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
        net.train()

        # reset the generators
        train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask,
                                   img_scale)
        val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask,
                                 img_scale)
        epoch_loss = 0

        for i, b in enumerate(batch(train, batch_size)):
            imgs = np.array([i[0] for i in b]).astype(np.float32)
            true_masks = np.array([i[1] // 200 for i in b])

            imgs = torch.from_numpy(imgs)
            true_masks = torch.from_numpy(true_masks)

            if gpu:
                imgs = imgs.cuda()
                true_masks = true_masks.cuda()

            masks_pred = net(imgs)
            masks_probs_flat = masks_pred.view(-1)

            true_masks_flat = true_masks.view(-1)

            loss = criterion(masks_probs_flat, true_masks_flat)
            epoch_loss += loss.item()

            print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train,
                                                     loss.item()))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


#             scheduler.step()

        print('Epoch finished ! Loss: {}'.format(epoch_loss / i))

        if 1:
            val_dice = eval_net(net, val, gpu)
            print('Validation Dice Coeff: {}'.format(val_dice))

        if save_cp:
            torch.save(net.state_dict(),
                       dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
            print('Checkpoint {} saved !'.format(epoch + 1))
Пример #18
0
def train_net(net,
              device,
              epochs=5,
              batch_size=2,
              lr=0.0001,
              val_percent=0.2,
              save_cp=True,
              img_scale=1):

    # Init dataset and train/test split
    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])

    # Call DataLoader
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True,
                              drop_last=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)

    # Writer to tensorboard
    writer = SummaryWriter(
        comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
    global_step = 0

    logging.info(f'''Starting training:
            Epochs:          {epochs}
            Batch size:      {batch_size}
            Learning rate:   {lr}
            Training size:   {n_train}
            Validation size: {n_val}
            Checkpoints:     {save_cp}
            Device:          {device.type}
            Images scaling:  {img_scale}
        ''')

    # Init optimizer and define lr_scheduler
    optimizer = optim.Adam(net.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)

    # In this version, we use BCEWithLogitsLoss
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']

                # Set-up device
                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long

                true_masks = true_masks.to(device=device, dtype=mask_type)

                # Forward
                masks_pred = net(imgs)

                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1

                if global_step % (n_train // (2 * batch_size)) == 0:
                    # if global_step % 100 == 0:
                    # Track weight and gradient
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag,
                                             value.data.cpu().numpy(),
                                             global_step)
                        writer.add_histogram('grads/' + tag,
                                             value.grad.data.cpu().numpy(),
                                             global_step)

                    val_score, val_score_iou = eval_net(
                        net, val_loader, device)
                    scheduler.step(val_score)

                    writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'],
                                      global_step)

                    # Visualize val scores
                    logging.info('Validation Dice Coeff: {}'.format(val_score))
                    logging.info(
                        'Validation IoU Coeff: {}'.format(val_score_iou))
                    writer.add_scalar('Dice/test', val_score, global_step)
                    writer.add_scalar('IoU/test', val_score_iou, global_step)
                    writer.add_images('images', imgs, global_step)

                    if net.n_classes == 1:
                        writer.add_images('masks/true', true_masks,
                                          global_step)
                        writer.add_images('masks/pred',
                                          torch.sigmoid(masks_pred) > 0.5,
                                          global_step)

        # test with sample images
        test_folder = 'test/test_epoch_{}_new'.format(epoch)
        os.makedirs(test_folder)
        dirs = os.listdir('test/test_set')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        for file in dirs:
            img = Image.open(os.path.join('test/test_set', file))
            mask = predict_img(net=net, full_img=img, device=device)

            result = mask_to_image(mask)
            result.save(os.path.join(test_folder, file))

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()
Пример #19
0
def train_net(net,
              epochs=5,
              batch_size=1,
              lr=0.1,
              val_percent=0.05,
              save_cp=True,
              gpu=False,
              img_scale=0.5,
              dir_img=None,
              dir_mask=None,
              dir_checkpoint=None,
              channels=1,
              classes=1):

    ids = os.listdir(dir_img)

    if not os.path.exists(dir_checkpoint):
        os.makedirs(dir_checkpoint, mode=0o755)

    iddataset = split_train_val(ids, val_percent)

    print('Starting training:')
    print('Epochs: ' + str(epochs))
    print('Batch size: ' + str(batch_size))
    print('Learning rate: ' + str(lr))
    print('Training size: ' + str(len(iddataset['train'])))
    print('Validation size: ' + str(len(iddataset['val'])))
    print('Checkpoints: ' + str(save_cp))

    N_train = len(iddataset['train'])

    optimizer = optim.RMSprop(net.parameters(), lr=lr)
    criterion = nn.BCELoss()

    for epoch in range(epochs):
        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
        net.train()

        # reset the generators
        train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask,
                                   img_scale)
        val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask,
                                 img_scale)

        epoch_loss = 0

        # Run Batch
        for i, b in enumerate(batch(train, batch_size)):

            # Grab data
            try:
                imgs = np.array([i[0] for i in b]).astype(np.float32)
                true_masks = np.array([i[1] for i in b])
            except:
                print(
                    'prob have dimension issues, wrong orientations or half reconned images'
                )
            # Deal with dimension issues
            if channels == 1:
                imgs = np.expand_dims(imgs, 1)
            if classes > 1:
                true_masks = to_categorical(true_masks, num_classes=classes)

            # Play in torch's sandbox
            imgs = torch.from_numpy(imgs)
            true_masks = torch.from_numpy(true_masks)

            # Send to GPU
            if gpu:
                imgs = imgs.cuda()
                true_masks = true_masks.cuda()

            # Predicted segmentations
            masks_pred = net(imgs)

            # Flatten
            masks_probs_flat = masks_pred.view(-1)
            true_masks_flat = true_masks.view(-1)

            # Calculate losses btwn true/predicted
            loss = criterion(masks_probs_flat, true_masks_flat)
            epoch_loss += loss.item()

            # Batch Loss
            print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train,
                                                     loss.item()))

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Epoch Loss
        print('Epoch finished ! Loss: {}'.format(epoch_loss / i))

        if 1:
            val_dice = eval_net(net, val, epoch, dir_checkpoint, gpu)
            print('Validation Dice Coeff: {}'.format(val_dice))

        if save_cp:
            torch.save(
                net.state_dict(),
                os.path.join(dir_checkpoint, 'CP{}.pth'.format(epoch + 1)))
            print('Checkpoint {} saved !'.format(epoch + 1))
Пример #20
0
def train_net(net,
              device,
              epochs=5,
              batch_size=1,
              lr=0.1,
              val_percent=0.1,
              save_cp=True,
              img_scale=0.5):

    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True)

    writer = SummaryWriter(
        comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
    global_step = 0

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images scaling:  {img_scale}
    ''')

    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8)
    print("net class num: {}".format(net.n_classes))
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']

                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                masks_pred = net(imgs)

                # print("predict mask: {}".format(masks_pred.size()))
                # print("true mask: {}".format(true_masks.size()))
                # print(type(true_masks))
                # a = np.array(true_masks)
                # # cnt = 0
                # # for row in a[0]:
                # np.savetxt("a.csv", a[0], delimiter=",")
                # #     cnt += 1
                # true_masks = torch.LongTensor(np.zeros((1,250,250)))
                # true_masks = true_masks.to(device=device, dtype=mask_type)
                print("predict mask: {}".format(masks_pred.size()))
                print("true mask: {}".format(true_masks.size()))

                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)

                print('success==========================')

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % (len(dataset) // (10 * batch_size)) == 0:
                    val_score = eval_net(net, val_loader, device, n_val)
                    if net.n_classes > 1:
                        logging.info(
                            'Validation cross entropy: {}'.format(val_score))
                        writer.add_scalar('Loss/test', val_score, global_step)

                    else:
                        logging.info(
                            'Validation Dice Coeff: {}'.format(val_score))
                        writer.add_scalar('Dice/test', val_score, global_step)

                    writer.add_images('images', imgs, global_step)
                    if net.n_classes == 1:
                        writer.add_images('masks/true', true_masks,
                                          global_step)
                        writer.add_images('masks/pred',
                                          torch.sigmoid(masks_pred) > 0.5,
                                          global_step)

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()
Пример #21
0
def train_net(net,
              device,
              epochs=100,
              batch_size=1,
              lr=0.1,
              val_percent=0.2,
              save_cp=True,
              img_scale=1):

    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=16,
                              pin_memory=True,
                              drop_last=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)

    gene_eval_data(val_loader, dir='./data/val/')

    writer = SummaryWriter(
        comment='LR_{}_BS_{}_SCALE_{}'.format(lr, batch_size, img_scale))
    global_step = 0

    logging.info('''Starting training:
        Epochs:          {}
        Batch size:      {}
        Learning rate:   {}
        Training size:   {}
        Validation size: {}
        Checkpoints:     {}
        Device:          {}
        Images scaling:  {}
    '''.format(epochs, batch_size, lr, n_train, n_val, save_cp, device.type,
               img_scale))

    # optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        'min' if net.n_classes > 1 else 'max',
        factor=0.5,
        patience=20)

    criterion = dice_loss
    # criterion = nn.BCELoss()

    last_loss = 9999
    last_val_score = 0
    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        step = 0
        mybatch_size = 4
        with tqdm(total=n_train,
                  desc='Epoch {}/{}'.format(epoch + 1, epochs),
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels,\
                    'Network has been defined with {} input channels, '.format(net.n_channels)+\
                'but loaded images have {} channels. Please check that '.format(imgs.shape[1])+\
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                global_step += 1
                writer.add_scalar('Loss/train', loss.item(), global_step)
                pbar.set_postfix(**{'loss (batch)': loss.item()})
                step += 1
                if step % mybatch_size == 0:

                    optimizer.step()
                    optimizer.zero_grad()
                    step = 0

                pbar.update(imgs.shape[0])


# if global_step % (len(dataset) // ( 2* batch_size)) == 0:
        for tag, value in net.named_parameters():
            tag = tag.replace('.', '/')
            writer.add_histogram('weights/' + tag,
                                 value.data.cpu().numpy(), global_step)
            writer.add_histogram('grads/' + tag,
                                 value.grad.data.cpu().numpy(), global_step)
        val_score = eval_net(net, val_loader, device)
        scheduler.step(val_score)
        writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'],
                          global_step)

        if net.n_classes > 1:
            logging.info('Validation cross entropy: {}'.format(val_score))
            writer.add_scalar('Loss/test', val_score, global_step)
        else:
            logging.info('Train Loss: {}    Validation Dice Coeff: {} '.format(
                epoch_loss / n_train, val_score))
            writer.add_scalar('Dice/test', val_score, global_step)

            writer.add_images('images', imgs, global_step)
            if net.n_classes == 1:
                writer.add_images('masks/true', true_masks, global_step)
                writer.add_images('masks/pred',
                                  torch.sigmoid(masks_pred) > 0.3, global_step)

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            if last_loss > epoch_loss or last_val_score < val_score:
                last_loss = min(last_loss, epoch_loss)
                last_val_score = max(last_val_score, val_score)
                # torch.save(net.state_dict(),
                torch.save(
                    net, dir_checkpoint +
                    'CP_epoch{}Trainloss{}ValDice{}.pt'.format(
                        epoch + 1, epoch_loss / n_train, val_score))
                logging.info('Checkpoint {} saved !'.format(epoch + 1) +
                             '   CP_epoch{}Trainloss{}ValDice{}.pt'.format(
                                 epoch + 1, epoch_loss / n_train, val_score))

    writer.close()
Пример #22
0
def train_net(dir_checkpoint,
              n_classes,
              n_channels,
              device,
              epochs=30,
              save_cp=True,
              img_scale=1):
    global best_val_iou_score
    global best_test_iou_score

    net = PAN()
    net.to(device=device)
    batch_size = 4
    lr = 1e-5
    writer = SummaryWriter(
        comment=
        f'_{net.__class__.__name__}_LR_{lr}_BS_{batch_size}_categoryFirstEntropy_ACQUISITION'
    )
    global_step = 0

    logging.basicConfig(
        filename="./logging_one32nd_category.txt",
        filemode='a',
        format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.DEBUG)

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Checkpoints:     {save_cp}
        Device:          {device.type}
    ''')

    optimizer = optim.RMSprop(net.parameters(),
                              lr=lr,
                              weight_decay=1e-8,
                              momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min' if n_classes > 1 else 'max', patience=2)
    if n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()
    num_phases = 25  # total 2689 imgs, within each phase: fetching 100 imgs to training set.
    training_pool_ids_path = "data_one32nd_category.json"
    all_training_data = "data_all.json"

    for phase in range(num_phases):
        # Within a phase, save the best epoch (having highest test_iou) checkpoint and save its test_iou to TF_Board
        #                 also, load the best right previous checkpoint
        selected_images = get_pool_data(training_pool_ids_path)
        data_train = RestrictedDataset(dir_img, dir_mask, selected_images)
        data_test = BasicDataset(imgs_dir=dir_img_test,
                                 masks_dir=dir_mask_test,
                                 train=False,
                                 scale=img_scale)

        train_loader = DataLoader(data_train,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=4,
                                  pin_memory=True)
        test_loader = DataLoader(data_test,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=2,
                                 pin_memory=True,
                                 drop_last=True)
        right_previous_ckpt_dir = Path(dir_checkpoint + 'ckpt.pth')
        if right_previous_ckpt_dir.is_file():
            net.load_state_dict(
                torch.load(dir_checkpoint + 'ckpt.pth', map_location=device))
        for epoch in range(epochs):
            net.train()
            epoch_loss = 0
            n_train = len(data_train)
            with tqdm(total=n_train,
                      desc=f'Epoch {epoch + 1}/{epochs}',
                      unit='img') as pbar:
                for batch in train_loader:
                    imgs = batch['image']
                    true_masks = batch['mask']
                    assert imgs.shape[1] == n_channels, \
                        f'Network has been defined with {n_channels} input channels, ' \
                        f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                        'the images are loaded correctly.'

                    imgs = imgs.to(device=device, dtype=torch.float32)
                    mask_type = torch.float32 if n_classes == 1 else torch.long

                    true_masks = true_masks.to(device=device, dtype=mask_type)
                    masks_pred = net(imgs)  # return BCHW = 8_1_256_256
                    _tem = net(imgs)
                    # print("IS DIFFERENT OR NOT: ", torch.sum(masks_pred - _tem))
                    true_masks = true_masks[:, :1, :, :]
                    loss = criterion(masks_pred, true_masks)
                    epoch_loss += loss.item()
                    # writer.add_scalar('Loss/train', loss.item(), global_step)
                    pbar.set_postfix(**{'loss (batch)': loss.item()})
                    optimizer.zero_grad()
                    loss.backward()
                    nn.utils.clip_grad_value_(net.parameters(), 0.1)
                    optimizer.step()
                    pbar.update(imgs.shape[0])
                    global_step += 1
            # Tính dice và iou score trên tập Test set, ghi vào tensorboard .
            test_score_dice, test_score_iou = eval_net(net, test_loader,
                                                       n_classes, device)
            if test_score_iou > best_test_iou_score:
                best_test_iou_score = test_score_iou
                try:
                    os.mkdir(dir_checkpoint)
                    logging.info('Created checkpoint directory')
                except OSError:
                    pass
                torch.save(
                    net.state_dict(),
                    dir_checkpoint + f'best_CP_epoch{epoch + 1}_one32th_.pth')
                logging.info(f'Checkpoint {epoch + 1} saved !')
            logging.info('Test Dice Coeff: {}'.format(test_score_dice))
            print('Test Dice Coeff: {}'.format(test_score_dice))
            writer.add_scalar(f'Phase_{phase}_Dice/test', test_score_dice,
                              epoch)

            logging.info('Test IOU : {}'.format(test_score_iou))
            print('Test IOU : {}'.format(test_score_iou))
            writer.add_scalar(f'Phase_{phase}_IOU/test', test_score_iou, epoch)
        print(f"Phase_{phase}_best iou: ", best_test_iou_score)
        torch.save(net.state_dict(), dir_checkpoint + 'ckpt.pth')
        writer.add_scalar('Phase_IOU/test', best_test_iou_score, phase)
        # Fetching data for next phase - Update pooling images.
        update_training_pool_ids_2(net,
                                   training_pool_ids_path,
                                   all_training_data,
                                   device,
                                   acquisition_func="cfe")

    writer.close()
Пример #23
0
def train_net(net,
              device,
              epochs=5,
              batch_size=1,
              lr=0.1,
              val_percent=0.1,
              save_cp=True,
              img_scale=0.5):

    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True)

    #writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images scaling:  {img_scale}
    ''')

    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8)
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()
    best_score = 0
    for epoch in range(epochs):
        net.train()
        epoch_loss = 0
        start = time.time()
        with tqdm(total=n_train, desc=f'Epoch {epoch}', unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                #writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                pbar.update(imgs.shape[0])
        cost_time = time.time() - start
        logging.info(f"{epoch} loss: {epoch_loss:.5f} time {cost_time:.3f}s")
        val_score = eval_net(net, val_loader, device, n_val)
        if net.n_classes > 1:
            logging.info('Validation cross entropy: {:.5f}'.format(val_score))
            #writer.add_scalar('Loss/test', val_score, global_step)
        else:
            logging.info('Validation Dice Coeff: {:.5f}'.format(val_score))
            #writer.add_scalar('Dice/test', val_score, global_step)
            #writer.add_images('images', imgs, global_step)
            # if net.n_classes == 1:
            #     writer.add_images('masks/true', true_masks, global_step)
            #     writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
        if val_score > best_score:
            torch.save(net.state_dict(), log_dir + '/best.pth')
            best_score = val_score
            logging.info(f'best improved to {val_score:.5f}')
        torch.save(net.state_dict(), log_dir + "/latest.pth")
Пример #24
0
def train_net(net,
              device,
              figpath,
              epochs=5,
              batch_size=1,
              lr=0.001,
              val_percent=0.1,
              save_cp=True,
              img_scale=0.5,
              img_size=512,
              noise_fraction=0):

    dir_img = 'ISIC-2017_Training_Data/'
    dir_mask = 'ISIC-2017_Training_Part1_GroundTruth'
    dir_val_img = 'ISIC-2017_Training_Data_validation/'
    dir_val_mask = 'ISIC-2017_Training_Part1_GroundTruth_validation/'
    dir_cle_img = 'ISIC-2017_Training_Data_clean/'
    dir_cle_mask = 'ISIC-2017_Training_Part1_GroundTruth_validation_clean/'
    dir_checkpoint = 'checkpoints/'

    if noise_fraction != 0:
        dir_mask = dir_mask + '_' + str(noise_fraction) + '/'
        print(dir_mask)
    else:
        dir_mask = dir_mask + '/'
        print(dir_mask)

    train = BasicDataset(dir_img, dir_mask, img_scale, img_size)
    val = BasicDataset(dir_val_img, dir_val_mask, img_scale, img_size)
    cle = BasicDataset(dir_cle_img, dir_cle_mask, img_scale, img_size)
    # n_val = int(len(dataset) * val_percent)
    # n_train = len(dataset) - n_val
    # train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)
    cle_loader = DataLoader(cle,
                            batch_size=5,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True)

    batch = next(iter(cle_loader))
    clean_data = batch['image']
    clean_labels = batch['mask']
    clean_data = clean_data.to(device=device, dtype=torch.float32)
    clean_labels = clean_labels.to(device=device, dtype=torch.float32)
    # clean_data = clean_data.cuda()
    # clean_labels = clean_labels.cuda()

    writer = SummaryWriter(
        comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
    global_step = 0
    net_losses = []
    acc_test = []
    acc_train = []
    dice_train = []
    dice_test = []
    loss_train = []
    num_batch = len(train_loader)

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {len(train)}
        Validation size: {len(val)}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images scaling:  {img_scale}
        Images size:     {img_size}
        Noise fraction:  {noise_fraction}
    ''')

    # optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    optimizer = optim.RMSprop(net.parameters(),
                              lr=lr,
                              weight_decay=0,
                              momentum=0.99)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
    if net.n_classes > 1:
        # criterion = nn.CrossEntropyLoss()
        criterion = nn.CrossEntropyLoss(reduction="none")
    else:
        # criterion = nn.BCEWithLogitsLoss()
        criterion = nn.BCEWithLogitsLoss(reduction="none")

    for epoch in range(epochs):
        net.train()
        tot = 0
        num_val = 0
        tot_val = 0
        epoch_loss = 0
        with tqdm(total=len(train),
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                with higher.innerloop_ctx(net,
                                          optimizer) as (meta_net, meta_opt):
                    y_f_hat = meta_net(imgs)
                    loss = criterion(y_f_hat, true_masks[:, 0:1])
                    eps = torch.zeros(loss.size()).cuda()
                    eps = eps.requires_grad_()
                    l_f_meta = torch.sum(loss * eps)
                    meta_opt.step(l_f_meta)

                    y_g_hat = meta_net(clean_data)
                    l_g_meta = torch.mean(
                        criterion(y_g_hat, clean_labels[:, 0:1]))
                    grad_eps = torch.autograd.grad(
                        l_g_meta, eps, only_inputs=True,
                        allow_unused=True)[0].detach()

                w_tild = torch.clamp(-grad_eps, min=0)
                norm_c = torch.sum(w_tild)

                if norm_c != 0:
                    w = w_tild / norm_c
                else:
                    w = w_tild

                masks_pred = net(imgs)
                pred = torch.sigmoid(masks_pred)
                pred = (pred > 0.5).float()
                # print(pred.size())
                # print(true_masks[:, 0:1].size())
                tot += dice_coeff(pred, true_masks[:, 0:1]).item()
                dice_train.append(dice_coeff(pred, true_masks[:, 0:1]).item())
                writer.add_scalar('Dice/train',
                                  dice_coeff(pred, true_masks[:, 0:1]).item(),
                                  global_step)

                if dice_coeff(pred, true_masks[:, 0:1]).item() <= 0.3:
                    writer.add_images('masks/true', true_masks[:, 0:1],
                                      global_step)
                    writer.add_images('masks/pred', pred, global_step)

                cost = criterion(masks_pred, true_masks[:, 0:1])
                loss = torch.sum(cost * w)
                epoch_loss += loss.item()
                net_losses.append(loss.item())
                writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1

                if global_step % (len(train) // (10 * batch_size)) == 0:
                    num_val += 1
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag,
                                             value.data.cpu().numpy(),
                                             global_step)
                        writer.add_histogram('grads/' + tag,
                                             value.grad.data.cpu().numpy(),
                                             global_step)
                    val_score = eval_net(net, val_loader, device)
                    dice_test.append(val_score)
                    tot_val += val_score
                    scheduler.step(val_score)
                    writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'],
                                      global_step)

                    if net.n_classes > 1:
                        logging.info(
                            'Validation cross entropy: {}'.format(val_score))
                        writer.add_scalar('Loss/test', val_score, global_step)
                    else:
                        logging.info(
                            'Validation Dice Coeff: {}'.format(val_score))
                        print('Step Validation Dice: ', val_score)
                        writer.add_scalar('Dice/test', val_score, global_step)

                    # writer.add_images('images', imgs, global_step)
                    # if net.n_classes == 1:
                    #     writer.add_images('masks/true', true_masks, global_step)
                    #     writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)

        print('Epoch: ', epoch)
        print('Epoch Loss: ', epoch_loss / num_batch)
        loss_train.append(epoch_loss / num_batch)

        print('Train EpochDice: ', tot / num_batch)
        acc_train.append(tot / num_batch)
        writer.add_scalar('EpochDice/train', tot / num_batch, epoch)

        print('Val EpochDice: ', tot_val / num_val)
        acc_test.append(tot_val / num_val)
        writer.add_scalar('EpochDice/test', tot_val / num_val, epoch)

        path = dir_checkpoint + figpath + '_' + str(epoch) + '_model.pth'
        # path = 'baseline/' + str(args.noise_fraction) + '/model.pth'
        torch.save(net.state_dict(), path)

    IPython.display.clear_output()
    fig, axes = plt.subplots(3, 2, figsize=(13, 5))
    ax1, ax2, ax3, ax4, ax5, ax6 = axes.ravel()

    ax1.plot(net_losses, label='iteration_losses')
    ax1.set_ylabel("Losses")
    ax1.set_xlabel("Iteration")
    ax1.legend()

    ax2.plot(loss_train, label='epoch_losses')
    ax2.set_ylabel('Losses')
    ax2.set_xlabel('Epoch')
    ax2.legend()

    ax3.plot(acc_train, label='dice_train_epoch')
    ax3.set_ylabel('EpochDice/train')
    ax3.set_xlabel('Epoch')
    ax3.legend()

    ax4.plot(acc_test, label='dice_test_epoch')
    ax4.set_ylabel('EpochDice/test')
    ax4.set_xlabel('Epoch')
    ax4.legend()

    ax5.plot(dice_train, label='dice_train_iteration')
    ax5.set_ylabel('IterationDice/train')
    ax5.set_xlabel('Iteration')
    ax5.legend()

    ax6.plot(dice_test, label='dice_test_iteration')
    ax6.set_ylabel('IterationDice/test')
    ax6.set_xlabel('Iteration')
    ax6.legend()

    plt.savefig(args.figpath + '.png')

    writer.close()
    return net
Пример #25
0
def train_net(net,
              epochs=5,
              batch_size=1,
              lr=0.01,
              val_percent=0.05,
              save_cp=True,
              gpu=True):

    # Define directories
    dir_img = 'E:/Dataset/Dataset10k/images/training/'
    dir_mask = 'E:/Dataset/Dataset10k/annotations/training/'

    val_dir_img = 'E:/Dataset/Dataset10k/images/validation/'
    val_dir_mask = 'E:/Dataset/Dataset10k/annotations/validation/'

    dir_checkpoint = 'checkpoints/'

    # Get list of images and annotations
    train_images = os.listdir(dir_img)
    train_masks = os.listdir(dir_mask)
    train_size = len(train_images)

    val_images = os.listdir(val_dir_img)
    val_masks = os.listdir(val_dir_mask)
    val_size = len(val_images)

    val_imgs = np.array([read_image(val_dir_img + i)
                         for i in val_images]).astype(np.float32)
    val_true_masks = np.array(
        [read_masks(val_dir_mask + i) for i in val_masks])
    val = zip(val_imgs, val_true_masks)

    print('''
    Starting training:
        Epochs: {}
        Batch size: {}
        Learning rate: {}
        Training size: {}
        Validation size: {}
        Checkpoints: {}
        CUDA: {}
    '''.format(epochs, batch_size, lr, train_size, val_size, str(save_cp),
               str(gpu)))

    # Define optimizer and loss functions
    optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=0.0005)

    criterion = nn.BCELoss()

    # Start training epochs
    for epoch in range(epochs):
        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
        net.train()

        epoch_loss = 0

        for i in range(round(train_size // batch_size)):
            imgs = train_images[i:i + batch_size]
            true_masks = train_masks[i:i + batch_size]

            imgs = np.array([read_image(dir_img + i)
                             for i in imgs]).astype(np.float32)
            true_masks = np.array(
                [read_masks(dir_mask + i) for i in true_masks])

            imgs = torch.from_numpy(imgs)
            true_masks = torch.from_numpy(true_masks)

            print(imgs.size(), true_masks.size())

            if gpu:
                imgs = imgs.cuda()
                true_masks = true_masks.cuda()

            masks_pred = net(imgs)
            print(masks_pred.size())

            masks_probs_flat = masks_pred.view(-1)
            print(masks_probs_flat.size())

            true_masks_flat = true_masks.view(-1)
            print(true_masks_flat.size())

            loss = criterion(masks_probs_flat, true_masks_flat)
            epoch_loss += loss.item()

            print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size,
                                                     loss.item()))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('Epoch finished ! Loss: {}'.format(np.mean(epoch_loss)))

        if 1:
            val_dice = eval_net(net, val, gpu)
            print('Validation Dice Coeff: {}'.format(val_dice))

        if save_cp:
            torch.save(net.state_dict(),
                       dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
            print('Checkpoint {} saved !'.format(epoch + 1))
def train_nets(gen_net, gen_optimizer, gen_scheduler, args):

    # if args.dataset == 'Aspect':
    #     train_dataset = AspectDataset(args.train_dir, args)
    #     val_dataset = AspectDataset(args.val_dir, args, validtion_flag=True)

    if args.dataset == 'IXI':
        train_dataset = IXIataset(args.train_dir, args)
        val_dataset = IXIataset(args.val_dir, args, validtion_flag=True)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batchsize,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=True)
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batchsize,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        drop_last=True
    )  #shuffle is true just for the diffrent images on tensorboard

    #TODO: better name for checkpoints dir
    writer = SummaryWriter(log_dir=dir_checkpoint + '/runs',
                           comment=f'LR_{args.lr}_BS_{args.batchsize}')

    logging.info(f'''Starting training:
        Epochs:          {args.epochs_n}
        Batch size:      {args.batchsize}
        Learning rate:   {args.lr}
        Checkpoints:     {args.save_cp}
        Device:          {args.device}
    ''')

    gen_net.to(device=device)
    start_epoch = 0

    if args.load:
        checkpoint = torch.load(args.load, map_location=args.device)
        gen_net.load_state_dict(checkpoint['model_state_dict'])

        if args.load_scheduler_optimizer:
            start_epoch = int(checkpoint['epoch'])
            gen_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            gen_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            logging.info(
                f'Model, optimizer and scheduler load from {args.load}')
        else:
            logging.info(f'Model only load from {args.load}')

    criterion = netLoss(args)

    for epoch in range(start_epoch, args.epochs_n):
        gen_net.train()
        epoch_loss = 0
        progress_img = 0
        with tqdm(desc=f'Epoch {epoch + 1}/{args.epochs_n}',
                  unit=' imgs') as pbar:
            #train
            for batch in train_loader:

                masked_Kspaces = batch['masked_Kspaces']
                target_Kspace = batch['target_Kspace']
                target_img = batch['target_img']

                masked_Kspaces = masked_Kspaces.to(device=args.device,
                                                   dtype=torch.float32)
                target_Kspace = target_Kspace.to(device=args.device,
                                                 dtype=torch.float32)
                target_img = target_img.to(device=args.device,
                                           dtype=torch.float32)

                rec_img, rec_Kspace, F_rec_Kspace = gen_net(masked_Kspaces)

                FullLoss, ImL2, ImL1, KspaceL2 = criterion.calc(
                    rec_img, rec_Kspace, target_img, target_Kspace)

                epoch_loss += FullLoss.item()
                writer.add_scalar('train/FullLoss', FullLoss.item(), epoch)
                writer.add_scalar('train/ImL2', ImL2.item(), epoch)
                writer.add_scalar('train/ImL1', ImL1.item(), epoch)
                writer.add_scalar('train/KspaceL2', KspaceL2.item(), epoch)

                progress_img += 100 * target_Kspace.shape[0] / len(
                    train_dataset)
                pbar.set_postfix(
                    **{
                        'FullLoss': FullLoss.item(),
                        'ImL2': ImL2.item(),
                        'ImL1': ImL1.item(),
                        'KspaceL2': KspaceL2.item(),
                        'Prctg of train set': progress_img
                    })

                gen_optimizer.zero_grad()
                FullLoss.backward()
                #TODO: Do we need this clipping?
                nn.utils.clip_grad_value_(gen_net.parameters(), 0.1)
                gen_optimizer.step()

                pbar.update(target_Kspace.shape[0])  # current batch size

            # if epoch:
            writer.add_images('train/Fully_sampled_images', target_img, epoch)
            writer.add_images('train/rec_images', rec_img, epoch)
            writer.add_images('train/Kspace_rec_images', F_rec_Kspace, epoch)

            for tag, value in gen_net.named_parameters():
                tag = tag.replace('.', '/')
                writer.add_histogram('weights/' + tag,
                                     value.data.cpu().numpy(), epoch)
                writer.add_histogram('grads/' + tag,
                                     value.grad.data.cpu().numpy(), epoch)

            # validation:
            val_rec_img, val_full_img, val_F_rec_Kspace, val_FullLoss, val_ImL2, val_ImL1, val_KspaceL2, val_PSNR =\
                eval_net(gen_net, val_loader, criterion, args.device)
            gen_scheduler.step(val_FullLoss)

            writer.add_images('validation/Fully_sampled_images', val_full_img,
                              epoch)
            writer.add_images('validation/rec_images', val_rec_img, epoch)
            writer.add_images('validation/Kspace_rec_images', val_F_rec_Kspace,
                              epoch)

            writer.add_scalar('learning_rate',
                              gen_optimizer.param_groups[0]['lr'], epoch)

            logging.info(
                'Validation full score: {}, ImL2: {}. ImL1: {}, KspaceL2: {}, PSNR: {}'
                .format(val_FullLoss, val_ImL2, val_ImL1, val_KspaceL2,
                        val_PSNR))
            writer.add_scalar('validation/FullLoss', val_FullLoss, epoch)
            writer.add_scalar('validation/ImL2', val_ImL2, epoch)
            writer.add_scalar('validation/ImL2', val_ImL2, epoch)
            writer.add_scalar('validation/ImL1', val_ImL1, epoch)
            writer.add_scalar('validation/KspaceL2', val_KspaceL2, epoch)
            writer.add_scalar('validation/PSNR', val_PSNR, epoch)

        if args.save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': gen_net.state_dict(),
                    'optimizer_state_dict': gen_optimizer.state_dict(),
                    'scheduler_state_dict': gen_scheduler.state_dict(),
                }, dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')
    writer.close()
Пример #27
0
def train_net(net,
              device,
              epochs=5,
              batch_size=1,
              lr=0.1,
              val_percent=0.1,
              save_cp=True,
              img_scale=0.5,
              data_augment=True):

    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=0,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=0,
                            pin_memory=True)

    global_step = 0

    logging.info(f'''Starting training:
    Epochs:          {epochs}
    Batch size:      {batch_size}
    Learning rate:   {lr}
    Training size:   {n_train}
    Validation size: {n_val}
    Checkpoints:     {save_cp}
    Device:          {device.type}
    Images scaling:  {img_scale}
    ''')

    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8)
    criterion = nn.BCEWithLogitsLoss()  # 1 class
    best_score = 0.

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                assert true_masks.shape[1] == net.n_classes, \
                    f'Network has been defined with {net.n_classes} output classes, ' \
                    f'but loaded masks have {true_masks.shape[1]} channels. Please check that ' \
                    'the masks are loaded correctly.'

                if data_augment:
                    for i in range(imgs.__len__()):
                        imgs[i], true_masks[i] = my_segmentation_transforms(
                            imgs[i], true_masks[i])

                imgs = imgs.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.float32)

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1

                if global_step % (len(dataset) // (10 * batch_size)) == 0:
                    val_score = eval_net(net, val_loader, device, n_val)
                    logging.info('Validation Dice Coeff: {}'.format(val_score))
                    print(" ")
                    print('Validation Dice Coeff: {}'.format(val_score))

        if best_score < val_score:
            torch.save(net.state_dict(), 'BEST.pth')
            logging.info(f'Best saved !')
            best_score = val_score

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')
Пример #28
0
def train_net(net,
              epochs=5,
              batch_size=10,
              lr=0.1,
              val_percent=0.05,
              cp=True,
              gpu=False,
              mask_type="depth",
              half_scale=True):
    prefix = "/data/chc631/project/"
    dir_img = prefix + 'data/train/'
    # use depth map as target
    if mask_type == "depth":
        dir_mask = prefix + "data/train_masks_depth_map/"
    # use color map as target
    else:
        dir_mask = prefix + 'data/train_masks/'
    dir_checkpoint = "/data/chc631/project/data/checkpoints/" + options.dir
    if not os.path.exists(dir_checkpoint):
        os.makedirs(dir_checkpoint)

    ids = get_ids(dir_img)
    ids = split_ids(ids)

    iddataset = split_train_val(ids, val_percent)

    print('''
    Starting training:
        Epochs: {}
        Batch size: {}
        Learning rate: {}
        Training size: {}
        Validation size: {}
        Checkpoints: {}
        CUDA: {}
    '''.format(epochs, batch_size, lr, len(iddataset['train']),
               len(iddataset['val']), str(cp), str(gpu)))

    N_train = len(iddataset['train'])
    optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=0.0005)
    criterion = scaleInvarLoss()

    for epoch in range(epochs):
        net.train()
        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
        epoch_loss = 0

        if half_scale:
            print("half_scale")
            train = get_imgs_and_masks(iddataset['train'],
                                       dir_img,
                                       dir_mask,
                                       scale=0.5)
            val = get_imgs_and_masks(iddataset['val'],
                                     dir_img,
                                     dir_mask,
                                     scale=0.5)
        else:
            train = get_imgs_and_masks(iddataset['train'],
                                       dir_img,
                                       dir_mask,
                                       scale=1)
            val = get_imgs_and_masks(iddataset['val'],
                                     dir_img,
                                     dir_mask,
                                     scale=1)
        # train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
        for i, b in enumerate(batch(train, batch_size)):
            X = np.array([i[0] for i in b])
            y = np.array([i[1] for i in b])

            X = torch.FloatTensor(X)
            y = torch.FloatTensor(y)
            y = y.unsqueeze(
                0)  # manually create a channel dimension for conv2d
            y = y.transpose(0, 1)

            if gpu:
                X = Variable(X).cuda()
                y = Variable(y).cuda()
            else:
                X = Variable(X)
                y = Variable(y)

            y_pred = net(X)
            y_pred_flat = y_pred.view(-1)

            if half_scale:
                conv_mat = Variable(torch.ones(1, 1, 2, 2)).cuda()
                y = F.conv2d(y, conv_mat, stride=2)
                y = torch.squeeze(y)

            y_flat = y.view(-1)

            loss = criterion(y_pred_flat, y_flat.float())
            epoch_loss += loss.data[0]

            print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train,
                                                     loss.data[0]))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('Epoch finished ! Loss: {}'.format(epoch_loss / i))

        if cp:
            torch.save(net.state_dict(),
                       dir_checkpoint + "/" + 'CP{}.pth'.format(epoch + 1))

            print('Checkpoint {} saved !'.format(epoch + 1))
            val_err = eval_net(net, val, gpu, half_scale)
            print('Validation Error: {}'.format(val_err))
            with open(dir_checkpoint + "/ValidationError.txt", 'a') as outfile:
                outfile.write(str(val_err) + '\n')
            with open(dir_checkpoint + "/TrainingError.txt", 'a') as outfile:
                outfile.write(str(epoch_loss / i) + '\n')
Пример #29
0
def train_net(net,
              device,
              epochs=5,
              batch_size=1,
              lr=0.001,
              val_percent=0.1,
              save_cp=True,
              img_scale=0.5):

    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)

    writer = SummaryWriter(
        comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
    global_step = 0

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images scaling:  {img_scale}
    ''')

    optimizer = optim.RMSprop(net.parameters(),
                              lr=lr,
                              weight_decay=1e-8,
                              momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % (n_train // (10 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag,
                                             value.data.cpu().numpy(),
                                             global_step)
                        writer.add_histogram('grads/' + tag,
                                             value.grad.data.cpu().numpy(),
                                             global_step)
                    val_score = eval_net(net, val_loader, device)
                    scheduler.step(val_score)
                    writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'],
                                      global_step)

                    if net.n_classes > 1:
                        logging.info(
                            'Validation cross entropy: {}'.format(val_score))
                        writer.add_scalar('Loss/test', val_score, global_step)
                    else:
                        logging.info(
                            'Validation Dice Coeff: {}'.format(val_score))
                        writer.add_scalar('Dice/test', val_score, global_step)

                    writer.add_images('images', imgs, global_step)
                    if net.n_classes == 1:
                        writer.add_images('masks/true', true_masks,
                                          global_step)
                        writer.add_images('masks/pred',
                                          torch.sigmoid(masks_pred) > 0.5,
                                          global_step)

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()
Пример #30
0
def train_net(net, trainset, valset, device, epochs, batch_size, lr,
              weight_decay, log_save_path):

    train_loader = DataLoader(trainset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=True)
    val_loader = DataLoader(valset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=2,
                            pin_memory=True,
                            drop_last=True)

    writer = SummaryWriter(log_dir=log_save_path)

    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer,
                                                 gamma=0.95)
    criterion = DiceBCELoss()

    best_DSC = 0.0
    for epoch in range(epochs):
        logging.info(f'Epoch {epoch + 1}')
        epoch_loss = 0
        epoch_dice = 0
        with tqdm(total=len(trainset),
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                net.train()
                imgs = batch['image']
                true_masks = batch['mask']

                imgs = imgs.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.float32)
                masks_pred = net(imgs)

                pred = torch.sigmoid(masks_pred)
                pred = (pred > 0.5).float()
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                epoch_dice += dice_coeff(pred, true_masks).item()
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 5)
                optimizer.step()

                pbar.set_postfix(**{'loss (batch)': loss.item()})
                pbar.update(imgs.shape[0])

        scheduler.step()

        logging.info('Training loss:   {}'.format(epoch_loss /
                                                  len(train_loader)))
        writer.add_scalar('Train/loss', epoch_loss / len(train_loader), epoch)
        logging.info('Training DSC:    {}'.format(epoch_dice /
                                                  len(train_loader)))
        writer.add_scalar('Train/dice', epoch_dice / len(train_loader), epoch)

        val_dice, val_loss = eval_net(net, val_loader, device, criterion)
        logging.info('Validation Loss: {}'.format(val_loss))
        writer.add_scalar('Val/loss', val_loss, epoch)
        logging.info('Validation DSC:  {}'.format(val_dice))
        writer.add_scalar('Val/dice', val_dice, epoch)

        writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'],
                          epoch)

        # writer.add_images('images', imgs, epoch)
        writer.add_images('masks/true', true_masks, epoch)
        writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, epoch)

    writer.close()