示例#1
0
class DRAM(object):
    def __init__(self, config):
        self.config = config
        self.data_init()
        self.model_init()

    def data_init(self):
        print("\nData init")
        self.dataset = Dataset(self.config)
        self.generator = Generator(self.config, self.dataset)

    def model_init(self):

        self.rnn_cell = tf.contrib.rnn
        self.config = config
        self.regularizer = tf.contrib.layers.l2_regularizer(
            scale=self.config.regularizer)
        self.initializer = tf.contrib.layers.xavier_initializer()
        self.images_ph = tf.placeholder(
            tf.float32,
            [None, self.config.input_shape, self.config.input_shape, 3])
        self.labels_ph = tf.placeholder(tf.int64, [None])
        self.N = tf.shape(self.images_ph)[0]

        # ------- GlimpseNet / LocNet -------

        with tf.variable_scope('glimpse_net'):
            self.gl = ConvGlimpseNetwork(self.config, self.images_ph)

        with tf.variable_scope('loc_net'):
            self.loc_net = LocNet(self.config)

        self.init_loc = tf.zeros(shape=[self.N, 2], dtype=tf.float32)
        with tf.variable_scope("rnn_decoder/loop_function",
                               reuse=tf.AUTO_REUSE):
            self.init_glimpse = self.gl(self.init_loc)

        self.inputs = [self.init_glimpse]
        self.inputs.extend([0] * (self.config.num_glimpses - 1))

        # ------- Recurrent network -------

        def get_next_input(output, i):

            loc, loc_mean = self.loc_net(output)
            gl_next = self.gl(loc)

            self.loc_mean_arr.append(loc_mean)
            self.sampled_loc_arr.append(loc)
            self.glimpses.append(self.gl.glimpse)

            return gl_next

        def rnn_decoder(decoder_inputs,
                        initial_state,
                        cell,
                        loop_function=None):

            with tf.variable_scope("rnn_decoder"):
                state = initial_state
                outputs = []
                prev = None

                for i, inp in enumerate(decoder_inputs):
                    if loop_function is not None and prev is not None:
                        with tf.variable_scope("loop_function",
                                               reuse=tf.AUTO_REUSE):
                            inp = loop_function(prev, i)

                    if i > 0:
                        tf.get_variable_scope().reuse_variables()

                    output, state = cell(inp, state)
                    outputs.append(output)

                    if loop_function is not None:
                        prev = output

            return outputs, state

        self.loc_mean_arr = [self.init_loc]
        self.sampled_loc_arr = [self.init_loc]
        self.glimpses = [self.gl.glimpse]

        self.lstm_cell = self.rnn_cell.LSTMCell(self.config.cell_size,
                                                state_is_tuple=True,
                                                activation=tf.nn.tanh,
                                                forget_bias=1.)
        self.init_state = self.lstm_cell.zero_state(self.N, tf.float32)
        self.outputs, self.rnn_state = rnn_decoder(
            self.inputs,
            self.init_state,
            self.lstm_cell,
            loop_function=get_next_input)

        # ------- Classification -------

        baselines = []
        for t, output in enumerate(self.outputs):
            with tf.variable_scope('baseline', reuse=tf.AUTO_REUSE):
                baseline_t = tf.layers.dense(
                    inputs=output,
                    units=2,
                    kernel_initializer=self.initializer)
            baseline_t = tf.squeeze(baseline_t)
            baselines.append(baseline_t)

        baselines = tf.stack(baselines)
        self.baselines = tf.transpose(baselines)

        with tf.variable_scope('classification', reuse=tf.AUTO_REUSE):
            self.class_prob_arr = []
            for t, op in enumerate(self.outputs):
                self.glimpse_logit = tf.layers.dense(
                    inputs=op,
                    units=self.config.num_classes,
                    kernel_initializer=self.initializer,
                    name='FCCN',
                    reuse=tf.AUTO_REUSE)
                self.glimpse_logit = tf.stop_gradient(self.glimpse_logit)
                self.glimpse_logit = tf.nn.softmax(self.glimpse_logit)
                self.class_prob_arr.append(self.glimpse_logit)
            self.class_prob_arr = tf.stack(self.class_prob_arr, axis=1)

        self.output = self.outputs[-1]
        with tf.variable_scope('classification', reuse=tf.AUTO_REUSE):
            self.logits = tf.layers.dense(inputs=self.output,
                                          units=self.config.num_classes,
                                          kernel_initializer=self.initializer,
                                          name='FCCN',
                                          reuse=tf.AUTO_REUSE)

            self.softmax = tf.nn.softmax(self.logits)

        self.sampled_locations = tf.concat(self.sampled_loc_arr, axis=0)
        self.mean_locations = tf.concat(self.loc_mean_arr, axis=0)
        self.sampled_locations = tf.reshape(
            self.sampled_locations, (self.config.num_glimpses, self.N, 2))
        self.sampled_locations = tf.transpose(self.sampled_locations,
                                              [1, 0, 2])
        self.mean_locations = tf.reshape(self.mean_locations,
                                         (self.config.num_glimpses, self.N, 2))
        self.mean_locations = tf.transpose(self.mean_locations, [1, 0, 2])
        prefix = tf.expand_dims(self.init_loc, 1)
        self.sampled_locations = tf.concat([prefix, self.sampled_locations],
                                           axis=1)
        self.mean_locations = tf.concat([prefix, self.mean_locations], axis=1)
        self.glimpses = tf.stack(self.glimpses, axis=1)

        # Losses/reward

        def loglikelihood(mean_arr, sampled_arr, sigma):
            mu = tf.stack(mean_arr)
            sampled = tf.stack(sampled_arr)
            gaussian = tf.contrib.distributions.Normal(mu, sigma)
            logll = gaussian.log_prob(sampled)
            logll = tf.reduce_sum(logll, 2)
            logll = tf.transpose(logll)
            return logll

        self.xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=self.logits, labels=self.labels_ph)
        self.xent = tf.reduce_mean(self.xent)

        self.pred_labels = tf.argmax(self.logits, 1)
        self.reward = tf.cast(tf.equal(self.pred_labels, self.labels_ph),
                              tf.float32)
        self.rewards = tf.expand_dims(self.reward, 1)
        self.rewards = tf.tile(self.rewards, [1, self.config.num_glimpses])
        self.logll = loglikelihood(self.loc_mean_arr, self.sampled_loc_arr,
                                   self.config.loc_std)
        self.advs = self.rewards - tf.stop_gradient(self.baselines)
        self.logllratio = tf.reduce_mean(self.logll * self.advs)

        self.reward = tf.reduce_mean(self.reward)

        self.baselines_mse = tf.reduce_mean(
            tf.square((self.rewards - self.baselines)))
        self.var_list = tf.trainable_variables()

        self.loss = -self.logllratio + self.xent + self.baselines_mse
        self.grads = tf.gradients(self.loss, self.var_list)
        self.grads, _ = tf.clip_by_global_norm(self.grads,
                                               self.config.max_grad_norm)

        self.setup_optimization()

        # session
        self.session_config = tf.ConfigProto()
        self.session_config.gpu_options.visible_device_list = self.config.gpu
        self.session_config.gpu_options.allow_growth = True
        self.session = tf.Session(config=self.session_config)
        self.session.run(tf.global_variables_initializer())

    def setup_optimization(self):

        # learning rate
        self.global_step = tf.get_variable(
            'global_step', [],
            initializer=tf.constant_initializer(0),
            trainable=False)

        self.training_steps_per_epoch = int(
            len(self.generator.training_ids) // self.config.batch_size)
        print('Training Step Per Epoch:', self.training_steps_per_epoch)

        self.starter_learning_rate = self.config.lr_start
        self.learning_rate = tf.train.exponential_decay(
            self.starter_learning_rate,
            self.global_step,
            self.training_steps_per_epoch,
            0.70,
            staircase=False)
        self.learning_rate = tf.maximum(self.learning_rate, self.config.lr_min)
        self.optimizer = tf.train.MomentumOptimizer(self.learning_rate,
                                                    momentum=0.90,
                                                    use_nesterov=True)
        #self.optimizer = tf.train.AdamOptimizer(self.learning_rate)
        self.train_op = self.optimizer.apply_gradients(
            zip(self.grads, self.var_list), global_step=self.global_step)

    def setup_logger(self):
        """Creates log directory and initializes logger."""

        self.summary_ops = {
            'reward': tf.summary.scalar('reward', self.reward),
            'hybrid_loss': tf.summary.scalar('hybrid_loss', self.loss),
            'cross_entropy': tf.summary.scalar('cross_entropy', self.xent),
            'baseline_mse': tf.summary.scalar('baseline_mse',
                                              self.baselines_mse),
            'logllratio': tf.summary.scalar('logllratio', self.logllratio),
            'lr': tf.summary.scalar('lr', self.learning_rate)
        }
        # 'glimpses': tf.summary.image('glimpses',tf.reshape(self.glimpses,[-1,self.config.glimpse_size,
        #                                                                  self.config.glimpse_size,
        #                                                                 3]),max_outputs=8)}

        self.eval_ops = {
            'labels': self.labels_ph,
            'pred_labels': self.pred_labels,
            'reward': self.reward,
            'hybrid_loss': self.loss,
            'cross_entropy': self.xent,
            'baseline_mse': self.baselines_mse,
            'logllratio': self.logllratio,
            'lr': self.learning_rate
        }

        self.logger = Logger(self.config.logdir,
                             sess=self.session,
                             summary_ops=self.summary_ops,
                             global_step=self.global_step,
                             eval_ops=self.eval_ops,
                             n_verbose=self.config.n_verbose,
                             var_list=self.var_list)

    def train(self):

        print('\n\n\n------------ Starting training ------------  \nT -- %s x %s \n' \
              'Model:  %s glimpses, glimpse size %s x %s \n\n\n' % (
                  self.config.input_shape, self.config.input_shape, self.config.num_glimpses, self.config.glimpse_size,
                  self.config.glimpse_size))

        self.setup_logger()

        for i in range(self.config.steps + 1):

            loc_dir_name = self.config.logdir + '/image/locations'
            traj_dir_name = self.config.logdir + '/image/trajectories'
            ROCs_dir_name = self.config.logdir + '/metrics/ROCs_AUCs/'
            PRs_dir_name = self.config.logdir + '/metrics/PRs/'

            if i == 0:
                if os.path.exists(loc_dir_name):
                    shutil.rmtree(loc_dir_name)
                    os.makedirs(loc_dir_name)
                else:
                    os.makedirs(loc_dir_name)

                if os.path.exists(traj_dir_name):
                    shutil.rmtree(traj_dir_name)
                    os.makedirs(traj_dir_name)
                else:
                    os.makedirs(traj_dir_name)

                if os.path.exists(ROCs_dir_name):
                    shutil.rmtree(ROCs_dir_name)
                    os.makedirs(ROCs_dir_name)
                else:
                    os.makedirs(ROCs_dir_name)

                if os.path.exists(PRs_dir_name):
                    shutil.rmtree(PRs_dir_name)
                    os.makedirs(PRs_dir_name)
                else:
                    os.makedirs(PRs_dir_name)

            self.logger.step = i

            images, labels = self.generator.generate()
            images = images.reshape(
                (-1, self.config.input_shape, self.config.input_shape, 3))
            labels = labels[0]
            feed_dict = {self.images_ph: images, self.labels_ph: labels}

            fetches = [
                self.output, self.rewards, self.reward, self.labels_ph,
                self.pred_labels, self.logits, self.train_op, self.loss,
                self.xent, self.baselines_mse, self.logllratio,
                self.learning_rate, self.loc_mean_arr
            ]
            output, rewards, reward, real_labels, pred_labels, logits, _, hybrid_loss, cross_entropy, baselines_mse, logllratio, lr, locations = self.session.run(
                fetches, feed_dict)

            if i % 1 == 0:

                print('\n------ Step %s ------' % (i))
                print('reward', reward)
                print('labels', real_labels)
                print('pred_labels', pred_labels)
                print('hybrid_loss', hybrid_loss)
                print('cross_entropy', cross_entropy)
                print('baseline_mse', baselines_mse)
                print('logllratio', logllratio)
                print('lr', lr)
                print('locations', locations[-1])
                print('logits', logits)
                self.logger.log('train', feed_dict=feed_dict)

            #if  i > 0 and i % 100 == 0:

            #   self.eval(i)
            #  self.logger.log('val', feed_dict=feed_dict)

            if i == self.config.steps:

                self.test(i)

            #if i == self.config.steps:
        # if i > 0 and i % 100 == 0:

        #    glimpse_images = self.session.run(self.glimpses, feed_dict)
        #   mean_locations = self.session.run(self.mean_locations, feed_dict)
        #  probs = self.session.run(self.class_prob_arr, feed_dict)

        # plot_glimpses(config=self.config, glimpse_images=glimpse_images, pred_labels=pred_labels, probs=probs,
        #   sampled_loc=mean_locations, X=images, labels=real_labels, file_name=loc_dir_name, step=i)

        #plot_trajectories(config=self.config, locations=mean_locations, X=images, labels=real_labels,
        #   pred_labels=pred_labels, file_name=traj_dir_name, step=i)

        #self.logger.save()

    def eval(self, step):
        return self.evaluate(self.session, self.images_ph, self.labels_ph,
                             self.softmax, step)

    def evaluate(self, sess, images_ph, labels_ph, softmax, step):
        print('Evaluating (%s x %s) using %s glimpses' %
              (self.config.input_shape, self.config.input_shape,
               self.config.num_glimpses))
        self.X_val, self.y_val = self.dataset.convert_to_arrays(
            self.dataset._partition[0]['val'],
            size=self.config.sampling_size_val)
        print('Validation set has %s patients' % len(self.y_val))

        X_val, y_val = self.X_val, self.y_val

        _num_examples = X_val.shape[0]
        steps_per_epoch = _num_examples // self.config.eval_batch_size

        y_scores = []
        y_trues = []

        for i in tqdm(iter(range(steps_per_epoch))):

            images, labels_val = self.dataset.next_batch(
                X_val, y_val[0], self.config.eval_batch_size, i)
            #images = images.reshape((-1, self.config.input_shape, self.config.input_shape, 3))

            softmax_val = sess.run(softmax,
                                   feed_dict={
                                       images_ph: images,
                                       labels_ph: labels_val
                                   })
            y_trues.extend(labels_val)
            y_scores.extend(softmax_val)

        y_preds = np.argmax(y_scores, 1)
        y_scores = np.array(y_scores)

        self.metrics_ROCs(y_trues, y_preds, y_scores, step)
        self.metrics(y_trues, y_preds, step)
        return

    def count_params(self):
        return self.count_parameters(self.session)

    def count_parameters(self, sess):
        variables_names = [v.name for v in tf.trainable_variables()]
        values = sess.run(variables_names)
        n_params = 0

        for k, v in zip(variables_names, values):
            print('-'.center(140, '-'))
            print('%s \t Shape: %s \t %s parameters' % (k, v.shape, v.size))
            n_params += v.size

        print('-'.center(140, '-'))
        print('Total # parameters:\t\t %s \n\n' % (n_params))
        return n_params

    def metrics_ROCs(self, y_trues, y_preds, y_scores, step, stage=None):

        y_trues_binary = label_binarize(
            y_trues, classes=list(self.dataset.le_name_mapping.values()))
        y_preds_binary = label_binarize(
            y_preds, classes=list(self.dataset.le_name_mapping.values()))
        n_classes = y_preds_binary.shape[1]
        if stage == 'test':
            fpr, tpr, _ = roc_curve(y_trues, y_scores)
        else:
            fpr, tpr, _ = roc_curve(y_trues, y_scores[:, 1])

        roc_auc = auc(fpr, tpr)

        plt.figure()

        plt.plot(fpr,
                 tpr,
                 label='ROC curve (AUC = {0:0.2f})'
                 ''.format(roc_auc),
                 color='navy',
                 linestyle=':',
                 linewidth=4)

        plt.plot([0, 1], [0, 1], 'k--', lw=2)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiving Operating Characteristic Curves')
        plt.legend(loc="lower right")
        plt.savefig(self.config.logdir + '/metrics/ROCs_AUCs/%i' % step)
        return

    def metrics(self, y_trues, y_preds, step):
        #        y_trues_binary= label_binarize(y_trues, classes=list(self.dataset.le_name_mapping.values()))
        #       y_preds_binary= label_binarize(y_preds, classes=list(self.dataset.le_name_mapping.values()))

        accuracy = accuracy_score(y_trues, y_preds)
        f1score = f1_score(y_trues, y_preds)
        recall = recall_score(y_trues, y_preds)
        precision = precision_score(y_trues, y_preds)
        names = ['accuracy', 'f1_score', 'recall', 'precision']
        pd.DataFrame(data=np.array([accuracy, f1score, recall, precision]),
                     index=names).to_csv(self.config.logdir +
                                         '/metrics/metrics_%i.csv' % step)
        return

    def load(self, checkpoint_dir):
        folder = os.path.join(checkpoint_dir, 'checkpoints')
        print('\nLoading model from <<{}>>.\n'.format(folder))

        self.saver = tf.train.Saver(self.var_list)
        ckpt = tf.train.get_checkpoint_state(folder)

        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt)
            self.saver.restore(self.session, ckpt.model_checkpoint_path)

    def patch_to_image(self, y_patches, proba=True):

        if proba == True:
            y_image = np.array([
                np.mean(y_patches[i * self.config.sampling_size_test:(i + 1) *
                                  self.config.sampling_size_test],
                        axis=0)
                for i in range(
                    int(len(y_patches) / self.config.sampling_size_test))
            ])

        else:
            y_image = np.array([
                np.mean(y_patches[i * self.config.sampling_size_test:(i + 1) *
                                  self.config.sampling_size_test]) > 0.5
                for i in range(
                    int(len(y_patches) / self.config.sampling_size_test))
            ]).reshape((-1, 1)).astype(int)
            y_image = np.asarray(y_image.flatten())
        return y_image

    def test(self, step):
        return self.testing(self.session, self.images_ph, self.labels_ph,
                            self.softmax, step)

    def testing(self, sess, images_ph, labels_ph, softmax, step):
        print('Testing (%s x %s) using %s glimpses' %
              (self.config.input_shape, self.config.input_shape,
               self.config.num_glimpses))
        print(self.dataset._partition[0]['test'])
        self.X_test, self.y_test = self.dataset.convert_to_arrays(
            self.dataset._partition[0]['test'],
            size=self.config.sampling_size_test)
        X_test, y_test = self.X_test, self.y_test
        print('y_test', y_test)
        _num_examples = X_test.shape[0]
        steps_per_epoch = _num_examples // self.config.test_batch_size

        y_scores = []
        y_trues = []

        for i in tqdm(iter(range(steps_per_epoch))):

            images, labels_test = self.dataset.next_batch(
                X_test, y_test[0], self.config.test_batch_size, i)

            print(labels_test)
            #images = images.reshape((-1, self.config.input_shape, self.config.input_shape, 3))

            softmax_test = sess.run(softmax,
                                    feed_dict={
                                        images_ph: images,
                                        labels_ph: labels_test
                                    })
            y_trues.extend(labels_test)
            y_scores.extend(softmax_test)

        y_trues = self.patch_to_image(y_trues, proba=False)
        y_scores = self.patch_to_image(y_scores, proba=True)

        y_preds = np.argmax(y_scores, 1)

        print('Test Set', self.dataset._partition[0]['test'])
        print(y_trues)
        print(y_preds)

        self.metrics_ROCs(y_trues, y_preds, y_scores, step)
        self.metrics(y_trues, y_preds, step)
        return
示例#2
0
class Model(object):
    def __init__(self, config):

        self.config = config
        self.data_init()
        self.model_init()

    def data_init(self):

        print("\nData init")
        #self.dataset = TCGA_Dataset(self.config)
        self.dataset = Dataset(self.config)

        generator = Generator(self.config, self.dataset)
        self.train_generator = generator.generate()

        self.X_val, self.y_val = self.dataset.convert_to_arrays(
            self.dataset._partition[0]['val'],
            self.dataset._partition[1]['val'],
            phase='val',
            size=self.config.sampling_size_val)

        self.X_test, self.y_test = self.dataset.convert_to_arrays(
            self.dataset._partition[0]['test'],
            self.dataset._partition[1]['test'],
            phase='test',
            size=self.config.sampling_size_test)

        self.y_test = self.patch_to_image(self.y_test, proba=False)

    def plot_ROCs(self, y_scores):

        fig = plt.figure(figsize=(10, 10))
        y_true = self.y_test
        y_score = y_scores
        fpr, tpr, _ = roc_curve(y_true, y_score)
        auc = roc_auc_score(y_true, y_score)
        plt.plot(fpr,
                 tpr,
                 lw=2,
                 c='r',
                 alpha=0.8,
                 label=r'%s (AUC = %0.2f)' % auc)
        plt.plot([0, 1], [0, 1],
                 linestyle='--',
                 lw=2,
                 color='black',
                 label='Luck',
                 alpha=.8)
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title("ROC curve")
        plt.legend(loc="lower right")
        fig.savefig("output/ROC_curve")
        plt.close()

    def plot_PRs(self, y_scores):

        fig = plt.figure(figsize=(10, 10))

        y_true = self.y_test
        y_score = y_scores

        precision, recall, _ = precision_recall_curve(y_true, y_score)
        plt.plot(recall,
                 precision,
                 lw=2,
                 c='b',
                 alpha=0.8,
                 label=r'PR curve (AP = %0.2f)' % (precision))
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title("PR curve")
        plt.legend(loc="lower right")
        fig.savefig("output/PR_curve")
        plt.close()

    def model_init(self):

        print("\nModel init")
        self.base_model = DenseNet169(include_top=False,
                                      weights='imagenet',
                                      input_shape=(224, 224, 3),
                                      pooling=None)
        x = self.base_model.output
        x = GlobalAveragePooling2D()(x)
        x = Dense(2048, activation='relu', kernel_regularizer=l2(0.1))(x)
        x = Dropout(0.30)(x, training=True)
        x = Dense(100, activation='relu', kernel_regularizer=l2(0.1))(x)
        x = Dropout(0.30)(x)
        output = Dense(1, activation='sigmoid')(x)
        self.model = keras.models.Model(inputs=self.base_model.input,
                                        outputs=output)

    def set_trainable(self, from_idx=0):

        print("\nTraining")
        #for layer in self.base_model.layers:
        #   layer.trainable = False
        for layer in self.model.layers[0:]:
            layer.trainable = True

    def train(self, lr=1e-4, epochs=10, from_idx=0):

        self.set_trainable()
        optimizer = Adam(lr=lr,
                         beta_1=0.9,
                         beta_2=0.999,
                         epsilon=1e-08,
                         decay=self.config.lr_decay)
        self.model.compile(optimizer=optimizer,
                           loss='binary_crossentropy',
                           metrics=['accuracy'])
        train_steps = len(
            self.dataset._partition[0]['train']) / self.config.batch_size
        early_stopping = EarlyStopping(monitor='val_loss',
                                       min_delta=0,
                                       patience=5,
                                       verbose=0,
                                       mode='auto')
        self.history = custom_fit_generator(model=self.model,
                                            generator=self.train_generator,
                                            steps_per_epoch=train_steps,
                                            epochs=epochs,
                                            verbose=1,
                                            validation_data=(self.X_val,
                                                             self.y_val),
                                            shuffle=True,
                                            max_queue_size=30,
                                            workers=30,
                                            use_multiprocessing=True,
                                            callbacks=[early_stopping])

    def predict(self):

        df = self.dataset.get_binarized_data()
        ids = df.index
        labels = df.values

        print("\nPredicting")
        intermediate_layer_model = keras.models.Model(
            inputs=self.base_model.input, outputs=self.model.layers[-1].output)

        self.X_feat, self.y_feat = self.dataset.convert_to_arrays(
            list(ids),
            labels,
            phase='train',
            size=self.config.sample_size_feat)

        ids = np.asarray(ids)
        ids = np.repeat(ids, self.config.sample_size_feat)

        for i in range(10):
            intermediate_output = intermediate_layer_model.predict(self.X_feat)
            features = pd.DataFrame(data=intermediate_output, index=ids)
            features["ids"] = ids
            features = features.groupby(["ids"]).mean()
            features.to_csv("pathology_scores_%s.csv" % i)

#       intermediate_output = intermediate_layer_model.predict(self.X_test, batch_size= self.config.batch_size)
#  print(len(intermediate_output))
#  print(len(intermediate_output[1]))
#       print('intermediate_output',intermediate_output.shape)

#  y_scores = self.model.predict(self.X_test, batch_size= self.config.batch_size)

#   y_scores = self.patch_to_image(y_scores, proba=True)
#  print('y_scores', y_scores)
#  y_preds = np.array([(y_score>0.5).astype(int) for y_score in y_scores]).flatten()
#  pd.DataFrame(data = y_preds, index =self.dataset._partition[0]['test'] ).to_csv('Results.csv')
# print(self.dataset._partition[0]['test'], y_preds)
# return y_scores, y_preds

    def train_predict(self):

        self.train(self.config.lr, self.config.epochs, self.config.from_idx)
        # self.plot_loss()
        # y_scores, y_preds = self.predict()
        self.predict()

    #  np.save("output/y_scores", y_scores)
    #  np.save("output/y_preds", y_preds)

    # return y_scores, y_preds

    def patch_to_image(self, y_patches, proba=True):

        if proba == True:
            y_image = np.array([
                np.mean(y_patches[i * self.config.sampling_size_test:(i + 1) *
                                  self.config.sampling_size_test])
                for i in range(
                    int(len(y_patches) / self.config.sampling_size_test))
            ]).reshape((-1, 1))
        else:
            y_image = np.array([
                np.mean(y_patches[i * self.config.sampling_size_test:(i + 1) *
                                  self.config.sampling_size_test]) > 0.5
                for i in range(
                    int(len(y_patches) / self.config.sampling_size_test))
            ]).reshape((-1, 1)).astype(int)
        y_image = np.asarray(y_image.flatten())
        return y_image

    def plot_loss(self):

        keys = list(self.history.history.keys())
        val_acc_keys = [
            key for key in keys if key[0:3] == "val" and key[-3:] == "acc"
        ]
        acc_keys = [
            key for key in keys if key[0:3] != "val" and key[-3:] == "acc"
        ]
        val_acc = np.mean([self.history.history[key] for key in val_acc_keys],
                          axis=0)
        acc = np.mean([self.history.history[key] for key in acc_keys], axis=0)
        loss = self.history.history["loss"]
        val_loss = self.history.history["val_loss"]

        fig = plt.figure(figsize=(10, 10))
        ax1 = plt.subplot(121)
        ax1.tick_params(labelsize=10)
        plt.plot(acc)
        plt.plot(val_acc)
        plt.title('Mean accuracy', size=14)
        plt.ylabel('accuracy', size=12)
        plt.xlabel('epoch', size=12)
        plt.legend(['train', 'test'], loc='upper left', fontsize=12)
        ax2 = plt.subplot(122)
        ax2.tick_params(labelsize=10)
        plt.plot(loss)
        plt.plot(val_loss)
        plt.title('Mean loss', size=14)
        plt.ylabel('loss', size=12)
        plt.xlabel('epoch', size=12)
        plt.legend(['train', 'test'], loc='upper left', fontsize=12)
        plt.show()
        fig.savefig("output/learning_curve")
        plt.close()

    def plot(self):

        print("\nPlotting model")
        plot_model(self.model, to_file='output/model.png')

    def get_metrics(self, y_scores, y_preds):
        list_of_metrics = [
            "accuracy", "precision", "recall", "f1score", "AUC", "AP"
        ]
        self.metrics = pd.DataFrame(data=np.zeros((1, len(list_of_metrics))),
                                    columns=list_of_metrics)

        #      y_true = self.y_test
        y_pred = y_preds
        y_score = y_scores

        #     print('y_true',y_true)
        print('y_score', y_score)
        print('y_pred', y_pred)