def predict_model(fcn_version, dataset_name, dataset_path, saved_variables): """ Create a FCN model and perform one-off training. Using the `saved_variables` parameter, one can train a FCN16 model. :param fcn_version: The type of FCN, one of `FCN32`, `FCN16`, or `FCN8`. :param dataset_name: The name of the dataset to use for training, one of `kitty_road`, `cam_vid`, `pascal_voc_2012`, `pascal_plus`. :param dataset_path: The path to the dataset root directory. :param saved_variables: Path to the pre-trained FCN weights. :return: None """ if fcn_version not in ('FCN32', 'FCN16', 'FCN8'): raise ValueError('{} is an invalid model'.format(fcn_version)) print('{} evaluation on {}'.format(fcn_version, dataset_name)) sess = clear_session() if dataset_name == 'kitty_road': dataset = KittyRoadDataset(FLAGS.augmentation_params) dataset_filepath = os.path.join( dataset_path, 'training/TFRecords/segmentation_test.tfrecords') elif dataset_name == 'cam_vid': dataset = CamVidDataset(FLAGS.augmentation_params) dataset_filepath = os.path.join( dataset_path, 'TFRecords/segmentation_val.tfrecords') elif dataset_name == 'pascal_voc_2012': dataset = PascalVOC2012Dataset(FLAGS.augmentation_params) dataset_filepath = os.path.join( dataset_path, 'TFRecords/segmentation_val.tfrecords') elif dataset_name == 'pascal_plus': dataset = PascalPlusDataset(FLAGS.augmentation_params) dataset_filepath = os.path.join( dataset_path, 'TFRecords/segmentation_val.tfrecords') else: raise ValueError('{} is an invalid dataset'.format(dataset_name)) # Build the model model = fcn_model.Model(dataset.image_shape, dataset.n_classes, FLAGS.vgg16_weights_path) model(fcn_version) # TODO: remove unnecessary graph operations # No exponential decay is applied, this is a constant learning rate learning_rate_fn = learning_rate_with_exp_decay( FLAGS.batch_size, dataset.n_images['train'], 10, decay_rate=0, base_lr=FLAGS.learning_rate) compile_model(model, FLAGS.metrics, learning_rate_fn, weight_decay=FLAGS.weight_decay) model.load_variables(os.path.join(FLAGS.save_dir, saved_variables)) # And now predict masks dataset.predict_dataset(FLAGS.save_dir, dataset_filepath, model, FLAGS.batch_size)
def evaluate_model(fcn_version, dataset_name, dataset_path, metrics, saved_variables): """ Create a FCN model and perform one-off training. Using the `saved_variables` parameter, one can train a FCN16 model. :param fcn_version: The type of FCN, one of `FCN32`, `FCN16`, or `FCN8`. :param dataset_name: The name of the dataset to use for training, one of `kitty_road`, `cam_vid`, `pascal_voc_2012`, `pascal_plus`. :param dataset_path: The path to the dataset root directory. :param metrics: The list metrics to calculate during model evaluation. Only the following metrics are supported: `acc` for pixel accuracy, `mean_acc` for mean class accuracy, and `mean_iou` for mean intersection of union. :param saved_variables: Path to the pre-trained FCN weights. :return: None """ if fcn_version not in ('FCN32', 'FCN16', 'FCN8'): raise ValueError('{} is an invalid model'.format(fcn_version)) print('{} evaluation on {}'.format(fcn_version, dataset_name)) sess = clear_session() if dataset_name == 'kitty_road': dataset = KittyRoadDataset(FLAGS.augmentation_params) elif dataset_name == 'cam_vid': dataset = CamVidDataset(FLAGS.augmentation_params) elif dataset_name == 'pascal_voc_2012': dataset = PascalVOC2012Dataset(FLAGS.augmentation_params) elif dataset_name == 'pascal_plus': dataset = PascalPlusDataset(FLAGS.augmentation_params) else: raise ValueError('{} is an invalid dataset'.format(dataset_name)) dataset_val = dataset.load_dataset(is_training=False, data_dir=dataset_path, batch_size=FLAGS.batch_size) # Build the model model = fcn_model.Model(dataset.image_shape, dataset.n_classes, FLAGS.vgg16_weights_path) model(fcn_version) # TODO: remove unnecessary graph operations # No exponential decay is applied, this is a constant learning rate learning_rate_fn = learning_rate_with_exp_decay( FLAGS.batch_size, dataset.n_images['train'], 10, decay_rate=0, base_lr=FLAGS.learning_rate) compile_model(model, metrics, learning_rate_fn, weight_decay=FLAGS.weight_decay) model.load_variables(os.path.join(FLAGS.save_dir, saved_variables)) # Evaluate the model validation_loss = 0 n_batches = 0 iterator = tf.data.Iterator.from_structure(dataset_val.output_types, dataset_val.output_shapes) next_batch = iterator.get_next() # Initialize an iterator over the validation dataset val_init_op = iterator.make_initializer(dataset_val) sess.run(val_init_op) # Train the model print('Now evaluating validation set...') while True: try: im_batch, gt_batch = sess.run(next_batch) if len(im_batch) < FLAGS.batch_size: continue res = sess.run({ **{ "loss": model.loss }, **model.metrics }, feed_dict={ model.inputs: im_batch, model.labels: gt_batch, model.keep_prob: 1.0 }) validation_loss += res["loss"] n_batches += 1 except tf.errors.OutOfRangeError: break # Save validation metrics results validation_loss /= n_batches message = 'val_loss = {:.3f}'.format(validation_loss) for metric in list(model.metrics.keys()): # Remove the void/ignore class accuracy in the mean calculation because its value is 0 if metric == 'mean_acc': val = np.mean(res[metric][1][:model.n_classes]) # Remove the void/ignore class IoU in the mean calculation because its value is NaN elif metric == 'mean_iou': mat = res[metric][1][:model.n_classes, :model.n_classes] val = np.mean((np.diag(mat) / (mat.sum(axis=0) + mat.sum(axis=1) - np.diag(mat)))) # No need to adjust other metrics else: val = res[metric][0] message += ', val_{} = {:.3f}'.format(metric, val) print(message)
def oneoff_training(fcn_version, dataset_name, dataset_path, metrics, model_name, saved_variables=None): """ Create a FCN model and perform one-off training. Using the `saved_variables` parameter, one can train a FCN16 model. :param fcn_version: The type of FCN, one of `FCN32`, `FCN16`, or `FCN8`. :param dataset_name: The name of the dataset to use for training, one of `kitty_road`, `cam_vid`, `pascal_voc_2012`, `pascal_plus`. :param dataset_path: The path to the dataset root directory. :param metrics: The list of training and validation metrics to evaluate after each epoch. Only the following metrics are supported: `acc` for pixel accuracy, `mean_acc` for mean class accuracy, and `mean_iou` for mean intersection of union. :param model_name: The name of your model, for example `fcn8_trial`. It is used to save weights, and the training curve CSV and plots to disk. :param saved_variables: Optional filename with pre-trained `FCN32` or `FCN16` weights to load. Do not use this parameter to indicate the path to VGG16 pre-trained weights. :return: The model, and a tuple with the training and validation datasets. """ if fcn_version not in ('FCN32', 'FCN16', 'FCN8'): raise ValueError('{} is an invalid model'.format(fcn_version)) print('One-off {} end-to-end training using {}'.format( fcn_version, dataset_name)) sess = clear_session() if dataset_name == 'kitty_road': dataset = KittyRoadDataset(FLAGS.augmentation_params) elif dataset_name == 'cam_vid': dataset = CamVidDataset(FLAGS.augmentation_params) elif dataset_name == 'pascal_voc_2012': dataset = PascalVOC2012Dataset(FLAGS.augmentation_params) elif dataset_name == 'pascal_plus': dataset = PascalPlusDataset(FLAGS.augmentation_params) else: raise ValueError('{} is an invalid dataset'.format(dataset_name)) dataset_train = dataset.load_dataset(is_training=True, data_dir=dataset_path, batch_size=FLAGS.batch_size) dataset_val = dataset.load_dataset(is_training=False, data_dir=dataset_path, batch_size=FLAGS.batch_size) # Build the model saved_variables = None if saved_variables is None else os.path.join( FLAGS.save_dir, saved_variables) model = fcn_model.Model(dataset.image_shape, dataset.n_classes, FLAGS.vgg16_weights_path) model(fcn_version, saved_variables=saved_variables) # No decay is applied, this is a constant learning rate learning_rate_fn = learning_rate_with_exp_decay( FLAGS.batch_size, dataset.n_images['train'], 10, decay_rate=0, base_lr=FLAGS.learning_rate) compile_model(model, metrics, learning_rate_fn, weight_decay=FLAGS.weight_decay) # Train the model fit_model(model, FLAGS.n_epochs, FLAGS.batch_size, dataset_train, dataset_val, model_name) print("Total steps = {}".format( tf.train.global_step(sess, tf.train.get_global_step()))) return model, (dataset_train, dataset_val)