Пример #1
0
def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
    if not ModelPhase.is_valid_phase(phase):
        raise ValueError("ModelPhase {} is not valid!".format(phase))
    if ModelPhase.is_train(phase):
        width = cfg.TRAIN_CROP_SIZE[0]
        height = cfg.TRAIN_CROP_SIZE[1]
    else:
        width = cfg.EVAL_CROP_SIZE[0]
        height = cfg.EVAL_CROP_SIZE[1]

    image_shape = [-1, cfg.DATASET.DATA_DIM, height, width]
    grt_shape = [-1, 1, height, width]
    class_num = cfg.DATASET.NUM_CLASSES

    with fluid.program_guard(main_prog, start_prog):
        with fluid.unique_name.guard():
            image = fluid.data(name='image',
                               shape=image_shape,
                               dtype='float32')
            label = fluid.data(name='label', shape=grt_shape, dtype='int32')
            if cfg.MODEL.MODEL_NAME == 'lanenet':
                label_instance = fluid.data(name='label_instance',
                                            shape=grt_shape,
                                            dtype='int32')
            mask = fluid.data(name='mask', shape=grt_shape, dtype='int32')

            # use DataLoader.from_generator when doing traning and evaluation
            if ModelPhase.is_train(phase) or ModelPhase.is_eval(phase):
                data_loader = fluid.io.DataLoader.from_generator(
                    feed_list=[image, label, label_instance, mask],
                    capacity=cfg.DATALOADER.BUF_SIZE,
                    iterable=False,
                    use_double_buffer=True)

            loss_type = cfg.SOLVER.LOSS
            if not isinstance(loss_type, list):
                loss_type = list(loss_type)

            logits = seg_model(image, class_num)

            if ModelPhase.is_train(phase):
                loss_valid = False
                valid_loss = []
                if cfg.MODEL.MODEL_NAME == 'lanenet':
                    embeding_logit = logits[1]
                    logits = logits[0]
                    disc_loss, _, _, l_reg = discriminative_loss(
                        embeding_logit, label_instance, 4, image_shape[2:],
                        0.5, 3.0, 1.0, 1.0, 0.001)

                if "softmax_loss" in loss_type:
                    weight = None
                    if cfg.MODEL.MODEL_NAME == 'lanenet':
                        weight = get_dynamic_weight(label)
                    seg_loss = multi_softmax_with_loss(logits, label, mask,
                                                       class_num, weight)
                    loss_valid = True
                    valid_loss.append("softmax_loss")

                if not loss_valid:
                    raise Exception(
                        "SOLVER.LOSS: {} is set wrong. it should "
                        "include one of (softmax_loss, bce_loss, dice_loss) at least"
                        " example: ['softmax_loss']".format(cfg.SOLVER.LOSS))

                invalid_loss = [x for x in loss_type if x not in valid_loss]
                if len(invalid_loss) > 0:
                    print(
                        "Warning: the loss {} you set is invalid. it will not be included in loss computed."
                        .format(invalid_loss))

                avg_loss = disc_loss + 0.00001 * l_reg + seg_loss

            #get pred result in original size
            if isinstance(logits, tuple):
                logit = logits[0]
            else:
                logit = logits

            if logit.shape[2:] != label.shape[2:]:
                logit = fluid.layers.resize_bilinear(logit, label.shape[2:])

            # return image input and logit output for inference graph prune
            if ModelPhase.is_predict(phase):
                if class_num == 1:
                    logit = sigmoid_to_softmax(logit)
                else:
                    logit = softmax(logit)
                return image, logit

            if class_num == 1:
                out = sigmoid_to_softmax(logit)
                out = fluid.layers.transpose(out, [0, 2, 3, 1])
            else:
                out = fluid.layers.transpose(logit, [0, 2, 3, 1])

            pred = fluid.layers.argmax(out, axis=3)
            pred = fluid.layers.unsqueeze(pred, axes=[3])
            if ModelPhase.is_visual(phase):
                if cfg.MODEL.MODEL_NAME == 'lanenet':
                    return pred, logits[1]
                if class_num == 1:
                    logit = sigmoid_to_softmax(logit)
                else:
                    logit = softmax(logit)
                return pred, logit

            accuracy, fp, fn = compute_metric(pred, label)
            if ModelPhase.is_eval(phase):
                return data_loader, pred, label, mask, accuracy, fp, fn

            if ModelPhase.is_train(phase):
                optimizer = solver.Solver(main_prog, start_prog)
                decayed_lr = optimizer.optimise(avg_loss)
                return data_loader, avg_loss, decayed_lr, pred, label, mask, disc_loss, seg_loss, accuracy, fp, fn
Пример #2
0
def run():
    parser = argparse.ArgumentParser()
    # Directories
    parser.add_argument('-s', '--srcdir', default='data', help="Source directory of TuSimple dataset")
    parser.add_argument('-m', '--modeldir', default='pretrained_semantic_model',
                        help="Output directory of extracted data")
    parser.add_argument('-o', '--outdir', default='saved_model/lane', help="Directory for trained model")
    parser.add_argument('-l', '--logdir', default='log', help="Log directory for tensorboard and evaluation files")
    # Hyperparameters
    parser.add_argument('--epochs', type=int, default=50, help="Number of epochs")
    parser.add_argument('--var', type=float, default=1., help="Weight of variance loss")
    parser.add_argument('--dist', type=float, default=1., help="Weight of distance loss")
    parser.add_argument('--reg', type=float, default=0.001, help="Weight of regularization loss")
    parser.add_argument('--dvar', type=float, default=0.5, help="Cutoff variance")
    parser.add_argument('--ddist', type=float, default=1.5, help="Cutoff distance")

    args = parser.parse_args()

    if not os.path.isdir(args.srcdir):
        raise IOError('Directory does not exist')
    if not os.path.isdir(args.modeldir):
        raise IOError('Directory does not exist')
    if not os.path.isdir(args.logdir):
        os.mkdir(args.logdir)
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    image_shape = (960, 680)
    # data_dir = args.srcdir #os.path.join('.', 'data')
    data_dir = '/media/jintian/sg/permanent/datasets/minieye/minieye_lane/lane_20180308_1700'

    val_data_dir = '/media/jintian/sg/permanent/datasets/minieye/minieye_lane/lane_20180323_300'
    model_dir = args.modeldir
    output_dir = args.outdir
    log_dir = args.logdir

    # image_paths = glob(os.path.join(data_dir, 'images', '*.jpg'))
    # label_paths = glob(os.path.join(data_dir, 'masks', '*.png'))
    image_paths, label_paths = datagenerator.get_lane_f_paths('./datasets/lanenet_train.txt')

    image_paths.sort()
    label_paths.sort()

    image_paths_s = image_paths[0:10]
    print(image_paths_s)
    # label_paths = label_paths[0:10]

    X_train, X_valid, y_train, y_valid = train_test_split(image_paths, label_paths, test_size=0.10, random_state=42)

    print(('Number of train samples', len(y_train)))
    print(('Number of valid samples', len(y_valid)))

    debug_clustering = True
    bandwidth = 0.7
    cluster_cycle = 5000
    eval_cycle = 1000
    save_cycle = 5000

    epochs = args.epochs
    batch_size = 1
    starter_learning_rate = 1e-4
    learning_rate_decay_rate = 0.96
    learning_rate_decay_interval = 5000

    feature_dim = 3
    param_var = args.var
    param_dist = args.dist
    param_reg = args.reg
    delta_v = args.dvar
    delta_d = args.ddist

    param_string = 'fdim' + str(feature_dim) + '_var' + str(param_var) + '_dist' + str(param_dist) + '_reg' + str(
        param_reg) \
                   + '_dv' + str(delta_v) + '_dd' + str(delta_d) \
                   + '_lr' + str(starter_learning_rate) + '_btch' + str(batch_size)

    if not os.path.exists(os.path.join(log_dir, param_string)):
        os.makedirs(os.path.join(log_dir, param_string))

    config = tf.ConfigProto()
    # config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5

    with tf.Session(config=config) as sess:

        input_image = tf.placeholder(tf.float32, shape=(None, image_shape[1], image_shape[0], 3))
        correct_label = tf.placeholder(dtype=tf.float32, shape=(None, image_shape[1], image_shape[0]))

        last_prelu = utils.load_enet(sess, model_dir, input_image, batch_size)
        prediction = utils.add_transfer_layers_and_initialize(sess, last_prelu, feature_dim)

        print(('Number of parameters in the model', utils.count_parameters()))
        global_step = tf.Variable(0, trainable=False)
        sess.run(global_step.initializer)
        learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
                                                   learning_rate_decay_interval, learning_rate_decay_rate,
                                                   staircase=True)

        trainables = utils.get_trainable_variables_and_initialize(sess, debug=False)

        disc_loss, l_var, l_dist, l_reg = discriminative_loss(prediction, correct_label, feature_dim, image_shape,
                                                              delta_v, delta_d, param_var, param_dist, param_reg)
        with tf.name_scope('Instance/Adam'):
            train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(disc_loss, var_list=trainables,
                                                                                    global_step=global_step)
        adam_initializers = [var.initializer for var in tf.global_variables() if 'Adam' in var.name]
        sess.run(adam_initializers)

        summary_op_train, summary_op_valid = utils.collect_summaries(disc_loss, l_var, l_dist, l_reg, input_image,
                                                                     prediction, correct_label)

        train_writer = tf.summary.FileWriter(log_dir)

        valid_image_chosen, valid_label_chosen = datagenerator.get_validation_batch_lane(val_data_dir, image_shape)
        print((valid_image_chosen.shape))
        # visualization.save_image_overlay(valid_image_chosen.copy(), valid_label_chosen.copy())

        saver = tf.train.Saver()
        step_train = 0
        step_valid = 0
        for epoch in range(epochs):
            print(('epoch', epoch))

            train_loss = 0
            for image, label in datagenerator.get_batches_fn(batch_size, image_shape, X_train, y_train):

                lr = sess.run(learning_rate)

                if (step_train % eval_cycle != 0):
                    _, step_prediction, step_loss, step_l_var, step_l_dist, step_l_reg = sess.run([
                        train_op,
                        prediction,
                        disc_loss,
                        l_var,
                        l_dist,
                        l_reg],
                        feed_dict={input_image: image, correct_label: label})
                else:
                    # First run normal training step and record summaries
                    print('Evaluating on chosen images ...')
                    _, summary, step_prediction, step_loss, step_l_var, step_l_dist, step_l_reg = sess.run([
                        train_op,
                        summary_op_train,
                        prediction,
                        disc_loss,
                        l_var,
                        l_dist,
                        l_reg],
                        feed_dict={input_image: image, correct_label: label})
                    train_writer.add_summary(summary, step_train)

                    # Then run model on some chosen images and save feature space visualization
                    valid_pred = sess.run(prediction,
                                          feed_dict={input_image: np.expand_dims(valid_image_chosen[0], axis=0),
                                                     correct_label: np.expand_dims(valid_label_chosen[0], axis=0)})
                    visualization.evaluate_scatter_plot(log_dir, valid_pred, valid_label_chosen, feature_dim,
                                                        param_string, step_train)

                    # Perform mean-shift clustering on prediction
                    if (step_train % cluster_cycle == 0):
                        if debug_clustering:
                            instance_masks = clustering.get_instance_masks(valid_pred, bandwidth)
                            for img_id, mask in enumerate(instance_masks):
                                cv2.imwrite(os.path.join(log_dir, param_string,
                                                         'cluster_{}_{}.png'.format(str(step_train).zfill(6),
                                                                                    str(img_id))), mask)

                step_train += 1

                if (step_train % save_cycle == (save_cycle - 1)):
                    try:
                        print('Saving model ...')
                        saver.save(sess, os.path.join(output_dir, 'model.ckpt'), global_step=step_train)
                    except:
                        print('FAILED saving model')
                # print 'gradient', step_gradient
                print(('step', step_train, '\tloss', step_loss, '\tl_var', step_l_var, '\tl_dist', step_l_dist,
                       '\tl_reg', step_l_reg, '\tcurrent lr', lr))

            print('Evaluating current model ...')
            for image, label in datagenerator.get_batches_fn(batch_size, image_shape, X_valid, y_valid):
                if step_valid % 100 == 0:
                    summary, step_prediction, step_loss, step_l_var, step_l_dist, step_l_reg = sess.run([
                        summary_op_valid,
                        prediction,
                        disc_loss,
                        l_var,
                        l_dist,
                        l_reg],
                        feed_dict={input_image: image, correct_label: label})
                    train_writer.add_summary(summary, step_valid)
                else:
                    step_prediction, step_loss, step_l_var, step_l_dist, step_l_reg = sess.run([
                        prediction,
                        disc_loss,
                        l_var,
                        l_dist,
                        l_reg],
                        feed_dict={input_image: image, correct_label: label})
                step_valid += 1

                print(('step_valid', step_valid, 'valid loss', step_loss, '\tvalid l_var', step_l_var, '\tvalid l_dist',
                       step_l_dist, '\tvalid l_reg', step_l_reg))

        saver = tf.train.Saver()
        saver.save(sess, os.path.join(output_dir, 'model.ckpt'), global_step=step_train)