def __init__(self, filePathTrain):
        # Hyperparameters
        self.batchSize = 1
        self.numEpochs = 10
        self.learningRate = 0.001
        self.validPercent = 0.1
        self.trainShuffle = True
        self.testShuffle = False
        self.momentum = 0.99
        self.imageDim = 128

        # Variables
        self.imageDirectory = filePathTrain
        self.labelDirectory = filePathTrain
        self.numChannels = 3
        self.numClasses = 1

        # Device configuration
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        # Load dataset
        self.trainLoader = self.getTrainingLoader()
        #self.testLoader  = self.getTestLoader()

        # Setup model
        self.model = UNet(n_channels=self.numChannels,
                          n_classes=self.numClasses,
                          bilinear=True).to(self.device)
        #self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=self.learningRate, weight_decay=self.weightDecay, momentum=self.momentum)
        #self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learningRate)
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=self.learningRate,
                                         momentum=self.momentum)
        #self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min' if self.numClasses > 1 else 'max', patience=2)
        self.criterion = DiceLoss()
Пример #2
0
def main(args):

    global run
    run = Run.get_context()

    print("Current directory:", os.getcwd())
    print("Data directory:", args.images)
    print("Training directory content:", os.listdir(args.images))

    makedirs(args)
    snapshotargs(args)

    device = torch.device(
        "cpu" if not torch.cuda.is_available() else args.device)
    print("Using device:", device)

    loader_train, loader_valid = data_loaders(args)
    loaders = {"train": loader_train, "valid": loader_valid}

    unet = UNet(in_channels=Dataset.in_channels,
                out_channels=Dataset.out_channels)

    unet = unet.to(device)
    unet = torch.nn.DataParallel(unet)

    dsc_loss = DiceLoss()
    best_validation_dsc = 0.0

    optimizer = optim.Adam(unet.parameters(), lr=args.lr)

    logger = Logger(args.logs)
    loss_train = []
    loss_valid = []

    step = 0

    for epoch in tqdm(range(args.epochs), total=args.epochs):
        for phase in ["train", "valid"]:

            start = time.time()

            if phase == "train":
                unet.train()
            else:
                unet.eval()

            validation_pred = []
            validation_true = []

            for i, data in enumerate(loaders[phase]):
                if phase == "train":
                    step += 1

                x, y_true = data
                x, y_true = x.to(device), y_true.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    y_pred = unet(x)

                    loss = dsc_loss(y_pred, y_true)

                    if phase == "valid":
                        loss_valid.append(loss.item())
                        y_pred_np = y_pred.detach().cpu().numpy()
                        validation_pred.extend(
                            [y_pred_np[s] for s in range(y_pred_np.shape[0])])
                        y_true_np = y_true.detach().cpu().numpy()
                        validation_true.extend(
                            [y_true_np[s] for s in range(y_true_np.shape[0])])
                        if (epoch % args.vis_freq
                                == 0) or (epoch == args.epochs - 1):
                            if i * args.batch_size < args.vis_images:
                                tag = "image/{}".format(i)
                                num_images = args.vis_images - i * args.batch_size
                                logger.image_list_summary(
                                    tag,
                                    log_images(x, y_true, y_pred)[:num_images],
                                    step,
                                )

                    if phase == "train":
                        loss_train.append(loss.item())
                        loss.backward()
                        optimizer.step()

                if phase == "train" and (step + 1) % 10 == 0:
                    log_loss_summary(logger, loss_train, step)
                    loss_train = []

            if phase == "valid":
                log_loss_summary(logger, loss_valid, step, prefix="val_")
                mean_dsc = np.mean(
                    dsc_per_volume(
                        validation_pred,
                        validation_true,
                        loader_valid.dataset.patient_slice_index,
                    ))
                logger.scalar_summary("val_dsc", mean_dsc, step)
                if mean_dsc > best_validation_dsc:
                    best_validation_dsc = mean_dsc
                    #torch.save(unet.state_dict(), os.path.join(args.weights, "unet_epoch_" + str(epoch) + ".pt"))
                    torch.save(unet.state_dict(),
                               os.path.join(args.weights, "unet.pt"))
                loss_valid = []

            run.log("time_" + phase, time.time() - start)

    print("Best validation mean DSC: {:4f}".format(best_validation_dsc))
    run.log("best_validation_mean_dsv", best_validation_dsc)
Пример #3
0
    def __init__(self, configs):
        self.batch_size = configs.get("batch_size", "16")
        self.epochs = configs.get("epochs", "100")
        self.lr = configs.get("lr", "0.0001")

        device_args = configs.get("device", "cuda")
        self.device = torch.device(
            "cpu" if not torch.cuda.is_available() else device_args)

        self.workers = configs.get("workers", "4")

        self.vis_images = configs.get("vis_images", "200")
        self.vis_freq = configs.get("vis_freq", "10")

        self.weights = configs.get("weights", "./weights")
        if not os.path.exists(self.weights):
            os.mkdir(self.weights)

        self.logs = configs.get("logs", "./logs")
        if not os.path.exists(self.weights):
            os.mkdir(self.weights)

        self.images_path = configs.get("images_path", "./data")

        self.is_resize = config.get("is_resize", False)
        self.image_short_side = config.get("image_short_side", 256)

        self.is_padding = config.get("is_padding", False)

        is_multi_gpu = config.get("DateParallel", False)

        pre_train = config.get("pre_train", False)
        model_path = config.get("model_path", './weights/unet_idcard_adam.pth')

        # self.image_size = configs.get("image_size", "256")
        # self.aug_scale = configs.get("aug_scale", "0.05")
        # self.aug_angle = configs.get("aug_angle", "15")

        self.step = 0

        self.dsc_loss = DiceLoss()
        self.model = UNet(in_channels=Dataset.in_channels,
                          out_channels=Dataset.out_channels)
        if pre_train:
            self.model.load_state_dict(torch.load(model_path,
                                                  map_location=self.device),
                                       strict=False)

        if is_multi_gpu:
            self.model = nn.DataParallel(self.model)

        self.model.to(self.device)

        self.best_validation_dsc = 0.0

        self.loader_train, self.loader_valid = self.data_loaders()

        self.params = [p for p in self.model.parameters() if p.requires_grad]

        self.optimizer = optim.Adam(self.params,
                                    lr=self.lr,
                                    weight_decay=0.0005)
        # self.optimizer = torch.optim.SGD(self.params, lr=self.lr, momentum=0.9, weight_decay=0.0005)
        self.scheduler = lr_scheduler.LR_Scheduler_Head(
            'poly', self.lr, self.epochs, len(self.loader_train))
Пример #4
0
        weight = Variable(weight.cuda())

    else:
        weight = args.weight  # weight is None

    print("weight: {}".format(weight))

    # criterion
    if args.criterion == 'nll':
        criterion = nn.NLLLoss(weight=weight)
    elif args.criterion == 'ce':
        criterion = nn.CrossEntropyLoss(weight=weight)
    elif args.criterion == 'dice':
        criterion = DiceLoss(weight=weight,
                             ignore_index=None,
                             weight_type=args.weight_type,
                             cal_zerogt=args.cal_zerogt)

    elif args.criterion == 'gdl_inv_square':
        criterion = GeneralizedDiceLoss(weight=weight,
                                        ignore_index=None,
                                        weight_type='inv_square',
                                        alpha=args.alpha)
    elif args.criterion == 'gdl_others_one_gt':
        criterion = GeneralizedDiceLoss(weight=weight,
                                        ignore_index=None,
                                        weight_type='others_one_gt',
                                        alpha=args.alpha)
    elif args.criterion == 'gdl_others_one_pred':
        criterion = GeneralizedDiceLoss(weight=weight,
                                        ignore_index=None,
Пример #5
0
def train(args, model, optimizer, dataloader_train, dataloader_val):
    # E' l'oggetto che ci stampa a schermo ciò chee acca
    writer = SummaryWriter(
        comment=''.format(args.optimizer, args.context_path))
    # settiamo la loss
    if args.loss == 'dice':
        # classe definita da loro nel file loss.py
        loss_func = DiceLoss()
    elif args.loss == 'crossentropy':
        loss_func = torch.nn.CrossEntropyLoss(ignore_index=255)
    # inizializziamo i contatori
    max_miou = 0
    step = 0
    # iniziamo il training
    for epoch in range(args.num_epochs):
        # inizializziamo il learning rate
        lr = poly_lr_scheduler(optimizer,
                               args.learning_rate,
                               iter=epoch,
                               max_iter=args.num_epochs)
        # iniziamo il train
        model.train()
        # cosa grafica sequenziale
        tq = tqdm(total=len(dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr %f' % (epoch, lr))
        # Crediamo che sia la lista delle loss di ogni batch:
        loss_record = []

        # per ogni immagine o per ogni batch??? Ipotizziamo sia su ogni singolo mini-batch
        for i, (data, label) in enumerate(dataloader_train):

            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda().long()

            # Prendiamo:
            # - risultato finale dopo FFM
            # - risultato del 16xdown del contextPath, dopo ARM, modificati (?)
            # - risultato del 32xdown del contextPath, dopo ARM, modificati (?)
            output, output_sup1, output_sup2 = model(data)

            # Calcoliammo la loss
            # Principal loss function (l_p in the paper):
            loss1 = loss_func(output, label)
            # Auxilary loss functions (l_i, for i=2, 3 in the paper):
            loss2 = loss_func(output_sup1, label)
            loss3 = loss_func(output_sup2, label)

            # alfa = 1, compute equation 2:
            loss = loss1 + loss2 + loss3

            # codice grafica
            tq.update(args.batch_size)
            tq.set_postfix(loss='%.6f' % loss)
            '''
            zero_grad clears old gradients from the last step (otherwise you’d just accumulate the gradients from all loss.backward() calls).
            loss.backward() computes the derivative of the loss w.r.t. the parameters (or anything requiring gradients) using backpropagation.
            opt.step() causes the optimizer to take a step based on the gradients of the parameters.
            '''
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # incrementiamo il contatore
            step += 1
            # aggiungiamo i valori per il grafico
            writer.add_scalar('loss_step', loss, step)
            loss_record.append(loss.item())
        tq.close()
        loss_train_mean = np.mean(loss_record)
        writer.add_scalar('epoch/loss_epoch_train', float(loss_train_mean),
                          epoch)
        print('loss for train : %f' % (loss_train_mean))

        # salva il modello fin ora trainato
        if epoch % args.checkpoint_step == 0 and epoch != 0:
            import os
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            torch.save(model.state_dict(),
                       os.path.join(args.save_model_path, 'model.pth'))

        # compute validation every 10 epochs
        if epoch % args.validation_step == 0 and epoch != 0:

            # chaiam la funzione val che da in output le metriche
            precision, miou = val(args, model, dataloader_val)

            # salva miou max e salva il relativo miglior modello
            if miou > max_miou:
                max_miou = miou
                import os
                os.makedirs(args.save_model_path, exist_ok=True)
                torch.save(
                    model.state_dict(),
                    os.path.join(args.save_model_path, 'best_dice_loss.pth'))

            writer.add_scalar('epoch/precision_val', precision, epoch)
            writer.add_scalar('epoch/miou val', miou, epoch)
    # proviamo a terminare il writer per vedere se stampa qualcosa
    writer.close()
def main(args):
    makedirs(args)
    snapshotargs(args)
    device = torch.device("cpu" if not torch.cuda.is_available() else args.device)

    loader_train, loader_valid = data_loaders(args)
    loaders = {"train": loader_train, "valid": loader_valid}

    unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
    unet.to(device)

    dsc_loss = DiceLoss()
    best_validation_dsc = 0.0

    optimizer = optim.Adam(unet.parameters(), lr=args.lr)
    print("Learning rate = ", args.lr)                     #AP knowing lr
    print("Batch-size = ", args.batch_size)  # AP knowing batch-size
    print("Number of visualization images to save in log file = ", args.vis_images)  # AP knowing batch-size

    logger = Logger(args.logs)
    loss_train = []
    loss_valid = []

    step = 0

    for epoch in tqdm(range(args.epochs), total=args.epochs):
        for phase in ["train", "valid"]:
            if phase == "train":
                unet.train()
            else:
                unet.eval()

            validation_pred = []
            validation_true = []

            for i, data in enumerate(loaders[phase]):
                if phase == "train":
                    step += 1

                x, y_true = data
                x, y_true = x.to(device), y_true.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    y_pred = unet(x)

                    loss = dsc_loss(y_pred, y_true)

                    if phase == "valid":
                        loss_valid.append(loss.item())
                        y_pred_np = y_pred.detach().cpu().numpy()
                        validation_pred.extend(
                            [y_pred_np[s] for s in range(y_pred_np.shape[0])]
                        )
                        y_true_np = y_true.detach().cpu().numpy()
                        validation_true.extend(
                            [y_true_np[s] for s in range(y_true_np.shape[0])]
                        )
                        if (epoch % args.vis_freq == 0) or (epoch == args.epochs - 1):
                            if i * args.batch_size < args.vis_images:
                                tag = "image/{}".format(i)
                                num_images = args.vis_images - i * args.batch_size
                                logger.image_list_summary(
                                    tag,
                                    log_images(x, y_true, y_pred)[:num_images],
                                    step,
                                )

                    if phase == "train":
                        loss_train.append(loss.item())
                        loss.backward()
                        optimizer.step()

                if phase == "train" and (step + 1) % 10 == 0:
                    log_loss_summary(logger, loss_train, step)
                    loss_train = []

            if phase == "valid":
                log_loss_summary(logger, loss_valid, step, prefix="val_")
                mean_dsc = np.mean(
                    dsc_per_volume(
                        validation_pred,
                        validation_true,
                        loader_valid.dataset.patient_slice_index,
                    )
                )
                logger.scalar_summary("val_dsc", mean_dsc, step)
                if mean_dsc > best_validation_dsc:
                    best_validation_dsc = mean_dsc
                    torch.save(unet.state_dict(), os.path.join(args.weights, "unet.pt"))
                loss_valid = []

    print("Best validation mean DSC: {:4f}".format(best_validation_dsc))
Пример #7
0
def main(args):
    makedirs(args)
    snapshotargs(args)
    device = torch.device("cpu" if not torch.cuda.is_available() else args.device)

    loader_train, loader_valid = data_loaders(args)
    loaders = {"train": loader_train, "valid": loader_valid}

    unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
    # unet.apply(weights_init)
    unet.to(device)

    dsc_loss = DiceLoss()
    best_validation_dsc = 0.0

    optimizer = optim.Adam(unet.parameters(), lr=args.lr, weight_decay=1e-3)
    # optimizer = optim.Adam(unet.parameters(), lr=args.lr)

    logger = Logger(args.logs)
    loss_train = []
    loss_valid = []

    log_train = []
    log_valid = []

    validation_pred = []
    validation_true = []
    step = 0

    for epoch in tqdm(range(args.epochs), total=args.epochs):
        for phase in ["train", "valid"]:
            if phase == "train":
                unet.train()
            else:
                unet.eval()


            for i, data in enumerate(loaders[phase]):
                if phase == "train":
                    step += 1
                x, y_true = data
                x, y_true = x.to(device), y_true.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == "train"):
                    y_pred = unet(x)
                    loss = dsc_loss(y_pred, y_true)
                    print(loss)

                    # if phase == "valid":
                    if phase == "train":
                        loss_valid.append(loss.item())
                        y_pred_np = y_pred.detach().cpu().numpy()
                        validation_pred.extend(
                            [y_pred_np[s] for s in range(y_pred_np.shape[0])]
                        )
                        y_true_np = y_true.detach().cpu().numpy()
                        validation_true.extend(
                            [y_true_np[s] for s in range(y_true_np.shape[0])]
                        )

                    if phase == "train":
                        loss_train.append(loss.item())
                        loss.backward()
                        optimizer.step()


            if phase == "valid":
                dsc, label_dsc = dsc_per_volume(
                          validation_pred,
                          validation_true,
                          # loader_valid.dataset.patient_slice_index,
                          loader_train.dataset.patient_slice_index,
                          )
                mean_dsc = np.mean(dsc)
                print(mean_dsc)
                print(np.array(label_dsc).mean(axis=0))

                if mean_dsc > best_validation_dsc:
                    best_validation_dsc = mean_dsc
                    best_label_dsc = label_dsc
                    torch.save(unet.state_dict(), os.path.join(args.weights, "unet.pt"))
                    opt = epoch
                log_valid.append(np.mean(loss_valid))
                loss_valid = []
                validation_pred = []
                validation_true = []
        log_train.append(np.mean(loss_train))
        loss_train=[]

    plt.plot(log_valid)
    plt.plot(log_train)
    plt.savefig("Test")
    print("Best validation mean DSC: {:4f}".format(best_validation_dsc))
    print(opt)
Пример #8
0
from torchvision.transforms import transforms as T
import argparse  # argparse模块的作用是用于解析命令行参数,例如python parseTest.py input.txt --port=8080
from Newunet import Insensee_3Dunet
from torch import optim
import MRI2IMG_dataset
from torch.utils.data import DataLoader
# from advanced_model import CleanU_Net
from networks.unet_model import UNet
from ResNetUNet import ResNetUNet
from HSC82 import CleanU_Net
from transform import imageaug
from loss import DiceLoss

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
learning_rate = 1e-3
DICE_loss = DiceLoss()


def train_model(model, criterion, optimizer, dataload, num_epochs=200):
    # model.load_state_dict(torch.load('./3dunet_model_save/weights_199.pth'))
    for epoch in range(num_epochs):
        save_loss = []
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('learning_rate:',
              optimizer.state_dict()['param_groups'][0]['lr'])
        print('-' * 10)
        dataset_size = len(dataload.dataset)
        epoch_loss = 0
        step = 0
        for img, label, _, _ in dataload:
            img_train_tensor = img
Пример #9
0
def main(args):
    presentParameters(vars(args))
    results_path = args.results
    if not os.path.exists(results_path):
        os.makedirs(results_path)

    save_args(args, modelpath=results_path)

    device = torch.device(args.device)
    if args.model == 'u-net':
        from unet.model import UNet
        model = UNet(in_channels=3, n_classes=1).to(device)
    elif args.model == 'fcd-net':
        from tiramisu.model import FCDenseNet
        # select model archictecture so it can be trained in 16gb ram GPU
        model = FCDenseNet(in_channels=3,
                           n_classes=1,
                           n_filter_first_conv=48,
                           n_pool=4,
                           growth_rate=8,
                           n_layers_per_block=3,
                           dropout_p=0.2).to(device)
    else:
        print(
            'Parsed model argument "{}" invalid. Possible choices are "u-net" or "fcd-net"'
            .format(args.model))

    # Init weights for model
    model = model.apply(weights_init)

    transforms = my_transforms(scale=args.aug_scale,
                               angle=args.aug_angle,
                               flip_prob=args.aug_flip)
    print('Trainable parameters for model {}: {}'.format(
        args.model, get_number_params(model)))

    # create pytorch dataset
    dataset = DataSetfromNumpy(
        image_npy_path='data/train_img_{}x{}.npy'.format(
            args.image_size, args.image_size),
        mask_npy_path='data/train_mask_{}x{}.npy'.format(
            args.image_size, args.image_size),
        transform=transforms)

    # create training and validation set
    n_val = int(len(dataset) * args.val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])

    ## hacky solution: only add CustomToTensor transform in validation
    from utils.transform import CustomToTensor
    val.dataset.transform = CustomToTensor()

    print('Training the model with n_train: {} and n_val: {} images/masks'.
          format(n_train, n_val))
    train_loader = DataLoader(train,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers)
    val_loader = DataLoader(val,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)

    dc_loss = DiceLoss()
    writer = SummaryWriter(log_dir=os.path.join(args.logs, args.model))
    optimizer = Adam(params=model.parameters(), lr=args.lr)
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.9,
                                                           patience=5)

    loss_train = []
    loss_valid = []

    # training loop:
    global_step = 0
    for epoch in range(args.epochs):
        eval_count = 0
        epoch_start_time = datetime.datetime.now().replace(microsecond=0)
        # set model into train mode
        model = model.train()
        train_epoch_loss = 0
        valid_epoch_loss = 0
        # tqdm progress bar
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{args.epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                # retrieve images and masks and send to pytorch device
                imgs = batch['image'].to(device=device, dtype=torch.float32)
                true_masks = batch['mask'].to(
                    device=device,
                    dtype=torch.float32
                    if model.n_classes == 1 else torch.long)

                # compute prediction masks
                predicted_masks = model(imgs)
                if model.n_classes == 1:
                    predicted_masks = torch.sigmoid(predicted_masks)
                elif model.n_classes > 1:
                    predicted_masks = F.softmax(predicted_masks, dim=1)

                # compute dice loss
                loss = dc_loss(y_true=true_masks, y_pred=predicted_masks)
                train_epoch_loss += loss.item()
                # update model network weights
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # logging
                writer.add_scalar('Loss/train', loss.item(), global_step)
                # update progress bar
                pbar.update(imgs.shape[0])
                # Do evaluation every 25 training steps
                if global_step % 25 == 0:
                    eval_count += 1
                    val_loss = np.mean(
                        eval_net(model, val_loader, device, dc_loss))
                    valid_epoch_loss += val_loss
                    writer.add_scalar('Loss/validation', val_loss, global_step)
                    if model.n_classes > 1:
                        pbar.set_postfix(
                            **{
                                'Training CE loss (batch)': loss.item(),
                                'Validation CE (val set)': val_loss
                            })
                    else:
                        pbar.set_postfix(
                            **{
                                'Training dice loss (batch)': loss.item(),
                                'Validation dice loss (val set)': val_loss
                            })

                global_step += 1
                # save images as well as true + predicted masks into writer
                if global_step % args.vis_images == 0:
                    writer.add_images('images', imgs, global_step)
                    if model.n_classes == 1:
                        writer.add_images('masks/true', true_masks,
                                          global_step)
                        writer.add_images('masks/pred', predicted_masks > 0.5,
                                          global_step)

            # Get estimation of training and validation loss for entire epoch
            valid_epoch_loss /= eval_count
            train_epoch_loss /= len(train_loader)

            # Apply learning rate scheduler per epoch
            scheduler.step(valid_epoch_loss)
            # Only save the model in case the validation metric is best. For the first epoch, directly save
            if epoch > 0:
                best_model_bool = [valid_epoch_loss < l for l in loss_valid]
                best_model_bool = np.all(best_model_bool)
            else:
                best_model_bool = True

            # append
            loss_train.append(train_epoch_loss)
            loss_valid.append(valid_epoch_loss)

            if best_model_bool:
                print(
                    '\nSaving model and optimizers at epoch {} with best validation loss of {}'
                    .format(epoch, valid_epoch_loss))
                torch.save(obj={
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'lr_scheduler': scheduler.state_dict(),
                },
                           f=results_path +
                           '/model_epoch-{}_val_loss-{}.pth'.format(
                               epoch, np.round(valid_epoch_loss, 4)))
                epoch_time_difference = datetime.datetime.now().replace(
                    microsecond=0) - epoch_start_time
                print('Epoch: {:3d} time execution: {}'.format(
                    epoch, epoch_time_difference))

    print(
        'Finished training the segmentation model.\nAll results can be found at: {}'
        .format(results_path))
    # save scalars dictionary as json file
    scalars = {'loss_train': loss_train, 'loss_valid': loss_valid}
    with open('{}/all_scalars.json'.format(results_path), 'w') as fp:
        json.dump(scalars, fp)

    print('Logging file for tensorboard is stored at {}'.format(args.logs))
    writer.close()
Пример #10
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for computing the token classification loss.
            Indices should be in ``[0, ..., config.num_labels - 1]``.

    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
            Classification loss.
        scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
            Classification scores (before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        from transformers import BertTokenizer, BertForTokenClassification
        import torch

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForTokenClassification.from_pretrained('bert-base-uncased')

        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)

        loss, scores = outputs[:2]

        """

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        outputs = (logits, ) + outputs[
            2:]  # add hidden states and attention if they are here
        if labels is not None:

            # loss_fct = CrossEntropyLoss()
            # loss_fct = FocalLoss()
            loss_fct = DiceLoss()
            # loss_fct = DSCLoss()
            # loss_fct= LabelSmoothingCrossEntropy()

            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1),
                    torch.tensor(loss_fct.ignore_index).type_as(labels))
                # print(active_loss, active_loss.shape, \
                #      active_logits,active_logits.shape,\
                #      active_labels,active_labels.shape,\
                #      labels, labels.shape)
                #2048 2048*435 2048 8*256
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels),
                                labels.view(-1))
            outputs = (loss, ) + outputs

        return outputs  # (loss), scores, (hidden_states), (attentions)
Пример #11
0
def train(args, model_G, model_D, optimizer_G, optimizer_D, CamVid_dataloader_train, CamVid_dataloader_val, IDDA_dataloader, curr_epoch, max_miou): 
# we need the camvid data loader an modify the dataloadrer val we don't need the data loader train because we use Idda dataloader 
    writer = SummaryWriter(comment=''.format(args.optimizer_G,args.optimizer_D, args.context_path))#not important for now


    scaler = amp.GradScaler()
    if args.loss_G == 'dice':
        loss_func_G = DiceLoss()
    elif args.loss_G == 'crossentropy':
        loss_func_G = torch.nn.CrossEntropyLoss()
        
    loss_func_adv = torch.nn.BCEWithLogitsLoss()
    loss_func_D = torch.nn.BCEWithLogitsLoss()
        
    step = 0
    for epoch in range(curr_epoch + 1, args.num_epochs + 1):  # added +1 shift to finish with an eval
        lr_G = poly_lr_scheduler(optimizer_G, args.learning_rate_G, iter=epoch, max_iter=args.num_epochs)
        lr_D = poly_lr_scheduler(optimizer_D, args.learning_rate_D, iter=epoch, max_iter=args.num_epochs)
        model_G.train()
        model_D.train()
        tq = tqdm(total=len(CamVid_dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr_G %f , lr_D %f' % (epoch, lr_G ,lr_D )) 

        # set the ground truth for the discriminator
        source_label = 0
        target_label = 1
        # iniate lists to track the losses 
        loss_G_record = []
        loss_adv_record = []  # we added a new list to track the advarsirial loss of generator
        loss_D_record = []     # we added a new list to track the discriminator loss 
        
        source_train_loader = enumerate(IDDA_dataloader)
        s_size = len(IDDA_dataloader)
        target_loader = enumerate(CamVid_dataloader_train)
        t_size = len(CamVid_dataloader_train)

        for i in range(t_size):

            optimizer_G.zero_grad()
            optimizer_D.zero_grad()

        #train G:
        
            for param in model_D.parameters():
                param.requires_grad = False

            #train with source:

            _, batch = next(source_train_loader)
            data, label = batch
            #label = label.type(torch.LongTensor)
            data = data.cuda()
            label = label.long().cuda()

            with amp.autocast():
                output_s, output_sup1, output_sup2 = model_G(data)
                loss1 = loss_func_G(output_s, label)
                loss2 = loss_func_G(output_sup1, label)
                loss3 = loss_func_G(output_sup2, label)
                loss_G = loss1 + loss2 + loss3

            scaler.scale(loss_G).backward()

            #train with target:

            #try:
            _, batch = next(target_loader)
            #except:
            #    target_loader = enumerate(CamVid_dataloader_train)
            #    _, batch = next(target_loader)

            data, _ = batch
            data = data.cuda()
            with amp.autocast():

                output_t, output_sup1, output_sup2 = model_G(data)
                D_out = model_D(F.softmax(output_t))
                loss_adv = loss_func_adv(D_out , Variable(torch.FloatTensor(D_out.data.size()).fill_(source_label)).cuda() )  # I MIDIFIED THOSE TRY TO FOOL THE DISC
                loss_adv = loss_adv * args.lambda_adv#0.001 or 0.01 CHECK

            scaler.scale(loss_adv).backward()

        # train D:
            for param in model_D.parameters():
                param.requires_grad = True

            #train with source:

            output_s = output_s.detach()
            with amp.autocast():
                D_out = model_D(F.softmax(output_s))  # we feed the discriminator with the output of the model
                loss_D = loss_func_D(D_out, Variable(torch.FloatTensor(D_out.data.size()).fill_(source_label)).cuda())   # add the adversarial loss
                loss_D = loss_D / 2
            scaler.scale(loss_D).backward()

            #train with target:

            output_t = output_t.detach()
            with amp.autocast():
                D_out = model_D(F.softmax(output_t))  # we feed the discriminator with the output of the model
                loss_D = loss_func_D(D_out, Variable(torch.FloatTensor(D_out.data.size()).fill_(target_label)).cuda())  # add the adversarial loss
                loss_D = loss_D / 2
            scaler.scale(loss_D).backward()

            tq.update(args.batch_size)
            losses = {"loss_seg" : '%.6f' %(loss_G.item())  , "loss_adv" : '%.6f' %(loss_adv.item()) , "loss_D" : '%.6f'%(loss_D.item()) } # add dictionary to print losses
            tq.set_postfix(losses)

            loss_G_record.append(loss_G.item())
            loss_adv_record.append(loss_adv.item())
            loss_D_record.append(loss_D.item())           
            step += 1
            writer.add_scalar('loss_G_step', loss_G, step)  # track the segmentation loss 
            writer.add_scalar('loss_adv_step', loss_adv, step)  # track the adversarial loss 
            writer.add_scalar('loss_D_step', loss_D, step)  # track the discreminator loss 
            scaler.step(optimizer_G)  # update the optimizer for genarator
            scaler.step(optimizer_D)  # update the optimizer for discriminator
            scaler.update()

        tq.close()
        loss_G_train_mean = np.mean(loss_G_record)
        loss_adv_train_mean = np.mean(loss_adv_record)
        loss_D_train_mean = np.mean(loss_D_record)
        writer.add_scalar('epoch/loss_G_train_mean', float(loss_G_train_mean), epoch)
        writer.add_scalar('epoch/loss_adv_train_mean', float(loss_adv_train_mean), epoch)
        writer.add_scalar('epoch/loss_D_train_mean', float(loss_D_train_mean), epoch)
    
        
        
        
        #the checkpoint needs rewriting
        
        if epoch % args.checkpoint_step == 0 and epoch != 0:
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            state = {
                "epoch": epoch,
                "model_G_state": model_G.module.state_dict(),
                "optimizer_G": optimizer_G.state_dict() ,
                "model_D_state": model_D.module.state_dict(), 
                "optimizer_D": optimizer_D.state_dict(),
                "max_miou": max_miou
            }

            torch.save(state, os.path.join(args.save_model_path, 'latest_dice_loss.pth'))

            print("*** epoch " + str(epoch) + " saved as recent checkpoint!!!")

        if epoch % args.validation_step == 0 and epoch != 0:
            precision, miou = val(args, model_G, CamVid_dataloader_val)
            if miou > max_miou:
                max_miou = miou
                os.makedirs(args.save_model_path, exist_ok=True)
                state = {
                    "epoch": epoch,
                    "model_state": model_G.module.state_dict(),
                    "optimizer": optimizer_G.state_dict(),
                    "max_miou": max_miou
                }
                torch.save(state, os.path.join(args.save_model_path, 'best_dice_loss.pth'))
                print("*** epoch " + str(epoch) + " saved as best checkpoint!!!")
            writer.add_scalar('epoch/precision_val', precision, epoch)
            writer.add_scalar('epoch/miou val', miou, epoch)
Пример #12
0
def train(args):

    torch.cuda.manual_seed(1)
    torch.manual_seed(1)

    # user defined parameters
    model_name = args.model_name
    model_type = args.model_type
    lstm_backbone = args.lstmbase
    unet_backbone = args.unetbase
    layer_num = args.layer_num
    nb_shortcut = args.nb_shortcut
    loss_fn = args.loss_fn
    world_size = args.world_size
    rank = args.rank
    base_channel = args.base_channels
    crop_size = args.crop_size
    ignore_idx = args.ignore_idx
    return_sequence = args.return_sequence
    variant = args.LSTM_variant
    epochs = args.epoch
    is_pretrain = args.is_pretrain

    # system setup parameters
    config_file = 'config.yaml'
    config = load_config(config_file)
    labels = config['PARAMETERS']['labels']
    root_path = config['PATH']['model_root']
    model_dir = config['PATH']['save_ckp']
    best_dir = config['PATH']['save_best_model']

    input_modalites = int(config['PARAMETERS']['input_modalites'])
    output_channels = int(config['PARAMETERS']['output_channels'])
    batch_size = int(config['PARAMETERS']['batch_size'])
    is_best = bool(config['PARAMETERS']['is_best'])
    is_resume = bool(config['PARAMETERS']['resume'])
    patience = int(config['PARAMETERS']['patience'])
    time_step = int(config['PARAMETERS']['time_step'])
    num_workers = int(config['PARAMETERS']['num_workers'])
    early_stop_patience = int(config['PARAMETERS']['early_stop_patience'])
    lr = int(config['PARAMETERS']['lr'])
    optimizer = config['PARAMETERS']['optimizer']
    connect = config['PARAMETERS']['connect']
    conv_type = config['PARAMETERS']['lstm_convtype']

    # build up dirs
    model_path = os.path.join(root_path, model_dir)
    best_path = os.path.join(root_path, best_dir)
    intermidiate_data_save = os.path.join(root_path, 'train_newdata',
                                          model_name)
    train_info_file = os.path.join(intermidiate_data_save,
                                   '{}_train_info.json'.format(model_name))
    log_path = os.path.join(root_path, 'logfiles')

    if not os.path.exists(model_path):
        os.mkdir(model_path)
    if not os.path.exists(best_path):
        os.mkdir(best_path)
    if not os.path.exists(intermidiate_data_save):
        os.makedirs(intermidiate_data_save)
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    log_name = model_name + '_' + config['PATH']['log_file']
    logger = logfile(os.path.join(log_path, log_name))
    logger.info('labels {} are ignored'.format(ignore_idx))
    logger.info('Dataset is loading ...')
    writer = SummaryWriter('ProcessVisu/%s' % model_name)

    # load training set and validation set
    data_class = data_split()
    train, val, test = data_construction(data_class)
    train_dict = time_parser(train, time_patch=time_step)
    val_dict = time_parser(val, time_patch=time_step)

    # LSTM initilization

    if model_type == 'LSTM':
        net = LSTMSegNet(lstm_backbone=lstm_backbone,
                         input_dim=input_modalites,
                         output_dim=output_channels,
                         hidden_dim=base_channel,
                         kernel_size=3,
                         num_layers=layer_num,
                         conv_type=conv_type,
                         return_sequence=return_sequence)
    elif model_type == 'UNet_LSTM':
        if variant == 'back':
            net = BackLSTM(input_dim=input_modalites,
                           hidden_dim=base_channel,
                           output_dim=output_channels,
                           kernel_size=3,
                           num_layers=layer_num,
                           conv_type=conv_type,
                           lstm_backbone=lstm_backbone,
                           unet_module=unet_backbone,
                           base_channel=base_channel,
                           return_sequence=return_sequence,
                           is_pretrain=is_pretrain)
            logger.info(
                'the pretrained status of backbone is {}'.format(is_pretrain))
        elif variant == 'center':
            net = CenterLSTM(input_modalites=input_modalites,
                             output_channels=output_channels,
                             base_channel=base_channel,
                             num_layers=layer_num,
                             conv_type=conv_type,
                             return_sequence=return_sequence,
                             is_pretrain=is_pretrain)
        elif variant == 'bicenter':
            net = BiCenterLSTM(input_modalites=input_modalites,
                               output_channels=output_channels,
                               base_channel=base_channel,
                               num_layers=layer_num,
                               connect=connect,
                               conv_type=conv_type,
                               return_sequence=return_sequence,
                               is_pretrain=is_pretrain)
        elif variant == 'directcenter':
            net = DirectCenterLSTM(input_modalites=input_modalites,
                                   output_channels=output_channels,
                                   base_channel=base_channel,
                                   num_layers=layer_num,
                                   conv_type=conv_type,
                                   return_sequence=return_sequence,
                                   is_pretrain=is_pretrain)
        elif variant == 'bidirectcenter':
            net = BiDirectCenterLSTM(input_modalites=input_modalites,
                                     output_channels=output_channels,
                                     base_channel=base_channel,
                                     num_layers=layer_num,
                                     connect=connect,
                                     conv_type=conv_type,
                                     return_sequence=return_sequence,
                                     is_pretrain=is_pretrain)
        elif variant == 'rescenter':
            net = ResCenterLSTM(input_modalites=input_modalites,
                                output_channels=output_channels,
                                base_channel=base_channel,
                                num_layers=layer_num,
                                conv_type=conv_type,
                                return_sequence=return_sequence,
                                is_pretrain=is_pretrain)
        elif variant == 'birescenter':
            net = BiResCenterLSTM(input_modalites=input_modalites,
                                  output_channels=output_channels,
                                  base_channel=base_channel,
                                  num_layers=layer_num,
                                  connect=connect,
                                  conv_type=conv_type,
                                  return_sequence=return_sequence,
                                  is_pretrain=is_pretrain)
        elif variant == 'shortcut':
            net = ShortcutLSTM(input_modalites=input_modalites,
                               output_channels=output_channels,
                               base_channel=base_channel,
                               num_layers=layer_num,
                               num_connects=nb_shortcut,
                               conv_type=conv_type,
                               return_sequence=return_sequence,
                               is_pretrain=is_pretrain)
    else:
        raise NotImplementedError()

    # loss and optimizer setup
    if loss_fn == 'Dice':
        criterion = DiceLoss(labels=labels, ignore_idx=ignore_idx)
    elif loss_fn == 'GDice':
        criterion = GneralizedDiceLoss(labels=labels)
    elif loss_fn == 'WCE':
        criterion = WeightedCrossEntropyLoss(labels=labels)
    else:
        raise NotImplementedError()

    if optimizer == 'adam':
        optimizer = optim.Adam(net.parameters(), lr=0.001)
        # optimizer = optim.Adam(net.parameters())
    elif optimizer == 'sgd':
        optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     verbose=True,
                                                     patience=patience)

    # device setup
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    # net, optimizer = amp.initialize(net, optimizer, opt_level="O1")

    if torch.cuda.device_count() > 1:
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='tcp://127.0.0.1:38366',
            rank=rank,
            world_size=world_size)
    if distributed_is_initialized():
        print('distributed is initialized')
        net.to(device)
        net = nn.parallel.DistributedDataParallel(net,
                                                  find_unused_parameters=True)
    else:
        print('data parallel')
        net = nn.DataParallel(net)
        net.to(device)

    min_loss = float('Inf')
    early_stop_count = 0
    global_step = 0
    start_epoch = 0
    start_loss = 0
    train_info = {
        'train_loss': [],
        'val_loss': [],
        'label_0_acc': [],
        'label_1_acc': [],
        'label_2_acc': [],
        'label_3_acc': [],
        'label_4_acc': []
    }

    if is_resume:
        try:
            # open previous check points
            ckp_path = os.path.join(model_path,
                                    '{}_model_ckp.pth.tar'.format(model_name))
            net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp(
                ckp_path, net, optimizer, scheduler)

            # open previous training records
            with open(train_info_file) as f:
                train_info = json.load(f)

            logger.info(
                'Training loss from last time is {}'.format(start_loss) +
                '\n' +
                'Mininum training loss from last time is {}'.format(min_loss))
            logger.info(
                'Training accuracies from last time are: label 0: {}, label 1: {}, label 2: {}, label 3: {}, label 4: {}'
                .format(train_info['label_0_acc'][-1],
                        train_info['label_1_acc'][-1],
                        train_info['label_2_acc'][-1],
                        train_info['label_3_acc'][-1],
                        train_info['label_4_acc'][-1]))

        except:
            logger.warning(
                'No checkpoint available, strat training from scratch')

    for epoch in range(start_epoch, epochs):

        train_set = data_loader(train_dict,
                                batch_size=batch_size,
                                key='train',
                                num_works=num_workers,
                                time_step=time_step,
                                patch=crop_size,
                                model_type='RNN')
        n_train = len(train_set)

        val_set = data_loader(val_dict,
                              batch_size=batch_size,
                              key='val',
                              num_works=num_workers,
                              time_step=time_step,
                              patch=crop_size,
                              model_type='CNN')
        n_val = len(val_set)

        logger.info('Dataset loading finished!')

        nb_batches = np.ceil(n_train / batch_size)
        n_total = n_train + n_val
        logger.info(
            '{} images will be used in total, {} for trainning and {} for validation'
            .format(n_total, n_train, n_val))

        train_loader = train_set.load()

        # setup to train mode
        net.train()
        running_loss = 0
        dice_score_label_0 = 0
        dice_score_label_1 = 0
        dice_score_label_2 = 0
        dice_score_label_3 = 0
        dice_score_label_4 = 0

        logger.info('Training epoch {} will begin'.format(epoch + 1))

        with tqdm(total=n_train,
                  desc=f'Epoch {epoch+1}/{epochs}',
                  unit='patch') as pbar:

            for i, data in enumerate(train_loader, 0):

                # i : patient
                images, segs = data['image'].to(device), data['seg'].to(device)

                outputs = net(images)
                loss = criterion(outputs, segs)
                loss.backward()
                optimizer.step()

                # if i == 0:
                #     in_images = images.detach().cpu().numpy()[0]
                #     in_segs = segs.detach().cpu().numpy()[0]
                #     in_pred = outputs.detach().cpu().numpy()[0]
                #     heatmap_plot(image=in_images, mask=in_segs, pred=in_pred, name=model_name, epoch=epoch+1, is_train=True)

                running_loss += loss.detach().item()

                outputs = outputs.view(-1, outputs.shape[-4],
                                       outputs.shape[-3], outputs.shape[-2],
                                       outputs.shape[-1])
                segs = segs.view(-1, segs.shape[-3], segs.shape[-2],
                                 segs.shape[-1])
                _, preds = torch.max(outputs.data, 1)
                dice_score = dice(preds.data.cpu(),
                                  segs.data.cpu(),
                                  ignore_idx=None)

                dice_score_label_0 += dice_score['bg']
                dice_score_label_1 += dice_score['csf']
                dice_score_label_2 += dice_score['gm']
                dice_score_label_3 += dice_score['wm']
                dice_score_label_4 += dice_score['tm']

                # show progress bar
                pbar.set_postfix(
                    **{
                        'training loss': loss.detach().item(),
                        'Training accuracy': dice_score['avg']
                    })
                pbar.update(images.shape[0])

                global_step += 1
                if global_step % nb_batches == 0:
                    net.eval()
                    val_loss, val_acc, val_info = validation(net,
                                                             val_set,
                                                             criterion,
                                                             device,
                                                             batch_size,
                                                             ignore_idx=None,
                                                             name=model_name,
                                                             epoch=epoch + 1)
                    net.train()

        train_info['train_loss'].append(running_loss / nb_batches)
        train_info['val_loss'].append(val_loss)
        train_info['label_0_acc'].append(dice_score_label_0 / nb_batches)
        train_info['label_1_acc'].append(dice_score_label_1 / nb_batches)
        train_info['label_2_acc'].append(dice_score_label_2 / nb_batches)
        train_info['label_3_acc'].append(dice_score_label_3 / nb_batches)
        train_info['label_4_acc'].append(dice_score_label_4 / nb_batches)

        # save bast trained model
        scheduler.step(running_loss / nb_batches)
        logger.info('Epoch: {}, LR: {}'.format(
            epoch + 1, optimizer.param_groups[0]['lr']))

        if min_loss > running_loss / nb_batches:
            min_loss = running_loss / nb_batches
            is_best = True
            early_stop_count = 0
        else:
            is_best = False
            early_stop_count += 1

        state = {
            'epoch': epoch + 1,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': running_loss / nb_batches,
            'min_loss': min_loss
        }
        verbose = save_ckp(state,
                           is_best,
                           early_stop_count=early_stop_count,
                           early_stop_patience=early_stop_patience,
                           save_model_dir=model_path,
                           best_dir=best_path,
                           name=model_name)

        # summarize the training results of this epoch
        logger.info('The average training loss for this epoch is {}'.format(
            running_loss / nb_batches))
        logger.info('The best training loss till now is {}'.format(min_loss))
        logger.info(
            'Validation dice loss: {}; Validation (avg) accuracy of the last timestep: {}'
            .format(val_loss, val_acc))

        # save the training info every epoch
        logger.info('Writing the training info into file ...')
        val_info_file = os.path.join(intermidiate_data_save,
                                     '{}_val_info.json'.format(model_name))
        with open(train_info_file, 'w') as fp:
            json.dump(train_info, fp)
        with open(val_info_file, 'w') as fp:
            json.dump(val_info, fp)

        for name, layer in net.named_parameters():
            if layer.requires_grad:
                writer.add_histogram(name + '_grad',
                                     layer.grad.cpu().data.numpy(), epoch)
                writer.add_histogram(name + '_data',
                                     layer.cpu().data.numpy(), epoch)
        if verbose:
            logger.info(
                'The validation loss has not improved for {} epochs, training will stop here.'
                .format(early_stop_patience))
            break

    loss_plot(train_info_file, name=model_name)
    logger.info('finish training!')

    return
Пример #13
0
def train(args, model, optimizer, dataloader_train, dataloader_val):
    writer = SummaryWriter(
        comment=''.format(args.optimizer, args.context_path))
    if args.loss == 'dice':
        loss_func = DiceLoss()
    elif args.loss == 'crossentropy':
        loss_func = torch.nn.CrossEntropyLoss(ignore_index=255)
    max_miou = 0
    step = 0
    for epoch in range(args.num_epochs):
        lr = poly_lr_scheduler(optimizer,
                               args.learning_rate,
                               iter=epoch,
                               max_iter=args.num_epochs)
        model.train()
        tq = tqdm(total=len(dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr %f' % (epoch, lr))
        loss_record = []
        for i, (data, label) in enumerate(dataloader_train):
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda().long()

            with torch.cuda.amp.autocast():
                output, output_sup1, output_sup2 = model(data)
                loss1 = loss_func(output, label)
                loss2 = loss_func(output_sup1, label)
                loss3 = loss_func(output_sup2, label)
                loss = loss1 + loss2 + loss3
                tq.update(args.batch_size)
                tq.set_postfix(loss='%.6f' % loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            step += 1
            writer.add_scalar('loss_step', loss, step)
            loss_record.append(loss.item())
        tq.close()
        loss_train_mean = np.mean(loss_record)
        writer.add_scalar('epoch/loss_epoch_train', float(loss_train_mean),
                          epoch)
        print('loss for train : %f' % (loss_train_mean))
        if epoch % args.checkpoint_step == 0 and epoch != 0:
            import os
            if not os.path.isdir(args.save_model_path):
                import os
                os.mkdir(args.save_model_path)
            torch.save(model.module.state_dict(),
                       os.path.join(args.save_model_path, 'model.pth'))

        if epoch % args.validation_step == 0 and epoch != 0:
            precision, miou = val(args, model, dataloader_val)
            if miou > max_miou:
                max_miou = miou
                import os
                os.makedirs(args.save_model_path, exist_ok=True)
                torch.save(
                    model.module.state_dict(),
                    os.path.join(args.save_model_path, 'best_dice_loss.pth'))
            writer.add_scalar('epoch/precision_val', precision, epoch)
            writer.add_scalar('epoch/miou val', miou, epoch)
Пример #14
0
def main():
    axis = 'ax1'
    # CUDA for PyTorch
    device = train_device()

    # Training dataset
    train_params = {'batch_size': 10, 'shuffle': True, 'num_workers': 4}

    data_path = './dataset/dataset_' + axis + '/train/'
    train_dataset = Dataset(data_path,
                            transform=transforms.Compose([Preprocessing()]))
    lenght = int(len(train_dataset))
    train_loader = torch.utils.data.DataLoader(train_dataset, **train_params)

    # Validation dataset
    data_path = './dataset/dataset_' + axis + '/valid/'
    valid_dataset = Dataset(data_path,
                            transform=transforms.Compose([Preprocessing()]))
    valid_params = {'batch_size': 10, 'shuffle': True, 'num_workers': 4}
    val_loader = torch.utils.data.DataLoader(valid_dataset, **valid_params)

    # Training params
    learning_rate = 1e-4
    max_epochs = 100

    # Used pretrained model and modify channels from 3 to 1
    model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch',
                           'unet',
                           in_channels=3,
                           out_channels=1,
                           init_features=32,
                           pretrained=True)
    model.encoder1.enc1conv1 = nn.Conv2d(1,
                                         32,
                                         kernel_size=(3, 3),
                                         stride=(1, 1),
                                         padding=(1, 1),
                                         bias=False)
    model.to(device)

    # Optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    dsc_loss = DiceLoss()

    # Metrics
    train_loss = AverageMeter('Training loss', ':.6f')
    val_loss = AverageMeter('Validation loss', ':.6f')
    best_loss = float('inf')
    nb_of_batches = lenght // train_params['batch_size']

    for epoch in range(max_epochs):
        val_loss.avg = 0
        train_loss.avg = 0
        if not epoch:
            logg_file = loggs.Loggs(['epoch', 'train_loss', 'val_loss'])
            model.train()
        for i, (image, label) in enumerate(train_loader):
            torch.cuda.empty_cache()
            image, label = image.to(device), label.to(device)
            optimizer.zero_grad()
            y_pred = model(image)
            loss = dsc_loss(y_pred, label)
            del y_pred
            train_loss.update(loss.item(), image.size(0))
            loss.backward()
            optimizer.step()
            loggs.training_bar(i,
                               nb_of_batches,
                               prefix='Epoch: %d/%d' % (epoch, max_epochs),
                               suffix='Loss: %.6f' % loss.item())
        print(train_loss.avg)

        with torch.no_grad():
            for i, (x_val, y_val) in enumerate(val_loader):
                x_val, y_val = x_val.to(device), y_val.to(device)
                model.eval()
                yhat = model(x_val)
                loss = dsc_loss(yhat, y_val)
                val_loss.update(loss.item(), x_val.size(0))
            print(val_loss)
            logg_file.save([epoch, train_loss.avg, val_loss.avg])

            # Save the best model with minimum validation loss
            if best_loss > val_loss.avg:
                print('Updated model with validation loss %.6f ---> %.6f' %
                      (best_loss, val_loss.avg))
                best_loss = val_loss.avg
                torch.save(model, './model_' + axis + '/best_model.pt')
Пример #15
0
def train():
    # 训练的epoch数
    epoch = 500
    # 数据文件夹
    img_dir = "./data/training/images"
    # 掩模文件夹
    mask_dir = "./data/training/1st_manual"
    # 网络输入图片大小
    img_size = (512, 512)
    # 创建训练loader和验证loader
    tr_loader = DataLoader(DRIVE_Loader(img_dir, mask_dir, img_size, 'train'),
                           batch_size=4,
                           shuffle=True,
                           num_workers=2,
                           pin_memory=True,
                           drop_last=True)
    val_loader = DataLoader(DRIVE_Loader(img_dir, mask_dir, img_size, 'val'),
                            batch_size=4,
                            shuffle=True,
                            num_workers=2,
                            pin_memory=True,
                            drop_last=True)
    # 定义损失函数
    criterion = DiceBCELoss()
    # 把网络加载到显卡
    network = UNet().cuda()
    # 定义优化器
    optimizer = Adam(network.parameters(), weight_decay=0.0001)
    best_score = 1.0
    for i in range(epoch):
        # 设置为训练模式,会更新BN和Dropout参数
        network.train()
        train_step = 0
        train_loss = 0
        val_loss = 0
        val_step = 0
        # 训练
        for batch in tr_loader:
            # 读取每个batch的数据和掩模
            imgs, mask = batch
            # 把数据加载到显卡
            imgs = imgs.cuda()
            mask = mask.cuda()
            # 把数据喂入网络,获得一个预测结果
            mask_pred = network(imgs)
            # 根据预测结果与掩模求出Loss
            loss = criterion(mask_pred, mask)
            # 统计训练loss
            train_loss += loss.item()
            train_step += 1
            # 梯度清零
            optimizer.zero_grad()
            # 通过loss求出梯度
            loss.backward()
            # 使用Adam进行梯度回传
            optimizer.step()
        # 设置为验证模式,不更新BN和Dropout参数
        network.eval()
        # 验证
        with torch.no_grad():
            for batch in val_loader:
                imgs, mask = batch
                imgs = imgs.cuda()
                mask = mask.cuda()
                # 求出评价指标,这里用的是dice
                val_loss += DiceLoss()(network(imgs), mask).item()
                val_step += 1
        # 分别求出整个epoch的训练loss以及验证指标
        train_loss /= train_step
        val_loss /= val_step
        # 如果验证指标比最优值更好,那么保存当前模型参数
        if val_loss < best_score:
            best_score = val_loss
            torch.save(network.state_dict(), "./checkpoint.pth")
        # 输出
        print(str(i), "train_loss:", train_loss, "val_dice", val_loss)
Пример #16
0
    def compute_loss(self, start_logits, end_logits, start_labels, end_labels,
                     label_mask):
        """compute loss on squad task."""
        if len(start_labels.size()) > 1:
            start_labels = start_labels.squeeze(-1)
        if len(end_labels.size()) > 1:
            end_labels = end_labels.squeeze(-1)

        # sometimes the start/end positions are outside our model inputs, we ignore these terms
        batch_size, ignored_index = start_logits.shape  # ignored_index: seq_len
        start_labels.clamp_(0, ignored_index)
        end_labels.clamp_(0, ignored_index)

        if self.loss_type != "ce":
            # start_labels/end_labels: position index of answer starts/ends among the document.
            # F.one_hot will map the postion index to a sequence of 0, 1 labels.
            start_labels = F.one_hot(start_labels, num_classes=ignored_index)
            end_labels = F.one_hot(end_labels, num_classes=ignored_index)

        if self.loss_type == "ce":
            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_labels)
            end_loss = loss_fct(end_logits, end_labels)
        elif self.loss_type == "bce":
            start_loss = F.binary_cross_entropy_with_logits(
                start_logits.view(-1),
                start_labels.view(-1).float(),
                reduction="none")
            end_loss = F.binary_cross_entropy_with_logits(
                end_logits.view(-1),
                end_labels.view(-1).float(),
                reduction="none")

            start_loss = (start_loss *
                          label_mask.view(-1)).sum() / label_mask.sum()
            end_loss = (end_loss *
                        label_mask.view(-1)).sum() / label_mask.sum()
        elif self.loss_type == "focal":
            loss_fct = FocalLoss(gamma=self.args.focal_gamma, reduction="none")
            start_loss = loss_fct(
                FocalLoss.convert_binary_pred_to_two_dimension(
                    start_logits.view(-1)), start_labels.view(-1))
            end_loss = loss_fct(
                FocalLoss.convert_binary_pred_to_two_dimension(
                    end_logits.view(-1)), end_labels.view(-1))
            start_loss = (start_loss *
                          label_mask.view(-1)).sum() / label_mask.sum()
            end_loss = (end_loss *
                        label_mask.view(-1)).sum() / label_mask.sum()

        elif self.loss_type in ["dice", "adaptive_dice"]:
            loss_fct = DiceLoss(with_logits=True,
                                smooth=self.args.dice_smooth,
                                ohem_ratio=self.args.dice_ohem,
                                alpha=self.args.dice_alpha,
                                square_denominator=self.args.dice_square)
            # add to test
            # start_logits, end_logits = start_logits.view(batch_size, -1), end_logits.view(batch_size, -1)
            # start_labels, end_labels = start_labels.view(batch_size, -1), end_labels.view(batch_size, -1)
            start_logits, end_logits = start_logits.view(-1,
                                                         1), end_logits.view(
                                                             -1, 1)
            start_labels, end_labels = start_labels.view(-1,
                                                         1), end_labels.view(
                                                             -1, 1)
            # label_mask = label_mask.view(batch_size, -1)
            label_mask = label_mask.view(-1, 1)
            start_loss = loss_fct(start_logits, start_labels, mask=label_mask)
            end_loss = loss_fct(end_logits, end_labels, mask=label_mask)
        else:
            raise ValueError("This type of loss func donot exists.")

        total_loss = (start_loss + end_loss) / 2

        return total_loss, start_loss, end_loss
def train(model, setting, optimizer, scheduler, epochs, batchSize, logger,
          resultsPath, testResults, testResultsTTA, tbWriter,
          allClassEvaluators):

    model.to(device)
    if torch.cuda.device_count() > 1 and useAllAvailableGPU:
        logger.info('# {} GPUs utilized! #'.format(torch.cuda.device_count()))
        model = nn.DataParallel(model)

    # mandatory to produce random numpy numbers during training, otherwise batches will contain equal random numbers (originally: numpy issue)
    def worker_init_fn(worker_id):
        np.random.seed(np.random.get_state()[1][0] + worker_id)

    # allocate and separately load train / val / test data sets
    dataset_Train = CustomDataSetRAM('train', logger)
    dataloader_Train = DataLoader(dataset=dataset_Train,
                                  batch_size=batchSize,
                                  shuffle=True,
                                  num_workers=4,
                                  worker_init_fn=worker_init_fn)

    if 'val' in setting:
        dataset_Val = CustomDataSetRAM('val', logger)
        dataloader_Val = DataLoader(dataset=dataset_Val,
                                    batch_size=batchSize,
                                    shuffle=False,
                                    num_workers=1,
                                    worker_init_fn=worker_init_fn)

    if 'test' in setting:
        dataset_Test = CustomDataSetRAM('test', logger)
        dataloader_Test = DataLoader(dataset=dataset_Test,
                                     batch_size=batchSize,
                                     shuffle=False,
                                     num_workers=1,
                                     worker_init_fn=worker_init_fn)

    logger.info('####### DATA LOADED - TRAINING STARTS... #######')

    # Utilize dice loss and weighted cross entropy loss
    Dice_Loss = DiceLoss(ignore_index=8).to(device)
    CE_Loss = nn.CrossEntropyLoss(weight=torch.FloatTensor(
        [1., 1., 1., 1., 1., 1., 1., 10.]),
                                  ignore_index=8).to(device)
    # WCE_Loss = nn.CrossEntropyLoss(weight=getWeightsForCEloss(dataset, train_idx, areLabelsOnehotEncoded=False, device=device, logger=logger)).to(device)

    for epoch in range(epochs):
        model.train(True)

        epochCELoss = 0
        epochDiceLoss = 0
        epochLoss = 0

        np.random.seed()
        start = time.time()
        for batch in dataloader_Train:
            # get data and put onto device
            imgBatch, segBatch = batch

            imgBatch = imgBatch.to(device)
            segBatch = segBatch.to(device)

            optimizer.zero_grad()

            # forward image batch, compute loss and backprop
            prediction = model(imgBatch)

            CEloss = CE_Loss(prediction, segBatch)
            diceLoss = Dice_Loss(prediction, segBatch)

            loss = CEloss + diceLoss

            epochCELoss += CEloss.item()
            epochDiceLoss += diceLoss.item()
            epochLoss += loss.item()

            loss.backward()
            # nn.utils.clip_grad_norm(model.parameters(), 10)
            optimizer.step()

        epochTrainLoss = epochLoss / dataloader_Train.__len__()

        end = time.time()
        # print current loss
        logger.info('[Epoch ' + str(epoch + 1) + '] Train-Loss: ' +
                    str(round(epochTrainLoss, 5)) + ', DiceLoss: ' +
                    str(round(epochDiceLoss / dataloader_Train.__len__(), 5)) +
                    ', CELoss: ' +
                    str(round(epochCELoss / dataloader_Train.__len__(), 5)) +
                    '  [took ' + str(round(end - start, 3)) + 's]')

        # use tensorboard for visualization of training progress
        tbWriter.add_scalars(
            'Plot/train', {
                'loss': epochTrainLoss,
                'CEloss': epochCELoss / dataloader_Train.__len__(),
                'DiceLoss': epochDiceLoss / dataloader_Train.__len__()
            }, epoch)

        # each 50th epoch add prediction image to tensorboard
        if epoch % 50 == 0:
            with torch.no_grad():
                tbWriter.add_image(
                    'Train/_img',
                    torch.round(
                        (imgBatch[0, :, :, :] + 1.6) / 3.2 * 255.0).byte(),
                    epoch)
                tbWriter.add_image(
                    'Train/GT',
                    convert_labelmap_to_rgb(segBatch[0, :, :].cpu()), epoch)
                tbWriter.add_image(
                    'Train/pred',
                    convert_labelmap_to_rgb(
                        prediction[0, :, :, :].argmax(0).cpu()), epoch)

        if epoch % 100 == 0:
            logger.info('[Epoch ' + str(epoch + 1) + '] ' +
                        parse_nvidia_smi(GPUno))
            logger.info('[Epoch ' + str(epoch + 1) + '] ' + parse_RAM_info())

        # if validation is active, compute dice scores on validation data
        if 'val' in setting:
            model.train(False)

            diceScores_Val = []

            start = time.time()
            for batch in dataloader_Val:
                imgBatch, segBatch = batch
                imgBatch = imgBatch.to(device)
                # segBatch = segBatch.to(device)

                with torch.no_grad():
                    prediction = model(imgBatch).to('cpu')

                    diceScores_Val.append(getDiceScores(prediction, segBatch))

            diceScores_Val = np.concatenate(
                diceScores_Val, 0
            )  # <- all dice scores of val data (batchSize x amountClasses-1)
            diceScores_Val = diceScores_Val[:, :
                                            -1]  # ignore last coloum=border dice scores

            mean_DiceScores_Val, epoch_val_mean_score = getMeanDiceScores(
                diceScores_Val, logger)

            end = time.time()
            logger.info('[Epoch ' + str(epoch + 1) +
                        '] Val-Score (mean label dice scores): ' +
                        str(np.round(mean_DiceScores_Val, 4)) + ', Mean: ' +
                        str(round(epoch_val_mean_score, 4)) + '  [took ' +
                        str(round(end - start, 3)) + 's]')

            tbWriter.add_scalar('Plot/val', epoch_val_mean_score, epoch)

            if epoch % 50 == 0:
                with torch.no_grad():
                    tbWriter.add_image(
                        'Val/_img',
                        torch.round(
                            (imgBatch[0, :, :, :] + 1.6) / 3.2 * 255.0).byte(),
                        epoch)
                    tbWriter.add_image(
                        'Val/GT',
                        convert_labelmap_to_rgb(segBatch[0, :, :].cpu()),
                        epoch)
                    tbWriter.add_image(
                        'Val/pred',
                        convert_labelmap_to_rgb(
                            prediction[0, :, :, :].argmax(0).cpu()), epoch)

            if epoch % 100 == 0:
                logger.info('[Epoch ' + str(epoch + 1) +
                            ' - After Validation] ' + parse_nvidia_smi(GPUno))
                logger.info('[Epoch ' + str(epoch + 1) +
                            ' - After Validation] ' + parse_RAM_info())

        # scheduler.step()
        if 'val' in setting:
            endLoop = scheduler.stepTrainVal(epoch_val_mean_score, logger)
        else:
            endLoop = scheduler.stepTrain(epochTrainLoss, logger)

        if epoch == (
                epochs - 1
        ):  #when no early stop is performed, load bestValModel into current model for later save
            logger.info(
                '### No early stop performed! Best val model loaded... ####')
            if 'val' in setting:
                scheduler.loadBestValIntoModel()

        # if test is active, compute global dice scores on test data for fast and coarse performance check
        if 'test' in setting:
            model.train(False)

            diceScores_Test = []

            start = time.time()
            for batch in dataloader_Test:
                imgBatch, segBatch = batch
                imgBatch = imgBatch.to(device)
                # segBatch = segBatch.to(device)

                with torch.no_grad():
                    prediction = model(imgBatch).to('cpu')

                    diceScores_Test.append(getDiceScores(prediction, segBatch))

            diceScores_Test = np.concatenate(
                diceScores_Test, 0
            )  # <- all dice scores of test data (amountTestData x amountClasses-1)
            diceScores_Test = diceScores_Test[:, :
                                              -1]  #ignore last coloum=border dice scores

            mean_DiceScores_Test, test_mean_score = getMeanDiceScores(
                diceScores_Test, logger)

            end = time.time()
            logger.info('[Epoch ' + str(epoch + 1) +
                        '] Test-Score (mean label dice scores): ' +
                        str(np.round(mean_DiceScores_Test, 4)) + ', Mean: ' +
                        str(round(test_mean_score, 4)) + '  [took ' +
                        str(round(end - start, 3)) + 's]')

            tbWriter.add_scalar('Plot/test', test_mean_score, epoch)

            if epoch % 50 == 0:
                with torch.no_grad():
                    tbWriter.add_image(
                        'Test/_img',
                        torch.round(
                            (imgBatch[0, :, :, :] + 1.6) / 3.2 * 255.0).byte(),
                        epoch)
                    tbWriter.add_image(
                        'Test/GT',
                        convert_labelmap_to_rgb(segBatch[0, :, :].cpu()),
                        epoch)
                    tbWriter.add_image(
                        'Test/pred',
                        convert_labelmap_to_rgb(
                            prediction[0, :, :, :].argmax(0).cpu()), epoch)

            if epoch % 100 == 0:
                logger.info('[Epoch ' + str(epoch + 1) + ' - After Testing] ' +
                            parse_nvidia_smi(GPUno))
                logger.info('[Epoch ' + str(epoch + 1) + ' - After Testing] ' +
                            parse_RAM_info())

            with torch.no_grad():
                ### if training is over ###
                if endLoop or (epoch == epochs - 1):

                    diceScores_Test = []
                    diceScores_Test_TTA = []

                    test_idx = np.arange(sum(testDatasetsSizes))
                    for sampleNo in test_idx:
                        diseaseID = -1
                        if sampleNo < sum(testDatasetsSizes[:1]):
                            diseaseID = 0  # Healthy test sample
                        elif sampleNo < sum(testDatasetsSizes[:2]):
                            diseaseID = 2  # UUO test sample
                        elif sampleNo < sum(testDatasetsSizes[:3]):
                            diseaseID = 4  # Adenine test sample
                        elif sampleNo < sum(testDatasetsSizes[:4]):
                            diseaseID = 6  # Alport test sample
                        elif sampleNo < sum(testDatasetsSizes[:5]):
                            diseaseID = 8  # IRI test sample
                        elif sampleNo < sum(testDatasetsSizes[:6]):
                            diseaseID = 10  # NTN test sample

                        # get test sample, forward it through network in evaluation mode, and compute performance
                        imgBatch, segBatch = dataset_Test.__getitem__(sampleNo)

                        imgBatch = imgBatch.unsqueeze(0).to(device)
                        segBatch = segBatch.unsqueeze(0)

                        prediction = model(imgBatch)

                        predictionCPU = prediction.to("cpu")

                        # apply post-processing
                        postprocessedPrediction, outputPrediction, preprocessedGT = postprocessPredictionAndGT(
                            prediction,
                            segBatch.squeeze(0).numpy(),
                            device=device,
                            predictionsmoothing=True,
                            holefilling=True)
                        classInstancePredictionList, classInstanceGTList, finalPredictionRGB, preprocessedGTrgb = extractInstanceChannels(
                            postprocessedPrediction,
                            preprocessedGT,
                            tubuliDilation=True)

                        # evaluate performance (TP, NP, FP counting and dice score computation)
                        for i in range(6):  #number classes to evaluate = 6
                            allClassEvaluators[diseaseID][i].add_example(
                                classInstancePredictionList[i],
                                classInstanceGTList[i])

                        # compute global dice scores as coarse performance check
                        diceScores_Test.append(
                            getDiceScores(predictionCPU, segBatch))

                        if saveFinalTestResults:
                            figFolder = resultsPath + '/' + diseaseModels[
                                diseaseID // 2]
                            if not os.path.exists(figFolder):
                                os.makedirs(figFolder)

                            imgBatchCPU = torch.round(
                                (imgBatch[0, :, :, :].to("cpu") + 1.6) / 3.2 *
                                255.0).byte().numpy().transpose(1, 2, 0)
                            figPath = figFolder + '/test_idx_' + str(
                                sampleNo) + '_result.png'
                            saveFigureResults(imgBatchCPU,
                                              outputPrediction,
                                              postprocessedPrediction,
                                              finalPredictionRGB,
                                              segBatch.squeeze(0).numpy(),
                                              preprocessedGT,
                                              preprocessedGTrgb,
                                              fullResultPath=figPath,
                                              alpha=0.4)

                        # perform exactly the same when applying TTA
                        if applyTestTimeAugmentation:
                            prediction = torch.softmax(prediction, 1)

                            imgBatch = imgBatch.flip(2)
                            prediction += torch.softmax(model(imgBatch),
                                                        1).flip(2)

                            imgBatch = imgBatch.flip(3)
                            prediction += torch.softmax(model(imgBatch),
                                                        1).flip(3).flip(2)

                            imgBatch = imgBatch.flip(2)
                            prediction += torch.softmax(model(imgBatch),
                                                        1).flip(3)

                            prediction /= 4.

                            predictionCPU = prediction.to("cpu")

                            postprocessedPrediction, outputPrediction, preprocessedGT = postprocessPredictionAndGT(
                                prediction,
                                segBatch.squeeze(0).numpy(),
                                device=device,
                                predictionsmoothing=False,
                                holefilling=True)

                            classInstancePredictionList, classInstanceGTList, finalPredictionRGB, preprocessedGTrgb = extractInstanceChannels(
                                postprocessedPrediction,
                                preprocessedGT,
                                tubuliDilation=False)

                            for i in range(6):
                                allClassEvaluators[
                                    diseaseID + 1][i].add_example(
                                        classInstancePredictionList[i],
                                        classInstanceGTList[i])

                            diceScores_Test_TTA.append(
                                getDiceScores(predictionCPU, segBatch))

                            if saveFinalTestResults:
                                figPath = figFolder + '/test_idx_' + str(
                                    sampleNo) + '_result_TTA.png'
                                saveFigureResults(imgBatchCPU,
                                                  outputPrediction,
                                                  postprocessedPrediction,
                                                  finalPredictionRGB,
                                                  segBatch.squeeze(0).numpy(),
                                                  preprocessedGT,
                                                  preprocessedGTrgb,
                                                  fullResultPath=figPath,
                                                  alpha=0.4)

                    # print global dice scores as coarse performance check
                    diceScores_Test = np.concatenate(
                        diceScores_Test, 0
                    )  # <- all dice scores of test data (amountTestData x amountClasses-1)
                    diceScores_Test = diceScores_Test[:, :
                                                      -1]  # ignore last coloum=border dice scores
                    mean_DiceScores_Test, test_mean_score = getMeanDiceScores(
                        diceScores_Test, logger)
                    logger.info('[FINAL RESULT] [Epoch ' + str(epoch + 1) +
                                '] Test-Score (mean label dice scores): ' +
                                str(np.round(mean_DiceScores_Test, 4)) +
                                ', Mean: ' + str(round(test_mean_score, 4)))
                    testResults.append(diceScores_Test)

                    # print global dice scores of TTA as coarse performance check
                    if applyTestTimeAugmentation:
                        diceScores_Test_TTA = np.concatenate(
                            diceScores_Test_TTA, 0
                        )  # <- all dice scores of test data (amountTestData x amountClasses-1)
                        diceScores_Test_TTA = diceScores_Test_TTA[:, :
                                                                  -1]  # ignore last coloum=border dice scores
                        mean_DiceScores_Test_TTA, test_mean_score_TTA = getMeanDiceScores(
                            diceScores_Test_TTA, logger)
                        logger.info(
                            '[FINAL TTA RESULT] [Epoch ' + str(epoch + 1) +
                            '] Test-Score (mean label dice scores): ' +
                            str(np.round(mean_DiceScores_Test_TTA, 4)) +
                            ', Mean: ' + str(round(test_mean_score_TTA, 4)))
                        testResultsTTA.append(diceScores_Test_TTA)

        if endLoop:
            logger.info('### Early network training stop at epoch ' +
                        str(epoch + 1) + '! ###')
            break

    logger.info('[Epoch ' + str(epoch + 1) + '] ### Training done! ###')

    return model
Пример #18
0
def train_net(args):
    cropsize = [cfgs.crop_height, cfgs.crop_width]
    # dataset_train = CityScapes(cfgs.data_dir, cropsize=cropsize, mode='train')
    dataset_train = ContextVoc(cfgs.train_file,
                               cropsize=cropsize,
                               mode='train')
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  drop_last=True)
    # dataset_val = CityScapes(cfgs.data_dir,  mode='val')
    dataset_val = ContextVoc(cfgs.val_file, cropsize=cropsize, mode='train')
    dataloader_val = DataLoader(dataset_val,
                                batch_size=1,
                                shuffle=True,
                                num_workers=args.num_workers,
                                drop_last=True)
    # build net
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
    if torch.cuda.is_available() and args.use_gpu:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    # model = BiSeNet(args.num_classes, args.context_path)
    net = DeeplabV3plus(cfgs).to(device)
    # net = SCAR(load_weights=True).to(device)
    if args.pretrained_model_path is not None:
        print('load model from %s ...' % args.pretrained_model_path)
        state_dict = torch.load(args.pretrained_model_path,
                                map_location=device)
        state_dict = renamedict(state_dict)
        net.load_state_dict(state_dict, strict=False)
        # net.load_state_dict(torch.load(args.pretrained_model_path))
        print('Done!')
    if args.mulgpu:
        net = torch.nn.DataParallel(net)
    net.train()
    # build optimizer
    if args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(net.parameters(), args.learning_rate)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(),
                                    args.learning_rate,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(net.parameters(), args.learning_rate)
    else:
        print('not supported optimizer \n')
        optimizer = None
    #build loss
    if args.losstype == 'dice':
        criterion = DiceLoss()
    elif args.losstype == 'crossentropy':
        criterion = torch.nn.CrossEntropyLoss()
    elif args.losstype == 'ohem':
        score_thres = 0.7
        n_min = args.batch_size * cfgs.crop_height * cfgs.crop_width // 16
        criterion = OhemCELoss(thresh=score_thres, n_min=n_min)
    elif args.losstype == 'focal':
        # criterion = SoftmaxFocalLoss()
        criterion = FocalLoss()
    elif args.losstype == 'multi':
        criterion = Multiloss(4)
    return net, optimizer, criterion, dataloader_train, dataloader_val
Пример #19
0
def train_val(config):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_loader = get_dataloader(img_dir=config.train_img_dir,
                                  mask_dir=config.train_mask_dir,
                                  mode="train",
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers)
    val_loader = get_dataloader(img_dir=config.val_img_dir,
                                mask_dir=config.val_mask_dir,
                                mode="val",
                                batch_size=config.batch_size,
                                num_workers=config.num_workers)

    writer = SummaryWriter(
        comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" %
        (config.lr, config.batch_size, config.model_type, config.data_type))

    if config.model_type not in [
            'UNet', 'R2UNet', 'AUNet', 'R2AUNet', 'SEUNet', 'SEUNet++',
            'UNet++', 'DAUNet', 'DANet', 'AUNetR', 'RendDANet', "BASNet"
    ]:
        print('ERROR!! model_type should be selected in supported models')
        print('Choose model %s' % config.model_type)
        return
    if config.model_type == "UNet":
        model = UNet()
    elif config.model_type == "AUNet":
        model = AUNet()
    elif config.model_type == "R2UNet":
        model = R2UNet()
    elif config.model_type == "SEUNet":
        model = SEUNet(useCSE=False, useSSE=False, useCSSE=True)
    elif config.model_type == "UNet++":
        model = UNetPP()
    elif config.model_type == "DANet":
        model = DANet(backbone='resnet101', nclass=1)
    elif config.model_type == "AUNetR":
        model = AUNet_R16(n_classes=1, learned_bilinear=True)
    elif config.model_type == "RendDANet":
        model = RendDANet(backbone='resnet101', nclass=1)
    elif config.model_type == "BASNet":
        model = BASNet(n_channels=3, n_classes=1)
    else:
        model = UNet()

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device, dtype=torch.float)

    if config.optimizer == "sgd":
        optimizer = SGD(model.parameters(),
                        lr=config.lr,
                        weight_decay=1e-6,
                        momentum=0.9)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    if config.loss == "dice":
        criterion = DiceLoss()
    elif config.loss == "bce":
        criterion = nn.BCELoss()
    elif config.loss == "bas":
        criterion = BasLoss()
    else:
        criterion = MixLoss()

    scheduler = lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
    global_step = 0
    best_dice = 0.0
    for epoch in range(config.num_epochs):
        epoch_loss = 0.0
        with tqdm(total=config.num_train,
                  desc="Epoch %d / %d" % (epoch + 1, config.num_epochs),
                  unit='img') as train_pbar:
            model.train()
            for image, mask in train_loader:
                image = image.to(device, dtype=torch.float)
                mask = mask.to(device, dtype=torch.float)
                d0, d1, d2, d3, d4, d5, d6, d7 = model(image)
                loss = criterion(d0, d1, d2, d3, d4, d5, d6, d7, mask)
                epoch_loss += loss.item()

                writer.add_scalar('Loss/train', loss.item(), global_step)
                train_pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_pbar.update(image.shape[0])
                global_step += 1

                # if global_step % 100 == 0:
                #     writer.add_images('masks/true', mask, global_step)
                #     writer.add_images('masks/pred', d0 > 0.5, global_step)
            scheduler.step()
        epoch_dice = 0.0
        epoch_acc = 0.0
        epoch_sen = 0.0
        epoch_spe = 0.0
        epoch_pre = 0.0
        current_num = 0
        with tqdm(total=config.num_val,
                  desc="Epoch %d / %d validation round" %
                  (epoch + 1, config.num_epochs),
                  unit='img') as val_pbar:
            model.eval()
            locker = 0
            for image, mask in val_loader:
                current_num += image.shape[0]
                image = image.to(device, dtype=torch.float)
                mask = mask.to(device, dtype=torch.float)
                d0, d1, d2, d3, d4, d5, d6, d7 = model(image)
                batch_dice = dice_coeff(mask, d0).item()
                epoch_dice += batch_dice * image.shape[0]
                epoch_acc += get_accuracy(pred=d0, true=mask) * image.shape[0]
                epoch_sen += get_sensitivity(pred=d0,
                                             true=mask) * image.shape[0]
                epoch_spe += get_specificity(pred=d0,
                                             true=mask) * image.shape[0]
                epoch_pre += get_precision(pred=d0, true=mask) * image.shape[0]
                if locker == 200:
                    writer.add_images('masks/true', mask, epoch + 1)
                    writer.add_images('masks/pred', d0 > 0.5, epoch + 1)
                val_pbar.set_postfix(**{'dice (batch)': batch_dice})
                val_pbar.update(image.shape[0])
                locker += 1
            epoch_dice /= float(current_num)
            epoch_acc /= float(current_num)
            epoch_sen /= float(current_num)
            epoch_spe /= float(current_num)
            epoch_pre /= float(current_num)
            epoch_f1 = get_F1(SE=epoch_sen, PR=epoch_pre)
            if epoch_dice > best_dice:
                best_dice = epoch_dice
                writer.add_scalar('Best Dice/test', best_dice, epoch + 1)
                torch.save(
                    model, config.result_path + "/%s_%s_%d.pth" %
                    (config.model_type, str(epoch_dice), epoch + 1))
            logging.info('Validation Dice Coeff: {}'.format(epoch_dice))
            print("epoch dice: " + str(epoch_dice))
            writer.add_scalar('Dice/test', epoch_dice, epoch + 1)
            writer.add_scalar('Acc/test', epoch_acc, epoch + 1)
            writer.add_scalar('Sen/test', epoch_sen, epoch + 1)
            writer.add_scalar('Spe/test', epoch_spe, epoch + 1)
            writer.add_scalar('Pre/test', epoch_pre, epoch + 1)
            writer.add_scalar('F1/test', epoch_f1, epoch + 1)

    writer.close()
    print("Training finished")
Пример #20
0
def train(args, model, optimizer, dataloader_train, dataloader_val):
    plotting.output_file('learning_curve_%s_%s.html' %
                         (args.optimizer, args.context_path))
    fig_loss = plotting.figure(title='Loss Curve',
                               x_axis_label='epochs',
                               y_axis_label='loss',
                               plot_width=600,
                               plot_height=600)
    fig_precision = plotting.figure(title='Precision Curve',
                                    x_axis_label='epochs',
                                    y_axis_label='precision',
                                    plot_width=600,
                                    plot_height=600)
    fig_miou = plotting.figure(title='mIOU Curve',
                               x_axis_label='epochs',
                               y_axis_label='mIOU',
                               plot_width=600,
                               plot_height=600)

    if args.loss == 'dice':
        loss_func = DiceLoss()
    elif args.loss == 'crossentropy':
        loss_func = torch.nn.CrossEntropyLoss()

    max_miou = 0
    loss_list = []
    epoch_x = []
    precision_list = []
    miou_list = []
    for epoch in range(args.num_epochs):
        lr = poly_lr_scheduler(optimizer,
                               args.learning_rate,
                               iter=epoch,
                               max_iter=args.num_epochs)
        model.train()
        tq = tqdm.tqdm(total=len(dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr %f' % (epoch, lr))
        loss_record = []
        for i, (data, label) in enumerate(dataloader_train):
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()
            output, output_sup1, output_sup2 = model(data)
            loss1 = loss_func(output, label)
            loss2 = loss_func(output_sup1, label)
            loss3 = loss_func(output_sup2, label)
            loss = loss1 + loss2 + loss3
            tq.update(args.batch_size)
            tq.set_postfix(loss='%.6f' % loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_record.append(loss.item())
        tq.close()
        loss_train_mean = np.mean(loss_record)
        loss_list.append(loss_train_mean)
        print('loss for train : %f' % (loss_train_mean))

        if epoch % args.checkpoint_step == 0 and epoch != 0:
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            torch.save(
                model.module.state_dict(),
                os.path.join(args.save_model_path, 'latest_dice_loss.pth'))

        if epoch % args.validation_step == 0 or epoch == (args.num_epochs - 1):
            precision, miou = val(args, model, dataloader_val)
            if miou > max_miou:
                max_miou = miou
                torch.save(
                    model.module.state_dict(),
                    os.path.join(args.save_model_path, 'best_dice_loss.pth'))

            precision_list.append(precision)
            miou_list.append(miou)
            epoch_x.append(epoch)

    fig_loss.line(range(len(loss_list)),
                  loss_list,
                  legend_label='train loss, min: %.4f' % min(loss_list),
                  line_width=2,
                  line_color='red')
    fig_precision.line(epoch_x,
                       precision_list,
                       legend_label='precision, max: %.4f' %
                       max(precision_list),
                       line_width=2,
                       line_color='blue')
    fig_miou.line(epoch_x,
                  miou_list,
                  legend_label='miou, max: %.4f' % max(miou_list),
                  line_width=2,
                  line_color='green')
    plotting.save(row(fig_loss, fig_precision, fig_miou))
Пример #21
0
def main():
    # load data
    print('\nloading the dataset ...')
    assert opt.dataset == "ISIC2016" or opt.dataset == "ISIC2017"
    if opt.dataset == "ISIC2016":
        num_aug = 5
        normalize = Normalize((0.7012, 0.5517, 0.4875),
                              (0.0942, 0.1331, 0.1521))
    elif opt.dataset == "ISIC2017":
        num_aug = 2
        normalize = Normalize((0.6820, 0.5312, 0.4736),
                              (0.0840, 0.1140, 0.1282))
    if opt.over_sample:
        print('data is offline oversampled ...')
        train_file = 'train_oversample.csv'
    else:
        print('no offline oversampling ...')
        train_file = 'train.csv'
    im_size = 224
    transform_train = torch_transforms.Compose([
        RatioCenterCrop(0.8),
        Resize((256, 256)),
        RandomCrop((224, 224)),
        RandomRotate(),
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        ToTensor(), normalize
    ])
    transform_val = torch_transforms.Compose([
        RatioCenterCrop(0.8),
        Resize((256, 256)),
        CenterCrop((224, 224)),
        ToTensor(), normalize
    ])
    trainset = ISIC(csv_file=train_file, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=8,
        worker_init_fn=_worker_init_fn_(),
        drop_last=True)
    valset = ISIC(csv_file='val.csv', transform=transform_val)
    valloader = torch.utils.data.DataLoader(valset,
                                            batch_size=64,
                                            shuffle=False,
                                            num_workers=8)
    print('done\n')

    # load models
    print('\nloading the model ...')

    if not opt.no_attention:
        print('turn on attention ...')
        if opt.normalize_attn:
            print('use softmax for attention map ...')
        else:
            print('use sigmoid for attention map ...')
    else:
        print('turn off attention ...')

    net = AttnVGG(num_classes=2,
                  attention=not opt.no_attention,
                  normalize_attn=opt.normalize_attn)
    dice = DiceLoss()
    if opt.focal_loss:
        print('use focal loss ...')
        criterion = FocalLoss(gama=2., size_average=True, weight=None)
    else:
        print('use cross entropy loss ...')
        criterion = nn.CrossEntropyLoss()
    print('done\n')

    # move to GPU
    print('\nmoving models to GPU ...')
    model = nn.DataParallel(net, device_ids=device_ids).to(device)
    criterion.to(device)
    dice.to(device)
    print('done\n')

    # optimizer
    optimizer = optim.SGD(model.parameters(),
                          lr=opt.lr,
                          momentum=0.9,
                          weight_decay=1e-4,
                          nesterov=True)
    lr_lambda = lambda epoch: np.power(0.1, epoch // 10)
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    # training
    print('\nstart training ...\n')
    step = 0
    EMA_accuracy = 0
    AUC_val = 0
    writer = SummaryWriter(opt.outf)
    if opt.log_images:
        data_iter = iter(valloader)
        fixed_batch = next(data_iter)
        fixed_batch = fixed_batch['image'][0:16, :, :, :].to(device)
    for epoch in range(opt.epochs):
        torch.cuda.empty_cache()
        # adjust learning rate
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        writer.add_scalar('train/learning_rate', current_lr, epoch)
        print("\nepoch %d learning rate %f\n" % (epoch + 1, current_lr))
        # run for one epoch
        for aug in range(num_aug):
            for i, data in enumerate(trainloader, 0):
                # warm up
                model.train()
                model.zero_grad()
                optimizer.zero_grad()
                inputs, seg, labels = data['image'], data['image_seg'], data[
                    'label']
                seg = seg[:, -1:, :, :]
                seg_1 = F.adaptive_avg_pool2d(seg,
                                              im_size // opt.base_up_factor)
                seg_2 = F.adaptive_avg_pool2d(
                    seg, im_size // opt.base_up_factor // 2)
                inputs, seg_1, seg_2, labels = inputs.to(device), seg_1.to(
                    device), seg_2.to(device), labels.to(device)
                # forward
                pred, a1, a2 = model(inputs)
                # backward
                loss_c = criterion(pred, labels)
                loss_seg1 = dice(a1, seg_1)
                loss_seg2 = dice(a2, seg_2)
                loss = loss_c + 0.001 * loss_seg1 + 0.01 * loss_seg2
                loss.backward()
                optimizer.step()
                # display results
                if i % 10 == 0:
                    model.eval()
                    pred, __, __ = model(inputs)
                    predict = torch.argmax(pred, 1)
                    total = labels.size(0)
                    correct = torch.eq(predict, labels).sum().double().item()
                    accuracy = correct / total
                    EMA_accuracy = 0.9 * EMA_accuracy + 0.1 * accuracy
                    writer.add_scalar('train/loss_c', loss_c.item(), step)
                    writer.add_scalar('train/loss_seg1', loss_seg1.item(),
                                      step)
                    writer.add_scalar('train/loss_seg2', loss_seg2.item(),
                                      step)
                    writer.add_scalar('train/accuracy', accuracy, step)
                    writer.add_scalar('train/EMA_accuracy', EMA_accuracy, step)
                    print(
                        "[epoch %d][aug %d/%d][iter %d/%d] loss_c %.4f loss_seg1 %.4f loss_seg2 %.4f accuracy %.2f%% EMA %.2f%%"
                        %
                        (epoch + 1, aug + 1, num_aug, i + 1, len(trainloader),
                         loss.item(), loss_seg1.item(), loss_seg2.item(),
                         (100 * accuracy), (100 * EMA_accuracy)))
                step += 1
        # the end of each epoch - validation results
        model.eval()
        total = 0
        correct = 0
        with torch.no_grad():
            with open('val_results.csv', 'wt', newline='') as csv_file:
                csv_writer = csv.writer(csv_file, delimiter=',')
                for i, data in enumerate(valloader, 0):
                    images_val, labels_val = data['image'], data['label']
                    images_val, labels_val = images_val.to(
                        device), labels_val.to(device)
                    pred_val, __, __ = model(images_val)
                    predict = torch.argmax(pred_val, 1)
                    total += labels_val.size(0)
                    correct += torch.eq(predict,
                                        labels_val).sum().double().item()
                    # record predictions
                    responses = F.softmax(pred_val,
                                          dim=1).squeeze().cpu().numpy()
                    responses = [
                        responses[i] for i in range(responses.shape[0])
                    ]
                    csv_writer.writerows(responses)
            AP, AUC, precision_mean, precision_mel, recall_mean, recall_mel = compute_metrics(
                'val_results.csv', 'val.csv')
            # save checkpoints
            print('\nsaving checkpoints ...\n')
            checkpoint = {
                'state_dict': model.module.state_dict(),
                'opt_state_dict': optimizer.state_dict(),
            }
            torch.save(checkpoint,
                       os.path.join(opt.outf, 'checkpoint_latest.pth'))
            if AUC > AUC_val:  # save optimal validation model
                torch.save(checkpoint, os.path.join(opt.outf,
                                                    'checkpoint.pth'))
                AUC_val = AUC
            # log scalars
            writer.add_scalar('val/accuracy', correct / total, epoch)
            writer.add_scalar('val/mean_precision', precision_mean, epoch)
            writer.add_scalar('val/mean_recall', recall_mean, epoch)
            writer.add_scalar('val/precision_mel', precision_mel, epoch)
            writer.add_scalar('val/recall_mel', recall_mel, epoch)
            writer.add_scalar('val/AP', AP, epoch)
            writer.add_scalar('val/AUC', AUC, epoch)
            print("\n[epoch %d] val result: accuracy %.2f%%" %
                  (epoch + 1, 100 * correct / total))
            print(
                "\nmean precision %.2f%% mean recall %.2f%% \nprecision for mel %.2f%% recall for mel %.2f%%"
                % (100 * precision_mean, 100 * recall_mean,
                   100 * precision_mel, 100 * recall_mel))
            print("\nAP %.4f AUC %.4f\n optimal AUC: %.4f" %
                  (AP, AUC, AUC_val))
            # log images
            if opt.log_images:
                print('\nlog images ...\n')
                I_train = utils.make_grid(inputs[0:16, :, :, :],
                                          nrow=4,
                                          normalize=True,
                                          scale_each=True)
                I_seg_1 = utils.make_grid(seg_1[0:16, :, :, :],
                                          nrow=4,
                                          normalize=True,
                                          scale_each=True)
                I_seg_2 = utils.make_grid(seg_2[0:16, :, :, :],
                                          nrow=4,
                                          normalize=True,
                                          scale_each=True)
                writer.add_image('train/image', I_train, epoch)
                writer.add_image('train/seg1', I_seg_1, epoch)
                writer.add_image('train/seg2', I_seg_2, epoch)
                if epoch == 0:
                    I_val = utils.make_grid(fixed_batch,
                                            nrow=4,
                                            normalize=True,
                                            scale_each=True)
                    writer.add_image('val/image', I_val, epoch)
            if opt.log_images and (not opt.no_attention):
                print('\nlog attention maps ...\n')
                # training data
                __, a1, a2 = model(inputs[0:16, :, :, :])
                if a1 is not None:
                    attn1 = visualize_attn(I_train,
                                           a1,
                                           up_factor=opt.base_up_factor,
                                           nrow=4)
                    writer.add_image('train/attention_map_1', attn1, epoch)
                if a2 is not None:
                    attn2 = visualize_attn(I_train,
                                           a2,
                                           up_factor=2 * opt.base_up_factor,
                                           nrow=4)
                    writer.add_image('train/attention_map_2', attn2, epoch)
                # val data
                __, a1, a2 = model(fixed_batch)
                if a1 is not None:
                    attn1 = visualize_attn(I_val,
                                           a1,
                                           up_factor=opt.base_up_factor,
                                           nrow=4)
                    writer.add_image('val/attention_map_1', attn1, epoch)
                if a2 is not None:
                    attn2 = visualize_attn(I_val,
                                           a2,
                                           up_factor=2 * opt.base_up_factor,
                                           nrow=4)
                    writer.add_image('val/attention_map_2', attn2, epoch)
Пример #22
0
def main():
    # load the config
    config = parse_train_config()
    # load the model
    model = Unet3d(in_channels=config.in_channels,
                   out_channels=config.out_channels,
                   interpolate=config.interpolate,
                   concatenate=config.concatenate,
                   norm_type=config.norm_type,
                   init_channels=config.init_channels,
                   scale_factor=(2, 2, 2))

    if config.init_weight:
        model.apply(init_weight)

    # get the device to train on
    gpu_all = tuple(config.gpu_index)
    gpu_main = gpu_all[0]

    device = torch.device(
        'cuda:' + str(gpu_main) if torch.cuda.is_available() else 'cpu')
    model = nn.DataParallel(model, device_ids=gpu_all)
    model.to(device)

    # load the saved checkpoint - update model parameters
    utils.load_checkpoint(config.model_path, device, model)
    for params in model.parameters():
        params.requires_grad = False

    model2 = smallmodel(out_channels=config.out_channels,
                        interpolate=config.interpolate,
                        norm_type=config.norm_type,
                        init_channels=config.init_channels)

    if config.init_weight:
        model2.apply(init_weight)

    model2 = nn.DataParallel(model2, device_ids=gpu_all)
    model2.to(device)

    # load data
    phase = 'train'
    train_dataset = Hdf5Dataset(config.data_path, phase,
                                config.train_sub_index)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=config.train_batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True,
                              drop_last=True)
    val_dataset = Hdf5Dataset(config.data_path, phase, config.val_sub_index)
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=config.val_batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)

    # define accuracy
    accuracy_criterion = DiceAccuracy()

    # define loss
    if config.loss_weight is None:
        loss_criterion = DiceLoss()
    else:
        loss_criterion = DiceLoss(weight=config.loss_weight)

    # define optimizer
    optimizer = torch.optim.Adam(model2.parameters(),
                                 lr=config.learning_rate,
                                 weight_decay=config.weight_decay)

    trainer = Trainer(config, model, model2, device, train_loader, val_loader,
                      accuracy_criterion, loss_criterion, optimizer)
    trainer.main()
Пример #23
0
def display_dataset_details(dataset, train_set, val_set):
    x = 'Total Images :'+str(len(dataset))
    y = 'Total Training Images :'+str(len(train_set))
    z = 'Total Validation Images :'+str(len(val_set))

    return x,y,z

if(add_selectbox == 'Training'):
    st.header(add_selectbox)
    st.markdown('Device Detected : '+str(device))
    st.write('Select Training Parameters')
    epochs = st.number_input('Epochs', min_value = 1, value = 2)
    lr = st.number_input('Learning Rate', min_value = 0.0001, max_value = None, value = 0.0010, step = 0.001, format = '%f')

    if st.button('Load Data'):
        dsc_loss = DiceLoss()
        dataset = SyntheticCellDataset('dataset')
        indices = torch.randperm(len(dataset)).tolist()
        sr = int(0.2 * len(dataset))
        train_set = torch.utils.data.Subset(dataset, indices[:-sr])
        val_set = torch.utils.data.Subset(dataset, indices[-sr:])
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=2, shuffle=True, pin_memory=True)
        val_loader = torch.utils.data.DataLoader(val_set, batch_size=2, shuffle=False, pin_memory=True)
        st.write('Data Loaded')
        x,y,z = display_dataset_details(dataset, train_set, val_set)
        st.write(x)
        st.write(y)
        st.write(z)
        #if st.button('Start Training'):
        model = UNet()
        model.to(device)
Пример #24
0
def train(args, model, optimizer, dataloader_train, dataloader_val_train,
          dataloader_test):
    writer = SummaryWriter(log_dir='runs_50_adadelta',
                           comment=''.format(args.optimizer,
                                             args.context_path))
    if args.loss == 'dice':
        loss_func = DiceLoss()
    elif args.loss == 'crossentropy':
        loss_func = torch.nn.CrossEntropyLoss()
    max_miou = 0
    step = 0
    for epoch in range(args.epoch_start_i, args.num_epochs):
        lr = poly_lr_scheduler(optimizer,
                               args.learning_rate,
                               iter=epoch,
                               max_iter=args.num_epochs)
        model.train()
        tq = tqdm.tqdm(total=len(dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr %f' % (epoch, lr))
        loss_record = []
        for i, (data, label) in enumerate(dataloader_train):
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()
            output, output_sup1, output_sup2 = model(data)
            loss1 = loss_func(output, label)
            loss2 = loss_func(output_sup1, label)
            loss3 = loss_func(output_sup2, label)
            loss = loss1 + loss2 + loss3
            tq.update(args.batch_size)
            tq.set_postfix(loss='%.6f' % loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            step += 1
            writer.add_scalar('loss_step', loss, step)
            loss_record.append(loss.item())
        tq.close()
        loss_train_mean = np.mean(loss_record)
        writer.add_scalar('epoch/loss_epoch_train', float(loss_train_mean),
                          epoch)
        print('loss for train : %f' % (loss_train_mean))
        if epoch % args.checkpoint_step == 0 and epoch != 0:
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            torch.save(
                model.module.state_dict(),
                os.path.join(args.save_model_path, 'latest_dice_loss.pth'))

        if epoch % args.validation_step == 0:
            #precision, miou = val(args, model, dataloader_val)
            oa, miou, cm, cks, f1 = val(args, model, dataloader_val_train,
                                        'train')
            oa_test, miou_test, cm_test, cks_test, f1_test = val(
                args, model, dataloader_test, 'test')
            if miou > max_miou:
                max_miou = miou
                torch.save(
                    model.module.state_dict(),
                    os.path.join(args.save_model_path, 'best_dice_loss.pth'))
            #writer.add_scalar('epoch/precision_val', precision, epoch)
            writer.add_scalar('epoch/oa_train', oa, epoch)
            writer.add_scalar('epoch/oa_test', oa_test, epoch)
            #writer.add_scalar('epoch/miou val', miou, epoch)
            writer.add_scalar('epoch/miou_train', miou, epoch)
            writer.add_scalar('epoch/miou_test', miou_test, epoch)
            writer.add_scalar('epoch/cks_train', cks, epoch)
            writer.add_scalar('epoch/cks_test', cks_test, epoch)
            writer.add_scalar('epoch/f1_train', f1, epoch)
            writer.add_scalar('epoch/f1_test', f1_test, epoch)
            with open(os.path.join(args.save_model_path,
                                   'classification_results.txt'),
                      mode='a') as f:
                f.write('epoch: ' + str(epoch) + '\n')
                # f.write('train time:\t' + str(train_time))
                # f.write('\ntest time:\t' + str(test_time))
                f.write('\nmiou:\t' + str(miou))
                f.write('\noverall accuracy:\t' + str(oa))
                f.write('\ncohen kappa:\t' + str(cks))
                f.write('\nconfusion matrix:\n')
                f.write(str(cm))
                f.write('\nf1:\t' + str(f1))
                f.write('\n\n')
Пример #25
0
data_dir = './data'
train_csv_path = os.path.join(data_dir, 'train.csv')
test_csv_path = os.path.join(data_dir, 'test.csv')

train_images_dir = os.path.join(data_dir, 'stage_1_train_images/')
test_images_dir = os.path.join(data_dir, 'stage_1_test_images/')

train_df, train_loader, dev_pids, dev_loader, dev_dataset_for_predict, dev_loader_for_predict, test_loader, test_df, test_pids, boxes_by_pid_dict, min_box_area = load_data(
    train_csv_path, test_csv_path, train_images_dir, test_images_dir,
    batch_size, validation_prop, rescale_factor)
min_box_area = int(round(min_box_area / float(rescale_factor**2)))

# model = torch.nn.DataParallel(LeakyUNET().cuda(), device_ids=[0, 1, 2, 3, 4, 5, 6, 7])
model = torch.nn.DataParallel(LeakyUNET().cuda(), device_ids=[0, 1, 2, 3])

loss_fn = DiceLoss().cuda()

init_learning_rate = 0.5

num_epochs = 1 if debug else 5
num_train_steps = 5 if debug else len(train_loader)
num_dev_steps = 5 if debug else len(dev_loader)

img_dim = int(round(original_dim / rescale_factor))

print("Training for {} epochs".format(num_epochs))
histories, best_models = train_and_evaluate(model,
                                            train_loader,
                                            dev_loader,
                                            init_learning_rate,
                                            loss_fn,
Пример #26
0
parser.add_argument('--restart', type=int, default=50,
                    help='restart learning rate every <restart> epochs')
parser.add_argument('--resume_model',
                    type=str,
                    default=None,
                    help='path to load previously saved model')
args = parser.parse_args(argv)
print(args)

is_cuda = torch.cuda.is_available()

net = UNet3D(1, 1, use_bias=True, inplanes=16)
if args.resume_model is not None:
    transfer_weights(net, args.resume_model)
bce_crit = nn.BCELoss()
dice_crit = DiceLoss()
last_bce_loss = 0
last_dice_loss = 0


def criterion(pred, labels, weights=[0.1, 0.9]):
    _bce_loss = bce_crit(pred, labels)
    _dice_loss = dice_crit(pred, labels)
    global last_bce_loss, last_dice_loss
    last_bce_loss = _bce_loss.item()
    last_dice_loss = _dice_loss.item()
    return weights[0] * _bce_loss + weights[1] * _dice_loss


size = args.volume_size * 3 if len(args.volume_size) == 1 else args.volume_size
assert len(size) == 3
Пример #27
0
def main():
    args = parser.parse_args()
    save_path = 'Trainid_' + args.id
    writer = SummaryWriter(log_dir='runs/' + args.tag + str(time.time()))
    if not os.path.isdir(save_path):
        os.mkdir(save_path)
        os.mkdir(save_path + '/Checkpoint')

    train_dataset_path = 'data/train'
    val_dataset_path = 'data/valid'
    train_transform = transforms.Compose([ToTensor()])
    val_transform = transforms.Compose([ToTensor()])
    train_dataset = TrainDataset(path=train_dataset_path,
                                 transform=train_transform)
    val_dataset = TrainDataset(path=val_dataset_path, transform=val_transform)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=4)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                pin_memory=True,
                                num_workers=4)

    size_train = len(train_dataloader)
    size_val = len(val_dataloader)
    print('Number of Training Images: {}'.format(size_train))
    print('Number of Validation Images: {}'.format(size_val))
    start_epoch = 0
    model = Res(n_ch=4, n_classes=9)
    class_weights = torch.Tensor([1, 1, 1, 1, 1, 1, 1, 1, 0]).cuda()
    criterion = DiceLoss()
    criterion1 = torch.nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    if args.gpu:
        model = model.cuda()
        criterion = criterion.cuda()
        criterion1 = criterion1.cuda()

    if args.resume is not None:
        weight_path = sorted(os.listdir(save_path + '/Checkpoint/'),
                             key=lambda x: float(x[:-8]))[0]
        checkpoint = torch.load(save_path + '/Checkpoint/' + weight_path)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print('Loaded Checkpoint of Epoch: {}'.format(args.resume))

    for epoch in range(start_epoch, int(args.epoch) + start_epoch):
        adjust_learning_rate(optimizer, epoch)
        train(model, train_dataloader, criterion, criterion1, optimizer, epoch,
              writer, size_train)
        print('')
        val_loss = val(model, val_dataloader, criterion, criterion1, epoch,
                       writer, size_val)
        print('')
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            filename=save_path + '/Checkpoint/' + str(val_loss) + '.pth.tar')
    writer.export_scalars_to_json(save_path + '/log.json')
Пример #28
0
def train_net(image_size, batch_size, num_epochs, lr, num_workers, checkpoint):
    train_loader, val_loader = data_loaders(image_size=(image_size,
                                                        image_size),
                                            batch_size=batch_size)
    device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
    model = Unet().to(device)
    if checkpoint:
        model.load_state_dict(torch.load(checkpoint))

    criterion = DiceLoss().to(device)
    optimizer = Adam(model.parameters(), lr=lr)

    logging.info(f'Start training:\n'
                 f'Num epochs:               {num_epochs}\n'
                 f'Batch size:               {batch_size}\n'
                 f'Learning rate:            {lr}\n'
                 f'Num workers:              {num_workers}\n'
                 f'Scale image size:         {image_size}\n'
                 f'Device:                   {device}\n'
                 f'Checkpoint:               {checkpoint}\n')

    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}: ')
        train_batch_losses = []
        val_batch_losses = []
        best_val_loss = 9999

        for x_train, y_train in tqdm(train_loader):
            x_train = x_train.to(device)
            y_train = y_train.to(device)
            y_pred = model(x_train)

            optimizer.zero_grad()
            loss = criterion(y_pred, y_train)
            train_batch_losses.append(loss.item())
            loss.backward()
            optimizer.step()

        train_losses.append(sum(train_batch_losses) / len(train_batch_losses))
        print(
            f'-----------------------Train loss: {train_losses[-1]} -------------------------------'
        )

        for x_val, y_val in tqdm(val_loader):
            x_val = x_val.to(device)
            y_val = y_val.to(device)
            y_pred = model(x_val)

            loss = criterion(y_pred, y_val)
            val_batch_losses.append(loss.item())

        val_losses.append(sum(val_batch_losses) / len(val_batch_losses))
        print(
            f'-----------------------Val loss: {val_losses[-1]} -------------------------------'
        )
        if val_losses[-1] < best_val_loss:
            best_val_loss = val_losses[-1]
            if not os.path.isdir('weights/'):
                os.mkdir('weights/')
            torch.save(model.state_dict(), f'weights/checkpoint{epoch+1}.pth')
            print(f'Save checkpoint in: weights/checkpoint{epoch+1}.pth')
Пример #29
0
# optimizer = torch.optim.RMSprop(params, lr=config.lr, alpha = 0.95)
# optimizer = RAdam(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0)
# optimizer = PlainRAdam(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0)
if os.path.exists(config.init_optimizer):
    ckpt = torch.load(config.init_optimizer)
    optimizer.load_state_dict(ckpt['optimizer'])

# lr_scheduler
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.3)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.num_epochs*len(data_loader))
lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=100, 
                                      total_epoch=min(1000, len(data_loader)-1), 
                                      after_scheduler=scheduler_cosine)

# loss function
criterion = DiceLoss()
# criterion = Weight_Soft_Dice_Loss(weight=[0.1, 0.9])
# criterion = BCELoss()
# criterion = MixedLoss(10.0, 2.0)
# criterion = Weight_BCELoss(weight_pos=0.25, weight_neg=0.75)
# criterion = Lovasz_Loss(margin=[1, 5]

print('start training...')
train_start = time.time()
for epoch in range(config.num_epochs):
    epoch_start = time.time()
    model_ft, optimizer = train_one_epoch(model_ft, data_loader, criterion, 
                                          optimizer, lr_scheduler=lr_scheduler, device=device, 
                                          epoch=epoch, vis=vis)
    do_valid(model_ft, dataloader_val, criterion, epoch, device, vis=vis)
    print('Epoch time: {:.3f}min\n'.format((time.time()-epoch_start)/60/60))
Пример #30
0
    def __init__(self, args, logger):
        """Constructor for training algorithm.

        Args:
            args: From command line, picked up by `argparse`.

        Initializes:
            - Data: train, val and test.
            - Model: shared and controller.
            - Inference: optimizers for shared and controller parameters.
            - Criticism: cross-entropy loss for training the shared model.
        """
        self.args = args
        self.controller_step = 0
        self.cuda = args.cuda
        self.epoch = 0
        self.shared_step = 0
        self.start_epoch = 0
        self.logger = logger
        self.baseline = None
        """Load dataset"""
        self.load_dataset()
        if args.mode == 'train':
            self.train_data_loader.restart()

        if args.use_tensorboard:
            self.tb = TensorBoard(args.model_dir)
        else:
            self.tb = None
        self.build_model()

        if self.args.load_path:
            self.load_model()

        shared_optimizer = _get_optimizer(self.args.shared_optim)
        controller_optimizer = _get_optimizer(self.args.controller_optim)

        self.shared_optim = shared_optimizer(
            self.shared.parameters(),
            lr=self.shared_lr,
        )

        self.controller_optim = controller_optimizer(
            self.controller.parameters(), lr=self.args.controller_lr)

        self.ce = nn.CrossEntropyLoss()
        if self.args.loss == 'MulticlassDiceLoss':
            self.model_loss = MulticlassDiceLoss()
        else:
            self.model_loss = DiceLoss()
        self.time = time.time()
        self.dag_file = open(
            self.args.model_dir + '/' + self.args.mode + '_dag.log', 'a')

        cnn_type_index = {}
        for i, action in enumerate(self.args.shared_cnn_types):
            cnn_type_index[action] = i
        if self.args.use_ref:
            self.ref_arch_num = []
            ip = []
            action = []
            for i, block in enumerate(self.args.ref_arch):
                ip.append(block[0])
                action.append(cnn_type_index[block[1]])

            for i in range(len(ip) / 2):
                self.ref_arch_num.append(
                    [ip[i], ip[i + 1], action[i], action[i + 1]])

            self.ref_arch_num = np.array(self.ref_arch_num)
            self.ref_arch_num = self.ref_arch_num.reshape(
                1, 2 * len(self.ref_arch_num))
            """