コード例 #1
0
def process_trace_set_paths(result,
                            trace_set_paths,
                            conf,
                            request_id=None,
                            keep_trace_sets=False):
    num_todo = len(trace_set_paths)
    num_done = 0
    for trace_set_path in trace_set_paths:
        # Get trace name from path
        # trace_set_name = basename(trace_set_path)
        logger.info("Processing '%s' (%d/%d)" %
                    (trace_set_path, num_done, num_todo))

        # Load trace
        trace_set = emio.get_trace_set(trace_set_path,
                                       conf.format,
                                       ignore_malformed=False,
                                       remote=False)
        if trace_set is None:
            logger.warning(
                "Failed to load trace set %s (got None). Skipping..." %
                trace_set_path)
            continue

        # Process trace set
        process_trace_set(result, trace_set, conf, request_id, keep_trace_sets)

        num_done += 1
コード例 #2
0
def compress_trace_set(trace_set_path):
    if trace_set_path.endswith('.npy'):
        parent_dataset_path = os.path.dirname(trace_set_path)
        manifest_path = os.path.join(parent_dataset_path, 'manifest.emcap')

        if os.path.exists(manifest_path):
            # Open manifest
            with open(manifest_path, 'rb') as manifest_file:
                manifest = pickle.load(manifest_file)
                conf = manifest['conf']

            # Load trace set
            trace_set = emio.get_trace_set(trace_set_path, 'cw', remote=False)
            conf_delete_action(
                conf, 'optimize_capture'
            )  # Make sure there is no optimize_capture action anymore

            # Add appropriate actions
            if 'pca' in manifest:
                conf.actions.append(Action('pca[%s]' % manifest_path))
            elif 'autoenc' in manifest:
                conf.actions.append(Action('corrtest[autoenc]'))

            # Perform compression
            result = EMResult()
            ops.process_trace_set(result,
                                  trace_set,
                                  conf,
                                  keep_trace_sets=True)
            processed_trace_set = result.trace_sets[0]

            # Save compressed trace set
            processed_trace_set.save(os.path.abspath(parent_dataset_path),
                                     dry=False)
        else:
            raise EMMAException(
                "No manifest.emcap in %s, so don't know how to compress." %
                parent_dataset_path)
    else:
        raise EMMAException("Not a valid traceset_path in numpy format")
コード例 #3
0
ファイル: dataset.py プロジェクト: zhihuishuwp/emma
    def _setup(self, emma_conf):
        """
        Get a list of relative trace set paths for the dataset identifier and retrieve
        a reference signal for the entire dataset.

        Example trace set paths:
        em-arduino/trace1.npy
        em-arduino/trace2.npy
        ...
        em-arduino/tracen.npy

        Where trace1.npy is loaded as the reference signal.

        At a later time, the relative paths need to be resolved to absolute paths
        on the workers.
        """
        settings = configparser.RawConfigParser()
        settings.read('settings.conf')
        self.root = settings.get("Datasets", "datasets_path")

        # Assign trace set paths
        if self.format == "cw":  # .npy
            path = join(self.root, self.id)
            self.trace_set_paths = sorted([
                join(self.id, f) for f in listdir(path)
                if isfile(join(path, f)) and '_traces.npy' in f
            ])
        elif self.format == "sigmf":  # .meta
            self.trace_set_paths = None
            raise NotImplementedError
        elif self.format == "gnuradio":  # .cfile
            self.trace_set_paths = None
            raise NotImplementedError
        elif self.format == "ascad":  # ASCAD .h5
            # Hack to force split between validation and training set in ASCAD
            validation_set = join(
                self.root,
                'ASCAD/ASCAD_data/ASCAD_databases/%s.h5-val' % self.id)
            training_set = join(
                self.root,
                'ASCAD/ASCAD_data/ASCAD_databases/%s.h5-train' % self.id)

            # Make sure we never use training set when attacking or classifying
            if emma_conf is not None and (
                    conf_has_op(emma_conf, 'attack')
                    or conf_has_op(emma_conf, 'classify')
                    or conf_has_op(emma_conf, 'dattack')
                    or conf_has_op(emma_conf, 'spattack')
                    or conf_has_op(emma_conf, 'pattack')):
                self.trace_set_paths = [validation_set]
            else:
                self.trace_set_paths = [validation_set, training_set]
        else:
            raise Exception("Unknown input format '%s'" % self.format)

        assert (len(self.trace_set_paths) > 0)

        # Assign reference signal
        reference_trace_set = emio.get_trace_set(join(self.root,
                                                      self.trace_set_paths[0]),
                                                 self.format,
                                                 ignore_malformed=False,
                                                 remote=False)

        self.traces_per_set = len(reference_trace_set.traces)
        self.reference_signal = reference_trace_set.traces[
            self.reference_index].signal
コード例 #4
0
parser.add_argument("--limit",
                    type=int,
                    default=0,
                    help="Limit number of trace sets (0=infinite)")
args = parser.parse_args()

dataset_name = args.dataset_name
dataset_name_pca = dataset_name + "-pca"

# Gather signals
dataset = get_dataset(dataset_name, remote=False)
all_signals = []
for count, trace_set_path in enumerate(dataset.trace_set_paths):
    print("\rGathering signal %d           " % count)
    trace_set_path = os.path.join(dataset.root, trace_set_path)
    trace_set = get_trace_set(trace_set_path, "cw", remote=False)
    for trace in trace_set.traces:
        all_signals.append(trace.signal)
    if count + 1 == args.limit:
        break
all_signals = np.array(all_signals)

# Do PCA
print("Performing PCA")
pca = PCA()
pca.fit(all_signals)

# Save PCA model
with open("pca-components-%s.p" % dataset_name, 'wb') as f:
    pickle.dump(pca, f)
    print("Dumped PCA model to pca-components-%s.p" % dataset_name)
コード例 #5
0
                        nargs='+')
    args, unknown = parser.parse_known_args()

    table = PrettyTable(['Dataset', 'Num items', 'Mean', 'Std'])
    for dataset_name in args.dataset:
        dataset = emio.get_dataset(dataset_name)
        print("Dataset: %s\nFormat: %s" % (dataset.id, dataset.format))
        mean_sum = 0.0
        std_sum = 0.0
        n = 0

        # Calculate mean
        for trace_set_path in dataset.trace_set_paths:
            trace_set = emio.get_trace_set(join(dataset.prefix,
                                                trace_set_path),
                                           dataset.format,
                                           ignore_malformed=False,
                                           remote=False)
            for trace in trace_set.traces:
                mean_sum += np.sum(trace.signal)
                n += len(trace.signal)

        mean = mean_sum / n

        # Calculate std dev
        for trace_set_path in dataset.trace_set_paths:
            trace_set = emio.get_trace_set(join(dataset.prefix,
                                                trace_set_path),
                                           dataset.format,
                                           ignore_malformed=False,
                                           remote=False)
コード例 #6
0
def remote_get_trace_set(trace_set_path, format, ignore_malformed):
    return emio.get_trace_set(trace_set_path,
                              format,
                              ignore_malformed,
                              remote=False)