Add urh
This commit is contained in:
65
Software/Universal Radio Hacker/tests/awre/AWRETestCase.py
Normal file
65
Software/Universal Radio Hacker/tests/awre/AWRETestCase.py
Normal file
@ -0,0 +1,65 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy
|
||||
from urh.awre.FormatFinder import FormatFinder
|
||||
|
||||
from tests.utils_testing import get_path_for_data_file
|
||||
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
|
||||
|
||||
from urh.signalprocessing.MessageType import MessageType
|
||||
|
||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
|
||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
|
||||
from urh.signalprocessing.FieldType import FieldType
|
||||
|
||||
|
||||
class AWRETestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
numpy.set_printoptions(linewidth=80)
|
||||
self.field_types = self.__init_field_types()
|
||||
|
||||
def get_format_finder_from_protocol_file(self, filename: str, clear_participant_addresses=True, return_messages=False):
|
||||
proto_file = get_path_for_data_file(filename)
|
||||
protocol = ProtocolAnalyzer(signal=None, filename=proto_file)
|
||||
protocol.from_xml_file(filename=proto_file, read_bits=True)
|
||||
|
||||
self.clear_message_types(protocol.messages)
|
||||
|
||||
ff = FormatFinder(protocol.messages)
|
||||
if clear_participant_addresses:
|
||||
ff.known_participant_addresses.clear()
|
||||
|
||||
if return_messages:
|
||||
return ff, protocol.messages
|
||||
else:
|
||||
return ff
|
||||
|
||||
@staticmethod
|
||||
def __init_field_types():
|
||||
result = []
|
||||
for field_type_function in FieldType.Function:
|
||||
result.append(FieldType(field_type_function.value, field_type_function))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def clear_message_types(messages: list):
|
||||
mt = MessageType("empty")
|
||||
for msg in messages:
|
||||
msg.message_type = mt
|
||||
|
||||
@staticmethod
|
||||
def save_protocol(name, protocol_generator, silent=False):
|
||||
filename = os.path.join(tempfile.gettempdir(), name + ".proto")
|
||||
if isinstance(protocol_generator, ProtocolGenerator):
|
||||
protocol_generator.to_file(filename)
|
||||
elif isinstance(protocol_generator, ProtocolAnalyzer):
|
||||
participants = list(set(msg.participant for msg in protocol_generator.messages))
|
||||
protocol_generator.to_xml_file(filename, [], participants=participants, write_bits=True)
|
||||
info = "Protocol written to " + filename
|
||||
if not silent:
|
||||
print()
|
||||
print("-" * len(info))
|
||||
print(info)
|
||||
print("-" * len(info))
|
791
Software/Universal Radio Hacker/tests/awre/AWRExperiments.py
Normal file
791
Software/Universal Radio Hacker/tests/awre/AWRExperiments.py
Normal file
@ -0,0 +1,791 @@
|
||||
import array
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from tests.awre.AWRETestCase import AWRETestCase
|
||||
from tests.utils_testing import get_path_for_data_file
|
||||
from urh.awre.FormatFinder import FormatFinder
|
||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
|
||||
from urh.awre.Preprocessor import Preprocessor
|
||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
|
||||
from urh.awre.engines.Engine import Engine
|
||||
from urh.signalprocessing.FieldType import FieldType
|
||||
from urh.signalprocessing.Message import Message
|
||||
from urh.signalprocessing.MessageType import MessageType
|
||||
from urh.signalprocessing.Participant import Participant
|
||||
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
|
||||
from urh.util.GenericCRC import GenericCRC
|
||||
|
||||
|
||||
def run_for_num_broken(protocol_nr, num_broken: list, num_messages: int, num_runs: int) -> list:
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
result = []
|
||||
for broken in num_broken:
|
||||
tmp_accuracies = np.empty(num_runs, dtype=np.float64)
|
||||
tmp_accuracies_without_broken = np.empty(num_runs, dtype=np.float64)
|
||||
for i in range(num_runs):
|
||||
protocol, expected_labels = AWRExperiments.get_protocol(protocol_nr,
|
||||
num_messages=num_messages,
|
||||
num_broken_messages=broken,
|
||||
silent=True)
|
||||
|
||||
AWRExperiments.run_format_finder_for_protocol(protocol)
|
||||
accuracy = AWRExperiments.calculate_accuracy(protocol.messages, expected_labels)
|
||||
accuracy_without_broken = AWRExperiments.calculate_accuracy(protocol.messages, expected_labels, broken)
|
||||
tmp_accuracies[i] = accuracy
|
||||
tmp_accuracies_without_broken[i] = accuracy_without_broken
|
||||
|
||||
avg_accuracy = np.mean(tmp_accuracies)
|
||||
avg_accuracy_without_broken = np.mean(tmp_accuracies_without_broken)
|
||||
|
||||
result.append((avg_accuracy, avg_accuracy_without_broken))
|
||||
print("Protocol {} with {} broken: {:>3}% {:>3}%".format(protocol_nr, broken, int(avg_accuracy),
|
||||
int(avg_accuracy_without_broken)))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class AWRExperiments(AWRETestCase):
|
||||
@staticmethod
|
||||
def _prepare_protocol_1() -> ProtocolGenerator:
|
||||
alice = Participant("Alice", address_hex="dead")
|
||||
bob = Participant("Bob", address_hex="beef")
|
||||
|
||||
mb = MessageTypeBuilder("data")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x1337"},
|
||||
participants=[alice, bob])
|
||||
return pg
|
||||
|
||||
@staticmethod
|
||||
def _prepare_protocol_2() -> ProtocolGenerator:
|
||||
alice = Participant("Alice", address_hex="dead01")
|
||||
bob = Participant("Bob", address_hex="beef24")
|
||||
|
||||
mb = MessageTypeBuilder("data")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 72)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 24)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 24)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x1337"},
|
||||
preambles_by_mt={mb.message_type: "10" * 36},
|
||||
sequence_number_increment=32,
|
||||
participants=[alice, bob])
|
||||
|
||||
return pg
|
||||
|
||||
@staticmethod
|
||||
def _prepare_protocol_3() -> ProtocolGenerator:
|
||||
alice = Participant("Alice", address_hex="1337")
|
||||
bob = Participant("Bob", address_hex="beef")
|
||||
|
||||
checksum = GenericCRC.from_standard_checksum("CRC8 CCITT")
|
||||
|
||||
mb = MessageTypeBuilder("data")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 16)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
|
||||
mb.add_label(FieldType.Function.DATA, 10 * 8)
|
||||
mb.add_checksum_label(8, checksum)
|
||||
|
||||
mb_ack = MessageTypeBuilder("ack")
|
||||
mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
|
||||
mb_ack.add_label(FieldType.Function.SYNC, 16)
|
||||
mb_ack.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
mb_ack.add_checksum_label(8, checksum)
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"},
|
||||
preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8},
|
||||
participants=[alice, bob])
|
||||
|
||||
return pg
|
||||
|
||||
@staticmethod
|
||||
def _prepare_protocol_4() -> ProtocolGenerator:
|
||||
alice = Participant("Alice", address_hex="1337")
|
||||
bob = Participant("Bob", address_hex="beef")
|
||||
|
||||
checksum = GenericCRC.from_standard_checksum("CRC16 CCITT")
|
||||
|
||||
mb = MessageTypeBuilder("data1")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 16)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.DATA, 8 * 8)
|
||||
mb.add_checksum_label(16, checksum)
|
||||
|
||||
mb2 = MessageTypeBuilder("data2")
|
||||
mb2.add_label(FieldType.Function.PREAMBLE, 16)
|
||||
mb2.add_label(FieldType.Function.SYNC, 16)
|
||||
mb2.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb2.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||||
mb2.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
mb2.add_label(FieldType.Function.DATA, 64 * 8)
|
||||
mb2.add_checksum_label(16, checksum)
|
||||
|
||||
mb_ack = MessageTypeBuilder("ack")
|
||||
mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
|
||||
mb_ack.add_label(FieldType.Function.SYNC, 16)
|
||||
mb_ack.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
mb_ack.add_checksum_label(16, checksum)
|
||||
|
||||
mt1, mt2, mt3 = mb.message_type, mb2.message_type, mb_ack.message_type
|
||||
|
||||
preamble = "10001000" * 2
|
||||
|
||||
pg = ProtocolGenerator([mt1, mt2, mt3],
|
||||
syncs_by_mt={mt1: "0x9a7d", mt2: "0x9a7d", mt3: "0x9a7d"},
|
||||
preambles_by_mt={mt1: preamble, mt2: preamble, mt3: preamble},
|
||||
participants=[alice, bob])
|
||||
|
||||
return pg
|
||||
|
||||
@staticmethod
|
||||
def _prepare_protocol_5() -> ProtocolGenerator:
|
||||
alice = Participant("Alice", address_hex="1337")
|
||||
bob = Participant("Bob", address_hex="beef")
|
||||
carl = Participant("Carl", address_hex="cafe")
|
||||
|
||||
mb = MessageTypeBuilder("data")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 16)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
|
||||
|
||||
mb_ack = MessageTypeBuilder("ack")
|
||||
mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
|
||||
mb_ack.add_label(FieldType.Function.SYNC, 16)
|
||||
mb_ack.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"},
|
||||
preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8},
|
||||
participants=[alice, bob, carl])
|
||||
|
||||
return pg
|
||||
|
||||
@staticmethod
|
||||
def _prepare_protocol_6() -> ProtocolGenerator:
|
||||
alice = Participant("Alice", address_hex="24")
|
||||
broadcast = Participant("Bob", address_hex="ff")
|
||||
|
||||
mb = MessageTypeBuilder("data")
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 8)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x8e88"},
|
||||
preambles_by_mt={mb.message_type: "10" * 8},
|
||||
participants=[alice, broadcast])
|
||||
|
||||
return pg
|
||||
|
||||
@staticmethod
|
||||
def _prepare_protocol_7() -> ProtocolGenerator:
|
||||
alice = Participant("Alice", address_hex="313370")
|
||||
bob = Participant("Bob", address_hex="031337")
|
||||
charly = Participant("Charly", address_hex="110000")
|
||||
daniel = Participant("Daniel", address_hex="001100")
|
||||
# broadcast = Participant("Broadcast", address_hex="ff") #TODO: Sometimes messages to broadcast
|
||||
|
||||
checksum = GenericCRC.from_standard_checksum("CRC16 CC1101")
|
||||
|
||||
mb = MessageTypeBuilder("data")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 16)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 24)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 24)
|
||||
mb.add_label(FieldType.Function.DATA, 8 * 8)
|
||||
mb.add_checksum_label(16, checksum)
|
||||
|
||||
mb_ack = MessageTypeBuilder("ack")
|
||||
mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb_ack.add_label(FieldType.Function.SYNC, 16)
|
||||
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 24)
|
||||
mb_ack.add_checksum_label(16, checksum)
|
||||
|
||||
mb_kex = MessageTypeBuilder("kex")
|
||||
mb_kex.add_label(FieldType.Function.PREAMBLE, 24)
|
||||
mb_kex.add_label(FieldType.Function.SYNC, 16)
|
||||
mb_kex.add_label(FieldType.Function.DST_ADDRESS, 24)
|
||||
mb_kex.add_label(FieldType.Function.SRC_ADDRESS, 24)
|
||||
mb_kex.add_label(FieldType.Function.DATA, 64 * 8)
|
||||
mb_kex.add_checksum_label(16, checksum)
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type, mb_kex.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x0420", mb_ack.message_type: "0x2222",
|
||||
mb_kex.message_type: "0x6767"},
|
||||
preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 4,
|
||||
mb_kex.message_type: "10" * 12},
|
||||
participants=[alice, bob, charly, daniel])
|
||||
|
||||
return pg
|
||||
|
||||
@staticmethod
|
||||
def _prepare_protocol_8() -> ProtocolGenerator:
|
||||
alice = Participant("Alice")
|
||||
|
||||
mb = MessageTypeBuilder("data1")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 4)
|
||||
mb.add_label(FieldType.Function.SYNC, 4)
|
||||
mb.add_label(FieldType.Function.LENGTH, 16)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
|
||||
mb.add_label(FieldType.Function.DATA, 8 * 542)
|
||||
|
||||
mb2 = MessageTypeBuilder("data2")
|
||||
mb2.add_label(FieldType.Function.PREAMBLE, 4)
|
||||
mb2.add_label(FieldType.Function.SYNC, 4)
|
||||
mb2.add_label(FieldType.Function.LENGTH, 16)
|
||||
mb2.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
|
||||
mb2.add_label(FieldType.Function.DATA, 8 * 260)
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type, mb2.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x9", mb2.message_type: "0x9"},
|
||||
preambles_by_mt={mb.message_type: "10" * 2, mb2.message_type: "10" * 2},
|
||||
sequence_number_increment=32,
|
||||
participants=[alice],
|
||||
little_endian=True)
|
||||
|
||||
return pg
|
||||
|
||||
def test_export_to_latex(self):
|
||||
filename = os.path.expanduser("~/GIT/publications/awre/USENIX/protocols.tex")
|
||||
if os.path.isfile(filename):
|
||||
os.remove(filename)
|
||||
|
||||
for i in range(1, 9):
|
||||
pg = getattr(self, "_prepare_protocol_" + str(i))()
|
||||
pg.export_to_latex(filename, i)
|
||||
|
||||
@classmethod
|
||||
def get_protocol(cls, protocol_number: int, num_messages, num_broken_messages=0, silent=False):
|
||||
if protocol_number == 1:
|
||||
pg = cls._prepare_protocol_1()
|
||||
elif protocol_number == 2:
|
||||
pg = cls._prepare_protocol_2()
|
||||
elif protocol_number == 3:
|
||||
pg = cls._prepare_protocol_3()
|
||||
elif protocol_number == 4:
|
||||
pg = cls._prepare_protocol_4()
|
||||
elif protocol_number == 5:
|
||||
pg = cls._prepare_protocol_5()
|
||||
elif protocol_number == 6:
|
||||
pg = cls._prepare_protocol_6()
|
||||
elif protocol_number == 7:
|
||||
pg = cls._prepare_protocol_7()
|
||||
elif protocol_number == 8:
|
||||
pg = cls._prepare_protocol_8()
|
||||
else:
|
||||
raise ValueError("Unknown protocol number")
|
||||
|
||||
messages_types_with_data_field = [mt for mt in pg.protocol.message_types
|
||||
if mt.get_first_label_with_type(FieldType.Function.DATA)]
|
||||
i = -1
|
||||
while len(pg.protocol.messages) < num_messages:
|
||||
i += 1
|
||||
source = pg.participants[i % len(pg.participants)]
|
||||
destination = pg.participants[(i + 1) % len(pg.participants)]
|
||||
if i % 2 == 0:
|
||||
data_bytes = 8
|
||||
else:
|
||||
# data_bytes = 16
|
||||
data_bytes = 64
|
||||
|
||||
if len(messages_types_with_data_field) == 0:
|
||||
# set data automatically
|
||||
data = "".join(random.choice(["0", "1"]) for _ in range(data_bytes * 8))
|
||||
pg.generate_message(data=data, source=source, destination=destination)
|
||||
else:
|
||||
# search for message type with right data length
|
||||
mt = messages_types_with_data_field[i % len(messages_types_with_data_field)]
|
||||
data_length = mt.get_first_label_with_type(FieldType.Function.DATA).length
|
||||
data = "".join(random.choice(["0", "1"]) for _ in range(data_length))
|
||||
pg.generate_message(message_type=mt, data=data, source=source, destination=destination)
|
||||
|
||||
ack_message_type = next((mt for mt in pg.protocol.message_types if "ack" in mt.name), None)
|
||||
if ack_message_type:
|
||||
pg.generate_message(message_type=ack_message_type, data="", source=destination, destination=source)
|
||||
|
||||
for i in range(num_broken_messages):
|
||||
msg = pg.protocol.messages[i]
|
||||
pos = random.randint(0, len(msg.plain_bits) // 2)
|
||||
msg.plain_bits[pos:] = array.array("B",
|
||||
[random.randint(0, 1) for _ in range(len(msg.plain_bits) - pos)])
|
||||
|
||||
if num_broken_messages == 0:
|
||||
cls.save_protocol("protocol{}_{}_messages".format(protocol_number, num_messages), pg, silent=silent)
|
||||
else:
|
||||
cls.save_protocol("protocol{}_{}_broken".format(protocol_number, num_broken_messages), pg, silent=silent)
|
||||
|
||||
expected_message_types = [msg.message_type for msg in pg.protocol.messages]
|
||||
|
||||
# Delete message type information -> no prior knowledge
|
||||
cls.clear_message_types(pg.protocol.messages)
|
||||
|
||||
# Delete data labels if present
|
||||
for mt in expected_message_types:
|
||||
data_lbl = mt.get_first_label_with_type(FieldType.Function.DATA)
|
||||
if data_lbl:
|
||||
mt.remove(data_lbl)
|
||||
|
||||
return pg.protocol, expected_message_types
|
||||
|
||||
@staticmethod
|
||||
def calculate_accuracy(messages, expected_labels, num_broken_messages=0):
|
||||
"""
|
||||
Calculate the accuracy of labels compared to expected labels
|
||||
Accuracy is 100% when labels == expected labels
|
||||
Accuracy drops by 1 / len(expected_labels) for every expected label not present in labels
|
||||
|
||||
:type messages: list of Message
|
||||
:type expected_labels: list of MessageType
|
||||
:return:
|
||||
"""
|
||||
accuracy = sum(len(set(expected_labels[i]) & set(messages[i].message_type)) / len(expected_labels[i])
|
||||
for i in range(num_broken_messages, len(messages)))
|
||||
try:
|
||||
accuracy /= (len(messages) - num_broken_messages)
|
||||
except ZeroDivisionError:
|
||||
accuracy = 0
|
||||
|
||||
return accuracy * 100
|
||||
|
||||
def test_against_num_messages(self):
|
||||
num_messages = list(range(1, 24, 1))
|
||||
accuracies = defaultdict(list)
|
||||
|
||||
protocols = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
for protocol_nr in protocols:
|
||||
for n in num_messages:
|
||||
protocol, expected_labels = self.get_protocol(protocol_nr, num_messages=n)
|
||||
self.run_format_finder_for_protocol(protocol)
|
||||
|
||||
accuracy = self.calculate_accuracy(protocol.messages, expected_labels)
|
||||
accuracies["protocol {}".format(protocol_nr)].append(accuracy)
|
||||
|
||||
self.__plot(num_messages, accuracies, xlabel="Number of messages", ylabel="Accuracy in %", grid=True)
|
||||
self.__export_to_csv("/tmp/accuray-vs-messages", num_messages, accuracies)
|
||||
|
||||
def test_against_error(self):
|
||||
Engine._DEBUG_ = False
|
||||
Preprocessor._DEBUG_ = False
|
||||
|
||||
num_runs = 100
|
||||
|
||||
num_messages = 30
|
||||
num_broken_messages = list(range(0, num_messages + 1))
|
||||
accuracies = defaultdict(list)
|
||||
accuracies_without_broken = defaultdict(list)
|
||||
|
||||
protocols = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
with multiprocessing.Pool() as p:
|
||||
result = p.starmap(run_for_num_broken,
|
||||
[(i, num_broken_messages, num_messages, num_runs) for i in protocols])
|
||||
for i, acc in enumerate(result):
|
||||
accuracies["protocol {}".format(i + 1)] = [a[0] for a in acc]
|
||||
accuracies_without_broken["protocol {}".format(i + 1)] = [a[1] for a in acc]
|
||||
|
||||
self.__plot(100 * np.array(num_broken_messages) / num_messages, accuracies,
|
||||
title="Overall Accuracy vs percentage of broken messages",
|
||||
xlabel="Broken messages in %",
|
||||
ylabel="Accuracy in %", grid=True)
|
||||
self.__plot(100 * np.array(num_broken_messages) / num_messages, accuracies_without_broken,
|
||||
title=" Accuracy of unbroken vs percentage of broken messages",
|
||||
xlabel="Broken messages in %",
|
||||
ylabel="Accuracy in %", grid=True)
|
||||
self.__export_to_csv("/tmp/accuray-vs-error", num_broken_messages, accuracies, relative=num_messages)
|
||||
self.__export_to_csv("/tmp/accuray-vs-error-without-broken", num_broken_messages, accuracies_without_broken,
|
||||
relative=num_messages)
|
||||
|
||||
def test_performance(self):
|
||||
Engine._DEBUG_ = False
|
||||
Preprocessor._DEBUG_ = False
|
||||
|
||||
num_messages = list(range(200, 205, 5))
|
||||
protocols = [1]
|
||||
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
performances = defaultdict(list)
|
||||
|
||||
for protocol_nr in protocols:
|
||||
print("Running for protocol", protocol_nr)
|
||||
for messages in num_messages:
|
||||
protocol, _ = self.get_protocol(protocol_nr, messages, silent=True)
|
||||
|
||||
t = time.time()
|
||||
self.run_format_finder_for_protocol(protocol)
|
||||
performances["protocol {}".format(protocol_nr)].append(time.time() - t)
|
||||
|
||||
# self.__plot(num_messages, performances, xlabel="Number of messages", ylabel="Time in seconds", grid=True)
|
||||
|
||||
def test_performance_real_protocols(self):
|
||||
Engine._DEBUG_ = False
|
||||
Preprocessor._DEBUG_ = False
|
||||
|
||||
num_runs = 100
|
||||
|
||||
num_messages = list(range(8, 512, 4))
|
||||
protocol_names = ["enocean", "homematic", "rwe"]
|
||||
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
performances = defaultdict(list)
|
||||
|
||||
for protocol_name in protocol_names:
|
||||
for messages in num_messages:
|
||||
if protocol_name == "homematic":
|
||||
protocol = self.generate_homematic(messages, save_protocol=False)
|
||||
elif protocol_name == "enocean":
|
||||
protocol = self.generate_enocean(messages, save_protocol=False)
|
||||
elif protocol_name == "rwe":
|
||||
protocol = self.generate_rwe(messages, save_protocol=False)
|
||||
else:
|
||||
raise ValueError("Unknown protocol name")
|
||||
|
||||
tmp_performances = np.empty(num_runs, dtype=np.float64)
|
||||
for i in range(num_runs):
|
||||
print("\r{0} with {1:02d} messages ({2}/{3} runs)".format(protocol_name, messages, i + 1, num_runs),
|
||||
flush=True, end="")
|
||||
|
||||
t = time.time()
|
||||
self.run_format_finder_for_protocol(protocol)
|
||||
tmp_performances[i] = time.time() - t
|
||||
self.clear_message_types(protocol.messages)
|
||||
|
||||
mean_performance = tmp_performances.mean()
|
||||
print(" {:.2f}s".format(mean_performance))
|
||||
performances["{}".format(protocol_name)].append(mean_performance)
|
||||
|
||||
self.__plot(num_messages, performances, xlabel="Number of messages", ylabel="Time in seconds", grid=True)
|
||||
self.__export_to_csv("/tmp/performance.csv", num_messages, performances)
|
||||
|
||||
@staticmethod
|
||||
def __export_to_csv(filename: str, x: list, y: dict, relative=None):
|
||||
if not filename.endswith(".csv"):
|
||||
filename += ".csv"
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write("N,")
|
||||
if relative is not None:
|
||||
f.write("NRel,")
|
||||
for y_cap in sorted(y):
|
||||
f.write(y_cap + ",")
|
||||
f.write("\n")
|
||||
|
||||
for i, x_val in enumerate(x):
|
||||
f.write("{},".format(x_val))
|
||||
if relative is not None:
|
||||
f.write("{},".format(100 * x_val / relative))
|
||||
|
||||
for y_cap in sorted(y):
|
||||
f.write("{},".format(y[y_cap][i]))
|
||||
f.write("\n")
|
||||
|
||||
@staticmethod
|
||||
def __plot(x: list, y: dict, xlabel: str, ylabel: str, grid=False, title=None):
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel(ylabel)
|
||||
|
||||
for y_cap, y_values in sorted(y.items()):
|
||||
plt.plot(x, y_values, label=y_cap)
|
||||
|
||||
if grid:
|
||||
plt.grid()
|
||||
|
||||
if title:
|
||||
plt.title(title)
|
||||
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
@staticmethod
|
||||
def run_format_finder_for_protocol(protocol: ProtocolAnalyzer):
|
||||
ff = FormatFinder(protocol.messages)
|
||||
ff.known_participant_addresses.clear()
|
||||
ff.run()
|
||||
|
||||
for msg_type, indices in ff.existing_message_types.items():
|
||||
for i in indices:
|
||||
protocol.messages[i].message_type = msg_type
|
||||
|
||||
@classmethod
|
||||
def generate_homematic(cls, num_messages: int, save_protocol=True):
|
||||
mb_m_frame = MessageTypeBuilder("mframe")
|
||||
mb_c_frame = MessageTypeBuilder("cframe")
|
||||
mb_r_frame = MessageTypeBuilder("rframe")
|
||||
mb_a_frame = MessageTypeBuilder("aframe")
|
||||
|
||||
participants = [Participant("CCU", address_hex="3927cc"), Participant("Switch", address_hex="3101cc")]
|
||||
|
||||
checksum = GenericCRC.from_standard_checksum("CRC16 CC1101")
|
||||
for mb_builder in [mb_m_frame, mb_c_frame, mb_r_frame, mb_a_frame]:
|
||||
mb_builder.add_label(FieldType.Function.PREAMBLE, 32)
|
||||
mb_builder.add_label(FieldType.Function.SYNC, 32)
|
||||
mb_builder.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb_builder.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
|
||||
mb_builder.add_label(FieldType.Function.TYPE, 16)
|
||||
mb_builder.add_label(FieldType.Function.SRC_ADDRESS, 24)
|
||||
mb_builder.add_label(FieldType.Function.DST_ADDRESS, 24)
|
||||
if mb_builder.name == "mframe":
|
||||
mb_builder.add_label(FieldType.Function.DATA, 16, name="command")
|
||||
elif mb_builder.name == "cframe":
|
||||
mb_builder.add_label(FieldType.Function.DATA, 16 * 4, name="command+challenge+magic")
|
||||
elif mb_builder.name == "rframe":
|
||||
mb_builder.add_label(FieldType.Function.DATA, 32 * 4, name="cipher")
|
||||
elif mb_builder.name == "aframe":
|
||||
mb_builder.add_label(FieldType.Function.DATA, 10 * 4, name="command + auth")
|
||||
mb_builder.add_checksum_label(16, checksum)
|
||||
|
||||
message_types = [mb_m_frame.message_type, mb_c_frame.message_type, mb_r_frame.message_type,
|
||||
mb_a_frame.message_type]
|
||||
preamble = "0xaaaaaaaa"
|
||||
sync = "0xe9cae9ca"
|
||||
initial_sequence_number = 36
|
||||
pg = ProtocolGenerator(message_types, participants,
|
||||
preambles_by_mt={mt: preamble for mt in message_types},
|
||||
syncs_by_mt={mt: sync for mt in message_types},
|
||||
sequence_numbers={mt: initial_sequence_number for mt in message_types},
|
||||
message_type_codes={mb_m_frame.message_type: 42560,
|
||||
mb_c_frame.message_type: 40962,
|
||||
mb_r_frame.message_type: 40963,
|
||||
mb_a_frame.message_type: 32770})
|
||||
|
||||
for i in range(num_messages):
|
||||
mt = pg.message_types[i % 4]
|
||||
data_length = mt.get_first_label_with_type(FieldType.Function.DATA).length
|
||||
data = "".join(random.choice(["0", "1"]) for _ in range(data_length))
|
||||
pg.generate_message(mt, data, source=pg.participants[i % 2], destination=pg.participants[(i + 1) % 2])
|
||||
|
||||
if save_protocol:
|
||||
cls.save_protocol("homematic", pg)
|
||||
|
||||
cls.clear_message_types(pg.messages)
|
||||
return pg.protocol
|
||||
|
||||
@classmethod
|
||||
def generate_enocean(cls, num_messages: int, save_protocol=True):
|
||||
filename = get_path_for_data_file("enocean_bits.txt")
|
||||
enocean_bits = []
|
||||
with open(filename, "r") as f:
|
||||
for line in map(str.strip, f):
|
||||
enocean_bits.append(line)
|
||||
|
||||
protocol = ProtocolAnalyzer(None)
|
||||
message_type = MessageType("empty")
|
||||
for i in range(num_messages):
|
||||
msg = Message.from_plain_bits_str(enocean_bits[i % len(enocean_bits)])
|
||||
msg.message_type = message_type
|
||||
protocol.messages.append(msg)
|
||||
|
||||
if save_protocol:
|
||||
cls.save_protocol("enocean", protocol)
|
||||
|
||||
return protocol
|
||||
|
||||
@classmethod
|
||||
def generate_rwe(cls, num_messages: int, save_protocol=True):
|
||||
proto_file = get_path_for_data_file("rwe.proto.xml")
|
||||
protocol = ProtocolAnalyzer(signal=None, filename=proto_file)
|
||||
protocol.from_xml_file(filename=proto_file, read_bits=True)
|
||||
messages = protocol.messages
|
||||
|
||||
result = ProtocolAnalyzer(None)
|
||||
message_type = MessageType("empty")
|
||||
for i in range(num_messages):
|
||||
msg = messages[i % len(messages)] # type: Message
|
||||
msg.message_type = message_type
|
||||
result.messages.append(msg)
|
||||
|
||||
if save_protocol:
|
||||
cls.save_protocol("rwe", result)
|
||||
|
||||
return result
|
||||
|
||||
def test_export_latex_table(self):
|
||||
def bold_latex(s):
|
||||
return r"\textbf{" + str(s) + r"}"
|
||||
|
||||
comments = {
|
||||
1: "common protocol",
|
||||
2: "unusual field sizes",
|
||||
3: "contains ack and CRC8 CCITT",
|
||||
4: "contains ack and CRC16 CCITT",
|
||||
5: "three participants with ack frame",
|
||||
6: "short address",
|
||||
7: "four participants, varying preamble size, varying sync words",
|
||||
8: "nibble fields + LE"
|
||||
}
|
||||
|
||||
bold = {i: defaultdict(bool) for i in range(1, 9)}
|
||||
bold[2][FieldType.Function.PREAMBLE] = True
|
||||
bold[2][FieldType.Function.SRC_ADDRESS] = True
|
||||
bold[2][FieldType.Function.DST_ADDRESS] = True
|
||||
|
||||
bold[3][FieldType.Function.CHECKSUM] = True
|
||||
|
||||
bold[4][FieldType.Function.CHECKSUM] = True
|
||||
|
||||
bold[6][FieldType.Function.SRC_ADDRESS] = True
|
||||
|
||||
bold[7][FieldType.Function.PREAMBLE] = True
|
||||
bold[7][FieldType.Function.SYNC] = True
|
||||
bold[7][FieldType.Function.SRC_ADDRESS] = True
|
||||
bold[7][FieldType.Function.DST_ADDRESS] = True
|
||||
|
||||
bold[8][FieldType.Function.PREAMBLE] = True
|
||||
bold[8][FieldType.Function.SYNC] = True
|
||||
|
||||
filename = os.path.expanduser("~/GIT/publications/awre/USENIX/protocol_table.tex")
|
||||
rowcolors = [r"\rowcolor{black!10}", r"\rowcolor{black!20}"]
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write(r"\begin{table*}[!h]" + "\n")
|
||||
f.write(
|
||||
"\t" + r"\caption{Properties of tested protocols whereby $\times$ means field is not present and $N_P$ is the number of participants.}" + "\n")
|
||||
f.write("\t" + r"\label{tab:protocols}" + "\n")
|
||||
f.write("\t" + r"\centering" + "\n")
|
||||
f.write("\t" + r"\begin{tabularx}{\linewidth}{cp{2.5cm}llcccccccc}" + "\n")
|
||||
f.write("\t\t" + r"\hline" + "\n")
|
||||
f.write("\t\t" + r"\rowcolor{black!90}" + "\n")
|
||||
f.write("\t\t" + r"\textcolor{white}{\textbf{\#}} & "
|
||||
r"\textcolor{white}{\textbf{Comment}} & "
|
||||
r"\textcolor{white}{$\mathbf{ N_P }$} & "
|
||||
r"\textcolor{white}{\textbf{Message}} & "
|
||||
r"\textcolor{white}{\textbf{Even/odd}} & "
|
||||
r"\multicolumn{7}{c}{\textcolor{white}{\textbf{Size of field in bit (BE=Big Endian, LE=Little Endian)}}}\\"
|
||||
"\n\t\t"
|
||||
r"\rowcolor{black!90}"
|
||||
"\n\t\t"
|
||||
r"& & & \textcolor{white}{\textbf{Type}} & \textcolor{white}{\textbf{message data}} &"
|
||||
r"\textcolor{white}{Preamble} & "
|
||||
r"\textcolor{white}{Sync} & "
|
||||
r"\textcolor{white}{Length} & "
|
||||
r"\textcolor{white}{SRC} & "
|
||||
r"\textcolor{white}{DST} & "
|
||||
r"\textcolor{white}{SEQ Nr} & "
|
||||
r"\textcolor{white}{CRC} \\" + "\n")
|
||||
f.write("\t\t" + r"\hline" + "\n")
|
||||
|
||||
rowcolor_index = 0
|
||||
for i in range(1, 9):
|
||||
pg = getattr(self, "_prepare_protocol_" + str(i))()
|
||||
assert isinstance(pg, ProtocolGenerator)
|
||||
|
||||
try:
|
||||
data1 = next(mt for mt in pg.message_types if mt.name == "data1")
|
||||
data2 = next(mt for mt in pg.message_types if mt.name == "data2")
|
||||
|
||||
data1_len = data1.get_first_label_with_type(FieldType.Function.DATA).length // 8
|
||||
data2_len = data2.get_first_label_with_type(FieldType.Function.DATA).length // 8
|
||||
|
||||
except StopIteration:
|
||||
data1_len, data2_len = 8, 64
|
||||
|
||||
rowcolor = rowcolors[rowcolor_index % len(rowcolors)]
|
||||
rowcount = 0
|
||||
for j, mt in enumerate(pg.message_types):
|
||||
if mt.name == "data2":
|
||||
continue
|
||||
|
||||
rowcount += 1
|
||||
if j == 0:
|
||||
protocol_nr, participants = str(i), len(pg.participants)
|
||||
if participants > 2:
|
||||
participants = bold_latex(participants)
|
||||
else:
|
||||
protocol_nr, participants = " ", " "
|
||||
|
||||
f.write("\t\t" + rowcolor + "\n")
|
||||
|
||||
if len(pg.message_types) == 1 or (
|
||||
mt.name == "data1" and "ack" not in {m.name for m in pg.message_types}):
|
||||
f.write("\t\t{} & {} & {} & {} &".format(protocol_nr, comments[i], participants,
|
||||
mt.name.replace("1", "")))
|
||||
elif j == len(pg.message_types) - 1:
|
||||
f.write(
|
||||
"\t\t{} & \\multirow{{{}}}{{\\linewidth}}{{{}}} & {} & {} &".format(protocol_nr, -rowcount,
|
||||
comments[i],
|
||||
participants,
|
||||
mt.name.replace("1",
|
||||
"")))
|
||||
else:
|
||||
f.write("\t\t{} & & {} & {} &".format(protocol_nr, participants, mt.name.replace("1", "")))
|
||||
data_lbl = mt.get_first_label_with_type(FieldType.Function.DATA)
|
||||
|
||||
if mt.name == "data1" or mt.name == "data2":
|
||||
f.write("{}/{} byte &".format(data1_len, data2_len))
|
||||
elif mt.name == "data" and data_lbl is None:
|
||||
f.write("{}/{} byte &".format(data1_len, data2_len))
|
||||
elif data_lbl is not None:
|
||||
f.write("{0}/{0} byte & ".format(data_lbl.length // 8))
|
||||
else:
|
||||
f.write(r"$ \times $ & ")
|
||||
|
||||
for t in (FieldType.Function.PREAMBLE, FieldType.Function.SYNC, FieldType.Function.LENGTH,
|
||||
FieldType.Function.SRC_ADDRESS, FieldType.Function.DST_ADDRESS,
|
||||
FieldType.Function.SEQUENCE_NUMBER,
|
||||
FieldType.Function.CHECKSUM):
|
||||
lbl = mt.get_first_label_with_type(t)
|
||||
if lbl is not None:
|
||||
if bold[i][lbl.field_type.function]:
|
||||
f.write(bold_latex(lbl.length))
|
||||
else:
|
||||
f.write(str(lbl.length))
|
||||
if lbl.length > 8 and t in (FieldType.Function.LENGTH, FieldType.Function.SEQUENCE_NUMBER):
|
||||
f.write(" ({})".format(bold_latex("LE") if pg.little_endian else "BE"))
|
||||
else:
|
||||
f.write(r"$ \times $")
|
||||
|
||||
if t != FieldType.Function.CHECKSUM:
|
||||
f.write(" & ")
|
||||
else:
|
||||
f.write(r"\\" + "\n")
|
||||
|
||||
rowcolor_index += 1
|
||||
|
||||
f.write("\t" + r"\end{tabularx}" + "\n")
|
||||
|
||||
f.write(r"\end{table*}" + "\n")
|
179
Software/Universal Radio Hacker/tests/awre/TestAWREHistograms.py
Normal file
179
Software/Universal Radio Hacker/tests/awre/TestAWREHistograms.py
Normal file
@ -0,0 +1,179 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from tests.awre.AWRETestCase import AWRETestCase
|
||||
from urh.awre.FormatFinder import FormatFinder
|
||||
from urh.awre.Histogram import Histogram
|
||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
|
||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
|
||||
from urh.signalprocessing.FieldType import FieldType
|
||||
from urh.signalprocessing.Participant import Participant
|
||||
|
||||
SHOW_PLOTS = True
|
||||
|
||||
class TestAWREHistograms(AWRETestCase):
|
||||
def test_very_simple_protocol(self):
|
||||
"""
|
||||
Test a very simple protocol consisting just of a preamble, sync and some random data
|
||||
:return:
|
||||
"""
|
||||
mb = MessageTypeBuilder("very_simple_test")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 8)
|
||||
|
||||
num_messages = 10
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type], syncs_by_mt={mb.message_type: "0x9a"})
|
||||
for _ in range(num_messages):
|
||||
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 255), 8))
|
||||
|
||||
self.save_protocol("very_simple", pg)
|
||||
|
||||
h = Histogram(FormatFinder.get_bitvectors_from_messages(pg.protocol.messages))
|
||||
if SHOW_PLOTS:
|
||||
h.plot()
|
||||
|
||||
def test_simple_protocol(self):
|
||||
"""
|
||||
Test a simple protocol with preamble, sync and length field and some random data
|
||||
:return:
|
||||
"""
|
||||
mb = MessageTypeBuilder("simple_test")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
|
||||
num_messages_by_data_length = {8: 5, 16: 10, 32: 15}
|
||||
pg = ProtocolGenerator([mb.message_type], syncs_by_mt={mb.message_type: "0x9a9d"})
|
||||
for data_length, num_messages in num_messages_by_data_length.items():
|
||||
for _ in range(num_messages):
|
||||
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** data_length - 1), data_length))
|
||||
|
||||
self.save_protocol("simple", pg)
|
||||
|
||||
plt.subplot("221")
|
||||
plt.title("All messages")
|
||||
format_finder = FormatFinder(pg.protocol.messages)
|
||||
|
||||
for i, sync_end in enumerate(format_finder.sync_ends):
|
||||
self.assertEqual(sync_end, 24, msg=str(i))
|
||||
|
||||
h = Histogram(format_finder.bitvectors)
|
||||
h.subplot_on(plt)
|
||||
|
||||
bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages)
|
||||
bitvectors_by_length = defaultdict(list)
|
||||
for bitvector in bitvectors:
|
||||
bitvectors_by_length[len(bitvector)].append(bitvector)
|
||||
|
||||
for i, (message_length, bitvectors) in enumerate(bitvectors_by_length.items()):
|
||||
plt.subplot(2, 2, i + 2)
|
||||
plt.title("Messages with length {} ({})".format(message_length, len(bitvectors)))
|
||||
Histogram(bitvectors).subplot_on(plt)
|
||||
|
||||
if SHOW_PLOTS:
|
||||
plt.show()
|
||||
|
||||
def test_medium_protocol(self):
|
||||
"""
|
||||
Test a protocol with preamble, sync, length field, 2 participants and addresses and seq nr and random data
|
||||
:return:
|
||||
"""
|
||||
mb = MessageTypeBuilder("medium_test")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 8)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
|
||||
|
||||
alice = Participant("Alice", "A", "1234", color_index=0)
|
||||
bob = Participant("Bob", "B", "5a9d", color_index=1)
|
||||
|
||||
num_messages = 100
|
||||
pg = ProtocolGenerator([mb.message_type], syncs_by_mt={mb.message_type: "0x1c"}, little_endian=False)
|
||||
for i in range(num_messages):
|
||||
len_data = random.randint(1, 5)
|
||||
data = "".join(pg.decimal_to_bits(random.randint(0, 2 ** 8 - 1), 8) for _ in range(len_data))
|
||||
if i % 2 == 0:
|
||||
source, dest = alice, bob
|
||||
else:
|
||||
source, dest = bob, alice
|
||||
pg.generate_message(data=data, source=source, destination=dest)
|
||||
|
||||
self.save_protocol("medium", pg)
|
||||
|
||||
plt.subplot(2, 2, 1)
|
||||
plt.title("All messages")
|
||||
bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages)
|
||||
h = Histogram(bitvectors)
|
||||
h.subplot_on(plt)
|
||||
|
||||
for i, (participant, bitvectors) in enumerate(
|
||||
sorted(self.get_bitvectors_by_participant(pg.protocol.messages).items())):
|
||||
plt.subplot(2, 2, i + 3)
|
||||
plt.title("Messages with participant {} ({})".format(participant.shortname, len(bitvectors)))
|
||||
Histogram(bitvectors).subplot_on(plt)
|
||||
|
||||
if SHOW_PLOTS:
|
||||
plt.show()
|
||||
|
||||
def get_bitvectors_by_participant(self, messages):
|
||||
import numpy as np
|
||||
result = defaultdict(list)
|
||||
for msg in messages: # type: Message
|
||||
result[msg.participant].append(np.array(msg.decoded_bits, dtype=np.uint8, order="C"))
|
||||
return result
|
||||
|
||||
def test_ack_protocol(self):
|
||||
"""
|
||||
Test a protocol with acks
|
||||
:return:
|
||||
"""
|
||||
mb = MessageTypeBuilder("data")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 8)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
|
||||
|
||||
mb_ack = MessageTypeBuilder("ack")
|
||||
mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb_ack.add_label(FieldType.Function.SYNC, 8)
|
||||
mb_ack.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
|
||||
alice = Participant("Alice", "A", "1234", color_index=0)
|
||||
bob = Participant("Bob", "B", "5a9d", color_index=1)
|
||||
|
||||
num_messages = 50
|
||||
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
|
||||
syncs_by_mt={mb.message_type: "0xbf", mb_ack.message_type: "0xbf"},
|
||||
little_endian=False)
|
||||
for i in range(num_messages):
|
||||
if i % 2 == 0:
|
||||
source, dest = alice, bob
|
||||
else:
|
||||
source, dest = bob, alice
|
||||
pg.generate_message(data="0xffff", source=source, destination=dest)
|
||||
pg.generate_message(data="", source=dest, destination=source, message_type=mb_ack.message_type)
|
||||
|
||||
self.save_protocol("proto_with_acks", pg)
|
||||
|
||||
plt.subplot(2, 2, 1)
|
||||
plt.title("All messages")
|
||||
bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages)
|
||||
h = Histogram(bitvectors)
|
||||
h.subplot_on(plt)
|
||||
|
||||
for i, (participant, bitvectors) in enumerate(
|
||||
sorted(self.get_bitvectors_by_participant(pg.protocol.messages).items())):
|
||||
plt.subplot(2, 2, i + 3)
|
||||
plt.title("Messages with participant {} ({})".format(participant.shortname, len(bitvectors)))
|
||||
Histogram(bitvectors).subplot_on(plt)
|
||||
|
||||
if SHOW_PLOTS:
|
||||
plt.show()
|
@ -0,0 +1,386 @@
|
||||
import random
|
||||
from array import array
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tests.awre.AWRETestCase import AWRETestCase
|
||||
from tests.utils_testing import get_path_for_data_file
|
||||
from urh.awre.FormatFinder import FormatFinder
|
||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
|
||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
|
||||
from urh.awre.engines.AddressEngine import AddressEngine
|
||||
from urh.signalprocessing.FieldType import FieldType
|
||||
from urh.signalprocessing.Message import Message
|
||||
from urh.signalprocessing.Participant import Participant
|
||||
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
|
||||
from urh.util import util
|
||||
|
||||
|
||||
class TestAddressEngine(AWRETestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.alice = Participant("Alice", "A", address_hex="1234")
|
||||
self.bob = Participant("Bob", "B", address_hex="cafe")
|
||||
|
||||
def test_one_participant(self):
|
||||
"""
|
||||
Test a simple protocol with
|
||||
preamble, sync and length field (8 bit) and some random data
|
||||
|
||||
:return:
|
||||
"""
|
||||
mb = MessageTypeBuilder("simple_address_test")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||||
|
||||
num_messages_by_data_length = {8: 5, 16: 10, 32: 15}
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x9a9d"},
|
||||
participants=[self.alice])
|
||||
for data_length, num_messages in num_messages_by_data_length.items():
|
||||
for i in range(num_messages):
|
||||
pg.generate_message(data=pg.decimal_to_bits(22 * i, data_length), source=self.alice)
|
||||
|
||||
#self.save_protocol("address_one_participant", pg)
|
||||
|
||||
self.clear_message_types(pg.protocol.messages)
|
||||
ff = FormatFinder(pg.protocol.messages)
|
||||
|
||||
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
|
||||
address_dict = address_engine.find_addresses()
|
||||
|
||||
self.assertEqual(len(address_dict), 0)
|
||||
|
||||
def test_two_participants(self):
|
||||
mb = MessageTypeBuilder("address_two_participants")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
|
||||
num_messages = 50
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x9a9d"},
|
||||
participants=[self.alice, self.bob])
|
||||
|
||||
for i in range(num_messages):
|
||||
if i % 2 == 0:
|
||||
source, destination = self.alice, self.bob
|
||||
data_length = 8
|
||||
else:
|
||||
source, destination = self.bob, self.alice
|
||||
data_length = 16
|
||||
pg.generate_message(data=pg.decimal_to_bits(4 * i, data_length), source=source, destination=destination)
|
||||
|
||||
#self.save_protocol("address_two_participants", pg)
|
||||
|
||||
self.clear_message_types(pg.protocol.messages)
|
||||
ff = FormatFinder(pg.protocol.messages)
|
||||
|
||||
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
|
||||
address_dict = address_engine.find_addresses()
|
||||
self.assertEqual(len(address_dict), 2)
|
||||
addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0]))
|
||||
addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1]))
|
||||
self.assertIn(self.alice.address_hex, addresses_1)
|
||||
self.assertIn(self.alice.address_hex, addresses_2)
|
||||
self.assertIn(self.bob.address_hex, addresses_1)
|
||||
self.assertIn(self.bob.address_hex, addresses_2)
|
||||
|
||||
ff.known_participant_addresses.clear()
|
||||
self.assertEqual(len(ff.known_participant_addresses), 0)
|
||||
|
||||
ff.perform_iteration()
|
||||
|
||||
self.assertEqual(len(ff.known_participant_addresses), 2)
|
||||
self.assertIn(bytes([int(h, 16) for h in self.alice.address_hex]),
|
||||
map(bytes, ff.known_participant_addresses.values()))
|
||||
self.assertIn(bytes([int(h, 16) for h in self.bob.address_hex]),
|
||||
map(bytes, ff.known_participant_addresses.values()))
|
||||
|
||||
self.assertEqual(len(ff.message_types), 1)
|
||||
mt = ff.message_types[0]
|
||||
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
|
||||
self.assertIsNotNone(dst_addr)
|
||||
self.assertEqual(dst_addr.start, 32)
|
||||
self.assertEqual(dst_addr.length, 16)
|
||||
src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
|
||||
self.assertIsNotNone(src_addr)
|
||||
self.assertEqual(src_addr.start, 48)
|
||||
self.assertEqual(src_addr.length, 16)
|
||||
|
||||
def test_two_participants_with_ack_messages(self):
|
||||
mb = MessageTypeBuilder("data")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||||
mb_ack = MessageTypeBuilder("ack")
|
||||
mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb_ack.add_label(FieldType.Function.SYNC, 16)
|
||||
mb_ack.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
|
||||
num_messages = 50
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"},
|
||||
participants=[self.alice, self.bob])
|
||||
|
||||
random.seed(0)
|
||||
for i in range(num_messages):
|
||||
if i % 2 == 0:
|
||||
source, destination = self.alice, self.bob
|
||||
data_length = 8
|
||||
else:
|
||||
source, destination = self.bob, self.alice
|
||||
data_length = 16
|
||||
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
|
||||
source=source, destination=destination)
|
||||
pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination)
|
||||
|
||||
#self.save_protocol("address_two_participants_with_acks", pg)
|
||||
|
||||
self.clear_message_types(pg.protocol.messages)
|
||||
ff = FormatFinder(pg.protocol.messages)
|
||||
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
|
||||
address_dict = address_engine.find_addresses()
|
||||
self.assertEqual(len(address_dict), 2)
|
||||
addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0]))
|
||||
addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1]))
|
||||
self.assertIn(self.alice.address_hex, addresses_1)
|
||||
self.assertIn(self.alice.address_hex, addresses_2)
|
||||
self.assertIn(self.bob.address_hex, addresses_1)
|
||||
self.assertIn(self.bob.address_hex, addresses_2)
|
||||
|
||||
ff.known_participant_addresses.clear()
|
||||
ff.perform_iteration()
|
||||
self.assertEqual(len(ff.message_types), 2)
|
||||
mt = ff.message_types[1]
|
||||
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
|
||||
self.assertIsNotNone(dst_addr)
|
||||
self.assertEqual(dst_addr.start, 32)
|
||||
self.assertEqual(dst_addr.length, 16)
|
||||
src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
|
||||
self.assertIsNotNone(src_addr)
|
||||
self.assertEqual(src_addr.start, 48)
|
||||
self.assertEqual(src_addr.length, 16)
|
||||
|
||||
mt = ff.message_types[0]
|
||||
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
|
||||
self.assertIsNotNone(dst_addr)
|
||||
self.assertEqual(dst_addr.start, 32)
|
||||
self.assertEqual(dst_addr.length, 16)
|
||||
|
||||
def test_two_participants_with_ack_messages_and_type(self):
|
||||
mb = MessageTypeBuilder("data")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.TYPE, 8)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||||
mb_ack = MessageTypeBuilder("ack")
|
||||
mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb_ack.add_label(FieldType.Function.SYNC, 16)
|
||||
mb_ack.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
|
||||
num_messages = 50
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"},
|
||||
participants=[self.alice, self.bob])
|
||||
|
||||
random.seed(0)
|
||||
for i in range(num_messages):
|
||||
if i % 2 == 0:
|
||||
source, destination = self.alice, self.bob
|
||||
data_length = 8
|
||||
else:
|
||||
source, destination = self.bob, self.alice
|
||||
data_length = 16
|
||||
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
|
||||
source=source, destination=destination)
|
||||
pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination)
|
||||
|
||||
#self.save_protocol("address_two_participants_with_acks_and_types", pg)
|
||||
|
||||
self.clear_message_types(pg.protocol.messages)
|
||||
ff = FormatFinder(pg.protocol.messages)
|
||||
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
|
||||
address_dict = address_engine.find_addresses()
|
||||
self.assertEqual(len(address_dict), 2)
|
||||
addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0]))
|
||||
addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1]))
|
||||
self.assertIn(self.alice.address_hex, addresses_1)
|
||||
self.assertIn(self.alice.address_hex, addresses_2)
|
||||
self.assertIn(self.bob.address_hex, addresses_1)
|
||||
self.assertIn(self.bob.address_hex, addresses_2)
|
||||
|
||||
ff.known_participant_addresses.clear()
|
||||
ff.perform_iteration()
|
||||
self.assertEqual(len(ff.message_types), 2)
|
||||
mt = ff.message_types[1]
|
||||
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
|
||||
self.assertIsNotNone(dst_addr)
|
||||
self.assertEqual(dst_addr.start, 40)
|
||||
self.assertEqual(dst_addr.length, 16)
|
||||
src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
|
||||
self.assertIsNotNone(src_addr)
|
||||
self.assertEqual(src_addr.start, 56)
|
||||
self.assertEqual(src_addr.length, 16)
|
||||
|
||||
mt = ff.message_types[0]
|
||||
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
|
||||
self.assertIsNotNone(dst_addr)
|
||||
self.assertEqual(dst_addr.start, 32)
|
||||
self.assertEqual(dst_addr.length, 16)
|
||||
|
||||
def test_three_participants_with_ack(self):
|
||||
alice = Participant("Alice", address_hex="1337")
|
||||
bob = Participant("Bob", address_hex="4711")
|
||||
carl = Participant("Carl", address_hex="cafe")
|
||||
|
||||
mb = MessageTypeBuilder("data")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 16)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
|
||||
|
||||
mb_ack = MessageTypeBuilder("ack")
|
||||
mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
|
||||
mb_ack.add_label(FieldType.Function.SYNC, 16)
|
||||
mb_ack.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"},
|
||||
preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8},
|
||||
participants=[alice, bob, carl])
|
||||
|
||||
i = -1
|
||||
while len(pg.protocol.messages) < 20:
|
||||
i += 1
|
||||
source = pg.participants[i % len(pg.participants)]
|
||||
destination = pg.participants[(i + 1) % len(pg.participants)]
|
||||
if i % 2 == 0:
|
||||
data_bytes = 8
|
||||
else:
|
||||
data_bytes = 16
|
||||
|
||||
data = "".join(random.choice(["0", "1"]) for _ in range(data_bytes * 8))
|
||||
pg.generate_message(data=data, source=source, destination=destination)
|
||||
|
||||
if "ack" in (msg_type.name for msg_type in pg.protocol.message_types):
|
||||
pg.generate_message(message_type=1, data="", source=destination, destination=source)
|
||||
|
||||
self.clear_message_types(pg.protocol.messages)
|
||||
ff = FormatFinder(pg.protocol.messages)
|
||||
ff.known_participant_addresses.clear()
|
||||
self.assertEqual(len(ff.known_participant_addresses), 0)
|
||||
ff.run()
|
||||
|
||||
# Since there are ACKS in this protocol, the engine must be able to assign the correct participant addresses
|
||||
# IN CORRECT ORDER!
|
||||
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[0]), "1337")
|
||||
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[1]), "4711")
|
||||
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[2]), "cafe")
|
||||
|
||||
def test_protocol_with_acks_and_checksum(self):
|
||||
proto_file = get_path_for_data_file("ack_frames_with_crc.proto.xml")
|
||||
protocol = ProtocolAnalyzer(signal=None, filename=proto_file)
|
||||
protocol.from_xml_file(filename=proto_file, read_bits=True)
|
||||
|
||||
self.clear_message_types(protocol.messages)
|
||||
|
||||
ff = FormatFinder(protocol.messages)
|
||||
ff.known_participant_addresses.clear()
|
||||
|
||||
ff.run()
|
||||
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[0]), "1337")
|
||||
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[1]), "4711")
|
||||
|
||||
for mt in ff.message_types:
|
||||
preamble = mt.get_first_label_with_type(FieldType.Function.PREAMBLE)
|
||||
self.assertEqual(preamble.start, 0)
|
||||
self.assertEqual(preamble.length, 16)
|
||||
sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
|
||||
self.assertEqual(sync.start, 16)
|
||||
self.assertEqual(sync.length, 16)
|
||||
length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
|
||||
self.assertEqual(length.start, 32)
|
||||
self.assertEqual(length.length, 8)
|
||||
|
||||
def test_address_engine_performance(self):
|
||||
ff, messages = self.get_format_finder_from_protocol_file("35_messages.proto.xml", return_messages=True)
|
||||
|
||||
engine = AddressEngine(ff.hexvectors, ff.participant_indices)
|
||||
engine.find()
|
||||
|
||||
def test_paper_example(self):
|
||||
alice = Participant("Alice", "A")
|
||||
bob = Participant("Bob", "B")
|
||||
participants = [alice, bob]
|
||||
msg1 = Message.from_plain_hex_str("aabb1234")
|
||||
msg1.participant = alice
|
||||
msg2 = Message.from_plain_hex_str("aabb6789")
|
||||
msg2.participant = alice
|
||||
msg3 = Message.from_plain_hex_str("bbaa4711")
|
||||
msg3.participant = bob
|
||||
msg4 = Message.from_plain_hex_str("bbaa1337")
|
||||
msg4.participant = bob
|
||||
|
||||
protocol = ProtocolAnalyzer(None)
|
||||
protocol.messages.extend([msg1, msg2, msg3, msg4])
|
||||
#self.save_protocol("paper_example", protocol)
|
||||
|
||||
bitvectors = FormatFinder.get_bitvectors_from_messages(protocol.messages)
|
||||
hexvectors = FormatFinder.get_hexvectors(bitvectors)
|
||||
address_engine = AddressEngine(hexvectors, participant_indices=[participants.index(msg.participant) for msg in
|
||||
protocol.messages])
|
||||
|
||||
def test_find_common_sub_sequence(self):
|
||||
from urh.cythonext import awre_util
|
||||
str1 = "0612345678"
|
||||
str2 = "0756781234"
|
||||
|
||||
seq1 = np.array(list(map(int, str1)), dtype=np.uint8, order="C")
|
||||
seq2 = np.array(list(map(int, str2)), dtype=np.uint8, order="C")
|
||||
|
||||
indices = awre_util.find_longest_common_sub_sequence_indices(seq1, seq2)
|
||||
self.assertEqual(len(indices), 2)
|
||||
for ind in indices:
|
||||
s = str1[slice(*ind)]
|
||||
self.assertIn(s, ("5678", "1234"))
|
||||
self.assertIn(s, str1)
|
||||
self.assertIn(s, str2)
|
||||
|
||||
def test_find_first_occurrence(self):
|
||||
from urh.cythonext import awre_util
|
||||
str1 = "00" * 100 + "1234500012345" + "00" * 100
|
||||
str2 = "12345"
|
||||
|
||||
seq1 = np.array(list(map(int, str1)), dtype=np.uint8, order="C")
|
||||
seq2 = np.array(list(map(int, str2)), dtype=np.uint8, order="C")
|
||||
indices = awre_util.find_occurrences(seq1, seq2)
|
||||
self.assertEqual(len(indices), 2)
|
||||
index = indices[0]
|
||||
self.assertEqual(str1[index:index + len(str2)], str2)
|
||||
|
||||
# Test with ignoring indices
|
||||
indices = awre_util.find_occurrences(seq1, seq2, array("L", list(range(0, 205))))
|
||||
self.assertEqual(len(indices), 1)
|
||||
|
||||
# Test with ignoring indices
|
||||
indices = awre_util.find_occurrences(seq1, seq2, array("L", list(range(0, 210))))
|
||||
self.assertEqual(len(indices), 0)
|
||||
|
||||
self.assertEqual(awre_util.find_occurrences(seq1, np.ones(10, dtype=np.uint8)), [])
|
@ -0,0 +1,256 @@
|
||||
import random
|
||||
|
||||
from tests.awre.AWRETestCase import AWRETestCase
|
||||
from urh.awre.FormatFinder import FormatFinder
|
||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
|
||||
from urh.awre.Preprocessor import Preprocessor
|
||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
|
||||
from urh.signalprocessing.FieldType import FieldType
|
||||
from urh.signalprocessing.Message import Message
|
||||
from urh.signalprocessing.Participant import Participant
|
||||
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestAWREPreprocessing(AWRETestCase):
|
||||
def test_very_simple_sync_word_finding(self):
|
||||
preamble = "10101010"
|
||||
sync = "1101"
|
||||
|
||||
pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync)],
|
||||
num_messages=(20,),
|
||||
data=(lambda i: 10 * i,))
|
||||
|
||||
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
|
||||
|
||||
possible_syncs = preprocessor.find_possible_syncs()
|
||||
#self.save_protocol("very_simple_sync_test", pg)
|
||||
self.assertGreaterEqual(len(possible_syncs), 1)
|
||||
self.assertEqual(preprocessor.find_possible_syncs()[0], sync)
|
||||
|
||||
def test_simple_sync_word_finding(self):
|
||||
preamble = "10101010"
|
||||
sync = "1001"
|
||||
|
||||
pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "1010", sync)],
|
||||
num_messages=(20, 5),
|
||||
data=(lambda i: 10 * i, lambda i: 22 * i))
|
||||
|
||||
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
|
||||
|
||||
possible_syncs = preprocessor.find_possible_syncs()
|
||||
#self.save_protocol("simple_sync_test", pg)
|
||||
self.assertGreaterEqual(len(possible_syncs), 1)
|
||||
self.assertEqual(preprocessor.find_possible_syncs()[0], sync)
|
||||
|
||||
def test_sync_word_finding_odd_preamble(self):
|
||||
preamble = "0101010"
|
||||
sync = "1101"
|
||||
pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "10", sync)],
|
||||
num_messages=(20, 5),
|
||||
data=(lambda i: 10 * i, lambda i: i))
|
||||
|
||||
# If we have a odd preamble length, the last bit of the preamble is counted to the sync
|
||||
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
|
||||
possible_syncs = preprocessor.find_possible_syncs()
|
||||
|
||||
#self.save_protocol("odd_preamble", pg)
|
||||
self.assertEqual(preamble[-1] + sync[:-1], possible_syncs[0])
|
||||
|
||||
def test_sync_word_finding_special_preamble(self):
|
||||
preamble = "111001110011100"
|
||||
sync = "0110"
|
||||
pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "10", sync)],
|
||||
num_messages=(20, 5),
|
||||
data=(lambda i: 10 * i, lambda i: i))
|
||||
|
||||
# If we have a odd preamble length, the last bit of the preamble is counted to the sync
|
||||
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
|
||||
possible_syncs = preprocessor.find_possible_syncs()
|
||||
|
||||
#self.save_protocol("special_preamble", pg)
|
||||
self.assertEqual(sync, possible_syncs[0])
|
||||
|
||||
def test_sync_word_finding_errored_preamble(self):
|
||||
preamble = "00010101010" # first bits are wrong
|
||||
sync = "0110"
|
||||
pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "10", sync)],
|
||||
num_messages=(20, 5),
|
||||
data=(lambda i: 10 * i, lambda i: i))
|
||||
|
||||
# If we have a odd preamble length, the last bit of the preamble is counted to the sync
|
||||
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
|
||||
possible_syncs = preprocessor.find_possible_syncs()
|
||||
|
||||
#self.save_protocol("errored_preamble", pg)
|
||||
self.assertIn(preamble[-1] + sync[:-1], possible_syncs)
|
||||
|
||||
def test_sync_word_finding_with_two_sync_words(self):
|
||||
preamble = "0xaaaa"
|
||||
sync1, sync2 = "0x1234", "0xcafe"
|
||||
pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync1), (preamble, sync2)],
|
||||
num_messages=(15, 10),
|
||||
data=(lambda i: 12 * i, lambda i: 16 * i))
|
||||
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
|
||||
possible_syncs = preprocessor.find_possible_syncs()
|
||||
#self.save_protocol("two_syncs", pg)
|
||||
self.assertGreaterEqual(len(possible_syncs), 2)
|
||||
self.assertIn(ProtocolGenerator.to_bits(sync1), possible_syncs)
|
||||
self.assertIn(ProtocolGenerator.to_bits(sync2), possible_syncs)
|
||||
|
||||
def test_multiple_sync_words(self):
|
||||
hex_messages = [
|
||||
"aaS1234",
|
||||
"aaScafe",
|
||||
"aaSdead",
|
||||
"aaSbeef",
|
||||
]
|
||||
|
||||
for i in range(1, 256):
|
||||
messages = []
|
||||
sync = "{0:02x}".format(i)
|
||||
if sync.startswith("a"):
|
||||
continue
|
||||
|
||||
for msg in hex_messages:
|
||||
messages.append(Message.from_plain_hex_str(msg.replace("S", sync)))
|
||||
|
||||
for i in range(1, len(messages)):
|
||||
messages[i].message_type = messages[0].message_type
|
||||
|
||||
ff = FormatFinder(messages)
|
||||
ff.run()
|
||||
|
||||
self.assertEqual(len(ff.message_types), 1, msg=sync)
|
||||
|
||||
preamble = ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE)
|
||||
self.assertEqual(preamble.start, 0, msg=sync)
|
||||
self.assertEqual(preamble.length, 8, msg=sync)
|
||||
|
||||
sync = ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC)
|
||||
self.assertEqual(sync.start, 8, msg=sync)
|
||||
self.assertEqual(sync.length, 8, msg=sync)
|
||||
|
||||
def test_sync_word_finding_varying_message_length(self):
|
||||
hex_messages = [
|
||||
"aaaa9a7d0f1337471100009a44ebdd13517bf9",
|
||||
"aaaa9a7d4747111337000134a4473c002b909630b11df37e34728c79c60396176aff2b5384e82f31511581d0cbb4822ad1b6734e2372ad5cf4af4c9d6b067e5f7ec359ec443c3b5ddc7a9e",
|
||||
"aaaa9a7d0f13374711000205ee081d26c86b8c",
|
||||
"aaaa9a7d474711133700037cae4cda789885f88f5fb29adc9acf954cb2850b9d94e7f3b009347c466790e89f2bcd728987d4670690861bbaa120f71f14d4ef8dc738a6d7c30e7d2143c267",
|
||||
"aaaa9a7d0f133747110004c2906142300427f3"
|
||||
]
|
||||
|
||||
messages = [Message.from_plain_hex_str(hex_msg) for hex_msg in hex_messages]
|
||||
for i in range(1, len(messages)):
|
||||
messages[i].message_type = messages[0].message_type
|
||||
|
||||
ff = FormatFinder(messages)
|
||||
ff.run()
|
||||
|
||||
self.assertEqual(len(ff.message_types), 1)
|
||||
preamble = ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE)
|
||||
self.assertEqual(preamble.start, 0)
|
||||
self.assertEqual(preamble.length, 16)
|
||||
|
||||
sync = ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC)
|
||||
self.assertEqual(sync.start, 16)
|
||||
self.assertEqual(sync.length, 16)
|
||||
|
||||
def test_sync_word_finding_common_prefix(self):
|
||||
"""
|
||||
Messages are very similar (odd and even ones are the same)
|
||||
However, they do not have two different sync words!
|
||||
The algorithm needs to check for a common prefix of the two found sync words
|
||||
|
||||
:return:
|
||||
"""
|
||||
sync = "0x1337"
|
||||
num_messages = 10
|
||||
|
||||
alice = Participant("Alice", address_hex="dead01")
|
||||
bob = Participant("Bob", address_hex="beef24")
|
||||
|
||||
mb = MessageTypeBuilder("protocol_with_one_message_type")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 72)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 24)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 24)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x1337"},
|
||||
preambles_by_mt={mb.message_type: "10" * 36},
|
||||
participants=[alice, bob])
|
||||
|
||||
random.seed(0)
|
||||
for i in range(num_messages):
|
||||
if i % 2 == 0:
|
||||
source, destination = alice, bob
|
||||
data_length = 8
|
||||
else:
|
||||
source, destination = bob, alice
|
||||
data_length = 16
|
||||
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
|
||||
source=source, destination=destination)
|
||||
|
||||
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
|
||||
possible_syncs = preprocessor.find_possible_syncs()
|
||||
#self.save_protocol("sync_by_common_prefix", pg)
|
||||
self.assertEqual(len(possible_syncs), 1)
|
||||
|
||||
# +0000 is okay, because this will get fixed by correction in FormatFinder
|
||||
self.assertIn(possible_syncs[0], [ProtocolGenerator.to_bits(sync), ProtocolGenerator.to_bits(sync) + "0000"])
|
||||
|
||||
def test_with_given_preamble_and_sync(self):
|
||||
preamble = "10101010"
|
||||
sync = "10011"
|
||||
pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync)],
|
||||
num_messages=(20,),
|
||||
data=(lambda i: 10 * i,))
|
||||
|
||||
# If we have a odd preamble length, the last bit of the preamble is counted to the sync
|
||||
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages],
|
||||
existing_message_types={i: msg.message_type for i, msg in
|
||||
enumerate(pg.protocol.messages)})
|
||||
preamble_starts, preamble_lengths, sync_len = preprocessor.preprocess()
|
||||
|
||||
#self.save_protocol("given_preamble", pg)
|
||||
|
||||
self.assertTrue(all(preamble_start == 0 for preamble_start in preamble_starts))
|
||||
self.assertTrue(all(preamble_length == len(preamble) for preamble_length in preamble_lengths))
|
||||
self.assertEqual(sync_len, len(sync))
|
||||
|
||||
@staticmethod
|
||||
def build_protocol_generator(preamble_syncs: list, num_messages: tuple, data: tuple) -> ProtocolGenerator:
|
||||
message_types = []
|
||||
preambles_by_mt = dict()
|
||||
syncs_by_mt = dict()
|
||||
|
||||
assert len(preamble_syncs) == len(num_messages) == len(data)
|
||||
|
||||
for i, (preamble, sync_word) in enumerate(preamble_syncs):
|
||||
assert isinstance(preamble, str)
|
||||
assert isinstance(sync_word, str)
|
||||
|
||||
preamble, sync_word = map(ProtocolGenerator.to_bits, (preamble, sync_word))
|
||||
|
||||
mb = MessageTypeBuilder("message type #{0}".format(i))
|
||||
mb.add_label(FieldType.Function.PREAMBLE, len(preamble))
|
||||
mb.add_label(FieldType.Function.SYNC, len(sync_word))
|
||||
|
||||
message_types.append(mb.message_type)
|
||||
preambles_by_mt[mb.message_type] = preamble
|
||||
syncs_by_mt[mb.message_type] = sync_word
|
||||
|
||||
pg = ProtocolGenerator(message_types, preambles_by_mt=preambles_by_mt, syncs_by_mt=syncs_by_mt)
|
||||
for i, msg_type in enumerate(message_types):
|
||||
for j in range(num_messages[i]):
|
||||
if callable(data[i]):
|
||||
msg_data = pg.decimal_to_bits(data[i](j), num_bits=8)
|
||||
else:
|
||||
msg_data = data[i]
|
||||
|
||||
pg.generate_message(message_type=msg_type, data=msg_data)
|
||||
|
||||
return pg
|
@ -0,0 +1,149 @@
|
||||
from tests.awre.AWRETestCase import AWRETestCase
|
||||
from tests.utils_testing import get_path_for_data_file
|
||||
from urh.awre.CommonRange import CommonRange
|
||||
from urh.awre.FormatFinder import FormatFinder
|
||||
from urh.awre.Preprocessor import Preprocessor
|
||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
|
||||
from urh.signalprocessing.FieldType import FieldType
|
||||
from urh.signalprocessing.Message import Message
|
||||
from urh.signalprocessing.MessageType import MessageType
|
||||
from urh.signalprocessing.Participant import Participant
|
||||
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
|
||||
import numpy as np
|
||||
|
||||
class TestAWRERealProtocols(AWRETestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
alice = Participant("Alice", "A")
|
||||
bob = Participant("Bob", "B")
|
||||
self.participants = [alice, bob]
|
||||
|
||||
def test_format_finding_enocean(self):
|
||||
enocean_protocol = ProtocolAnalyzer(None)
|
||||
with open(get_path_for_data_file("enocean_bits.txt")) as f:
|
||||
for line in f:
|
||||
enocean_protocol.messages.append(Message.from_plain_bits_str(line.replace("\n", "")))
|
||||
enocean_protocol.messages[-1].message_type = enocean_protocol.default_message_type
|
||||
|
||||
ff = FormatFinder(enocean_protocol.messages)
|
||||
ff.perform_iteration()
|
||||
|
||||
message_types = ff.message_types
|
||||
self.assertEqual(len(message_types), 1)
|
||||
|
||||
preamble = message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE)
|
||||
self.assertEqual(preamble.start, 0)
|
||||
self.assertEqual(preamble.length, 8)
|
||||
|
||||
sync = message_types[0].get_first_label_with_type(FieldType.Function.SYNC)
|
||||
self.assertEqual(sync.start, 8)
|
||||
self.assertEqual(sync.length, 4)
|
||||
|
||||
checksum = message_types[0].get_first_label_with_type(FieldType.Function.CHECKSUM)
|
||||
self.assertEqual(checksum.start, 56)
|
||||
self.assertEqual(checksum.length, 4)
|
||||
|
||||
self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
|
||||
self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
|
||||
self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.LENGTH))
|
||||
self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER))
|
||||
|
||||
def test_format_finding_rwe(self):
|
||||
ff, messages = self.get_format_finder_from_protocol_file("rwe.proto.xml", return_messages=True)
|
||||
ff.run()
|
||||
|
||||
sync1, sync2 = "0x9a7d9a7d", "0x67686768"
|
||||
|
||||
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in messages])
|
||||
possible_syncs = preprocessor.find_possible_syncs()
|
||||
self.assertIn(ProtocolGenerator.to_bits(sync1), possible_syncs)
|
||||
self.assertIn(ProtocolGenerator.to_bits(sync2), possible_syncs)
|
||||
|
||||
ack_messages = (3, 5, 7, 9, 11, 13, 15, 17, 20)
|
||||
ack_message_type = next(mt for mt, messages in ff.existing_message_types.items() if ack_messages[0] in messages)
|
||||
self.assertTrue(all(ack_msg in ff.existing_message_types[ack_message_type] for ack_msg in ack_messages))
|
||||
|
||||
for mt in ff.message_types:
|
||||
preamble = mt.get_first_label_with_type(FieldType.Function.PREAMBLE)
|
||||
self.assertEqual(preamble.start, 0)
|
||||
self.assertEqual(preamble.length, 32)
|
||||
|
||||
sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
|
||||
self.assertEqual(sync.start, 32)
|
||||
self.assertEqual(sync.length, 32)
|
||||
|
||||
length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
|
||||
self.assertEqual(length.start, 64)
|
||||
self.assertEqual(length.length, 8)
|
||||
|
||||
dst = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
|
||||
self.assertEqual(dst.length, 24)
|
||||
|
||||
if mt == ack_message_type or 1 in ff.existing_message_types[mt]:
|
||||
self.assertEqual(dst.start, 72)
|
||||
else:
|
||||
self.assertEqual(dst.start, 88)
|
||||
|
||||
if mt != ack_message_type and 1 not in ff.existing_message_types[mt]:
|
||||
src = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
|
||||
self.assertEqual(src.start, 112)
|
||||
self.assertEqual(src.length, 24)
|
||||
elif 1 in ff.existing_message_types[mt]:
|
||||
# long ack
|
||||
src = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
|
||||
self.assertEqual(src.start, 96)
|
||||
self.assertEqual(src.length, 24)
|
||||
|
||||
crc = mt.get_first_label_with_type(FieldType.Function.CHECKSUM)
|
||||
self.assertIsNotNone(crc)
|
||||
|
||||
def test_homematic(self):
|
||||
proto_file = get_path_for_data_file("homematic.proto.xml")
|
||||
protocol = ProtocolAnalyzer(signal=None, filename=proto_file)
|
||||
protocol.message_types = []
|
||||
protocol.from_xml_file(filename=proto_file, read_bits=True)
|
||||
# prevent interfering with preassinged labels
|
||||
protocol.message_types = [MessageType("Default")]
|
||||
|
||||
participants = sorted({msg.participant for msg in protocol.messages})
|
||||
|
||||
self.clear_message_types(protocol.messages)
|
||||
ff = FormatFinder(protocol.messages, participants=participants)
|
||||
ff.known_participant_addresses.clear()
|
||||
ff.perform_iteration()
|
||||
|
||||
self.assertGreater(len(ff.message_types), 0)
|
||||
|
||||
for i, message_type in enumerate(ff.message_types):
|
||||
preamble = message_type.get_first_label_with_type(FieldType.Function.PREAMBLE)
|
||||
self.assertEqual(preamble.start, 0)
|
||||
self.assertEqual(preamble.length, 32)
|
||||
|
||||
sync = message_type.get_first_label_with_type(FieldType.Function.SYNC)
|
||||
self.assertEqual(sync.start, 32)
|
||||
self.assertEqual(sync.length, 32)
|
||||
|
||||
length = message_type.get_first_label_with_type(FieldType.Function.LENGTH)
|
||||
self.assertEqual(length.start, 64)
|
||||
self.assertEqual(length.length, 8)
|
||||
|
||||
seq = message_type.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
|
||||
self.assertEqual(seq.start, 72)
|
||||
self.assertEqual(seq.length, 8)
|
||||
|
||||
src = message_type.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
|
||||
self.assertEqual(src.start, 96)
|
||||
self.assertEqual(src.length, 24)
|
||||
|
||||
dst = message_type.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
|
||||
self.assertEqual(dst.start, 120)
|
||||
self.assertEqual(dst.length, 24)
|
||||
|
||||
checksum = message_type.get_first_label_with_type(FieldType.Function.CHECKSUM)
|
||||
self.assertEqual(checksum.length, 16)
|
||||
self.assertIn("CC1101", checksum.checksum.caption)
|
||||
|
||||
for msg_index in ff.existing_message_types[message_type]:
|
||||
msg_len = len(protocol.messages[msg_index])
|
||||
self.assertEqual(checksum.start, msg_len-16)
|
||||
self.assertEqual(checksum.end, msg_len)
|
@ -0,0 +1,102 @@
|
||||
import array
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tests.awre.AWRETestCase import AWRETestCase
|
||||
from urh.awre.CommonRange import ChecksumRange
|
||||
from urh.awre.FormatFinder import FormatFinder
|
||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
|
||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
|
||||
from urh.awre.engines.ChecksumEngine import ChecksumEngine
|
||||
from urh.signalprocessing.FieldType import FieldType
|
||||
from urh.util import util
|
||||
from urh.util.GenericCRC import GenericCRC
|
||||
from urh.cythonext import util as c_util
|
||||
|
||||
class TestChecksumEngine(AWRETestCase):
|
||||
def test_find_crc8(self):
|
||||
messages = ["aabbcc7d", "abcdee24", "dacafe33"]
|
||||
message_bits = [np.array(msg, dtype=np.uint8) for msg in map(util.hex2bit, messages)]
|
||||
|
||||
checksum_engine = ChecksumEngine(message_bits, n_gram_length=8)
|
||||
result = checksum_engine.find()
|
||||
self.assertEqual(len(result), 1)
|
||||
checksum_range = result[0] # type: ChecksumRange
|
||||
self.assertEqual(checksum_range.length, 8)
|
||||
self.assertEqual(checksum_range.start, 24)
|
||||
|
||||
reference = GenericCRC()
|
||||
reference.set_polynomial_from_hex("0x07")
|
||||
self.assertEqual(checksum_range.crc.polynomial, reference.polynomial)
|
||||
|
||||
self.assertEqual(checksum_range.message_indices, {0, 1, 2})
|
||||
|
||||
def test_find_crc16(self):
|
||||
messages = ["12345678347B", "abcdefffABBD", "cafe1337CE12"]
|
||||
message_bits = [np.array(msg, dtype=np.uint8) for msg in map(util.hex2bit, messages)]
|
||||
|
||||
checksum_engine = ChecksumEngine(message_bits, n_gram_length=8)
|
||||
result = checksum_engine.find()
|
||||
self.assertEqual(len(result), 1)
|
||||
checksum_range = result[0] # type: ChecksumRange
|
||||
self.assertEqual(checksum_range.start, 32)
|
||||
self.assertEqual(checksum_range.length, 16)
|
||||
|
||||
reference = GenericCRC()
|
||||
reference.set_polynomial_from_hex("0x8005")
|
||||
self.assertEqual(checksum_range.crc.polynomial, reference.polynomial)
|
||||
|
||||
self.assertEqual(checksum_range.message_indices, {0, 1, 2})
|
||||
|
||||
def test_find_crc32(self):
|
||||
messages = ["deadcafe5D7F3F5A", "47111337E3319242", "beefaffe0DCD0E15"]
|
||||
message_bits = [np.array(msg, dtype=np.uint8) for msg in map(util.hex2bit, messages)]
|
||||
|
||||
checksum_engine = ChecksumEngine(message_bits, n_gram_length=8)
|
||||
result = checksum_engine.find()
|
||||
self.assertEqual(len(result), 1)
|
||||
checksum_range = result[0] # type: ChecksumRange
|
||||
self.assertEqual(checksum_range.start, 32)
|
||||
self.assertEqual(checksum_range.length, 32)
|
||||
|
||||
reference = GenericCRC()
|
||||
reference.set_polynomial_from_hex("0x04C11DB7")
|
||||
self.assertEqual(checksum_range.crc.polynomial, reference.polynomial)
|
||||
|
||||
self.assertEqual(checksum_range.message_indices, {0, 1, 2})
|
||||
|
||||
def test_find_generated_crc16(self):
|
||||
mb = MessageTypeBuilder("data")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.DATA, 32)
|
||||
mb.add_checksum_label(16, GenericCRC.from_standard_checksum("CRC16 CCITT"))
|
||||
|
||||
mb2 = MessageTypeBuilder("data2")
|
||||
mb2.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb2.add_label(FieldType.Function.SYNC, 16)
|
||||
mb2.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb2.add_label(FieldType.Function.DATA, 16)
|
||||
|
||||
mb2.add_checksum_label(16, GenericCRC.from_standard_checksum("CRC16 CCITT"))
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type, mb2.message_type], syncs_by_mt={mb.message_type: "0x1234", mb2.message_type: "0x1234"})
|
||||
|
||||
num_messages = 5
|
||||
|
||||
for i in range(num_messages):
|
||||
pg.generate_message(data="{0:032b}".format(i), message_type=mb.message_type)
|
||||
pg.generate_message(data="{0:016b}".format(i), message_type=mb2.message_type)
|
||||
|
||||
#self.save_protocol("crc16_test", pg)
|
||||
self.clear_message_types(pg.protocol.messages)
|
||||
|
||||
ff = FormatFinder(pg.protocol.messages)
|
||||
ff.run()
|
||||
|
||||
self.assertEqual(len(ff.message_types), 2)
|
||||
for mt in ff.message_types:
|
||||
checksum_label = mt.get_first_label_with_type(FieldType.Function.CHECKSUM)
|
||||
self.assertEqual(checksum_label.length, 16)
|
||||
self.assertEqual(checksum_label.checksum.caption, "CRC16 CCITT")
|
@ -0,0 +1,35 @@
|
||||
import unittest
|
||||
|
||||
from urh.awre.CommonRange import CommonRange
|
||||
|
||||
|
||||
class TestCommonRange(unittest.TestCase):
|
||||
def test_ensure_not_overlaps(self):
|
||||
test_range = CommonRange(start=4, length=8, value="12345678")
|
||||
self.assertEqual(test_range.end, 11)
|
||||
|
||||
# no overlapping
|
||||
self.assertEqual(test_range, test_range.ensure_not_overlaps(0, 3)[0])
|
||||
self.assertEqual(test_range, test_range.ensure_not_overlaps(20, 24)[0])
|
||||
|
||||
# overlapping on left
|
||||
result = test_range.ensure_not_overlaps(2, 6)[0]
|
||||
self.assertEqual(result.start, 6)
|
||||
self.assertEqual(result.end, 11)
|
||||
|
||||
# overlapping on right
|
||||
result = test_range.ensure_not_overlaps(6, 14)[0]
|
||||
self.assertEqual(result.start, 4)
|
||||
self.assertEqual(result.end, 5)
|
||||
|
||||
# full overlapping
|
||||
self.assertEqual(len(test_range.ensure_not_overlaps(3, 14)), 0)
|
||||
|
||||
# overlapping in the middle
|
||||
result = test_range.ensure_not_overlaps(6, 9)
|
||||
self.assertEqual(len(result), 2)
|
||||
left, right = result[0], result[1]
|
||||
self.assertEqual(left.start, 4)
|
||||
self.assertEqual(left.end, 5)
|
||||
self.assertEqual(right.start, 10)
|
||||
self.assertEqual(right.end, 11)
|
102
Software/Universal Radio Hacker/tests/awre/test_format_finder.py
Normal file
102
Software/Universal Radio Hacker/tests/awre/test_format_finder.py
Normal file
@ -0,0 +1,102 @@
|
||||
import numpy as np
|
||||
|
||||
from tests.awre.AWRETestCase import AWRETestCase
|
||||
from urh.awre.CommonRange import CommonRange, CommonRangeContainer
|
||||
from urh.awre.FormatFinder import FormatFinder
|
||||
|
||||
|
||||
class TestFormatFinder(AWRETestCase):
|
||||
def test_create_message_types_1(self):
|
||||
rng1 = CommonRange(0, 8, "1" * 8, score=1, field_type="Length")
|
||||
rng1.message_indices = {0, 1, 2}
|
||||
rng2 = CommonRange(8, 8, "1" * 8, score=1, field_type="Address")
|
||||
rng2.message_indices = {0, 1, 2}
|
||||
|
||||
message_types = FormatFinder.create_common_range_containers({rng1, rng2})
|
||||
self.assertEqual(len(message_types), 1)
|
||||
|
||||
expected = CommonRangeContainer([rng1, rng2], message_indices={0, 1, 2})
|
||||
self.assertEqual(message_types[0], expected)
|
||||
|
||||
def test_create_message_types_2(self):
|
||||
rng1 = CommonRange(0, 8, "1" * 8, score=1, field_type="Length")
|
||||
rng1.message_indices = {0, 2, 4, 6, 8, 12}
|
||||
rng2 = CommonRange(8, 8, "1" * 8, score=1, field_type="Address")
|
||||
rng2.message_indices = {1, 2, 3, 4, 5, 12}
|
||||
rng3 = CommonRange(16, 8, "1" * 8, score=1, field_type="Seq")
|
||||
rng3.message_indices = {1, 3, 5, 7, 12}
|
||||
|
||||
message_types = FormatFinder.create_common_range_containers({rng1, rng2, rng3})
|
||||
expected1 = CommonRangeContainer([rng1], message_indices={0, 6, 8})
|
||||
expected2 = CommonRangeContainer([rng1, rng2], message_indices={2, 4})
|
||||
expected3 = CommonRangeContainer([rng1, rng2, rng3], message_indices={12})
|
||||
expected4 = CommonRangeContainer([rng2, rng3], message_indices={1, 3, 5})
|
||||
expected5 = CommonRangeContainer([rng3], message_indices={7})
|
||||
|
||||
self.assertEqual(len(message_types), 5)
|
||||
|
||||
self.assertIn(expected1, message_types)
|
||||
self.assertIn(expected2, message_types)
|
||||
self.assertIn(expected3, message_types)
|
||||
self.assertIn(expected4, message_types)
|
||||
self.assertIn(expected5, message_types)
|
||||
|
||||
def test_retransform_message_indices(self):
|
||||
sync_ends = np.array([12, 12, 12, 14, 14])
|
||||
|
||||
rng = CommonRange(0, 8, "1" * 8, score=1, field_type="length", message_indices={0, 1, 2, 3, 4})
|
||||
retransformed_ranges = FormatFinder.retransform_message_indices([rng], [0, 1, 2, 3, 4], sync_ends)
|
||||
|
||||
# two different sync ends
|
||||
self.assertEqual(len(retransformed_ranges), 2)
|
||||
|
||||
expected1 = CommonRange(12, 8, "1" * 8, score=1, field_type="length", message_indices={0, 1, 2})
|
||||
expected2 = CommonRange(14, 8, "1" * 8, score=1, field_type="length", message_indices={3, 4})
|
||||
|
||||
self.assertIn(expected1, retransformed_ranges)
|
||||
self.assertIn(expected2, retransformed_ranges)
|
||||
|
||||
def test_handle_no_overlapping_conflict(self):
|
||||
rng1 = CommonRange(0, 8, "1" * 8, score=1, field_type="Length")
|
||||
rng1.message_indices = {0, 1, 2}
|
||||
rng2 = CommonRange(8, 8, "1" * 8, score=1, field_type="Address")
|
||||
rng2.message_indices = {0, 1, 2}
|
||||
|
||||
container = CommonRangeContainer([rng1, rng2], message_indices={0, 1, 2})
|
||||
|
||||
# no conflict
|
||||
result = FormatFinder.handle_overlapping_conflict([container])
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(len(result[0]), 2)
|
||||
self.assertIn(rng1, result[0])
|
||||
self.assertEqual(result[0].message_indices, {0, 1, 2})
|
||||
self.assertIn(rng2, result[0])
|
||||
|
||||
def test_handle_easy_overlapping_conflict(self):
|
||||
# Easy conflict: First Label has higher score
|
||||
rng1 = CommonRange(8, 8, "1" * 8, score=1, field_type="Length")
|
||||
rng1.message_indices = {0, 1, 2}
|
||||
rng2 = CommonRange(8, 8, "1" * 8, score=0.8, field_type="Address")
|
||||
rng2.message_indices = {0, 1, 2}
|
||||
|
||||
container = CommonRangeContainer([rng1, rng2], message_indices={0, 1, 2})
|
||||
result = FormatFinder.handle_overlapping_conflict([container])
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(len(result[0]), 1)
|
||||
self.assertIn(rng1, result[0])
|
||||
self.assertEqual(result[0].message_indices, {0, 1, 2})
|
||||
|
||||
def test_handle_medium_overlapping_conflict(self):
|
||||
rng1 = CommonRange(8, 8, "1" * 8, score=1, field_type="Length")
|
||||
rng2 = CommonRange(4, 10, "1" * 8, score=0.8, field_type="Address")
|
||||
rng3 = CommonRange(15, 20, "1" * 8, score=1, field_type="Seq")
|
||||
rng4 = CommonRange(60, 80, "1" * 8, score=0.8, field_type="Type")
|
||||
rng5 = CommonRange(70, 90, "1" * 8, score=0.9, field_type="Data")
|
||||
|
||||
container = CommonRangeContainer([rng1, rng2, rng3, rng4, rng5])
|
||||
result = FormatFinder.handle_overlapping_conflict([container])
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(len(result[0]), 3)
|
||||
self.assertIn(rng1, result[0])
|
||||
self.assertIn(rng3, result[0])
|
||||
self.assertIn(rng5, result[0])
|
@ -0,0 +1,236 @@
|
||||
from tests.awre.AWRETestCase import AWRETestCase
|
||||
from urh.awre import AutoAssigner
|
||||
from urh.awre.FormatFinder import FormatFinder
|
||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
|
||||
from urh.awre.Preprocessor import Preprocessor
|
||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
|
||||
from urh.signalprocessing.FieldType import FieldType
|
||||
from urh.signalprocessing.Participant import Participant
|
||||
from urh.util import util
|
||||
|
||||
|
||||
class TestGeneratedProtocols(AWRETestCase):
|
||||
def __check_addresses(self, messages, format_finder, known_participant_addresses):
|
||||
"""
|
||||
Use the AutoAssigner used also in main GUI to test assigned participant addresses to get same results
|
||||
as in main program and not rely on cache of FormatFinder, because values there might be false
|
||||
but SRC address labels still on right position which is the basis for Auto Assigner
|
||||
|
||||
:param messages:
|
||||
:param format_finder:
|
||||
:param known_participant_addresses:
|
||||
:return:
|
||||
"""
|
||||
|
||||
for msg_type, indices in format_finder.existing_message_types.items():
|
||||
for i in indices:
|
||||
messages[i].message_type = msg_type
|
||||
|
||||
participants = list(set(m.participant for m in messages))
|
||||
for p in participants:
|
||||
p.address_hex = ""
|
||||
AutoAssigner.auto_assign_participant_addresses(messages, participants)
|
||||
|
||||
for i in range(len(participants)):
|
||||
self.assertIn(participants[i].address_hex,
|
||||
list(map(util.convert_numbers_to_hex_string, known_participant_addresses.values())),
|
||||
msg=" [ " + " ".join(p.address_hex for p in participants) + " ]")
|
||||
|
||||
def test_without_preamble(self):
|
||||
alice = Participant("Alice", address_hex="24")
|
||||
broadcast = Participant("Broadcast", address_hex="ff")
|
||||
|
||||
mb = MessageTypeBuilder("data")
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 8)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x8e88"},
|
||||
preambles_by_mt={mb.message_type: "10" * 8},
|
||||
participants=[alice, broadcast])
|
||||
|
||||
for i in range(20):
|
||||
data_bits = 16 if i % 2 == 0 else 32
|
||||
source = pg.participants[i % 2]
|
||||
destination = pg.participants[(i + 1) % 2]
|
||||
pg.generate_message(data="1010" * (data_bits // 4), source=source, destination=destination)
|
||||
|
||||
#self.save_protocol("without_preamble", pg)
|
||||
self.clear_message_types(pg.messages)
|
||||
ff = FormatFinder(pg.messages)
|
||||
ff.known_participant_addresses.clear()
|
||||
|
||||
ff.run()
|
||||
self.assertEqual(len(ff.message_types), 1)
|
||||
|
||||
mt = ff.message_types[0]
|
||||
sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
|
||||
self.assertEqual(sync.start, 0)
|
||||
self.assertEqual(sync.length, 16)
|
||||
|
||||
length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
|
||||
self.assertEqual(length.start, 16)
|
||||
self.assertEqual(length.length, 8)
|
||||
|
||||
dst = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
|
||||
self.assertEqual(dst.start, 24)
|
||||
self.assertEqual(dst.length, 8)
|
||||
|
||||
seq = mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
|
||||
self.assertEqual(seq.start, 32)
|
||||
self.assertEqual(seq.length, 8)
|
||||
|
||||
def test_without_preamble_random_data(self):
|
||||
ff = self.get_format_finder_from_protocol_file("without_ack_random_data.proto.xml")
|
||||
ff.run()
|
||||
|
||||
self.assertEqual(len(ff.message_types), 1)
|
||||
|
||||
mt = ff.message_types[0]
|
||||
sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
|
||||
self.assertEqual(sync.start, 0)
|
||||
self.assertEqual(sync.length, 16)
|
||||
|
||||
length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
|
||||
self.assertEqual(length.start, 16)
|
||||
self.assertEqual(length.length, 8)
|
||||
|
||||
dst = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
|
||||
self.assertEqual(dst.start, 24)
|
||||
self.assertEqual(dst.length, 8)
|
||||
|
||||
seq = mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
|
||||
self.assertEqual(seq.start, 32)
|
||||
self.assertEqual(seq.length, 8)
|
||||
|
||||
def test_without_preamble_random_data2(self):
|
||||
ff = self.get_format_finder_from_protocol_file("without_ack_random_data2.proto.xml")
|
||||
ff.run()
|
||||
|
||||
self.assertEqual(len(ff.message_types), 1)
|
||||
|
||||
mt = ff.message_types[0]
|
||||
sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
|
||||
self.assertEqual(sync.start, 0)
|
||||
self.assertEqual(sync.length, 16)
|
||||
|
||||
length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
|
||||
self.assertEqual(length.start, 16)
|
||||
self.assertEqual(length.length, 8)
|
||||
|
||||
dst = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
|
||||
self.assertEqual(dst.start, 24)
|
||||
self.assertEqual(dst.length, 8)
|
||||
|
||||
seq = mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
|
||||
self.assertEqual(seq.start, 32)
|
||||
self.assertEqual(seq.length, 8)
|
||||
|
||||
def test_with_checksum(self):
|
||||
ff = self.get_format_finder_from_protocol_file("with_checksum.proto.xml", clear_participant_addresses=False)
|
||||
known_participant_addresses = ff.known_participant_addresses.copy()
|
||||
ff.known_participant_addresses.clear()
|
||||
ff.run()
|
||||
|
||||
self.assertIn(known_participant_addresses[0].tostring(),
|
||||
list(map(bytes, ff.known_participant_addresses.values())))
|
||||
self.assertIn(known_participant_addresses[1].tostring(),
|
||||
list(map(bytes, ff.known_participant_addresses.values())))
|
||||
|
||||
self.assertEqual(len(ff.message_types), 3)
|
||||
|
||||
def test_with_only_one_address(self):
|
||||
ff = self.get_format_finder_from_protocol_file("only_one_address.proto.xml", clear_participant_addresses=False)
|
||||
known_participant_addresses = ff.known_participant_addresses.copy()
|
||||
ff.known_participant_addresses.clear()
|
||||
|
||||
ff.run()
|
||||
|
||||
self.assertIn(known_participant_addresses[0].tostring(),
|
||||
list(map(bytes, ff.known_participant_addresses.values())))
|
||||
self.assertIn(known_participant_addresses[1].tostring(),
|
||||
list(map(bytes, ff.known_participant_addresses.values())))
|
||||
|
||||
def test_with_four_broken(self):
|
||||
ff, messages = self.get_format_finder_from_protocol_file("four_broken.proto.xml",
|
||||
clear_participant_addresses=False,
|
||||
return_messages=True)
|
||||
|
||||
assert isinstance(ff, FormatFinder)
|
||||
known_participant_addresses = ff.known_participant_addresses.copy()
|
||||
ff.known_participant_addresses.clear()
|
||||
|
||||
ff.run()
|
||||
|
||||
self.__check_addresses(messages, ff, known_participant_addresses)
|
||||
|
||||
for i in range(4, len(messages)):
|
||||
mt = next(mt for mt, indices in ff.existing_message_types.items() if i in indices)
|
||||
self.assertIsNotNone(mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER))
|
||||
|
||||
def test_with_one_address_one_message_type(self):
|
||||
ff, messages = self.get_format_finder_from_protocol_file("one_address_one_mt.proto.xml",
|
||||
clear_participant_addresses=False,
|
||||
return_messages=True)
|
||||
|
||||
self.assertEqual(len(messages), 17)
|
||||
self.assertEqual(len(ff.hexvectors), 17)
|
||||
|
||||
known_participant_addresses = ff.known_participant_addresses.copy()
|
||||
ff.known_participant_addresses.clear()
|
||||
|
||||
ff.run()
|
||||
|
||||
self.assertEqual(len(ff.message_types), 1)
|
||||
|
||||
self.assertIn(known_participant_addresses[0].tostring(),
|
||||
list(map(bytes, ff.known_participant_addresses.values())))
|
||||
self.assertIn(known_participant_addresses[1].tostring(),
|
||||
list(map(bytes, ff.known_participant_addresses.values())))
|
||||
|
||||
def test_without_preamble_24_messages(self):
|
||||
ff, messages = self.get_format_finder_from_protocol_file("no_preamble24.proto.xml",
|
||||
clear_participant_addresses=False,
|
||||
return_messages=True)
|
||||
|
||||
known_participant_addresses = ff.known_participant_addresses.copy()
|
||||
ff.known_participant_addresses.clear()
|
||||
|
||||
ff.run()
|
||||
|
||||
self.assertEqual(len(ff.message_types), 1)
|
||||
|
||||
self.assertIn(known_participant_addresses[0].tostring(),
|
||||
list(map(bytes, ff.known_participant_addresses.values())))
|
||||
self.assertIn(known_participant_addresses[1].tostring(),
|
||||
list(map(bytes, ff.known_participant_addresses.values())))
|
||||
|
||||
def test_with_three_syncs_different_preamble_lengths(self):
|
||||
ff, messages = self.get_format_finder_from_protocol_file("three_syncs.proto.xml", return_messages=True)
|
||||
preprocessor = Preprocessor(ff.get_bitvectors_from_messages(messages))
|
||||
sync_words = preprocessor.find_possible_syncs()
|
||||
self.assertIn("0000010000100000", sync_words, msg="Sync 1")
|
||||
self.assertIn("0010001000100010", sync_words, msg="Sync 2")
|
||||
self.assertIn("0110011101100111", sync_words, msg="Sync 3")
|
||||
|
||||
ff.run()
|
||||
|
||||
expected_sync_ends = [32, 24, 40, 24, 32, 24, 40, 24, 32, 24, 40, 24, 32, 24, 40, 24]
|
||||
|
||||
for i, (s1, s2) in enumerate(zip(expected_sync_ends, ff.sync_ends)):
|
||||
self.assertEqual(s1, s2, msg=str(i))
|
||||
|
||||
def test_with_four_participants(self):
|
||||
ff, messages = self.get_format_finder_from_protocol_file("four_participants.proto.xml",
|
||||
clear_participant_addresses=False,
|
||||
return_messages=True)
|
||||
|
||||
known_participant_addresses = ff.known_participant_addresses.copy()
|
||||
ff.known_participant_addresses.clear()
|
||||
|
||||
ff.run()
|
||||
|
||||
self.__check_addresses(messages, ff, known_participant_addresses)
|
||||
self.assertEqual(len(ff.message_types), 3)
|
167
Software/Universal Radio Hacker/tests/awre/test_length_engine.py
Normal file
167
Software/Universal Radio Hacker/tests/awre/test_length_engine.py
Normal file
@ -0,0 +1,167 @@
|
||||
import random
|
||||
|
||||
from tests.awre.AWRETestCase import AWRETestCase
|
||||
from urh.awre.FormatFinder import FormatFinder
|
||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
|
||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
|
||||
from urh.awre.engines.LengthEngine import LengthEngine
|
||||
from urh.signalprocessing.FieldType import FieldType
|
||||
from urh.signalprocessing.ProtocoLabel import ProtocolLabel
|
||||
|
||||
|
||||
class TestLengthEngine(AWRETestCase):
|
||||
def test_simple_protocol(self):
|
||||
"""
|
||||
Test a simple protocol with
|
||||
preamble, sync and length field (8 bit) and some random data
|
||||
|
||||
:return:
|
||||
"""
|
||||
mb = MessageTypeBuilder("simple_length_test")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
|
||||
num_messages_by_data_length = {8: 5, 16: 10, 32: 15}
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x9a9d"})
|
||||
random.seed(0)
|
||||
for data_length, num_messages in num_messages_by_data_length.items():
|
||||
for i in range(num_messages):
|
||||
pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(data_length)]))
|
||||
|
||||
#self.save_protocol("simple_length", pg)
|
||||
|
||||
self.clear_message_types(pg.protocol.messages)
|
||||
ff = FormatFinder(pg.protocol.messages)
|
||||
|
||||
length_engine = LengthEngine(ff.bitvectors)
|
||||
highscored_ranges = length_engine.find(n_gram_length=8)
|
||||
self.assertEqual(len(highscored_ranges), 3)
|
||||
|
||||
ff.perform_iteration()
|
||||
self.assertEqual(len(ff.message_types), 1)
|
||||
self.assertGreater(len(ff.message_types[0]), 0)
|
||||
label = ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH)
|
||||
self.assertEqual(label.start, 24)
|
||||
self.assertEqual(label.length, 8)
|
||||
|
||||
def test_easy_protocol(self):
|
||||
"""
|
||||
preamble, sync, sequence number, length field (8 bit) and some random data
|
||||
|
||||
:return:
|
||||
"""
|
||||
mb = MessageTypeBuilder("easy_length_test")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 16)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
|
||||
|
||||
num_messages_by_data_length = {32: 10, 64: 15, 16: 5, 24: 7}
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
preambles_by_mt={mb.message_type: "10" * 8},
|
||||
syncs_by_mt={mb.message_type: "0xcafe"})
|
||||
for data_length, num_messages in num_messages_by_data_length.items():
|
||||
for i in range(num_messages):
|
||||
if i % 4 == 0:
|
||||
data = "1" * data_length
|
||||
elif i % 4 == 1:
|
||||
data = "0" * data_length
|
||||
elif i % 4 == 2:
|
||||
data = "10" * (data_length // 2)
|
||||
else:
|
||||
data = "01" * (data_length // 2)
|
||||
|
||||
pg.generate_message(data=data)
|
||||
|
||||
#self.save_protocol("easy_length", pg)
|
||||
|
||||
self.clear_message_types(pg.protocol.messages)
|
||||
ff = FormatFinder(pg.protocol.messages)
|
||||
|
||||
length_engine = LengthEngine(ff.bitvectors)
|
||||
highscored_ranges = length_engine.find(n_gram_length=8)
|
||||
self.assertEqual(len(highscored_ranges), 4)
|
||||
|
||||
ff.perform_iteration()
|
||||
self.assertEqual(len(ff.message_types), 1)
|
||||
self.assertGreater(len(ff.message_types[0]), 0)
|
||||
label = ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH)
|
||||
self.assertIsInstance(label, ProtocolLabel)
|
||||
self.assertEqual(label.start, 32)
|
||||
self.assertEqual(label.length, 8)
|
||||
|
||||
def test_medium_protocol(self):
|
||||
"""
|
||||
Protocol with two message types. Length field only present in one of them
|
||||
|
||||
:return:
|
||||
"""
|
||||
mb1 = MessageTypeBuilder("data")
|
||||
mb1.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb1.add_label(FieldType.Function.SYNC, 8)
|
||||
mb1.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb1.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
|
||||
|
||||
mb2 = MessageTypeBuilder("ack")
|
||||
mb2.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb2.add_label(FieldType.Function.SYNC, 8)
|
||||
|
||||
pg = ProtocolGenerator([mb1.message_type, mb2.message_type],
|
||||
syncs_by_mt={mb1.message_type: "11110011",
|
||||
mb2.message_type: "11110011"})
|
||||
num_messages_by_data_length = {8: 5, 16: 10, 32: 5}
|
||||
for data_length, num_messages in num_messages_by_data_length.items():
|
||||
for i in range(num_messages):
|
||||
pg.generate_message(data=pg.decimal_to_bits(10 * i, data_length), message_type=mb1.message_type)
|
||||
pg.generate_message(message_type=mb2.message_type, data="0xaf")
|
||||
|
||||
#self.save_protocol("medium_length", pg)
|
||||
|
||||
self.clear_message_types(pg.protocol.messages)
|
||||
ff = FormatFinder(pg.protocol.messages)
|
||||
|
||||
ff.perform_iteration()
|
||||
self.assertEqual(len(ff.message_types), 2)
|
||||
length_mt = next(
|
||||
mt for mt in ff.message_types if mt.get_first_label_with_type(FieldType.Function.LENGTH) is not None)
|
||||
length_label = length_mt.get_first_label_with_type(FieldType.Function.LENGTH)
|
||||
|
||||
for i, sync_end in enumerate(ff.sync_ends):
|
||||
self.assertEqual(sync_end, 16, msg=str(i))
|
||||
|
||||
self.assertEqual(16, length_label.start)
|
||||
self.assertEqual(8, length_label.length)
|
||||
|
||||
def test_little_endian_16_bit(self):
|
||||
mb = MessageTypeBuilder("little_endian_16_length_test")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 16)
|
||||
|
||||
num_messages_by_data_length = {256*8: 5, 16: 4, 512: 2}
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x9a9d"},
|
||||
little_endian=True)
|
||||
|
||||
random.seed(0)
|
||||
for data_length, num_messages in num_messages_by_data_length.items():
|
||||
for i in range(num_messages):
|
||||
pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(data_length)]))
|
||||
|
||||
#self.save_protocol("little_endian_16_length_test", pg)
|
||||
|
||||
self.clear_message_types(pg.protocol.messages)
|
||||
ff = FormatFinder(pg.protocol.messages)
|
||||
|
||||
length_engine = LengthEngine(ff.bitvectors)
|
||||
highscored_ranges = length_engine.find(n_gram_length=8)
|
||||
self.assertEqual(len(highscored_ranges), 3)
|
||||
|
||||
ff.perform_iteration()
|
||||
self.assertEqual(len(ff.message_types), 1)
|
||||
self.assertGreater(len(ff.message_types[0]), 0)
|
||||
label = ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH)
|
||||
self.assertEqual(label.start, 24)
|
||||
self.assertEqual(label.length, 16)
|
@ -0,0 +1,198 @@
|
||||
import copy
|
||||
import random
|
||||
|
||||
from urh.signalprocessing.MessageType import MessageType
|
||||
|
||||
from urh.awre.FormatFinder import FormatFinder
|
||||
|
||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
|
||||
from urh.signalprocessing.FieldType import FieldType
|
||||
|
||||
from tests.awre.AWRETestCase import AWRETestCase
|
||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
|
||||
from urh.signalprocessing.Participant import Participant
|
||||
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
|
||||
|
||||
|
||||
class TestPartiallyLabeled(AWRETestCase):
|
||||
"""
|
||||
Some tests if there are already information about the message types present
|
||||
|
||||
"""
|
||||
def test_fully_labeled(self):
|
||||
"""
|
||||
For fully labeled protocol, nothing should be done
|
||||
|
||||
:return:
|
||||
"""
|
||||
protocol = self.__prepare_example_protocol()
|
||||
message_types = sorted(copy.deepcopy(protocol.message_types), key=lambda x: x.name)
|
||||
ff = FormatFinder(protocol.messages)
|
||||
ff.perform_iteration()
|
||||
self.assertEqual(len(message_types), len(ff.message_types))
|
||||
|
||||
for mt1, mt2 in zip(message_types, ff.message_types):
|
||||
self.assertTrue(self.__message_types_have_same_labels(mt1, mt2))
|
||||
|
||||
def test_one_message_type_empty(self):
|
||||
"""
|
||||
Empty the "ACK" message type, the labels should be find by FormatFinder
|
||||
|
||||
:return:
|
||||
"""
|
||||
protocol = self.__prepare_example_protocol()
|
||||
n_message_types = len(protocol.message_types)
|
||||
ack_mt = next(mt for mt in protocol.message_types if mt.name == "ack")
|
||||
ack_mt.clear()
|
||||
self.assertEqual(len(ack_mt), 0)
|
||||
|
||||
ff = FormatFinder(protocol.messages)
|
||||
ff.perform_iteration()
|
||||
self.assertEqual(n_message_types, len(ff.message_types))
|
||||
|
||||
self.assertEqual(len(ack_mt), 4, msg=str(ack_mt))
|
||||
|
||||
def test_given_address_information(self):
|
||||
"""
|
||||
Empty both message types and see if addresses are found, when information of participant addresses is given
|
||||
|
||||
:return:
|
||||
"""
|
||||
protocol = self.__prepare_example_protocol()
|
||||
self.clear_message_types(protocol.messages)
|
||||
|
||||
ff = FormatFinder(protocol.messages)
|
||||
ff.perform_iteration()
|
||||
self.assertEqual(2, len(ff.message_types))
|
||||
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE))
|
||||
self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.PREAMBLE))
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC))
|
||||
self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.SYNC))
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH))
|
||||
self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.LENGTH))
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
|
||||
self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
|
||||
self.assertIsNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
|
||||
self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
|
||||
|
||||
def test_type_part_already_labeled(self):
|
||||
protocol = self.__prepare_simple_example_protocol()
|
||||
self.clear_message_types(protocol.messages)
|
||||
ff = FormatFinder(protocol.messages)
|
||||
|
||||
# overlaps type
|
||||
ff.message_types[0].add_protocol_label_start_length(32, 8)
|
||||
ff.perform_iteration()
|
||||
self.assertEqual(1, len(ff.message_types))
|
||||
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE))
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC))
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH))
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
|
||||
|
||||
def test_length_part_already_labeled(self):
|
||||
protocol = self.__prepare_simple_example_protocol()
|
||||
self.clear_message_types(protocol.messages)
|
||||
ff = FormatFinder(protocol.messages)
|
||||
|
||||
# overlaps length
|
||||
ff.message_types[0].add_protocol_label_start_length(24, 8)
|
||||
ff.perform_iteration()
|
||||
self.assertEqual(1, len(ff.message_types))
|
||||
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE))
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC))
|
||||
self.assertIsNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH))
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
|
||||
|
||||
def test_address_part_already_labeled(self):
|
||||
protocol = self.__prepare_simple_example_protocol()
|
||||
self.clear_message_types(protocol.messages)
|
||||
ff = FormatFinder(protocol.messages)
|
||||
|
||||
# overlaps dst address
|
||||
ff.message_types[0].add_protocol_label_start_length(40, 16)
|
||||
ff.perform_iteration()
|
||||
self.assertEqual(1, len(ff.message_types))
|
||||
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE))
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC))
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH))
|
||||
self.assertIsNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
|
||||
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
|
||||
|
||||
@staticmethod
|
||||
def __message_types_have_same_labels(mt1: MessageType, mt2: MessageType):
|
||||
if len(mt1) != len(mt2):
|
||||
return False
|
||||
|
||||
for i, lbl in enumerate(mt1):
|
||||
if lbl != mt2[i]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def __prepare_example_protocol(self) -> ProtocolAnalyzer:
|
||||
alice = Participant("Alice", "A", address_hex="1234")
|
||||
bob = Participant("Bob", "B", address_hex="cafe")
|
||||
|
||||
mb = MessageTypeBuilder("data")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.TYPE, 8)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||||
mb_ack = MessageTypeBuilder("ack")
|
||||
mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb_ack.add_label(FieldType.Function.SYNC, 16)
|
||||
mb_ack.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
|
||||
num_messages = 50
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"},
|
||||
participants=[alice, bob])
|
||||
|
||||
random.seed(0)
|
||||
for i in range(num_messages):
|
||||
if i % 2 == 0:
|
||||
source, destination = alice, bob
|
||||
data_length = 8
|
||||
else:
|
||||
source, destination = bob, alice
|
||||
data_length = 16
|
||||
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
|
||||
source=source, destination=destination)
|
||||
pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination)
|
||||
|
||||
#self.save_protocol("labeled_protocol", pg)
|
||||
|
||||
return pg.protocol
|
||||
|
||||
def __prepare_simple_example_protocol(self):
|
||||
random.seed(0)
|
||||
alice = Participant("Alice", "A", address_hex="1234")
|
||||
bob = Participant("Bob", "B", address_hex="cafe")
|
||||
|
||||
mb = MessageTypeBuilder("data")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.TYPE, 8)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x6768"},
|
||||
participants=[alice, bob])
|
||||
|
||||
for i in range(10):
|
||||
pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(16)]), source=alice, destination=bob)
|
||||
pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(8)]), source=bob, destination=alice)
|
||||
|
||||
return pg.protocol
|
@ -0,0 +1,182 @@
|
||||
from tests.awre.AWRETestCase import AWRETestCase
|
||||
from urh.awre.CommonRange import CommonRange
|
||||
from urh.awre.FormatFinder import FormatFinder
|
||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
|
||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
|
||||
from urh.awre.engines.SequenceNumberEngine import SequenceNumberEngine
|
||||
from urh.signalprocessing.FieldType import FieldType
|
||||
from urh.signalprocessing.Participant import Participant
|
||||
|
||||
|
||||
class TestSequenceNumberEngine(AWRETestCase):
|
||||
def test_simple_protocol(self):
|
||||
"""
|
||||
Test a simple protocol with
|
||||
preamble, sync and increasing sequence number (8 bit) and some constant data
|
||||
|
||||
:return:
|
||||
"""
|
||||
mb = MessageTypeBuilder("simple_seq_test")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
|
||||
|
||||
num_messages = 20
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x9a9d"})
|
||||
|
||||
for i in range(num_messages):
|
||||
pg.generate_message(data="0xcafe")
|
||||
|
||||
#self.save_protocol("simple_sequence_number", pg)
|
||||
|
||||
self.clear_message_types(pg.protocol.messages)
|
||||
ff = FormatFinder(pg.protocol.messages)
|
||||
|
||||
seq_engine = SequenceNumberEngine(ff.bitvectors, n_gram_length=8)
|
||||
highscored_ranges = seq_engine.find()
|
||||
self.assertEqual(len(highscored_ranges), 1)
|
||||
|
||||
ff.perform_iteration()
|
||||
self.assertEqual(len(ff.message_types), 1)
|
||||
self.assertGreater(len(ff.message_types[0]), 0)
|
||||
self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1)
|
||||
label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
|
||||
self.assertEqual(label.start, 24)
|
||||
self.assertEqual(label.length, 8)
|
||||
|
||||
def test_16bit_seq_nr(self):
|
||||
mb = MessageTypeBuilder("16bit_seq_test")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
|
||||
|
||||
num_messages = 10
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x9a9d"}, sequence_number_increment=64)
|
||||
|
||||
for i in range(num_messages):
|
||||
pg.generate_message(data="0xcafe")
|
||||
|
||||
#self.save_protocol("16bit_seq", pg)
|
||||
|
||||
bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages, sync_ends=[24]*num_messages)
|
||||
seq_engine = SequenceNumberEngine(bitvectors, n_gram_length=8)
|
||||
highscored_ranges = seq_engine.find()
|
||||
self.assertEqual(len(highscored_ranges), 1)
|
||||
|
||||
self.clear_message_types(pg.protocol.messages)
|
||||
ff = FormatFinder(pg.protocol.messages)
|
||||
ff.perform_iteration()
|
||||
|
||||
self.assertEqual(len(ff.message_types), 1)
|
||||
self.assertGreater(len(ff.message_types[0]), 0)
|
||||
self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1)
|
||||
label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
|
||||
self.assertEqual(label.start, 24)
|
||||
self.assertEqual(label.length, 16)
|
||||
|
||||
def test_16bit_seq_nr_with_zeros_in_first_part(self):
|
||||
mb = MessageTypeBuilder("16bit_seq_first_byte_zero_test")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
|
||||
|
||||
num_messages = 10
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x9a9d"}, sequence_number_increment=1)
|
||||
|
||||
for i in range(num_messages):
|
||||
pg.generate_message(data="0xcafe" + "abc" * i)
|
||||
|
||||
#self.save_protocol("16bit_seq_first_byte_zero_test", pg)
|
||||
|
||||
bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages, sync_ends=[24]*num_messages)
|
||||
seq_engine = SequenceNumberEngine(bitvectors, n_gram_length=8)
|
||||
highscored_ranges = seq_engine.find()
|
||||
self.assertEqual(len(highscored_ranges), 1)
|
||||
|
||||
self.clear_message_types(pg.protocol.messages)
|
||||
ff = FormatFinder(pg.protocol.messages)
|
||||
ff.perform_iteration()
|
||||
self.assertEqual(len(ff.message_types), 1)
|
||||
self.assertGreater(len(ff.message_types[0]), 0)
|
||||
self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1)
|
||||
label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
|
||||
|
||||
# Not consider constants as part of SEQ Nr!
|
||||
self.assertEqual(label.start, 40)
|
||||
self.assertEqual(label.length, 8)
|
||||
|
||||
def test_no_sequence_number(self):
|
||||
"""
|
||||
Ensure no sequence number is labeled, when it cannot be found
|
||||
|
||||
:return:
|
||||
"""
|
||||
alice = Participant("Alice", address_hex="dead")
|
||||
bob = Participant("Bob", address_hex="beef")
|
||||
|
||||
mb = MessageTypeBuilder("protocol_with_one_message_type")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.LENGTH, 8)
|
||||
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||||
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||||
|
||||
num_messages = 3
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x1337"},
|
||||
participants=[alice, bob])
|
||||
|
||||
for i in range(num_messages):
|
||||
if i % 2 == 0:
|
||||
source, destination = alice, bob
|
||||
else:
|
||||
source, destination = bob, alice
|
||||
pg.generate_message(data="", source=source, destination=destination)
|
||||
|
||||
#self.save_protocol("protocol_1", pg)
|
||||
|
||||
# Delete message type information -> no prior knowledge
|
||||
self.clear_message_types(pg.protocol.messages)
|
||||
|
||||
ff = FormatFinder(pg.protocol.messages)
|
||||
ff.known_participant_addresses.clear()
|
||||
ff.perform_iteration()
|
||||
|
||||
self.assertEqual(len(ff.message_types), 1)
|
||||
|
||||
self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 0)
|
||||
|
||||
def test_sequence_number_little_endian_16_bit(self):
|
||||
mb = MessageTypeBuilder("16bit_seq_test")
|
||||
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||||
mb.add_label(FieldType.Function.SYNC, 16)
|
||||
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
|
||||
|
||||
num_messages = 8
|
||||
|
||||
pg = ProtocolGenerator([mb.message_type],
|
||||
syncs_by_mt={mb.message_type: "0x9a9d"},
|
||||
little_endian=True, sequence_number_increment=64)
|
||||
|
||||
for i in range(num_messages):
|
||||
pg.generate_message(data="0xcafe")
|
||||
|
||||
#self.save_protocol("16bit_litte_endian_seq", pg)
|
||||
|
||||
self.clear_message_types(pg.protocol.messages)
|
||||
ff = FormatFinder(pg.protocol.messages)
|
||||
ff.perform_iteration()
|
||||
|
||||
self.assertEqual(len(ff.message_types), 1)
|
||||
self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1)
|
||||
label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
|
||||
self.assertEqual(label.start, 24)
|
||||
self.assertEqual(label.length, 16)
|
Reference in New Issue
Block a user