def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN): if not ModelPhase.is_valid_phase(phase): raise ValueError("ModelPhase {} is not valid!".format(phase)) if ModelPhase.is_train(phase): width = cfg.TRAIN_CROP_SIZE[0] height = cfg.TRAIN_CROP_SIZE[1] else: width = cfg.EVAL_CROP_SIZE[0] height = cfg.EVAL_CROP_SIZE[1] image_shape = [-1, cfg.DATASET.DATA_DIM, height, width] grt_shape = [-1, 1, height, width] class_num = cfg.DATASET.NUM_CLASSES with fluid.program_guard(main_prog, start_prog): with fluid.unique_name.guard(): image = fluid.data(name='image', shape=image_shape, dtype='float32') label = fluid.data(name='label', shape=grt_shape, dtype='int32') if cfg.MODEL.MODEL_NAME == 'lanenet': label_instance = fluid.data(name='label_instance', shape=grt_shape, dtype='int32') mask = fluid.data(name='mask', shape=grt_shape, dtype='int32') # use DataLoader.from_generator when doing traning and evaluation if ModelPhase.is_train(phase) or ModelPhase.is_eval(phase): data_loader = fluid.io.DataLoader.from_generator( feed_list=[image, label, label_instance, mask], capacity=cfg.DATALOADER.BUF_SIZE, iterable=False, use_double_buffer=True) loss_type = cfg.SOLVER.LOSS if not isinstance(loss_type, list): loss_type = list(loss_type) logits = seg_model(image, class_num) if ModelPhase.is_train(phase): loss_valid = False valid_loss = [] if cfg.MODEL.MODEL_NAME == 'lanenet': embeding_logit = logits[1] logits = logits[0] disc_loss, _, _, l_reg = discriminative_loss( embeding_logit, label_instance, 4, image_shape[2:], 0.5, 3.0, 1.0, 1.0, 0.001) if "softmax_loss" in loss_type: weight = None if cfg.MODEL.MODEL_NAME == 'lanenet': weight = get_dynamic_weight(label) seg_loss = multi_softmax_with_loss(logits, label, mask, class_num, weight) loss_valid = True valid_loss.append("softmax_loss") if not loss_valid: raise Exception( "SOLVER.LOSS: {} is set wrong. it should " "include one of (softmax_loss, bce_loss, dice_loss) at least" " example: ['softmax_loss']".format(cfg.SOLVER.LOSS)) invalid_loss = [x for x in loss_type if x not in valid_loss] if len(invalid_loss) > 0: print( "Warning: the loss {} you set is invalid. it will not be included in loss computed." .format(invalid_loss)) avg_loss = disc_loss + 0.00001 * l_reg + seg_loss #get pred result in original size if isinstance(logits, tuple): logit = logits[0] else: logit = logits if logit.shape[2:] != label.shape[2:]: logit = fluid.layers.resize_bilinear(logit, label.shape[2:]) # return image input and logit output for inference graph prune if ModelPhase.is_predict(phase): if class_num == 1: logit = sigmoid_to_softmax(logit) else: logit = softmax(logit) return image, logit if class_num == 1: out = sigmoid_to_softmax(logit) out = fluid.layers.transpose(out, [0, 2, 3, 1]) else: out = fluid.layers.transpose(logit, [0, 2, 3, 1]) pred = fluid.layers.argmax(out, axis=3) pred = fluid.layers.unsqueeze(pred, axes=[3]) if ModelPhase.is_visual(phase): if cfg.MODEL.MODEL_NAME == 'lanenet': return pred, logits[1] if class_num == 1: logit = sigmoid_to_softmax(logit) else: logit = softmax(logit) return pred, logit accuracy, fp, fn = compute_metric(pred, label) if ModelPhase.is_eval(phase): return data_loader, pred, label, mask, accuracy, fp, fn if ModelPhase.is_train(phase): optimizer = solver.Solver(main_prog, start_prog) decayed_lr = optimizer.optimise(avg_loss) return data_loader, avg_loss, decayed_lr, pred, label, mask, disc_loss, seg_loss, accuracy, fp, fn
def run(): parser = argparse.ArgumentParser() # Directories parser.add_argument('-s', '--srcdir', default='data', help="Source directory of TuSimple dataset") parser.add_argument('-m', '--modeldir', default='pretrained_semantic_model', help="Output directory of extracted data") parser.add_argument('-o', '--outdir', default='saved_model/lane', help="Directory for trained model") parser.add_argument('-l', '--logdir', default='log', help="Log directory for tensorboard and evaluation files") # Hyperparameters parser.add_argument('--epochs', type=int, default=50, help="Number of epochs") parser.add_argument('--var', type=float, default=1., help="Weight of variance loss") parser.add_argument('--dist', type=float, default=1., help="Weight of distance loss") parser.add_argument('--reg', type=float, default=0.001, help="Weight of regularization loss") parser.add_argument('--dvar', type=float, default=0.5, help="Cutoff variance") parser.add_argument('--ddist', type=float, default=1.5, help="Cutoff distance") args = parser.parse_args() if not os.path.isdir(args.srcdir): raise IOError('Directory does not exist') if not os.path.isdir(args.modeldir): raise IOError('Directory does not exist') if not os.path.isdir(args.logdir): os.mkdir(args.logdir) if not os.path.exists(args.outdir): os.makedirs(args.outdir) image_shape = (960, 680) # data_dir = args.srcdir #os.path.join('.', 'data') data_dir = '/media/jintian/sg/permanent/datasets/minieye/minieye_lane/lane_20180308_1700' val_data_dir = '/media/jintian/sg/permanent/datasets/minieye/minieye_lane/lane_20180323_300' model_dir = args.modeldir output_dir = args.outdir log_dir = args.logdir # image_paths = glob(os.path.join(data_dir, 'images', '*.jpg')) # label_paths = glob(os.path.join(data_dir, 'masks', '*.png')) image_paths, label_paths = datagenerator.get_lane_f_paths('./datasets/lanenet_train.txt') image_paths.sort() label_paths.sort() image_paths_s = image_paths[0:10] print(image_paths_s) # label_paths = label_paths[0:10] X_train, X_valid, y_train, y_valid = train_test_split(image_paths, label_paths, test_size=0.10, random_state=42) print(('Number of train samples', len(y_train))) print(('Number of valid samples', len(y_valid))) debug_clustering = True bandwidth = 0.7 cluster_cycle = 5000 eval_cycle = 1000 save_cycle = 5000 epochs = args.epochs batch_size = 1 starter_learning_rate = 1e-4 learning_rate_decay_rate = 0.96 learning_rate_decay_interval = 5000 feature_dim = 3 param_var = args.var param_dist = args.dist param_reg = args.reg delta_v = args.dvar delta_d = args.ddist param_string = 'fdim' + str(feature_dim) + '_var' + str(param_var) + '_dist' + str(param_dist) + '_reg' + str( param_reg) \ + '_dv' + str(delta_v) + '_dd' + str(delta_d) \ + '_lr' + str(starter_learning_rate) + '_btch' + str(batch_size) if not os.path.exists(os.path.join(log_dir, param_string)): os.makedirs(os.path.join(log_dir, param_string)) config = tf.ConfigProto() # config.gpu_options.allow_growth = True # config.gpu_options.per_process_gpu_memory_fraction = 0.5 with tf.Session(config=config) as sess: input_image = tf.placeholder(tf.float32, shape=(None, image_shape[1], image_shape[0], 3)) correct_label = tf.placeholder(dtype=tf.float32, shape=(None, image_shape[1], image_shape[0])) last_prelu = utils.load_enet(sess, model_dir, input_image, batch_size) prediction = utils.add_transfer_layers_and_initialize(sess, last_prelu, feature_dim) print(('Number of parameters in the model', utils.count_parameters())) global_step = tf.Variable(0, trainable=False) sess.run(global_step.initializer) learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, learning_rate_decay_interval, learning_rate_decay_rate, staircase=True) trainables = utils.get_trainable_variables_and_initialize(sess, debug=False) disc_loss, l_var, l_dist, l_reg = discriminative_loss(prediction, correct_label, feature_dim, image_shape, delta_v, delta_d, param_var, param_dist, param_reg) with tf.name_scope('Instance/Adam'): train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(disc_loss, var_list=trainables, global_step=global_step) adam_initializers = [var.initializer for var in tf.global_variables() if 'Adam' in var.name] sess.run(adam_initializers) summary_op_train, summary_op_valid = utils.collect_summaries(disc_loss, l_var, l_dist, l_reg, input_image, prediction, correct_label) train_writer = tf.summary.FileWriter(log_dir) valid_image_chosen, valid_label_chosen = datagenerator.get_validation_batch_lane(val_data_dir, image_shape) print((valid_image_chosen.shape)) # visualization.save_image_overlay(valid_image_chosen.copy(), valid_label_chosen.copy()) saver = tf.train.Saver() step_train = 0 step_valid = 0 for epoch in range(epochs): print(('epoch', epoch)) train_loss = 0 for image, label in datagenerator.get_batches_fn(batch_size, image_shape, X_train, y_train): lr = sess.run(learning_rate) if (step_train % eval_cycle != 0): _, step_prediction, step_loss, step_l_var, step_l_dist, step_l_reg = sess.run([ train_op, prediction, disc_loss, l_var, l_dist, l_reg], feed_dict={input_image: image, correct_label: label}) else: # First run normal training step and record summaries print('Evaluating on chosen images ...') _, summary, step_prediction, step_loss, step_l_var, step_l_dist, step_l_reg = sess.run([ train_op, summary_op_train, prediction, disc_loss, l_var, l_dist, l_reg], feed_dict={input_image: image, correct_label: label}) train_writer.add_summary(summary, step_train) # Then run model on some chosen images and save feature space visualization valid_pred = sess.run(prediction, feed_dict={input_image: np.expand_dims(valid_image_chosen[0], axis=0), correct_label: np.expand_dims(valid_label_chosen[0], axis=0)}) visualization.evaluate_scatter_plot(log_dir, valid_pred, valid_label_chosen, feature_dim, param_string, step_train) # Perform mean-shift clustering on prediction if (step_train % cluster_cycle == 0): if debug_clustering: instance_masks = clustering.get_instance_masks(valid_pred, bandwidth) for img_id, mask in enumerate(instance_masks): cv2.imwrite(os.path.join(log_dir, param_string, 'cluster_{}_{}.png'.format(str(step_train).zfill(6), str(img_id))), mask) step_train += 1 if (step_train % save_cycle == (save_cycle - 1)): try: print('Saving model ...') saver.save(sess, os.path.join(output_dir, 'model.ckpt'), global_step=step_train) except: print('FAILED saving model') # print 'gradient', step_gradient print(('step', step_train, '\tloss', step_loss, '\tl_var', step_l_var, '\tl_dist', step_l_dist, '\tl_reg', step_l_reg, '\tcurrent lr', lr)) print('Evaluating current model ...') for image, label in datagenerator.get_batches_fn(batch_size, image_shape, X_valid, y_valid): if step_valid % 100 == 0: summary, step_prediction, step_loss, step_l_var, step_l_dist, step_l_reg = sess.run([ summary_op_valid, prediction, disc_loss, l_var, l_dist, l_reg], feed_dict={input_image: image, correct_label: label}) train_writer.add_summary(summary, step_valid) else: step_prediction, step_loss, step_l_var, step_l_dist, step_l_reg = sess.run([ prediction, disc_loss, l_var, l_dist, l_reg], feed_dict={input_image: image, correct_label: label}) step_valid += 1 print(('step_valid', step_valid, 'valid loss', step_loss, '\tvalid l_var', step_l_var, '\tvalid l_dist', step_l_dist, '\tvalid l_reg', step_l_reg)) saver = tf.train.Saver() saver.save(sess, os.path.join(output_dir, 'model.ckpt'), global_step=step_train)