Example #1
0
    def set_model(self, model=None, load_path=None, use_most_recent=True):
        """

        :param model: Can be directly a pytorch Module. Otherwise, will use the default model (Bionet) EDIT: Changed to UNet
        :param load_path: Path indicating from where to load a pretrained set of weights. Can point to a folder containing
        multiples files (see argument below) or directly to a .pth file
        If None, the model is initialized accordingly to the default policy
        :param use_most_recent: If load_path points to a folder and not a file, the loading function will choose the most
        recent file in the folder (usually, the last saved during the training and therefore 'maybe' the best).
        :return:
        """
        if model is None:
            model = UNet(checkpoint=self.config.hp.save_point,
                         config=self.config.model)
        if load_path is not None:
            model.load(load_path, load_most_recent=use_most_recent)

        self.model = model.to(self.device)
        self.setup_optims()

        if self.use_apex:
            self.model, self.optim = self.amp.initialize(self.model,
                                                         self.optim,
                                                         opt_level="O1",
                                                         loss_scale="dynamic")

        if self.multi_gpu:
            self.model = MyDataParallel(self.model,
                                        device_ids=self.config.extern.gpu)
 def get_unet(pretrain_unet_path):
     unet = UNet(3, depth=5, in_channels=3)
     print(unet)
     print('load uent with depth = 5 and downsampling will be performed for 4 times!!')
     unet.load_state_dict(weight_to_cpu(pretrain_unet_path))
     print('load pretrained unet!')
     return unet
Example #3
0
def main():
    model = UNet().to(device)
    criterion = BinaryDiceLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    train_dataset = ImageDataset(split='train')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=train_batch_size,
                                               num_workers=num_workers,
                                               shuffle=True,
                                               pin_memory=True)
    valid_dataset = ImageDataset(split='valid')
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=valid_batch_size,
                                               num_workers=num_workers,
                                               shuffle=False,
                                               pin_memory=True)

    for epoch in range(epochs):
        train(epoch, model, train_loader, criterion, optimizer)
        valid(epoch, model, valid_loader, criterion)

    save_checkpoint(
        {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        },
        save_file='pretrained/unet.pth.tar',
        is_best=False)
Example #4
0
    def __init__(self, prefix, epoch, data_dir, saved_path):
        self.prefix = prefix
        self.data = data_dir
        self.epoch = epoch
        self.batch_size = 32
        self.power = 2
        self.saved_path = saved_path
        self.dataloader = self.get_dataloader()
        self.config = self.load_config()
        self.unet = UNet(3, depth=self.config['u_depth'], in_channels=3)

        print(self.unet)
        print(
            'load uent with depth %d and downsampling will be performed for %d times!!'
            % (self.config['u_depth'], self.config['u_depth'] - 1))
        self.unet.load_state_dict(
            weight_to_cpu('%s/epoch_%s/g.pkl' % (self.prefix, self.epoch)))
        print('load pretrained unet')

        self.d = get_discriminator(self.config['gan_type'],
                                   self.config['d_depth'],
                                   self.config['dowmsampling'])
        self.d.load_state_dict(
            weight_to_cpu('%s/epoch_%s/d.pkl' % (self.prefix, self.epoch)))
        print('load pretrained d')
Example #5
0
 def get_unet(self):
     unet = UNet(3, depth=self.u_depth, in_channels=3)
     print(unet)
     print(
         'load uent with depth %d and downsampling will be performed for %d times!!'
         % (self.u_depth, self.u_depth - 1))
     if self.is_pretrained_unet:
         unet.load_state_dict(weight_to_cpu(self.pretrain_unet_path))
         print('load pretrained unet')
     return unet
Example #6
0
class BaseMonoDepthEstimator(nn.Module):
    def __init__(self, inchans=3,nframes=3):
        super(BaseMonoDepthEstimator, self).__init__()
        # 3 parameters for translation and 3 for rotation
        self.inchans = inchans
        self.ego_vector_dim = 6
        self.nframes = nframes
        self.ego_prediction_size = self.ego_vector_dim * (self.nframes -1)

        # Depth network - encoder-decoder combo
        self.unet = UNet(3,1)
        self.final_conv_layer0 = nn.Conv2d(1, 1, 1, stride=1, padding=0)
        self.final_conv_layer1 = nn.Conv2d(1, 1, 1, stride=1, padding=0)

        # 2) an ego motion network - use the encoder from (1)
        # and append extra cnns
        self.ego_motion_cnn = ConvolutionStack(self.unet.feature_channels()*nframes)
        self.ego_motion_cnn.append(128,3,2)
        self.ego_motion_cnn.append(64,3,2)
        self.ego_motion_cnn.append(self.ego_prediction_size,1,1)


    def forward(self, x):       
        # for each frame, run through the depth and ego motion networks
        # assume input is BxKxCxWxH
        assert len(x.shape)==5, 'Input must be BxKxCxWxH!'
        assert x.shape[1]==self.nframes, 'Input sequence length must match nframes!, expected {}, found {}'.format(self.nframes, x.shape[1])

        batch_size = x.shape[0]
        unstacked_frames = torch.chunk(x,self.nframes,1)
        unstacked_frames = [torch.squeeze(y,1) for y in unstacked_frames]
        encoded_features = []
        depth_maps = []

        # first predict depth in all frames
        for frame in unstacked_frames:
            depth_map, features = self.unet.forward(frame)
            encoded_features.append(features)
            # depth_map = self.final_conv_layer0(depth_map)
            # depth_map = self.final_conv_layer1(depth_map)
            depth_map = F.sigmoid(depth_map.squeeze(1))
            depth_maps.append(depth_map)

        # next predict ego_motion, stack feature maps
        stacked_features = torch.cat(encoded_features,1)
        ego_motion_features = self.ego_motion_cnn(stacked_features)
        # reduce mean along the spatial dimensions
        ego_vectors = torch.mean(ego_motion_features,[2,3]).reshape(batch_size,self.nframes-1,-1)
        depth_maps = torch.stack(depth_maps,1)

        # frames, ego_vectors, and depth_maps
        return x, ego_vectors, depth_maps
def model_builder():
    classifier = resnet18(is_ptrtrained=False)
    print('use resnet18')
    auto_encoder = UNet(3, depth=5, in_channels=3)
    auto_encoder.load_state_dict(weight_to_cpu(args.pretrain_unet))
    print('load pretrained unet!')

    model = Locator(aer=auto_encoder, classifier=classifier)
    if args.cuda:
        model = DataParallel(model).cuda()
    else:
        raise ValueError('there is no gpu')

    return model
 def __init__(self, config):
     super().__init__()
     self.config = config
     self.gpuid = config['gpuid']
     self.unet = UNet(config)
     self.keypoint = Keypoint(config)
     self.softmax_matcher = SoftmaxMatcher(config)
     self.svd = SVD(config)
Example #9
0
    def __init__(self, inchans=3,nframes=3):
        super(BaseMonoDepthEstimator, self).__init__()
        # 3 parameters for translation and 3 for rotation
        self.inchans = inchans
        self.ego_vector_dim = 6
        self.nframes = nframes
        self.ego_prediction_size = self.ego_vector_dim * (self.nframes -1)

        # Depth network - encoder-decoder combo
        self.unet = UNet(3,1)
        self.final_conv_layer0 = nn.Conv2d(1, 1, 1, stride=1, padding=0)
        self.final_conv_layer1 = nn.Conv2d(1, 1, 1, stride=1, padding=0)

        # 2) an ego motion network - use the encoder from (1)
        # and append extra cnns
        self.ego_motion_cnn = ConvolutionStack(self.unet.feature_channels()*nframes)
        self.ego_motion_cnn.append(128,3,2)
        self.ego_motion_cnn.append(64,3,2)
        self.ego_motion_cnn.append(self.ego_prediction_size,1,1)
Example #10
0
 def __init__(self, config):
     super().__init__()
     self.config = config
     self.gpuid = config['gpuid']
     self.unet = UNet(config)
     self.keypoint = Keypoint(config)
     self.softmax_matcher = SoftmaxRefMatcher(config)
     self.solver = SteamSolver(config)
     self.patch_size = config['networks']['keypoint_block']['patch_size']
     self.patch_mean_thres = config['steam']['patch_mean_thres']
Example #11
0
    def __init__(self, device, num_steps, z_dimension=8):

        # in and out channels for the generator:
        a, b = 2, 3

        G = Generator(a, b) if not USE_UNET else UNet(a, b)
        E = ResNetEncoder(b, z_dimension)

        # conditional discriminators
        D1 = MultiScaleDiscriminator(a + b - 1)
        D2 = MultiScaleDiscriminator(a + b - 1)

        def weights_init(m):
            if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
                init.xavier_normal_(m.weight, gain=0.02)
                if m.bias is not None:
                    init.zeros_(m.bias)
            elif isinstance(m, nn.InstanceNorm2d) and m.affine:
                init.ones_(m.weight)
                init.zeros_(m.bias)

        self.G = G.apply(weights_init).to(device)
        self.E = E.apply(weights_init).to(device)
        self.D1 = D1.apply(weights_init).to(device)
        self.D2 = D2.apply(weights_init).to(device)

        params = {
            'lr': 4e-4,
            'betas': (0.5, 0.999),
            'weight_decay': 1e-8
        }
        generator_groups = [
            {'params': [p for n, p in self.G.named_parameters() if 'mapping' not in n]},
            {'params': self.G.mapping.parameters(), 'lr': 4e-5}
        ]
        self.optimizer = {
            'G': optim.Adam(generator_groups, **params),
            'E': optim.Adam(self.E.parameters(), **params),
            'D1': optim.Adam(self.D1.parameters(), **params),
            'D2': optim.Adam(self.D2.parameters(), **params)
        }

        def lambda_rule(i):
            decay = num_steps // 2
            m = 1.0 if i < decay else 1.0 - (i - decay) / decay
            return max(m, 0.0)

        self.schedulers = []
        for o in self.optimizer.values():
            self.schedulers.append(LambdaLR(o, lr_lambda=lambda_rule))

        self.gan_loss = LSGAN()
        self.z_dimension = z_dimension
        self.device = device
def main():
    global args, logger
    args = parser.parse_args()
    # logger = Logger(add_prefix(args.prefix, 'logs'))
    set_prefix(args.prefix, __file__)
    model = UNet(3, depth=5, in_channels=3)
    print(model)
    print('load unet with depth=5')
    if args.cuda:
        model = DataParallel(model).cuda()
    else:
        raise RuntimeError('there is no gpu')
    criterion = nn.L1Loss(reduce=False).cuda()
    print('use l1_loss')
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    # accelerate the speed of training
    cudnn.benchmark = True

    data_loader = get_dataloader()
    # class_names=['LESION', 'NORMAL']
    # class_names = data_loader.dataset.class_names
    # print(class_names)

    since = time.time()
    print('-' * 10)
    for epoch in range(1, args.epochs + 1):
        train(data_loader, model, optimizer, criterion, epoch)
        if epoch % 40 == 0:
            validate(model, epoch, data_loader)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    validate(model, args.epochs, data_loader)
    # save model parameter
    torch.save(model.state_dict(),
               add_prefix(args.prefix, 'identical_mapping.pkl'))
    # save running parameter setting to json
    write(vars(args), add_prefix(args.prefix, 'paras.txt'))
Example #13
0
def net_factory(net_type="unet", in_chns=1, class_num=3):
    if net_type == "unet":
        net = UNet(in_chns=in_chns, class_num=class_num).cuda()
    elif net_type == "unet_ds":
        net = UNet_DS(in_chns=in_chns, class_num=class_num).cuda()
    elif net_type == "efficient_unet":
        net = Effi_UNet('efficientnet-b3',
                        encoder_weights='imagenet',
                        in_channels=in_chns,
                        classes=class_num).cuda()
    elif net_type == "pnet":
        net = PNet2D(in_chns, class_num, 64, [1, 2, 4, 8, 16]).cuda()
    else:
        net = None
    return net
def test_easy_dr():
    normalize = transforms.Normalize([0.5, 0.5, 0.5],
                                     [0.5, 0.5, 0.5])
    model = UNet(3, depth=3, in_channels=3)
    train_dir = '../data/gan/normal'
    train_dataset = EasyDR(data_dir=train_dir,
                           pre_transform=None,
                           post_transform=transforms.Compose([
                               transforms.ToTensor(),
                               normalize
                           ]),
                           )
    train_loader = DataLoader(dataset=train_dataset, batch_size=15,
                              shuffle=False, num_workers=2, pin_memory=False)
    for idx, (inputs, target, image_names, weight) in enumerate(train_loader):
        inputs = Variable(inputs)
        target = Variable(target)
        result = model(inputs)
        print(image_names)
        break
Example #15
0
def train(config):
    # rng
    rng = np.random.RandomState(config["seed"])
    torch.cuda.manual_seed(config["seed"])
    torch.cuda.manual_seed_all(config["seed"])

    # occupy
    occ = Occupier()
    if config["occupy"]:
        occ.occupy()

    # Compute input shape
    c = UNet.get_optimal_shape(output_shape_lower_bound=config["output_size"],
                               steps=config["num_unet_steps"],
                               num_convs=config["num_unet_convs"])
    input_size = [int(ci) for ci in c["input"]]
    config['margin'] = np.asarray(input_size) - np.asarray(
        config["output_size"])
    # m = np.asarray(input_size) - np.asarray(config["output_size"])
    # if len(np.unique(m)) == 1:
    #     config["margin"] = m[0]
    # else:
    #     raise RuntimeError("Should never be here?")
    if len(np.unique(config["margin"])) > 1:
        raise RuntimeError("Beware: this might not work?")
    data = Data(config)

    # writer
    writer = SummaryWriter(log_dir="output/logs/" + config["force_hash"])
    board = {
        'dataset': data.loss_label,
        'loss': config['loss'],
        'writer': writer,
    }

    # Save config file, for reference
    os.system('cp {} {}/{}'.format(config["config_filename"], config["output"],
                                   config["config_filename"].split('/')[-1]))
    fn = config["output"] + "/config.h5"
    print("Storing config file: '{}'".format(fn))
    dd.io.save(fn, config)

    if config["model"] == "UNet":
        print("Instantiating UNet")

        model = UNet(
            steps=config["num_unet_steps"],
            num_input_channels=data.num_channels,
            first_layer_channels=config["num_unet_filters"],
            num_classes=data.num_classes,
            num_convs=config["num_unet_convs"],
            output_size=config["output_size"],
            pooling=config["pooling"],
            activation=config["activation"],
            use_dropout=config["use_dropout"],
            use_batchnorm=config["use_batchnorm"],
            init_type=config["init_type"],
            final_unit=config["final_unit"],
        )

        # Need to overwrite this
        if model.is_3d:
            config["input_size"] = model.input_size
        else:
            config["input_size"] = [model.input_size[1], model.input_size[2]]
        # config["margin"] = model.margin
        print("UNet -> Input size: {}. Output size: {}".format(
            config["input_size"], config["output_size"]))
    else:
        raise RuntimeError("Unknown model")
    model.cuda()

    # Sanity check
    for j in range(len(data.train_images_optim)):
        s = data.train_images_optim[j].shape
        for i in range(len(s) - 1):
            if model.input_size[i] > s[i + 1]:
                raise RuntimeError('Input patch larger than training data '
                                   '({}>{}) for dim #{}, sample #{}'.format(
                                       model.input_size[i], s[i + 1], i, j))
    if data.val_images_mirrored:
        for j in range(len(data.val_images_mirrored)):
            s = data.val_images_mirrored[j].shape
            for i in range(len(s) - 1):
                if model.input_size[i] > s[i + 1]:
                    raise RuntimeError(
                        'Input patch larger than validation data '
                        '({}>{}) for dim #{}, sample #{}'.format(
                            model.input_size[i], s[i + 1], i, j))
    if data.test_images_mirrored:
        for j in range(len(data.test_images_mirrored)):
            s = data.test_images_mirrored[j].shape
            for i in range(len(s) - 1):
                if model.input_size[i] > s[i + 1]:
                    raise RuntimeError(
                        'Input patch larger than test data '
                        '({}>{}) for dim #{}, sample #{}'.format(
                            model.input_size[i], s[i + 1], i, j))

    if config["optimizer"] == "Adam":
        optimizer = optim.Adam(
            model.parameters(),
            lr=config["learning_rate"],
            weight_decay=config["weight_decay"],
        )
    elif config["optimizer"] == "SGD":
        optimizer = optim.SGD(
            model.parameters(),
            lr=config["learning_rate"],
            weight_decay=config["weight_decay"],
            momentum=config["momentum"],
        )
    elif config["optimizer"] == "RMSprop":
        optimizer = optim.RMSprop(
            model.parameters(),
            lr=config["learning_rate"],
            weight_decay=config["weight_decay"],
            momentum=config["momentum"],
        )
    else:
        raise RuntimeError("Unsupported optimizer")

    # Load state
    first_batch = 0
    fn = config["output"] + "/state.h5"
    if isfile(fn):
        # print("Loading state: '{}'".format(fn))
        # with open(fn, "rb") as handle:
        #     state = pickle.load(handle)
        state = dd.io.load(fn)
        first_batch = state["cur_batch"] + 1
    else:
        state = {}

    # Load model
    fn = "{}/model-last.pth".format(config["output"])
    if isfile(fn):
        print("Loading model: '{}'".format(fn))
        model.load_state_dict(torch.load(fn))
    else:
        print("No model to load")

    # Load optimizer
    fn = "{}/optim-last.pth".format(config["output"])
    if isfile(fn):
        optimizer.load_state_dict(torch.load(fn))
    else:
        print("No optimizer to load")

    state.setdefault("epoch", 0)
    state.setdefault("cur_batch", 0)
    state.setdefault("loss", np.zeros(config["max_steps"]))
    state.setdefault("res_train", {"batch": [], "metrics": []})
    for t in config["test_thresholds"]:
        state.setdefault("res_train_th_{}".format(t), {
            "batch": [],
            "metrics": []
        })
        state.setdefault("res_val_th_{}".format(t), {
            "batch": [],
            "metrics": []
        })
        state.setdefault("res_test_th_{}".format(t), {
            "batch": [],
            "metrics": []
        })

    # TODO Learn to sample and update this accordingly
    if config["loss"] == "classification":
        # loss_criterion = torch.nn.NLLLoss(data.weights.cuda(), reduce=False)
        loss_criterion = F.nll_loss
    elif config["loss"] == "regression":
        raise RuntimeError("TODO")
    elif config['loss'] == 'jaccard' or config['loss'] == 'dice':
        from loss import OverlapLoss
        loss_criterion = OverlapLoss(config['loss'],
                                     config['overlap_loss_smoothness'],
                                     config['overlap_fp_factor'])
    else:
        raise RuntimeError("TODO")

    if model.is_3d:
        batch = torch.Tensor(
            config["batch_size"],
            data.num_channels,
            config["input_size"][0],
            config["input_size"][1],
            config["input_size"][2],
        )
        # labels = torch.ByteTensor(
        if not data.dot_annotations:
            labels = torch.LongTensor(
                config["batch_size"],
                config["output_size"][0],
                config["output_size"][1],
                config["output_size"][2],
            )
        else:
            labels = []
    else:
        batch = torch.Tensor(
            config["batch_size"],
            data.num_channels,
            config["input_size"][0],
            config["input_size"][1],
        )
        # labels = torch.ByteTensor(
        if not data.dot_annotations:
            labels = torch.LongTensor(
                config["batch_size"],
                config["output_size"][0],
                config["output_size"][1],
            )
        else:
            labels = []

    do_save_state = False
    model.train()

    # Sampler
    print("Instantiating sampler")
    sampler = Sampler(
        model.is_3d,
        {
            "images": data.train_images_optim,
            "labels": data.train_labels_optim,
            "mean": data.train_mean,
            "std": data.train_std
        },
        config,
        rng,
        data.dot_annotations,
    )

    if occ.is_busy():
        occ.free()

    # Loop
    for state["cur_batch"] in range(first_batch, config["max_steps"]):
        # Sample
        ts = time()
        coords = []
        elastic = []
        for i in range(config["batch_size"]):
            b, l, cur_coords, cur_elastic = sampler.sample()
            batch[i] = torch.from_numpy(b)
            if not data.dot_annotations:
                labels[i] = torch.from_numpy(l)
            else:
                labels.append(torch.from_numpy(l))
            coords.append(cur_coords)
            elastic.append(cur_elastic)

        # Forward pass
        inputs = Variable(batch).cuda()
        outputs = model(inputs)
        optimizer.zero_grad()
        if config['loss'] == 'jaccard' or config['loss'] == 'dice':
            targets = Variable(labels.float()).cuda()
            o = F.softmax(outputs, dim=1)[:, 1, :, :]
            loss = loss_criterion.forward(o, targets)
            loss = sum(loss) / len(loss)
        elif config['loss'] == 'classification':
            targets = Variable(labels).cuda()
            if data.is_3d:
                # Do it slice by slice. Ugly but it works!
                loss = []
                for z in range(outputs.shape[2]):
                    loss.append(
                        loss_criterion(F.log_softmax(outputs[:, :, z, :, :],
                                                     dim=1),
                                       targets[:, z, :, :],
                                       weight=data.weights.cuda(),
                                       reduce=True,
                                       ignore_index=2))
                loss = sum(loss) / len(loss)
            else:
                # f(reduce=True) is equivalent to f(reduce=False).mean()
                # no need to average over the batch size then
                loss = loss_criterion(F.log_softmax(outputs, dim=1),
                                      targets,
                                      weight=data.weights.cuda(),
                                      reduce=True,
                                      ignore_index=2)
        else:
            raise RuntimeError('Bad loss type')

        # Sanity check
        # if not data.dot_annotations and loss.data.cpu().sum() > 10:
        #     print("very high loss?")
        #     embed()

        # Backward pass
        loss.backward()
        optimizer.step()

        # Get class stats
        ws = [0, 0]
        for l in labels:
            ws[0] += (l == 0).sum()
            ws[1] += (l == 1).sum()

        # Update state
        cur_loss = loss.data.cpu().sum()
        state["loss"][state["cur_batch"]] = cur_loss
        board['writer'].add_scalar(board['dataset'] + '-loss-' + board['loss'],
                                   cur_loss, state['cur_batch'])

        print(
            "Batch {it:d} -> Avg. loss {loss:.05f}: [{t:.02f} s.] (Range: {rg:.1f})"
            .format(
                it=state["cur_batch"] + 1,
                loss=cur_loss,
                t=time() - ts,
                rg=outputs.data.max() - outputs.data.min(),
            ))

        # Cross-validation
        force_eval = False
        if config["check_val_every"] > 0 and data.evaluate_val:
            if (state["cur_batch"] + 1) % config["check_val_every"] == 0:
                res = model.inference(
                    {
                        "images": data.val_images_mirrored,
                        "mean": data.val_mean,
                        "std": data.val_std,
                    },
                    config['batch_size'],
                    config['use_lcn'],
                )

                is_best = model.validation_by_classification(
                    images=data.val_images,
                    gt=data.val_labels_th,
                    prediction=res,
                    state=state,
                    board=board,
                    output_folder=config['output'],
                    xval_metric=config['xval_metric'],
                    dilation_thresholds=config['test_thresholds'],
                    subset='val',
                    make_stack=data.plot_make_stack,
                    force_save=False,
                )

                # Save models if they are the best at any test threshold
                for k, v in is_best.items():
                    if v is True:
                        save_model(config, state, model, optimizer,
                                   'best_th_{}'.format(k))
                        do_save_state = True

                # Force testing on train/test
                force_eval = any(is_best.keys())

        # Test on the training data
        if config["check_train_every"] > 0 and data.evaluate_train:
            if ((state["cur_batch"] + 1) % config["check_train_every"]
                    == 0) or force_eval:
                res = model.inference(
                    {
                        "images":
                        data.train_images_mirrored[:data.num_train_orig],
                        "mean": data.train_mean,
                        "std": data.train_std,
                    },
                    config['batch_size'],
                    config['use_lcn'],
                )

                model.validation_by_classification(
                    images=data.train_images[:data.num_train_orig],
                    gt=data.train_labels_th,
                    prediction=res,
                    state=state,
                    board=board,
                    output_folder=config['output'],
                    xval_metric=config['xval_metric'],
                    dilation_thresholds=config['test_thresholds'],
                    subset='train',
                    make_stack=data.plot_make_stack,
                    force_save=force_eval,
                )

        # Test on the test data
        if config["check_test_every"] > 0 and data.evaluate_test:
            if ((state["cur_batch"] + 1) % config["check_test_every"]
                    == 0) or force_eval:
                res = model.inference(
                    {
                        "images": data.test_images_mirrored,
                        "mean": data.test_mean,
                        "std": data.test_std,
                    },
                    config['batch_size'],
                    config['use_lcn'],
                )

                model.validation_by_classification(
                    images=data.test_images,
                    gt=data.test_labels_th,
                    prediction=res,
                    state=state,
                    board=board,
                    output_folder=config['output'],
                    xval_metric=config['xval_metric'],
                    dilation_thresholds=config['test_thresholds'],
                    subset='test',
                    make_stack=data.plot_make_stack,
                    force_save=force_eval,
                )

        # Also save models periodically, to resume executions
        if config["save_models_every"] > 0:
            if (state["cur_batch"] + 1) % config["save_models_every"] == 0:
                save_model(config, state, model, optimizer, 'last')
                do_save_state = True

        # Save training state periodically (or if forced)
        if do_save_state:
            save_state(config, state)
            do_save_state = False

    board['writer'].close()
Example #16
0
class generate_scores(base):
    """
    usage:
    python generate_scores.py ../gan254 99 ../data/gan20 ../gan260
    note: the script runs in gpu environment.the script intends to test validate dataset if d is fully trained.
    """
    def __init__(self, prefix, epoch, data_dir, saved_path):
        self.prefix = prefix
        self.data = data_dir
        self.epoch = epoch
        self.batch_size = 32
        self.power = 2
        self.saved_path = saved_path
        self.dataloader = self.get_dataloader()
        self.config = self.load_config()
        self.unet = UNet(3, depth=self.config['u_depth'], in_channels=3)

        print(self.unet)
        print(
            'load uent with depth %d and downsampling will be performed for %d times!!'
            % (self.config['u_depth'], self.config['u_depth'] - 1))
        self.unet.load_state_dict(
            weight_to_cpu('%s/epoch_%s/g.pkl' % (self.prefix, self.epoch)))
        print('load pretrained unet')

        self.d = get_discriminator(self.config['gan_type'],
                                   self.config['d_depth'],
                                   self.config['dowmsampling'])
        self.d.load_state_dict(
            weight_to_cpu('%s/epoch_%s/d.pkl' % (self.prefix, self.epoch)))
        print('load pretrained d')

    def __call__(self):
        real_data_score = []
        fake_data_score = []
        for i, (lesion_data, _, lesion_names, _, real_data, _, normal_names,
                _) in enumerate(self.dataloader):
            print('id=%d' % i)
            lesion_output = self.d(self.unet(lesion_data))
            fake_data_score += list(
                lesion_output.squeeze().data.numpy().flatten())
            normal_output = self.d(real_data)
            real_data_score += list(
                normal_output.squeeze().data.numpy().flatten())
        if not os.path.exists(self.saved_path):
            os.mkdir(self.saved_path)
        self.plot_hist('%s/score_distribution.png' % self.saved_path,
                       real_data_score, fake_data_score)

    def load_config(self):
        return read(add_prefix(self.prefix, 'para.txt'))

    def get_dataloader(self):
        if self.data == '../data/gan17':
            print('training dataset of real normal data from gan15 in gan246.')
        elif self.data == '../data/gan18':
            print('validate dataset of unet lesion data output from gan246.')
        elif self.data == '../data/gan19':
            print('training dataset of real normal data from gan15 in gan253.')
        elif self.data == '../data/gan20':
            print('validate dataset of unet lesion data output from gan253.')
        else:
            raise ValueError(
                "the parameter data must be in ['./data/gan', './data/gan_h_flip']"
            )
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])
        dataset = ConcatDataset(data_dir=self.data,
                                transform=transform,
                                alpha=self.power)
        data_loader = DataLoader(dataset,
                                 batch_size=self.batch_size,
                                 shuffle=True,
                                 num_workers=2,
                                 drop_last=False,
                                 pin_memory=False)
        return data_loader
                           transform=test_xtransform,
                           target_transform=test_ytransform)
else:
    raise ValueError(
        'Unknown target dataset: the options are drosophila or hela. ')
tar_train_loader = DataLoader(tar_train, batch_size=args.train_batch_size)
tar_test_loader = DataLoader(tar_test, batch_size=args.test_batch_size)
"""
    Setup optimization for finetuning
"""
print('[%s] Setting up optimization for finetuning' %
      (datetime.datetime.now()))
# load best checkpoint
# net = torch.load(os.path.join(args.log_dir, args.target, str(0.0), args.method, 'logs', 'pretraining', 'best_checkpoint.pytorch'))
net = UNet(feature_maps=args.fm,
           levels=args.levels,
           group_norm=(args.group_norm == 1))
optimizer = optim.Adam(net.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer,
                                      step_size=args.step_size,
                                      gamma=args.gamma)
"""
    Finetune on target if necessary
"""
print('[%s] Finetuning with %d percent of the target labels' %
      (datetime.datetime.now(), args.frac_target_labels * 100))
seg_net = net.get_segmentation_net()
if args.frac_target_labels > 0:
    seg_net.train_net(train_loader_src=tar_train_loader,
                      test_loader_src=tar_test_loader,
                      test_loader_tar=None,
Example #18
0
    # Config
    # with open(params.pickle, 'rb') as f:
    #     config = pickle.load(f)
    config = dd.io.load(params.pickle)

    # Number of threads
    torch.set_num_threads(params.threads)
    if not params.use_cpu:
        print("Running on GPU {}".format(os.environ["CUDA_VISIBLE_DEVICES"]))

    config["batch_size"] = params.batch_size
    config["use_cpu"] = params.use_cpu

    # Compute input shape
    c = UNet.get_optimal_shape(output_shape_lower_bound=config["output_size"],
                               steps=config["num_unet_steps"],
                               num_convs=config["num_unet_convs"])
    input_size = [int(ci) for ci in c["input"]]
    m = np.asarray(input_size) - np.asarray(config["output_size"])
    if len(np.unique(m)) == 1:
        config["margin"] = m[0]
    else:
        raise RuntimeError("Should never be here?")

    if config["model"] == "UNet":
        print("Instantiating UNet")
        model = UNet(
            steps=config["num_unet_steps"],
            num_input_channels=np.int64(1),
            first_layer_channels=config["num_unet_filters"],
            num_classes=np.int64(2),
Example #19
0
from networks.unet import UNet
import torch

UNet.get_shape_combinations(200, steps=4, num_convs=1)