Example #1
0
Y_train = np.load(TRAIN_DATA_PATH / 'Y_train.npy')

np.random.seed(SEED)
X_train = np.random.permutation(X_train)
np.random.seed(SEED)
Y_train = np.random.permutation(Y_train)

m = X_train.shape[0]
m_val = VAL_BATCH_SIZE
X_val = X_train[:m_val]
Y_val = Y_train[:m_val]
X_train = X_train[m_val:]
Y_train = Y_train[m_val:]
print('Beginning training...')
model.train(X_train,
            Y_train,
            X_val,
            Y_val,
            max_epochs=100,
            batch_size=32,
            learning_rate_init=1e-4,
            reg_param=0,
            learning_rate_decay_type='constant',
            learning_rate_decay_parameter=1,
            early_stopping=True,
            save_path='./models/0/UNet0',
            reset_parameters=True,
            check_val_every_n_batches=100,
            seed=SEED,
            data_on_GPU=False)
Example #2
0
transform = transforms.Compose(
    [transforms.CenterCrop(256),
     transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))])

args = parser.parse_args()

VAR = args.var
DATA_DIR = args.data_dir
CHECKPOINT = args.checkpoint

testset = CustomImageDataset(DATA_DIR, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4)
dataiter = iter(testloader)
checkpoint = torch.load(CHECKPOINT, map_location=torch.device('cpu'))


model_test = UNet(in_channels=3, out_channels=3).double()
model_test.load_state_dict(checkpoint['model_state_dict'])
model_test = model_test.cpu()
model_test.train()

noisy = NoisyDataset(var=VAR)

images, _ = dataiter.next()
noisy_images = noisy(images)
# Displaying the Noisy Images
imshow(torchvision.utils.make_grid(noisy_images.cpu()))
# Displaying the Denoised Images
imshow(torchvision.utils.make_grid(model_test(noisy_images.cpu())))
Example #3
0
def train(train_loader,
          valid_loader,
          loss_type,
          act_type,
          tolerance,
          result_path,
          log_interval=10,
          lr=0.000001,
          max_epochs=500):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # model = UNet(upsample_mode='transpose').to(device)
    model = UNet(upsample_mode='bilinear').to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    best_model_path = result_path + '/best_model.pth'
    model_path = result_path + '/model_epoch'

    train_batch_loss_file = open(result_path + "/train_batch_loss.txt", "w")
    valid_batch_loss_file = open(result_path + "/valid_batch_loss.txt", "w")

    train_all_epochs_loss_file = open(
        result_path + "/train_all_epochs_loss.txt", "w")
    train_all_epochs_loss = []

    valid_all_epochs_loss_file = open(
        result_path + "/valid_all_epochs_loss.txt", "w")
    valid_all_epochs_loss = []

    minimum_loss = np.inf
    finish = False

    for epoch in range(1, max_epochs + 1):
        for phase in ['train', 'val']:
            if phase == 'train':
                idx = list(range(0, len(train_loader)))
                train_smpl = random.sample(idx, 1)
                #train_smpl.append(len(train_loader)-1)
                loader = train_loader
                model.train()
            elif phase == 'val':
                idx = list(range(0, len(valid_loader)))
                val_smpl = random.sample(idx, 1)
                #val_smpl.append(len(valid_loader) - 1)
                loader = valid_loader
                model.eval()

            all_batches_losses = []

            for batch_i, sample in enumerate(loader):
                data, target, loss_weight = sample['image'], sample[
                    'image_anno'], sample['loss_weight']  #/1000
                data, target, loss_weight = data.to(device), target.to(
                    device), loss_weight.to(device)

                optimizer.zero_grad()
                loss_weight = loss_weight / 1000

                with torch.set_grad_enabled(phase == 'train'):
                    output = model(data)
                    # Set activation type:
                    if act_type == 'sigmoid':
                        activation = torch.nn.Sigmoid().cuda()
                    elif act_type == 'tanh':
                        activation = torch.nn.Tanh().cuda()
                    elif act_type == 'soft':
                        activation = torch.nn.Softmax().cuda()

                    # Calculate loss:
                    if loss_type == 'wbce':
                        # Weighted BCE with averaging:
                        criterion = torch.nn.BCELoss(weight=loss_weight).cuda(
                        )  #,size_average=False).cuda()
                        loss = criterion(activation(output), target).cuda()
                        #loss = criterion(output, target).cuda()
                    elif loss_type == 'bce':
                        # BCE with averaging:
                        criterion = torch.nn.BCELoss().cuda(
                        )  # ,size_average=False).cuda()
                        loss = criterion(activation(output), target).cuda()
                    elif loss_type == 'mse':
                        # MSE:
                        loss = F.mse_loss(output, target).cuda()
                    else:  # loss_type == 'jac':
                        loss = jaccard_loss(activation(output), target).cuda()

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                if phase == 'train':
                    train_batch_loss_file.write(str(loss.item()) + "\n")
                    train_batch_loss_file.close()
                    train_batch_loss_file = open(
                        result_path + "/train_batch_loss.txt", "a")
                else:
                    valid_batch_loss_file.write(str(loss.item()) + "\n")
                    valid_batch_loss_file.close()

                    valid_batch_loss_file = open(
                        result_path + "/valid_batch_loss.txt", "a")

                all_batches_losses.append(loss.item())

                if batch_i % log_interval == 0:
                    print(
                        '{} Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                            phase, epoch, batch_i * len(data),
                            len(loader.dataset), 100. * batch_i / len(loader),
                            loss.item()))

                if phase == 'train' and batch_i in train_smpl:
                    post_transform = transforms.Compose(
                        [Binarize_Output(threshold=output.mean())])
                    thres = post_transform(output)
                    post_transform_weight = transforms.Compose(
                        [Binarize_Output(threshold=loss_weight.mean())])
                    weight_tresh = post_transform_weight(loss_weight)
                    utils.save_image(
                        data, "{}/train_input_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        target, "{}/train_target_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        output, "{}/train_output_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        thres, "{}/train_thres_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        weight_tresh, "{}/train_weights_{}_{}.png".format(
                            result_path, epoch, batch_i))

                    if epoch % 25 == 0:
                        torch.save(model.state_dict(),
                                   model_path + '_{}.pth'.format(epoch))

                if phase == 'val' and batch_i in val_smpl:
                    post_transform = transforms.Compose(
                        [Binarize_Output(threshold=output.mean())])
                    thres = post_transform(output)

                    post_transform_weight = transforms.Compose(
                        [Binarize_Output(threshold=loss_weight.mean())])
                    weight_tresh = post_transform_weight(loss_weight)

                    utils.save_image(
                        data, "{}/valid_input_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        target, "{}/valid_target_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        output, "{}/valid_output_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        thres, "{}/valid_thres_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        weight_tresh, "{}/valid_weights_{}_{}.png".format(
                            result_path, epoch, batch_i))

            if phase == 'train':
                train_last_avg_loss = np.mean(all_batches_losses)
                print("------average %s loss %f" %
                      (phase, train_last_avg_loss))
                train_all_epochs_loss_file.write(
                    str(train_last_avg_loss) + "\n")
                train_all_epochs_loss_file.close()
                train_all_epochs_loss_file = open(
                    result_path + "/train_all_epochs_loss.txt", "a")
            if phase == 'val':
                valid_last_avg_loss = np.mean(all_batches_losses)
                print("------average %s loss %f" %
                      (phase, valid_last_avg_loss))
                valid_all_epochs_loss_file.write(
                    str(valid_last_avg_loss) + "\n")
                valid_all_epochs_loss_file.close()
                valid_all_epochs_loss_file = open(
                    result_path + "/valid_all_epochs_loss.txt", "a")
                valid_all_epochs_loss.append(valid_last_avg_loss)
                if valid_last_avg_loss < minimum_loss:
                    minimum_loss = valid_last_avg_loss
                    #--------------------- Saving the best found model -----------------------
                    torch.save(model.state_dict(), best_model_path)
                    print("Minimum Average Loss so far:", minimum_loss)
                if early_stopping(epoch, valid_all_epochs_loss, tolerance):
                    finish = True
                    break

        if finish == True:
            break
Example #4
0
np.random.seed(SEED)
X_train = np.random.permutation(X_train)
np.random.seed(SEED)
Y_train = np.random.permutation(Y_train)

m = X_train.shape[0]
m_val = VAL_BATCH_SIZE
X_val = X_train[:m_val]
Y_val = Y_train[:m_val]
X_train = X_train[m_val:]
Y_train = Y_train[m_val:]
print('Beginning training...')
model.train(X_train,
            Y_train,
            X_val,
            Y_val,
            max_epochs=int(1e9),
            batch_size=16,
            learning_rate_init=2e-3,
            reg_param=0,
            learning_rate_decay_type='inverse',
            learning_rate_decay_parameter=10,
            keep_prob=[0.7, 0.8],
            early_stopping=True,
            save_path=MODEL_PATH,
            reset_parameters=True,
            val_checks_per_epoch=10,
            seed=SEED,
            data_on_GPU=False)
Example #5
0
trainset = TrainSet(IMG_ROOT, LABEL_ROOT, data_type)
trainloader = DataLoader(trainset, BATCH_SIZE, shuffle=True)
print('loader done')

# Defining model and optimization methode
device = 'cuda:0'
#device = 'cpu'
unet = UNet(in_channel=3, class_num=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(unet.parameters(), lr=0.0005, amsgrad=True)

epochs = 1
lsize = len(trainloader)
itr = 0
p_itr = 10  # print every N iteration
unet.train()
tloss = 0
loss_history = []

for epoch in range(epochs):
    with tqdm(total=lsize) as pbar:
        for x, y, path in trainloader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = unet(x)
            loss = criterion(output, y[:, 0, :, :].to(device))
            loss.backward()
            optimizer.step()
            tloss += loss.item()
            loss_history.append(loss.item())
Example #6
0
def train(args, Dataset):
    ####################################### Initializing Model #######################################
    step = args.lr
    #experiment_dir = args['--experiment_dir']
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("device:{}".format(device))
    print_every = int(args.print_every)
    num_epochs = int(args.num_epochs)
    save_every = int(args.save_every)
    save_path = str(args.model_save_path)
    batch_size = int(args.batch_size)
    #train_data_path = str(args['--data_path'])
    in_ch = int(args.in_ch)
    val_split = args.val_split
    img_directory = args.image_directory
    #model = MW_Unet(in_ch=in_ch)
    model = UNet(in_ch=in_ch)
    #model = model
    model.to(device)
    model.apply(init_weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=step)

    #criterion = nn.MSELoss()
    criterion = torch.nn.L1Loss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

    ######################################### Loading Data ##########################################

    dataset_total = Dataset
    dataset_size = len(dataset_total)
    indices = list(range(dataset_size))
    split = int(np.floor(val_split * dataset_size))
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    #train_indices, val_indices = indices[:1], indices[1:2]
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)

    dataloader_train = torch.utils.data.DataLoader(dataset_total,
                                                   batch_size=batch_size,
                                                   sampler=train_sampler,
                                                   num_workers=8)
    dataloader_val = torch.utils.data.DataLoader(dataset_total,
                                                 batch_size=batch_size,
                                                 sampler=valid_sampler,
                                                 num_workers=2)

    print("length of train set: ", len(train_indices))
    print("length of val set: ", len(val_indices))

    #best_val_PSNR = 0.0
    best_val_MSE, best_val_PSNR, best_val_SSIM = 100.0, -1, -1

    train_PSNRs = []
    train_losses = []
    train_SSIMs = []
    train_MSEs = []

    val_PSNRs = []
    val_losses = []
    val_SSIMs = []
    val_MSEs = []

    try:
        for epoch in range(1, num_epochs + 1):
            # INITIATE dataloader_train
            print("epoch: ", epoch)
            with tqdm(total=len(dataloader_train)) as pbar:
                for index, sample in enumerate(dataloader_train):
                    model.train()

                    target, model_input, features = sample['target'], sample[
                        'input'], sample['features']
                    N, P, C, H, W = model_input.shape
                    N, P, C_feat, H, W = features.shape
                    model_input = torch.reshape(model_input, (-1, C, H, W))
                    features = torch.reshape(features, (-1, C_feat, H, W))
                    albedo = features[:, 3:, :, :]
                    albedo = albedo.to(device)
                    eps = torch.tensor(1e-2)
                    eps = eps.to(device)
                    model_input = model_input.to(device)
                    model_input /= (albedo + eps)
                    target = torch.reshape(target, (-1, C, H, W))
                    features = features.to(device)
                    model_input = torch.cat((model_input, features), dim=1)
                    target = target.to(device)
                    model_input = model_input.to(device)

                    #print(model_input.dtype)
                    #print(model_input.shape)
                    # print(index)

                    output = model.forward(model_input)
                    output *= (albedo + eps)

                    train_loss = utils.backprop(optimizer, output, target,
                                                criterion)
                    train_PSNR = utils.get_PSNR(output, target)
                    train_MSE = utils.get_MSE(output, target)
                    train_SSIM = utils.get_SSIM(output, target)

                    avg_val_PSNR = []
                    avg_val_loss = []
                    avg_val_MSE = []
                    avg_val_SSIM = []
                    model.eval()
                    #output_val = 0;

                    train_losses.append(train_loss.cpu().detach().numpy())
                    train_PSNRs.append(train_PSNR)
                    train_MSEs.append(train_MSE)
                    train_SSIMs.append(train_SSIM)

                    if index == len(dataloader_train) - 1:
                        with torch.no_grad():
                            for val_index, val_sample in enumerate(
                                    dataloader_val):
                                target_val, model_input_val, features_val = val_sample[
                                    'target'], val_sample['input'], val_sample[
                                        'features']
                                N, P, C, H, W = model_input_val.shape
                                N, P, C_feat, H, W = features_val.shape
                                model_input_val = torch.reshape(
                                    model_input_val, (-1, C, H, W))
                                features_val = torch.reshape(
                                    features_val, (-1, C_feat, H, W))
                                albedo = features_val[:, 3:, :, :]
                                albedo = albedo.to(device)
                                eps = torch.tensor(1e-2)
                                eps = eps.to(device)
                                model_input_val = model_input_val.to(device)
                                model_input_val /= (albedo + eps)
                                target_val = torch.reshape(
                                    target_val, (-1, C, H, W))
                                features_val = features_val.to(device)
                                model_input_val = torch.cat(
                                    (model_input_val, features_val), dim=1)
                                target_val = target_val.to(device)
                                model_input_val = model_input_val.to(device)
                                output_val = model.forward(model_input_val)
                                output_val *= (albedo + eps)
                                loss_fn = criterion
                                loss_val = loss_fn(output_val, target_val)
                                PSNR = utils.get_PSNR(output_val, target_val)
                                MSE = utils.get_MSE(output_val, target_val)
                                SSIM = utils.get_SSIM(output_val, target_val)
                                avg_val_PSNR.append(PSNR)
                                avg_val_loss.append(
                                    loss_val.cpu().detach().numpy())
                                avg_val_MSE.append(MSE)
                                avg_val_SSIM.append(SSIM)

                        avg_val_PSNR = np.mean(avg_val_PSNR)
                        avg_val_loss = np.mean(avg_val_loss)
                        avg_val_MSE = np.mean(avg_val_MSE)
                        avg_val_SSIM = np.mean(avg_val_SSIM)

                        val_PSNRs.append(avg_val_PSNR)
                        val_losses.append(avg_val_loss)
                        val_MSEs.append(avg_val_MSE)
                        val_SSIMs.append(avg_val_SSIM)
                        scheduler.step(avg_val_loss)

                        img_grid = output.data[:9]
                        img_grid = torchvision.utils.make_grid(img_grid)
                        real_grid = target.data[:9]
                        real_grid = torchvision.utils.make_grid(real_grid)
                        input_grid = model_input.data[:9, :3, :, :]
                        input_grid = torchvision.utils.make_grid(input_grid)
                        val_grid = output_val.data[:9]
                        val_grid = torchvision.utils.make_grid(val_grid)
                        #save_image(input_grid, '{}train_input_img.png'.format(img_directory))
                        #save_image(img_grid, '{}train_img_{}.png'.format(img_directory, epoch))
                        #save_image(real_grid, '{}train_real_img_{}.png'.format(img_directory, epoch))
                        #print('train images')
                        fig, ax = plt.subplots(4)
                        fig.subplots_adjust(hspace=0.5)
                        ax[0].set_title('target')
                        ax[0].imshow(real_grid.cpu().numpy().transpose(
                            (1, 2, 0)))
                        ax[1].set_title('input')
                        ax[1].imshow(input_grid.cpu().numpy().transpose(
                            (1, 2, 0)))
                        ax[2].set_title('output_train')
                        ax[2].imshow(img_grid.cpu().numpy().transpose(
                            (1, 2, 0)))
                        ax[3].set_title('output_val')
                        ax[3].imshow(val_grid.cpu().numpy().transpose(
                            (1, 2, 0)))
                        #plt.show()
                        plt.savefig('{}train_output_target_img_{}.png'.format(
                            img_directory, epoch))
                        plt.close()

                    pbar.update(1)
            if epoch % print_every == 0:
                print(
                    "Epoch: {}, Loss: {}, Train MSE: {} Train PSNR: {}, Train SSIM: {}"
                    .format(epoch, train_loss, train_MSE, train_PSNR,
                            train_SSIM))
                print(
                    "Epoch: {}, Avg Val Loss: {}, Avg Val MSE: {}, Avg Val PSNR: {}, Avg Val SSIM: {}"
                    .format(epoch, avg_val_loss, avg_val_MSE, avg_val_PSNR,
                            avg_val_SSIM))
                plt.figure()
                plt.semilogy(np.linspace(0, epoch, len(train_losses)),
                             train_losses)
                plt.xlabel("Epoch")
                plt.ylabel("Loss")
                plt.savefig("{}train_loss.png".format(img_directory))
                plt.close()

                plt.figure()
                plt.semilogy(np.linspace(0, epoch, len(val_losses)),
                             val_losses)
                plt.xlabel("Epoch")
                plt.ylabel("Loss")
                plt.savefig("{}val_loss.png".format(img_directory))
                plt.close()

                plt.figure()
                plt.plot(np.linspace(0, epoch, len(train_PSNRs)), train_PSNRs)
                plt.xlabel("Epoch")
                plt.ylabel("PSNR")
                plt.savefig("{}train_PSNR.png".format(img_directory))
                plt.close()

                plt.figure()
                plt.plot(np.linspace(0, epoch, len(val_PSNRs)), val_PSNRs)
                plt.xlabel("Epoch")
                plt.ylabel("PSNR")
                plt.savefig("{}val_PSNR.png".format(img_directory))
                plt.close()

                plt.figure()
                plt.semilogy(np.linspace(0, epoch, len(train_MSEs)),
                             train_MSEs)
                plt.xlabel("Epoch")
                plt.ylabel("MSE")
                plt.savefig("{}train_MSE.png".format(img_directory))
                plt.close()

                plt.figure()
                plt.semilogy(np.linspace(0, epoch, len(val_MSEs)), val_MSEs)
                plt.xlabel("Epoch")
                plt.ylabel("MSE")
                plt.savefig("{}val_MSE.png".format(img_directory))
                plt.close()

                plt.figure()
                plt.plot(np.linspace(0, epoch, len(train_SSIMs)), train_SSIMs)
                plt.xlabel("Epoch")
                plt.ylabel("SSIM")
                plt.savefig("{}train_SSIM.png".format(img_directory))
                plt.close()

                plt.figure()
                plt.plot(np.linspace(0, epoch, len(val_SSIMs)), val_SSIMs)
                plt.xlabel("Epoch")
                plt.ylabel("SSIM")
                plt.savefig("{}val_SSIM.png".format(img_directory))
                plt.close()

            if best_val_MSE > avg_val_MSE:
                best_val_MSE, best_val_PSNR, best_val_SSIM = avg_val_MSE, avg_val_PSNR, avg_val_SSIM
                print("new best Avg Val MSE: {}".format(best_val_MSE))
                print("Saving model to {}".format(save_path))
                torch.save(
                    {
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': train_loss
                    }, save_path + "best_model.pth")
                print("Saved successfully to {}".format(save_path))

    except KeyboardInterrupt:
        print("Training interupted...")
        print("Saving model to {}".format(save_path))
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': train_loss
            }, save_path + "checkpoint{}.pth".format(epoch))
        print("Saved successfully to {}".format(save_path))

        print("Training completed.")

    print("Best MSE: %.10f, Best PSNR: %.10f, Best SSIM: %.10f" %
          (best_val_MSE, best_val_PSNR, best_val_SSIM))
    return (train_losses, train_PSNRs, val_losses, val_PSNRs, best_val_MSE)
Example #7
0
def main_loop(data_path,
              batch_size=batch_size,
              model_type='UNet',
              green=False,
              tensorboard=True):
    # Load train and val data
    tasks = ['EX']
    data_path = data_path
    n_labels = len(tasks)
    n_channels = 1 if green else 3  # green or RGB
    train_loader, val_loader = load_train_val_data(tasks=tasks,
                                                   data_path=data_path,
                                                   batch_size=batch_size,
                                                   green=green)

    if model_type == 'UNet':
        lr = learning_rate
        model = UNet(n_channels, n_labels)
        # Choose loss function
        criterion = nn.MSELoss()
        # criterion = dice_loss
        # criterion = mean_dice_loss
        # criterion = nn.BCELoss()

    elif model_type == 'GCN':
        lr = 1e-4
        model = GCN(n_labels, image_size[0])
        criterion = weighted_BCELoss
        # criterion = nn.BCELoss()

    else:
        raise TypeError('Please enter a valid name for the model type')

    try:
        loss_name = criterion._get_name()
    except AttributeError:
        loss_name = criterion.__name__

    if loss_name == 'BCEWithLogitsLoss':
        lr = 1e-4
        print('learning rate: ', lr)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # Choose optimize
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                              verbose=True,
                                                              patience=7)

    if tensorboard:
        log_dir = tensorboard_folder + session_name + '/'
        print('log dir: ', log_dir)
        if not os.path.isdir(log_dir):
            os.makedirs(log_dir)
        writer = SummaryWriter(log_dir)
    else:
        writer = None

    max_aupr = 0.0
    for epoch in range(epochs):  # loop over the dataset multiple times
        print('******** Epoch [{}/{}]  ********'.format(epoch + 1, epochs + 1))
        print(session_name)

        # train for one epoch
        model.train(True)
        print('Training with batch size : ', batch_size)
        train_loop(train_loader,
                   model,
                   criterion,
                   optimizer,
                   writer,
                   epoch,
                   lr_scheduler=lr_scheduler,
                   model_type=model_type)

        # evaluate on validation set
        print('Validation')
        with torch.no_grad():
            model.eval()
            val_loss, val_aupr = train_loop(val_loader, model, criterion,
                                            optimizer, writer, epoch)

        # Save best model
        if val_aupr > max_aupr and epoch > 3:
            print('\t Saving best model, mean aupr on validation set: {:.4f}'.
                  format(val_aupr))
            max_aupr = val_aupr
            save_checkpoint(
                {
                    'epoch': epoch,
                    'best_model': True,
                    'model': model_type,
                    'state_dict': model.state_dict(),
                    'val_loss': val_loss,
                    'loss': loss_name,
                    'optimizer': optimizer.state_dict()
                }, model_path)

        elif save_model and (epoch + 1) % save_frequency == 0:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'best_model': False,
                    'model': model_type,
                    'loss': loss_name,
                    'state_dict': model.state_dict(),
                    'val_loss': val_loss,
                    'optimizer': optimizer.state_dict()
                }, model_path)

    return model
Example #8
0
def train(net: UNet,
          train_ids_file_path: str,
          val_ids_file_path: str,
          in_dir_path: str,
          mask_dir_path: str,
          check_points: str,
          epochs=10,
          batch_size=4,
          learning_rate=0.1,
          device=torch.device("cpu")):
    train_data_set = ImageSet(train_ids_file_path, in_dir_path, mask_dir_path)

    train_data_loader = DataLoader(train_data_set,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=1)

    net = net.to(device)

    loss_func = nn.BCELoss()
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=learning_rate,
                                momentum=0.99)
    writer = SummaryWriter("tensorboard")
    g_step = 0

    for epoch in range(epochs):
        net.train()
        total_loss = 0

        with tqdm(total=len(train_data_set),
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for step, (imgs, masks) in tqdm(enumerate(train_data_loader)):
                imgs = imgs.to(device)
                masks = masks.to(device)

                outputs = net(imgs)
                loss = loss_func(outputs, masks)
                total_loss += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # record
                writer.add_scalar("Loss/Train", loss.item(), g_step)
                writer.flush()
                pbar.set_postfix(**{'loss (batch)': loss.item()})
                pbar.update(imgs.shape[0])
                g_step += 1

                if g_step % 10 == 0:
                    writer.add_images('masks/origin', imgs, g_step)
                    writer.add_images('masks/true', masks, g_step)
                    writer.add_images('masks/pred', outputs > 0.5, g_step)
                    writer.flush()

        try:
            os.mkdir(check_points)
            logging.info('Created checkpoint directory')
        except OSError:
            pass
        torch.save(net.state_dict(), check_points + f'CP_epoch{epoch + 1}.pth')
        logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()
Example #9
0
def run():
    print('loop')
    # torch.backends.cudnn.enabled = False
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cpu")
    # Assuming that we are on a CUDA machine, this should print a CUDA device:
    print(device)

    Dx = Discriminator().to(device)
    Gx = UNet(3, 3).to(device)

    Dy = Discriminator().to(device)
    Gy = UNet(3, 3).to(device)

    ld = False
    if ld:
        try:
            Gx.load_state_dict(torch.load('./genx'))
            Dx.load_state_dict(torch.load('./fcnx'))
            Gy.load_state_dict(torch.load('./geny'))
            Dy.load_state_dict(torch.load('./fcny'))
            print('net loaded')
        except Exception as e:
            print(e)

    dataset = 'ukiyoe2photo'
    # A 562
    image_path_A = './datasets/' + dataset + '/trainA/*.jpg'
    image_path_B = './datasets/' + dataset + '/trainB/*.jpg'

    plt.ion()

    train_image_paths_A = glob.glob(image_path_A)
    train_image_paths_B = glob.glob(image_path_B)
    print(len(train_image_paths_A), len(train_image_paths_B))

    b_size = 8

    train_dataset_A = CustomDataset(train_image_paths_A, train=True)
    train_loader_A = torch.utils.data.DataLoader(train_dataset_A,
                                                 batch_size=b_size,
                                                 shuffle=True,
                                                 num_workers=4,
                                                 pin_memory=False,
                                                 drop_last=True)

    train_dataset_B = CustomDataset(train_image_paths_B, True, 562, train=True)
    train_loader_B = torch.utils.data.DataLoader(train_dataset_B,
                                                 batch_size=b_size,
                                                 shuffle=True,
                                                 num_workers=4,
                                                 pin_memory=False,
                                                 drop_last=True)

    Gx.train()
    Dx.train()

    Gy.train()
    Dy.train()

    criterion = nn.BCEWithLogitsLoss().to(device)
    # criterion2 = nn.SmoothL1Loss().to(device)
    criterion2 = nn.L1Loss().to(device)

    g_lr = 2e-4
    d_lr = 2e-4
    optimizer_x = optim.Adam(Gx.parameters(), lr=g_lr, betas=(0.5, 0.999))
    optimizer_x_d = optim.Adam(Dx.parameters(), lr=d_lr, betas=(0.5, 0.999))

    optimizer_y = optim.Adam(Gy.parameters(), lr=g_lr, betas=(0.5, 0.999))
    optimizer_y_d = optim.Adam(Dy.parameters(), lr=d_lr, betas=(0.5, 0.999))

    # cp = cropper().to(device)

    _zero = torch.from_numpy(np.zeros((b_size, 1))).float().to(device)
    _zero.requires_grad = False

    _one = torch.from_numpy(np.ones((b_size, 1))).float().to(device)
    _one.requires_grad = False

    for epoch in trange(100, desc='epoch'):
        # loop = tqdm(zip(train_loader_A, train_loader_B), desc='iteration')
        loop = zip(tqdm(train_loader_A, desc='iteration'), train_loader_B)
        batch_idx = 0
        for data_A, data_B in loop:
            batch_idx += 1
            zero = _zero
            one = _one
            _data_A = data_A.to(device)
            _data_B = data_B.to(device)

            # Dy loss (A -> B)
            gen = Gy(_data_A)

            optimizer_y_d.zero_grad()

            output2_p = Dy(_data_B.detach())
            output_p = Dy(gen.detach())

            errD = (
                criterion(output2_p - torch.mean(output_p), one.detach()) +
                criterion(output_p - torch.mean(output2_p), zero.detach())) / 2
            errD.backward()
            optimizer_y_d.step()

            # Dx loss (B -> A)
            gen = Gx(_data_B)

            optimizer_x_d.zero_grad()

            output2_p = Dx(_data_A.detach())
            output_p = Dx(gen.detach())

            errD = (
                criterion(output2_p - torch.mean(output_p), one.detach()) +
                criterion(output_p - torch.mean(output2_p), zero.detach())) / 2
            errD.backward()
            optimizer_x_d.step()

            # Gy loss (A -> B)
            optimizer_y.zero_grad()
            gen = Gy(_data_A)
            output_p = Dy(gen)
            output2_p = Dy(_data_B.detach())
            g_loss = (
                criterion(output2_p - torch.mean(output_p), zero.detach()) +
                criterion(output_p - torch.mean(output2_p), one.detach())) / 2

            # Gy cycle loss (B -> A -> B)
            fA = Gx(_data_B)
            gen = Gy(fA.detach())
            c_loss = criterion2(gen, _data_B)

            errG = g_loss + c_loss
            errG.backward()
            optimizer_y.step()

            if batch_idx % 10 == 0:

                fig = plt.figure(1)
                fig.clf()
                plt.imshow((np.transpose(_data_B.detach().cpu().numpy()[0],
                                         (1, 2, 0)) + 1) / 2)
                fig.canvas.draw()
                fig.canvas.flush_events()

                fig = plt.figure(2)
                fig.clf()
                plt.imshow((np.transpose(fA.detach().cpu().numpy()[0],
                                         (1, 2, 0)) + 1) / 2)
                fig.canvas.draw()
                fig.canvas.flush_events()

                fig = plt.figure(3)
                fig.clf()
                plt.imshow((np.transpose(gen.detach().cpu().numpy()[0],
                                         (1, 2, 0)) + 1) / 2)
                fig.canvas.draw()
                fig.canvas.flush_events()

            # Gx loss (B -> A)
            optimizer_x.zero_grad()
            gen = Gx(_data_B)
            output_p = Dx(gen)
            output2_p = Dx(_data_A.detach())
            g_loss = (
                criterion(output2_p - torch.mean(output_p), zero.detach()) +
                criterion(output_p - torch.mean(output2_p), one.detach())) / 2

            # Gx cycle loss (A -> B -> A)
            fB = Gy(_data_A)
            gen = Gx(fB.detach())
            c_loss = criterion2(gen, _data_A)

            errG = g_loss + c_loss
            errG.backward()
            optimizer_x.step()

        torch.save(Gx.state_dict(), './genx')
        torch.save(Dx.state_dict(), './fcnx')
        torch.save(Gy.state_dict(), './geny')
        torch.save(Dy.state_dict(), './fcny')
    print('\nFinished Training')
Example #10
0
def train():
    # 训练的epoch数
    epoch = 500
    # 数据文件夹
    img_dir = "./data/training/images"
    # 掩模文件夹
    mask_dir = "./data/training/1st_manual"
    # 网络输入图片大小
    img_size = (512, 512)
    # 创建训练loader和验证loader
    tr_loader = DataLoader(DRIVE_Loader(img_dir, mask_dir, img_size, 'train'),
                           batch_size=4,
                           shuffle=True,
                           num_workers=2,
                           pin_memory=True,
                           drop_last=True)
    val_loader = DataLoader(DRIVE_Loader(img_dir, mask_dir, img_size, 'val'),
                            batch_size=4,
                            shuffle=True,
                            num_workers=2,
                            pin_memory=True,
                            drop_last=True)
    # 定义损失函数
    criterion = DiceBCELoss()
    # 把网络加载到显卡
    network = UNet().cuda()
    # 定义优化器
    optimizer = Adam(network.parameters(), weight_decay=0.0001)
    best_score = 1.0
    for i in range(epoch):
        # 设置为训练模式,会更新BN和Dropout参数
        network.train()
        train_step = 0
        train_loss = 0
        val_loss = 0
        val_step = 0
        # 训练
        for batch in tr_loader:
            # 读取每个batch的数据和掩模
            imgs, mask = batch
            # 把数据加载到显卡
            imgs = imgs.cuda()
            mask = mask.cuda()
            # 把数据喂入网络,获得一个预测结果
            mask_pred = network(imgs)
            # 根据预测结果与掩模求出Loss
            loss = criterion(mask_pred, mask)
            # 统计训练loss
            train_loss += loss.item()
            train_step += 1
            # 梯度清零
            optimizer.zero_grad()
            # 通过loss求出梯度
            loss.backward()
            # 使用Adam进行梯度回传
            optimizer.step()
        # 设置为验证模式,不更新BN和Dropout参数
        network.eval()
        # 验证
        with torch.no_grad():
            for batch in val_loader:
                imgs, mask = batch
                imgs = imgs.cuda()
                mask = mask.cuda()
                # 求出评价指标,这里用的是dice
                val_loss += DiceLoss()(network(imgs), mask).item()
                val_step += 1
        # 分别求出整个epoch的训练loss以及验证指标
        train_loss /= train_step
        val_loss /= val_step
        # 如果验证指标比最优值更好,那么保存当前模型参数
        if val_loss < best_score:
            best_score = val_loss
            torch.save(network.state_dict(), "./checkpoint.pth")
        # 输出
        print(str(i), "train_loss:", train_loss, "val_dice", val_loss)