Exemplo n.º 1
0
 def __init__(self, batch_size, reuse=False, tf_record_dir=None, num_epochs=0, weighting=[1]*2):
     self.batch_size = batch_size
     record_reader = RecordReader(StandardProcessTup(data_loader))
     x_shape = [-1] + list(segment_size_in) + [1]
     y_shape = [-1] + list(segment_size_out) + [1]
     with tf.device('/cpu:0'):
         with tf.variable_scope("input"):
             if tf_record_dir:
                 if reuse:
                     X, Y = record_reader.input_pipeline(False, batch_size, None, tf_record_dir)
                 else:
                     X, Y = record_reader.input_pipeline(True, batch_size, num_epochs, tf_record_dir)
                 self.X = tf.reshape(X, x_shape)
                 self.Y = tf.reshape(Y, y_shape)
             else:
                 self.X = tf.placeholder(
                     dtype=tf.float32,
                     shape=[None] + segment_size_in.tolist() + [1])
                 self.Y = tf.placeholder(
                     dtype=tf.float32,
                     shape=[None] + segment_size_out.tolist() + [1])
             X = self.X
             Y = tf.cast(tf.one_hot(tf.reshape(tf.cast(self.Y, tf.uint8), [-1]+list(segment_size_out)), 2), tf.float32)
     with tf.variable_scope("inference") as scope:
         if reuse:
             scope.reuse_variables()
             logits = self.build_net(X, reuse=True)
         else:
             logits = self.build_net(X, reuse=False)
         with tf.variable_scope("pred"):
             softmax_logits = tf.nn.softmax(logits)
             self.pred = tf.cast(tf.argmax(softmax_logits, axis=4), tf.float32)
         with tf.variable_scope("dice"):
             self.dice_op = tf.divide(tf.reduce_sum(tf.multiply(softmax_logits, Y)),
                                      tf.reduce_sum(self.pred) + tf.reduce_sum(Y), name='dice')
         with tf.variable_scope("loss") as scope:
             self.loss_op = tf.reduce_mean(
                 tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=Y),
                 name='cross_entropy')
             #reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
             #reg_constant = 0.0005  # Choose an appropriate one.
             #self.loss_op += reg_constant * sum(reg_losses)
         # Choose the metrics to compute:
         names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map({
             'accuracy': tf.contrib.metrics.streaming_accuracy(softmax_logits, Y),
             'precision': tf.contrib.metrics.streaming_precision(softmax_logits, Y),
             'recall': tf.contrib.metrics.streaming_recall(softmax_logits, Y),
             'mse': tf.contrib.metrics.streaming_mean_squared_error(softmax_logits, Y),
         })
         self.mse = names_to_values['mse']
         with tf.variable_scope("metrics"):
             self.metric_update_ops = list(names_to_updates.values())
         if tf_record_dir:
             tf.summary.scalar('dice', self.dice_op)
             #tf.summary.scalar('precision', self.precision_op)
             #tf.summary.scalar('recall', self.recall_op)
             #tf.summary.scalar('mse', self.mse_op)
             tf.summary.scalar('loss', self.loss_op)
         for metric_name, metric_value in names_to_values.items():
             op = tf.summary.scalar(metric_name, metric_value)
Exemplo n.º 2
0
    def __init__(self, batch_size, reuse=False, tf_record_dir=None, num_epochs=0, weighting=[1]*2):
        self.batch_size = batch_size
        record_reader = RecordReader(StandardProcessTup(data_loader))
        with tf.device('/cpu:0'):
            with tf.variable_scope("input"):
                if tf_record_dir:
                    x_shape = [-1] + list(segment_size_in) + [1]
                    y_shape = [-1] + list(segment_size_out) + [1]
                    if reuse:
                        X, Y = record_reader.input_pipeline(False, batch_size, None, tf_record_dir)
                    else:
                        X, Y = record_reader.input_pipeline(True, batch_size, num_epochs, tf_record_dir)
                    self.X = tf.reshape(X, x_shape)
                    self.Y = tf.reshape(Y, y_shape)
                else:
                    self.X = tf.placeholder(
                        dtype=tf.float32,
                        shape=[None] + segment_size_in.tolist() + [1])
                    self.Y = tf.placeholder(
                        dtype=tf.float32,
                        shape=[None] + segment_size_out.tolist() + [1])
                X = self.X
                Y = tf.cast(tf.one_hot(tf.reshape(tf.cast(self.Y, tf.uint8), [-1] + list(segment_size_out)), 2),
                            tf.float32)
                X_A = tf.split(X, 2)
                Y_A = tf.split(Y, 2)

        with tf.variable_scope("inference") as scope:
            # if tf_record_dir:
            losses = []
            preds = []
            for gpu_id in range(2):
                with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)):
                    with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
                        if gpu_id == 0:
                            logits, self.feats = self.build_net(X_A[gpu_id], False)
                        else:
                            logits, self.feats = self.build_net(X_A[gpu_id], True)
                        with tf.variable_scope("pred"):
                            softmax_logits = tf.nn.softmax(logits)
                            pred = tf.cast(tf.argmax(softmax_logits, axis=4), tf.float32)
                            preds.append(pred)
                        with tf.variable_scope("dice"):
                            self.dice_op = tf.divide(2 * tf.reduce_sum(tf.multiply(softmax_logits, Y_A[gpu_id])),
                                                     tf.reduce_sum(pred) + tf.reduce_sum(Y_A[gpu_id]), name='dice')
                        with tf.variable_scope("loss"):
                            class_weight = tf.constant(weighting, tf.float32)
                            weighted_logits = tf.multiply(logits, tf.reshape(class_weight, [-1, 1, 1, 1, 2]))
                            loss_op = tf.nn.softmax_cross_entropy_with_logits(logits=weighted_logits,
                                                                              labels=Y_A[gpu_id])
                            losses.append(loss_op)
                        # Choose the metrics to compute:
                        names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map({
                            'accuracy': tf.contrib.metrics.streaming_accuracy(softmax_logits, Y_A[gpu_id]),
                            'precision': tf.contrib.metrics.streaming_precision(softmax_logits, Y_A[gpu_id]),
                            'recall': tf.contrib.metrics.streaming_recall(softmax_logits, Y_A[gpu_id]),
                            'mse': tf.contrib.metrics.streaming_mean_squared_error(softmax_logits, Y_A[gpu_id]),
                        })
            self.loss_op = tf.reduce_mean(tf.concat(losses, axis=0))
            self.pred = tf.cast(tf.concat(preds, axis=0), tf.float32)
            self.mse = names_to_values['mse']
            with tf.variable_scope("metrics"):
                self.metric_update_ops = list(names_to_updates.values())
            if tf_record_dir:
                tf.summary.scalar('dice', self.dice_op)
                tf.summary.scalar('loss', self.loss_op)
            for metric_name, metric_value in names_to_values.items():
                op = tf.summary.scalar(metric_name, metric_value)