Beispiel #1
0
 def _build_loaders(self):
     train_path = self.get("data/paths/train", ensure_exists=True)
     train_transforms = get_transforms(self.get("data/transforms/train",
                                                {}))
     validate_path = self.get("data/paths/validate", ensure_exists=True)
     validate_transforms = get_transforms(
         self.get("data/transforms/validate", {}))
     self.train_loader = get_dataloader(
         path=train_path,
         transforms=train_transforms,
         **self.get("data/loader_kwargs", ensure_exists=True),
     )
     self.validate_loader = get_dataloader(
         path=validate_path,
         transforms=validate_transforms,
         **self.get("data/loader_kwargs", ensure_exists=True),
     )
Beispiel #2
0
    def test_loader(self):
        from ctt.data_loading.loader import get_dataloader

        path = self.DATASET_PATH
        batch_size = 5
        dataloader = get_dataloader(
            batch_size=batch_size, shuffle=False, num_workers=0, path=path
        )
        batch = next(iter(dataloader))
        self.assertEqual(len(batch), self.NUM_KEYS_IN_BATCH)
        # Testing that all the keys in the batch have the batch_size
        keys_in_batch = list(batch.keys())
        for key in keys_in_batch:
            self.assertEqual(len(batch[key]), batch_size)
        dataloader = get_dataloader(
            batch_size=batch_size, shuffle=False, num_workers=0, path=[path, path]
        )
        batch = next(iter(dataloader))
Beispiel #3
0
 def _build_train_loader(self):
     train_path = self.get("data/paths/train", ensure_exists=True)
     train_transforms = get_transforms(self.get("data/transforms/train",
                                                {}))
     train_pretransforms = get_pre_transforms(
         self.get("data/pre_transforms", {}))
     self.train_loader = get_dataloader(
         path=train_path,
         transforms=train_transforms,
         pre_transforms=train_pretransforms,
         rng=np.random.RandomState(self.epoch),
         **self.get("data/loader_kwargs", ensure_exists=True),
     )
Beispiel #4
0
 def _build_validate_loader(self):
     validate_path = self.get("data/paths/validate", ensure_exists=True)
     validate_transforms = get_transforms(
         self.get("data/transforms/validate", {}))
     validate_pretransforms = get_pre_transforms(
         self.get("data/pre_transforms", {}))
     # Prep loader kwargs (override things if required)
     loader_kwargs = deepcopy(
         self.get("data/loader_kwargs", ensure_exists=True))
     loader_kwargs.update(self.get("data/validation_loader_kwargs", {}))
     self.validate_loader = get_dataloader(
         path=validate_path,
         transforms=validate_transforms,
         pre_transforms=validate_pretransforms,
         rng=np.random.RandomState(self.epoch),
         **loader_kwargs,
     )
Beispiel #5
0
def convert_pytorch_model_fixed_messages(pytorch_model, nb_messages,
                                         working_directory, dataset_path):

    # Make sure we are converting the inference graph
    pytorch_model.eval()

    # Setup working directory
    if not os.path.exists(working_directory):
        os.makedirs(working_directory)

    # Load dataset (used for sanity checking the converted models)
    dataloader = get_dataloader(batch_size=1,
                                shuffle=False,
                                num_workers=0,
                                path=dataset_path,
                                bit_encoded_age=True)
    batch = next(iter(dataloader))

    # Get a padded batch to use for the conversion to TF and TFLite
    batch = pad_messages_minibatch(batch, nb_messages)

    # Get list of inputs names as in the batch
    input_names = []
    for i in batch:
        input_names.append(i)
    output_names = ['encounter_variables', 'latent_variable']

    # Convert PyTorch model to ONNX format
    onnx_model_path = os.path.join(working_directory, "model_onnx_10.onnx")
    torch.onnx.export(pytorch_model,
                      batch,
                      onnx_model_path,
                      export_params=True,
                      opset_version=10,
                      do_constant_folding=True,
                      input_names=input_names,
                      output_names=output_names)

    # Load ONNX model and convert to TF model
    onnx_model = onnx.load(onnx_model_path)
    tf_model = prepare(onnx_model)

    # Convert the tf graph to a TF Saved Model
    tf_model_path = os.path.join(working_directory, "tf_model")
    if os.path.isdir(tf_model_path):
        print('Already saved a TF model, cleaning up')
        shutil.rmtree(tf_model_path, ignore_errors=True)

    builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(tf_model_path)
    with tf.compat.v1.Session(graph=tf_model.graph) as sess:

        input_spec = {}
        output_spec = {}
        for name in tf_model.inputs:
            input_spec[name] = tf_model.tensor_dict[name]
        for name in output_names:
            output_spec[name] = tf_model.tensor_dict[name]

        sigs = {
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            tf.compat.v1.saved_model.signature_def_utils.predict_signature_def(
                input_spec, output_spec)
        }

        builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING],
                                             signature_def_map=sigs)
        builder.save()

    # Convert Saved Model to TFLite model
    converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path)
    converter.allow_custom_ops = True
    converter.experimental_new_converter = True
    converter.enable_mlir_converter = True
    #converter.optimizations = [tf.lite.Optimize.DEFAULT] # 8-bits weight quantization
    tflite_model = converter.convert()
    tflite_model_path = os.path.join(working_directory,
                                     "model_%i_messages.tflite" % nb_messages)
    open(tflite_model_path, "wb").write(tflite_model)

    # Sanity-check the Tensorflow and TFLite models on the examples that have, at most,
    # the maximum number of messages that they can handle.
    interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
    pytorch_tf_deltas = []
    tf_tflite_deltas = []
    pytorch_tflite_deltas = []
    nb_validation_samples = 0
    for batch in iter(dataloader):

        # If this example is too big for the model, skip it.
        batch_nb_messages = batch["mask"].shape[1]
        if batch_nb_messages > nb_messages:
            continue

        # Generate padded inputs for the TF and TFLite models
        padded_batch = pad_messages_minibatch(batch, nb_messages)

        # Get pytorch model output
        pytorch_output = pytorch_model(batch)
        pytorch_padded_output = pytorch_model(padded_batch)

        # Get TF model output
        tf_padded_output = tf_model.run(padded_batch)
        tf_output = {
            "encounter_variables":
            tf_padded_output.encounter_variables[:, :batch_nb_messages],
            "latent_variable":
            tf_padded_output.latent_variable
        }

        # Send inputs to the TFLite model
        interpreter.allocate_tensors()
        for inp_detail in interpreter.get_input_details():
            inp_name = inp_detail["name"]
            interpreter.set_tensor(inp_detail["index"], padded_batch[inp_name])

        # Get TFLite model outputs
        tflite_output = {}
        interpreter.invoke()
        for out_name, out_detail in zip(output_names,
                                        interpreter.get_output_details()):
            if out_name == "encounter_variables":
                # Remove the message padding
                tflite_output[out_name] = interpreter.get_tensor(
                    out_detail["index"])[:, :batch_nb_messages]
            else:
                tflite_output[out_name] = interpreter.get_tensor(
                    out_detail["index"])

        # Compare the three models
        for k in pytorch_output.keys():
            k_pytorch_tf_delta = pytorch_output[k].detach().numpy(
            ) - tf_output[k]
            pytorch_tf_deltas.append(k_pytorch_tf_delta)

            k_tf_tflite_delta = tf_output[k] - tflite_output[k]
            tf_tflite_deltas.append(k_tf_tflite_delta)

            k_pytorch_tflite_delta = pytorch_output[k].detach().numpy(
            ) - tflite_output[k]
            pytorch_tflite_deltas.append(k_pytorch_tflite_delta)

        # Limit the testing to avoid spending too much time on it.
        nb_validation_samples += 1
        if nb_validation_samples >= NB_EXAMPLES_FOR_SANITY_CHECK:
            break

    # Log the results of the conversion sanity check
    if nb_validation_samples == 0:
        raise ValueError(
            "Model generated for %i messages could not be tested. All validation "
            "samples have too many messages to be used for this model." %
            nb_messages)

    log_filename = os.path.join(working_directory,
                                "model_%i_messages.txt" % nb_messages)
    with open(log_filename, "w") as f:

        abs_pytorch_tf_deltas = numpy.abs(numpy.hstack(pytorch_tf_deltas))
        abs_tf_tflite_deltas = numpy.abs(numpy.hstack(tf_tflite_deltas))
        abs_pytorch_tflite_deltas = numpy.abs(
            numpy.hstack(pytorch_tflite_deltas))

        f.write("Models compared on %i validation samples\n\n" %
                nb_validation_samples)

        f.write("Conversion from pytorch model to TF model\n")
        f.write("  Min abs diff between outputs : %f \n" %
                abs_pytorch_tf_deltas.min())
        f.write("  Mean abs diff between outputs : %f \n" %
                abs_pytorch_tf_deltas.mean())
        f.write("  Max abs diff between outputs : %f \n\n" %
                abs_pytorch_tf_deltas.max())

        f.write("Conversion from TF model to TFLite model\n")
        f.write("  Min abs diff between outputs : %f \n" %
                abs_tf_tflite_deltas.min())
        f.write("  Mean abs diff between outputs : %f \n" %
                abs_tf_tflite_deltas.mean())
        f.write("  Max abs diff between outputs : %f \n\n" %
                abs_tf_tflite_deltas.max())

        f.write("Overall conversion from pytorch model to TFLite model\n")
        f.write("  Min abs diff between outputs : %f \n" %
                abs_pytorch_tflite_deltas.min())
        f.write("  Mean abs diff between outputs : %f \n" %
                abs_pytorch_tflite_deltas.mean())
        f.write("  Max abs diff between outputs : %f \n\n" %
                abs_pytorch_tflite_deltas.max())

    return abs_pytorch_tflite_deltas.max()