Exemplo n.º 1
0
def load_model(data, model_path, cuda=True):

    if cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")

    unet = UNet()

    if cuda:
        unet = unet.cuda()

    if not cuda:
        unet.load_state_dict(
            torch.load(model_path, map_location=lambda storage, loc: storage))
    else:
        unet.load_state_dict(torch.load(model_path))

    if cuda:
        data = Variable(data.cuda())
    else:
        data = Variable(data)
    data = torch.unsqueeze(data, 0)

    output = unet(data)
    if cuda:
        output = output.cuda()

    return output
Exemplo n.º 2
0
def see_results(n_channels, n_classes, load_weights, dir_img, dir_cmp, savedir,
                title):
    # Use GPU or not
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Create the model
    net = UNet(n_channels, n_classes).to(device)
    net = torch.nn.DataParallel(
        net, device_ids=list(range(torch.cuda.device_count()))).to(device)

    # Load old weights
    checkpoint = torch.load(load_weights, map_location='cpu')
    net.load_state_dict(checkpoint['state_dict'])

    # Load the dataset
    loader = get_dataloader_show(dir_img, dir_cmp)

    # If savedir does not exists make folder
    if not os.path.exists(savedir):
        os.makedirs(savedir)

    net.eval()
    with torch.no_grad():
        for (data, gt) in loader:
            # Use GPU or not
            data, gt = data.to(device), gt.to(device)

            # Forward
            predictions = net(data)

            save_image(predictions, savedir + title + "_pred.png")
            save_image(gt, savedir + title + "_gt.png")
Exemplo n.º 3
0
def main():
    # width_in = 284
    # height_in = 284
    # width_out = 196
    # height_out = 196
    # PATH = './unet.pt'
    # x_train, y_train, x_val, y_val = get_dataset(width_in, height_in, width_out, height_out)
    # print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)

    batch_size = 3
    epochs = 1
    epoch_lapse = 50
    threshold = 0.5
    learning_rate = 0.01
    unet = UNet(in_channel=1, out_channel=2)
    if use_gpu:
        unet = unet.cuda()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(unet.parameters(), lr=0.01, momentum=0.99)
    if sys.argv[1] == 'train':
        train(unet, batch_size, epochs, epoch_lapse, threshold, learning_rate,
              criterion, optimizer, x_train, y_train, x_val, y_val, width_out,
              height_out)
        pass
    else:
        if use_gpu:
            unet.load_state_dict(torch.load(PATH))
        else:
            unet.load_state_dict(torch.load(PATH, map_location='cpu'))
        print(unet.eval())
Exemplo n.º 4
0
def lr_find(model: UNet,
            data_loader,
            optimizer: Optimizer,
            criterion,
            use_gpu,
            min_lr=0.0001,
            max_lr=0.1):
    # Save model and optimizer states to revert
    model_state = model.state_dict()
    optimizer_state = optimizer.state_dict()

    losses = []
    lrs = []
    scheduler = CyclicExpLR(optimizer,
                            min_lr,
                            max_lr,
                            step_size_up=100,
                            mode='triangular',
                            cycle_momentum=True)
    model.train()
    for i, (data, target, class_ids) in enumerate(data_loader):
        data, target = data, target

        if use_gpu:
            data = data.cuda()
            target = target.cuda()

        optimizer.zero_grad()
        output_raw = model(data)
        # This step is specific for this project
        output = torch.zeros(output_raw.shape[0], 1, output_raw.shape[2],
                             output_raw.shape[3])

        if use_gpu:
            output = output.cuda()

        # This step is specific for this project
        for idx, (raw_o, class_id) in enumerate(zip(output_raw, class_ids)):
            output[idx] = raw_o[class_id - 1]

        loss = criterion(output, target)
        loss.backward()
        current_lr = optimizer.param_groups[0]['lr']
        # Stop if lr stopped increasing
        if len(lrs) > 0 and current_lr < lrs[-1]:
            break
        lrs.append(current_lr)
        losses.append(loss.item())
        optimizer.step()
        scheduler.step()

    # Plot in log scale
    plt.plot(lrs, losses)
    plt.xscale('log')

    plt.show()

    model.load_state_dict(model_state)
    optimizer.load_state_dict(optimizer_state)
Exemplo n.º 5
0
def load_model(key: str):
    if key == BRAIN_DENOISING_MODEL:
        checkpoint = torch.load(BRAIN_DENOISING_MODEL)
        unet = UNet(in_channels=1, n_classes=1, depth=3, wf=8, padding=True,
                    batch_norm=False, up_mode='upconv', grid=True, bias=True)
        unet.load_state_dict(checkpoint["model_state_dict"])
        return unet

    if key == BRAIN_SEG_MODEL:
        checkpoint = torch.load(BRAIN_SEG_MODEL)
        unet = UNet(in_channels=1, n_classes=1, depth=3, wf=8, padding=True,
                    batch_norm=False, up_mode='upconv', grid=False, bias=True)
        unet.load_state_dict(checkpoint["model_state_dict"])

    if key == ABDOM_DENOISING_MODEL:
        checkpoint = torch.load(ABDOM_DENOISING_MODEL)
        unet = UNet(in_channels=1, n_classes=1, depth=3, wf=8, padding=True,
                    batch_norm=False, up_mode='upconv', grid=True, bias=True)
        unet.load_state_dict(checkpoint["model_state_dict"])

    if key == ABDOM_SEG_MODEL:
        checkpoint = torch.load(ABDOM_SEG_MODEL)
        unet = UNet(in_channels=1, n_classes=1, depth=3, wf=8, padding=True,
                    batch_norm=False, up_mode='upconv', grid=False, bias=True)
        unet.load_state_dict(checkpoint["model_state_dict"])

    return unet
Exemplo n.º 6
0
def test(path):
    """Count the input image"""
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Image
    image = np.array(Image.open(path), dtype=np.float32) / 255
    image = torch.Tensor(np.transpose(image, (2, 0, 1))).unsqueeze(0)

    # Ground Truth
    header = ".".join(path.split('/')[-1].split('.')[:2])
    label_path = opt.label_path + header + '.label.png'
    label = np.array(Image.open(label_path))
    if opt.color == 'red':
        labels = 100.0 * (label[:, :, 0] > 0)
    else:
        labels = 100.0 * (label[:, :, 1] > 0)
    labels = ndimage.gaussian_filter(labels, sigma=(1, 1), order=0)
    labels = torch.Tensor(labels).unsqueeze(0)

    if opt.model.find("UNet") != -1:
        model = UNet(input_filters=3, filters=opt.unet_filters,
                     N=opt.conv).to(device)
    else:
        model = FCRN_A(input_filters=3, filters=opt.unet_filters,
                       N=opt.conv).to(device)
    model = torch.nn.DataParallel(model)

    if os.path.exists('{}.pth'.format(opt.model)):
        model.load_state_dict(torch.load('{}.pth'.format(opt.model)))

    model.eval()
    image = image.to(device)
    labels = labels.to(device)

    out = model(image)
    predicted_counts = torch.sum(out).item() / 100
    real_counts = torch.sum(labels).item() / 100
    print(predicted_counts, real_counts)

    label = np.zeros((image.shape[2], image.shape[2], 3))
    if opt.color == 'red':
        label[:, :, 0] = out[0][0].cpu().detach().numpy()
    else:
        label[:, :, 1] = out[0][0].cpu().detach().numpy()

    imageio.imwrite('example/test_results/density_map_{}.png'.format(header),
                    label)

    return header, predicted_counts, real_counts
Exemplo n.º 7
0
def load_model():
	
	checkpoint = get_model()
	
	model = UNet(
		backbone="mobilenetv2",
		num_classes=2,
		pretrained_backbone=None
	)
	
	trained_dict = torch.load(checkpoint, map_location="cpu")['state_dict']
	model.load_state_dict(trained_dict, strict=False)
	model.eval()
	
	print("model is loaded")
	
	return model
Exemplo n.º 8
0
def main():
    opt = parser.parse_args()
    print(torch.__version__)
    print(opt)

    enc_layers = [1, 2, 2, 4]
    dec_layers = [1, 1, 1, 1]
    number_of_channels = [
        int(8 * 2**i) for i in range(1, 1 + len(enc_layers))
    ]  #[16,32,64,128]
    model = UNet(depth=len(enc_layers),
                 encoder_layers=enc_layers,
                 decoder_layers=dec_layers,
                 number_of_channels=number_of_channels,
                 number_of_outputs=3)
    s = torch.load(os.path.join(opt.models_path, opt.name,
                                opt.name + 'best_model.pth'),
                   map_location='cpu')
    new_state_dict = OrderedDict()
    for k, v in s['model'].state_dict().items():
        name = k[7:]  # remove 'module' word in the beginning of keys.
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    model.eval()

    x = torch.randn(1,
                    4,
                    opt.input_size[0],
                    opt.input_size[1],
                    opt.input_size[2],
                    requires_grad=True)

    register_op("group_norm", group_norm_symbolic, "", 10)

    torch_out = torch.onnx.export(
        model,  # model being run
        [
            x,
        ],  # model input (or a tuple for multiple inputs)
        os.path.join(
            opt.models_path, opt.name, opt.name + ".onnx"
        ),  # where to save the model (can be a file or file-like object)
        export_params=True,
        verbose=
        True,  # store the trained parameter weights inside the model file
        opset_version=10)
Exemplo n.º 9
0
def test(args):
    """
    Test some data from trained UNet
    """
    image = load_test_image(args.test_image)  # 1 c w h
    net = UNet(in_channels=3, out_channels=5)
    if args.cuda:
        net = net.cuda()
        image = image.cuda()
    print('Loading model param from {}'.format(args.model_state_dict))
    net.load_state_dict(torch.load(args.model_state_dict))
    net.eval()

    print('Predicting for {}...'.format(args.test_image))
    ys_pred = net(image)  # 1 ch w h

    colors = []
    with open(args.mask_json_path, 'r', encoding='utf-8') as mask:
        print('Reading mask colors list from {}'.format(args.mask_json_path))
        colors = json.loads(mask.read())
        colors = [tuple(c) for c in colors]
        print('Mask colors: {}'.format(colors))

    ys_pred = ys_pred.cpu().detach().numpy()[0]
    ys_pred[ys_pred < 0.5] = 0
    ys_pred[ys_pred >= 0.5] = 1
    ys_pred = ys_pred.astype(np.int)
    image_w = ys_pred.shape[1]
    image_h = ys_pred.shape[2]
    out_image = np.zeros((image_w, image_h, 3))

    for w in range(image_w):
        for h in range(image_h):
            for ch in range(ys_pred.shape[0]):
                if ys_pred[ch][w][h] == 1:
                    out_image[w][h][0] = colors[ch][0]
                    out_image[w][h][1] = colors[ch][1]
                    out_image[w][h][2] = colors[ch][2]

    out_image = out_image.astype(np.uint8)  # w h c
    out_image = out_image.transpose((1, 0, 2))  # h w c
    out_image = Image.fromarray(out_image)
    out_image.save(args.test_save_path)
    print('Segmentation result has been saved to {}'.format(
        args.test_save_path))
class modelLoader():
    def __init__(self, model_folder, in_chnl, out_chnl, checkpoint=None):
        self.checkpoint = checkpoint
        self.model_folder = model_folder
        self.in_chnl = in_chnl
        self.out_chnl = out_chnl
        self.model = None

    def set_model(self):
        self.model = UNet(self.in_chnl, self.out_chnl).to(device)
        PATH = "./{}/checkpoint_{}.pth".format(self.model_folder,
                                               self.checkpoint)
        self.model.load_state_dict(torch.load(PATH))
        self.model.eval()

    def model_infos(self):
        print(
            "Checkpoint: {}, Model folder: ./{}, Input channel: {}, Output channel: {}"
            .format(self.checkpoint, self.model_folder, self.in_chnl,
                    self.out_chnl))
Exemplo n.º 11
0
def test(args):
    model = UNet(n_channels=125, n_classes=10).to(device)
    model.load_state_dict(
        torch.load(args.ckpt % args.num_epochs, map_location='cpu'))
    liver_dataset = LiverDataset('data',
                                 transform=x_transform,
                                 target_transform=y_transform,
                                 train=False)
    dataloaders = DataLoader(liver_dataset, batch_size=1)
    model.eval()
    with torch.no_grad():
        for x, y in dataloaders:
            x = x.permute(0, 2, 1, 3)
            print(x.shape)
            x = x.float()
            label = y.float()
            x = x.to(device)
            label = label.to(device)
            outputs = model(x)
            # print(outputs.shape)
            label = label.squeeze(1)
Exemplo n.º 12
0
Arquivo: test.py Projeto: krsrv/UNet
def test(n_class=1,
         in_channel=1,
         load=False,
         img_size=None,
         directory='../Data/train/'):
    global original, dataset, model
    original = Segmentation(directory, 'training.json', ToTensor())
    if img_size is None:
        dataset = Segmentation(directory, 'test.json', ToTensor())
    else:
        dataset = Segmentation(directory, 'test.json', T.Compose([\
          RandomCrop((img_size, img_size)), ToTensor()]))

    model = UNet(n_class=n_class, in_channel=in_channel)
    if load:
        filename = "unet.pth"
        map_location = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        try:
            checkpoint = torch.load(filename, map_location=map_location)
            model.load_state_dict(checkpoint['state_dict'])
            print("Loaded saved model")
        except:
            print("Unable to load saved model")
Exemplo n.º 13
0
def run_inference(args):
    model = UNet(input_channels=3, num_classes=3)
    model.load_state_dict(torch.load(args.model_path, map_location='cpu'),
                          strict=False)
    print('Log: Loaded pretrained {}'.format(args.model_path))
    model.eval()
    # annus/Desktop/palsar/
    test_image_path = '/home/annus/Desktop/palsar/palsar_dataset_full/palsar_dataset/palsar_{}_region_{}.tif'.format(
        args.year, args.region)
    test_label_path = '/home/annus/Desktop/palsar/palsar_dataset_full/palsar_dataset/fnf_{}_region_{}.tif'.format(
        args.year, args.region)
    inference_loader = get_inference_loader(image_path=test_image_path,
                                            label_path=test_label_path,
                                            model_input_size=128,
                                            num_classes=4,
                                            one_hot=True,
                                            batch_size=args.bs,
                                            num_workers=4)
    # we need to fill our new generated test image
    generated_map = np.empty(shape=inference_loader.dataset.get_image_size())
    weights = torch.Tensor([1, 1, 2])
    focal_criterion = FocalLoss2d(weight=weights)
    un_confusion_meter = tnt.meter.ConfusionMeter(2, normalized=False)
    confusion_meter = tnt.meter.ConfusionMeter(2, normalized=True)
    total_correct, total_examples = 0, 0
    net_loss = []
    for idx, data in enumerate(inference_loader):
        coordinates, test_x, label = data['coordinates'].tolist(
        ), data['input'], data['label']
        out_x, softmaxed = model.forward(test_x)
        pred = torch.argmax(softmaxed, dim=1)
        not_one_hot_target = torch.argmax(label, dim=1)
        # convert to binary classes
        # 0-> noise, 1-> forest, 2-> non-forest, 3-> water
        pred[pred == 0] = 2
        pred[pred == 3] = 2
        not_one_hot_target[not_one_hot_target == 0] = 2
        not_one_hot_target[not_one_hot_target == 3] = 2
        # now convert 1, 2 to 0, 1
        pred -= 1
        not_one_hot_target -= 1
        pred_numpy = pred.numpy().transpose(1, 2, 0)
        for k in range(test_x.shape[0]):
            x, x_, y, y_ = coordinates[k]
            generated_map[x:x_, y:y_] = pred_numpy[:, :, k]
        loss = focal_criterion(
            softmaxed,
            not_one_hot_target)  # dice_criterion(softmaxed, label) #
        accurate = (pred == not_one_hot_target).sum().item()
        numerator = float(accurate)
        denominator = float(
            pred.view(-1).size(0))  # test_x.size(0) * dimension ** 2)
        total_correct += numerator
        total_examples += denominator
        net_loss.append(loss.item())
        un_confusion_meter.add(predicted=pred.view(-1),
                               target=not_one_hot_target.view(-1))
        confusion_meter.add(predicted=pred.view(-1),
                            target=not_one_hot_target.view(-1))
        # if idx % 5 == 0:
        accuracy = float(numerator) * 100 / denominator
        print(
            '{}, {} -> ({}/{}) output size = {}, loss = {}, accuracy = {}/{} = {:.2f}%'
            .format(args.year, args.region, idx, len(inference_loader),
                    out_x.size(), loss.item(), numerator, denominator,
                    accuracy))
        #################################
    mean_accuracy = total_correct * 100 / total_examples
    mean_loss = np.asarray(net_loss).mean()
    print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
    print('log: test:: total loss = {:.5f}, total accuracy = {:.5f}%'.format(
        mean_loss, mean_accuracy))
    print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
    print('---> Confusion Matrix:')
    print(confusion_meter.value())
    # class_names = ['background/clutter', 'buildings', 'trees', 'cars',
    #                'low_vegetation', 'impervious_surfaces', 'noise']
    with open('normalized.pkl', 'wb') as this:
        pkl.dump(confusion_meter.value(), this, protocol=pkl.HIGHEST_PROTOCOL)
    with open('un_normalized.pkl', 'wb') as this:
        pkl.dump(un_confusion_meter.value(),
                 this,
                 protocol=pkl.HIGHEST_PROTOCOL)

    # save_path = 'generated_maps/generated_{}_{}.npy'.format(args.year, args.region)
    save_path = '/home/annus/Desktop/palsar/generated_maps/using_separate_models/generated_{}_{}.npy'.format(
        args.year, args.region)
    np.save(save_path, generated_map)
    #########################################################################################3
    inference_loader.dataset.clear_mem()
    pass
Exemplo n.º 14
0
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

if __name__ == '__main__':
    # model selection
    print('===> Building model')
    model = UNet(input_channels=2, image_channels=1)
    pre_model = DnCNN(image_channels=1)

    pre_model = torch.load(os.path.join(args.load_model_dir, 'model.pth'))

    initial_epoch = findLastCheckpoint(save_dir=save_dir)# load the last model in matconvnet style
    # initial_epoch = 150
    if initial_epoch > 0:
        print('resuming by loading epoch %03d' % initial_epoch)
        model.load_state_dict(torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)))
        # model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))

    model.train()
    pre_model.eval()
    # criterion = nn.MSELoss(reduction = 'sum')  # PyTorch 0.4.1
    # criterion = sum_squared_error()
    # criterion = nn.MSELoss()

    criterion = Intensity_loss()
    cri_chk = nn.MSELoss()

    if cuda:
        model = model.cuda()
        pre_model = pre_model.cuda()
        # device_ids = [0]
Exemplo n.º 15
0
import numpy as np
import io
import cv2
import os
from werkzeug.utils import secure_filename
import base64
from model import UNet
import json
from PIL import Image
import re

MODEL_PATH = "model/0.pth"
UPLOAD_PATH = "../backend/uploads/"

model = UNet()
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))

app = flask.Flask(__name__)
app.secret_key = "super secret key"
cors = CORS(app)
app.config["UPLOAD_FOLDER"] = UPLOAD_PATH
app.config['CORS_HEADERS'] = 'Content-Type'

ctr = 0


@app.route("/", methods=["GET", "POST"])
@cross_origin()
def upload_file():
    global ctr
    if request.method == "POST":
Exemplo n.º 16
0
def run_inference(args):
    model = UNet(topology=args.model_topology,
                 input_channels=len(args.bands),
                 num_classes=len(args.classes))
    model.load_state_dict(torch.load(args.model_path, map_location='cpu'),
                          strict=False)
    print('Log: Loaded pretrained {}'.format(args.model_path))
    model.eval()
    if args.cuda:
        print('log: Using GPU')
        model.cuda(device=args.device)
    # all_districts = ["abbottabad", "battagram", "buner", "chitral", "hangu", "haripur", "karak", "kohat", "kohistan", "lower_dir", "malakand", "mansehra",
    # "nowshehra", "shangla", "swat", "tor_ghar", "upper_dir"]
    all_districts = ["abbottabad"]

    # years = [2014, 2016, 2017, 2018, 2019, 2020]
    years = [2016]
    # change this to do this for all the images in that directory
    for district in all_districts:
        for year in years:
            print("(LOG): On District: {} @ Year: {}".format(district, year))
            # test_image_path = os.path.join(args.data_path, 'landsat8_4326_30_{}_region_{}.tif'.format(year, district))
            test_image_path = os.path.join(args.data_path,
                                           'landsat8_{}_region_{}.tif'.format(
                                               year, district))  #added(nauman)
            inference_loader, adjustment_mask = get_inference_loader(
                rasterized_shapefiles_path=args.rasterized_shapefiles_path,
                district=district,
                image_path=test_image_path,
                model_input_size=128,
                bands=args.bands,
                num_classes=len(args.classes),
                batch_size=args.bs,
                num_workers=4)
            # inference_loader = get_inference_loader(rasterized_shapefiles_path=args.rasterized_shapefiles_path, district=district,
            #                                                          image_path=test_image_path, model_input_size=128, bands=args.bands,
            #                                                          num_classes=len(args.classes), batch_size=args.bs, num_workers=4)
            # we need to fill our new generated test image
            generated_map = np.empty(
                shape=inference_loader.dataset.get_image_size())
            for idx, data in enumerate(inference_loader):
                coordinates, test_x = data['coordinates'].tolist(
                ), data['input']
                test_x = test_x.cuda(
                    device=args.device) if args.cuda else test_x
                out_x, softmaxed = model.forward(test_x)
                pred = torch.argmax(softmaxed, dim=1)
                pred_numpy = pred.cpu().numpy().transpose(1, 2, 0)
                if idx % 5 == 0:
                    print('LOG: on {} of {}'.format(idx,
                                                    len(inference_loader)))
                for k in range(test_x.shape[0]):
                    x, x_, y, y_ = coordinates[k]
                    generated_map[x:x_, y:y_] = pred_numpy[:, :, k]
            # adjust the inferred map
            generated_map += 1  # to make forest pixels: 2, non-forest pixels: 1, null pixels: 0
            generated_map = np.multiply(generated_map, adjustment_mask)
            # save generated map as png image, not numpy array
            forest_map_rband = np.zeros_like(generated_map)
            forest_map_gband = np.zeros_like(generated_map)
            forest_map_bband = np.zeros_like(generated_map)
            forest_map_gband[generated_map == FOREST_LABEL] = 255
            forest_map_rband[generated_map == NON_FOREST_LABEL] = 255
            forest_map_for_visualization = np.dstack(
                [forest_map_rband, forest_map_gband,
                 forest_map_bband]).astype(np.uint8)
            save_this_map_path = os.path.join(
                args.dest, '{}_{}_inferred_map.png'.format(district, year))
            matimg.imsave(save_this_map_path, forest_map_for_visualization)
            print('Saved: {} @ {}'.format(save_this_map_path,
                                          forest_map_for_visualization.shape))
Exemplo n.º 17
0
            model = UNet(in_channels=n_channels, n_classes=n_classes)

        elif args.model == 'SegNet':
            model = SegNet(in_channels=n_channels, n_classes=n_classes)

        elif args.model == 'DenseNet':
            model = DenseNet(in_channels=n_channels, n_classes=n_classes)

        else:
            print("wrong model : must be UNet, SegNet, or DenseNet")
            raise SystemExit

        summary(model,
                input_size=(n_channels, args.height, args.width),
                device='cpu')

        model.load_state_dict(torch.load(args.model_path))

        adversarial_examples = DAG_Attack(model, test_dataset, args)

        if args.attack_path is None:

            adversarial_path = 'data/' + args.model + '_' + args.attacks + '.pickle'

        else:
            adversarial_path = args.attack_path

    # save adversarial examples([adversarial examples, labels])
    with open(adversarial_path, 'wb') as fp:
        pickle.dump(adversarial_examples, fp)
Exemplo n.º 18
0
        optimizer = optim.Adam(model.parameters(),
                               lr=opt.learning_rate,
                               weight_decay=opt.weight_decay)
        criterion = nn.BCELoss()
        vis = Visualizer(env=opt.env)

        if opt.is_cuda:
            model.cuda()
            criterion.cuda()
            if opt.n_gpu > 1:
                model = nn.DataParallel(model)

        run(model, train_loader, val_loader, criterion, vis)
    else:
        if opt.is_cuda:
            model.cuda()
            if opt.n_gpu > 1:
                model = nn.DataParallel(model)
        test_loader = get_test_loader(batch_size=20,
                                      shuffle=True,
                                      num_workers=opt.num_workers,
                                      pin_memory=opt.pin_memory)
        # load the model and run test
        model.load_state_dict(
            torch.load(
                os.path.join(opt.checkpoint_dir, 'RSNA_UNet_0.895_09210122')))

        img_ids, images, pred_masks = run_test(model, test_loader)

        save_pred_result(img_ids, images, pred_masks)
# %%
n_sticks = 40


# %%
# Model
# model = UNet().cuda()
model = UNet()
model.eval()

model_dir = Path('model')
model_path = model_dir.joinpath('model.pth')

param = torch.load(model_path)
model.load_state_dict(param)


# %%
# Dataset for inference
# test_dataset = SSSDataset(train=False, n_sticks=n_sticks, data_size=16)
test_dataset = DataLoaderInstanceSegmentation(train = False)
# test_dataloader = DataLoader(test_dataset, batch_size=16,
#                              shuffle=False, num_workers=0,
#                              pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=20,
                             shuffle=False, num_workers=0,
                             pin_memory=True)                             


# %%
Exemplo n.º 20
0
def main(args):
    writer = SummaryWriter(os.path.join('./logs'))
    # torch.backends.cudnn.benchmark = True
    if not os.path.isdir(args.checkpoint_dir):
        os.mkdir(args.checkpoint_dir)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('[MODEL] CUDA DEVICE : {}'.format(device))

    # TODO DEFINE TRAIN AND TEST TRANSFORMS
    train_tf = None
    test_tf = None

    # Channel wise mean calculated on adobe240-fps training dataset
    mean = [0.429, 0.431, 0.397]
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean, std=std)
    transform = transforms.Compose([transforms.ToTensor(), normalize])

    test_valid = 'validation' if args.valid else 'test'
    train_data = BlurDataset(os.path.join(args.dataset_root, 'train'),
                             seq_len=args.sequence_length,
                             tau=args.num_frame_blur,
                             delta=5,
                             transform=train_tf)
    test_data = BlurDataset(os.path.join(args.dataset_root, test_valid),
                            seq_len=args.sequence_length,
                            tau=args.num_frame_blur,
                            delta=5,
                            transform=train_tf)

    train_loader = DataLoader(train_data,
                              batch_size=args.train_batch_size,
                              shuffle=True)
    test_loader = DataLoader(test_data,
                             batch_size=args.test_batch_size,
                             shuffle=False)

    # TODO IMPORT YOUR CUSTOM MODEL
    model = UNet(3, 3, device, decode_mode=args.decode_mode)

    if args.checkpoint:
        store_dict = torch.load(args.checkpoint)
        try:
            model.load_state_dict(store_dict['state_dict'])
        except KeyError:
            model.load_state_dict(store_dict)

    if args.train_continue:
        store_dict = torch.load(args.checkpoint)
        model.load_state_dict(store_dict['state_dict'])

    else:
        store_dict = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1}

    model.to(device)
    model.train(True)

    # model = nn.DataParallel(model)

    # TODO DEFINE MORE CRITERIA
    # input(True if device == torch.device('cuda:0') else False)
    criterion = {
        'MSE': nn.MSELoss(),
        'L1': nn.L1Loss(),
        # 'Perceptual': PerceptualLoss(model='net-lin', net='vgg', dataparallel=True,
        #                              use_gpu=True if device == torch.device('cuda:0') else False)
    }

    criterion_w = {'MSE': 1.0, 'L1': 10.0, 'Perceptual': 10.0}

    # Define optimizers
    # optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9,weight_decay=5e-4)
    optimizer = optim.Adam(model.parameters(), lr=args.init_learning_rate)

    # Define lr scheduler
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.milestones,
                                               gamma=0.1)

    # best_acc = 0.0
    # start = time.time()
    cLoss = store_dict['loss']
    valLoss = store_dict['valLoss']
    valPSNR = store_dict['valPSNR']
    checkpoint_counter = 0

    loss_tracker = {}
    loss_tracker_test = {}

    psnr_old = 0.0
    dssim_old = 0.0

    for epoch in range(1, 10 *
                       args.epochs):  # loop over the dataset multiple times

        # Append and reset
        cLoss.append([])
        valLoss.append([])
        valPSNR.append([])
        running_loss = 0

        # Increment scheduler count
        scheduler.step()

        tqdm_loader = tqdm(range(len(train_loader)), ncols=150)

        loss = 0.0
        psnr_ = 0.0
        dssim_ = 0.0

        loss_tracker = {}
        for loss_fn in criterion.keys():
            loss_tracker[loss_fn] = 0.0

        # Train
        model.train(True)
        total_steps = 0.01
        total_steps_test = 0.01
        '''for train_idx, data in enumerate(train_loader, 1):
            loss = 0.0
            blur_data, sharpe_data = data
            #import pdb; pdb.set_trace()
            # input(sharpe_data.shape)
            #import pdb; pdb.set_trace()
            interp_idx = int(math.ceil((args.num_frame_blur/2) - 0.49))
            #input(interp_idx)
            if args.decode_mode == 'interp':
                sharpe_data = sharpe_data[:, :, 1::2, :, :]
            elif args.decode_mode == 'deblur':
                sharpe_data = sharpe_data[:, :, 0::2, :, :]
            else:
                #print('\nBoth\n')
                sharpe_data = sharpe_data

            #print(sharpe_data.shape)
            #input(blur_data.shape)
            blur_data = blur_data.to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
            try:
                sharpe_data = sharpe_data.squeeze().to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
            except:
                sharpe_data = sharpe_data.squeeze(3).to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)

            # clear gradient
            optimizer.zero_grad()

            # forward pass
            sharpe_out = model(blur_data)
            # import pdb; pdb.set_trace()
            # input(sharpe_out.shape)

            # compute losses
            # import pdb;
            # pdb.set_trace()
            sharpe_out = sharpe_out.permute(0, 2, 1, 3, 4)
            B, C, S, Fx, Fy = sharpe_out.shape
            for loss_fn in criterion.keys():
                loss_tmp = 0.0

                if loss_fn == 'Perceptual':
                    for bidx in range(B):
                        loss_tmp += criterion_w[loss_fn] * \
                                   criterion[loss_fn](sharpe_out[bidx].permute(1, 0, 2, 3),
                                                      sharpe_data[bidx].permute(1, 0, 2, 3)).sum()
                    # loss_tmp /= B
                else:
                    loss_tmp = criterion_w[loss_fn] * \
                               criterion[loss_fn](sharpe_out, sharpe_data)


                # try:
                # import pdb; pdb.set_trace()
                loss += loss_tmp # if
                # except :
                try:
                    loss_tracker[loss_fn] += loss_tmp.item()
                except KeyError:
                    loss_tracker[loss_fn] = loss_tmp.item()

            # Backpropagate
            loss.backward()
            optimizer.step()

            # statistics
            # import pdb; pdb.set_trace()
            sharpe_out = sharpe_out.detach().cpu().numpy()
            sharpe_data = sharpe_data.cpu().numpy()
            for sidx in range(S):
                for bidx in range(B):
                    psnr_ += psnr(sharpe_out[bidx, :, sidx, :, :], sharpe_data[bidx, :, sidx, :, :]) #, peak=1.0)
                    """dssim_ += dssim(np.moveaxis(sharpe_out[bidx, :, sidx, :, :], 0, 2),
                                    np.moveaxis(sharpe_data[bidx, :, sidx, :, :], 0, 2)
                                    )"""

            """sharpe_out = sharpe_out.reshape(-1,3, sx, sy).detach().cpu().numpy()
            sharpe_data = sharpe_data.reshape(-1, 3, sx, sy).cpu().numpy()
            for idx in range(sharpe_out.shape[0]):
                # import pdb; pdb.set_trace()
                psnr_ += psnr(sharpe_data[idx], sharpe_out[idx])
                dssim_ += dssim(np.swapaxes(sharpe_data[idx], 2, 0), np.swapaxes(sharpe_out[idx], 2, 0))"""

            # psnr_ /= sharpe_out.shape[0]
            # dssim_ /= sharpe_out.shape[0]
            running_loss += loss.item()
            loss_str = ''
            total_steps += B*S
            for key in loss_tracker.keys():
               loss_str += ' {0} : {1:6.4f} '.format(key, 1.0*loss_tracker[key] / total_steps)

            # set display info
            if train_idx % 5 == 0:
                tqdm_loader.set_description(('\r[Training] [Ep {0:6d}] loss: {1:6.4f} PSNR: {2:6.4f} SSIM: {3:6.4f} '.format
                                    (epoch, running_loss / total_steps,
                                     psnr_ / total_steps,
                                     dssim_ / total_steps) + loss_str
                                    ))

                tqdm_loader.update(5)
        tqdm_loader.close()'''

        # Validation
        running_loss_test = 0.0
        psnr_test = 0.0
        dssim_test = 0.0
        # print('len', len(test_loader))
        tqdm_loader_test = tqdm(range(len(test_loader)), ncols=150)
        # import pdb; pdb.set_trace()

        loss_tracker_test = {}
        for loss_fn in criterion.keys():
            loss_tracker_test[loss_fn] = 0.0

        with torch.no_grad():
            model.eval()
            total_steps_test = 0.0

            for test_idx, data in enumerate(test_loader, 1):
                loss = 0.0
                blur_data, sharpe_data = data
                interp_idx = int(math.ceil((args.num_frame_blur / 2) - 0.49))
                # input(interp_idx)
                if args.decode_mode == 'interp':
                    sharpe_data = sharpe_data[:, :, 1::2, :, :]
                elif args.decode_mode == 'deblur':
                    sharpe_data = sharpe_data[:, :, 0::2, :, :]
                else:
                    # print('\nBoth\n')
                    sharpe_data = sharpe_data

                # print(sharpe_data.shape)
                # input(blur_data.shape)
                blur_data = blur_data.to(device)[:, :, :, :352, :].permute(
                    0, 1, 2, 4, 3)
                try:
                    sharpe_data = sharpe_data.squeeze().to(
                        device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
                except:
                    sharpe_data = sharpe_data.squeeze(3).to(
                        device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)

                # clear gradient
                optimizer.zero_grad()

                # forward pass
                sharpe_out = model(blur_data)
                # import pdb; pdb.set_trace()
                # input(sharpe_out.shape)

                # compute losses
                sharpe_out = sharpe_out.permute(0, 2, 1, 3, 4)
                B, C, S, Fx, Fy = sharpe_out.shape
                for loss_fn in criterion.keys():
                    loss_tmp = 0.0
                    if loss_fn == 'Perceptual':
                        for bidx in range(B):
                            loss_tmp += criterion_w[loss_fn] * \
                                        criterion[loss_fn](sharpe_out[bidx].permute(1, 0, 2, 3),
                                                           sharpe_data[bidx].permute(1, 0, 2, 3)).sum()
                        # loss_tmp /= B
                    else:
                        loss_tmp = criterion_w[loss_fn] * \
                                   criterion[loss_fn](sharpe_out, sharpe_data)
                    loss += loss_tmp
                    try:
                        loss_tracker_test[loss_fn] += loss_tmp.item()
                    except KeyError:
                        loss_tracker_test[loss_fn] = loss_tmp.item()

                if ((test_idx % args.progress_iter) == args.progress_iter - 1):
                    itr = test_idx + epoch * len(test_loader)
                    # itr_train
                    writer.add_scalars(
                        'Loss', {
                            'trainLoss': running_loss / total_steps,
                            'validationLoss':
                            running_loss_test / total_steps_test
                        }, itr)
                    writer.add_scalar('Train PSNR', psnr_ / total_steps, itr)
                    writer.add_scalar('Test PSNR',
                                      psnr_test / total_steps_test, itr)
                    # import pdb; pdb.set_trace()
                    # writer.add_image('Validation', sharpe_out.permute(0, 2, 3, 1), itr)

                # statistics
                sharpe_out = sharpe_out.detach().cpu().numpy()
                sharpe_data = sharpe_data.cpu().numpy()
                for sidx in range(S):
                    for bidx in range(B):
                        psnr_test += psnr(
                            sharpe_out[bidx, :, sidx, :, :],
                            sharpe_data[bidx, :, sidx, :, :])  #, peak=1.0)
                        dssim_test += dssim(
                            np.moveaxis(sharpe_out[bidx, :, sidx, :, :], 0, 2),
                            np.moveaxis(sharpe_data[bidx, :, sidx, :, :], 0,
                                        2))  #,range=1.0  )

                running_loss_test += loss.item()
                total_steps_test += B * S
                loss_str = ''
                for key in loss_tracker.keys():
                    loss_str += ' {0} : {1:6.4f} '.format(
                        key, 1.0 * loss_tracker_test[key] / total_steps_test)

                # set display info

                tqdm_loader_test.set_description((
                    '\r[Test    ] [Ep {0:6d}] loss: {1:6.4f} PSNR: {2:6.4f} SSIM: {3:6.4f} '
                    .format(epoch, running_loss_test / total_steps_test,
                            psnr_test / total_steps_test,
                            dssim_test / total_steps_test) + loss_str))
                tqdm_loader_test.update(1)
            tqdm_loader_test.close()

        # save model
        if psnr_old < (psnr_test / total_steps_test):
            if epoch != 1:
                os.remove(
                    os.path.join(
                        args.checkpoint_dir,
                        'epoch-{}-test-psnr-{}-ssim-{}.ckpt'.format(
                            epoch_old,
                            str(round(psnr_old, 4)).replace('.', 'pt'),
                            str(round(dssim_old, 4)).replace('.', 'pt'))))
            epoch_old = epoch
            psnr_old = psnr_test / total_steps_test
            dssim_old = dssim_test / total_steps_test

            checkpoint_dict = {
                'epoch': epoch_old,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'train_psnr': psnr_ / total_steps,
                'train_dssim': dssim_ / total_steps,
                'train_mse': loss_tracker['MSE'] / total_steps,
                'train_l1': loss_tracker['L1'] / total_steps,
                # 'train_percp': loss_tracker['Perceptual'] / total_steps,
                'test_psnr': psnr_old,
                'test_dssim': dssim_old,
                'test_mse': loss_tracker_test['MSE'] / total_steps_test,
                'test_l1': loss_tracker_test['L1'] / total_steps_test,
                # 'test_percp': loss_tracker_test['Perceptual'] / total_steps_test,
            }

            torch.save(
                checkpoint_dict,
                os.path.join(
                    args.checkpoint_dir,
                    'epoch-{}-test-psnr-{}-ssim-{}.ckpt'.format(
                        epoch_old,
                        str(round(psnr_old, 4)).replace('.', 'pt'),
                        str(round(dssim_old, 4)).replace('.', 'pt'))))

        # if epoch % args.checkpoint_epoch == 0:
        #    torch.save(model.state_dict(),args.checkpoint_dir + str(int(epoch/100))+".ckpt")

    return None
Exemplo n.º 21
0
class Solver:
    def __init__(self,
                 config,
                 train_loader=None,
                 val_loader=None,
                 test_loader=None):
        self.cfg = config
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.n_gpus = self.cfg.n_gpus

        if self.cfg.mode in ['train', 'test']:
            self.train_loader = train_loader
            self.val_loader = val_loader
        else:
            self.test_loader = test_loader

        # Build model
        self.build_model()
        if self.cfg.resume:
            self.load_pre_model()
        else:
            self.start_epoch = 0

        # Trigger Tensorboard Logger
        if self.cfg.use_tensorboard:
            try:
                from tensorboardX import SummaryWriter
                self.writer = SummaryWriter()
            except ImportError:
                print(
                    '=> There is no module named tensorboardX, tensorboard disabled'
                )
                self.cfg.use_tensorboard = False

    def train_val(self):
        # Build record objs
        self.build_recorder()

        iter_per_epoch = len(
            self.train_loader.dataset) // self.cfg.train_batch_size
        if len(self.train_loader.dataset) % self.cfg.train_batch_size != 0:
            iter_per_epoch += 1

        for epoch in range(self.start_epoch,
                           self.start_epoch + self.cfg.n_epochs):

            self.model.train()

            self.train_time.reset()
            self.train_loss.reset()
            self.train_cls_acc.reset()
            self.train_pix_acc.reset()
            self.train_mIoU.reset()

            for i, (image, label) in enumerate(self.train_loader):
                start_time = time.time()
                image_var = image.to(self.device)
                label_var = label.to(self.device)

                output = self.model(image_var)
                loss = self.criterion(output, label_var)

                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

                end_time = time.time()

                self.train_time.update(end_time - start_time)
                self.train_loss.update(loss.item())

                if self.cfg.task == 'cls':
                    # Record classification accuracy
                    cls_acc = cal_acc(output, label_var)

                    # Update recorder
                    self.train_cls_acc.update(cls_acc.item())

                    if (i + 1) % self.cfg.log_step == 0:
                        print(
                            'Epoch[{0}][{1}/{2}]\t'
                            'Time {train_time.val:.3f} ({train_time.avg:.3f})\t'
                            'Loss {train_loss.val:.4f} ({train_loss.avg:.4f})\t'
                            'Accuracy {train_cls_acc.val:.4f} ({train_cls_acc.avg:.4f})'
                            .format(epoch + 1,
                                    i + 1,
                                    iter_per_epoch,
                                    train_time=self.train_time,
                                    train_loss=self.train_loss,
                                    train_cls_acc=self.train_cls_acc))

                    if self.cfg.use_tensorboard:
                        self.writer.add_scalar('train/loss', loss.item(),
                                               epoch * iter_per_epoch + i)
                        self.writer.add_scalar('train/accuracy',
                                               cls_acc.item(),
                                               epoch * iter_per_epoch + i)

                elif self.cfg.task == 'seg':
                    # Record mIoU and pixel-wise accuracy
                    pix_acc = cal_pixel_acc(output, label_var)
                    mIoU = cal_mIoU(output, label_var)[-1]
                    mIoU = torch.mean(mIoU)

                    # Update recorders
                    self.train_pix_acc.update(pix_acc.item())
                    self.train_mIoU.update(mIoU.item())

                    if (i + 1) % self.cfg.log_step == 0:
                        print(
                            'Epoch[{0}][{1}/{2}]\t'
                            'Time {train_time.val:.3f} ({train_time.avg:.3f})\t'
                            'Loss {train_loss.val:.4f} ({train_loss.avg:.4f})\t'
                            'Pixel-Acc {train_pix_acc.val:.4f} ({train_pix_acc.avg:.4f})\t'
                            'mIoU {train_mIoU.val:.4f} ({train_mIoU.avg:.4f})'.
                            format(epoch + 1,
                                   i + 1,
                                   iter_per_epoch,
                                   train_time=self.train_time,
                                   train_loss=self.train_loss,
                                   train_pix_acc=self.train_pix_acc,
                                   train_mIoU=self.train_mIoU))

                    if self.cfg.use_tensorboard:
                        self.writer.add_scalar('train/loss', loss.item(),
                                               epoch * iter_per_epoch + i)
                        self.writer.add_scalar('train/pix_acc', pix_acc.item(),
                                               epoch * iter_per_epoch + i)
                        self.writer.add_scalar('train/mIoU', mIoU.item(),
                                               epoch * iter_per_epoch + i)

                #FIXME currently test validation code
                #if (i + 1) % 100 == 0:
            if (epoch + 1) % self.cfg.val_step == 0:
                self.validate(epoch)

        # Close logging
        self.writer.close()

    def validate(self, epoch):
        """ Validate with validation dataset """
        self.model.eval()

        self.val_time.reset()
        self.val_loss.reset()
        self.val_cls_acc.reset()
        self.val_mIoU.reset()
        self.val_pix_acc.reset()

        iter_per_epoch = len(
            self.val_loader.dataset) // self.cfg.val_batch_size
        if len(self.val_loader.dataset) % self.cfg.val_batch_size != 0:
            iter_per_epoch += 1

        for i, (image, label) in enumerate(self.val_loader):

            start_time = time.time()
            image_var = image.to(self.device)
            label_var = label.to(self.device)

            output = self.model(image_var)
            loss = self.criterion(output, label_var)

            end_time = time.time()

            self.val_time.update(end_time - start_time)
            self.val_loss.update(loss.item())

            if self.cfg.task == 'cls':
                # Record classification accuracy
                cls_acc = cal_acc(output, label_var)

                # Update recorder
                self.val_cls_acc.update(cls_acc.item())

                if (i + 1) % self.cfg.log_step == 0:
                    print(
                        'Epoch[{0}][{1}/{2}]\t'
                        'Time {val_time.val:.3f} ({val_time.avg:.3f})\t'
                        'Loss {val_loss.val:.4f} ({val_loss.avg:.4f})\t'
                        'Accuracy {val_cls_acc.val:.4f} ({val_cls_acc.avg:.4f})'
                        .format(epoch + 1,
                                i + 1,
                                iter_per_epoch,
                                val_time=self.val_time,
                                val_loss=self.val_loss,
                                val_cls_acc=self.val_cls_acc))

                if self.cfg.use_tensorboard:
                    self.writer.add_scalar('val/loss', loss.item(),
                                           epoch * iter_per_epoch + i)
                    self.writer.add_scalar('val/accuracy', cls_acc.item(),
                                           epoch * iter_per_epoch + i)

            elif self.cfg.task == 'seg':
                # Record mIoU and pixel-wise accuracy
                pix_acc = cal_pixel_acc(output, label_var)
                mIoU = cal_mIoU(output, label_var)[-1]
                mIoU = torch.mean(mIoU)

                # Update recorders
                self.val_pix_acc.update(pix_acc.item())
                self.val_mIoU.update(mIoU.item())

                if (i + 1) % self.cfg.log_step == 0:
                    print(
                        ' ##### Validation\t'
                        'Epoch[{0}][{1}/{2}]\t'
                        'Time {val_time.val:.3f} ({val_time.avg:.3f})\t'
                        'Loss {val_loss.val:.4f} ({val_loss.avg:.4f})\t'
                        'Pixel-Acc {val_pix_acc.val:.4f} ({val_pix_acc.avg:.4f})\t'
                        'mIoU {val_mIoU.val:.4f} ({val_mIoU.avg:.4f})'.format(
                            epoch + 1,
                            i + 1,
                            iter_per_epoch,
                            val_time=self.val_time,
                            val_loss=self.val_loss,
                            val_pix_acc=self.val_pix_acc,
                            val_mIoU=self.val_mIoU))

                if self.cfg.use_tensorboard:
                    self.writer.add_scalar('val/loss', loss.item(),
                                           epoch * iter_per_epoch + i)
                    self.writer.add_scalar('val/pix_acc', pix_acc.item(),
                                           epoch * iter_per_epoch + i)
                    self.writer.add_scalar('val/mIoU', mIoU.item(),
                                           epoch * iter_per_epoch + i)

        if self.cfg.task == 'cls':
            if (epoch + 1) % self.cfg.model_save_epoch == 0:
                state = {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optim': self.optim.state_dict()
                }
                if self.best_cls < self.val_cls_acc.avg:
                    self.best_cls = self.val_cls_acc.avg
                    torch.save(
                        state, './model/cls_model_' + str(epoch + 1) + '_' +
                        str(self.val_cls_acc.avg)[0:5] + '.pth')

        elif self.cfg.task == 'seg':
            # Save segmentation samples and model
            if (epoch + 1) % self.cfg.sample_save_epoch == 0:
                pred = torch.argmax(output, dim=1)
                save_image(image, './sample/ori_' + str(epoch + 1) + '.png')
                save_image(label.unsqueeze(1),
                           './sample/true_' + str(epoch + 1) + '.png')
                save_image(pred.cpu().unsqueeze(1),
                           './sample/pred_' + str(epoch + 1) + '.png')

            if (epoch + 1) % self.cfg.model_save_epoch == 0:
                state = {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optim': self.optim.state_dict()
                }
                if self.best_seg < self.val_pix_acc.avg:
                    self.best_seg = self.val_pix_acc.avg
                    torch.save(
                        state, './model/seg_model_' + str(epoch + 1) + '_' +
                        str(self.val_pix_acc.avg)[0:5] + '.pth')

            if self.cfg.use_tensorboard:
                image = make_grid(image)
                label = make_grid(label.unsqueeze(1))
                pred = make_grid(pred.cpu().unqueeze(1))
                self.writer.add_image('Origianl', image, epoch + 1)
                self.writer.add_image('Labels', label, epoch + 1)
                self.writer.add_image('Predictions', pred, epoch + 1)

    def build_model(self):
        """ Rough """
        if self.cfg.task == 'cls':
            self.model = BinaryClassifier(num_classes=2)
        elif self.cfg.task == 'seg':
            self.model = UNet(num_classes=2)
        self.criterion = nn.CrossEntropyLoss()
        self.optim = torch.optim.Adam(self.model.parameters(),
                                      lr=self.cfg.lr,
                                      betas=(self.cfg.beta0, self.cfg.beta1))
        if self.n_gpus > 1:
            print('### {} of gpus are used!!!'.format(self.n_gpus))
            self.model = nn.DataParallel(self.model)

        self.model = self.model.to(self.device)

    def build_recorder(self):
        # Train recorder
        self.train_time = AverageMeter()
        self.train_loss = AverageMeter()

        # For classification
        self.train_cls_acc = AverageMeter()
        # For segmentation
        self.train_mIoU = AverageMeter()
        self.train_pix_acc = AverageMeter()

        # Validation recorder
        self.val_time = AverageMeter()
        self.val_loss = AverageMeter()

        # For classification
        self.val_cls_acc = AverageMeter()
        # For segmentation
        self.val_mIoU = AverageMeter()
        self.val_pix_acc = AverageMeter()

        # self.logger = Logger('./logs')
        self.best_cls = 0
        self.best_seg = 0

    def load_pre_model(self):
        """ Load pretrained model """
        print('=> loading checkpoint {}'.format(self.cfg.pre_model))
        checkpoint = torch.load(self.cfg.pre_model)
        self.start_epoch = checkpoint['epoch']
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optim.load_state_dict(checkpoint['optim'])
        print('=> loaded checkpoint {}(epoch {})'.format(
            self.cfg.pre_model, self.start_epoch))

    #TODO:Inference part:
    def infer(self, data):
        """
        input
            @data: iterable 256 x 256 patches
        output
            @output : segmentation results from each patch
                    i) If classifier's result is that there is a tissue inside of patch, outcome is a masked result.
                    ii) Otherwise, output is segmentated mask which all of pixels are background
        """
        # Data Loading

        # Load models of classification and segmetation and freeze them
        self.freeze()

        # Forward images to Classification model / Select targeted images

        # Forward images to Segmentation model

        # Record Loss / Accuracy / Pixel-Accuracy

        # Print samples out..

    def freeze(self):
        pass
        print('{}, {} have frozen!!!'.format('model_name_1', 'model_name_2'))
Exemplo n.º 22
0
base_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])

test_dataset = KneeDataset('./data/test_knee/', transform=base_transform)

test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False)

# Load model
net = UNet().to(device)

optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum)

ckpt = torch.load('best_unet.pt')
net.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])

with torch.no_grad():
    net.eval()
    i = 0
    for img, mask in test_data_loader:
        inputs = img.to(device)
        labels = mask.to(device)
        outputs = net(inputs)
        
        inputs = inputs.squeeze()
        labels = labels.squeeze()
        outputs = outputs.squeeze()
        
        pred = np.logical_not(outputs.cpu() < 0.5)
Exemplo n.º 23
0
from model import UNet

path = "./checkpoints/checkpoint_SAR.pt"
model_test = UNet()
model_test.load_state_dict(torch.load(path,map_location=torch.device('cpu')))

for filename in os.listdir('./testimages'):
    if not filename.startswith('.'):
        image = Image.open('./testimages/' + filename).convert("LA")
        mean, std = np.mean(image), np.std(image)
        x=9
        data_transforms = transforms.Compose([
                                             transforms.RandomResizedCrop(2**x),
                                             transforms.ToTensor(),
                                             transforms.Normalize(0.5, 0.5)
        ])
        img = data_transforms(image)[0]
        img = torch.reshape(img, (1, 1, 2**x, 2**x))
        image_log = lintolog(img)
        img_out = logtolin(model_test(image_log))
        plt.imsave('./output/'+filename, img_out[0][0], cmap='gray')
        fig2 = plt.figure(figsize = (10,10)) # create a 5 x 5 figure
        ax2 = fig2.add_subplot(121)
        ax2.imshow(img_out[0][0], interpolation='none', cmap='gray')
        ax2 = fig2.add_subplot(122)
        ax2.imshow(img[0][0], interpolation='none', cmap='gray')
        ax2.set_title(filename)
class Instructor:
    ''' Model training and evaluation '''
    def __init__(self, opt):
        self.opt = opt
        if opt.inference:
            self.testset = TestImageDataset(fdir=opt.impaths['test'],
                                            imsize=opt.imsize)
        else:
            self.trainset = ImageDataset(fdir=opt.impaths['train'],
                                         bdir=opt.impaths['btrain'],
                                         imsize=opt.imsize,
                                         mode='train',
                                         aug_prob=opt.aug_prob,
                                         prefetch=opt.prefetch)
            self.valset = ImageDataset(fdir=opt.impaths['val'],
                                       bdir=opt.impaths['bval'],
                                       imsize=opt.imsize,
                                       mode='val',
                                       aug_prob=opt.aug_prob,
                                       prefetch=opt.prefetch)
        self.model = UNet(n_channels=3,
                          n_classes=1,
                          bilinear=self.opt.use_bilinear)
        if opt.checkpoint:
            self.model.load_state_dict(
                torch.load('./state_dict/{:s}'.format(opt.checkpoint),
                           map_location=self.opt.device))
            print('checkpoint {:s} has been loaded'.format(opt.checkpoint))
        if opt.multi_gpu == 'on':
            self.model = torch.nn.DataParallel(self.model)
        self.model = self.model.to(opt.device)
        self._print_args()

    def _print_args(self):
        n_trainable_params, n_nontrainable_params = 0, 0
        for p in self.model.parameters():
            n_params = torch.prod(torch.tensor(p.shape))
            if p.requires_grad:
                n_trainable_params += n_params
            else:
                n_nontrainable_params += n_params
        self.info = 'n_trainable_params: {0}, n_nontrainable_params: {1}\n'.format(
            n_trainable_params, n_nontrainable_params)
        self.info += 'training arguments:\n' + '\n'.join([
            '>>> {0}: {1}'.format(arg, getattr(self.opt, arg))
            for arg in vars(self.opt)
        ])
        if self.opt.device.type == 'cuda':
            print('cuda memory allocated:',
                  torch.cuda.memory_allocated(opt.device.index))
        print(self.info)

    def _reset_records(self):
        self.records = {
            'best_epoch': 0,
            'best_dice': 0,
            'train_loss': list(),
            'val_loss': list(),
            'val_dice': list(),
            'checkpoints': list()
        }

    def _update_records(self, epoch, train_loss, val_loss, val_dice):
        if val_dice > self.records['best_dice']:
            path = './state_dict/{:s}_dice{:.4f}_temp{:s}.pt'.format(
                self.opt.model_name, val_dice,
                str(time.time())[-6:])
            if self.opt.multi_gpu == 'on':
                torch.save(self.model.module.state_dict(), path)
            else:
                torch.save(self.model.state_dict(), path)
            self.records['best_epoch'] = epoch
            self.records['best_dice'] = val_dice
            self.records['checkpoints'].append(path)
        self.records['train_loss'].append(train_loss)
        self.records['val_loss'].append(val_loss)
        self.records['val_dice'].append(val_dice)

    def _draw_records(self):
        timestamp = str(int(time.time()))
        print('best epoch: {:d}'.format(self.records['best_epoch']))
        print('best train loss: {:.4f}, best val loss: {:.4f}'.format(
            min(self.records['train_loss']), min(self.records['val_loss'])))
        print('best val dice {:.4f}'.format(self.records['best_dice']))
        os.rename(
            self.records['checkpoints'][-1],
            './state_dict/{:s}_dice{:.4f}_save{:s}.pt'.format(
                self.opt.model_name, self.records['best_dice'], timestamp))
        for path in self.records['checkpoints'][0:-1]:
            os.remove(path)
        # Draw figures
        plt.figure()
        trainloss, = plt.plot(self.records['train_loss'])
        valloss, = plt.plot(self.records['val_loss'])
        plt.legend([trainloss, valloss], ['train', 'val'], loc='upper right')
        plt.title('{:s} loss curve'.format(timestamp))
        plt.savefig('./figs/{:s}_loss.png'.format(timestamp),
                    format='png',
                    transparent=True,
                    dpi=300)
        plt.figure()
        valdice, = plt.plot(self.records['val_dice'])
        plt.title('{:s} dice curve'.format(timestamp))
        plt.savefig('./figs/{:s}_dice.png'.format(timestamp),
                    format='png',
                    transparent=True,
                    dpi=300)
        # Save report
        report = '\t'.join(
            ['val_dice', 'train_loss', 'val_loss', 'best_epoch', 'timestamp'])
        report += "\n{:.4f}\t{:.4f}\t{:.4f}\t{:d}\t{:s}\n{:s}".format(
            self.records['best_dice'], min(self.records['train_loss']),
            min(self.records['val_loss']), self.records['best_epoch'],
            timestamp, self.info)
        with open('./logs/{:s}_log.txt'.format(timestamp), 'w') as f:
            f.write(report)
        print('report saved:', './logs/{:s}_log.txt'.format(timestamp))

    def _train(self, train_dataloader, criterion, optimizer):
        self.model.train()
        train_loss, n_total, n_batch = 0, 0, len(train_dataloader)
        for i_batch, sample_batched in enumerate(train_dataloader):
            inputs, target = sample_batched[0].to(
                self.opt.device), sample_batched[1].to(self.opt.device)
            predict = self.model(inputs)

            optimizer.zero_grad()
            loss = criterion(predict, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(sample_batched)
            n_total += len(sample_batched)

            ratio = int((i_batch + 1) * 50 / n_batch)
            sys.stdout.write("\r[" + ">" * ratio + " " * (50 - ratio) +
                             "] {}/{} {:.2f}%".format(i_batch + 1, n_batch,
                                                      (i_batch + 1) * 100 /
                                                      n_batch))
            sys.stdout.flush()
        print()
        return train_loss / n_total

    def _evaluation(self, val_dataloader, criterion):
        self.model.eval()
        val_loss, val_dice, n_total = 0, 0, 0
        with torch.no_grad():
            for sample_batched in val_dataloader:
                inputs, target = sample_batched[0].to(
                    self.opt.device), sample_batched[1].to(self.opt.device)
                predict = self.model(inputs)
                loss = criterion(predict, target)
                dice = dice_coeff(predict, target)
                val_loss += loss.item() * len(sample_batched)
                val_dice += dice.item() * len(sample_batched)
                n_total += len(sample_batched)
        return val_loss / n_total, val_dice / n_total

    def run(self):
        _params = filter(lambda p: p.requires_grad, self.model.parameters())
        optimizer = torch.optim.Adam(_params,
                                     lr=self.opt.lr,
                                     weight_decay=self.opt.l2reg)
        criterion = BCELoss2d()
        train_dataloader = DataLoader(dataset=self.trainset,
                                      batch_size=self.opt.batch_size,
                                      shuffle=True)
        val_dataloader = DataLoader(dataset=self.valset,
                                    batch_size=self.opt.batch_size,
                                    shuffle=False)
        self._reset_records()
        for epoch in range(self.opt.num_epoch):
            train_loss = self._train(train_dataloader, criterion, optimizer)
            val_loss, val_dice = self._evaluation(val_dataloader, criterion)
            self._update_records(epoch, train_loss, val_loss, val_dice)
            print(
                '{:d}/{:d} > train loss: {:.4f}, val loss: {:.4f}, val dice: {:.4f}'
                .format(epoch + 1, self.opt.num_epoch, train_loss, val_loss,
                        val_dice))
        self._draw_records()

    def inference(self):
        test_dataloader = DataLoader(dataset=self.testset,
                                     batch_size=1,
                                     shuffle=False)
        n_batch = len(test_dataloader)
        with torch.no_grad():
            for i_batch, sample_batched in enumerate(test_dataloader):
                index, inputs = sample_batched[0], sample_batched[1].to(
                    self.opt.device)
                predict = self.model(inputs)
                self.testset.save_img(index.item(), predict, self.opt.use_crf)
                ratio = int((i_batch + 1) * 50 / n_batch)
                sys.stdout.write(
                    "\r[" + ">" * ratio + " " * (50 - ratio) +
                    "] {}/{} {:.2f}%".format(i_batch + 1, n_batch,
                                             (i_batch + 1) * 100 / n_batch))
                sys.stdout.flush()
        print()
Exemplo n.º 25
0
                      default=False,
                      help='use cuda')
    parser.add_option('-l',
                      '--load',
                      dest='load',
                      default=False,
                      help='load file model')

    (options, args) = parser.parse_args()
    return options


if __name__ == '__main__':
    args = get_args()

    net = UNet(n_classes=args.n_classes).cuda()

    if args.load:
        net.load_state_dict(torch.load(args.load))
        print('Model loaded from %s' % (args.load))

    if args.gpu:
        net.cuda()
        cudnn.benchmark = True

    train_net(net=net,
              epochs=args.epochs,
              n_classes=args.n_classes,
              gpu=args.gpu,
              data_dir=args.data_dir)
Exemplo n.º 26
0
                      '--load',
                      dest='load',
                      default=False,
                      help='load file model')

    (options, args) = parser.parse_args()
    return options


if __name__ == '__main__':
    args = get_args()

    net = UNet(n_classes=args.n_classes)

    if args.load:
        if args.gpu:
            net.load_state_dict(torch.load(args.load))
        else:
            net.load_state_dict(torch.load(args.load, map_location='cpu'))
        print('Model loaded from %s' % (args.load))

    if args.gpu:
        net.cuda()
        cudnn.benchmark = True

    train_net(net=net,
              epochs=args.epochs,
              n_classes=args.n_classes,
              gpu=args.gpu,
              data_dir=args.data_dir)
Exemplo n.º 27
0
dct['down4.mpconv.1.conv.4.bias'].data.copy_(dctvgg['features.38.bias'])
dct['down4.mpconv.1.conv.4.running_mean'].data.copy_(
    dctvgg['features.38.running_mean'])  #
dct['down4.mpconv.1.conv.4.running_var'].data.copy_(
    dctvgg['features.38.running_var'])

dct['down4.mpconv.1.conv.6.weight'].data.copy_(dctvgg['features.40.weight'])
dct['down4.mpconv.1.conv.6.bias'].data.copy_(dctvgg['features.40.bias'])
dct['down4.mpconv.1.conv.7.weight'].data.copy_(dctvgg['features.41.weight'])  #
dct['down4.mpconv.1.conv.7.bias'].data.copy_(dctvgg['features.41.bias'])
dct['down4.mpconv.1.conv.7.running_mean'].data.copy_(
    dctvgg['features.41.running_mean'])  #
dct['down4.mpconv.1.conv.7.running_var'].data.copy_(
    dctvgg['features.41.running_var'])

model.load_state_dict(dct)

writer = SummaryWriter()


def compute_overlaps_masks(masks1, masks2):
    """Computes IoU overlaps between two sets of masks.
        masks1, masks2: [Height, Width, instances]
        """

    # If either set of masks is empty return empty result
    if masks1.shape[-1] == 0 or masks2.shape[-1] == 0:
        return np.zeros((masks1.shape[-1], masks2.shape[-1]))
    # flatten masks and compute their areas
    masks1 = np.reshape(masks1 > .5, (-1, masks1.shape[-1])).astype(np.float32)
    masks2 = np.reshape(masks2 > .5, (-1, masks2.shape[-1])).astype(np.float32)
Exemplo n.º 28
0
if __name__ == "__main__":
    args = get_args()
    in_files = args.input
    out_files = get_output_filenames(args)

    net = UNet(3, 1)

    net = torch.nn.DataParallel(net)

    print("Loading model {}".format(args.model))

    if not args.cpu:
        print("Using CUDA version of the net, prepare your GPU !")
        net.cuda()
        net.load_state_dict(torch.load(args.model))
    else:
        net.cpu()
        net.load_state_dict(torch.load(args.model, map_location='cpu'))
        print("Using CPU version of the net, this may be very slow")

    print("Model loaded !")

    for i, fn in enumerate(in_files):
        print("\nPredicting image {} ...".format(fn))

        img = Image.open(fn)
        if img.size[0] < img.size[1]:
            print("Error: image height larger than the width")

        mask = predict_img(net=net,
Exemplo n.º 29
0
    # model = NestedUNet(n_channels=1, n_classes=20)
    # model = SCSE_UNet(n_channels=1, n_classes=20)
    # model = SCSENestedUNet(n_channels=1, n_classes=18)
    # model = DilatedUNet(in_channels=1, classes=20)
    # model = ResUNet(in_channel=1, n_classes=5)

    # stage1_model_whole = ResUNet(in_channel=1, n_classes=2) #
    # root_dir = "E:\CSI2019\AASCE_stage4_Process/"
    # stage1_model_whole_load_dir = "saved_model/stage1_resunet_whole_line/CP200.pth"
    # stage1_model_whole.load_state_dict(torch.load(os.path.join(root_dir, stage1_model_whole_load_dir)))
    # logger.info('Stage1_Model loaded from {}'.format(stage1_model_whole_load_dir))

    stage1_model_whole = UNet(n_channels=1, n_classes=2)  #
    root_dir = "./"
    stage1_model_whole_load_dir = "saved_model/stage1_unet_whole_line/CP324.pth"
    stage1_model_whole.load_state_dict(
        torch.load(os.path.join(root_dir, stage1_model_whole_load_dir)))
    logger.info(
        'Stage1_Model loaded from {}'.format(stage1_model_whole_load_dir))

    stage1_model_segm = ResUNet(in_channel=1, n_classes=2)  #
    stage1_model_segm_load_dir = "saved_model/stage1_resunet_segm_line/CP1127.pth"
    stage1_model_segm.load_state_dict(
        torch.load(os.path.join(root_dir, stage1_model_segm_load_dir)))
    logger.info(
        'Stage1_Model loaded from {}'.format(stage1_model_segm_load_dir))

    stage2_model_box = UNet(n_channels=1, n_classes=18)  #
    stage2_model_box_load_dir = "saved_model/stage2__unet_box/Bestmodel_568.pth"
    stage2_model_box.load_state_dict(
        torch.load(os.path.join(root_dir, stage2_model_box_load_dir)))
    logger.info(
Exemplo n.º 30
0
args = parser.parse_args()

criterionMSE = nn.MSELoss() #.to(device)


transform = transforms.Compose(transform_list)

img_dir = open('/n/holyscratch01/wadduwage_lab/uom_bme/dataset_static_2020/20200105_synthBeads_1/test.txt','r')

avg_mse = 0
avg_psnr = 0

for epochs in range(55,56):
    my_model = 'ckpt/train_deep_tfm_loss_mse/fcn_deep_' + str(epochs) + '.pth'
    netG = UNet(n_classes=args.output_nc)
    netG.load_state_dict(torch.load(my_model))
    netG.eval()
    p = 0
    f_path = '/n/holyscratch01/wadduwage_lab/uom_bme/dataset_static_2020/20200105_synthBeads_1/tr_data_1sls/'    
    for line in img_dir:
        print(line)
        GT_ = io.imread(f_path + str(line[0:-1]) + '_gt.png')
        modalities = np.zeros((32,128,128))
        for i in range(0,32):
             modalities[i,:,:] = io.imread(f_path + str(line[0:-1]) +'_'+str(i+1) +'.png')  
        depth = modalities.shape[2]
        predicted_im = np.zeros((128,128,1))
        if np.min(np.array(GT_))==np.max(np.array(GT_)):
             print('Yes')
        GT = torch.from_numpy(np.divide(GT_,max_gt))
        img = torch.from_numpy(np.divide(modalities,max_im)[None, :, :]).float()