def __init__(self, batch_size, reuse=False, tf_record_dir=None, num_epochs=0, weighting=[1]*2): self.batch_size = batch_size record_reader = RecordReader(StandardProcessTup(data_loader)) x_shape = [-1] + list(segment_size_in) + [1] y_shape = [-1] + list(segment_size_out) + [1] with tf.device('/cpu:0'): with tf.variable_scope("input"): if tf_record_dir: if reuse: X, Y = record_reader.input_pipeline(False, batch_size, None, tf_record_dir) else: X, Y = record_reader.input_pipeline(True, batch_size, num_epochs, tf_record_dir) self.X = tf.reshape(X, x_shape) self.Y = tf.reshape(Y, y_shape) else: self.X = tf.placeholder( dtype=tf.float32, shape=[None] + segment_size_in.tolist() + [1]) self.Y = tf.placeholder( dtype=tf.float32, shape=[None] + segment_size_out.tolist() + [1]) X = self.X Y = tf.cast(tf.one_hot(tf.reshape(tf.cast(self.Y, tf.uint8), [-1]+list(segment_size_out)), 2), tf.float32) with tf.variable_scope("inference") as scope: if reuse: scope.reuse_variables() logits = self.build_net(X, reuse=True) else: logits = self.build_net(X, reuse=False) with tf.variable_scope("pred"): softmax_logits = tf.nn.softmax(logits) self.pred = tf.cast(tf.argmax(softmax_logits, axis=4), tf.float32) with tf.variable_scope("dice"): self.dice_op = tf.divide(tf.reduce_sum(tf.multiply(softmax_logits, Y)), tf.reduce_sum(self.pred) + tf.reduce_sum(Y), name='dice') with tf.variable_scope("loss") as scope: self.loss_op = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=Y), name='cross_entropy') #reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) #reg_constant = 0.0005 # Choose an appropriate one. #self.loss_op += reg_constant * sum(reg_losses) # Choose the metrics to compute: names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map({ 'accuracy': tf.contrib.metrics.streaming_accuracy(softmax_logits, Y), 'precision': tf.contrib.metrics.streaming_precision(softmax_logits, Y), 'recall': tf.contrib.metrics.streaming_recall(softmax_logits, Y), 'mse': tf.contrib.metrics.streaming_mean_squared_error(softmax_logits, Y), }) self.mse = names_to_values['mse'] with tf.variable_scope("metrics"): self.metric_update_ops = list(names_to_updates.values()) if tf_record_dir: tf.summary.scalar('dice', self.dice_op) #tf.summary.scalar('precision', self.precision_op) #tf.summary.scalar('recall', self.recall_op) #tf.summary.scalar('mse', self.mse_op) tf.summary.scalar('loss', self.loss_op) for metric_name, metric_value in names_to_values.items(): op = tf.summary.scalar(metric_name, metric_value)
def __init__(self, batch_size, reuse=False, tf_record_dir=None, num_epochs=0, weighting=[1]*2): self.batch_size = batch_size record_reader = RecordReader(StandardProcessTup(data_loader)) with tf.device('/cpu:0'): with tf.variable_scope("input"): if tf_record_dir: x_shape = [-1] + list(segment_size_in) + [1] y_shape = [-1] + list(segment_size_out) + [1] if reuse: X, Y = record_reader.input_pipeline(False, batch_size, None, tf_record_dir) else: X, Y = record_reader.input_pipeline(True, batch_size, num_epochs, tf_record_dir) self.X = tf.reshape(X, x_shape) self.Y = tf.reshape(Y, y_shape) else: self.X = tf.placeholder( dtype=tf.float32, shape=[None] + segment_size_in.tolist() + [1]) self.Y = tf.placeholder( dtype=tf.float32, shape=[None] + segment_size_out.tolist() + [1]) X = self.X Y = tf.cast(tf.one_hot(tf.reshape(tf.cast(self.Y, tf.uint8), [-1] + list(segment_size_out)), 2), tf.float32) X_A = tf.split(X, 2) Y_A = tf.split(Y, 2) with tf.variable_scope("inference") as scope: # if tf_record_dir: losses = [] preds = [] for gpu_id in range(2): with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)): with tf.variable_scope("model", reuse=tf.AUTO_REUSE): if gpu_id == 0: logits, self.feats = self.build_net(X_A[gpu_id], False) else: logits, self.feats = self.build_net(X_A[gpu_id], True) with tf.variable_scope("pred"): softmax_logits = tf.nn.softmax(logits) pred = tf.cast(tf.argmax(softmax_logits, axis=4), tf.float32) preds.append(pred) with tf.variable_scope("dice"): self.dice_op = tf.divide(2 * tf.reduce_sum(tf.multiply(softmax_logits, Y_A[gpu_id])), tf.reduce_sum(pred) + tf.reduce_sum(Y_A[gpu_id]), name='dice') with tf.variable_scope("loss"): class_weight = tf.constant(weighting, tf.float32) weighted_logits = tf.multiply(logits, tf.reshape(class_weight, [-1, 1, 1, 1, 2])) loss_op = tf.nn.softmax_cross_entropy_with_logits(logits=weighted_logits, labels=Y_A[gpu_id]) losses.append(loss_op) # Choose the metrics to compute: names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map({ 'accuracy': tf.contrib.metrics.streaming_accuracy(softmax_logits, Y_A[gpu_id]), 'precision': tf.contrib.metrics.streaming_precision(softmax_logits, Y_A[gpu_id]), 'recall': tf.contrib.metrics.streaming_recall(softmax_logits, Y_A[gpu_id]), 'mse': tf.contrib.metrics.streaming_mean_squared_error(softmax_logits, Y_A[gpu_id]), }) self.loss_op = tf.reduce_mean(tf.concat(losses, axis=0)) self.pred = tf.cast(tf.concat(preds, axis=0), tf.float32) self.mse = names_to_values['mse'] with tf.variable_scope("metrics"): self.metric_update_ops = list(names_to_updates.values()) if tf_record_dir: tf.summary.scalar('dice', self.dice_op) tf.summary.scalar('loss', self.loss_op) for metric_name, metric_value in names_to_values.items(): op = tf.summary.scalar(metric_name, metric_value)