4: 'text' } class_labels = [v for v in range((args.number_of_classes + 1))] class_labels[-1] = 255 LOG_FOLDER = './tboard_logs' TEST_DATASET_DIR = "./dataset/tfrecords" TEST_FILE = 'test.tfrecords' test_filenames = [os.path.join(TEST_DATASET_DIR, TEST_FILE)] test_dataset = tf.data.TFRecordDataset(test_filenames) test_dataset = test_dataset.map( tf_record_parser) # Parse the record into tensors. test_dataset = test_dataset.map( lambda image, annotation, image_shape: scale_image_with_crop_padding( image, annotation, image_shape, args.crop_size)) test_dataset = test_dataset.shuffle(buffer_size=100) test_dataset = test_dataset.batch(args.batch_size) iterator = test_dataset.make_one_shot_iterator() batch_images_tf, batch_labels_tf, batch_shapes_tf = iterator.get_next() logits_tf = network.deeplab_v3(batch_images_tf, args, is_training=False, reuse=False) valid_labels_batch_tf, valid_logits_batch_tf = training.get_valid_logits_and_labels( annotation_batch_tensor=batch_labels_tf, logits_batch_tensor=logits_tf, class_labels=class_labels)
resnet_checkpoints_path = './resnet/checkpoints/' download_resnet_checkpoint_if_necessary(resnet_checkpoints_path, args.resnet_model) # ============================================================================= # # defining training and validation dataset # ============================================================================= training_filenames = [os.path.join(TRAIN_DATASET_DIR,TRAIN_FILE)] training_dataset = tf.data.TFRecordDataset(training_filenames) training_dataset = training_dataset.map(tf_record_parser) # Parse the record in tensors training_dataset = training_dataset.map(rescale_image_and_annotation_by_factor) training_dataset = training_dataset.map(distort_randomly_image_color) training_dataset = training_dataset.map(lambda image,annotation,image_shape:scale_image_with_crop_padding(image,annotation,image_shape,crop_size)) training_dataset = training_dataset.map(random_flip_image_and_annotation) training_dataset = training_dataset.repeat() # no of epochs (no values means inf time repeat) training_dataset = training_dataset.shuffle(buffer_size=500) training_dataset = training_dataset.batch(args.batch_size) validation_filenames = [os.path.join(TRAIN_DATASET_DIR,VALIDATION_FILE)] validation_dataset = tf.data.TFRecordDataset(validation_filenames) validation_dataset = validation_dataset.map(tf_record_parser) # Parse the record in tensors validation_dataset = validation_dataset.map(lambda image,annotation,image_shape:scale_image_with_crop_padding(image,annotation,image_shape,crop_size)) validation_dataset = validation_dataset.shuffle(buffer_size = 100) validation_dataset = validation_dataset.batch(args.batch_size) class_labels = [ v for v in range(args.number_of_classes+1)] class_labels[-1] = 255