def train():
    """
    Introduction
    ------------
        训练模型
    """
    train_reader = Reader('train',
                          config.data_dir,
                          config.anchors_path,
                          config.num_classes,
                          input_shape=config.input_shape,
                          max_boxes=config.max_boxes)
    train_data = train_reader.build_dataset(config.train_batch_size)
    is_training = tf.placeholder(tf.bool, shape=[])
    iterator = train_data.make_one_shot_iterator()
    images, bbox, bbox_true_13, bbox_true_26, bbox_true_52 = iterator.get_next(
    )
    images.set_shape([None, config.input_shape, config.input_shape, 3])
    bbox.set_shape([None, config.max_boxes, 5])
    grid_shapes = [
        config.input_shape // 32, config.input_shape // 16,
        config.input_shape // 8
    ]
    bbox_true_13.set_shape(
        [None, grid_shapes[0], grid_shapes[0], 3, 5 + config.num_classes])
    bbox_true_26.set_shape(
        [None, grid_shapes[1], grid_shapes[1], 3, 5 + config.num_classes])
    bbox_true_52.set_shape(
        [None, grid_shapes[2], grid_shapes[2], 3, 5 + config.num_classes])
    draw_box(images, bbox)
    model = yolo(config.norm_epsilon, config.norm_decay, config.anchors_path,
                 config.classes_path, config.pre_train)
    bbox_true = [bbox_true_13, bbox_true_26, bbox_true_52]
    output = model.yolo_inference(images, config.num_anchors / 3,
                                  config.num_classes, is_training)
    loss = model.yolo_loss(output, bbox_true, model.anchors,
                           config.num_classes, config.ignore_thresh)
    l2_loss = tf.losses.get_regularization_loss()
    loss += l2_loss
    tf.summary.scalar('loss', loss)
    merged_summary = tf.summary.merge_all()
    global_step = tf.Variable(0, trainable=False)
    lr = tf.train.exponential_decay(config.learning_rate,
                                    global_step,
                                    decay_steps=2000,
                                    decay_rate=0.8)
    # decay_steps = 3 * int(config.train_num / config.train_batch_size)

    optimizer = tf.train.AdamOptimizer(learning_rate=lr)
    # 如果读取预训练权重,则冻结darknet53网络的变量
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        if config.pre_train:
            train_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='yolo')
            train_op = optimizer.minimize(loss=loss,
                                          global_step=global_step,
                                          var_list=train_var)
        else:
            train_op = optimizer.minimize(loss=loss, global_step=global_step)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver(max_to_keep=5)
    with tf.Session(config=tf.ConfigProto(log_device_placement=False)) as sess:
        ckpt = tf.train.get_checkpoint_state(config.model_dir)
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print('restore model', ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            sess.run(init)
        if config.pre_train is True:
            load_ops = load_weights(tf.global_variables(scope='darknet53'),
                                    config.darknet53_weights_path)
            sess.run(load_ops)
        summary_writer = tf.summary.FileWriter(config.log_dir, sess.graph)
        loss_value = 0
        for epoch in range(config.Epoch):
            for step in range(int(config.train_num / config.train_batch_size)):
                start_time = time.time()
                train_loss, summary, global_step_value, _ = sess.run(
                    [loss, merged_summary, global_step, train_op],
                    {is_training: True})
                loss_value += train_loss
                duration = time.time() - start_time
                examples_per_sec = float(duration) / config.train_batch_size
                format_str = (
                    'Epoch {} step {}, step_loss = {} train loss = {} ( {} examples/sec; {} '
                    'sec/batch)')
                print(
                    format_str.format(epoch, step, train_loss,
                                      loss_value / global_step_value,
                                      examples_per_sec, duration))
                summary_writer.add_summary(summary=tf.Summary(value=[
                    tf.Summary.Value(tag="train loss", simple_value=train_loss)
                ]),
                                           global_step=step)
                summary_writer.add_summary(summary, step)
                summary_writer.flush()
                # 每3个epoch保存一次模型
                if step % 30 == 0:
                    checkpoint_path = os.path.join(config.model_dir,
                                                   'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=global_step)
Beispiel #2
0
def train():
    """
    Introduction
    ------------
        训练模型
    """
    print("start")
    train_reader = Reader('train',
                          config.data_dir,
                          config.anchors_path,
                          config.num_classes,
                          input_shape=config.input_shape,
                          max_boxes=config.max_boxes)
    train_data = train_reader.build_dataset(config.train_batch_size)
    is_training = tf.placeholder(tf.bool, shape=[])

    print("ssssssssssss")

    iterator = train_data.make_one_shot_iterator()
    images, bbox, bbox_true_13, bbox_true_26, bbox_true_52 = iterator.get_next(
    )  #x, y, x, y, label
    lr_images = tf.image.resize_images(
        images,
        size=[config.input_shape // 4, config.input_shape // 4],
        method=0,
        align_corners=False)
    print("bbox")

    images.set_shape([None, config.input_shape, config.input_shape, 3])
    lr_images.set_shape(
        [None, config.input_shape // 4, config.input_shape // 4, 3])
    bbox.set_shape([None, config.max_boxes, 5])
    grid_shapes = [
        config.input_shape // 32, config.input_shape // 16,
        config.input_shape // 8
    ]
    bbox_true_13.set_shape(
        [None, grid_shapes[0], grid_shapes[0], 3, 5 + config.num_classes])
    bbox_true_26.set_shape(
        [None, grid_shapes[1], grid_shapes[1], 3, 5 + config.num_classes])
    bbox_true_52.set_shape(
        [None, grid_shapes[2], grid_shapes[2], 3, 5 + config.num_classes])
    draw_box(images, bbox)

    model = gan_model(config.norm_epsilon, config.norm_decay,
                      config.anchors_path, config.classes_path,
                      config.pre_train)
    bbox_true = [bbox_true_13, bbox_true_26, bbox_true_52]

    #------------------------model-------------------------
    g_img1 = model.GAN_g1(lr_images)
    g_img = model.GAN_g(lr_images)
    d_real = model.d_inference(images, config.num_anchors / 3,
                               config.num_classes, is_training)
    d_fake = model.d_inference(g_img.outputs, config.num_anchors / 3,
                               config.num_classes, is_training)

    #------------------------d_loss-----------------------------
    d_loss1 = model.yolo_loss(d_real, bbox_true, model.anchors,
                              config.num_classes, 1, config.ignore_thresh)
    d_loss2 = model.yolo_loss(d_fake, bbox_true, model.anchors,
                              config.num_classes, 0, config.ignore_thresh)
    d_loss = d_loss1 + d_loss2

    #------------------------g_loss-----------------------------
    adv_loss = alpha * tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.zeros_like(d_fake), logits=d_fake)
    mse_loss1 = tl.cost.mean_squared_error(g_img1.outputs,
                                           images,
                                           is_mean=True)
    mse_loss2 = tl.cost.mean_squared_error(g_img.outputs, images, is_mean=True)
    mse_loss = mse_loss1 + mse_loss2
    clc_loss = d_loss2
    g_loss = mse_loss + clc_loss + adv_loss

    l2_loss = tf.losses.get_regularization_loss()
    d_loss += l2_loss
    g_loss += l2_loss

    tf.summary.scalar('d_loss', d_loss)
    tf.summary.scalar('g_loss', g_loss)

    merged_summary = tf.summary.merge_all()
    global_step = tf.Variable(0, trainable=False)
    lr = tf.train.exponential_decay(config.learning_rate,
                                    global_step,
                                    decay_steps=2000,
                                    decay_rate=0.8)
    optimizer = tf.train.AdamOptimizer(learning_rate=lr)
    # 如果读取预训练权重,则冻结darknet53网络的变量
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        if config.pre_train:
            train_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='yolo')
            train_opd = optimizer.minimize(loss=d_loss,
                                           global_step=global_step,
                                           var_list=train_var)
            train_opg = optimizer.minimize(loss=g_loss,
                                           global_step=global_step,
                                           var_list=train_var)

        else:
            train_opd = optimizer.minimize(loss=d_loss,
                                           global_step=global_step)
            train_opg = optimizer.minimize(loss=g_loss,
                                           global_step=global_step)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.Session(config=tf.ConfigProto(log_device_placement=False)) as sess:
        ckpt = tf.train.get_checkpoint_state(config.model_dir)
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print('restore model', ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            sess.run(init)
        if config.pre_train is True:
            load_ops = load_weights(tf.global_variables(scope='darknet53'),
                                    config.darknet53_weights_path)
            sess.run(load_ops)
        summary_writer = tf.summary.FileWriter(config.log_dir, sess.graph)
        dloss_value = 0
        gloss_value = 0
        for epoch in range(config.Epoch):
            for step in range(int(config.train_num / config.train_batch_size)):
                start_time = time.time()
                dloss, summary, global_step_value, _ = sess.run(
                    [d_loss, merged_summary, global_step, train_opd],
                    {is_training: True})
                gloss, summary, global_step_value, _ = sess.run(
                    [g_loss, merged_summary, global_step, train_opg],
                    {is_training: True})

                dloss_value += dloss
                gloss_value += gloss

                duration = time.time() - start_time
                examples_per_sec = float(duration) / config.train_batch_size
                format_str1 = (
                    'Epoch {} step {},  d loss = {} ( {} examples/sec; {} '
                    'sec/batch)')
                print(
                    format_str1.format(epoch, step,
                                       dloss_value / global_step_value,
                                       examples_per_sec, duration))
                format_str2 = (
                    'Epoch {} step {},  g loss = {} ( {} examples/sec; {} '
                    'sec/batch)')
                print(
                    format_str2.format(epoch, step,
                                       gloss_value / global_step_value,
                                       examples_per_sec, duration))

                summary_writer.add_summary(summary=tf.Summary(
                    value=[tf.Summary.Value(tag="d loss", simple_value=dloss)
                           ]),
                                           global_step=step)
                summary_writer.add_summary(summary=tf.Summary(
                    value=[tf.Summary.Value(tag="g loss", simple_value=gloss)
                           ]),
                                           global_step=step)
                summary_writer.add_summary(summary, step)
                summary_writer.flush()
            # 每3个epoch保存一次模型
            if epoch % 3 == 0:
                checkpoint_path = os.path.join(config.model_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=global_step)
Beispiel #3
0
def train():
    """
    Introduction
    ------------
        训练模型
    """
    with tf.Graph().as_default(), tf.device("/cpu:0"):
        train_reader = Reader('train',
                              config.data_dir,
                              config.anchors_path,
                              config.num_classes,
                              input_shape=config.input_shape,
                              max_boxes=config.max_boxes,
                              shuffle_size=config.shuffle_size)
        train_data = train_reader.build_dataset(config.train_batch_size)
        is_training = tf.placeholder(tf.bool, shape=[])
        iterator = train_data.make_one_shot_iterator()
        images, bbox, bbox_true_13, bbox_true_26, bbox_true_52 = iterator.get_next(
        )
        images.set_shape([None, config.input_shape, config.input_shape, 3])
        bbox.set_shape([None, config.max_boxes, 5])
        grid_shapes = [
            config.input_shape // 32, config.input_shape // 16,
            config.input_shape // 8
        ]
        bbox_true_13.set_shape(
            [None, grid_shapes[0], grid_shapes[0], 3, 5 + config.num_classes])
        bbox_true_26.set_shape(
            [None, grid_shapes[1], grid_shapes[1], 3, 5 + config.num_classes])
        bbox_true_52.set_shape(
            [None, grid_shapes[2], grid_shapes[2], 3, 5 + config.num_classes])
        #draw_box(images, bbox)
        #split data for training
        images_list = tf.split(images, gpu_num)
        #bbox_list = tf.split(bbox, gpu_num)
        bbox_true_13_list = tf.split(bbox_true_13, gpu_num)
        bbox_true_26_list = tf.split(bbox_true_26, gpu_num)
        bbox_true_52_list = tf.split(bbox_true_52, gpu_num)

        global_step = tf.Variable(0, trainable=False)
        lr = tf.train.exponential_decay(
            config.learning_rate,
            global_step,
            decay_steps=1000,
            decay_rate=0.8
        )  #decay_steps 约等于一个epoch  = total_sample/batchsize int(config.train_num/config.train_batch_size)
        optimizer = tf.train.AdamOptimizer(learning_rate=lr)

        tower_grads = []
        tower_loss = []
        with tf.variable_scope(tf.get_variable_scope()):
            for gpu_id in range(gpu_num):
                with tf.device('/gpu:%d' % gpu_id):
                    with tf.name_scope('%s_%d' % ('tower', gpu_id)):
                        model = yolo(config.norm_epsilon, config.norm_decay,
                                     config.anchors_path, config.classes_path,
                                     config.pre_train)
                        bbox_true = [
                            bbox_true_13_list[gpu_id],
                            bbox_true_26_list[gpu_id],
                            bbox_true_52_list[gpu_id]
                        ]
                        output = model.yolo_inference(images_list[gpu_id],
                                                      config.num_anchors / 3,
                                                      config.num_classes,
                                                      is_training)
                        loss = model.yolo_loss(output, bbox_true,
                                               model.anchors,
                                               config.num_classes,
                                               config.ignore_thresh)
                        l2_loss = tf.losses.get_regularization_loss()
                        loss += l2_loss

                        tf.get_variable_scope().reuse_variables()
                        grads = optimizer.compute_gradients(loss)
                        tower_grads.append(grads)
                        tower_loss.append(loss)
        loss = average_loss(tower_loss)
        grads = average_gradients(tower_grads)
        #train_op = optimizer.apply_gradients(grads, global_step=global_step)

        tf.summary.scalar('loss', loss)
        merged_summary = tf.summary.merge_all()

        # 如果读取预训练权重,则冻结darknet53网络的变量
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            if config.pre_train:
                train_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                              scope='yolo')
                #train_op = optimizer.minimize(loss = loss, global_step = global_step, var_list = train_var)
                train_op = optimizer.apply_gradients(grads,
                                                     global_step=global_step)
            else:
                train_op = optimizer.apply_gradients(grads,
                                                     global_step=global_step)
        init = tf.global_variables_initializer()
        saver = tf.train.Saver()

        with tf.Session(
                config=tf.ConfigProto(log_device_placement=False,
                                      allow_soft_placement=True)) as sess:
            ckpt = tf.train.get_checkpoint_state(config.model_dir)
            if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
                print('restore model', ckpt.model_checkpoint_path)
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                sess.run(init)
            if config.pre_train is True:
                load_ops = load_weights(tf.global_variables(scope='darknet53'),
                                        config.darknet53_weights_path)
                sess.run(load_ops)
                print("pretrained the model")
            summary_writer = tf.summary.FileWriter(config.log_dir, sess.graph)
            loss_value = 0
            for epoch in range(config.Epoch):
                for step in range(
                        int(config.train_num / config.train_batch_size)):
                    start_time = time.time()
                    train_loss, summary, global_step_value, lr_step, _ = sess.run(
                        [loss, merged_summary, global_step, lr, train_op],
                        {is_training: True})
                    loss_value += train_loss
                    duration = time.time() - start_time
                    examples_per_sec = float(
                        duration) / config.train_batch_size
                    format_str = (
                        'Epoch {} step {} lr = {}, loss_step = {} train loss = {} ( {} examples/sec; {} '
                        'sec/batch)')
                    print(
                        format_str.format(epoch, step, lr_step, train_loss,
                                          loss_value / global_step_value,
                                          examples_per_sec, duration))
                    summary_writer.add_summary(
                        summary=tf.Summary(value=[
                            tf.Summary.Value(tag="train loss",
                                             simple_value=train_loss)
                        ]),
                        global_step=step + epoch *
                        int(config.train_num / config.train_batch_size))
                    summary_writer.add_summary(
                        summary,
                        global_step=step + epoch *
                        int(config.train_num / config.train_batch_size))
                    summary_writer.flush()
                    # 每3个epoch保存一次模型
                if epoch % 2 == 0:
                    checkpoint_path = os.path.join(config.model_dir,
                                                   'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=global_step)
Beispiel #4
0
def train():
    """
    Introduction
    ------------
        训练模型
    """
    # gpu_num = check_available_gpus()
    #
    # for gpu_id in range(int(gpu_num)):
    # with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)):

    # with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)):
    # with tf.variable_scope(tf.get_variable_scope(), reuse=False):

    #-----------------------train_data-------------------------
    train_reader = Reader('train',
                          config.data_dir,
                          config.anchors_path2,
                          config.num_classes,
                          input_shape=config.input_shape,
                          max_boxes=config.max_boxes)
    train_data = train_reader.build_dataset(config.train_batch_size)
    is_training = tf.placeholder(tf.bool, shape=[])
    iterator = train_data.make_one_shot_iterator()
    images, bbox, bbox_true_13, bbox_true_26, bbox_true_52 = iterator.get_next(
    )

    #-----------------------  definition-------------------------
    images.set_shape([None, config.input_shape, config.input_shape, 3])
    bbox.set_shape([None, config.max_boxes, 5])
    grid_shapes = [
        config.input_shape // 32, config.input_shape // 16,
        config.input_shape // 8
    ]
    lr_images = tf.image.resize_images(
        images,
        size=[config.input_shape // 4, config.input_shape // 4],
        method=0,
        align_corners=False)
    lr_images.set_shape(
        [None, config.input_shape // 4, config.input_shape // 4, 3])
    bbox_true_13.set_shape(
        [None, grid_shapes[0], grid_shapes[0], 3, 5 + config.num_classes])
    bbox_true_26.set_shape(
        [None, grid_shapes[1], grid_shapes[1], 3, 5 + config.num_classes])
    bbox_true_52.set_shape(
        [None, grid_shapes[2], grid_shapes[2], 3, 5 + config.num_classes])
    bbox_true = [bbox_true_13, bbox_true_26, bbox_true_52]

    #------------------------summary + draw-----------------------------------
    tf.summary.image('input1', images, max_outputs=3)
    draw_box(images, bbox)

    #------------------------------model---------------------------------
    model = yolo(config.norm_epsilon, config.norm_decay, config.anchors_path2,
                 config.classes_path, config.pre_train)
    # with tf.variable_scope("train_var"):
    # g_img1 = model.GAN_g1(lr_images)
    # print(g_img1.outputs)
    # tf.summary.image('img', g_img1.outputs, 3)
    # g_img2 = model.GAN_g2(g_img1)
    # print(model.g_variables)
    # net_g1 = model.GAN_g1(lr_images, is_train=True)
    with tf.variable_scope("model_gd"):
        net_g1 = model.GAN_g(lr_images, is_train=True, mask=False)
        net_g = model.GAN_g(lr_images, is_train=True, reuse=True, mask=True)

        d_real = model.yolo_inference(images,
                                      config.num_anchors / 3,
                                      config.num_classes,
                                      training=True)
        tf.get_variable_scope().reuse_variables()
        d_fake = model.yolo_inference(net_g.outputs,
                                      config.num_anchors / 3,
                                      config.num_classes,
                                      training=True)

    #---------------------------d_loss---------------------------------
    d_loss1 = model.yolo_loss(d_real, bbox_true, model.anchors,
                              config.num_classes, 1, config.ignore_thresh)
    d_loss2 = model.yolo_loss(d_fake, bbox_true, model.anchors,
                              config.num_classes, 0, config.ignore_thresh)
    d_loss = d_loss1 + d_loss2
    l2_loss = tf.losses.get_regularization_loss()
    d_loss += l2_loss

    #--------------------------g_loss------------------------------------
    adv_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(
        d_fake[3]),
                                                       logits=d_fake[3])
    # adv_loss = 1e-3 * tf.reduce_sum(adv_loss) / tf.cast(tf.shape(d_fake[3])[0], tf.float32)
    adv_loss = tf.reduce_sum(adv_loss) / tf.cast(
        tf.shape(d_fake[3])[0], tf.float32)
    mse_loss1 = tl.cost.mean_squared_error(net_g1.outputs,
                                           images,
                                           is_mean=True)
    mse_loss1 = tf.reduce_sum(mse_loss1) / tf.cast(
        tf.shape(net_g1.outputs)[0], tf.float32)
    mse_loss2 = tl.cost.mean_squared_error(net_g.outputs, images, is_mean=True)
    mse_loss2 = tf.reduce_sum(mse_loss2) / tf.cast(
        tf.shape(net_g.outputs)[0], tf.float32)
    mse_loss = mse_loss1 + mse_loss2
    # clc_loss = 2e-6 * d_loss2
    clc_loss = model.yolo_loss(d_fake, bbox_true, model.anchors,
                               config.num_classes, 1, config.ignore_thresh)
    g_loss = mse_loss + adv_loss + clc_loss
    l2_loss = tf.losses.get_regularization_loss()
    g_loss += l2_loss

    #----------------summary loss-------------------------
    # tf.summary.image('img', images, 3)
    tf.summary.scalar('d_loss', d_loss)
    tf.summary.scalar('g_loss', g_loss)
    merged_summary = tf.summary.merge_all()

    #----------------------optimizer---------------------------
    global_step = tf.Variable(0, trainable=False)
    lr = tf.train.exponential_decay(config.learning_rate,
                                    global_step,
                                    decay_steps=2000,
                                    decay_rate=0.8)
    optimizer = tf.train.AdamOptimizer(learning_rate=lr)
    # 如果读取预训练权重,则冻结darknet53网络的变量
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    # print(tf.all_variables())
    with tf.control_dependencies(update_ops):
        if config.pre_train:
            # aaa = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
            train_varg1 = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES,
                scope='model_gd/generator/generator1')
            train_varg2 = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES,
                scope='model_gd/generator/generator2')
            train_varg = train_varg1 + train_varg2
            # print(train_varg)
            train_vard = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES,
                scope='model_gd/yolo_inference/discriminator')
            # print(train_vard)

            train_opg = optimizer.minimize(loss=g_loss,
                                           global_step=global_step,
                                           var_list=train_varg)
            train_opd = optimizer.minimize(loss=d_loss,
                                           global_step=global_step,
                                           var_list=train_vard)
        else:
            train_opd = optimizer.minimize(loss=d_loss,
                                           global_step=global_step)
            train_opg = optimizer.minimize(loss=g_loss,
                                           global_step=global_step)

    #-------------------------session-----------------------------------
    init = tf.global_variables_initializer()
    # tl.layers.print_all_variables()
    saver = tf.train.Saver()
    with tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                          allow_soft_placement=True)) as sess:
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
        ckpt = tf.train.get_checkpoint_state(config.model_dir)
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print('restore model', ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            sess.run(init)
        if config.pre_train is True:
            load_ops = load_weights(tf.global_variables(scope='darknet53'),
                                    config.darknet53_weights_path)
            sess.run(load_ops)
        summary_writer = tf.summary.FileWriter(config.log_dir, sess.graph)
        dloss_value = 0
        gloss_value = 0
        for epoch in range(config.Epoch):
            for step in range(int(config.train_num / config.train_batch_size)):
                start_time = time.time()
                train_dloss, summary, global_step_value, _ = sess.run(
                    [d_loss, merged_summary, global_step, train_opd],
                    {is_training: True})
                train_gloss, summary, global_step_value, _ = sess.run(
                    [g_loss, merged_summary, global_step, train_opg],
                    {is_training: True})
                dloss_value += train_dloss
                gloss_value += train_gloss
                duration = time.time() - start_time
                examples_per_sec = float(duration) / config.train_batch_size
                print(global_step_value)

                #------------------------print(epoch)--------------------------
                format_str1 = (
                    'Epoch {} step {},  train dloss = {} train gloss = {} ( {} examples/sec; {} '
                    'sec/batch)')
                print(
                    format_str1.format(epoch, step,
                                       dloss_value / global_step_value,
                                       gloss_value / global_step_value,
                                       examples_per_sec, duration))
                # print(format_str1.format(epoch, step, train_dloss, train_gloss, examples_per_sec, duration))

                #----------------------------summary loss------------------------
                summary_writer.add_summary(summary=tf.Summary(value=[
                    tf.Summary.Value(tag="train dloss",
                                     simple_value=train_dloss)
                ]),
                                           global_step=step)
                summary_writer.add_summary(summary=tf.Summary(value=[
                    tf.Summary.Value(tag="train gloss",
                                     simple_value=train_gloss)
                ]),
                                           global_step=step)
                summary_writer.add_summary(summary, step)
                summary_writer.flush()

    #--------------------------save model------------------------------
    # 每3个epoch保存一次模型
            if epoch % 3 == 0:
                checkpoint_path = os.path.join(config.model_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=global_step)