def finalize(self, shapes, batchsize): assert not hasattr(self, 'output_shapes') self.output_shapes = shapes self.batchsize = batchsize ishapes = [[None] + list(s) for s in self.output_shapes] itypes = (tf.float32, tf.float32) handle = tf.placeholder(tf.string, shape=[]) iterator = Iterator.from_string_handle(handle, itypes, tuple(ishapes)) getter = iterator.get_next() return getter, handle
def get_dataset_tensors(args): with tf.device('/cpu:0'), tf.variable_scope('input_pipeline'): # TODO move this to hem.init() # find all dataset plugins available p = get_dataset(args.dataset) # ensure that the dataset exists if not p.check_prepared_datasets(args.dataset_dir): if not p.check_raw_datasets(args.raw_dataset_dir): print('Downloading dataset...') # TODO: datasets should be able to be marked as non-downloadable p.download(args.raw_dataset_dir) print('Converting to tfrecord...') p.convert_to_tfrecord(args.raw_dataset_dir, args.dataset_dir) # load the dataset datasets = p.get_datasets(args) dataset_iterators = {} # tensor to hold which training/eval phase we are in handle = tf.placeholder(tf.string, shape=[]) # add a dataset for train, validation, and testing for k, v in datasets.items(): # skip test set if not needed if len(args.test_epochs) == 0 and k == 'test': continue d = v[0] n = sum([1 for r in tf.python_io.tf_record_iterator(v[1])]) cache_fn = '{}.cache.{}'.format(args.dataset, k) d = d.cache(os.path.join( args.cache_dir, cache_fn)) if args.cache_dir else d.cache() d = d.repeat() d = d.shuffle(buffer_size=args.buffer_size, seed=args.seed) d = d.batch(args.batch_size * args.n_gpus) x_iterator = d.make_initializable_iterator() dataset_iterators[k] = { 'x': x_iterator, 'n': n, 'batches': int(n / (args.batch_size * args.n_gpus)), 'handle': x_iterator.string_handle() } # feedable dataset that will swap between train/test/val iterator = Iterator.from_string_handle(handle, d.output_types, d.output_shapes) return iterator.get_next(), handle, dataset_iterators
def train(hps, design): """Training loop.""" train_records = _get_tfrecord_files_from_dir( FLAGS.train_data_path) #get tfrecord files for train train_iterator = petct_input.build_input(train_records, hps.batch_size, hps.num_epochs, FLAGS.mode) train_iterator_handle = train_iterator.string_handle() if not FLAGS.val_data_path == '': # skip validation if no path val_records = _get_tfrecord_files_from_dir( FLAGS.val_data_path) # get tfrecord files for val val_iterator = petct_input.build_input(val_records, hps.batch_size, hps.num_epochs, 'valid') val_iterator_handle = val_iterator.string_handle() handle = tf.placeholder(tf.string, shape=[], name='data') iterator = Iterator.from_string_handle(handle, train_iterator.output_types, train_iterator.output_shapes) ct, pt, ctlb, ptlb, bglb = iterator.get_next() model = fuse_cnn_petct.FuseNet(hps, design, ct, pt, ctlb, ptlb, bglb, FLAGS.mode) model.build_cross_modal_model() # for use in loading later #tf.get_collection('model') #tf.add_to_collection('model',model) # put get metrics ops here for train and val with tf.variable_scope('metrics'): tr_summary_op, tr_precision_op, tr_recall_op, tr_accuracy_op, tr_rmse_op = get_metrics_ops( model, 'train') val_summary_op, val_precision_op, val_recall_op, val_accuracy_op, val_rmse_op = get_metrics_ops( model, 'valid') # needed for input handlers g_init_op = tf.global_variables_initializer() l_init_op = tf.local_variables_initializer() with tf.Session(config=tf.ConfigProto( allow_soft_placement=True, device_count={'GPU': 1})) as mon_sess: # Need a saver to save and restore all the variables. saver = tf.train.Saver() if FLAGS.DEBUG: print('ENABLING DEBUG') mon_sess = tf_debug.LocalCLIDebugWrapperSession(mon_sess) mon_sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) training_handle = mon_sess.run(train_iterator_handle) if not FLAGS.val_data_path == '': # skip validation if no path validation_handle = mon_sess.run(val_iterator_handle) train_writer = tf.summary.FileWriter(FLAGS.log_root + '/train', mon_sess.graph) if not FLAGS.val_data_path == '': # skip validation if no path valid_writer = tf.summary.FileWriter(FLAGS.log_root + '/valid') mon_sess.run([g_init_op, l_init_op]) summary = None step = None val_summary = None #check = 1 while True: try: ## FIRST RUN TRAINING OP BASED ON OUTPUT STYLE if FLAGS.output_style == fuse_cnn_petct.STYLE_SPLIT: # get PET and CT recons separately _, summary, step, loss, p, r, a, e, cts, pts, trcts, trpts, trbgs, recon_cts, recon_pts, ct_preds, pt_preds = mon_sess.run( [ model.train_op, tr_summary_op, model.global_step, model.cost, tr_precision_op, tr_recall_op, tr_accuracy_op, tr_rmse_op, model.ct, model.pt, model.lbct, model.lbpt, model.lbbg, model.ct_pred, model.pt_pred, model.ct_probabilities, model.pt_probabilities ], feed_dict={ handle: training_handle, model.is_training: True }) elif FLAGS.output_style == fuse_cnn_petct.STYLE_SINGLE: # get PET and CT recons together _, summary, step, loss, p, r, a, e, cts, pts, trallpos, trbgs, recon_all, all_preds = mon_sess.run( [ model.train_op, tr_summary_op, model.global_step, model.cost, tr_precision_op, tr_recall_op, tr_accuracy_op, tr_rmse_op, model.ct, model.pt, model.lb_pos_gt, model.lbbg, model.all_pred, model.all_probabilities ], feed_dict={ handle: training_handle, model.is_training: True }) if step % FLAGS.train_iter == 0: print( '[TRAIN] STEP: %d, LOSS: %.5f, PRECISION: %.5f, RECALL: %.5f, ACCURACY: %.5f, RMSE: %.5f' % (step, loss, p, r, a, e)) train_writer.add_summary(summary, step) train_writer.flush() if FLAGS.IMSAVE > 0: if step % FLAGS.IMSAVE == 0: print('SAVING IMAGES') if FLAGS.output_style == fuse_cnn_petct.STYLE_SPLIT: _saveImages(hps.batch_size, step, cts, pts, trcts=trcts, trpts=trpts, trbgs=trbgs, recon_cts=recon_cts, recon_pts=recon_pts, ct_preds=ct_preds, pt_preds=pt_preds) elif FLAGS.output_style == fuse_cnn_petct.STYLE_SINGLE: _saveImages(hps.batch_size, step, cts, pts, trallpos=trallpos, trbgs=trbgs, recon_all=recon_all, all_preds=all_preds) if not FLAGS.val_data_path == '': # skip validation if no path if step % FLAGS.val_iter == 0: _, val_summary, loss, p, r, a, e = mon_sess.run( [ model.val_op, val_summary_op, model.cost, val_precision_op, val_recall_op, val_accuracy_op, val_rmse_op ], feed_dict={ handle: validation_handle, model.is_training: False }) val_step = step print( '[VALID] STEP: %d, LOSS: %.5f, PRECISION: %.5f, RECALL: %.5f, ACCURACY: %.5f, RMSE: %.5f' % (step, loss, p, r, a, e)) valid_writer.add_summary(val_summary, step) valid_writer.flush() if step % FLAGS.chkpt_iter == 0: save_loc = FLAGS.log_root + '/' + FLAGS.chkpt_file + str( step) + '.ckpt' save_path = saver.save(mon_sess, save_loc) print('Model saved in path: %s' % save_path) except tf.errors.OutOfRangeError: print('OUT OF DATA - ENDING') # now finished training (either train or validation has run out) train_writer.add_summary(summary, step) train_writer.flush() if not FLAGS.val_data_path == '': # skip validation if no path valid_writer.add_summary(val_summary, val_step) valid_writer.flush() save_loc = FLAGS.log_root + '/' + FLAGS.chkpt_file + str( step) + '-end.ckpt' save_path = saver.save(mon_sess, save_loc) print('Model saved in path: %s' % save_path) break
if __name__ == '__main__': batch_size = 5 attr_label_num = 30 img_loader = image_bleach.ImageLoader( [global_configs.pic2attr_tfrecord_train_path], batch_size, 100, 40000) train_dataset = img_loader.launch_tfrecord_dataset() train_iterator = train_dataset.make_one_shot_iterator() # ========================= 数据导入 ========================= # =================== 用handle导入,feedble =================== # 构造一个可导入(feedble)的句柄占位符,可以通过这个将训练集的句柄或者验证集的句柄传入 handle = tf.placeholder(tf.string, shape=[]) iterator = Iterator.from_string_handle(handle, train_dataset.output_types, train_dataset.output_shapes) pic_name_batch, pic_class_batch, attr_label_batch, img_batch = iterator.get_next( ) # 从迭代器中出来的是一个二维数组,而用到的id、effect_len和label是要一个一维数组,需要reshape以下 pic_name_batch = tf.reshape(pic_name_batch, [batch_size]) pic_class_batch = tf.reshape(pic_class_batch, [batch_size]) attr_label_batch = tf.reshape(attr_label_batch, [batch_size, attr_label_num]) # ==================/ 用handle导入,feedble /================== with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) sess.run(tf.global_variables_initializer()) # 所有变量初始化 # 获得训练集和验证集的引用句柄,后面导入数据到模型用
def __init__(self, sess, batch_size, stage_of_development, learning_rate_decay_factor, type_of_model, summary_dir, experiment_folder, type_of_optimizer, num_of_classes, total_num_of_training_examples, dropout=0.5, model_path=None, beta1=0.9, beta2=0.999, min_bins=None, max_bins=None, list_of_tfrecords_for_training=None, list_of_tfrecords_for_evaluation=None, training_with_eval=False, dict_of_filePath_to_num_of_examples_in_tfrecord=None): self.training_batch_size = 0 self.batch_size = batch_size self.list_of_tr_datasets = [] self.list_of_eval_datasets = [] print("Training with dev", training_with_eval) print(sorted(list(dict_of_filePath_to_num_of_examples_in_tfrecord.keys()))) if stage_of_development == "training": for tfrecord_for_training_example_ in list_of_tfrecords_for_training: current_tr_data = tf.contrib.data.TFRecordDataset(tfrecord_for_training_example_) if type_of_model == 'VGG': current_tr_data = current_tr_data.map(parse_example_vgg) elif type_of_model == 'ResNet': current_tr_data = current_tr_data.map(parse_example_ResNet) else: current_tr_data = current_tr_data.map(parse_example) current_tr_data = current_tr_data.shuffle(buffer_size=20000) current_tr_data = current_tr_data.repeat() current_tfrecord_batch_size = math.ceil(((float(dict_of_filePath_to_num_of_examples_in_tfrecord[tfrecord_for_training_example_]) * 1.0) / (float(total_num_of_training_examples) * 1.0)) * self.batch_size) self.training_batch_size += current_tfrecord_batch_size current_tr_data = current_tr_data.batch(current_tfrecord_batch_size) print(tfrecord_for_training_example_, dict_of_filePath_to_num_of_examples_in_tfrecord[tfrecord_for_training_example_], current_tfrecord_batch_size) self.list_of_tr_datasets.append(current_tr_data) if stage_of_development == "training": self.batch_size = self.training_batch_size self.single_eval_data = None if stage_of_development != "training": self.single_eval_data = tf.contrib.data.TFRecordDataset(list_of_tfrecords_for_evaluation) if type_of_model == 'VGG': self.single_eval_data = self.single_eval_data.map(parse_example_vgg) elif type_of_model == 'ResNet': self.single_eval_data = self.singe_eval_data.map(parse_example_ResNet) else: self.single_eval_data = self.single_eval_data.map(parse_example) self.single_eval_data = self.single_eval_data.shuffle(buffer_size=10000) self.single_eval_data = self.single_eval_data.repeat(1) self.single_eval_data = self.single_eval_data.batch(self.batch_size) #for tfrecord_for_evaluation_example_ in list_of_tfrecords_for_evaluation: # current_eval_data = tf.contrib.data.TFRecordDataset(tfrecord_for_evaluation_example_) # if type_of_model == 'VGG': # current_eval_data = current_eval_data.map(parse_example_vgg) # elif type_of_model == 'ResNet': # current_eval_data = current_eval_data.map(parse_example_ResNet) # else: # current_eval_data = current_eval_data.map(parse_example) # current_eval_data = current_eval_data.shuffle(buffer_size=10000) # current_eval_data = current_eval_data.repeat(1) # current_eval_data = current_eval_data.batch(self.batch_size) # self.list_of_eval_datasets.append(current_eval_data) self.list_of_handles = [] self.list_of_iterators = [] self.list_of_batch_imgs = [] self.list_of_batch_labels = [] self.list_of_batch_imgs_and_batch_labels = [] if stage_of_development == "training": for idx_ in range(len(list_of_tfrecords_for_training)): self.list_of_handles.append(tf.placeholder(tf.string, shape=[])) self.list_of_iterators.append(Iterator.from_string_handle(self.list_of_handles[idx_], self.list_of_tr_datasets[0].output_types, self.list_of_tr_datasets[0].output_shapes)) batched_imgs, batched_labels = self.list_of_iterators[idx_].get_next() if type_of_model == 'DehazeNet': self.list_of_batch_imgs.append(tf.reshape(batched_imgs, [-1, 224+15, 224+15, 3])) else: self.list_of_batch_imgs.append(tf.reshape(batched_imgs, [-1, 224, 224, 3])) self.list_of_batch_labels.append(tf.reshape(batched_labels, [-1, 1])) else: self.single_eval_handle = tf.placeholder(tf.string, shape=[]) self.single_eval_iterator = Iterator.from_string_handle(self.single_eval_handle, self.single_eval_data.output_types, self.single_eval_data.output_shapes) self.eval_batched_imgs, self.eval_batched_labels = self.single_eval_iterator.get_next() if type_of_model == 'DehazeNet': self.eval_batched_imgs = tf.reshape(self.eval_batched_imgs, [-1, 224+15, 224+15, 3]) else: self.eval_batched_imgs = tf.reshape(self.eval_batched_imgs, [-1, 224, 224, 3]) self.eval_batched_labels = tf.reshape(self.eval_batched_labels, [-1, 1]) #for idx_ in range(len(list_of_tfrecords_for_evaluation)): # self.list_of_handles.append(tf.placeholder(tf.string, shape=[])) # self.list_of_iterators.append(Iterator.from_string_handle(self.list_of_handles[idx_], self.list_of_eval_datasets[0].output_types, self.list_of_eval_datasets[0].output_shapes)) # batched_imgs, batched_labels = self.list_of_iterators[idx_].get_next() # if type_of_model == 'DehazeNet': # self.list_of_batch_imgs.append(tf.reshape(batched_imgs, [-1, 224+15, 224+15, 3])) # else: # self.list_of_batch_imgs.append(tf.reshape(batched_imgs, [-1, 224, 224, 3])) # self.list_of_batch_labels.append(tf.reshape(batched_labels, [-1, 1])) self.list_of_training_iterators = [] self.single_eval_iterator = None if stage_of_development == "training": for tr_dataset_example_ in self.list_of_tr_datasets: validation_iterator = tr_dataset_example_.make_one_shot_iterator() self.list_of_training_iterators.append(validation_iterator) if stage_of_development == "evaluation": self.single_eval_iterator = self.single_eval_data.make_one_shot_iterator() self.row_indices = tf.placeholder(tf.int32, (self.batch_size,)) self.row_indices_reshaped = tf.reshape(self.row_indices, [self.batch_size, 1]) if stage_of_development == "training": self.batch_inputs = tf.gather_nd(tf.concat(self.list_of_batch_imgs, 0), self.row_indices_reshaped) self.batch_targets = tf.gather_nd(tf.concat(self.list_of_batch_labels, 0), self.row_indices_reshaped) else: self.batch_inputs = tf.gather_nd(self.eval_batch_imgs, self.row_indices_reshaped) self.batch_targets = tf.gather_nd(self.eval_batched_labels, self.row_indices_reshaped) self.stage_of_development = stage_of_development self.model_path = model_path self.type_of_optimizer = type_of_optimizer self.model = None self.beta1 = beta1 self.beta2 = beta2 self.pm_values = tf.gather_nd(tf.concat(self.list_of_batch_labels, 0), self.row_indices_reshaped) #if num_of_classes > 1: # discrete_targets = tf.cast(self.batch_targets, dtype=tf.float32) # discrete_targets = tf.reshape(discrete_targets, [-1, 1]) # min_bins = tf.reshape(tf.cast(min_bins, dtype=tf.float32), [1, -1]) # max_bins = tf.reshape(tf.cast(max_bins, dtype=tf.float32), [1, -1]) # c_1 = tf.subtract(discrete_targets, min_bins) # c_1 = tf.add(tf.cast(c_1 < 0, c_1.dtype) * 10000, tf.nn.relu(c_1)) # c_2 = tf.subtract(discrete_targets * -1, max_bins) # c_2 = tf.add(tf.cast(c_2 < 0, c_2.dtype) * 10000, tf.nn.relu(c_2)) # c = tf.add(c_1, c_2) # self.batch_targets = tf.reshape(tf.argmin(c, 1), [-1, 1]) self.is_training = tf.placeholder(tf.bool, shape=[]) if type_of_model == 'DehazeNet': self.model = DehazeNetModel(sess, self.batch_inputs, self.batch_targets, self.stage_of_development, num_of_classes, min_bins=min_bins, max_bins=max_bins) elif type_of_model == "VGG": self.model = AQPVGGModel(sess, self.batch_inputs, self.batch_targets, self.stage_of_development, self.model_path, num_of_classes, self.is_training, min_bins=min_bins, max_bins=max_bins) elif type_of_model == "ResNet": self.model = AQPResNetModel(sess, self.batch_inputs, self.batch_targets, self.stage_of_development, self.model_path, num_of_classes, min_bins=min_bins, max_bins=max_bins) else: self.model = SimpleCNNModel(sess, self.batch_inputs, self.batch_targets, self.stage_of_development) def return_predictions(): return self.model.predictions def return_validation_predictions(): return self.model.validation_predictions def return_MAE(): return tf.reduce_mean(tf.abs(tf.subtract(self.model.predictions, self.model.labels))) def return_validation_MAE(): return tf.reduce_mean(tf.abs(tf.subtract(self.model.validation_predictions, self.model.labels))) def return_MSE(): return tf.reduce_mean(tf.square(tf.subtract(self.model.predictions, self.model.labels))) def return_validation_MSE(): return tf.reduce_mean(tf.square(tf.subtract(self.model.validation_predictions, self.model.labels))) def return_MSLE(): return tf.reduce_mean(tf.square(tf.subtract(tf.log(tf.add(self.model.predictions, 1.0)), tf.log(tf.add(self.model.labels, 1.0))))) def return_validation_MSLE(): return tf.reduce_mean(tf.square(tf.subtract(tf.log(tf.add(self.model.validation_predictions, 1.0)), tf.log(tf.add(self.model.labels, 1.0))))) def return_R2_score(): numerator = tf.reduce_sum(tf.square(tf.subtract(self.model.labels, self.model.predictions))) #Unexplained Error denominator = tf.reduce_sum(tf.square(tf.subtract(self.model.labels, tf.reduce_mean(self.model.labels)))) # Total Error return tf.subtract(1.0, tf.divide(numerator, denominator)) def return_validation_R2_score(): numerator = tf.reduce_sum(tf.square(tf.subtract(self.model.labels, self.model.validation_predictions))) #Unexplained Error denominator = tf.reduce_sum(tf.square(tf.subtract(self.model.labels, tf.reduce_mean(self.model.labels)))) # Total Error return tf.subtract(1.0, tf.divide(numerator, denominator)) self.learning_rate = tf.placeholder(tf.float32, shape=[]) self.partial_learning_rate = tf.placeholder(tf.float32, shape=[]) self.global_step = tf.Variable(0, trainable=False) self.predictions = tf.cond(self.is_training, return_predictions, return_validation_predictions) self.MAE_ = tf.cond(self.is_training, return_MAE, return_validation_MAE) self.R2_score_ = tf.cond(self.is_training, return_R2_score, return_validation_R2_score) self.MSE_ = tf.cond(self.is_training, return_MSE, return_validation_MSE) self.MSLE_ = tf.cond(self.is_training, return_MSLE, return_validation_MSLE) if self.stage_of_development == "training": self.global_eval_step = tf.Variable(0, trainable=False) self.global_eval_update_step_variable = tf.assign(self.global_eval_step, self.global_eval_step+1) tf.summary.scalar('MAE', self.MAE_) tf.summary.scalar('MSE', self.MSE_) tf.summary.scalar('MSLE', self.MSLE_) tf.summary.scalar('R2 Coefficient', self.R2_score_) if self.stage_of_development == "training" or self.stage_of_development == "resume_training": if type_of_model == 'DehazeNet' or type_of_model == 'VGG': partial_opt = None if self.type_of_optimizer == 'adam': partial_opt = tf.train.AdamOptimizer(learning_rate=self.partial_learning_rate, beta1=self.beta1, beta2=self.beta2) else: partial_opt = tf.train.GradientDescentOptimizer(learning_rate=self.partial_learning_rate) partial_gradient = tf.gradients(self.MAE_, self.model.variables_trained_from_scratch) self.partial_train_op = partial_opt.apply_gradients(zip(partial_gradient, self.model.variables_trained_from_scratch), global_step=self.global_step) if self.type_of_optimizer == 'adam': full_opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.beta1, beta2=self.beta2) else: full_opt = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate) full_gradient = tf.gradients(self.MAE_, self.model.all_variables) self.train_op = full_opt.apply_gradients(zip(full_gradient, self.model.all_variables), global_step=self.global_step) elif type_of_model == 'ResNet': partial_opt = tf.train.AdamOptimizer(learning_rate=self.partial_learning_rate, beta1=self.beta1, beta2=self.beta2) self.partial_train_op = slim.learning.create_train_op(self.MAE_, partial_opt, global_step=self.global_step, variables_to_train=self.model.variables_trained_from_scratch) full_opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.beta1, beta2=self.beta2) self.train_op = slim.learning.create_train_op(self.MAE_, full_opt, global_step=self.global_step, variables_to_train=self.model.all_variables) else: if self.type_of_optimizer == 'adam': opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.beta1, beta2=self.beta2) else: opt = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate) gradient = tf.gradients(self.MAE_, self.model.all_variables) self.train_op = opt.apply_gradients(zip(gradient, self.model.all_variables), global_step=self.global_step) self.merged = tf.summary.merge_all() self.train_writer = tf.summary.FileWriter(summary_dir + '/' + experiment_folder + '/train', sess.graph) self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=4)