def train_model(): # 代码初始化 n_batch_train = int(train_data_number // batch_size) print('n_batch_train: ', n_batch_train) os.makedirs(ckpt, exist_ok=True) session_config = dk.set_gpu() with tf.Session(config=session_config) as sess: #如果使用tensorlfow1的debug神器(主要用于查出哪里有inf或nan,不能在pycharm运行调试程序,只能在xshell里面运行) if use_tensoflow_debug: sess = tfdbg.LocalCLIDebugWrapperSession(sess) sess.add_tensor_filter("has_inf_or_nan", tfdbg.has_inf_or_nan) #然后在xshell里面运行run -f has_inf_or_nan # 一旦inf / nan出现,界面现实所有包含此类病态数值的张量,按照时间排序。所以第一个就最有可能是最先出现inf / nan的节点。 # 可以用node_info, list_inputs等命令进一步查看节点的类型和输入,来发现问题的缘由。 #教程https://blog.csdn.net/tanmx219/article/details/82318133 # 入口 train_x, train_y = create_inputs(is_train) x = tf.placeholder(tf.float32, shape=input_shape) y = tf.placeholder(tf.float32, shape=labels_shape) # 构建网络和预测 prediction, endpoint = model(images=x, is_train=is_train, size=input_shape, l2_reg=0.0001) # 打印模型结构 dk.print_model_struct(endpoint) # 求loss the_loss = get_loss(choose_loss) loss = the_loss(y, prediction, labels_shape_vec) # 设置优化器 global_step, train_step = dk.set_optimizer( lr_range=lr_range, num_batches_per_epoch=n_batch_train, loss=loss) # 求dice_hard,不合适用acc dice_hard = dk.dice_hard(y, prediction, threshold=0.5, axis=[1, 2, 3], smooth=1e-5) # dice_hard = dk.iou_metric(prediction, y) # 初始化变量 coord, threads = dk.init_variables_and_start_thread(sess) # 设置训练日志 summary_dict = {'loss': loss, 'dice_hard': dice_hard} summary_writer, summary_op = dk.set_summary(sess, logdir, summary_dict) # 恢复model saver, start_epoch = dk.restore_model(sess, ckpt, restore_model=restore_model) # 显示参数量 dk.show_parament_numbers() # 训练loop total_step = n_batch_train * epoch for epoch_n in range(start_epoch, epoch): dice_hard_value_list = [] #清空 since = time.time() for n_batch in range(n_batch_train): batch_x, batch_y = sess.run([train_x, train_y]) ########################## 数据增强 ################################### batch_x = batch_x / 255.0 # 归一化,加了这句话loss值小了几十倍 batch_x, batch_y = augmentImages(batch_x, batch_y) ########################## end ####################################### # 训练一个step _, loss_value, dice_hard_value, summary_str, step = sess.run( [train_step, loss, dice_hard, summary_op, global_step], feed_dict={ x: batch_x, y: batch_y }) # 显示结果batch_size dk.print_effect_message(epoch_n, n_batch, n_batch_train, loss_value, dice_hard_value) # 保存summary if (step + 1) % 20 == 0: summary_writer.add_summary(summary_str, step) # 保存结果 dice_hard_value_list.append(dice_hard_value) # 显示进度、耗时、最小最大平均值 seconds_mean = (time.time() - since) / n_batch_train dk.print_progress_and_time_massge(seconds_mean, step, total_step, dice_hard_value_list) # 保存model if (((epoch_n + 1) % save_epoch_n)) == 0: print('epoch_n :{} saving movdel.......'.format(epoch_n)) saver.save(sess, os.path.join(ckpt, 'model_{}.ckpt'.format(epoch_n)), global_step=global_step) dk.stop_threads(coord, threads)
saver,start_epoch = dk.restore_model(sess, ckpt, restore_model=restore_model) # 显示参数量 dk.show_parament_numbers() # 若恢复model,则重新计算start_epoch继续 # start_epoch = 0 # if restore_model: # step = sess.run(global_step) # start_epoch = int(step/n_batch_train/save_epoch_n)*save_epoch_n # 训练loop total_step = n_batch_train * epoch for epoch_n in range(start_epoch,epoch): since = time.time() for n_batch in range(n_batch_train): batch_x, batch_y = sess.run([train_x, train_y]) ########################## 数据增强 ################################### batch_x = batch_x / 255.0 # 归一化,加了这句话loss值小了几十倍 batch_x, batch_y = augmentImages(batch_x, batch_y) ########################## end ####################################### # 训练一个step _, loss_value, dice_hard_value, summary_str, step = sess.run( [train_step, loss, dice_hard, summary_op, global_step], feed_dict={x: batch_x, y: batch_y}) # 显示结果batch_size dk.print_effect_message(epoch_n,n_batch,n_batch_train,loss_value,dice_hard_value) # 保存summary if (step + 1) % 20 == 0: summary_writer.add_summary(summary_str, step) # 显示进度和耗时 seconds_mean = (time.time() - since) / n_batch_train dk.print_progress_and_time_massge(seconds_mean,step,total_step) # 保存model
# 2、初始化和启动线程 tf.global_variables_initializer().run() tf.local_variables_initializer().run() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 3、训练模型 # num_epochs=10000 for i in range(num_epochs): since = time.time() #1、读图 pics, pics_masks = sess.run([images, labels]) # 取出一个batchsize的图片 ########################## 数据增强 ################################### pics = pics / 255 # 归一化 pics, pics_masks = augmentImages(pics, pics_masks) ########################## end ####################################### # 2、训练 loss_value = model.fit(pics, pics_masks, summary_step=i) # 3、计算耗时 interval_time = time.time() - since # 4、打印结果 # print('{}/{} acc_value: {:.3f} loss_value: {:.3f} ,time used: {:.3f}s'.format(i,num_epochs,acc_value,loss_value,interval_time)) # print('{}/{} loss_value: {:.3f} ,time used: {:.3f}s'.format(i,num_epochs,loss_value,interval_time)) message = '{}/{} loss_value: {:.3f} ,time used: {:.3f}s'.format( i, num_epochs, loss_value, interval_time) print_and_save_txt(str=message, filename=train_print_log) # 5、保存model if (i + 1) % 2000 == 0: model.save()