Пример #1
0
def get_ascad_trace_set(name, data, meta, limit=None):
    """
    Convert ASCAD data to a TraceSet object.
    """
    data_x, data_y = data
    traces = []
    plaintexts = []
    keys = []
    masks = []
    limit = len(data_x) if limit is None else min(len(data_x), limit)

    for i in range(0, limit):
        traces.append(data_x[i])
        plaintexts.append(meta[i]['plaintext'])
        keys.append(meta[i]['key'])
        masks.append(meta[i]['masks'])

    traces = np.array(traces)
    plaintexts = np.array(plaintexts)
    keys = np.array(keys)
    masks = np.array(masks)

    trace_set = TraceSet(name='ascad-%s' % name, traces=traces, plaintexts=plaintexts, ciphertexts=None, keys=keys, masks=masks)
    trace_set.window = Window(begin=0, end=len(trace_set.traces[0].signal))
    trace_set.windowed = True

    return trace_set
Пример #2
0
    def test_fft_trace_set(self):
        traces = np.array([[0, 1, 2]])

        ts = TraceSet(traces=traces)
        ops.fft_trace_set(ts, None, None, None)

        self.assertListEqual([round(x, 8) for x in list(ts.traces[0].signal)],
                             [3. + 0.j, -1.5 + 0.8660254j, -1.5 - 0.8660254j])
Пример #3
0
    def test_spectogram_trace_set(self):
        traces = np.array([[0, 1, 2]])

        ts = TraceSet(traces=traces)
        ops.spectogram_trace_set(ts, None, None, None)

        self.assertListEqual([round(x, 8) for x in list(ts.traces[0].signal)],
                             [9., 3., 3.])
Пример #4
0
    def test_filterkey_trace_set(self):
        traces = np.array([[0], [1], [2]])
        keys = np.array([[0], [1], [2]])

        ts = TraceSet(traces=traces, keys=keys)
        conf = Namespace()
        ops.filterkey_trace_set(ts, None, conf, params=['01'])

        self.assertEqual(len(ts.traces), 1)
        self.assertListEqual(list(ts.traces[0].signal), list(traces[1]))
Пример #5
0
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}

        if epoch % self.metric_freq != 0 or epoch == 0:
            return
        if self.trace_set is not None:
            # Fetch inputs from trace_set
            x = AIInput(self.conf).get_trace_set_inputs(self.trace_set)

            if self.cnn:
                x = np.expand_dims(x, axis=-1)

            encodings = self.model.predict(x)  # Output: [?, 16]

            # Store encodings as fake traceset
            keys = np.array([trace.key for trace in self.trace_set.traces])
            plaintexts = np.array(
                [trace.plaintext for trace in self.trace_set.traces])
            fake_ts = TraceSet(traces=encodings,
                               plaintexts=plaintexts,
                               keys=keys,
                               name="fake_ts")
            fake_ts.window = Window(begin=0, end=encodings.shape[1])
            fake_ts.windowed = True

            for i in range(self.key_low, self.key_high):
                if len(set(keys[:, i])) > 1:
                    print(
                        "Warning: nonidentical key bytes detected. Skipping rank calculation"
                    )
                    print("Subkey %d:" % i)
                    print(keys[:, i])
                    break
                rank, confidence = calculate_traceset_rank(
                    fake_ts, i, keys[0][i], self.conf
                )  # TODO: It is assumed here that all true keys of the test set are the same
                self._save_best_rank_model(rank, confidence)
                logs['rank %d' % i] = rank
                logs['confidence %d' % i] = confidence
            #self._save_best_rank_model(np.mean(ranks))
        else:
            print("Warning: no trace_set supplied to RankCallback")
Пример #6
0
    def test_window_trace_set(self):
        traces = np.array([[1], [1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4]])
        params = [1, 5]
        expected = np.array([[0, 0, 0, 0], [2, 3, 4, 5], [2, 3, 4, 0]])

        ts = TraceSet(traces=traces)
        conf = Namespace(windowing_method="rectangular")
        ops.window_trace_set(ts, None, conf, params=params)

        for i in range(0, len(traces)):
            self.assertListEqual(list(ts.traces[i].signal), list(expected[i]))
Пример #7
0
    def test_normalize_trace_set(self):
        traces = np.array([
            [10, 16, 19],
        ])
        expected = np.array([
            [-5, 1, 4],
        ])

        ts = TraceSet(traces=traces)
        ops.normalize_trace_set(ts, None, None, None)
        for i in range(0, len(traces)):
            self.assertListEqual(list(ts.traces[i].signal), list(expected[i]))
Пример #8
0
    def test_align_trace_set(self):
        traces = np.array([[0, 1, 0, 8, 10, 8, 0, 1, 0], [8, 8, 11, 8],
                           [8, 10, 8, 0]])
        expected = np.array([[8, 10, 8, 0, 1, 0], [8, 11, 8], [8, 10, 8, 0]])
        reference_signal = np.array([8, 10, 8])
        conf = Namespace(reference_signal=reference_signal,
                         butter_cutoff=0.1,
                         butter_order=1)

        ts = TraceSet(traces=traces, name='test')
        ops.align_trace_set(ts, None, conf, params=[0, len(reference_signal)])
        for i in range(0, len(ts.traces)):
            self.assertListEqual(list(ts.traces[i].signal), expected[i])
Пример #9
0
    def get_all_as_trace_set(self, limit=None):
        if limit is None:
            traces_to_get = self.trace_set_paths
        else:
            traces_to_get = self.trace_set_paths[0:limit]

        result = EMResult(
            task_id=self.request_id)  # Make new collection of results
        ops.process_trace_set_paths(
            result,
            traces_to_get,
            self.conf,
            keep_trace_sets=True,
            request_id=self.request_id)  # Store processed trace path in result

        all_traces = []
        for trace_set in result.trace_sets:
            all_traces.extend(trace_set.traces)

        result = TraceSet(name="all_traces")
        result.set_traces(all_traces)

        return result
Пример #10
0
def simulate_traces_random(args, train=True):
    """
    Untested. Simulate traces randomly, without artificial noise. Interesting for seeing true effect of random keys, but slow.
    :param args:
    :return:
    """
    specs = get_algorithm_specs(args.algorithm)
    key = random_bytes(specs.key_len)
    if train is False:
        print("Test set key: " + bytearray(key).hex())

    for i in range(0, args.num_trace_sets):
        traces = []
        print("\rSimulating trace set %d/%d...                          " %
              (i, args.num_trace_sets),
              end='')
        for j in range(0, args.num_traces_per_set):
            if train:
                key = random_bytes(specs.key_len)
            plaintext = random_bytes(specs.plaintext_len)
            key_string = binascii.hexlify(key).decode('utf-8')
            plaintext_string = binascii.hexlify(plaintext).decode('utf-8')

            sim = ProgramSimulation(specs.executable,
                                    (key_string, plaintext_string),
                                    specs.method,
                                    REGS_TO_CHECK,
                                    args=args)
            signal = sim.run()

            t = Trace(signal=signal,
                      plaintext=plaintext,
                      ciphertext=None,
                      key=key,
                      mask=None)
            traces.append(t)

        # Make TraceSet
        ts = TraceSet(name="sim-%s-%d" % (args.algorithm, i))
        ts.set_traces(traces)
        dataset_name = "sim-%s-%s" % (args.algorithm, args.mode)
        ts.save(join(args.output_directory, dataset_name + args.suffix))
Пример #11
0
def simulate_traces_noisy(args):
    specs = get_algorithm_specs(args.algorithm)

    key = random_bytes(specs.key_len)
    for i in range(0, 256):
        print("\rSimulating noisy trace sets for key %d...      " % i, end='')
        key[2] = i
        plaintext = random_bytes(specs.plaintext_len)
        key_string = binascii.hexlify(key).decode('utf-8')
        plaintext_string = binascii.hexlify(plaintext).decode('utf-8')

        sim = ProgramSimulation(specs.executable,
                                (key_string, plaintext_string),
                                specs.method,
                                REGS_TO_CHECK,
                                args=args)
        signal = sim.run()

        traces = []
        for j in range(0, args.num_traces_per_set):
            mod_signal = signal + np.random.normal(args.mu, args.sigma,
                                                   len(signal))
            t = Trace(signal=mod_signal,
                      plaintext=plaintext,
                      ciphertext=None,
                      key=key,
                      mask=None)
            traces.append(t)

            # Debug
            if args.debug:
                plt.plot(mod_signal)
                plt.show()

        # Make TraceSet
        ts = TraceSet(name="sim-noisy-%s-%d" % (args.algorithm, i))
        ts.set_traces(traces)
        dataset_name = "sim-noisy-%s" % args.algorithm
        ts.save(join(args.output_directory, dataset_name + args.suffix))
Пример #12
0
    def test_autoenctrain(self):
        """
        Artificial example to test AutoEncoder
        """

        # ------------------------------
        # Generate data
        # ------------------------------
        traces = [  # Contains abs(trace). Shape = [trace, point]
            [1, 1, 1, -15],
            [-4, 1, 2, -12],
            [10, 1, 3, 8],
            [8, 1, 1, -14],
            [9, 1, -3, 8],
        ]

        plaintexts = [
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        ]

        keys = [
            [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        ]

        # Convert to numpy
        traces = np.array(traces)
        plaintexts = np.array(plaintexts)
        keys = np.array(keys)

        trace_set = TraceSet(name='test',
                             traces=traces,
                             plaintexts=plaintexts,
                             keys=keys)

        # ------------------------------
        # Preprocess data
        # ------------------------------
        conf = Namespace(
            max_cache=0,
            augment_roll=False,
            augment_noise=False,
            normalize=False,
            traces_per_set=4,
            online=False,
            dataset_id='qa',
            cnn=False,
            leakage_model=LeakageModelType.HAMMING_WEIGHT_SBOX,
            input_type=AIInputType.SIGNAL,
            augment_shuffle=True,
            n_hidden_layers=1,
            n_hidden_nodes=256,
            activation='leakyrelu',
            metric_freq=100,
            regularizer=None,
            reglambda=0.001,
            model_suffix=None,
            use_bias=True,
            batch_norm=True,
            hamming=False,
            key_low=2,
            key_high=3,
            loss_type='correlation',
            lr=0.0001,
            epochs=2000,
            batch_size=512,
            norank=False,
        )
        it_dummy = AutoEncoderSignalIterator([],
                                             conf,
                                             batch_size=10000,
                                             request_id=None,
                                             stream_server=None)
        x, y = it_dummy._preprocess_trace_set(trace_set)

        # ------------------------------
        # Train and obtain encodings
        # ------------------------------
        model = ai.AutoEncoder(conf, input_dim=4, name="test")
        print(model.info())

        # Find optimal weights
        print("X, Y")
        print(x)
        print(y)
        print(
            "When feeding x through the model without training, the encodings become:"
        )
        print(model.predict(x))
        print("Training now")
        model.train_set(x, y, epochs=conf.epochs)
        print("Done training")

        # Get the encodings of the input data using the same approach used in ops.py corrtest (iterate over rows)
        result = []
        for i in range(0, x.shape[0]):
            result.append(
                model.predict(np.array([x[i, :]], dtype=float))[0]
            )  # Result contains sum of points such that corr with y[key_index] is maximal for all key indices. Shape = [trace, 16]
        result = np.array(result)

        for i in range(result.shape[0]):
            rounded_result = np.round(result[i])
            print("Original x    : %s" % x[i])
            print("Rounded result: %s" % rounded_result)
            self.assertListEqual(list(rounded_result), list(x[i]))
Пример #13
0
    def test_corrtrain_correlation_multi(self):
        from leakagemodels import LeakageModel
        """
        Artificial example to test AICorrNet and trace processing with multiple leakage values and multiple subkeys.
        """

        # ------------------------------
        # Generate data
        # ------------------------------
        traces = [  # Contains abs(trace). Shape = [trace, point]
            [1, 1, 1, -15],
            [-4, 2, 2, -12],
            [10, 3, 3, 8],
            [8, 1, 1, -14],
            [9, 0, -3, 8],
        ]

        plaintexts = [
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 13, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 15, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 8, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        ]

        keys = [
            [0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        ]

        # Convert to numpy
        traces = np.array(traces)
        plaintexts = np.array(plaintexts)
        keys = np.array(keys)

        trace_set = TraceSet(name='test',
                             traces=traces,
                             plaintexts=plaintexts,
                             keys=keys)

        # ------------------------------
        # Preprocess data
        # ------------------------------
        conf = Namespace(
            max_cache=0,
            augment_roll=False,
            augment_noise=False,
            normalize=False,
            traces_per_set=4,
            online=False,
            dataset_id='qa',
            cnn=False,
            leakage_model=LeakageModelType.AES_MULTI,
            input_type=AIInputType.SIGNAL,
            augment_shuffle=True,
            n_hidden_layers=1,
            n_hidden_nodes=256,
            activation='leakyrelu',
            metric_freq=100,
            regularizer=None,
            reglambda=0.001,
            model_suffix=None,
            use_bias=True,
            batch_norm=True,
            hamming=False,
            key_low=1,
            key_high=3,
            loss_type='correlation',
            lr=0.001,
            epochs=5000,
            batch_size=512,
            norank=False,
        )
        it_dummy = AICorrSignalIterator([],
                                        conf,
                                        batch_size=10000,
                                        request_id=None,
                                        stream_server=None)
        x, y = it_dummy._preprocess_trace_set(trace_set)

        # ------------------------------
        # Train and obtain encodings
        # ------------------------------
        model = ai.AICorrNet(conf, input_dim=4, name="test")
        print(model.info())
        rank_cb = rank.CorrRankCallback(conf,
                                        '/tmp/deleteme/',
                                        save_best=False,
                                        save_path=None)
        rank_cb.set_trace_set(trace_set)

        if model.using_regularization:
            print(
                "Warning: cant do correlation loss test because regularizer will influence loss function"
            )
            return

        # Find optimal weights
        print("The x (EM samples) and y (leakage model values) are:")
        print(x)
        print(y)
        print(
            "When feeding x through the model without training, the encodings become:"
        )
        print(model.predict(x))
        print("Training now")
        model.train_set(x,
                        y,
                        save=False,
                        epochs=conf.epochs,
                        extra_callbacks=[rank_cb])
        print("Done training")

        # Get the encodings of the input data using the same approach used in ops.py corrtest (iterate over rows)
        result = []
        for i in range(0, x.shape[0]):
            result.append(
                model.predict(np.array([x[i, :]], dtype=float))[0]
            )  # Result contains sum of points such that corr with y[key_index] is maximal for all key indices. Shape = [trace, 16]
        result = np.array(result)
        print(
            "When feeding x through the model after training, the encodings for key bytes %d to %d become:\n %s"
            % (conf.key_low, conf.key_high, str(result)))

        # ------------------------------
        # Check loss function
        # ------------------------------
        # Evaluate the model to get the loss for the encodings
        predicted_loss = model.model.evaluate(x, y, verbose=0)

        # Manually calculate the loss using numpy to verify that we are learning a correct correlation
        calculated_loss = 0
        num_keys = (conf.key_high - conf.key_low)
        num_outputs = LeakageModel.get_num_outputs(conf) // num_keys
        for i in range(0, num_keys):
            subkey_hws = y[:, i * num_outputs:(i + 1) * num_outputs]
            subkey_encodings = result[:, i * num_outputs:(i + 1) * num_outputs]
            print("Subkey %d HWs   : %s" % (i + conf.key_low, str(subkey_hws)))
            print("Subkey %d encodings: %s" %
                  (i + conf.key_low, str(subkey_encodings)))
            y_key = subkey_hws.reshape([-1, 1])
            y_pred = subkey_encodings.reshape([-1, 1])
            print("Flattened subkey %d HWs   : %s" %
                  (i + conf.key_low, str(y_key)))
            print("Flattened subkey %d encodings: %s" %
                  (i + conf.key_low, str(y_pred)))

            # Calculate correlation (numpy approach)
            corr_key_i = np.corrcoef(y_pred[:, 0], y_key[:, 0],
                                     rowvar=False)[1, 0]
            print("corr_num: %s" % corr_key_i)

            calculated_loss += 1.0 - corr_key_i

        print("These values should be close:")
        print("Predicted loss: %s" % str(predicted_loss))
        print("Calculated loss: %s" % str(calculated_loss))
        self.assertAlmostEqual(predicted_loss, calculated_loss, places=2)
Пример #14
0
    def process_ctrl_packet(self, pkt_type, payload):
        if pkt_type == CtrlPacketType.SIGNAL_START:
            logger.debug("Starting for payload: %s" % binary_to_hex(payload))
            self.parse_ies(payload)
            self.sdr.start()

            # Spinlock until data
            timeout = 3
            current_time = 0.0
            while len(self.stored_data) <= self.wait_num_chunks:
                sleep(0.0001)
                current_time += 0.0001
                if current_time >= timeout:
                    logger.warning("Timeout while waiting for data. Did the SDR crash? Reinstantiating...")
                    del self.sdr
                    self.data_socket.socket.close()
                    self.data_socket = SocketWrapper(socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM), ('127.0.0.1', 3884), self.cb_data)
                    self.data_socket.start()
                    self.sdr = SDR(**self.cap_kwargs)
                    self.process_ctrl_packet(pkt_type, payload)
        elif pkt_type == CtrlPacketType.SIGNAL_END:
            # self.sdr.sdr_source.stop()
            self.sdr.stop()
            self.sdr.wait()

            logger.debug("Stopped after receiving %d chunks" % len(self.stored_data))
            #sleep(0.5)
            #logger.debug("After sleep we have %d chunks" % len(self.stored_data))

            # Successful capture (no errors or timeouts)
            if len(self.stored_data) > 0:  # We have more than 1 chunk
                # Data to file
                np_data = np.fromstring(b"".join(self.stored_data), dtype=np.complex64)
                self.trace_set.append(np_data)
                self.plaintexts.append(self.stored_plaintext)
                self.keys.append(self.stored_key)

                if len(self.trace_set) >= self.kwargs['traces_per_set']:
                    assert(len(self.trace_set) == len(self.plaintexts))
                    assert(len(self.trace_set) == len(self.keys))

                    np_trace_set = np.array(self.trace_set)
                    np_plaintexts = np.array(self.plaintexts, dtype=np.uint8)
                    np_keys = np.array(self.keys, dtype=np.uint8)

                    if not self.online is None: # Stream online
                        ts = TraceSet(name="online %d" % self.online_counter, traces=np_trace_set, plaintexts=np_plaintexts, ciphertexts=None, keys=np_keys)
                        logger.info("Pickling")
                        ts_p = pickle.dumps(ts)
                        logger.info("Size is %d" % len(ts_p))
                        stream_payload = ts_p
                        stream_payload_len = len(stream_payload)
                        logger.info("Streaming trace set of %d bytes to server" % stream_payload_len)
                        stream_hdr = struct.pack(">BI", 0, stream_payload_len)
                        self.emma_client.send(stream_hdr + stream_payload)
                        self.online_counter += 1
                    else: # Save to disk
                        if not self.kwargs['dry']:
                            # Write metadata to sigmf file
                            # if sigmf
                            #with open(test_meta_path, 'w') as f:
                            #    test_sigmf = SigMFFile(data_file=test_data_path, global_info=copy.deepcopy(self.global_meta))
                            #    test_sigmf.add_capture(0, metadata=capture_meta)
                            #    test_sigmf.dump(f, pretty=True)
                            # elif chipwhisperer:
                            logger.info("Dumping %d traces to file" % len(self.trace_set))
                            filename = str(datetime.utcnow()).replace(" ","_").replace(".","_")
                            output_dir = self.kwargs['output_dir']
                            np.save(os.path.join(output_dir, "%s_traces.npy" % filename), np_trace_set)  # TODO abstract this in trace_set class
                            np.save(os.path.join(output_dir, "%s_textin.npy" % filename), np_plaintexts)
                            np.save(os.path.join(output_dir, "%s_knownkey.npy" % filename), np_keys)
                            if self.compress:
                                logger.info("Calling emcap-compress...")
                                subprocess.call(['/usr/bin/python', 'emcap-compress.py', os.path.join(output_dir, "%s_traces.npy" % filename)])

                        self.limit_counter += len(self.trace_set)
                        if self.limit_counter >= self.limit:
                            print("Done")
                            exit(0)

                    # Clear results
                    self.trace_set = []
                    self.plaintexts = []
                    self.keys = []

                # Clear
                self.stored_data = []
                self.stored_plaintext = []
Пример #15
0
def _get_trace_set(trace_set_path, format, ignore_malformed=True):
    """
    Load traces in from absolute path trace_set_path into a TraceSet object depending on the format.
    """

    if format == "cw":
        name = trace_set_path.rpartition('_traces')[0]
        plaintext_set_path = name + '_textin.npy'
        ciphertext_set_path = name + '_textout.npy'
        key_set_path = name + '_knownkey.npy'

        existing_properties = []
        try:
            traces = np.load(trace_set_path, encoding="bytes")
            existing_properties.append(traces)
        except FileNotFoundError:
            traces = None

        try:
            plaintexts = np.load(plaintext_set_path, encoding="bytes")
            existing_properties.append(plaintexts)
        except FileNotFoundError:
            print("WARNING: No plaintext for trace %s" % name)
            plaintexts = None

        try:
            ciphertexts = np.load(ciphertext_set_path, encoding="bytes")
            existing_properties.append(ciphertexts)
        except FileNotFoundError:
            ciphertexts = None

        try:
            keys = np.load(key_set_path, encoding="bytes")
            existing_properties.append(keys)
        except FileNotFoundError:
            keys = np.array([[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]]*traces.shape[0])
            print("No key file found! Using 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15")
            #keys = None

        masks = None  # No masks for Arduino experiments

        if ignore_malformed:  # Discard malformed traces
            for property in existing_properties:
                if traces.shape[0] != property.shape[0]:
                    return None

            return TraceSet(name=name, traces=traces, plaintexts=plaintexts, ciphertexts=ciphertexts, keys=keys, masks=masks)
        else:  # Just truncate malformed traces instead of discarding
            if not traces is None:
                traces = traces[0:len(plaintexts)]
            if not ciphertexts is None:
                ciphertexts = ciphertexts[0:len(plaintexts)]
            if not keys is None:
                keys = keys[0:len(plaintexts)]
            if not masks is None:
                masks = masks[0:len(plaintexts)]

            return TraceSet(name=name, traces=traces, plaintexts=plaintexts, ciphertexts=ciphertexts, keys=keys, masks=masks)
    elif format == "sigmf":  # .meta
        raise NotImplementedError
    elif format == "gnuradio":  # .cfile
        raise NotImplementedError
    elif format == "ascad":
        from ASCAD_train_models import load_ascad
        h5_path = trace_set_path.rpartition('-')[0]
        train_set, attack_set, metadata_set = load_ascad(h5_path, load_metadata=True)
        metadata_train, metadata_attack = metadata_set

        if trace_set_path.endswith('-train'):
            return get_ascad_trace_set('train', train_set, metadata_train)
        elif trace_set_path.endswith('-val'):
            return get_ascad_trace_set('validation', attack_set, metadata_attack)
    else:
        print("Unknown trace input format '%s'" % format)
        exit(1)

    return None