Пример #1
0
def main():

    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(funcname())

    prs = argparse.ArgumentParser()
    prs.add_argument('--name',
                     help='name used for checkpoints',
                     default='unet',
                     type=str)

    subprs = prs.add_subparsers(title='actions',
                                description='Choose from one of the actions.')
    subprs_trn = subprs.add_parser('train', help='Run training.')
    subprs_trn.set_defaults(which='train')
    subprs_trn.add_argument('-w', '--weights', help='path to keras weights')

    subprs_sbt = subprs.add_parser('submit', help='Make submission.')
    subprs_sbt.set_defaults(which='submit')
    subprs_sbt.add_argument('-w',
                            '--weights',
                            help='path to keras weights',
                            required=True)
    subprs_sbt.add_argument('-t',
                            '--tiff',
                            help='path to tiffs',
                            default='data/test-volume.tif')

    args = vars(prs.parse_args())
    assert args['which'] in ['train', 'submit']

    model = UNet(args['name'])

    if not path.exists(model.checkpoint_path):
        makedirs(model.checkpoint_path)

    def load_weights():
        if args['weights'] is not None:
            logger.info('Loading weights from %s.' % args['weights'])
            model.net.load_weights(args['weights'])

    if args['which'] == 'train':
        model.compile()
        load_weights()
        # model.net.summary()
        model.load_data()
        history = model.train()
        save_history(history, Path('checkpoints/unet_64'))

    elif args['which'] == 'submit':
        out_path = '%s/test-volume-masks.tif' % model.checkpoint_path
        model.config['input_shape'] = (512, 512)
        model.config['output_shape'] = (512, 512)
        model.compile()
        load_weights()
        model.net.summary()
        imgs_sbt = tiff.imread(args['tiff'])
        msks_sbt = model.predict(imgs_sbt)
        logger.info('Writing predicted masks to %s' % out_path)
        tiff.imsave(out_path, msks_sbt)
Пример #2
0
    def load_data(self):
        """output range [0,255]
        """
        logger = logging.getLogger(funcname())
        logger.info('Reading images from %s.' % self.config['data_path'])

        imgs = tiff.imread('%s/train-volume.tif' % self.config['data_path'])
        msks = tiff.imread('%s/train-labels.tif' %
                           self.config['data_path']).round()

        nb_trn = int(len(imgs) * self.config['prop_trn'])
        nb_val = int(len(imgs) * self.config['prop_val'])
        idx = np.arange(len(imgs))

        # Randomize selection for training and validation.
        if self.config['random_split']:
            np.random.shuffle(idx)
            idx_trn, idx_val = idx[:nb_trn], idx[-nb_val:]
        else:
            idx_trn, idx_val = idx[-nb_trn:], idx[:nb_val]
            np.random.shuffle(idx_trn)
            np.random.shuffle(idx_val)

        print(idx_trn)
        print(idx_val)

        H, W = self.config['img_shape']
        logger.info('Combining images and masks into montages.')
        imgs_trn, msks_trn = imgs[idx_trn], msks[idx_trn]
        nb_row, nb_col = self.config['montage_trn_shape']
        assert nb_row * nb_col == len(imgs_trn) == len(msks_trn)
        self.imgs_montage_trn = np.empty((nb_row * H, nb_col * W))
        self.msks_montage_trn = np.empty((nb_row * H, nb_col * W))
        imgs_trn, msks_trn = iter(imgs_trn), iter(msks_trn)
        for y0 in range(0, nb_row * H, H):
            for x0 in range(0, nb_col * W, W):
                y1, x1 = y0 + H, x0 + W
                self.imgs_montage_trn[y0:y1, x0:x1] = next(imgs_trn)
                self.msks_montage_trn[y0:y1, x0:x1] = next(msks_trn)

        logger.info('Combining validation images and masks into montages')
        imgs_val, msks_val = imgs[idx_val], msks[idx_val]
        nb_row, nb_col = self.config['montage_val_shape']
        assert nb_row * nb_col == len(imgs_val) == len(msks_val)
        self.imgs_montage_val = np.empty((nb_row * H, nb_col * W))
        self.msks_montage_val = np.empty((nb_row * H, nb_col * W))
        imgs_val, msks_val = iter(imgs_val), iter(msks_val)
        for y0 in range(0, nb_row * H, H):
            for x0 in range(0, nb_col * W, W):
                y1, x1 = y0 + H, x0 + W
                self.imgs_montage_val[y0:y1, x0:x1] = next(imgs_val)
                self.msks_montage_val[y0:y1, x0:x1] = next(msks_val)

        # Correct the types.
        self.imgs_montage_trn = self.imgs_montage_trn.astype(np.float32)
        self.msks_montage_trn = self.msks_montage_trn.astype(np.uint8)
        self.imgs_montage_val = self.imgs_montage_val.astype(np.float32)
        self.msks_montage_val = self.msks_montage_val.astype(np.uint8)

        return
Пример #3
0
def isbi_get_data_montage(imgs_path, msks_path, nb_rows, nb_cols, rng):
    '''Reads the images and masks and arranges them in a montage for sampling in training.'''
    logger = logging.getLogger(funcname())
    imgs = tiff.imread('%s/0.tif'%imgs_path)
    imgs = imgs.transpose(2,0,1)
    for i in range(1,25):
        im = tiff.imread('%s/%d.tif'%(imgs_path,i))
        im = im.transpose(2,0,1)
        imgs = np.concatenate((imgs,im),axis=0)
        # print(imgs.shape)
    msks = tiff.imread('%s/0.tif'%msks_path)
    msks = msks.transpose(2,0,1)
    for i in range(1,25):
        ms = tiff.imread('%s/%d.tif'%(msks_path,i))
        ms = ms.transpose(2,0,1)
        msks = np.concatenate((msks,ms),axis=0)
        # print(msks.shape)
    msks = msks / 255

    # imgs, msks = tiff.imread(imgs_path), tiff.imread(msks_path) / 255
    montage_imgs = np.empty((nb_rows * imgs.shape[1], nb_cols * imgs.shape[2]), dtype=np.float32)
    montage_msks = np.empty((nb_rows * imgs.shape[1], nb_cols * imgs.shape[2]), dtype=np.int8)

    idxs = np.arange(imgs.shape[0])
    rng.shuffle(idxs)
    idxs = iter(idxs)

    for y0 in range(0, montage_imgs.shape[0], imgs.shape[1]):
        for x0 in range(0, montage_imgs.shape[1], imgs.shape[2]):
            y1, x1 = y0 + imgs.shape[1], x0 + imgs.shape[2]
            idx = next(idxs)
            montage_imgs[y0:y1, x0:x1] = imgs[idx]
            montage_msks[y0:y1, x0:x1] = msks[idx]

    return montage_imgs, montage_msks
Пример #4
0
    def save_config(self):
        logger = logging.getLogger(funcname())

        if self.config['checkpoint_path_config']:
            logger.info('Saving model config to %s.' %
                        self.config['checkpoint_path_config'])
            f = open(self.config['checkpoint_path_config'], 'wb')
            pickle.dump(self.config, f)
            f.close()

        return
Пример #5
0
    def train(self):

        logger = logging.getLogger(funcname())

        gen_trn = self.batch_gen_trn(imgs=self.imgs_trn,
                                     msks=self.msks_trn,
                                     batch_size=self.config['batch_size'],
                                     transform=self.config['transform_train'])
        gen_val = self.batch_gen_trn(imgs=self.imgs_val,
                                     msks=self.msks_val,
                                     batch_size=self.config['batch_size'],
                                     transform=self.config['transform_train'])

        cb = [
            ReduceLROnPlateau(monitor='loss',
                              factor=0.9,
                              patience=5,
                              cooldown=3,
                              min_lr=1e-5,
                              verbose=1),
            ReduceLROnPlateau(monitor='val_loss',
                              factor=0.9,
                              patience=5,
                              cooldown=3,
                              min_lr=1e-5,
                              verbose=1),
            EarlyStopping(monitor='val_loss',
                          min_delta=1e-3,
                          patience=15,
                          verbose=1,
                          mode='min'),
            ModelCheckpoint(self.checkpoint_path + '/weights_loss_val.weights',
                            monitor='val_loss',
                            save_best_only=True,
                            verbose=1),
            ModelCheckpoint(self.checkpoint_path + '/weights_loss_trn.weights',
                            monitor='loss',
                            save_best_only=True,
                            verbose=1)
        ]

        logger.info('Training for %d epochs.' % self.config['nb_epoch'])

        history = self.net.fit_generator(generator=gen_trn,
                                         steps_per_epoch=100,
                                         epochs=self.config['nb_epoch'],
                                         validation_data=gen_val,
                                         validation_steps=20,
                                         verbose=1,
                                         callbacks=cb)

        return history
Пример #6
0
def isbi_get_test_data_montage(imgs_path, nb_rows, nb_cols, rng):
    '''Reads the images and masks and arranges them in a montage for sampling in training.'''
    logger = logging.getLogger(funcname())

    imgs = tiff.imread(imgs_path) / 255
    montage_imgs = np.empty((nb_rows * imgs.shape[1], nb_cols * imgs.shape[2]),
                            dtype=np.float32)
    idxs = np.arange(imgs.shape[0])
    rng.shuffle(idxs)
    idxs = iter(idxs)

    for y0 in range(0, montage_imgs.shape[0], imgs.shape[1]):
        for x0 in range(0, montage_imgs.shape[1], imgs.shape[2]):
            y1, x1 = y0 + imgs.shape[1], x0 + imgs.shape[2]
            idx = next(idxs)
            montage_imgs[y0:y1, x0:x1] = imgs[idx]

    return montage_imgs
Пример #7
0
def submit(args):
    logger = logging.getLogger(funcname())

    model = UNet()

    if args['model']:
        logger.info('Loading model from %s.' % args['model'])
        model.load_config(args['model'])

    # Get the checkpoint name before tweaking input shape, etc.
    chkpt_name = model.checkpoint_name

    model.config['input_shape'] = model.config['img_shape'] + model.config[
        'input_shape'][-1:]
    model.config['output_shape'] = model.config['img_shape'] + model.config[
        'output_shape'][-1:]
    model.config['output_shape_onehot'] = model.config[
        'img_shape'] + model.config['output_shape_onehot'][-1:]

    model.compile()
    model.net.summary()

    if args['net']:
        logger.info('Loading saved weights from %s.' % args['net'])
        model.net.load_weights(args['net'])

    logger.info('Loading testing images...')
    img_stack = tiff.imread('data/test-volume.tif')
    X_batch, coords = model.batch_gen_submit(img_stack)

    logger.info('Making predictions on batch...')
    prd_batch = model.net.predict(X_batch,
                                  batch_size=model.config['batch_size'])

    logger.info('Reconstructing images...')
    prd_stack = np.empty(img_stack.shape)
    for prd_wdw, (img_idx, y0, y1, x0, x1) in zip(prd_batch, coords):
        prd_stack[img_idx, y0:y1, x0:x1] = prd_wdw.reshape(y1 - y0, x1 - x0)
    prd_stack = prd_stack.astype('float32')

    logger.info('Saving full size predictions...')
    tiff.imsave(chkpt_name + '.submission.tif', prd_stack)
    logger.info('Done - saved file to %s.' % (chkpt_name + '.submission.tif'))
Пример #8
0
def train(args):

    logger = logging.getLogger(funcname())

    model = UNet()
    model.config['checkpoint_path_config'] = model.checkpoint_name + '.config'
    model.config[
        'checkpoint_path_history'] = model.checkpoint_name + '.history'
    model.config['transform_train'] = True
    model.config['nb_epoch'] = 250

    np.random.seed(model.config['seed'])
    model.load_data()
    model.save_config()
    model.compile()
    model.net.summary()
    if args['net']:
        logger.info('Loading saved weights from %s.' % args['net'])
        model.net.load_weights(args['net'])

    model.train()
    logger.info(model.evaluate())
    model.save_config()
    return
Пример #9
0
    def train(self):
        logger = logging.getLogger(funcname())

        gen_trn = self.batch_gen(imgs=self.imgs_montage_trn,
                                 msks=self.msks_montage_trn,
                                 infinite=True,
                                 re_seed=True,
                                 batch_size=self.config['batch_size'],
                                 transform=self.config['transform_train'])
        gen_val = self.batch_gen(imgs=self.imgs_montage_val,
                                 msks=self.msks_montage_val,
                                 infinite=True,
                                 re_seed=True,
                                 batch_size=self.config['batch_size'])

        cb = []
        cb.append(
            ReduceLROnPlateau(monitor='val_loss',
                              factor=0.5,
                              patience=3,
                              cooldown=5,
                              min_lr=1e-8,
                              verbose=1))
        cb.append(
            EarlyStopping(monitor='val_loss',
                          min_delta=1e-3,
                          patience=self.config['early_stop_patience'],
                          verbose=1,
                          mode='min'))
        cb.append(
            ModelCheckpoint(self.checkpoint_name + '_val_loss.net',
                            monitor='val_loss',
                            save_best_only=True,
                            verbose=1))
        cb.append(
            ModelCheckpoint(self.checkpoint_name + '_trn_loss.net',
                            monitor='loss',
                            save_best_only=True,
                            verbose=1))
        cb.append(
            TensorBoard(log_dir=self.checkpoint_name,
                        histogram_freq=0,
                        batch_size=1,
                        write_graph=True,
                        write_grads=False,
                        write_images=True,
                        update_freq='epoch'))
        history_plot_cb = KerasHistoryPlotCallback()
        history_plot_cb.file_name = self.checkpoint_name + '.history.png'
        cb.append(history_plot_cb)

        logger.info('Training for %d epochs.' % self.config['nb_epoch'])

        result = self.net.fit_generator(
            generator=gen_trn,
            steps_per_epoch=self.config['steps'],
            # samples_per_epoch=max(self.config['batch_size'] * 50, 2048),
            validation_data=gen_val,
            validation_steps=100,
            # nb_val_samples=max(self.config['batch_size'] * 25, 1024),
            epochs=self.config['nb_epoch'],
            callbacks=cb,
            # initial_epoch=0,
            # class_weight='auto', #??? what is this
            verbose=1)

        self.history = result.history
        if self.config['checkpoint_path_history'] != None:
            logger.info('Saving history to %s.' %
                        self.config['checkpoint_path_history'])
            f = open(self.config['checkpoint_path_history'], 'wb')
            pickle.dump(self.history, f)
            f.close()

        return
Пример #10
0
    def train(self):

        logger = logging.getLogger(funcname())

        gen_trn = self.batch_gen(imgs=self.imgs_montage_trn,
                                 msks=self.msks_montage_trn,
                                 infinite=True,
                                 re_seed=True,
                                 batch_size=self.config['batch_size'],
                                 transform=self.config['transform_train'])
        gen_val = self.batch_gen(imgs=self.imgs_montage_val,
                                 msks=self.msks_montage_val,
                                 infinite=True,
                                 re_seed=True,
                                 batch_size=self.config['batch_size'])

        cb = []
        cb.append(
            ReduceLROnPlateau(monitor='val_loss',
                              factor=0.5,
                              patience=5,
                              cooldown=3,
                              min_lr=1e-6,
                              verbose=1))
        cb.append(
            EarlyStopping(monitor='val_loss',
                          min_delta=1e-3,
                          patience=15,
                          verbose=1,
                          mode='min'))
        cb.append(
            ModelCheckpoint(self.checkpoint_name + '_val_loss.net',
                            monitor='val_loss',
                            save_best_only=True,
                            verbose=1))
        cb.append(
            ModelCheckpoint(self.checkpoint_name + '_trn_loss.net',
                            monitor='loss',
                            save_best_only=True,
                            verbose=1))

        history_plot_cb = KerasHistoryPlotCallback()
        history_plot_cb.file_name = self.checkpoint_name + '.history.png'
        cb.append(history_plot_cb)

        logger.info('Training for %d epochs.' % self.config['nb_epoch'])

        result = self.net.fit_generator(
            nb_epoch=self.config['nb_epoch'],
            samples_per_epoch=max(self.config['batch_size'] * 50, 2048),
            generator=gen_trn,
            nb_val_samples=max(self.config['batch_size'] * 25, 1024),
            validation_data=gen_val,
            initial_epoch=0,
            callbacks=cb,
            class_weight='auto',
            verbose=1)

        self.history = result.history

        if self.config['checkpoint_path_history'] != None:
            logger.info('Saving history to %s.' %
                        self.config['checkpoint_path_history'])
            f = open(self.config['checkpoint_path_history'], 'wb')
            pickle.dump(self.history, f)
            f.close()

        return