예제 #1
0
def train_model():

    # 代码初始化
    n_batch_train = int(train_data_number // batch_size)
    print('n_batch_train: ', n_batch_train)
    os.makedirs(ckpt, exist_ok=True)
    session_config = dk.set_gpu()

    with tf.Session(config=session_config) as sess:
        #如果使用tensorlfow1的debug神器(主要用于查出哪里有inf或nan,不能在pycharm运行调试程序,只能在xshell里面运行)
        if use_tensoflow_debug:
            sess = tfdbg.LocalCLIDebugWrapperSession(sess)
            sess.add_tensor_filter("has_inf_or_nan", tfdbg.has_inf_or_nan)
            #然后在xshell里面运行run -f has_inf_or_nan
            # 一旦inf / nan出现,界面现实所有包含此类病态数值的张量,按照时间排序。所以第一个就最有可能是最先出现inf / nan的节点。
            # 可以用node_info, list_inputs等命令进一步查看节点的类型和输入,来发现问题的缘由。
            #教程https://blog.csdn.net/tanmx219/article/details/82318133
        # 入口
        train_x, train_y = create_inputs(is_train)
        x = tf.placeholder(tf.float32, shape=input_shape)
        y = tf.placeholder(tf.float32, shape=labels_shape)
        # 构建网络和预测
        prediction, endpoint = model(images=x,
                                     is_train=is_train,
                                     size=input_shape,
                                     l2_reg=0.0001)
        # 打印模型结构
        dk.print_model_struct(endpoint)
        # 求loss
        the_loss = get_loss(choose_loss)
        loss = the_loss(y, prediction, labels_shape_vec)
        # 设置优化器
        global_step, train_step = dk.set_optimizer(
            lr_range=lr_range, num_batches_per_epoch=n_batch_train, loss=loss)
        # 求dice_hard,不合适用acc
        dice_hard = dk.dice_hard(y,
                                 prediction,
                                 threshold=0.5,
                                 axis=[1, 2, 3],
                                 smooth=1e-5)
        # dice_hard = dk.iou_metric(prediction, y)
        # 初始化变量
        coord, threads = dk.init_variables_and_start_thread(sess)
        # 设置训练日志
        summary_dict = {'loss': loss, 'dice_hard': dice_hard}
        summary_writer, summary_op = dk.set_summary(sess, logdir, summary_dict)
        # 恢复model
        saver, start_epoch = dk.restore_model(sess,
                                              ckpt,
                                              restore_model=restore_model)
        # 显示参数量
        dk.show_parament_numbers()
        # 训练loop
        total_step = n_batch_train * epoch
        for epoch_n in range(start_epoch, epoch):
            dice_hard_value_list = []  #清空
            since = time.time()
            for n_batch in range(n_batch_train):
                batch_x, batch_y = sess.run([train_x, train_y])
                ##########################   数据增强   ###################################
                batch_x = batch_x / 255.0  # 归一化,加了这句话loss值小了几十倍
                batch_x, batch_y = augmentImages(batch_x, batch_y)
                ##########################   end   #######################################
                # 训练一个step
                _, loss_value, dice_hard_value, summary_str, step = sess.run(
                    [train_step, loss, dice_hard, summary_op, global_step],
                    feed_dict={
                        x: batch_x,
                        y: batch_y
                    })
                # 显示结果batch_size
                dk.print_effect_message(epoch_n, n_batch, n_batch_train,
                                        loss_value, dice_hard_value)
                # 保存summary
                if (step + 1) % 20 == 0:
                    summary_writer.add_summary(summary_str, step)
                # 保存结果
                dice_hard_value_list.append(dice_hard_value)

            # 显示进度、耗时、最小最大平均值
            seconds_mean = (time.time() - since) / n_batch_train
            dk.print_progress_and_time_massge(seconds_mean, step, total_step,
                                              dice_hard_value_list)

            # 保存model
            if (((epoch_n + 1) % save_epoch_n)) == 0:
                print('epoch_n :{} saving movdel.......'.format(epoch_n))
                saver.save(sess,
                           os.path.join(ckpt, 'model_{}.ckpt'.format(epoch_n)),
                           global_step=global_step)

        dk.stop_threads(coord, threads)
예제 #2
0
        saver,start_epoch = dk.restore_model(sess, ckpt, restore_model=restore_model)        # 显示参数量
        dk.show_parament_numbers()
        # 若恢复model,则重新计算start_epoch继续
        # start_epoch = 0
        # if restore_model:
        #     step = sess.run(global_step)
        #     start_epoch = int(step/n_batch_train/save_epoch_n)*save_epoch_n
        # 训练loop
        total_step = n_batch_train * epoch
        for epoch_n in range(start_epoch,epoch):
            since = time.time()
            for n_batch in range(n_batch_train):
                batch_x, batch_y = sess.run([train_x, train_y])
                ##########################   数据增强   ###################################
                batch_x = batch_x / 255.0  # 归一化,加了这句话loss值小了几十倍
                batch_x, batch_y = augmentImages(batch_x, batch_y)
                ##########################   end   #######################################
                # 训练一个step
                _, loss_value, dice_hard_value, summary_str, step = sess.run(
                    [train_step, loss, dice_hard, summary_op, global_step],
                    feed_dict={x: batch_x, y: batch_y})
                # 显示结果batch_size
                dk.print_effect_message(epoch_n,n_batch,n_batch_train,loss_value,dice_hard_value)
                # 保存summary
                if (step + 1) % 20 == 0:
                    summary_writer.add_summary(summary_str, step)

            # 显示进度和耗时
            seconds_mean = (time.time() - since) / n_batch_train
            dk.print_progress_and_time_massge(seconds_mean,step,total_step)
            # 保存model
예제 #3
0
파일: train.py 프로젝트: qq191513/mySeg
        # 2、初始化和启动线程
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # 3、训练模型
        # num_epochs=10000
        for i in range(num_epochs):
            since = time.time()
            #1、读图
            pics, pics_masks = sess.run([images, labels])  # 取出一个batchsize的图片
            ##########################   数据增强   ###################################
            pics = pics / 255  # 归一化
            pics, pics_masks = augmentImages(pics, pics_masks)
            ##########################   end   #######################################
            # 2、训练
            loss_value = model.fit(pics, pics_masks, summary_step=i)
            # 3、计算耗时
            interval_time = time.time() - since
            # 4、打印结果
            # print('{}/{} acc_value: {:.3f} loss_value: {:.3f} ,time used: {:.3f}s'.format(i,num_epochs,acc_value,loss_value,interval_time))
            # print('{}/{} loss_value: {:.3f} ,time used: {:.3f}s'.format(i,num_epochs,loss_value,interval_time))
            message = '{}/{} loss_value: {:.3f} ,time used: {:.3f}s'.format(
                i, num_epochs, loss_value, interval_time)
            print_and_save_txt(str=message, filename=train_print_log)

            # 5、保存model
            if (i + 1) % 2000 == 0:
                model.save()