コード例 #1
0
 def get_unet(pretrain_unet_path):
     unet = UNet(3, depth=5, in_channels=3)
     print(unet)
     print('load uent with depth = 5 and downsampling will be performed for 4 times!!')
     unet.load_state_dict(weight_to_cpu(pretrain_unet_path))
     print('load pretrained unet!')
     return unet
コード例 #2
0
ファイル: train_crf.py プロジェクト: czifan/ConvCRF.pytorch
def main():
    model = UNet().to(device)
    model.load_state_dict(
        torch.load(pretrained_file, map_location=device)['state_dict'])
    model_crf = ConvCRF2d(config, kernel_size=5).to(device)
    criterion = BinaryDiceLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    train_dataset = ImageDataset(model, split='train')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=train_batch_size,
                                               num_workers=num_workers,
                                               shuffle=True,
                                               pin_memory=True)
    valid_dataset = ImageDataset(model, split='valid')
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=valid_batch_size,
                                               num_workers=num_workers,
                                               shuffle=False,
                                               pin_memory=True)

    for epoch in range(epochs):
        train(epoch, model_crf, train_loader, criterion, optimizer)
        valid(epoch, model_crf, valid_loader, criterion)

    save_checkpoint(
        {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        },
        save_file='pretrained/convcrf2d.pth.tar',
        is_best=False)
コード例 #3
0
 def get_unet(self):
     unet = UNet(3, depth=self.u_depth, in_channels=3)
     print(unet)
     print(
         'load uent with depth %d and downsampling will be performed for %d times!!'
         % (self.u_depth, self.u_depth - 1))
     if self.is_pretrained_unet:
         unet.load_state_dict(weight_to_cpu(self.pretrain_unet_path))
         print('load pretrained unet')
     return unet
コード例 #4
0
def model_builder():
    classifier = resnet18(is_ptrtrained=False)
    print('use resnet18')
    auto_encoder = UNet(3, depth=5, in_channels=3)
    auto_encoder.load_state_dict(weight_to_cpu(args.pretrain_unet))
    print('load pretrained unet!')

    model = Locator(aer=auto_encoder, classifier=classifier)
    if args.cuda:
        model = DataParallel(model).cuda()
    else:
        raise ValueError('there is no gpu')

    return model
コード例 #5
0
        )

        # Need to overwrite this
        if model.is_3d:
            config["input_size"] = model.input_size
        else:
            config["input_size"] = [model.input_size[1], model.input_size[2]]
        config["margin"] = model.margin
        print("UNet -> Input size: {}. Output size: {}".format(
            config["input_size"], config["output_size"]))
    else:
        raise RuntimeError("Unknown model")

    # Load weights
    print("Loading model: '{}'".format(params.weights))
    model.load_state_dict(torch.load(params.weights))

    if not config["use_cpu"]:
        model.cuda()

    if model.is_3d:
        config["input_size"] = model.input_size
    else:
        config["input_size"] = [model.input_size[1], model.input_size[2]]

    # Load some data
    stack = tifffile.imread(params.input)
    stack = np.expand_dims(stack, axis=1)

    # too much
    # stack = stack[:5, :1200, :1200]
コード例 #6
0
class generate_scores(base):
    """
    usage:
    python generate_scores.py ../gan254 99 ../data/gan20 ../gan260
    note: the script runs in gpu environment.the script intends to test validate dataset if d is fully trained.
    """
    def __init__(self, prefix, epoch, data_dir, saved_path):
        self.prefix = prefix
        self.data = data_dir
        self.epoch = epoch
        self.batch_size = 32
        self.power = 2
        self.saved_path = saved_path
        self.dataloader = self.get_dataloader()
        self.config = self.load_config()
        self.unet = UNet(3, depth=self.config['u_depth'], in_channels=3)

        print(self.unet)
        print(
            'load uent with depth %d and downsampling will be performed for %d times!!'
            % (self.config['u_depth'], self.config['u_depth'] - 1))
        self.unet.load_state_dict(
            weight_to_cpu('%s/epoch_%s/g.pkl' % (self.prefix, self.epoch)))
        print('load pretrained unet')

        self.d = get_discriminator(self.config['gan_type'],
                                   self.config['d_depth'],
                                   self.config['dowmsampling'])
        self.d.load_state_dict(
            weight_to_cpu('%s/epoch_%s/d.pkl' % (self.prefix, self.epoch)))
        print('load pretrained d')

    def __call__(self):
        real_data_score = []
        fake_data_score = []
        for i, (lesion_data, _, lesion_names, _, real_data, _, normal_names,
                _) in enumerate(self.dataloader):
            print('id=%d' % i)
            lesion_output = self.d(self.unet(lesion_data))
            fake_data_score += list(
                lesion_output.squeeze().data.numpy().flatten())
            normal_output = self.d(real_data)
            real_data_score += list(
                normal_output.squeeze().data.numpy().flatten())
        if not os.path.exists(self.saved_path):
            os.mkdir(self.saved_path)
        self.plot_hist('%s/score_distribution.png' % self.saved_path,
                       real_data_score, fake_data_score)

    def load_config(self):
        return read(add_prefix(self.prefix, 'para.txt'))

    def get_dataloader(self):
        if self.data == '../data/gan17':
            print('training dataset of real normal data from gan15 in gan246.')
        elif self.data == '../data/gan18':
            print('validate dataset of unet lesion data output from gan246.')
        elif self.data == '../data/gan19':
            print('training dataset of real normal data from gan15 in gan253.')
        elif self.data == '../data/gan20':
            print('validate dataset of unet lesion data output from gan253.')
        else:
            raise ValueError(
                "the parameter data must be in ['./data/gan', './data/gan_h_flip']"
            )
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])
        dataset = ConcatDataset(data_dir=self.data,
                                transform=transform,
                                alpha=self.power)
        data_loader = DataLoader(dataset,
                                 batch_size=self.batch_size,
                                 shuffle=True,
                                 num_workers=2,
                                 drop_last=False,
                                 pin_memory=False)
        return data_loader
コード例 #7
0
def train(config):
    # rng
    rng = np.random.RandomState(config["seed"])
    torch.cuda.manual_seed(config["seed"])
    torch.cuda.manual_seed_all(config["seed"])

    # occupy
    occ = Occupier()
    if config["occupy"]:
        occ.occupy()

    # Compute input shape
    c = UNet.get_optimal_shape(output_shape_lower_bound=config["output_size"],
                               steps=config["num_unet_steps"],
                               num_convs=config["num_unet_convs"])
    input_size = [int(ci) for ci in c["input"]]
    config['margin'] = np.asarray(input_size) - np.asarray(
        config["output_size"])
    # m = np.asarray(input_size) - np.asarray(config["output_size"])
    # if len(np.unique(m)) == 1:
    #     config["margin"] = m[0]
    # else:
    #     raise RuntimeError("Should never be here?")
    if len(np.unique(config["margin"])) > 1:
        raise RuntimeError("Beware: this might not work?")
    data = Data(config)

    # writer
    writer = SummaryWriter(log_dir="output/logs/" + config["force_hash"])
    board = {
        'dataset': data.loss_label,
        'loss': config['loss'],
        'writer': writer,
    }

    # Save config file, for reference
    os.system('cp {} {}/{}'.format(config["config_filename"], config["output"],
                                   config["config_filename"].split('/')[-1]))
    fn = config["output"] + "/config.h5"
    print("Storing config file: '{}'".format(fn))
    dd.io.save(fn, config)

    if config["model"] == "UNet":
        print("Instantiating UNet")

        model = UNet(
            steps=config["num_unet_steps"],
            num_input_channels=data.num_channels,
            first_layer_channels=config["num_unet_filters"],
            num_classes=data.num_classes,
            num_convs=config["num_unet_convs"],
            output_size=config["output_size"],
            pooling=config["pooling"],
            activation=config["activation"],
            use_dropout=config["use_dropout"],
            use_batchnorm=config["use_batchnorm"],
            init_type=config["init_type"],
            final_unit=config["final_unit"],
        )

        # Need to overwrite this
        if model.is_3d:
            config["input_size"] = model.input_size
        else:
            config["input_size"] = [model.input_size[1], model.input_size[2]]
        # config["margin"] = model.margin
        print("UNet -> Input size: {}. Output size: {}".format(
            config["input_size"], config["output_size"]))
    else:
        raise RuntimeError("Unknown model")
    model.cuda()

    # Sanity check
    for j in range(len(data.train_images_optim)):
        s = data.train_images_optim[j].shape
        for i in range(len(s) - 1):
            if model.input_size[i] > s[i + 1]:
                raise RuntimeError('Input patch larger than training data '
                                   '({}>{}) for dim #{}, sample #{}'.format(
                                       model.input_size[i], s[i + 1], i, j))
    if data.val_images_mirrored:
        for j in range(len(data.val_images_mirrored)):
            s = data.val_images_mirrored[j].shape
            for i in range(len(s) - 1):
                if model.input_size[i] > s[i + 1]:
                    raise RuntimeError(
                        'Input patch larger than validation data '
                        '({}>{}) for dim #{}, sample #{}'.format(
                            model.input_size[i], s[i + 1], i, j))
    if data.test_images_mirrored:
        for j in range(len(data.test_images_mirrored)):
            s = data.test_images_mirrored[j].shape
            for i in range(len(s) - 1):
                if model.input_size[i] > s[i + 1]:
                    raise RuntimeError(
                        'Input patch larger than test data '
                        '({}>{}) for dim #{}, sample #{}'.format(
                            model.input_size[i], s[i + 1], i, j))

    if config["optimizer"] == "Adam":
        optimizer = optim.Adam(
            model.parameters(),
            lr=config["learning_rate"],
            weight_decay=config["weight_decay"],
        )
    elif config["optimizer"] == "SGD":
        optimizer = optim.SGD(
            model.parameters(),
            lr=config["learning_rate"],
            weight_decay=config["weight_decay"],
            momentum=config["momentum"],
        )
    elif config["optimizer"] == "RMSprop":
        optimizer = optim.RMSprop(
            model.parameters(),
            lr=config["learning_rate"],
            weight_decay=config["weight_decay"],
            momentum=config["momentum"],
        )
    else:
        raise RuntimeError("Unsupported optimizer")

    # Load state
    first_batch = 0
    fn = config["output"] + "/state.h5"
    if isfile(fn):
        # print("Loading state: '{}'".format(fn))
        # with open(fn, "rb") as handle:
        #     state = pickle.load(handle)
        state = dd.io.load(fn)
        first_batch = state["cur_batch"] + 1
    else:
        state = {}

    # Load model
    fn = "{}/model-last.pth".format(config["output"])
    if isfile(fn):
        print("Loading model: '{}'".format(fn))
        model.load_state_dict(torch.load(fn))
    else:
        print("No model to load")

    # Load optimizer
    fn = "{}/optim-last.pth".format(config["output"])
    if isfile(fn):
        optimizer.load_state_dict(torch.load(fn))
    else:
        print("No optimizer to load")

    state.setdefault("epoch", 0)
    state.setdefault("cur_batch", 0)
    state.setdefault("loss", np.zeros(config["max_steps"]))
    state.setdefault("res_train", {"batch": [], "metrics": []})
    for t in config["test_thresholds"]:
        state.setdefault("res_train_th_{}".format(t), {
            "batch": [],
            "metrics": []
        })
        state.setdefault("res_val_th_{}".format(t), {
            "batch": [],
            "metrics": []
        })
        state.setdefault("res_test_th_{}".format(t), {
            "batch": [],
            "metrics": []
        })

    # TODO Learn to sample and update this accordingly
    if config["loss"] == "classification":
        # loss_criterion = torch.nn.NLLLoss(data.weights.cuda(), reduce=False)
        loss_criterion = F.nll_loss
    elif config["loss"] == "regression":
        raise RuntimeError("TODO")
    elif config['loss'] == 'jaccard' or config['loss'] == 'dice':
        from loss import OverlapLoss
        loss_criterion = OverlapLoss(config['loss'],
                                     config['overlap_loss_smoothness'],
                                     config['overlap_fp_factor'])
    else:
        raise RuntimeError("TODO")

    if model.is_3d:
        batch = torch.Tensor(
            config["batch_size"],
            data.num_channels,
            config["input_size"][0],
            config["input_size"][1],
            config["input_size"][2],
        )
        # labels = torch.ByteTensor(
        if not data.dot_annotations:
            labels = torch.LongTensor(
                config["batch_size"],
                config["output_size"][0],
                config["output_size"][1],
                config["output_size"][2],
            )
        else:
            labels = []
    else:
        batch = torch.Tensor(
            config["batch_size"],
            data.num_channels,
            config["input_size"][0],
            config["input_size"][1],
        )
        # labels = torch.ByteTensor(
        if not data.dot_annotations:
            labels = torch.LongTensor(
                config["batch_size"],
                config["output_size"][0],
                config["output_size"][1],
            )
        else:
            labels = []

    do_save_state = False
    model.train()

    # Sampler
    print("Instantiating sampler")
    sampler = Sampler(
        model.is_3d,
        {
            "images": data.train_images_optim,
            "labels": data.train_labels_optim,
            "mean": data.train_mean,
            "std": data.train_std
        },
        config,
        rng,
        data.dot_annotations,
    )

    if occ.is_busy():
        occ.free()

    # Loop
    for state["cur_batch"] in range(first_batch, config["max_steps"]):
        # Sample
        ts = time()
        coords = []
        elastic = []
        for i in range(config["batch_size"]):
            b, l, cur_coords, cur_elastic = sampler.sample()
            batch[i] = torch.from_numpy(b)
            if not data.dot_annotations:
                labels[i] = torch.from_numpy(l)
            else:
                labels.append(torch.from_numpy(l))
            coords.append(cur_coords)
            elastic.append(cur_elastic)

        # Forward pass
        inputs = Variable(batch).cuda()
        outputs = model(inputs)
        optimizer.zero_grad()
        if config['loss'] == 'jaccard' or config['loss'] == 'dice':
            targets = Variable(labels.float()).cuda()
            o = F.softmax(outputs, dim=1)[:, 1, :, :]
            loss = loss_criterion.forward(o, targets)
            loss = sum(loss) / len(loss)
        elif config['loss'] == 'classification':
            targets = Variable(labels).cuda()
            if data.is_3d:
                # Do it slice by slice. Ugly but it works!
                loss = []
                for z in range(outputs.shape[2]):
                    loss.append(
                        loss_criterion(F.log_softmax(outputs[:, :, z, :, :],
                                                     dim=1),
                                       targets[:, z, :, :],
                                       weight=data.weights.cuda(),
                                       reduce=True,
                                       ignore_index=2))
                loss = sum(loss) / len(loss)
            else:
                # f(reduce=True) is equivalent to f(reduce=False).mean()
                # no need to average over the batch size then
                loss = loss_criterion(F.log_softmax(outputs, dim=1),
                                      targets,
                                      weight=data.weights.cuda(),
                                      reduce=True,
                                      ignore_index=2)
        else:
            raise RuntimeError('Bad loss type')

        # Sanity check
        # if not data.dot_annotations and loss.data.cpu().sum() > 10:
        #     print("very high loss?")
        #     embed()

        # Backward pass
        loss.backward()
        optimizer.step()

        # Get class stats
        ws = [0, 0]
        for l in labels:
            ws[0] += (l == 0).sum()
            ws[1] += (l == 1).sum()

        # Update state
        cur_loss = loss.data.cpu().sum()
        state["loss"][state["cur_batch"]] = cur_loss
        board['writer'].add_scalar(board['dataset'] + '-loss-' + board['loss'],
                                   cur_loss, state['cur_batch'])

        print(
            "Batch {it:d} -> Avg. loss {loss:.05f}: [{t:.02f} s.] (Range: {rg:.1f})"
            .format(
                it=state["cur_batch"] + 1,
                loss=cur_loss,
                t=time() - ts,
                rg=outputs.data.max() - outputs.data.min(),
            ))

        # Cross-validation
        force_eval = False
        if config["check_val_every"] > 0 and data.evaluate_val:
            if (state["cur_batch"] + 1) % config["check_val_every"] == 0:
                res = model.inference(
                    {
                        "images": data.val_images_mirrored,
                        "mean": data.val_mean,
                        "std": data.val_std,
                    },
                    config['batch_size'],
                    config['use_lcn'],
                )

                is_best = model.validation_by_classification(
                    images=data.val_images,
                    gt=data.val_labels_th,
                    prediction=res,
                    state=state,
                    board=board,
                    output_folder=config['output'],
                    xval_metric=config['xval_metric'],
                    dilation_thresholds=config['test_thresholds'],
                    subset='val',
                    make_stack=data.plot_make_stack,
                    force_save=False,
                )

                # Save models if they are the best at any test threshold
                for k, v in is_best.items():
                    if v is True:
                        save_model(config, state, model, optimizer,
                                   'best_th_{}'.format(k))
                        do_save_state = True

                # Force testing on train/test
                force_eval = any(is_best.keys())

        # Test on the training data
        if config["check_train_every"] > 0 and data.evaluate_train:
            if ((state["cur_batch"] + 1) % config["check_train_every"]
                    == 0) or force_eval:
                res = model.inference(
                    {
                        "images":
                        data.train_images_mirrored[:data.num_train_orig],
                        "mean": data.train_mean,
                        "std": data.train_std,
                    },
                    config['batch_size'],
                    config['use_lcn'],
                )

                model.validation_by_classification(
                    images=data.train_images[:data.num_train_orig],
                    gt=data.train_labels_th,
                    prediction=res,
                    state=state,
                    board=board,
                    output_folder=config['output'],
                    xval_metric=config['xval_metric'],
                    dilation_thresholds=config['test_thresholds'],
                    subset='train',
                    make_stack=data.plot_make_stack,
                    force_save=force_eval,
                )

        # Test on the test data
        if config["check_test_every"] > 0 and data.evaluate_test:
            if ((state["cur_batch"] + 1) % config["check_test_every"]
                    == 0) or force_eval:
                res = model.inference(
                    {
                        "images": data.test_images_mirrored,
                        "mean": data.test_mean,
                        "std": data.test_std,
                    },
                    config['batch_size'],
                    config['use_lcn'],
                )

                model.validation_by_classification(
                    images=data.test_images,
                    gt=data.test_labels_th,
                    prediction=res,
                    state=state,
                    board=board,
                    output_folder=config['output'],
                    xval_metric=config['xval_metric'],
                    dilation_thresholds=config['test_thresholds'],
                    subset='test',
                    make_stack=data.plot_make_stack,
                    force_save=force_eval,
                )

        # Also save models periodically, to resume executions
        if config["save_models_every"] > 0:
            if (state["cur_batch"] + 1) % config["save_models_every"] == 0:
                save_model(config, state, model, optimizer, 'last')
                do_save_state = True

        # Save training state periodically (or if forced)
        if do_save_state:
            save_state(config, state)
            do_save_state = False

    board['writer'].close()