def __init__(self,sess,input_placer,gold_placer,learn_rate_placer,keep_prob, window_size = 15,model_path = './model/version_15', image_path = './data/neuron/train', valid_flag = False,valid_num = 2048, nn = FbUp(), predict_flag = True, wd = 0.001, wdo = 0.001): self.sess = sess self.batch_generator = None self.nn = nn self.saver = None self.predict_flag = predict_flag self.fg_para_dict = dict() self.valid_flag = valid_flag self.window_size = window_size if valid_flag: self.batch_generator = ImageListBatchGenerator((window_size - 1)//2,image_path = image_path) self.valid = self.next_batch(valid_num) self.get_nn(input_placer,gold_placer,learn_rate_placer, keep_prob,wd = wd, wdo = wdo, window_size = window_size) print("Creation Complete\nIntialization") init_op = tf.initialize_all_variables() print("Start Intialization") sess.run(init_op) print("Intialization Done") self.restore(model_path) print("Saving")
def get_nn(self, input_placer, gold_placer, keep_prob, learn_rate_placer, wd): self.x = input_placer self.y_ = gold_placer self.l_rate = learn_rate_placer self.keep_prob = keep_prob self.para_dict = dict() para_dict = FbUp().generate_flow(input_placer, keep_prob, wd=wd) self.y_conv = para_dict['y_conv'] self.y_res = para_dict['y_res'] if self.train_flag: self.y_conv_all = self.y_conv self.y_conv_2x = dilation3D(self.y_conv_all) self.y_res_all = self.y_res self.shape = tf.shape(self.y_res) width = self.shape[1] self.y_conv = self.y_conv[:, 3, 3, 3, :] self.y_res = self.y_res[:, 3, 3, 3] cross_entropy_mean = -tf.reduce_mean(self.y_ * tf.log(self.y_conv)) tf.add_to_collection('losses', cross_entropy_mean) cross_entropy = tf.add_n(tf.get_collection('losses'), name='total_loss') self.train_step = tf.train.AdamOptimizer( self.l_rate).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(self.y_conv, 1), tf.argmax(self.y_, 1)) self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) else: self.y_conv_all = self.y_conv self.y_conv_2x = dilation3D(self.y_conv_all) self.y_res_all = self.y_res self.y_res_2x = dilation3D(self.y_res_all) self.para_dict['y_conv'] = self.y_conv_2x self.para_dict['y_res'] = self.y_res_2x self.para_dict['h_conv4'] = para_dict['h_conv4'] self.para_dict['h_pool2'] = para_dict['h_pool2'] self.para_dict['x'] = para_dict['x'] self.para_dict['h_conv3'] = para_dict['h_conv3'] self.para_dict['h_conv2'] = para_dict['h_conv2'] self.para_dict['h_conv_out'] = para_dict['h_conv_out']