Beispiel #1
0
def tfreloader(mode, ep, bs, ctr, cte):
    filename = data_dir + '/' + mode + '.tfrecords'
    if mode == 'train':
        ct = ctr
    else:
        ct = cte

    datasets = data_input.DataSet(bs, ct, ep=ep, mode=mode, filename=filename)

    return datasets
Beispiel #2
0
def tfreloader(mode, ep, bs, cls, ctr, cte, cva, data_dir):
    filename = data_dir + '/' + mode + '.tfrecords'
    if mode == 'train':
        ct = ctr
    elif mode == 'test':
        ct = cte
    else:
        ct = cva

    datasets = data_input.DataSet(bs,
                                  ct,
                                  ep=ep,
                                  cls=cls,
                                  mode=mode,
                                  filename=filename)

    return datasets
Beispiel #3
0
    def train(self, data_dir, out_dir, slides, top_k=10,
              valid_data_path=None, sample_rate=None, n_epoch=10, batch_size=128, save=True):

        if save:
            saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)
            outfile = os.path.join(os.path.abspath(out_dir), "{}_MIL_{}_lr_{}_drop_{}".format(
                str(self.datetime), str(self.architecture),
                str(self.learning_rate), str(self.dropout)))

        now = datetime.now().strftime(r"%y-%m-%d %H:%M:%S")

        print("------- Training begin: {} -------\n".format(now))

        slide_fn = tf.placeholder(tf.string, shape=None)
        slide_dataset = data_input.DataSet(inputs=slide_fn, batch_size=64)
        rand_ph = tf.placeholder(tf.float32, shape=None)

        def sample_slide(ds, rand, rate):  # random sample from the whole slide, based on the sample_rate argument
            if not sample_rate:
                iter = ds.shuffled_iter()

            else:
                def sample_fn(data, rind):  # random sample from the whole slide, based on the sample_rate argument
                    return rind < rate

                dat = ds.get_data()
                rds = tf.data.Dataset.from_tensor_slices(rand)
                dat_sample = tf.data.Dataset.zip((dat, rds))
                dat_sample = dat_sample.filter(sample_fn)
                dat_sample = dat_sample.map(lambda dat_, rds_: dat_)
                iter = dat_sample.batch(batch_size=batch_size, drop_remainder=False).make_initializable_iterator()

            return iter

        slide_tfr_iter = sample_slide(slide_dataset, rand_ph, sample_rate)
        next_tfr_batch = slide_tfr_iter.get_next()

        trn_img_ph = tf.placeholder(tf.uint8)
        trn_lab_ph = tf.placeholder(tf.int64)
        trn_ds = tf.data.Dataset.from_tensor_slices((trn_img_ph, trn_lab_ph))
        trn_ds = trn_ds.shuffle(buffer_size=2000).batch(batch_size=batch_size)
        trn_iter = trn_ds.make_initializable_iterator()
        next_trn_batch = trn_iter.get_next()

        if valid_data_path:
            valid_data = data_input.DataSet(inputs=valid_data_path, batch_size=64)
            valid_iter = valid_data.shuffled_iter()
            next_val_batch = valid_iter.get_next()

        try:
            for epoch in range(n_epoch):
                """
                Inference run: get top score tiles from each slide
                """
                now = datetime.now().strftime(r"%y-%m-%d %H:%M:%S")
                print('----------epoch {}: {}----------'.format(epoch, now))
                trn_img_subsets = []
                trn_lab_subsets = []

                for slide in slides:
                    # s_id = slide.split('.')[0]
                    slide_prob = []
                    slide_img = []
                    slide_lab = []
                    slide_counter = 0
                    slide_path = data_dir + '/' + slide

                    self.sesh.run(slide_tfr_iter.initializer,
                                  feed_dict={slide_fn: slide_path, rand_ph: np.random.uniform(0., 1., 200000)})

                    while True:
                        try:
                            imgs, labs = self.sesh.run(next_tfr_batch)
                            batch_pred = self.inference(imgs)[:, 1]
                            batch_top_ind = batch_pred.argsort()[-top_k:]  # index of largest k probabilities
                            slide_counter += imgs.shape[0]

                            for top_ind in batch_top_ind:
                                top_prob = batch_pred[top_ind]
                                #print(top_ind)
                                #print(top_prob)
                                if (slide_counter <= top_k*batch_size or
                                        top_prob >= np.sort(np.array(slide_prob))[-top_k]):
                                    slide_prob.append(batch_pred[top_ind])
                                    slide_img.append(imgs[top_ind])
                                    slide_lab.append(labs[top_ind])
                                else:
                                    pass

                        except tf.errors.OutOfRangeError:
                            break

                        slide_prob_ind = np.asarray(slide_prob).argsort()[:-top_k]
                        for ind in sorted(slide_prob_ind, reverse=True):
                            del slide_prob[ind]
                            del slide_img[ind]
                            del slide_lab[ind]

                    print('{}: {} tiles inferred from slide.'.format(slide, slide_counter))
                    print('Top {} probabilities: '.format(top_k))
                    print(slide_prob)

                    for i in range(len(slide_prob)):
                        trn_img_subsets.append(slide_img[i])
                        trn_lab_subsets.append(slide_lab[i])

                    print('Filtered images: {}'.format(len(trn_img_subsets)))
                    #print('Filtered labels:{}'.format(len(lab_subsets)))

                trn_img_subsets = np.asarray(trn_img_subsets)
                trn_lab_subsets = np.asarray(trn_lab_subsets)
                self.sesh.run(trn_iter.initializer, feed_dict={trn_img_ph: trn_img_subsets,
                                                               trn_lab_ph: trn_lab_subsets})

                err_train = 0
                epoch_batch = 0

                while True:
                    try:
                        train_X, train_Y =self.sesh.run(next_trn_batch)
                        train_X = utils.input_preprocessing(train_X, model=self.architecture)
                        feed = {self.x_in: train_X, self.y_in: train_Y}
                        fetches = [self.merged_summary, self.logits, self.pred,
                                   self.cost, self.global_step, self.train_op]
                        summary, logits, pred, cost, i, _ = self.sesh.run(fetches=fetches, feed_dict=feed)
                        err_train += cost
                        epoch_batch += 1
                    except tf.errors.OutOfRangeError:
                        print('MIL training epoch {} finished.'.format(epoch))
                        print('Global step {}: average train error {}'.format(i, err_train / epoch_batch))
                        break

                self.epoch_trained = epoch

                if valid_data_path:
                    self.sesh.run(valid_iter.initializer)
                    valid_X, valid_Y = self.sesh.run(next_val_batch)
                    valid_X = utils.input_preprocessing(valid_X, model=self.architecture)
                    feed = {self.x_in: valid_X, self.y_in: valid_Y, self.training_status: False}
                    fetches = [self.merged_summary, self.pred,
                               self.cost, self.global_step]
                    summary, pred, cost, i = self.sesh.run(fetches=fetches, feed_dict=feed)
                    self.validation_logger.add_summary(summary, i)
                    print('MIL training epoch {} validation cost: {}'.format(self.epoch_trained, cost))
                if save:
                    saver.save(self.sesh, outfile, global_step=None)
                    print('Trained model saved to {}'.format(outfile))

            try:
                self.train_logger.flush()
                self.train_logger.close()
                self.validation_logger.flush()
                self.validation_logger.close()

            except AttributeError:  # not logging
                print('Not logging')

        except KeyboardInterrupt:
            pass

        now = datetime.now().strftime(r"%y-%m-%d %H:%M:%S")
        print("------- Training end: {} -------\n".format(now), flush=True)
        print('Epochs trained: {}'.format(str(self.epoch_trained)))
        i = self.global_step.eval(session = self.sesh)
        print('Global steps: {}'.format(str(i)))

        if save:
            saver.save(self.sesh, outfile, global_step=None)
            print('Trained model saved to {}'.format(outfile))
Beispiel #4
0
    def pre_train(self, pretrain_data_path, out_dir, valid_data_path=None, n_epoch=10, batch_size=128, save=True):
        """
        Pretrain the model with tile level labels before MIL (the train method below)
        """
        pretrain_data = data_input.DataSet(inputs=pretrain_data_path, batch_size=batch_size)
        if save:
            saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)
            outfile = os.path.join(os.path.abspath(out_dir), "{}_preMIL_{}_lr_{}_drop_{}".format(
                str(self.datetime), str(self.architecture),
                str(self.learning_rate), str(self.dropout)))

        valid_costs = []
        pretrain_iter = pretrain_data.shuffled_iter()
        next_pretrn_batch = pretrain_iter.get_next()

        if valid_data_path:
            preval_data = data_input.DataSet(inputs=valid_data_path, batch_size=batch_size)
            preval_iter = preval_data.shuffled_iter()
            next_preval_batch = preval_iter.get_next()

        now = datetime.now().strftime(r"%y-%m-%d %H:%M:%S")

        print("------- Pre-training begin: {} -------\n".format(now))
        try:
            for epoch in range(n_epoch):

                err_train = 0
                epoch_batch = 0
                self.sesh.run(pretrain_iter.initializer)

                while True:
                    try:
                        pretrain_X, pretrain_Y = self.sesh.run(next_pretrn_batch)
                        pretrain_X = utils.input_preprocessing(pretrain_X, model=self.architecture)
                        feed = {self.x_in: pretrain_X, self.y_in: pretrain_Y}
                        fetches = [self.merged_summary, self.logits, self.pred,
                                   self.cost, self.global_step, self.train_op]
                        summary, logits, pred, cost, i, _ = self.sesh.run(fetches=fetches, feed_dict=feed)
                        self.train_logger.add_summary(summary, i)
                        err_train += cost
                        epoch_batch += 1
                    except tf.errors.OutOfRangeError:
                        i = self.global_step.eval(session=self.sesh)
                        print('Epoch {} finished.'.format(epoch))
                        print('Global step {}: average train error {}'.format(i, err_train / epoch_batch))
                        break
                    self.epoch_pretrained = epoch

                if valid_data_path:
                    self.sesh.run(preval_iter.initializer)
                    valid_X, valid_Y = self.sesh.run(next_preval_batch)
                    valid_X = utils.input_preprocessing(valid_X, model=self.architecture)
                    feed = {self.x_in: valid_X, self.y_in: valid_Y, self.training_status: False}
                    fetches = [self.merged_summary, self.pred,
                                   self.cost, self.global_step]
                    summary, pred, cost, i = self.sesh.run(fetches=fetches, feed_dict=feed)
                    self.validation_logger.add_summary(summary, i)
                    print('Tile pre-training epoch {} validation cost: {}'.format(self.epoch_pretrained, cost))
                    valid_costs.append(cost)
                    min_valid_cost = min(valid_costs)

                    if cost > min_valid_cost:
                        print('Validation cost reached plateau. Pre-training stopped.')
                        break

            try:
                self.pretrain_logger.flush()
                self.pretrain_logger.close()
                self.validation_logger.flush()
                self.validation_logger.close()

            except AttributeError:  # not logging
                print('Not logging')

        except KeyboardInterrupt:
            pass

        now = datetime.now().strftime(r"%y-%m-%d %H:%M:%S")
        print("------- Pre-training end: {} -------\n".format(now), flush=True)
        print('Epochs trained: {}'.format(str(self.epoch_pretrained)))
        i = self.global_step.eval(session=self.sesh)
        print('Global steps: {}'.format(str(i)))

        if save:
            saver.save(self.sesh, outfile, global_step=None)
            print('Pre-trained model saved to {}'.format(outfile))
Beispiel #5
0
def tfreloader(mode, ep, bs):
    filename = data_dir + '/' + mode + '.tfrecords'
    datasets = data_input.DataSet(bs, ep=ep, mode=mode, filename=filename)
    return datasets
Beispiel #6
0
def loader(images, bs, ct):
    dataset = data_input.DataSet(bs, ct, images=images)
    return dataset