def call(self, inputs, training=None): if training is None: training = K.learning_phase() output = super(BatchNormalization, self).call(inputs, training=training) if training is K.learning_phase(): output._uses_learning_phase = True # pylint: disable=protected-access return output
def on_epoch_end(self, epoch, logs=None): logs = logs or {} if self.validation_data and self.histogram_freq: if epoch % self.histogram_freq == 0: # TODO(fchollet): implement batched calls to sess.run # (current call will likely go OOM on GPU) if self.model.uses_learning_phase: cut_v_data = len(self.model.inputs) val_data = self.validation_data[:cut_v_data] + [0] tensors = self.model.inputs + [K.learning_phase()] else: val_data = self.validation_data tensors = self.model.inputs feed_dict = dict(zip(tensors, val_data)) result = self.sess.run([self.merged], feed_dict=feed_dict) summary_str = result[0] self.writer.add_summary(summary_str, epoch) for name, value in logs.items(): if name in ['batch', 'size']: continue summary = tf_summary.Summary() summary_value = summary.value.add() summary_value.simple_value = value.item() summary_value.tag = name self.writer.add_summary(summary, epoch) self.writer.flush()
def on_epoch_end(self, epoch, logs=None): logs = logs or {} if self.validation_data and self.histogram_freq: if epoch % self.histogram_freq == 0: # TODO(fchollet): implement batched calls to sess.run # (current call will likely go OOM on GPU) if self.model.uses_learning_phase: cut_v_data = len(self.model.inputs) val_data = self.validation_data[:cut_v_data] + [0] tensors = self.model.inputs + [K.learning_phase()] else: val_data = self.validation_data tensors = self.model.inputs feed_dict = dict(zip(tensors, val_data)) result = self.sess.run([self.merged], feed_dict=feed_dict) summary_str = result[0] self.writer.add_summary(summary_str, epoch) if self.embeddings_freq and self.embeddings_logs: if epoch % self.embeddings_freq == 0: for log in self.embeddings_logs: self.saver.save(self.sess, log, epoch) for name, value in logs.items(): if name in ['batch', 'size']: continue summary = tf_summary.Summary() summary_value = summary.value.add() summary_value.simple_value = value.item() summary_value.tag = name self.writer.add_summary(summary, epoch) self.writer.flush()
def predict_stochastic(self, X, batch_size=128, verbose=0): '''Generate output predictions for the input samples batch by batch, using stochastic forward passes. If dropout is used at training, during prediction network units will be dropped at random as well. This procedure can be used for MC dropout (see [ModelTest callbacks](callbacks.md)). # Arguments X: the input data, as a numpy array. batch_size: integer. verbose: verbosity mode, 0 or 1. # Returns A numpy array of predictions. # References - [Dropout: A simple way to prevent neural networks from overfitting](http://jmlr.org/papers/v15/srivastava14a.html) - [Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning](http://arxiv.org/abs/1506.02142) ''' # https://stackoverflow.com/questions/44351054/keras-forward-pass-with-dropout X = _standardize_input_data( X, self.model.model._feed_input_names, self.model.model._feed_input_shapes, check_batch_axis=False ) if self._predict_stochastic is None: # we only get self.model after init self._predict_stochastic = K.function( [self.model.layers[0].input, K.learning_phase()], [self.model.layers[-1].output]) return self.model._predict_loop( self._predict_stochastic, X + [1.], batch_size, verbose)[:, 0]
def get_inputs(batch_size=1, target_size=(900, 1200), fmap=(112, 150), phoc_dim=0): """ Get Model Inputs Generate graph placeholders and return them as a namedtuple """ image = tf.placeholder(tf.float32, [batch_size, target_size[1], target_size[0], 3], name='image') box_viz = tf.placeholder(tf.float32, [batch_size, target_size[1], target_size[0], 3], name='box_viz_image') heatmap = tf.placeholder(tf.float32, [batch_size, target_size[1], target_size[0], 1], name='heatmap') # First coord is batch index tf_gt_boxes = tf.placeholder(tf.float32, [None, 5], name='gt_boxes') gt_phoc_tensor = tf.placeholder(tf.float32, [None, phoc_dim + 1], name='gt_phocs') relative_points = tf.placeholder(tf.float32, [1, fmap[0]*fmap[1], 2], name='relative_points') cntr_box_targets = tf.placeholder(tf.float32, [None, fmap[0]*fmap[1], 4], name='cntr_box_target') cntr_box_labels = tf.placeholder(tf.float32, [None, fmap[0]*fmap[1]], name='cntr_box_labels') # Use Keras' train mode placeholder is_training = K.learning_phase() return placeholders(image, box_viz, heatmap, tf_gt_boxes, gt_phoc_tensor, relative_points, cntr_box_targets, cntr_box_labels, is_training)
def on_epoch_end(self, epoch, logs=None): logs = logs or {} if not self.validation_data and self.histogram_freq: raise ValueError('If printing histograms, validation_data must be ' 'provided, and cannot be a generator.') if self.validation_data and self.histogram_freq: if epoch % self.histogram_freq == 0: val_data = self.validation_data tensors = ( self.model.inputs + self.model.targets + self.model.sample_weights) if self.model.uses_learning_phase: tensors += [K.learning_phase()] assert len(val_data) == len(tensors) val_size = val_data[0].shape[0] i = 0 while i < val_size: step = min(self.batch_size, val_size - i) batch_val = [] batch_val.append(val_data[0][i:i + step]) batch_val.append(val_data[1][i:i + step]) batch_val.append(val_data[2][i:i + step]) if self.model.uses_learning_phase: # do not slice the learning phase batch_val = [x[i:i + step] for x in val_data[:-1]] batch_val.append(val_data[-1]) else: batch_val = [x[i:i + step] for x in val_data] feed_dict = dict(zip(tensors, batch_val)) result = self.sess.run([self.merged], feed_dict=feed_dict) summary_str = result[0] self.writer.add_summary(summary_str, epoch) i += self.batch_size if self.embeddings_freq and self.embeddings_ckpt_path: if epoch % self.embeddings_freq == 0: self.saver.save(self.sess, self.embeddings_ckpt_path, epoch) for name, value in logs.items(): if name in ['batch', 'size']: continue summary = tf_summary.Summary() summary_value = summary.value.add() summary_value.simple_value = value.item() summary_value.tag = name self.writer.add_summary(summary, epoch) self.writer.flush()
def on_epoch_end(self, epoch, logs=None): logs = logs or {} if self.validation_data and self.histogram_freq: if epoch % self.histogram_freq == 0: val_data = self.validation_data tensors = ( self.model.inputs + self.model.targets + self.model.sample_weights) if self.model.uses_learning_phase: tensors += [K.learning_phase()] assert len(val_data) == len(tensors) val_size = val_data[0].shape[0] i = 0 while i < val_size: step = min(self.batch_size, val_size - i) batch_val = [] batch_val.append(val_data[0][i:i + step]) batch_val.append(val_data[1][i:i + step]) batch_val.append(val_data[2][i:i + step]) if self.model.uses_learning_phase: batch_val.append(val_data[3]) feed_dict = dict(zip(tensors, batch_val)) result = self.sess.run([self.merged], feed_dict=feed_dict) summary_str = result[0] self.writer.add_summary(summary_str, epoch) i += self.batch_size if self.embeddings_freq and self.embeddings_ckpt_path: if epoch % self.embeddings_freq == 0: self.saver.save(self.sess, self.embeddings_ckpt_path, epoch) for name, value in logs.items(): if name in ['batch', 'size']: continue summary = tf_summary.Summary() summary_value = summary.value.add() summary_value.simple_value = value.item() summary_value.tag = name self.writer.add_summary(summary, epoch) self.writer.flush()
y_ = tf.placeholder(tf.float32, [None, 10]) # define TF graph y_pred = mlp_model(x) loss = tf.losses.softmax_cross_entropy(y_, y_pred) train_step = tf.train.AdagradOptimizer(0.05).minimize(loss) correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) print('Training...') for i in range(10001): batch_xs, batch_ys = mnist.train.next_batch(100) train_fd = {x: batch_xs, y_: batch_ys, K.learning_phase(): 1} train_step.run(feed_dict=train_fd) if i % 1000 == 0: batch_xv, batch_yv = mnist.test.next_batch(200) val_accuracy = accuracy.eval({ x: batch_xv, y_: batch_yv, K.learning_phase(): 0 }) print(' step, accurary = %6d: %6.3f' % (i, val_accuracy)) test_fd = { x: mnist.test.images, y_: mnist.test.labels, K.learning_phase(): 0 }
def main(): training_images, training_labels, test_images, test_labels = load_dataset() # plt.imshow(training_images[:,:,0], cmap='gray') # plt.show() N = training_labels.size Nt = test_labels.size perm_train = np.random.permutation(N) training_labels = training_labels[perm_train] training_images = training_images[perm_train, :, :] / 255.0 training_images = np.expand_dims(training_images, -1) print(training_images.shape) test_images = test_images / 255.0 test_images = np.expand_dims(test_images, -1) # pdb.set_trace() training_labels = to_categorical(training_labels, NUM_CLASSES) test_labels = to_categorical(test_labels, NUM_CLASSES) BATCH_SIZE = 32*8 WIDTH, HEIGHT = 28, 28 epochs = 5 # Defiining the placeholders input_data = tf.placeholder(dtype=tf.float32, shape=[None, HEIGHT, WIDTH, 1], name='data') input_labels = tf.placeholder(dtype=tf.float32, shape=[None, NUM_CLASSES], name='labels') do_rate = tf.placeholder(dtype=tf.float32, name='dropout_rate') # pdb.set_trace() ''' with tf.name_scope('conv1'): with tf.variable_scope('conv1'): W_conv1 = tf.get_variable('w', [3,3,1,32]) b_conv1 = tf.get_variable('b', [32]) conv1 = tf.nn.conv2d(input=input_data, filter=W_conv1, strides=[1,1,1,1], padding='SAME') relu1 = tf.nn.relu(conv1 + b_conv1) with tf.name_scope('pool1'): pool1 = tf.nn.max_pool(value=relu1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') with tf.name_scope('conv2'): with tf.variable_scope('conv2'): W_conv2 = tf.get_variable('w', [3,3,32,32]) b_conv2 = tf.get_variable('b', [32]) conv2 = tf.nn.conv2d(input=pool1, filter=W_conv2, strides=[1,1,1,1], padding='VALID') relu2 = tf.nn.relu(conv2 + b_conv2) with tf.name_scope('pool2'): pool2 = tf.nn.max_pool(value=relu2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') with tf.name_scope('dense1'): with tf.variable_scope('dense1'): W_dense1 = tf.get_variable('w', [6*6*32,128]) b_dense1 = tf.get_variable('b', 128) flat = tf.reshape(pool2, [-1,6*6*32], name='reshape') dense1= tf.matmul(flat, W_dense1) relu3 = tf.nn.relu(dense1 + b_dense1) with tf.name_scope('dropout'): dropout = tf.nn.dropout(relu3, do_rate) with tf.name_scope('output'): with tf.variable_scope('output'): W_out = tf.get_variable('w', [128,NUM_CLASSES]) b_out = tf.get_variable('b', [NUM_CLASSES]) output = tf.matmul(dropout, W_out) + b_out ''' print('-------------------------------------------------------') """ Using Keras layers instead """ #input_layer = Input(shape=(HEIGHT, WIDTH, 1), name='input_layer') Kcnn1 = Conv2D(filters=32, kernel_size=3, strides=(1,1), padding='same', activation='relu')(input_data) Kmaxpool = MaxPooling2D(pool_size=2)(Kcnn1) """ with tf.name_scope('conv2'): with tf.variable_scope('conv2'): W_conv2 = tf.get_variable('w', [3,3,32,32]) b_conv2 = tf.get_variable('b', [32]) conv2 = tf.nn.conv2d(input=Kmaxpool, filter=W_conv2, strides=[1,1,1,1], padding='VALID') Kcnn2 = tf.nn.relu(conv2 + b_conv2) """ Kcnn2 = Conv2D(filters=32, kernel_size=3, strides=(1,1), padding='valid', activation='relu')(Kmaxpool) Kmaxpool = MaxPooling2D(pool_size=2)(Kcnn2) Kflat = Flatten()(Kmaxpool) Kdense1 = Dense(units=128, activation='relu')(Kflat) Kdropout = Dropout(.5)(Kdense1) output = Dense(units=NUM_CLASSES, activation='linear')(Kdropout) """ The rest of the code is almost the same as in pure_tf_mnist.py, except for the feed_dict, where instead of do_rate in tensorflow, we need to provide keras specific dropout tensor 'learning_phase' in the backend of Keras. """ print('-------------------------------------------------------') print('\n\n') print('-------------------------------------------------------') print('--------------- Trainable parameters ------------------') print('-------------------------------------------------------') total_parameters = 0 for v in tf.trainable_variables(): shape = v.get_shape() print(shape) #pdb.set_trace() params = 1 for dim in shape: params *= dim.value total_parameters += params print('total_parameters = {}'.format(total_parameters)) print('-------------------------------------------------------\n\n') loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=input_labels, logits=output, name='loss')) train_op = tf.train.AdamOptimizer(1e-4).minimize(loss) accuracy = tf.cast(tf.equal(tf.argmax(input_labels, 1), tf.argmax(output, 1)), tf.float32) # Training: sess = tf.Session() sess.run(tf.global_variables_initializer()) writer = tf.summary.FileWriter('graph', sess.graph) for i in range(epochs): steps = (int)(np.ceil(float(N)/float(BATCH_SIZE))) total_l = 0 total_acc = 0 for step in range(steps): x_in, y_in = get_batch(step, BATCH_SIZE, training_images, training_labels) l, acc, _ = sess.run([loss, accuracy, train_op], {input_data:x_in, input_labels:y_in, learning_phase():1})#do_rate:0.5}) total_l += l total_acc += np.sum(acc) #pdb.set_trace() total_acc /= np.float32(N) print("Epoch {}: Training loss = {}, Training accuracy = {}".format(i,total_l,total_acc)) # Test: total_acc = 0 steps = (int)(np.ceil(float(Nt)/float(BATCH_SIZE))) for step in range(steps): x_in, y_in = get_batch(step, BATCH_SIZE, test_images, test_labels) acc = sess.run([accuracy], {input_data:x_in, input_labels:y_in, learning_phase():0})#do_rate:1}) total_acc += np.sum(acc) total_acc /= np.float32(Nt) print('\n--------------------------\n') print("Test accuracy = {}".format(total_acc)) sess.close() writer.close()
def main(): test_files = [] with open('test.txt') as f: for line in f: test_files.append(line[:-1]) num_labels = 5 OFFSET_H = options['offset_h'] OFFSET_W = options['offset_w'] OFFSET_C = options['offset_c'] HSIZE = options['hsize'] WSIZE = options['wsize'] CSIZE = options['csize'] PSIZE = options['psize'] SAVE_PATH = options['model_path'] model_name = options['model_name'] OFFSET_PH = (HSIZE - PSIZE) / 2 OFFSET_PW = (WSIZE - PSIZE) / 2 OFFSET_PC = (CSIZE - PSIZE) / 2 batches_w = int(np.ceil((240 - WSIZE) / float(OFFSET_W))) + 1 batches_h = int(np.ceil((240 - HSIZE) / float(OFFSET_H))) + 1 batches_c = int(np.ceil((155 - CSIZE) / float(OFFSET_C))) + 1 flair_t2_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)) t1_t1ce_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)) if model_name == 'dense48': flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat_large( input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat_large( input=t1_t1ce_node, name='t1') elif model_name == 'no_dense': flair_t2_15, flair_t2_27 = tf_models.PlainCounterpart( input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.PlainCounterpart(input=t1_t1ce_node, name='t1') elif model_name == 'dense24': flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat( input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat( input=t1_t1ce_node, name='t1') elif model_name == 'dense24_nocorrection': flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat( input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat( input=t1_t1ce_node, name='t1') else: print ' No such model name ' t1_t1ce_15 = concatenate([t1_t1ce_15, flair_t2_15]) t1_t1ce_27 = concatenate([t1_t1ce_27, flair_t2_27]) t1_t1ce_15 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_15_cls')(t1_t1ce_15) t1_t1ce_27 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_27_cls')(t1_t1ce_27) t1_t1ce_score = t1_t1ce_15[:, 13:25, 13:25, 13:25, :] + \ t1_t1ce_27[:, 13:25, 13:25, 13:25, :] saver = tf.train.Saver() data_gen_test = vox_generator_test(test_files) dice_whole, dice_core, dice_et = [], [], [] with tf.Session() as sess: saver.restore(sess, SAVE_PATH) for i in range(len(test_files) - 1, 0, -1): print i print 'predicting %s' % test_files[i] x, x_n, y = gen_test_data(test_files[i]) pred = np.zeros([240, 240, 155, 5]) for hi in range(batches_h): offset_h = min(OFFSET_H * hi, 240 - HSIZE) offset_ph = offset_h + OFFSET_PH for wi in range(batches_w): offset_w = min(OFFSET_W * wi, 240 - WSIZE) offset_pw = offset_w + OFFSET_PW for ci in range(batches_c): offset_c = min(OFFSET_C * ci, 155 - CSIZE) offset_pc = offset_c + OFFSET_PC data = x[offset_h:offset_h + HSIZE, offset_w:offset_w + WSIZE, offset_c:offset_c + CSIZE, :] data_norm = x_n[offset_h:offset_h + HSIZE, offset_w:offset_w + WSIZE, offset_c:offset_c + CSIZE, :] data_norm = np.expand_dims(data_norm, 0) if not np.max(data) == 0 and np.min(data) == 0: score = sess.run(fetches=t1_t1ce_score, feed_dict={ flair_t2_node: data_norm[:, :, :, :, :2], t1_t1ce_node: data_norm[:, :, :, :, 2:], learning_phase(): 0 }) pred[offset_ph:offset_ph + PSIZE, offset_pw:offset_pw + PSIZE, offset_pc:offset_pc + PSIZE, :] += np.squeeze(score) pred = np.argmax(pred, axis=-1) pred = pred.astype(int) print 'calculating dice...' print options['save_path'] + test_files[i] + '_prediction' np.save(options['save_path'] + test_files[i] + '_prediction', pred) whole_pred = (pred > 0).astype(int) whole_gt = (y > 0).astype(int) core_pred = (pred == 1).astype(int) + (pred == 4).astype(int) core_gt = (y == 1).astype(int) + (y == 4).astype(int) et_pred = (pred == 4).astype(int) et_gt = (y == 4).astype(int) dice_whole_batch = dice_coef_np(whole_gt, whole_pred, 2) dice_core_batch = dice_coef_np(core_gt, core_pred, 2) dice_et_batch = dice_coef_np(et_gt, et_pred, 2) dice_whole.append(dice_whole_batch) dice_core.append(dice_core_batch) dice_et.append(dice_et_batch) print dice_whole_batch print dice_core_batch print dice_et_batch dice_whole = np.array(dice_whole) dice_core = np.array(dice_core) dice_et = np.array(dice_et) print 'mean dice whole:' print np.mean(dice_whole, axis=0) print 'mean dice core:' print np.mean(dice_core, axis=0) print 'mean dice enhance:' print np.mean(dice_et, axis=0) np.save(model_name + '_dice_whole', dice_whole) np.save(model_name + '_dice_core', dice_core) np.save(model_name + '_dice_enhance', dice_et) print 'pred saved'
input_array_small = np.random.random((500, 10)) * 2 target_small = np.random.random((500, 1)) input_tensor = Input(shape=(10, )) bn_tensor = BatchNormalization()(input_tensor) dp_tensor = Dropout(0.7)(input_tensor) #### Access BatchNormalization layer's output as arrays in both test, train mode # test mode from Model method model_bn = Model(input_tensor, bn_tensor) bn_array = model_bn.predict(input_array_small) # test and train mode from K.function method k_bn = K.function([input_tensor, K.learning_phase()], [bn_tensor]) bn_array_test = k_bn([input_array_small, 0])[0] bn_array_train = k_bn([input_array_small, 1])[0] # are test mode the same? and test mode array differ from train mode array (bn_array == bn_array_test).sum() bn_array.shape # compare to see for equality (bn_array == bn_array_train).sum() # total differ #### Access Dropout layer's output as array in both test and train mode # test mode from Model method model_dp = Model(input_tensor, dp_tensor) dp_array = model_dp.predict(input_array_small) # test and train mode from K.function method
Note that if your model has a different behavior in training and testing phase (e.g. if it uses `Dropout`, `BatchNormalization`, etc.), you will need to pass the learning phase flag to your function: """ from tensorflow.contrib.keras.python.keras.layers import Dropout, BatchNormalization input_tensor = Input(shape=(100, ), name="input_tensor") inter_tensor = Dense(300, name="my_layer")(input_tensor) bn_tensor = BatchNormalization()(inter_tensor) drop_tensor = Dropout(0.7)(bn_tensor) final_tensor = Dense(30, name="final_layer")(drop_tensor) model_dp_bn = Model(input_tensor, final_tensor) # create the original model # K.function can help distinct test_mode and training mode; as in training_mode, BatchNormalization is working, whereas in test_mode, BatchNormalization is not applied get_3rd_layer_output = K.function( [model_dp_bn.layers[0].input, K.learning_phase()], [model_dp_bn.layers[3].output]) # output in test mode = 0: no zeros in output layer_output_0 = get_3rd_layer_output([input_array1, 0])[0] layer_output_0.max() layer_output_0.min() # output in train mode = 1: lost of zeros in output layer_output_1 = get_3rd_layer_output([input_array1, 1])[0] layer_output_1.max() layer_output_1.min() # use a different method, we can only get test_mode, not training mode model2 = Model(input_tensor, bn_tensor) model2.predict(input_array1).max()
print('\n VAE fitting...') vae1.train(sess, mnist, training_epochs=100) # training classifier (classify_w_encoded) print('\nTraining...') print('number of train samples in this process = ', mnist.validation.num_examples) batch_size_tr = 100 epochs = 101 n_loop_train = int(mnist.validation.num_examples / batch_size_tr) for e in range(epochs): for i in range(n_loop_train): batch_x, batch_y = mnist.validation.next_batch(batch_size_tr) batch_z = vae1.transform(sess, batch_x) train_fd = {z: batch_z, y_: batch_y, K.learning_phase(): 1} train_step.run(feed_dict=train_fd) if e % 10 == 0: val_fd = {z: batch_z, y_: batch_y, K.learning_phase(): 0} tr_loss, tr_accu = sess.run([clf_loss, clf_accuracy], val_fd) print('Epoch, loss, accurary = {:>3d}: {:>8.4f}, {:>8.4f}'.format( e, tr_loss, tr_accu)) # test process batch_size_te = 100 n_loops_test = int(mnist.test.num_examples / batch_size_te) test_accu = [] for i in range(n_loops_test): batch_xte, batch_yte = mnist.test.next_batch(batch_size_te) batch_z = vae1.transform(sess, batch_xte)
def train(): NUM_EPOCHS = options['num_epochs'] LOAD_PATH = options['load_path'] SAVE_PATH = options['save_path'] PSIZE = options['psize'] HSIZE = options['hsize'] WSIZE = options['wsize'] CSIZE = options['csize'] model_name = options['model_name'] BATCH_SIZE = options['batch_size'] continue_training = options['continue_training'] files = [] num_labels = 5 with open('train.txt') as f: for line in f: files.append(line[:-1]) print '%d training samples' % len(files) flair_t2_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)) t1_t1ce_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)) flair_t2_gt_node = tf.placeholder(dtype=tf.int32, shape=(None, PSIZE, PSIZE, PSIZE, 2)) t1_t1ce_gt_node = tf.placeholder(dtype=tf.int32, shape=(None, PSIZE, PSIZE, PSIZE, 5)) if model_name == 'dense48': flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat_large( input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat_large( input=t1_t1ce_node, name='t1') elif model_name == 'no_dense': flair_t2_15, flair_t2_27 = tf_models.PlainCounterpart( input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.PlainCounterpart(input=t1_t1ce_node, name='t1') elif model_name == 'dense24': flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat( input=flair_t2_node, name='flair') t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat( input=t1_t1ce_node, name='t1') else: print ' No such model name ' t1_t1ce_15 = concatenate([t1_t1ce_15, flair_t2_15]) t1_t1ce_27 = concatenate([t1_t1ce_27, flair_t2_27]) flair_t2_15 = Conv3D(2, kernel_size=1, strides=1, padding='same', name='flair_t2_15_cls')(flair_t2_15) flair_t2_27 = Conv3D(2, kernel_size=1, strides=1, padding='same', name='flair_t2_27_cls')(flair_t2_27) t1_t1ce_15 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_15_cls')(t1_t1ce_15) t1_t1ce_27 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_27_cls')(t1_t1ce_27) flair_t2_score = flair_t2_15[:, 13:25, 13:25, 13:25, :] + \ flair_t2_27[:, 13:25, 13:25, 13:25, :] t1_t1ce_score = t1_t1ce_15[:, 13:25, 13:25, 13:25, :] + \ t1_t1ce_27[:, 13:25, 13:25, 13:25, :] loss = segmentation_loss(flair_t2_gt_node, flair_t2_score, 2) + \ segmentation_loss(t1_t1ce_gt_node, t1_t1ce_score, 5) acc_flair_t2 = acc_tf(y_pred=flair_t2_score, y_true=flair_t2_gt_node) acc_t1_t1ce = acc_tf(y_pred=t1_t1ce_score, y_true=t1_t1ce_gt_node) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.AdamOptimizer(learning_rate=5e-4).minimize(loss) saver = tf.train.Saver(max_to_keep=15) data_gen_train = vox_generator(all_files=files, n_pos=200, n_neg=200, correction=options['correction']) with tf.Session() as sess: if continue_training: saver.restore(sess, LOAD_PATH) else: sess.run(tf.global_variables_initializer()) for ei in range(NUM_EPOCHS): for pi in range(len(files)): acc_pi, loss_pi = [], [] data, labels, centers = data_gen_train.next() n_batches = int(np.ceil(float(centers.shape[1]) / BATCH_SIZE)) for nb in range(n_batches): offset_batch = min(nb * BATCH_SIZE, centers.shape[1] - BATCH_SIZE) data_batch, label_batch = get_patches_3d( data, labels, centers[:, offset_batch:offset_batch + BATCH_SIZE], HSIZE, WSIZE, CSIZE, PSIZE, False) label_batch = label_transform(label_batch, 5) _, l, acc_ft, acc_t1c = sess.run( fetches=[optimizer, loss, acc_flair_t2, acc_t1_t1ce], feed_dict={ flair_t2_node: data_batch[:, :, :, :, :2], t1_t1ce_node: data_batch[:, :, :, :, 2:], flair_t2_gt_node: label_batch[0], t1_t1ce_gt_node: label_batch[1], learning_phase(): 1 }) acc_pi.append([acc_ft, acc_t1c]) loss_pi.append(l) n_pos_sum = np.sum(np.reshape(label_batch[0], (-1, 2)), axis=0) print 'epoch-patient: %d, %d, iter: %d-%d, p%%: %.4f, loss: %.4f, acc_flair_t2: %.2f%%, acc_t1_t1ce: %.2f%%' % \ (ei + 1, pi + 1, nb + 1, n_batches, n_pos_sum[1]/float(np.sum(n_pos_sum)), l, acc_ft, acc_t1c) print 'patient loss: %.4f, patient acc: %.4f' % ( np.mean(loss_pi), np.mean(acc_pi)) saver.save(sess, SAVE_PATH, global_step=ei) print 'model saved'