示例#1
0
class Train(object):
    def __init__(self, configs):
        self.batch_size = configs.get("batch_size", "16")
        self.epochs = configs.get("epochs", "100")
        self.lr = configs.get("lr", "0.0001")

        device_args = configs.get("device", "cuda")
        self.device = torch.device(
            "cpu" if not torch.cuda.is_available() else device_args)

        self.workers = configs.get("workers", "4")

        self.vis_images = configs.get("vis_images", "200")
        self.vis_freq = configs.get("vis_freq", "10")

        self.weights = configs.get("weights", "./weights")
        if not os.path.exists(self.weights):
            os.mkdir(self.weights)

        self.logs = configs.get("logs", "./logs")
        if not os.path.exists(self.weights):
            os.mkdir(self.weights)

        self.images_path = configs.get("images_path", "./data")

        self.is_resize = config.get("is_resize", False)
        self.image_short_side = config.get("image_short_side", 256)

        self.is_padding = config.get("is_padding", False)

        is_multi_gpu = config.get("DateParallel", False)

        pre_train = config.get("pre_train", False)
        model_path = config.get("model_path", './weights/unet_idcard_adam.pth')

        # self.image_size = configs.get("image_size", "256")
        # self.aug_scale = configs.get("aug_scale", "0.05")
        # self.aug_angle = configs.get("aug_angle", "15")

        self.step = 0

        self.dsc_loss = DiceLoss()
        self.model = UNet(in_channels=Dataset.in_channels,
                          out_channels=Dataset.out_channels)
        if pre_train:
            self.model.load_state_dict(torch.load(model_path,
                                                  map_location=self.device),
                                       strict=False)

        if is_multi_gpu:
            self.model = nn.DataParallel(self.model)

        self.model.to(self.device)

        self.best_validation_dsc = 0.0

        self.loader_train, self.loader_valid = self.data_loaders()

        self.params = [p for p in self.model.parameters() if p.requires_grad]

        self.optimizer = optim.Adam(self.params,
                                    lr=self.lr,
                                    weight_decay=0.0005)
        # self.optimizer = torch.optim.SGD(self.params, lr=self.lr, momentum=0.9, weight_decay=0.0005)
        self.scheduler = lr_scheduler.LR_Scheduler_Head(
            'poly', self.lr, self.epochs, len(self.loader_train))

    def datasets(self):
        train_datasets = Dataset(
            images_dir=self.images_path,
            # image_size=self.image_size,
            subset="train",  # train
            transform=get_transforms(train=True),
            is_resize=self.is_resize,
            image_short_side=self.image_short_side,
            is_padding=self.is_padding)
        # valid_datasets = train_datasets

        valid_datasets = Dataset(
            images_dir=self.images_path,
            # image_size=self.image_size,
            subset="validation",  # validation
            transform=get_transforms(train=False),
            is_resize=self.is_resize,
            image_short_side=self.image_short_side,
            is_padding=False)
        return train_datasets, valid_datasets

    def data_loaders(self):
        dataset_train, dataset_valid = self.datasets()

        loader_train = DataLoader(
            dataset_train,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=self.workers,
        )
        loader_valid = DataLoader(
            dataset_valid,
            batch_size=1,
            drop_last=False,
            num_workers=self.workers,
        )

        return loader_train, loader_valid

    @staticmethod
    def dsc_per_volume(validation_pred, validation_true):
        assert len(validation_pred) == len(validation_true)
        dsc_list = []
        for p in range(len(validation_pred)):
            y_pred = np.array([validation_pred[p]])
            y_true = np.array([validation_true[p]])
            dsc_list.append(dsc(y_pred, y_true))
        return dsc_list

    @staticmethod
    def get_logger(filename, verbosity=1, name=None):
        level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
        formatter = logging.Formatter(
            "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
        )
        logger = logging.getLogger(name)
        logger.setLevel(level_dict[verbosity])

        fh = logging.FileHandler(filename, "w")
        fh.setFormatter(formatter)
        logger.addHandler(fh)

        sh = logging.StreamHandler()
        sh.setFormatter(formatter)
        logger.addHandler(sh)

        return logger

    def train_one_epoch(self, epoch):

        self.model.train()
        loss_train = []
        for i, data in enumerate(self.loader_train):
            self.scheduler(self.optimizer, i, epoch, self.best_validation_dsc)
            x, y_true = data
            x, y_true = x.to(self.device), y_true.to(self.device)

            y_pred = self.model(x)
            # print('1111', y_pred.size())
            # print('2222', y_true.size())
            loss = self.dsc_loss(y_pred, y_true)

            loss_train.append(loss.item())

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # lr_scheduler.step()
            if self.step % 200 == 0:
                print('Epoch:[{}/{}]\t iter:[{}]\t loss={:.5f}\t '.format(
                    epoch, self.epochs, i, loss))

            self.step += 1

    def eval_model(self, patience):
        self.model.eval()
        loss_valid = []

        validation_pred = []
        validation_true = []
        # early_stopping = EarlyStopping(patience=patience, verbose=True)

        for i, data in enumerate(self.loader_valid):
            x, y_true = data
            x, y_true = x.to(self.device), y_true.to(self.device)

            # print(x.size())
            # print(333,x[0][2])
            with torch.no_grad():
                y_pred = self.model(x)
                loss = self.dsc_loss(y_pred, y_true)

            # print(y_pred.shape)
            mask = y_pred > 0.5
            mask = mask * 255
            mask = mask.cpu().numpy()[0][0]
            # print(mask)
            # print(mask.shape())
            cv2.imwrite('result.png', mask)

            loss_valid.append(loss.item())

            y_pred_np = y_pred.detach().cpu().numpy()

            validation_pred.extend(
                [y_pred_np[s] for s in range(y_pred_np.shape[0])])
            y_true_np = y_true.detach().cpu().numpy()
            validation_true.extend(
                [y_true_np[s] for s in range(y_true_np.shape[0])])

        # early_stopping(loss_valid, self.model)
        # if early_stopping.early_stop:
        #     print('Early stopping')
        #     import sys
        #     sys.exit(1)
        mean_dsc = np.mean(
            self.dsc_per_volume(
                validation_pred,
                validation_true,
            ))
        # print('mean_dsc:', mean_dsc)
        if mean_dsc > self.best_validation_dsc:
            self.best_validation_dsc = mean_dsc
            torch.save(self.model.state_dict(),
                       os.path.join(self.weights, "unet_xia_adam.pth"))
            print("Best validation mean DSC: {:4f}".format(
                self.best_validation_dsc))

    def main(self):
        # print('train is begin.....')
        # print('load data end.....')

        # loaders = {"train": loader_train, "valid": loader_valid}

        for epoch in tqdm(range(self.epochs), total=self.epochs):
            self.train_one_epoch(epoch)
            self.eval_model(patience=10)

        torch.save(self.model.state_dict(),
                   os.path.join(self.weights, "unet_final.pth"))
示例#2
0
        model.cuda(cuda)
    optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate,  betas=(0.9, 0.999))

#    loss_all = np.zeros((2000, 4))
    for epoch in range(2000):
        lr = get_learning_rate(epoch)
        for p in optimizer.param_groups:
            p['lr'] = lr
            print("learning rate = {}".format(p['lr']))
        for batch_idx, items in enumerate(train_dataloader):
            image = items[0]
            gt = items[1]
            model.train()

            gt = gt.float()
            if have_cuda:
                gt = gt.cuda(cuda)
            
            pred = model(image)

            loss = (pred-gt).abs().mean() + 5 * ((pred-gt)**2).mean()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print ('epoch : ',epoch, 'loss: ',loss.item())
        #if epoch%10 == 5:
            save_pred(epoch,model,test_dataloader)
        torch.save(model.state_dict(), "out/sUNet_microtubule_"+str(epoch+1)+".pkl")
示例#3
0
            loss = criterion(y_pred, y.unsqueeze(0).float())

            l += loss.data[0]
            loss.backward()
            if i % 10 == 0:
                optimizer.step()
                print('Stepped')

            print('{0:.4f}%\t\t{1:.6f}'.format(i / len(ids) * 100,
                                               loss.data[0]))

        l = l / len(ids)
        print('Loss : {}'.format(l))
        torch.save(net.state_dict(),
                   'MODEL_EPOCH{}_LOSS{}.pth'.format(epoch + 1, l))
        print('Saved')


try:
    net.load_state_dict(torch.load('MODEL_INTERRUPTED.pth'))
    train(net)

except KeyboardInterrupt:
    print('Interrupted')
    torch.save(net.state_dict(), 'MODEL_INTERRUPTED.pth')
    try:
        sys.exit(0)
    except SystemExit:
        os._exit(0)
示例#4
0
def train(cont=False):

    # for tensorboard tracking
    logger = get_logger()
    logger.info("(1) Initiating Training ... ")
    logger.info("Training on device: {}".format(device))
    writer = SummaryWriter()

    # init model
    aux_layers = None
    if net == "SETR-PUP":
        aux_layers, model = get_SETR_PUP()
    elif net == "SETR-MLA":
        aux_layers, model = get_SETR_MLA()
    elif net == "TransUNet-Base":
        model = get_TransUNet_base()
    elif net == "TransUNet-Large":
        model = get_TransUNet_large()
    elif net == "UNet":
        model = UNet(CLASS_NUM)

    # prepare dataset
    cluster_model = get_clustering_model(logger)
    train_dataset = CityscapeDataset(img_dir=data_dir,
                                     img_dim=IMG_DIM,
                                     mode="train",
                                     cluster_model=cluster_model)
    valid_dataset = CityscapeDataset(img_dir=data_dir,
                                     img_dim=IMG_DIM,
                                     mode="val",
                                     cluster_model=cluster_model)
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=batch_size,
                              shuffle=False)

    logger.info("(2) Dataset Initiated. ")

    # optimizer
    epochs = epoch_num if epoch_num > 0 else iteration_num // len(
        train_loader) + 1
    optim = SGD(model.parameters(),
                lr=lrate,
                momentum=momentum,
                weight_decay=wdecay)
    # optim = Adam(model.parameters(), lr=lrate)
    scheduler = lr_scheduler.MultiStepLR(
        optim, milestones=[int(epochs * fine_tune_ratio)], gamma=0.1)

    cur_epoch = 0
    best_loss = float('inf')
    epochs_since_improvement = 0

    # for continue training
    if cont:
        model, optim, cur_epoch, best_loss = load_ckpt_continue_training(
            best_ckpt_src, model, optim, logger)
        logger.info("Current best loss: {0}".format(best_loss))
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for i in range(cur_epoch):
                scheduler.step()
    else:
        model = nn.DataParallel(model)
        model = model.to(device)

    logger.info("(3) Model Initiated ... ")
    logger.info("Training model: {}".format(net) + ". Training Started.")

    # loss
    ce_loss = CrossEntropyLoss()
    if use_dice_loss:
        dice_loss = DiceLoss(CLASS_NUM)

    # loop over epochs
    iter_count = 0
    epoch_bar = tqdm.tqdm(total=epochs,
                          desc="Epoch",
                          position=cur_epoch,
                          leave=True)
    logger.info("Total epochs: {0}. Starting from epoch {1}.".format(
        epochs, cur_epoch + 1))

    for e in range(epochs - cur_epoch):
        epoch = e + cur_epoch

        # Training.
        model.train()
        trainLossMeter = LossMeter()
        train_batch_bar = tqdm.tqdm(total=len(train_loader),
                                    desc="TrainBatch",
                                    position=0,
                                    leave=True)

        for batch_num, (orig_img, mask_img) in enumerate(train_loader):
            orig_img, mask_img = orig_img.float().to(
                device), mask_img.float().to(device)

            if net == "TransUNet-Base" or net == "TransUNet-Large":
                pred = model(orig_img)
            elif net == "SETR-PUP" or net == "SETR-MLA":
                if aux_layers is not None:
                    pred, _ = model(orig_img)
                else:
                    pred = model(orig_img)
            elif net == "UNet":
                pred = model(orig_img)

            loss_ce = ce_loss(pred, mask_img[:].long())
            if use_dice_loss:
                loss_dice = dice_loss(pred, mask_img, softmax=True)
                loss = 0.5 * (loss_ce + loss_dice)
            else:
                loss = loss_ce

            # Backward Propagation, Update weight and metrics
            optim.zero_grad()
            loss.backward()
            optim.step()

            # update learning rate
            for param_group in optim.param_groups:
                orig_lr = param_group['lr']
                param_group['lr'] = orig_lr * (1.0 -
                                               iter_count / iteration_num)**0.9
            iter_count += 1

            # Update loss
            trainLossMeter.update(loss.item())

            # print status
            if (batch_num + 1) % print_freq == 0:
                status = 'Epoch: [{0}][{1}/{2}]\t' \
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch+1, batch_num+1, len(train_loader), loss=trainLossMeter)
                logger.info(status)

            # log loss to tensorboard
            if (batch_num + 1) % tensorboard_freq == 0:
                writer.add_scalar(
                    'Train_Loss_{0}'.format(tensorboard_freq),
                    trainLossMeter.avg,
                    epoch * (len(train_loader) / tensorboard_freq) +
                    (batch_num + 1) / tensorboard_freq)
            train_batch_bar.update(1)

        writer.add_scalar('Train_Loss_epoch', trainLossMeter.avg, epoch)

        # Validation.
        model.eval()
        validLossMeter = LossMeter()
        valid_batch_bar = tqdm.tqdm(total=len(valid_loader),
                                    desc="ValidBatch",
                                    position=0,
                                    leave=True)
        with torch.no_grad():
            for batch_num, (orig_img, mask_img) in enumerate(valid_loader):
                orig_img, mask_img = orig_img.float().to(
                    device), mask_img.float().to(device)

                if net == "TransUNet-Base" or net == "TransUNet-Large":
                    pred = model(orig_img)
                elif net == "SETR-PUP" or net == "SETR-MLA":
                    if aux_layers is not None:
                        pred, _ = model(orig_img)
                    else:
                        pred = model(orig_img)
                elif net == "UNet":
                    pred = model(orig_img)

                loss_ce = ce_loss(pred, mask_img[:].long())
                if use_dice_loss:
                    loss_dice = dice_loss(pred, mask_img, softmax=True)
                    loss = 0.5 * (loss_ce + loss_dice)
                else:
                    loss = loss_ce

                # Update loss
                validLossMeter.update(loss.item())

            # print status
            if (batch_num + 1) % print_freq == 0:
                status = 'Validation: [{0}][{1}/{2}]\t' \
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch+1, batch_num+1, len(valid_loader), loss=validLossMeter)
                logger.info(status)

            # log loss to tensorboard
            if (batch_num + 1) % tensorboard_freq == 0:
                writer.add_scalar(
                    'Valid_Loss_{0}'.format(tensorboard_freq),
                    validLossMeter.avg,
                    epoch * (len(valid_loader) / tensorboard_freq) +
                    (batch_num + 1) / tensorboard_freq)
            valid_batch_bar.update(1)

        valid_loss = validLossMeter.avg
        writer.add_scalar('Valid_Loss_epoch', valid_loss, epoch)
        logger.info("Validation Loss of epoch [{0}/{1}]: {2}\n".format(
            epoch + 1, epochs, valid_loss))

        # update optim scheduler
        scheduler.step()

        # save checkpoint
        is_best = valid_loss < best_loss
        best_loss_tmp = min(valid_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            logger.info("Epochs since last improvement: %d\n" %
                        (epochs_since_improvement, ))
            if epochs_since_improvement == early_stop_tolerance:
                break  # early stopping.
        else:
            epochs_since_improvement = 0
            state = {
                'epoch': epoch,
                'loss': best_loss_tmp,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optim.state_dict(),
            }
            torch.save(state, ckpt_src)
            logger.info("Checkpoint updated.")
            best_loss = best_loss_tmp
        epoch_bar.update(1)
    writer.close()
            image = items['image_in']
            gt = items['groundtruth']

            model.train()

            image = np.swapaxes(image, 1, 3)
            image = np.swapaxes(image, 2, 3)
            image = image.float()
            image = image.cuda(cuda)

            gt = gt.squeeze()
            gt = gt.float()
            gt = gt.cuda(cuda)

            pred = model(image).squeeze()

            loss = (pred - gt).abs().mean() + 5 * ((pred - gt)**2).mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print("[Epoch %d] [Batch %d/%d] [loss: %f]" %
                  (epoch, batch_idx, len(train_dataloader), loss.item()))

        torch.save(
            model.state_dict(),
            "/home/star/0_code_lhj/DL-SIM-github/Training_codes/UNet/UNet_SIM15_microtubule.pkl"
        )
示例#6
0
                                          str(iteration).zfill(6)),
                                      padding=2,
                                      nrow=4,
                                      normalize=True)
                    vutils.save_image(fake_B,
                                      os.path.join(
                                          samples_path, '%s_fake_B.jpg' %
                                          str(iteration).zfill(6)),
                                      padding=2,
                                      nrow=4,
                                      normalize=True)
                    save_loss_image(list_loss, samples_path)

                iteration += 1

        # -----------------------------------------------------------
        # Save model
        if ((epoch + 1) % 5 == 0) or (epoch == max_epochs - 1):
            print('save model')
            save_path = os.path.join(samples_path,
                                     'checkpoint_iteration_%d.tar' % iteration)
            torch.save(
                {
                    'epoch': epoch,
                    'iteration': iteration,
                    'net_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                    'list_loss': list_loss,
                }, save_path)
示例#7
0
            p['lr'] = lr
            print("learning rate = {}".format(p['lr']))
            
        for batch_idx, items in enumerate(train_dataloader):
            
            image = items['image_in']
            gt = items['groundtruth']
            
            model.train()
            
            image = np.swapaxes(image, 1,3)
            image = np.swapaxes(image, 2,3)
            image = image.float()
            image = image.cuda(cuda)    
            
            gt = gt.squeeze()
            gt = gt.float()
            gt = gt.cuda(cuda)
            
            pred = model(image).squeeze()

            loss = (pred-gt).abs().mean() + 5 * ((pred-gt)**2).mean()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print ("[Epoch %d] [Batch %d/%d] [loss: %f]" % (epoch, batch_idx, len(train_dataloader), loss.item()))

        torch.save(model.state_dict(), "/home/star/0_code_lhj/DL-SIM-github/Training_codes/UNet/UNet_SIM15_LowLight_microtubule.pkl")
def train():
    print('Loading the dataset...')
    print('Training ' + model + ' on DRIVE patches')
    # ============ Load the data
    dataset = DRIVE_train_dataset(
        root=config.get('training settings', 'train_data_dir'),
        train_file=config.get('training settings', 'train_file'),
        flip=config.getboolean('training settings', 'flip'))
    print('all patches:', len(dataset))
    # ======== define the dataloader

    train_loader = data.DataLoader(dataset,
                                   batch_size=batch_size,
                                   pin_memory=True,
                                   shuffle=True)
    '''
    patches_imgs_test, patches_masks_test = get_data_testing(
        DRIVE_test_imgs_original=path_data + config.get('data paths', 'test_imgs_original'),  # original
        DRIVE_test_groudTruth=path_data + config.get('data paths', 'test_groundTruth'),  # masks
        Imgs_to_test=int(config.get('testing settings', 'full_images_to_test')),
        patch_height=int(config.get('data attributes', 'patch_height')),
        patch_width=int(config.get('data attributes', 'patch_width'))
    )
    visualize(group_images(patches_imgs_test[0:40, :, :, :], 5),
              './' + name_experiment + '/' + "sample_test_imgs")  # .show()
    visualize(group_images(patches_masks_test[0:40, :, :, :], 5),
              './' + name_experiment + '/' + "sample_test_masks")  # .show()
    test_data = data.TensorDataset(torch.tensor(patches_imgs_test),
                                   torch.tensor(patches_masks_test))
    test_loader = data.DataLoader(test_data, batch_size=30, pin_memory=True, shuffle=True)
    '''
    if model == 'UNet':
        ssd_net = UNet(n_channels=1, n_classes=2)
    elif model == 'UNet_side_loss':
        ssd_net = UNet_side_loss(n_channels=1, n_classes=2)
    elif model == 'UNet_level4_our':
        ssd_net = UNet_level4_our(n_channels=1, n_classes=2)
    elif model == 'UNet_cat':
        ssd_net = UNet_cat(n_channels=1, n_classes=2)
    else:
        ssd_net = UNet_multichannel(n_channels=1, n_classes=2)

    net = ssd_net
    #dummy_input = Variable(torch.rand(1, 1, 48, 48))
    #writer.add_graph(net,dummy_input)

    net = torch.nn.DataParallel(ssd_net)  #,devices_ids=[0,1,2])
    cudnn.benchmark = True
    net = net.cuda()
    '''
    if resume == True:
        ssd_net.load_state_dict(torch.load('./drive_train_with_64_adam_lr3/DRIVE_50epoch.pth'))
    
    if resume == False:
        print('Initializing weights...')

        ssd_net.inc.apply(weights_init)
        ssd_net.down1.apply(weights_init)
        ssd_net.down2.apply(weights_init)
        ssd_net.up1.apply(weights_init)
        ssd_net.up2.apply(weights_init)
        ssd_net.outc.apply(weights_init)
    '''
    #############################
    if optim_select == 'SGD':
        optimizer = optim.SGD(net.parameters(),
                              lr=learning_rate,
                              momentum=momentum,
                              weight_decay=weight_decay)
    else:
        optimizer = optim.Adam(net.parameters(),
                               lr=learning_rate,
                               weight_decay=weight_decay)
    # criterion = FocalLoss(class_num=2, alpha=None, gamma=2, size_average=True)
    DRIVE_weight = torch.cuda.FloatTensor([1.0, vessel_weight])
    criterion = CrossEntropyLoss2D(weights=DRIVE_weight)
    # criterion=dilation_loss()
    #criterion=CrossEntropyLoss2D(weights=None)
    # criterion=thin_mid_vessel_loss(thin_weight=80,mid_weight=9)#18
    net.train()
    step_index = 0

    epoch_size = len(dataset) // batch_size  #epoch 0 is the 1 epoch trained
    for epoch in range(start_iter, N_epochs + 1):
        train_loss = 0
        #acc = 0
        #precision = 0
        #recall = 0
        #spec = 0
        #dice_coef = 0
        #AUC = 0
        for i_batch, (images, targets) in enumerate(train_loader):
            iteration = epoch * epoch_size + i_batch
            # 50-90 the eval don't change so the 0.0001 is so small(35,50,90)
            # 90-150 don't change hhhh,so 1e-4 and 1e-5 is too small
            # no adjust learning rate is so shock ,but the where is the change node
            if epoch in (3, 6, 9):
                step_index += 1
                adjust_learning_rate(optimizer, gamma, step_index, epoch)

            images = Variable(images.float().cuda())
            targets = Variable(targets.long().cuda())  # long
            # dilation_mask=Variable(dilation_mask.long().cuda())

            # forward
            # t0 = time.time()
            out = net(images)
            # backprop
            optimizer.zero_grad()
            loss = criterion(
                out.permute(0, 2, 3, 1).contiguous().view(-1, 2),
                targets.view(-1))

            loss.backward()
            optimizer.step()
            # t1 = time.time()
            train_loss += loss.item()
            writer.add_scalar('train/batch_loss', loss.item(), iteration)
            ori = F.softmax(out, dim=1)
            #pre = torch.ge(ori.data[:, 1, :, :], 0.5)
            #acc_, precision_, recall_, spec_, _, dice_coef_ = evaluation(targets.data, pre)
            #AUC_ = computeAUC(targets.data,ori.data[:,1,:,:], pos_label=1)

            #acc += acc_
            #precision += precision_
            #recall += recall_
            #spec += spec_
            #dice_coef += dice_coef_
            #AUC += AUC_

        print('epoch' + repr(epoch) + '|| Loss: %.4f ||' % (loss.item()))
        #if epoch % 10 == 0:
        print('Saving state, epoch:', epoch)
        torch.save(
            ssd_net.state_dict(),
            name_experiment + '/' + 'DRIVE_' + repr(epoch) + 'epoch.pth')
        '''
        if epoch % 30 == 0:
            visualize(group_images(images.data, 1), './' + name_experiment + '/' + "train_imgs_" + str(epoch))
            visualize(group_images(targets.data, 1), './' + name_experiment + '/' + "train_masks_" + str(epoch))
            visualize(group_images(torch.unsqueeze(ori.data[:, 1, :, :], 1), 1),
                      './' + name_experiment + '/' + "train_pred_" + str(epoch))
        '''
        writer.add_scalar('train/loss', train_loss / epoch_size, epoch)
        '''
        writer.add_scalar('train/acc', acc / epoch_size, epoch)
        writer.add_scalar('train/AUC', AUC / epoch_size, epoch)
        writer.add_scalar('train/precision', precision / epoch_size, epoch)
        writer.add_scalar('train/recall', recall / epoch_size, epoch)
        writer.add_scalar('train/specificity', spec / epoch_size, epoch)
        writer.add_scalar('train/dice_score', dice_coef / epoch_size, epoch)
        '''
        '''
        optimizer.step()

        print('\rEpoch: ', epoch, 'Batch', n_batch, 'Error:', error.item())

        if (n_batch) % 1 == 0:
            test_images = generator(real_data)
            test_images = test_images.data

            grid_img = make_grid(test_images.cpu().detach(), nrow=3)
            # Display the color image
            # plt.imshow(grid_img.permute(1, 2, 0))
            save_image(real_data.cpu().detach(),
                       save_dir + str(epoch) + '_image_' + str(n_batch) +
                       '_real.png',
                       nrow=4)
            save_image(test_images.cpu().detach(),
                       save_dir + str(epoch) + '_image_' + str(n_batch) +
                       '_gen.png',
                       nrow=4)
            save_image(gt_data.cpu().detach(),
                       save_dir + str(epoch) + '_image_' + str(n_batch) +
                       '_gt.png',
                       nrow=4)

            # Save the latest models to resume training
            torch.save(generator.state_dict(),
                       os.path.join(model_path, 'model_gen_latest'))

    # Save every model of generator
    torch.save(generator.state_dict(),
               os.path.join(model_path, 'model_' + str(epoch)))
示例#10
0
class UNetObjPrior(nn.Module):
    """ 
    Wrapper around UNet that takes object priors (gaussians) and images 
    as input.
    """
    def __init__(self, params, depth=5):
        super(UNetObjPrior, self).__init__()
        self.in_channels = 4
        self.model = UNet(1, self.in_channels, depth, cuda=params['cuda'])
        self.params = params
        self.device = torch.device('cuda' if params['cuda'] else 'cpu')

    def forward(self, im, obj_prior):
        x = torch.cat((im, obj_prior), dim=1)
        return self.model(x)

    def train(self, dataloader_train, dataloader_val):

        since = time.time()
        best_loss = float("inf")

        dataloader_train.mode = 'train'
        dataloader_val.mode = 'val'
        dataloaders = {'train': dataloader_train, 'val': dataloader_val}

        optimizer = optim.SGD(self.model.parameters(),
                              momentum=self.params['momentum'],
                              lr=self.params['lr'],
                              weight_decay=self.params['weight_decay'])

        train_logger = LossLogger('train', self.params['batch_size'],
                                  len(dataloader_train),
                                  self.params['out_dir'])

        val_logger = LossLogger('val', self.params['batch_size'],
                                len(dataloader_val), self.params['out_dir'])

        loggers = {'train': train_logger, 'val': val_logger}

        # self.criterion = WeightedMSE(dataloader_train.get_classes_weights(),
        #                              cuda=self.params['cuda'])
        self.criterion = nn.MSELoss()

        for epoch in range(self.params['num_epochs']):
            print('Epoch {}/{}'.format(epoch, self.params['num_epochs'] - 1))
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    #scheduler.step()
                    self.model.train()
                else:
                    self.model.eval()  # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0

                # Iterate over data.
                samp = 1
                for i, data in enumerate(dataloaders[phase]):
                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        out = self.forward(data.image, data.obj_prior)
                        loss = self.criterion(out, data.truth)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    loggers[phase].update(epoch, samp, loss.item())

                    samp += 1

                loggers[phase].print_epoch(epoch)

                # Generate train prediction for check
                if phase == 'train':
                    path = os.path.join(self.params['out_dir'], 'previews',
                                        'epoch_{:04d}.jpg'.format(epoch))
                    data = dataloaders['val'].sample_uniform()
                    pred = self.forward(data.image, data.obj_prior)
                    im_ = data.image[0]
                    truth_ = data.truth[0]
                    pred_ = pred[0, ...]
                    utls.save_tensors(im_, pred_, truth_, path)

                if phase == 'val' and (loggers['val'].get_loss(epoch) <
                                       best_loss):
                    best_loss = loggers['val'].get_loss(epoch)

                loggers[phase].save('log_{}.csv'.format(phase))

                # save checkpoint
                if phase == 'val':
                    is_best = loggers['val'].get_loss(epoch) <= best_loss
                    path = os.path.join(self.params['out_dir'],
                                        'checkpoint.pth.tar')
                    utls.save_checkpoint(
                        {
                            'epoch': epoch + 1,
                            'state_dict': self.model.state_dict(),
                            'best_loss': best_loss,
                            'optimizer': optimizer.state_dict()
                        },
                        is_best,
                        path=path)

    def load_checkpoint(self, path, device='gpu'):

        if (device != 'gpu'):
            checkpoint = torch.load(path,
                                    map_location=lambda storage, loc: storage)
        else:
            checkpoint = torch.load(path)

        self.model.load_state_dict(checkpoint['state_dict'])
示例#11
0
        return image, label

    def __len__(self):
        # get the size of data set
        return len(self.imgs_path)


if __name__ == "__main__":
    dataset = DataLoader("data/train10/")
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_size=1,
                                               shuffle=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = UNet(n_channels=1, n_classes=1)
    net.to(device=device)

    net.train()
    for image, label in train_loader:
        image = image.to(device=device, dtype=torch.float32)
        label = label.to(device=device, dtype=torch.float32)
        pred = net(image)
        loss = F.cross_entropy(pred, label.to(torch.long))
        print('Loss/train', loss.item())
        if loss < best_loss:
            best_loss = loss
            torch.save(net.state_dict(), 'best_model.pth')
        loss.backward()
        optimizer.step()

        print(pred.shape, image.shape, label.shape)
示例#12
0
    col_heads = [
        'MSE', 'SSIM', 'PSNR', 'Train loss', 'Validation loss', 'LearnRate'
    ]
    #col_heads = ['Train loss', 'Validation loss', 'LearnRate']
    metrics_assess = pd.DataFrame(list(
        zip(valid_mse, valid_ssim, valid_psnr, train_loss, valid_loss,
            learn_rate)),
                                  columns=col_heads)
    #metrics_assess = pd.DataFrame(list(zip(train_loss, valid_loss, learn_rate)), columns=col_heads)
    metrics_assess.to_csv('metrics_2d' + '_' + str(Nthe) + '_' + str(Nphi) +
                          '.csv')
    if save_model_per_interval:
        if epoch % interval == 0:
            print('=> Saving Checkpoint')
            torch.save(
                model.state_dict(), model_loc + 'Model_Final' + '_' +
                str(epoch) + '_' + str(Nthe) + '_' + str(Nphi) + ".pkl")
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': tra_loss,
                    'learn_rate': lr
                }, model_loc + str(epoch) + ".pt")
torch.save(
    model.state_dict(), model_loc + 'Model_Final' + '_' + str(epoch) + '_' +
    str(Nthe) + '_' + str(Nphi) + ".pkl")
torch.save(
    {
        'epoch': epoch,
示例#13
0
        for batch_idx, items in enumerate(train_dataloader):

            image = items['image_in']
            gt = items['groundtruth']

            model.train()

            image = np.swapaxes(image, 1, 3)
            image = np.swapaxes(image, 2, 3)
            image = image.float()
            image = image.cuda(cuda)

            gt = gt.squeeze()
            gt = gt.float()
            gt = gt.cuda(cuda)

            pred = model(image).squeeze()

            loss = (pred - gt).abs().mean() + 5 * ((pred - gt)**2).mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print("[Epoch %d] [Batch %d/%d] [loss: %f]" %
                  (epoch, batch_idx, len(train_dataloader), loss.item()))

        torch.save(model.state_dict(),
                   "save/sUNet_microtubule_" + str(epoch + 1) + ".pkl")
示例#14
0
			# 2. Train Generator        # Generate fake data
			fake_data = generator(real_data).to(device)       # Train G

			# # Threshold the fake data
			# zeros=torch.zeros(fake_data.shape)
			# ones=torch.ones(fake_data.shape)

			# fake_data=torch.where(fake_data<0.6,zeros,ones).to(device) 
			# # print(fake_data)
			g_error = train_generator(fake_data, gt_data)        # Log batch error
		
		print('\rEpoch: ',epoch,'Batch',n_batch,'Gen Loss:',g_error.item(),'Dis Loss:',d_error.item())
		if (n_batch) % 50 == 0: 
			test_images = generator(real_data)
			test_images = test_images.data            
			
			grid_img = make_grid(test_images.cpu().detach(), nrow=3)
					  # Display the color image
			# plt.imshow(grid_img.permute(1, 2, 0))
			save_image(real_data.cpu().detach(),save_dir+str(epoch)+'_image_'+str(n_batch)+'_real.png', nrow=4)
			save_image(test_images.cpu().detach(),save_dir+str(epoch)+'_image_'+str(n_batch)+'_gen.png', nrow=4)
			save_image(gt_data.cpu().detach(),save_dir+str(epoch)+'_image_'+str(n_batch)+'_gt.png', nrow=4)

			# Save the latest models to resume training
			torch.save(generator.state_dict(), os.path.join(model_path,'model_gen_latest'))
			torch.save(discriminator_g.state_dict(), os.path.join(model_path,'model_gdis_latest'))
			torch.save(discriminator_l.state_dict(), os.path.join(model_path,'model_ldis_latest'))

	# Save every model of generator		
	torch.save(generator.state_dict(), os.path.join(model_path,'model_'+str(epoch)))
示例#15
0
    print('epoch: ', epoch, 'training loss: ', tra_loss, 'LR:', lr)
    epoch_list = range(epochs)
    col_heads = [
        'Train loss', 'Valid loss', 'MSE-1', 'SSIM-1', 'PSNR-1', 'MSE-2',
        'SSIM-2', 'PSNR-2', 'MSE-3', 'SSIM-3', 'PSNR-3', 'LearnRate'
    ]
    metrics_assess = pd.DataFrame(list(
        zip(train_loss, val_metrics['loss'], val_metrics['mse1'],
            val_metrics['ssim1'], val_metrics['psnr1'], val_metrics['mse2'],
            val_metrics['ssim2'], val_metrics['psnr2'], val_metrics['mse3'],
            val_metrics['ssim3'], val_metrics['psnr3'], learn_rate)),
                                  columns=col_heads)
    metrics_assess.to_csv('metrics_3d' + '_' + str(Nthe) + '_' + str(Nphi) +
                          '.csv')
    if save_model_per_epoch:
        torch.save(model.state_dict(), model_loc + str(epoch + 1) + ".pkl")
    if save_model_per_interval:
        if epoch % interval == 0:
            print('=> Saving Checkpoint')
            torch.save(
                model.state_dict(), model_loc + 'Model_Final' + '_' +
                str(epoch) + '_' + str(Nthe) + '_' + str(Nphi) + ".pkl")
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': tra_loss,
                    'learn_rate': lr
                }, model_loc + str(epoch) + ".pt")
torch.save(
示例#16
0
            print("learning rate = {}".format(p['lr']))
            
        for batch_idx, items in enumerate(train_dataloader):
            
            image = items['image_in']
            gt = items['groundtruth']
            
            model.train()
            
            image = np.swapaxes(image, 1,3)
            image = np.swapaxes(image, 2,3)
            image = image.float()
            image = image.cuda(cuda)    
            
            gt = gt.squeeze()
            gt = gt.float()
            gt = gt.cuda(cuda)
            
            pred = model(image).squeeze()

            loss = (pred-gt).abs().mean() + 5 * ((pred-gt)**2).mean()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print ("[Epoch %d] [Batch %d/%d] [loss: %f]" % (epoch, batch_idx, len(train_dataloader), loss.item()))

        if epoch % 50 == 49:
            torch.save(model.state_dict(), "/home/star/0_code_lhj/DL-SIM-github/Training_codes/UNet/UNet_SRRF_microtubule_"+str(epoch+1)+".pkl")
示例#17
0
            global_step += 1
            # write images to tensoboard every 10 batches
            # if global_step % (len(dataset) // (10 * batch_size)) == 0:
            #     writer.add_images('images', imgs, global_step)
            #     if net.n_classes == 1:
            #         writer.add_images('masks/true', true_masks, global_step)
            #         writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)

        epoch_loss_avg = epoch_loss / num_batches_per_epoch
        epoch_loss_list.append(epoch_loss_avg)
        val_score = evaluate(net, val_loader, device, n_val)  # evaluation
        val_score_list.append(val_score)
        print(
            'epoch: {:2d} \t training loss(cross_entropy): {:5f} \t validation score(dice coeff): {:5f}'
            .format(epoch, epoch_loss_avg, val_score))
        # writer.add_scalar('train_loss', epoch_loss_avg, epoch)
        # writer.add_scalar('val_score', val_score, epoch)

        # save checkpoints and history train_loss/val_score every 5 epochs
        if epoch % 2 == 0:
            torch.save(
                {
                    'state_dict': net.state_dict(),
                    'loss_list': epoch_loss_list,
                    'val_score_list': val_score_list
                }, checkpoint_dir + '/' + f'unet_ckpt_{epoch}.pth')
            epoch_loss_list = []
            val_score_list = []

    # writer.close()