示例#1
0
def dstributed_train(ps_hosts, worker_hosts, job_name, task_index):
    """
    Introduction
    ------------
        分布式训练
    Parameters
    ----------
        ps_hosts: sever的host
        worker_hosts: worker的host
        job_name: 判断是作为ps还是worker
        task_index: 任务index
    """
    ps_hosts = ps_hosts.split(',')
    worker_hosts = worker_hosts.split(',')
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
    server = tf.train.Server(cluster, job_name = job_name, task_index = task_index)
    if job_name == 'ps':
        server.join()
    else:
        with tf.device(tf.train.replica_device_setter(worker_device = "/job:worker/task:%d" % task_index, cluster = cluster)):
            train_data = Reader('train', config.data_dir, config.anchors_path, config.num_classes, input_shape=config.input_shape, max_boxes=config.max_boxes, jitter=config.jitter, hue=config.hue, sat=config.sat, cont=config.cont, bri=config.bri)
            val_data = Reader('val', config.data_dir, config.anchors_path, config.num_classes, input_shape=config.input_shape, max_boxes=config.max_boxes)
            images_train, bbox_true_13_train, bbox_true_26_train, bbox_true_52_train = train_data.provide(config.train_batch_size)
            images_val, bbox_true_13_val, bbox_true_26_val, bbox_true_52_val = val_data.provide(config.val_batch_size)
            model = yolo(config.norm_epsilon, config.norm_decay, config.anchors_path, config.classes_path, config.pre_train)
            is_training = tf.placeholder(dtype=tf.bool, shape=[])
            images = tf.placeholder(shape=[None, 416, 416, 3], dtype=tf.float32)
            bbox_true_13 = tf.placeholder(shape=[None, 13, 13, 3, 85], dtype=tf.float32)
            bbox_true_26 = tf.placeholder(shape=[None, 26, 26, 3, 85], dtype=tf.float32)
            bbox_true_52 = tf.placeholder(shape=[None, 52, 52, 3, 85], dtype=tf.float32)
            bbox_true = [bbox_true_13, bbox_true_26, bbox_true_52]
            output = model.yolo_inference(images, config.num_anchors / 3, config.num_classes, is_training)
            loss = model.yolo_loss(output, bbox_true, model.anchors, config.num_classes, config.ignore_thresh)
            tf.summary.scalar('loss', loss)
            global_step = tf.Variable(0, trainable=False)
            learning_rate = tf.train.exponential_decay(config.learning_rate, global_step, 20000, 0.1, staircase=True)
            optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            # 如果读取预训练权重,则冻结darknet53网络的变量
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                if config.pre_train:
                    train_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='yolo')
                    train_op = optimizer.minimize(loss=loss, global_step=global_step, var_list=train_var)
                else:
                    train_op = optimizer.minimize(loss=loss, global_step=global_step)
            init = tf.global_variables_initializer()
            saver = tf.train.Saver()
            with tf.Session(config=tf.ConfigProto(log_device_placement=False)) as sess:
                ckpt = tf.train.get_checkpoint_state(config.model_dir)
                if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
                    print('restore model', ckpt.model_checkpoint_path)
                    saver.restore(sess, ckpt.model_checkpoint_path)
                else:
                    sess.run(init)
                if config.pre_train is True:
                    load_ops = model.load_weights(tf.global_variables(scope='darknet53'), config.darknet53_weights_path)
                    sess.run(load_ops)
                summary_writer = tf.summary.FileWriter('./logs', sess.graph)
                tf.train.start_queue_runners(sess=sess)
                for epoch in range(config.Epoch):
                    for step in range(int(config.train_num / config.train_batch_size)):
                        start_time = time.time()
                        images_value, bbox_true_13_value, bbox_true_26_value, bbox_true_52_value = sess.run(
                            [images_train, bbox_true_13_train, bbox_true_26_train, bbox_true_52_train])
                        train_loss, _ = sess.run([loss, train_op], {images: images_value, bbox_true_13: bbox_true_13_value, bbox_true_26: bbox_true_26_value, bbox_true_52: bbox_true_52_value, is_training: True})
                        duration = time.time() - start_time
                        examples_per_sec = float(duration) / config.train_batch_size
                        format_str = ('Epoch {} step {},  train loss = {} ( {} examples/sec; {} ''sec/batch)')
                        print(format_str.format(epoch, step, train_loss, examples_per_sec, duration))
                        summary_writer.add_summary(summary=tf.Summary(value=[tf.Summary.Value(tag="train loss", simple_value=train_loss)]), global_step=step)
                        summary_writer.flush()
                    for step in range(int(config.val_num / config.val_batch_size)):
                        start_time = time.time()
                        images_value, bbox_true_13_value, bbox_true_26_value, bbox_true_52_value = sess.run([images_val, bbox_true_13_val, bbox_true_26_val, bbox_true_52_val])
                        val_loss = sess.run(loss, {images: images_value, bbox_true_13: bbox_true_13_value, bbox_true_26: bbox_true_26_value, bbox_true_52: bbox_true_52_value, is_training: False})
                        duration = time.time() - start_time
                        examples_per_sec = float(duration) / config.val_batch_size
                        format_str = ('Epoch {} step {}, val loss = {} ({} examples/sec; {} ''sec/batch)')
                        print(format_str.format(epoch, step, val_loss, examples_per_sec, duration))
                        summary_writer.add_summary(summary = tf.Summary(value=[tf.Summary.Value(tag = "val loss", simple_value = val_loss)]), global_step = step)
                        summary_writer.flush()
                    # 每3个epoch保存一次模型
                    if epoch % 3 == 0:
                        checkpoint_path = os.path.join(config.model_dir, 'model.ckpt')
                        saver.save(sess, checkpoint_path, global_step=epoch)
示例#2
0
def train():
    """
    Introduction
    ------------
        训练模型
    """
    # 指定使用GPU的Index
    os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_index
    train_data = Reader('train', config.data_dir, config.anchors_path, config.num_classes, input_shape = config.input_shape, max_boxes = config.max_boxes, jitter = config.jitter, hue = config.hue, sat = config.sat, cont = config.cont, bri = config.bri)
    val_data = Reader('val', config.data_dir, config.anchors_path, config.num_classes, input_shape = config.input_shape, max_boxes = config.max_boxes)
    images_train, bbox_true_13_train, bbox_true_26_train, bbox_true_52_train = train_data.provide(config.train_batch_size)
    images_val, bbox_true_13_val, bbox_true_26_val, bbox_true_52_val = val_data.provide(config.val_batch_size)

    model = yolo(config.norm_epsilon, config.norm_decay, config.anchors_path, config.classes_path, config.pre_train)
    is_training = tf.placeholder(dtype = tf.bool, shape = [])
    images = tf.placeholder(shape = [None, 416, 416, 3], dtype = tf.float32)
    bbox_true_13 = tf.placeholder(shape = [None, 13, 13, 3, 85], dtype = tf.float32)
    bbox_true_26 = tf.placeholder(shape = [None, 26, 26, 3, 85], dtype = tf.float32)
    bbox_true_52 = tf.placeholder(shape = [None, 52, 52, 3, 85], dtype = tf.float32)
    bbox_true = [bbox_true_13, bbox_true_26, bbox_true_52]
    output = model.yolo_inference(images, config.num_anchors / 3, config.num_classes, is_training)
    loss = model.yolo_loss(output, bbox_true, model.anchors, config.num_classes, config.ignore_thresh)
    tf.summary.scalar('loss', loss)
    global_step = tf.Variable(0, trainable = False)
    learning_rate = tf.train.exponential_decay(config.learning_rate, global_step, 3000, 0.1)
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
    # 如果读取预训练权重,则冻结darknet53网络的变量
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        if config.pre_train:
            train_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='yolo')
            train_op = optimizer.minimize(loss = loss, global_step = global_step, var_list = train_var)
        else:
            train_op = optimizer.minimize(loss = loss, global_step = global_step)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.Session(config = tf.ConfigProto(log_device_placement = False)) as sess:
        ckpt = tf.train.get_checkpoint_state(config.model_dir)
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print('restore model', ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            sess.run(init)
        if config.pre_train is True:
            load_ops = model.load_weights(tf.global_variables(scope = 'darknet53'), config.darknet53_weights_path)
            sess.run(load_ops)
        summary_writer = tf.summary.FileWriter('./logs', sess.graph)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess = sess, coord = coord)
        for epoch in range(config.Epoch):
            for step in range(int(config.train_num / config.train_batch_size)):
                start_time = time.time()
                images_value, bbox_true_13_value, bbox_true_26_value, bbox_true_52_value = sess.run([images_train, bbox_true_13_train, bbox_true_26_train, bbox_true_52_train])
                train_loss, _ = sess.run([loss, train_op], {images : images_value, bbox_true_13 : bbox_true_13_value, bbox_true_26 : bbox_true_26_value, bbox_true_52 : bbox_true_52_value, is_training : True})
                duration = time.time() - start_time
                examples_per_sec = float(duration) / config.train_batch_size
                format_str = ('Epoch {} step {},  train loss = {} ( {} examples/sec; {} ''sec/batch)')
                print(format_str.format(epoch, step, train_loss, examples_per_sec, duration))
                summary_writer.add_summary(summary = tf.Summary(value = [tf.Summary.Value(tag = "train loss", simple_value = train_loss)]), global_step = step)
                summary_writer.flush()
            for step in range(int(config.val_num / config.val_batch_size)):
                start_time = time.time()
                images_value, bbox_true_13_value, bbox_true_26_value, bbox_true_52_value = sess.run([images_val, bbox_true_13_val, bbox_true_26_val, bbox_true_52_val])
                val_loss = sess.run(loss, {images: images_value, bbox_true_13: bbox_true_13_value, bbox_true_26: bbox_true_26_value, bbox_true_52: bbox_true_52_value , is_training: False})
                duration = time.time() - start_time
                examples_per_sec = float(duration) / config.val_batch_size
                format_str = ('Epoch {} step {}, val loss = {} ({} examples/sec; {} ''sec/batch)')
                print(format_str.format(epoch, step, val_loss, examples_per_sec, duration))
                summary_writer.add_summary(summary = tf.Summary(value = [tf.Summary.Value(tag = "val loss", simple_value = val_loss)]), global_step = step)
                summary_writer.flush()
            # 每3个epoch保存一次模型
            if epoch % 3 == 0:
                checkpoint_path = os.path.join(config.model_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step = global_step)
        coord.request_stop()
        coord.join(threads)