def test(test_dir, weight_path, outputs_dir, img_size=(512, 512)):
    """
    :param test_dir:需要预测的数据的文件夹
    :param weight_path: 权重文件路径
    :param outputs_dir: 输出文件夹
    :param img_size: 图片大小
    :return:
    """
    # 定义网络结构,并且加载到显卡
    network = UNet().cuda()
    # 加载权重文件(训练好的网络)
    network.load_state_dict(torch.load(weight_path))
    # 获取测试文件夹的文件
    file_list = os.listdir(test_dir)
    for f in file_list:
        # 读取图片并完成缩放
        img = np.array(
            Image.open(os.path.join(test_dir,
                                    f)).resize(img_size, Image.BILINEAR))
        # 增加batch维度
        img = np.expand_dims(img, axis=0)
        # 更改通道顺序(BHWC->BCHW)
        img = img.transpose((0, 3, 1, 2))
        # 转为浮点类型
        img = img.astype(np.float32)
        # 预测结果并且从显存转移到内存中
        pred = network(
            torch.from_numpy(img).cuda()).clone().cpu().detach().numpy()
        # 二值化操作
        pred[pred >= 0.5] = 1
        pred[pred < 0.5] = 0
        # 保存结果到输出文件夹
        Image.fromarray(pred[0, 0, :, :]).save(os.path.join(outputs_dir, f))
Example #2
0
 def get_model(self):
     model = UNet(in_channels=3, out_channels=3).double()
     if self.use_cuda:
         model = model.cuda()
     noisy = NoisyDataset(var=self.VAR)
     optimizer = torch.optim.Adam(model.parameters(), lr=self.LR)
     criterion = RegularizedLoss()
     return model, noisy, optimizer, criterion
Example #3
0
def main():
    #print("in main of train")

    # model2 = smp.Unet("resnet18", encoder_weights="imagenet", classes=1, activation=None)
    # model_trainer2 = Trainer(model2)
    # model_trainer2.start()

    model = UNet()
    model = model.cuda()  #cuda expected but got cpu
    model_trainer = Trainer(model)
    model_trainer.start()
    def __init__(self, pars):
        super(SEUNet, self).__init__()
        self.unet = UNet(pars)

        self.fs = 16000
        self.win_len = 0.064
        self.hop_len = 0.016
    def __init__(self,
                 layer_level,
                 input_shape,
                 Loss,
                 lr,
                 num_modal,
                 num_chn=3,
                 num_class=2,
                 VISUALISATION=False,
                 SALIENCY=False,
                 DILATION=False,
                 dilation_factor=None):
        self.layer_level = layer_level
        self.input_shape = input_shape
        self.num_class = num_class
        self.num_modal = num_modal
        self.num_chn = num_chn
        self.train_setting()
        self.init_epoch = 0
        self.test_time = 0

        if DILATION and SALIENCY:
            self.model = Dilated_Saliency_UNet(input_shape, self.layer_level,
                                               num_modal, num_class, lr,
                                               self.loss, self.activation,
                                               dilation_factor)
            self.net_name = 'Dilated_Saliency_UNet'

        elif SALIENCY:
            self.model = Saliency_UNet(input_shape, self.layer_level,
                                       num_modal, num_class, lr, self.loss,
                                       self.activation)
            self.net_name = 'Saliency_UNet'
        else:
            input_shape = list(input_shape)
            input_shape[2] = num_chn
            input_shape = tuple(input_shape)
            self.model = UNet(input_shape, self.layer_level, lr, self.loss,
                              self.activation)
            self.net_name = 'UNet_depth' + str(self.layer_level)

        self.model = self.model.compiled_network()

        if VISUALISATION:
            self.visualise()
Example #6
0
def predict(load_path, image_path, scal=1):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = UNet(in_channels=3, classes=1)
    net.load_state_dict(torch.load(load_path, map_location=device))
    img = Image.open(image_path)
    w, h = img.size

    newW, newH = w // scal, h // scal
    img = img.resize((newW, newH))
    img = ToTensor()(img)
    img = img.unsqueeze(0)

    masks = net(img)
    masks = (masks >= 0.5)
    out_img = img * masks  # 添加遮罩
    out_img = out_img.squeeze()
    out_img = ToPILImage()(out_img)
    return out_img
Example #7
0
def generate(output_directory, ckpt_path, ckpt_epoch, n, T, beta_0, beta_T,
             unet_config):
    """
    Generate images using the pretrained UNet model

    Parameters:

    output_directory (str):     output generated images to this path
    ckpt_path (str):            path of the checkpoints
    ckpt_epoch (int or 'max'):  the pretrained model checkpoint to be loaded; 
                                automitically selects the maximum epoch if 'max' is selected
    n (int):                    number of images to generate
    T (int):                    the number of diffusion steps
    beta_0 and beta_T (float):  diffusion parameters
    unet_config (dict):         dictionary of UNet parameters
    """

    # Compute diffusion hyperparameters
    Beta = torch.linspace(beta_0, beta_T, T).cuda()
    Alpha = 1 - Beta
    Alpha_bar = torch.ones(T).cuda()
    Beta_tilde = Beta + 0
    for t in range(T):
        Alpha_bar[t] *= Alpha[t] * Alpha_bar[t - 1] if t else Alpha[t]
        if t > 0:
            Beta_tilde[t] *= (1 - Alpha_bar[t - 1]) / (1 - Alpha_bar[t])
    Sigma = torch.sqrt(Beta_tilde)

    # Predefine model
    net = UNet(**unet_config).cuda()
    print_size(net)

    # Load checkpoint
    if ckpt_epoch == 'max':
        ckpt_epoch = find_max_epoch(ckpt_path, 'unet_ckpt')
    model_path = os.path.join(ckpt_path,
                              'unet_ckpt_' + str(ckpt_epoch) + '.pkl')
    try:
        checkpoint = torch.load(model_path, map_location='cpu')
        print('Model at epoch %s has been trained for %s seconds' %
              (ckpt_epoch, checkpoint['training_time_seconds']))
        net = UNet(**unet_config)
        net.load_state_dict(checkpoint['model_state_dict'])
        net = net.cuda()
    except:
        raise Exception('No valid model found')

    # Generation
    time0 = time.time()
    X_gen = sampling(net, (n, 3, 256, 256), T, Alpha, Alpha_bar, Sigma)
    print('generated %s samples at epoch %s in %s seconds' %
          (n, ckpt_epoch, int(time.time() - time0)))

    # Save generated images
    for i in range(n):
        save_image(rescale(X_gen[i]),
                   os.path.join(output_directory, 'img_{}.jpg'.format(i)))
    print('saved generated samples at epoch %s' % ckpt_epoch)
Example #8
0
    def __init__(self, in_channels, out_channels):

        super().__init__()
        self.model = UNet(in_channels=in_channels,
                          out_channels=out_channels,
                          depth=4,
                          activation="relu",
                          channels_sequence=[32, 128, 256, 512],
                          conv_type="double",
                          dilation=1)
Example #9
0
    def __init__(self, in_channels, out_channels):

        super().__init__()
        self.model = UNet(in_channels=in_channels,
                          out_channels=out_channels,
                          depth=4,
                          channels_sequence=[32, 128, 256, 512],
                          conv_type="double",
                          dilation=[[1, 1], [2, 2], [4, 4], [8, 8], [1, 1]],
                          batchnorm=True,
                          residual_bottleneck=True,
                          downsample_type='conv_stride',
                          activation="prelu",
                          big_upsample=True,
                          advanced_bottleneck=True)
Example #10
0
def run():
    random.seed()
    np.random.seed()
    torch.multiprocessing.freeze_support()
    print('loop')
    # torch.backends.cudnn.enabled = False
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # GaussianNoise.device = device
    # device = torch.device("cpu")
    # Assuming that we are on a CUDA machine, this should print a CUDA device:
    print(device)

    G = UNet(3, 3).to(device)

    try:
        G.load_state_dict(torch.load('./genx'))
        print('net loaded')
    except Exception as e:
        print(e)

    plt.ion()
    dataset = 'ukiyoe2photo'

    real_image = 'testB'
    save_image = 'genA'
    save_prefix = './datasets/' + dataset + '/' + save_image + '/'

    image_path_B = './datasets/' + dataset + '/' + real_image + '/*'

    plt.ion()

    train_image_paths_B = glob.glob(image_path_B)
    print(len(train_image_paths_B))

    b_size = 1

    train_dataset_B = CustomDataset(train_image_paths_B, train=False)
    train_loader_B = torch.utils.data.DataLoader(train_dataset_B,
                                                 batch_size=b_size,
                                                 shuffle=False,
                                                 num_workers=1,
                                                 pin_memory=False)

    G.eval()

    unloader = transforms.ToPILImage()
    with torch.no_grad():
        loop = tqdm(train_loader_B, desc='inf')
        idx = 1
        for im in loop:
            im = im.to(device)
            gen = G(im)
            gen = (gen + 1) / 2.0
            im = unloader(gen.squeeze(0).cpu())
            im.save(save_prefix + '%04d.jpg' % idx)
            idx += 1
Example #11
0
    def compile(self):

        if self.compiled:
            print('Model already compiled.')
            return
        self.compiled = True

        # Placeholders.
        self.X = tf.placeholder(tf.float32, shape=(None, 32, 32, 1), name='X')
        self.Y = tf.placeholder(tf.float32, shape=(None, 32, 32, 2), name='Y')
        self.v = tf.placeholder(tf.float32, shape=(32, 32, 1), name='v')

        # U-Net.
        net = UNet(self.seed)
        self.out = net.forward(self.X)

        # Loss and metrics.
        # TODO: try with MAE.
        self.loss = tf.keras.losses.MeanSquaredError()(self.Y, self.out)

        # Global step.
        self.global_step = tf.Variable(0, trainable=False, name='Global_Step')

        # Learning rate.
        if self.learning_rate_decay:
            self.lr = tf.train.exponential_decay(
                self.learning_rate,
                self.global_step,
                self.learning_rate_decay_steps,
                self.learning_rate_decay_rate,
                name='learning_rate_decay')
        else:
            self.lr = tf.constant(self.learning_rate)

        # Optimizer.
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate=self.lr).minimize(self.loss,
                                            global_step=self.global_step)

        # Sampler.
        gen_sample = UNet(self.seed, is_training=False)
        self.sampler = gen_sample.forward(self.X, reuse_vars=True)

        # Tensorboard.
        tf.summary.scalar('loss', self.loss)

        self.saver = tf.train.Saver()
    def __init__(self,
                 name='U-Net',
                 number_of_filters_for_convolution_blocks=[128, 128, 128],
                 number_of_convolutions_per_block=5,
                 activation_function=tf.nn.relu,
                 use_batch_normalization=False,
                 dropout_rate=0.,
                 use_multiscale_predictions=True,
                 data_format='channels_first'):
        self.name = name

        if self.name == 'U-Net':
            self.architecture = UNet(
                number_of_filters_for_convolution_blocks=
                number_of_filters_for_convolution_blocks,
                number_of_convolutions_per_block=
                number_of_convolutions_per_block,
                use_multiscale_output=use_multiscale_predictions,
                activation_function=activation_function,
                use_batch_normalization=use_batch_normalization,
                dropout_rate=dropout_rate,
                data_format=data_format)
        else:
            assert self.name == 'Tiramisu'
            self.architecture = Tiramisu(
                # TODO: Make it configurable (DeepBlender)
                number_of_preprocessing_convolution_filters=
                number_of_filters_for_convolution_blocks[0],
                number_of_filters_for_convolution_blocks=
                number_of_filters_for_convolution_blocks,
                number_of_convolutions_per_block=
                number_of_convolutions_per_block,
                use_multiscale_output=use_multiscale_predictions,
                activation_function=activation_function,
                use_batch_normalization=use_batch_normalization,
                dropout_rate=dropout_rate,
                data_format=data_format)
Example #13
0
transform = transforms.Compose(
    [transforms.CenterCrop(256),
     transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))])

args = parser.parse_args()

VAR = args.var
DATA_DIR = args.data_dir
CHECKPOINT = args.checkpoint

testset = CustomImageDataset(DATA_DIR, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4)
dataiter = iter(testloader)
checkpoint = torch.load(CHECKPOINT, map_location=torch.device('cpu'))


model_test = UNet(in_channels=3, out_channels=3).double()
model_test.load_state_dict(checkpoint['model_state_dict'])
model_test = model_test.cpu()
model_test.train()

noisy = NoisyDataset(var=VAR)

images, _ = dataiter.next()
noisy_images = noisy(images)
# Displaying the Noisy Images
imshow(torchvision.utils.make_grid(noisy_images.cpu()))
# Displaying the Denoised Images
imshow(torchvision.utils.make_grid(model_test(noisy_images.cpu())))
Example #14
0
# import wandb
# import visdom

# wandb.init(project="unet")

train_data_dir = 'data'
train_mask_dir = 'mask'

train_df = pd.read_csv('train.csv')

if __name__ == '__main__':

    # vis = visdom.Visdom()
    train = CustomDataset(train_df, train_data_dir, train_mask_dir)

    net = UNet(n_channels=3, n_classes=2)

    net = net.cuda()
    criterion = nn.BCELoss().cuda()
    optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9)

    train_loader = DataLoader(train, batch_size=4, shuffle=True, num_workers=8)

    for epoch in range(100):

        index = 0
        epoch_loss = 0

        for item in train_loader:
            index += 1
            img = item['img']
import cv2
import matplotlib.pyplot as plt

from DataLoader import DataLoader
from UNet import UNet

import tensorflow as tf
from tensorflow.keras import backend as K

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
K.set_session(session)

unet = UNet()
unet.load_weights(epoch=12150)
model = unet.get_model()

data_loader = DataLoader(None)
data_loader.evaluate_audio('./eval/My recording 11.mp3', model)

Example #16
0
        print("Dice score:", dice)

        return label, dice

    else:
        dice = None
        return label, dice


if __name__ == "__main__":

    if args.multi_gpu is True:
        os.environ[
            'CUDA_VISIBLE_DEVICES'] = args.gpu_id  # Multi-gpu selector for training
        net = torch.nn.DataParallel(
            (UNet(residual='pool')).cuda())  # load the network Unet

    else:
        torch.cuda.set_device(args.gpu_id)
        net = UNet(residual='pool').cuda()

    net.load_state_dict(torch.load(args.weights))

    result, dice = inference(True,
                             net,
                             args.image,
                             args.label,
                             args.result,
                             args.resample,
                             args.new_resolution,
                             args.patch_size[0],
Example #17
0
IMG_ROOT = '/gpfs/workdir/houdberta/dataset/train_image'
LABEL_ROOT = '/gpfs/workdir/houdberta/dataset/train_label'

BATCH_SIZE = 8

data_type = 'core'
# Load dataset
trainset = TrainSet(IMG_ROOT, LABEL_ROOT, data_type)
trainloader = DataLoader(trainset, BATCH_SIZE, shuffle=True)
print('loader done')

# Defining model and optimization methode
device = 'cuda:0'
#device = 'cpu'
unet = UNet(in_channel=3, class_num=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(unet.parameters(), lr=0.0005, amsgrad=True)

epochs = 1
lsize = len(trainloader)
itr = 0
p_itr = 10  # print every N iteration
unet.train()
tloss = 0
loss_history = []

for epoch in range(epochs):
    with tqdm(total=lsize) as pbar:
        for x, y, path in trainloader:
            x, y = x.to(device), y.to(device)
Example #18
0
def eval_UNet(test_loader, model_path, test_output_path, act_type='sigmoid', loss_type='mse'):
    """
    In this function we find scores of F1, Jaccard Index (IoU) and object level dice index:

    Steps of object-level-dice-index:
        # for every image load target and output
        #   add number of nonzero pixels to Gp and Sq
        #   loop over every gland in target.
        #       find the best match.
        #       calc dice between them.
        #       multiply by pixel num of that target gland
        #       add this value to gt_overlap

        #   do the same in reverse for output.
        #       add the result to seg_overlap

        #   final dice = 0.5 * [(gt_overlap/Gp) + (seg_overlap/Sq)]

    """

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    test_loss_file = open(test_output_path + "/loss.txt", "w")
    test_F1_file = open(test_output_path + "/F1.txt", "w")
    test_IoU_file = open(test_output_path + "/IoU.txt", "w")
    test_precision_file = open(test_output_path + "/precision.txt", "w")
    test_recall_file = open(test_output_path + "/recall.txt", "w")
    test_objDice_file = open(test_output_path + "/objDice.txt", "w")


    model = UNet(upsample_mode='bilinear').to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    TP = FP = FN = 0
    Gp = Sq = 0  # Total target pixels & Total output pixels
    gt_overlap = seg_overlap = 0

    with torch.no_grad():
        for batch_i, sample in enumerate(test_loader):

            # Loading every image and its annotation:
            data, mask, loss_weight, img_name = sample['image'], sample['image_anno'], sample['loss_weight'], sample['name'][0]
            data, mask, loss_weight = data.to(device), mask.to(device), loss_weight.to(device)
            loss_weight = loss_weight / 1000
            output = model(data)

            # Calculate loss:
            if loss_type == 'wbce':
                # Weighted BCE with averaging:
                activation = torch.nn.Sigmoid().cuda()
                criterion = torch.nn.BCELoss(weight=loss_weight).cuda()
                loss = criterion(activation(output), mask).cuda()
                pred = torch.squeeze(activation(output) > 0.5, dim=0).cpu().numpy().astype(
                    np.uint8)  # pred is binarized output with treshhold of 0.5
            elif loss_type == 'bce':
                # BCE with averaging:
                activation = torch.nn.Sigmoid().cuda()
                criterion = torch.nn.BCELoss().cuda()
                loss = criterion(activation(output), mask).cuda()
                pred = torch.squeeze(activation(output) > 0.5, dim=0).cpu().numpy().astype(np.uint8)
            elif loss_type == 'mse':
                # MSE:
                loss = F.mse_loss(output, mask).cuda()
                post_transform = transforms.Compose([Binarize_Output(threshold=output.mean())])
                pred = post_transform(output)
                pred = torch.squeeze(pred, dim=0).cpu().numpy().astype(np.uint8)  # binarized output
            else:
                activation = torch.nn.Sigmoid().cuda()
                loss = jaccard_loss(activation(output), mask).cuda()
                pred = torch.squeeze(activation(output) > 0.5, dim=0).cpu().numpy().astype(np.uint8)  # binarized output

            print("Evaluating image number ", batch_i)#, ", Loss:", loss.item())
            bin_pred = np.squeeze(pred.transpose(1, 2, 0))  # covert to numpy for connecteComponent
            pred_ret, pred_component = connected_components(bin_pred, display=False)  # Find output connected components

            target = torch.squeeze(mask, dim=0).cpu().numpy().astype(np.uint8)
            bin_target = np.squeeze(target.transpose(1, 2, 0))
            target_ret, target_component = connected_components(bin_target, display=False)


            # ============================ Saving Images ===============================
            if loss_type == 'mse':
                trsh = output.mean()
                utils.save_image(output, "{}/{}-output.png".format(test_output_path, img_name))
            else:
                trsh = 0.5
                utils.save_image(F.sigmoid(output), "{}/{}-output.png".format(test_output_path, img_name))

            post_transform = transforms.Compose([Binarize_Output(threshold= trsh)])
            thres = post_transform(output)

            post_transform_weight = transforms.Compose([Binarize_Output(threshold=loss_weight.mean())])
            weight_tresh = post_transform_weight(loss_weight)

            utils.save_image(data, "{}/{}-input.png".format(test_output_path, img_name))
            utils.save_image(mask, "{}/{}-target.png".format(test_output_path, img_name))
            utils.save_image(loss_weight, "{}/{}-weights.png".format(test_output_path, img_name), normalize=True)
            utils.save_image(thres, "{}/{}-thres.png".format(test_output_path, img_name))

            #utils.save_image(data, "{}/test_input_{}.png".format(test_output_path, batch_i))
            #utils.save_image(mask, "{}/test_target_{}.png".format(test_output_path, batch_i))
            #utils.save_image(thres, "{}/test_thres_{}.png".format(test_output_path, batch_i))
            #utils.save_image(weight_tresh, "{}/test_weights_{}.png".format(test_output_path, batch_i))


            # ============================= F1 and Jaccard ============================
            # Find TP, FP, FN for every image
            _TP, _FP, _FN = find_TP_FP_FN(np.array(pred_component), np.array(target_component), bin_pred, bin_target)
            # Add up all of those local TP, FP, FN to the global ones:
            TP += _TP
            FP += _FP
            FN += _FN


            test_loss_file.write(str(loss.item()) + "\n")
            test_loss_file.close()
            test_loss_file = open(test_output_path + "/loss.txt", "a")


            # ============================ object level dice ==========================
            # We have to calculate dice for both side g->s & s->g
            Gp += sum(sum(bin_target))  # sum(sum(bin_target))
            Sq += sum(sum(bin_pred))  # sum(sum(bin_output))

            # g->s:
            # For every gland on target find the best match in output, Then calculate dice between them
            gt_overlap += calc_overlap(target_component, pred_component)
            # s->g:
            # For every gland on output find the best match in target, Then calculate dice between them
            seg_overlap += calc_overlap(pred_component, target_component)


    # ============================ Final step of F1 & Jac =========================
    eps = 1e-15  # To avoid devision by zero:
    Jac_index = float(TP + eps) / float(TP + FP + FN + eps)
    precision = float(TP) / float(TP + FP)
    recall = float(TP) / float(TP + FN)
    F1_score = 2 * (precision * recall + eps) / (precision + recall + eps)
    print("F1 score",F1_score, "Jac_index", Jac_index)#precision, recall")
    #print(F1_score, Jac_index, precision, recall)

    # =========================== Final object dice ===============================
    obj_dice = 0.5 * ((float(gt_overlap) / float(Gp)) + (float(seg_overlap) / float(Sq)))
    print("obj dice is:", obj_dice)

    # =========================== Saving results ===============================
    test_F1_file.write(str(F1_score) + "\n")
    test_F1_file.close()
    test_F1_file = open(test_output_path + "/F1.txt", "a")

    test_IoU_file.write(str(Jac_index) + "\n")
    test_IoU_file.close()
    test_IoU_file = open(test_output_path + "/IoU.txt", "a")

    test_precision_file.write(str(precision) + "\n")
    test_precision_file.close()
    test_precision_file = open(test_output_path + "/precision.txt", "a")

    test_recall_file.write(str(recall) + "\n")
    test_recall_file.close()
    test_recall_file = open(test_output_path + "/recall.txt", "a")

    test_objDice_file.write(str(obj_dice) + "\n")
    test_objDice_file.close()
    test_objDice_file = open(test_output_path + "/objDice.txt", "a")


    return F1_score, Jac_index, precision, recall, obj_dice
Example #19
0
def main_loop(data_path,
              batch_size=batch_size,
              model_type='UNet',
              green=False,
              tensorboard=True):
    # Load train and val data
    tasks = ['EX']
    data_path = data_path
    n_labels = len(tasks)
    n_channels = 1 if green else 3  # green or RGB
    train_loader, val_loader = load_train_val_data(tasks=tasks,
                                                   data_path=data_path,
                                                   batch_size=batch_size,
                                                   green=green)

    if model_type == 'UNet':
        lr = learning_rate
        model = UNet(n_channels, n_labels)
        # Choose loss function
        criterion = nn.MSELoss()
        # criterion = dice_loss
        # criterion = mean_dice_loss
        # criterion = nn.BCELoss()

    elif model_type == 'GCN':
        lr = 1e-4
        model = GCN(n_labels, image_size[0])
        criterion = weighted_BCELoss
        # criterion = nn.BCELoss()

    else:
        raise TypeError('Please enter a valid name for the model type')

    try:
        loss_name = criterion._get_name()
    except AttributeError:
        loss_name = criterion.__name__

    if loss_name == 'BCEWithLogitsLoss':
        lr = 1e-4
        print('learning rate: ', lr)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # Choose optimize
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                              verbose=True,
                                                              patience=7)

    if tensorboard:
        log_dir = tensorboard_folder + session_name + '/'
        print('log dir: ', log_dir)
        if not os.path.isdir(log_dir):
            os.makedirs(log_dir)
        writer = SummaryWriter(log_dir)
    else:
        writer = None

    max_aupr = 0.0
    for epoch in range(epochs):  # loop over the dataset multiple times
        print('******** Epoch [{}/{}]  ********'.format(epoch + 1, epochs + 1))
        print(session_name)

        # train for one epoch
        model.train(True)
        print('Training with batch size : ', batch_size)
        train_loop(train_loader,
                   model,
                   criterion,
                   optimizer,
                   writer,
                   epoch,
                   lr_scheduler=lr_scheduler,
                   model_type=model_type)

        # evaluate on validation set
        print('Validation')
        with torch.no_grad():
            model.eval()
            val_loss, val_aupr = train_loop(val_loader, model, criterion,
                                            optimizer, writer, epoch)

        # Save best model
        if val_aupr > max_aupr and epoch > 3:
            print('\t Saving best model, mean aupr on validation set: {:.4f}'.
                  format(val_aupr))
            max_aupr = val_aupr
            save_checkpoint(
                {
                    'epoch': epoch,
                    'best_model': True,
                    'model': model_type,
                    'state_dict': model.state_dict(),
                    'val_loss': val_loss,
                    'loss': loss_name,
                    'optimizer': optimizer.state_dict()
                }, model_path)

        elif save_model and (epoch + 1) % save_frequency == 0:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'best_model': False,
                    'model': model_type,
                    'loss': loss_name,
                    'state_dict': model.state_dict(),
                    'val_loss': val_loss,
                    'optimizer': optimizer.state_dict()
                }, model_path)

    return model
Example #20
0
import numpy as np
from pathlib import Path
from UNet import UNet
import utils as u

MODEL_PATH = Path('./models/4/UNet4')
TRAIN_DATA_PATH = Path('../Datasets/NucleusSegmentation/stage1_train')
#K_FOLDS = 5
VAL_BATCH_SIZE = 24
SEED = 0

# Construct computational graph
#models = {k:UNet() for k in range(K_FOLDS)}
#for k in models:
print('Constructing graphs...')
model = UNet()  # models[k]
model.inception_block(f_sizes=[1, 3],
                      f_channels=[16, 16],
                      s=1,
                      activation='relu',
                      use_shield=False)
# model.inception_block(f_sizes=[3,5], f_channels=[16,16], s=1, activation='relu', use_shield=True)
model.dropout(1)
model.inception_block(f_sizes=[3, 5],
                      f_channels=[16, 16],
                      s=1,
                      activation='relu',
                      use_shield=True)
# model.inception_block(f_sizes=[3,5], f_channels=[16,16], s=1, activation='relu', use_shield=True)
model.dropout(1)
model.squeeze_inception_block(f_sizes=[2, 3],
Example #21
0
    # checks if all arguments were present.
    if None in (model_name, learning_rate, epochs):
        print_usage()
        sys.exit(2)

    return model_name, learning_rate, epochs


if __name__ == '__main__':
    model_name, learning_rate, epochs = get_arguments()

    device = select_device(force_cpu=False)

    unet = UNet(
        in_channel=1,
        out_channel=6)  # out_channel represents number of classes desired
    unet = unet.to(device)

    paths = get_paths()

    training_set = PatchDataset(
        paths['out_dir'], device, use_wmap=False
    )  # TODO: use_wmap=True when weight_maps are implemented.
    validation_set = PatchDataset(paths['val_dir'], device, use_wmap=False)

    train_UNet(model_name,
               device,
               unet,
               training_set,
               validation_set,
Example #22
0
    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=config["num_workers"])
    val_loader = torch.utils.data.DataLoader(dataset_val,
                                             batch_size=config["batch_size"],
                                             shuffle=False,
                                             num_workers=config["num_workers"])

    device = config["device"]

    # define the network structure -- UNet
    # the output size is not always equal to your input size !!!
    model = UNet(img_ch=config["in_channels"], output_ch=config["n_classes"])
    # model = nn.DataParallel(model)
    model.to(device)

    # please enter the mask dir
    # mask_dir = ""
    # mask_dict = pickle.load(open(mask_dir, "rb"))
    # mask_ = torch.from_numpy(mask_dict[city]["sum"] > 0).bool()

    # get the trainable paramters
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print("# of parameters: ", params)

    trainNet(model, train_loader, val_loader, device)
Example #23
0
    set_gpu_usage()

    # Load Data
    with np.load(conf.train_load_path) as f:
        train_images = f['images']
        train_masks = f['masks']

    with np.load(conf.test_load_path) as f:
        test_images = f['images']
        test_image_shapes = f['shapes']

    train_data = train_images / 255
    train_labels = np.expand_dims(train_masks, axis=-1)

    # Model
    model = UNet()
    if conf.loss == 'dice_coef_loss':
        model.compile(optimizer=conf.optimizer, loss=dice_coef_loss, metrics=[mean_iou])
    elif conf.loss == 'binary_cross_entropy':
        model.compile(optimizer=conf.optimizer, loss='binary_cross_entropy', metrics=[mean_iou])
    else:
        raise()
    model.summary()

    checkpointer = ModelCheckpoint(filepath=conf.weight_path, verbose=1, period=5, save_weights_only=True)
    best_keeper = ModelCheckpoint(filepath=conf.best_path, verbose=1, save_weights_only=True,
                                  monitor='val_mean_iou', save_best_only=True, period=1, mode='max')

    csv_logger = CSVLogger(conf.csv_path)
    tensorboard = TensorBoard(log_dir=conf.log_path)
Example #24
0
def train(args, Dataset):
    ####################################### Initializing Model #######################################
    step = args.lr
    #experiment_dir = args['--experiment_dir']
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("device:{}".format(device))
    print_every = int(args.print_every)
    num_epochs = int(args.num_epochs)
    save_every = int(args.save_every)
    save_path = str(args.model_save_path)
    batch_size = int(args.batch_size)
    #train_data_path = str(args['--data_path'])
    in_ch = int(args.in_ch)
    val_split = args.val_split
    img_directory = args.image_directory
    #model = MW_Unet(in_ch=in_ch)
    model = UNet(in_ch=in_ch)
    #model = model
    model.to(device)
    model.apply(init_weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=step)

    #criterion = nn.MSELoss()
    criterion = torch.nn.L1Loss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

    ######################################### Loading Data ##########################################

    dataset_total = Dataset
    dataset_size = len(dataset_total)
    indices = list(range(dataset_size))
    split = int(np.floor(val_split * dataset_size))
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    #train_indices, val_indices = indices[:1], indices[1:2]
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)

    dataloader_train = torch.utils.data.DataLoader(dataset_total,
                                                   batch_size=batch_size,
                                                   sampler=train_sampler,
                                                   num_workers=8)
    dataloader_val = torch.utils.data.DataLoader(dataset_total,
                                                 batch_size=batch_size,
                                                 sampler=valid_sampler,
                                                 num_workers=2)

    print("length of train set: ", len(train_indices))
    print("length of val set: ", len(val_indices))

    #best_val_PSNR = 0.0
    best_val_MSE, best_val_PSNR, best_val_SSIM = 100.0, -1, -1

    train_PSNRs = []
    train_losses = []
    train_SSIMs = []
    train_MSEs = []

    val_PSNRs = []
    val_losses = []
    val_SSIMs = []
    val_MSEs = []

    try:
        for epoch in range(1, num_epochs + 1):
            # INITIATE dataloader_train
            print("epoch: ", epoch)
            with tqdm(total=len(dataloader_train)) as pbar:
                for index, sample in enumerate(dataloader_train):
                    model.train()

                    target, model_input, features = sample['target'], sample[
                        'input'], sample['features']
                    N, P, C, H, W = model_input.shape
                    N, P, C_feat, H, W = features.shape
                    model_input = torch.reshape(model_input, (-1, C, H, W))
                    features = torch.reshape(features, (-1, C_feat, H, W))
                    albedo = features[:, 3:, :, :]
                    albedo = albedo.to(device)
                    eps = torch.tensor(1e-2)
                    eps = eps.to(device)
                    model_input = model_input.to(device)
                    model_input /= (albedo + eps)
                    target = torch.reshape(target, (-1, C, H, W))
                    features = features.to(device)
                    model_input = torch.cat((model_input, features), dim=1)
                    target = target.to(device)
                    model_input = model_input.to(device)

                    #print(model_input.dtype)
                    #print(model_input.shape)
                    # print(index)

                    output = model.forward(model_input)
                    output *= (albedo + eps)

                    train_loss = utils.backprop(optimizer, output, target,
                                                criterion)
                    train_PSNR = utils.get_PSNR(output, target)
                    train_MSE = utils.get_MSE(output, target)
                    train_SSIM = utils.get_SSIM(output, target)

                    avg_val_PSNR = []
                    avg_val_loss = []
                    avg_val_MSE = []
                    avg_val_SSIM = []
                    model.eval()
                    #output_val = 0;

                    train_losses.append(train_loss.cpu().detach().numpy())
                    train_PSNRs.append(train_PSNR)
                    train_MSEs.append(train_MSE)
                    train_SSIMs.append(train_SSIM)

                    if index == len(dataloader_train) - 1:
                        with torch.no_grad():
                            for val_index, val_sample in enumerate(
                                    dataloader_val):
                                target_val, model_input_val, features_val = val_sample[
                                    'target'], val_sample['input'], val_sample[
                                        'features']
                                N, P, C, H, W = model_input_val.shape
                                N, P, C_feat, H, W = features_val.shape
                                model_input_val = torch.reshape(
                                    model_input_val, (-1, C, H, W))
                                features_val = torch.reshape(
                                    features_val, (-1, C_feat, H, W))
                                albedo = features_val[:, 3:, :, :]
                                albedo = albedo.to(device)
                                eps = torch.tensor(1e-2)
                                eps = eps.to(device)
                                model_input_val = model_input_val.to(device)
                                model_input_val /= (albedo + eps)
                                target_val = torch.reshape(
                                    target_val, (-1, C, H, W))
                                features_val = features_val.to(device)
                                model_input_val = torch.cat(
                                    (model_input_val, features_val), dim=1)
                                target_val = target_val.to(device)
                                model_input_val = model_input_val.to(device)
                                output_val = model.forward(model_input_val)
                                output_val *= (albedo + eps)
                                loss_fn = criterion
                                loss_val = loss_fn(output_val, target_val)
                                PSNR = utils.get_PSNR(output_val, target_val)
                                MSE = utils.get_MSE(output_val, target_val)
                                SSIM = utils.get_SSIM(output_val, target_val)
                                avg_val_PSNR.append(PSNR)
                                avg_val_loss.append(
                                    loss_val.cpu().detach().numpy())
                                avg_val_MSE.append(MSE)
                                avg_val_SSIM.append(SSIM)

                        avg_val_PSNR = np.mean(avg_val_PSNR)
                        avg_val_loss = np.mean(avg_val_loss)
                        avg_val_MSE = np.mean(avg_val_MSE)
                        avg_val_SSIM = np.mean(avg_val_SSIM)

                        val_PSNRs.append(avg_val_PSNR)
                        val_losses.append(avg_val_loss)
                        val_MSEs.append(avg_val_MSE)
                        val_SSIMs.append(avg_val_SSIM)
                        scheduler.step(avg_val_loss)

                        img_grid = output.data[:9]
                        img_grid = torchvision.utils.make_grid(img_grid)
                        real_grid = target.data[:9]
                        real_grid = torchvision.utils.make_grid(real_grid)
                        input_grid = model_input.data[:9, :3, :, :]
                        input_grid = torchvision.utils.make_grid(input_grid)
                        val_grid = output_val.data[:9]
                        val_grid = torchvision.utils.make_grid(val_grid)
                        #save_image(input_grid, '{}train_input_img.png'.format(img_directory))
                        #save_image(img_grid, '{}train_img_{}.png'.format(img_directory, epoch))
                        #save_image(real_grid, '{}train_real_img_{}.png'.format(img_directory, epoch))
                        #print('train images')
                        fig, ax = plt.subplots(4)
                        fig.subplots_adjust(hspace=0.5)
                        ax[0].set_title('target')
                        ax[0].imshow(real_grid.cpu().numpy().transpose(
                            (1, 2, 0)))
                        ax[1].set_title('input')
                        ax[1].imshow(input_grid.cpu().numpy().transpose(
                            (1, 2, 0)))
                        ax[2].set_title('output_train')
                        ax[2].imshow(img_grid.cpu().numpy().transpose(
                            (1, 2, 0)))
                        ax[3].set_title('output_val')
                        ax[3].imshow(val_grid.cpu().numpy().transpose(
                            (1, 2, 0)))
                        #plt.show()
                        plt.savefig('{}train_output_target_img_{}.png'.format(
                            img_directory, epoch))
                        plt.close()

                    pbar.update(1)
            if epoch % print_every == 0:
                print(
                    "Epoch: {}, Loss: {}, Train MSE: {} Train PSNR: {}, Train SSIM: {}"
                    .format(epoch, train_loss, train_MSE, train_PSNR,
                            train_SSIM))
                print(
                    "Epoch: {}, Avg Val Loss: {}, Avg Val MSE: {}, Avg Val PSNR: {}, Avg Val SSIM: {}"
                    .format(epoch, avg_val_loss, avg_val_MSE, avg_val_PSNR,
                            avg_val_SSIM))
                plt.figure()
                plt.semilogy(np.linspace(0, epoch, len(train_losses)),
                             train_losses)
                plt.xlabel("Epoch")
                plt.ylabel("Loss")
                plt.savefig("{}train_loss.png".format(img_directory))
                plt.close()

                plt.figure()
                plt.semilogy(np.linspace(0, epoch, len(val_losses)),
                             val_losses)
                plt.xlabel("Epoch")
                plt.ylabel("Loss")
                plt.savefig("{}val_loss.png".format(img_directory))
                plt.close()

                plt.figure()
                plt.plot(np.linspace(0, epoch, len(train_PSNRs)), train_PSNRs)
                plt.xlabel("Epoch")
                plt.ylabel("PSNR")
                plt.savefig("{}train_PSNR.png".format(img_directory))
                plt.close()

                plt.figure()
                plt.plot(np.linspace(0, epoch, len(val_PSNRs)), val_PSNRs)
                plt.xlabel("Epoch")
                plt.ylabel("PSNR")
                plt.savefig("{}val_PSNR.png".format(img_directory))
                plt.close()

                plt.figure()
                plt.semilogy(np.linspace(0, epoch, len(train_MSEs)),
                             train_MSEs)
                plt.xlabel("Epoch")
                plt.ylabel("MSE")
                plt.savefig("{}train_MSE.png".format(img_directory))
                plt.close()

                plt.figure()
                plt.semilogy(np.linspace(0, epoch, len(val_MSEs)), val_MSEs)
                plt.xlabel("Epoch")
                plt.ylabel("MSE")
                plt.savefig("{}val_MSE.png".format(img_directory))
                plt.close()

                plt.figure()
                plt.plot(np.linspace(0, epoch, len(train_SSIMs)), train_SSIMs)
                plt.xlabel("Epoch")
                plt.ylabel("SSIM")
                plt.savefig("{}train_SSIM.png".format(img_directory))
                plt.close()

                plt.figure()
                plt.plot(np.linspace(0, epoch, len(val_SSIMs)), val_SSIMs)
                plt.xlabel("Epoch")
                plt.ylabel("SSIM")
                plt.savefig("{}val_SSIM.png".format(img_directory))
                plt.close()

            if best_val_MSE > avg_val_MSE:
                best_val_MSE, best_val_PSNR, best_val_SSIM = avg_val_MSE, avg_val_PSNR, avg_val_SSIM
                print("new best Avg Val MSE: {}".format(best_val_MSE))
                print("Saving model to {}".format(save_path))
                torch.save(
                    {
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': train_loss
                    }, save_path + "best_model.pth")
                print("Saved successfully to {}".format(save_path))

    except KeyboardInterrupt:
        print("Training interupted...")
        print("Saving model to {}".format(save_path))
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': train_loss
            }, save_path + "checkpoint{}.pth".format(epoch))
        print("Saved successfully to {}".format(save_path))

        print("Training completed.")

    print("Best MSE: %.10f, Best PSNR: %.10f, Best SSIM: %.10f" %
          (best_val_MSE, best_val_PSNR, best_val_SSIM))
    return (train_losses, train_PSNRs, val_losses, val_PSNRs, best_val_MSE)
Example #25
0
def train(train_loader,
          valid_loader,
          loss_type,
          act_type,
          tolerance,
          result_path,
          log_interval=10,
          lr=0.000001,
          max_epochs=500):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # model = UNet(upsample_mode='transpose').to(device)
    model = UNet(upsample_mode='bilinear').to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    best_model_path = result_path + '/best_model.pth'
    model_path = result_path + '/model_epoch'

    train_batch_loss_file = open(result_path + "/train_batch_loss.txt", "w")
    valid_batch_loss_file = open(result_path + "/valid_batch_loss.txt", "w")

    train_all_epochs_loss_file = open(
        result_path + "/train_all_epochs_loss.txt", "w")
    train_all_epochs_loss = []

    valid_all_epochs_loss_file = open(
        result_path + "/valid_all_epochs_loss.txt", "w")
    valid_all_epochs_loss = []

    minimum_loss = np.inf
    finish = False

    for epoch in range(1, max_epochs + 1):
        for phase in ['train', 'val']:
            if phase == 'train':
                idx = list(range(0, len(train_loader)))
                train_smpl = random.sample(idx, 1)
                #train_smpl.append(len(train_loader)-1)
                loader = train_loader
                model.train()
            elif phase == 'val':
                idx = list(range(0, len(valid_loader)))
                val_smpl = random.sample(idx, 1)
                #val_smpl.append(len(valid_loader) - 1)
                loader = valid_loader
                model.eval()

            all_batches_losses = []

            for batch_i, sample in enumerate(loader):
                data, target, loss_weight = sample['image'], sample[
                    'image_anno'], sample['loss_weight']  #/1000
                data, target, loss_weight = data.to(device), target.to(
                    device), loss_weight.to(device)

                optimizer.zero_grad()
                loss_weight = loss_weight / 1000

                with torch.set_grad_enabled(phase == 'train'):
                    output = model(data)
                    # Set activation type:
                    if act_type == 'sigmoid':
                        activation = torch.nn.Sigmoid().cuda()
                    elif act_type == 'tanh':
                        activation = torch.nn.Tanh().cuda()
                    elif act_type == 'soft':
                        activation = torch.nn.Softmax().cuda()

                    # Calculate loss:
                    if loss_type == 'wbce':
                        # Weighted BCE with averaging:
                        criterion = torch.nn.BCELoss(weight=loss_weight).cuda(
                        )  #,size_average=False).cuda()
                        loss = criterion(activation(output), target).cuda()
                        #loss = criterion(output, target).cuda()
                    elif loss_type == 'bce':
                        # BCE with averaging:
                        criterion = torch.nn.BCELoss().cuda(
                        )  # ,size_average=False).cuda()
                        loss = criterion(activation(output), target).cuda()
                    elif loss_type == 'mse':
                        # MSE:
                        loss = F.mse_loss(output, target).cuda()
                    else:  # loss_type == 'jac':
                        loss = jaccard_loss(activation(output), target).cuda()

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                if phase == 'train':
                    train_batch_loss_file.write(str(loss.item()) + "\n")
                    train_batch_loss_file.close()
                    train_batch_loss_file = open(
                        result_path + "/train_batch_loss.txt", "a")
                else:
                    valid_batch_loss_file.write(str(loss.item()) + "\n")
                    valid_batch_loss_file.close()

                    valid_batch_loss_file = open(
                        result_path + "/valid_batch_loss.txt", "a")

                all_batches_losses.append(loss.item())

                if batch_i % log_interval == 0:
                    print(
                        '{} Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                            phase, epoch, batch_i * len(data),
                            len(loader.dataset), 100. * batch_i / len(loader),
                            loss.item()))

                if phase == 'train' and batch_i in train_smpl:
                    post_transform = transforms.Compose(
                        [Binarize_Output(threshold=output.mean())])
                    thres = post_transform(output)
                    post_transform_weight = transforms.Compose(
                        [Binarize_Output(threshold=loss_weight.mean())])
                    weight_tresh = post_transform_weight(loss_weight)
                    utils.save_image(
                        data, "{}/train_input_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        target, "{}/train_target_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        output, "{}/train_output_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        thres, "{}/train_thres_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        weight_tresh, "{}/train_weights_{}_{}.png".format(
                            result_path, epoch, batch_i))

                    if epoch % 25 == 0:
                        torch.save(model.state_dict(),
                                   model_path + '_{}.pth'.format(epoch))

                if phase == 'val' and batch_i in val_smpl:
                    post_transform = transforms.Compose(
                        [Binarize_Output(threshold=output.mean())])
                    thres = post_transform(output)

                    post_transform_weight = transforms.Compose(
                        [Binarize_Output(threshold=loss_weight.mean())])
                    weight_tresh = post_transform_weight(loss_weight)

                    utils.save_image(
                        data, "{}/valid_input_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        target, "{}/valid_target_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        output, "{}/valid_output_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        thres, "{}/valid_thres_{}_{}.png".format(
                            result_path, epoch, batch_i))
                    utils.save_image(
                        weight_tresh, "{}/valid_weights_{}_{}.png".format(
                            result_path, epoch, batch_i))

            if phase == 'train':
                train_last_avg_loss = np.mean(all_batches_losses)
                print("------average %s loss %f" %
                      (phase, train_last_avg_loss))
                train_all_epochs_loss_file.write(
                    str(train_last_avg_loss) + "\n")
                train_all_epochs_loss_file.close()
                train_all_epochs_loss_file = open(
                    result_path + "/train_all_epochs_loss.txt", "a")
            if phase == 'val':
                valid_last_avg_loss = np.mean(all_batches_losses)
                print("------average %s loss %f" %
                      (phase, valid_last_avg_loss))
                valid_all_epochs_loss_file.write(
                    str(valid_last_avg_loss) + "\n")
                valid_all_epochs_loss_file.close()
                valid_all_epochs_loss_file = open(
                    result_path + "/valid_all_epochs_loss.txt", "a")
                valid_all_epochs_loss.append(valid_last_avg_loss)
                if valid_last_avg_loss < minimum_loss:
                    minimum_loss = valid_last_avg_loss
                    #--------------------- Saving the best found model -----------------------
                    torch.save(model.state_dict(), best_model_path)
                    print("Minimum Average Loss so far:", minimum_loss)
                if early_stopping(epoch, valid_all_epochs_loss, tolerance):
                    finish = True
                    break

        if finish == True:
            break
Example #26
0
def train(net: UNet,
          train_ids_file_path: str,
          val_ids_file_path: str,
          in_dir_path: str,
          mask_dir_path: str,
          check_points: str,
          epochs=10,
          batch_size=4,
          learning_rate=0.1,
          device=torch.device("cpu")):
    train_data_set = ImageSet(train_ids_file_path, in_dir_path, mask_dir_path)

    train_data_loader = DataLoader(train_data_set,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=1)

    net = net.to(device)

    loss_func = nn.BCELoss()
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=learning_rate,
                                momentum=0.99)
    writer = SummaryWriter("tensorboard")
    g_step = 0

    for epoch in range(epochs):
        net.train()
        total_loss = 0

        with tqdm(total=len(train_data_set),
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for step, (imgs, masks) in tqdm(enumerate(train_data_loader)):
                imgs = imgs.to(device)
                masks = masks.to(device)

                outputs = net(imgs)
                loss = loss_func(outputs, masks)
                total_loss += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # record
                writer.add_scalar("Loss/Train", loss.item(), g_step)
                writer.flush()
                pbar.set_postfix(**{'loss (batch)': loss.item()})
                pbar.update(imgs.shape[0])
                g_step += 1

                if g_step % 10 == 0:
                    writer.add_images('masks/origin', imgs, g_step)
                    writer.add_images('masks/true', masks, g_step)
                    writer.add_images('masks/pred', outputs > 0.5, g_step)
                    writer.flush()

        try:
            os.mkdir(check_points)
            logging.info('Created checkpoint directory')
        except OSError:
            pass
        torch.save(net.state_dict(), check_points + f'CP_epoch{epoch + 1}.pth')
        logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()
    dataiter)['mask'], next(dataiter)['dpmap']

bgfg_grid = torchvision.utils.make_grid(bgfg)

# matplotlib_imshow(bgfg_images)
writer.add_image('BG_FG Images', bgfg_grid)

# Commented out IPython magic to ensure Python compatibility.
# %tensorboard -- logdir=/content/runs/Aug05_10-08-37_c31f995905fa
# %tensorboard --logdir logs/tensorboard

from torchsummary import summary
import torch
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
unet_model = UNet(6, 1).to(device)

summary(unet_model, input_size=(6, 192, 192))
"""**LOSS FUNCTIONS & CO-EFFICIENT**"""

from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau

unet_model = UNet(6, 1).to(device)  # MODEL

# CRITERION PART
depth_c = IoULoss()
mask_c = nn.BCEWithLogitsLoss()

optimizer = torch.optim.SGD(unet_model.parameters(), lr=0.01)
current_lr = [param_group['lr'] for param_group in optimizer.param_groups][0]
Example #28
0
import _init_paths
from utils import to_categorical_4d, to_categorical_4d_reverse
from UNet import UNet
from test import *
import tensorflow as tf
import numpy as np
import ear_pen
import math
import cv2

model_store_path = '../model/UNet/UNet.ckpt'

if __name__ == '__main__':
    img_ph = tf.placeholder(tf.float32, [None, 104, 78, 3])
    ann_ph = tf.placeholder(tf.int32, [None, 104, 78, 1])
    net = UNet(img_ph, ann_ph)
    work(img_ph, ann_ph, net, model_store_path)
Example #29
0
import numpy as np
from pathlib import Path
from UNet import UNet
import utils as u

MODEL_PATH = Path('./models/0')
TRAIN_DATA_PATH = Path('../Datasets/NucleusSegmentation/stage1_train')
#K_FOLDS = 5
VAL_BATCH_SIZE = 32
SEED = 0

# Construct computational graph
#models = {k:UNet() for k in range(K_FOLDS)}
#for k in models:
print('Constructing graphs...')
model = UNet()  # models[k]
model.convolution(f=3, s=1, n_out=32, activation='relu')
model.squeeze_convolution(f=2, s=2, n_out=32, activation='relu')
model.convolution(f=3, s=1, n_out=64, activation='relu')
model.squeeze_convolution(f=2, s=2, n_out=64, activation='relu')
model.convolution(f=3, s=1, n_out=128, activation='relu')
model.squeeze_convolution(f=2, s=2, n_out=128, activation='relu')
model.convolution(f=3, s=1, n_out=256, activation='relu')
model.squeeze_convolution(f=2, s=2, n_out=256, activation='relu')
model.convolution(f=3, s=1, n_out=512, activation='relu')
model.convolution(f=3, s=1, n_out=512, activation='relu')
model.stretch_transpose_convolution(f=2, s=2, n_out=256, activation='relu')
model.convolution(f=3, s=1, n_out=256, activation='relu')
model.stretch_transpose_convolution(f=2, s=2, n_out=128, activation='relu')
model.convolution(f=3, s=1, n_out=128, activation='relu')
model.stretch_transpose_convolution(f=2, s=2, n_out=64, activation='relu')
Example #30
0
        torch.save(net.state_dict(), check_points + f'CP_epoch{epoch + 1}.pth')
        logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()


if __name__ == '__main__':
    in_dir = "data/kaggle/train"
    out_dir = "data/kaggle/train_masks"
    train_lst_path = "data/kaggle/train.txt"
    val_lst_path = "data/kaggle/val.txt"
    check_points = "check_points/"

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(F"Using {device}")
    net = UNet(in_channels=5, classes=1)

    load_path = ""
    if load_path != "":
        net.load_state_dict(torch.load(load_path, map_location=device))

    try:
        train(net=net,
              train_ids_file_path=train_lst_path,
              val_ids_file_path=val_lst_path,
              in_dir_path=in_dir,
              mask_dir_path=out_dir,
              check_points=check_points,
              epochs=2,
              batch_size=1,
              learning_rate=0.1,