This commit is contained in:
RocketGod
2022-09-22 10:41:47 -07:00
parent e51463a5d1
commit 8347d2f50e
565 changed files with 165005 additions and 0 deletions

View File

@ -0,0 +1,65 @@
import os
import tempfile
import unittest
import numpy
from urh.awre.FormatFinder import FormatFinder
from tests.utils_testing import get_path_for_data_file
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
from urh.signalprocessing.MessageType import MessageType
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
from urh.awre.ProtocolGenerator import ProtocolGenerator
from urh.signalprocessing.FieldType import FieldType
class AWRETestCase(unittest.TestCase):
def setUp(self):
numpy.set_printoptions(linewidth=80)
self.field_types = self.__init_field_types()
def get_format_finder_from_protocol_file(self, filename: str, clear_participant_addresses=True, return_messages=False):
proto_file = get_path_for_data_file(filename)
protocol = ProtocolAnalyzer(signal=None, filename=proto_file)
protocol.from_xml_file(filename=proto_file, read_bits=True)
self.clear_message_types(protocol.messages)
ff = FormatFinder(protocol.messages)
if clear_participant_addresses:
ff.known_participant_addresses.clear()
if return_messages:
return ff, protocol.messages
else:
return ff
@staticmethod
def __init_field_types():
result = []
for field_type_function in FieldType.Function:
result.append(FieldType(field_type_function.value, field_type_function))
return result
@staticmethod
def clear_message_types(messages: list):
mt = MessageType("empty")
for msg in messages:
msg.message_type = mt
@staticmethod
def save_protocol(name, protocol_generator, silent=False):
filename = os.path.join(tempfile.gettempdir(), name + ".proto")
if isinstance(protocol_generator, ProtocolGenerator):
protocol_generator.to_file(filename)
elif isinstance(protocol_generator, ProtocolAnalyzer):
participants = list(set(msg.participant for msg in protocol_generator.messages))
protocol_generator.to_xml_file(filename, [], participants=participants, write_bits=True)
info = "Protocol written to " + filename
if not silent:
print()
print("-" * len(info))
print(info)
print("-" * len(info))

View File

@ -0,0 +1,791 @@
import array
import multiprocessing
import os
import random
import time
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
from tests.awre.AWRETestCase import AWRETestCase
from tests.utils_testing import get_path_for_data_file
from urh.awre.FormatFinder import FormatFinder
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
from urh.awre.Preprocessor import Preprocessor
from urh.awre.ProtocolGenerator import ProtocolGenerator
from urh.awre.engines.Engine import Engine
from urh.signalprocessing.FieldType import FieldType
from urh.signalprocessing.Message import Message
from urh.signalprocessing.MessageType import MessageType
from urh.signalprocessing.Participant import Participant
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
from urh.util.GenericCRC import GenericCRC
def run_for_num_broken(protocol_nr, num_broken: list, num_messages: int, num_runs: int) -> list:
random.seed(0)
np.random.seed(0)
result = []
for broken in num_broken:
tmp_accuracies = np.empty(num_runs, dtype=np.float64)
tmp_accuracies_without_broken = np.empty(num_runs, dtype=np.float64)
for i in range(num_runs):
protocol, expected_labels = AWRExperiments.get_protocol(protocol_nr,
num_messages=num_messages,
num_broken_messages=broken,
silent=True)
AWRExperiments.run_format_finder_for_protocol(protocol)
accuracy = AWRExperiments.calculate_accuracy(protocol.messages, expected_labels)
accuracy_without_broken = AWRExperiments.calculate_accuracy(protocol.messages, expected_labels, broken)
tmp_accuracies[i] = accuracy
tmp_accuracies_without_broken[i] = accuracy_without_broken
avg_accuracy = np.mean(tmp_accuracies)
avg_accuracy_without_broken = np.mean(tmp_accuracies_without_broken)
result.append((avg_accuracy, avg_accuracy_without_broken))
print("Protocol {} with {} broken: {:>3}% {:>3}%".format(protocol_nr, broken, int(avg_accuracy),
int(avg_accuracy_without_broken)))
return result
class AWRExperiments(AWRETestCase):
@staticmethod
def _prepare_protocol_1() -> ProtocolGenerator:
alice = Participant("Alice", address_hex="dead")
bob = Participant("Bob", address_hex="beef")
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x1337"},
participants=[alice, bob])
return pg
@staticmethod
def _prepare_protocol_2() -> ProtocolGenerator:
alice = Participant("Alice", address_hex="dead01")
bob = Participant("Bob", address_hex="beef24")
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.PREAMBLE, 72)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 24)
mb.add_label(FieldType.Function.DST_ADDRESS, 24)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x1337"},
preambles_by_mt={mb.message_type: "10" * 36},
sequence_number_increment=32,
participants=[alice, bob])
return pg
@staticmethod
def _prepare_protocol_3() -> ProtocolGenerator:
alice = Participant("Alice", address_hex="1337")
bob = Participant("Bob", address_hex="beef")
checksum = GenericCRC.from_standard_checksum("CRC8 CCITT")
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.PREAMBLE, 16)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
mb.add_label(FieldType.Function.DATA, 10 * 8)
mb.add_checksum_label(8, checksum)
mb_ack = MessageTypeBuilder("ack")
mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
mb_ack.add_label(FieldType.Function.SYNC, 16)
mb_ack.add_label(FieldType.Function.LENGTH, 8)
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
mb_ack.add_checksum_label(8, checksum)
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"},
preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8},
participants=[alice, bob])
return pg
@staticmethod
def _prepare_protocol_4() -> ProtocolGenerator:
alice = Participant("Alice", address_hex="1337")
bob = Participant("Bob", address_hex="beef")
checksum = GenericCRC.from_standard_checksum("CRC16 CCITT")
mb = MessageTypeBuilder("data1")
mb.add_label(FieldType.Function.PREAMBLE, 16)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
mb.add_label(FieldType.Function.DATA, 8 * 8)
mb.add_checksum_label(16, checksum)
mb2 = MessageTypeBuilder("data2")
mb2.add_label(FieldType.Function.PREAMBLE, 16)
mb2.add_label(FieldType.Function.SYNC, 16)
mb2.add_label(FieldType.Function.LENGTH, 8)
mb2.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb2.add_label(FieldType.Function.DST_ADDRESS, 16)
mb2.add_label(FieldType.Function.DATA, 64 * 8)
mb2.add_checksum_label(16, checksum)
mb_ack = MessageTypeBuilder("ack")
mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
mb_ack.add_label(FieldType.Function.SYNC, 16)
mb_ack.add_label(FieldType.Function.LENGTH, 8)
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
mb_ack.add_checksum_label(16, checksum)
mt1, mt2, mt3 = mb.message_type, mb2.message_type, mb_ack.message_type
preamble = "10001000" * 2
pg = ProtocolGenerator([mt1, mt2, mt3],
syncs_by_mt={mt1: "0x9a7d", mt2: "0x9a7d", mt3: "0x9a7d"},
preambles_by_mt={mt1: preamble, mt2: preamble, mt3: preamble},
participants=[alice, bob])
return pg
@staticmethod
def _prepare_protocol_5() -> ProtocolGenerator:
alice = Participant("Alice", address_hex="1337")
bob = Participant("Bob", address_hex="beef")
carl = Participant("Carl", address_hex="cafe")
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.PREAMBLE, 16)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
mb_ack = MessageTypeBuilder("ack")
mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
mb_ack.add_label(FieldType.Function.SYNC, 16)
mb_ack.add_label(FieldType.Function.LENGTH, 8)
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"},
preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8},
participants=[alice, bob, carl])
return pg
@staticmethod
def _prepare_protocol_6() -> ProtocolGenerator:
alice = Participant("Alice", address_hex="24")
broadcast = Participant("Bob", address_hex="ff")
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 8)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x8e88"},
preambles_by_mt={mb.message_type: "10" * 8},
participants=[alice, broadcast])
return pg
@staticmethod
def _prepare_protocol_7() -> ProtocolGenerator:
alice = Participant("Alice", address_hex="313370")
bob = Participant("Bob", address_hex="031337")
charly = Participant("Charly", address_hex="110000")
daniel = Participant("Daniel", address_hex="001100")
# broadcast = Participant("Broadcast", address_hex="ff") #TODO: Sometimes messages to broadcast
checksum = GenericCRC.from_standard_checksum("CRC16 CC1101")
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.PREAMBLE, 16)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.DST_ADDRESS, 24)
mb.add_label(FieldType.Function.SRC_ADDRESS, 24)
mb.add_label(FieldType.Function.DATA, 8 * 8)
mb.add_checksum_label(16, checksum)
mb_ack = MessageTypeBuilder("ack")
mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
mb_ack.add_label(FieldType.Function.SYNC, 16)
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 24)
mb_ack.add_checksum_label(16, checksum)
mb_kex = MessageTypeBuilder("kex")
mb_kex.add_label(FieldType.Function.PREAMBLE, 24)
mb_kex.add_label(FieldType.Function.SYNC, 16)
mb_kex.add_label(FieldType.Function.DST_ADDRESS, 24)
mb_kex.add_label(FieldType.Function.SRC_ADDRESS, 24)
mb_kex.add_label(FieldType.Function.DATA, 64 * 8)
mb_kex.add_checksum_label(16, checksum)
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type, mb_kex.message_type],
syncs_by_mt={mb.message_type: "0x0420", mb_ack.message_type: "0x2222",
mb_kex.message_type: "0x6767"},
preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 4,
mb_kex.message_type: "10" * 12},
participants=[alice, bob, charly, daniel])
return pg
@staticmethod
def _prepare_protocol_8() -> ProtocolGenerator:
alice = Participant("Alice")
mb = MessageTypeBuilder("data1")
mb.add_label(FieldType.Function.PREAMBLE, 4)
mb.add_label(FieldType.Function.SYNC, 4)
mb.add_label(FieldType.Function.LENGTH, 16)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
mb.add_label(FieldType.Function.DATA, 8 * 542)
mb2 = MessageTypeBuilder("data2")
mb2.add_label(FieldType.Function.PREAMBLE, 4)
mb2.add_label(FieldType.Function.SYNC, 4)
mb2.add_label(FieldType.Function.LENGTH, 16)
mb2.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
mb2.add_label(FieldType.Function.DATA, 8 * 260)
pg = ProtocolGenerator([mb.message_type, mb2.message_type],
syncs_by_mt={mb.message_type: "0x9", mb2.message_type: "0x9"},
preambles_by_mt={mb.message_type: "10" * 2, mb2.message_type: "10" * 2},
sequence_number_increment=32,
participants=[alice],
little_endian=True)
return pg
def test_export_to_latex(self):
filename = os.path.expanduser("~/GIT/publications/awre/USENIX/protocols.tex")
if os.path.isfile(filename):
os.remove(filename)
for i in range(1, 9):
pg = getattr(self, "_prepare_protocol_" + str(i))()
pg.export_to_latex(filename, i)
@classmethod
def get_protocol(cls, protocol_number: int, num_messages, num_broken_messages=0, silent=False):
if protocol_number == 1:
pg = cls._prepare_protocol_1()
elif protocol_number == 2:
pg = cls._prepare_protocol_2()
elif protocol_number == 3:
pg = cls._prepare_protocol_3()
elif protocol_number == 4:
pg = cls._prepare_protocol_4()
elif protocol_number == 5:
pg = cls._prepare_protocol_5()
elif protocol_number == 6:
pg = cls._prepare_protocol_6()
elif protocol_number == 7:
pg = cls._prepare_protocol_7()
elif protocol_number == 8:
pg = cls._prepare_protocol_8()
else:
raise ValueError("Unknown protocol number")
messages_types_with_data_field = [mt for mt in pg.protocol.message_types
if mt.get_first_label_with_type(FieldType.Function.DATA)]
i = -1
while len(pg.protocol.messages) < num_messages:
i += 1
source = pg.participants[i % len(pg.participants)]
destination = pg.participants[(i + 1) % len(pg.participants)]
if i % 2 == 0:
data_bytes = 8
else:
# data_bytes = 16
data_bytes = 64
if len(messages_types_with_data_field) == 0:
# set data automatically
data = "".join(random.choice(["0", "1"]) for _ in range(data_bytes * 8))
pg.generate_message(data=data, source=source, destination=destination)
else:
# search for message type with right data length
mt = messages_types_with_data_field[i % len(messages_types_with_data_field)]
data_length = mt.get_first_label_with_type(FieldType.Function.DATA).length
data = "".join(random.choice(["0", "1"]) for _ in range(data_length))
pg.generate_message(message_type=mt, data=data, source=source, destination=destination)
ack_message_type = next((mt for mt in pg.protocol.message_types if "ack" in mt.name), None)
if ack_message_type:
pg.generate_message(message_type=ack_message_type, data="", source=destination, destination=source)
for i in range(num_broken_messages):
msg = pg.protocol.messages[i]
pos = random.randint(0, len(msg.plain_bits) // 2)
msg.plain_bits[pos:] = array.array("B",
[random.randint(0, 1) for _ in range(len(msg.plain_bits) - pos)])
if num_broken_messages == 0:
cls.save_protocol("protocol{}_{}_messages".format(protocol_number, num_messages), pg, silent=silent)
else:
cls.save_protocol("protocol{}_{}_broken".format(protocol_number, num_broken_messages), pg, silent=silent)
expected_message_types = [msg.message_type for msg in pg.protocol.messages]
# Delete message type information -> no prior knowledge
cls.clear_message_types(pg.protocol.messages)
# Delete data labels if present
for mt in expected_message_types:
data_lbl = mt.get_first_label_with_type(FieldType.Function.DATA)
if data_lbl:
mt.remove(data_lbl)
return pg.protocol, expected_message_types
@staticmethod
def calculate_accuracy(messages, expected_labels, num_broken_messages=0):
"""
Calculate the accuracy of labels compared to expected labels
Accuracy is 100% when labels == expected labels
Accuracy drops by 1 / len(expected_labels) for every expected label not present in labels
:type messages: list of Message
:type expected_labels: list of MessageType
:return:
"""
accuracy = sum(len(set(expected_labels[i]) & set(messages[i].message_type)) / len(expected_labels[i])
for i in range(num_broken_messages, len(messages)))
try:
accuracy /= (len(messages) - num_broken_messages)
except ZeroDivisionError:
accuracy = 0
return accuracy * 100
def test_against_num_messages(self):
num_messages = list(range(1, 24, 1))
accuracies = defaultdict(list)
protocols = [1, 2, 3, 4, 5, 6, 7, 8]
random.seed(0)
np.random.seed(0)
for protocol_nr in protocols:
for n in num_messages:
protocol, expected_labels = self.get_protocol(protocol_nr, num_messages=n)
self.run_format_finder_for_protocol(protocol)
accuracy = self.calculate_accuracy(protocol.messages, expected_labels)
accuracies["protocol {}".format(protocol_nr)].append(accuracy)
self.__plot(num_messages, accuracies, xlabel="Number of messages", ylabel="Accuracy in %", grid=True)
self.__export_to_csv("/tmp/accuray-vs-messages", num_messages, accuracies)
def test_against_error(self):
Engine._DEBUG_ = False
Preprocessor._DEBUG_ = False
num_runs = 100
num_messages = 30
num_broken_messages = list(range(0, num_messages + 1))
accuracies = defaultdict(list)
accuracies_without_broken = defaultdict(list)
protocols = [1, 2, 3, 4, 5, 6, 7, 8]
random.seed(0)
np.random.seed(0)
with multiprocessing.Pool() as p:
result = p.starmap(run_for_num_broken,
[(i, num_broken_messages, num_messages, num_runs) for i in protocols])
for i, acc in enumerate(result):
accuracies["protocol {}".format(i + 1)] = [a[0] for a in acc]
accuracies_without_broken["protocol {}".format(i + 1)] = [a[1] for a in acc]
self.__plot(100 * np.array(num_broken_messages) / num_messages, accuracies,
title="Overall Accuracy vs percentage of broken messages",
xlabel="Broken messages in %",
ylabel="Accuracy in %", grid=True)
self.__plot(100 * np.array(num_broken_messages) / num_messages, accuracies_without_broken,
title=" Accuracy of unbroken vs percentage of broken messages",
xlabel="Broken messages in %",
ylabel="Accuracy in %", grid=True)
self.__export_to_csv("/tmp/accuray-vs-error", num_broken_messages, accuracies, relative=num_messages)
self.__export_to_csv("/tmp/accuray-vs-error-without-broken", num_broken_messages, accuracies_without_broken,
relative=num_messages)
def test_performance(self):
Engine._DEBUG_ = False
Preprocessor._DEBUG_ = False
num_messages = list(range(200, 205, 5))
protocols = [1]
random.seed(0)
np.random.seed(0)
performances = defaultdict(list)
for protocol_nr in protocols:
print("Running for protocol", protocol_nr)
for messages in num_messages:
protocol, _ = self.get_protocol(protocol_nr, messages, silent=True)
t = time.time()
self.run_format_finder_for_protocol(protocol)
performances["protocol {}".format(protocol_nr)].append(time.time() - t)
# self.__plot(num_messages, performances, xlabel="Number of messages", ylabel="Time in seconds", grid=True)
def test_performance_real_protocols(self):
Engine._DEBUG_ = False
Preprocessor._DEBUG_ = False
num_runs = 100
num_messages = list(range(8, 512, 4))
protocol_names = ["enocean", "homematic", "rwe"]
random.seed(0)
np.random.seed(0)
performances = defaultdict(list)
for protocol_name in protocol_names:
for messages in num_messages:
if protocol_name == "homematic":
protocol = self.generate_homematic(messages, save_protocol=False)
elif protocol_name == "enocean":
protocol = self.generate_enocean(messages, save_protocol=False)
elif protocol_name == "rwe":
protocol = self.generate_rwe(messages, save_protocol=False)
else:
raise ValueError("Unknown protocol name")
tmp_performances = np.empty(num_runs, dtype=np.float64)
for i in range(num_runs):
print("\r{0} with {1:02d} messages ({2}/{3} runs)".format(protocol_name, messages, i + 1, num_runs),
flush=True, end="")
t = time.time()
self.run_format_finder_for_protocol(protocol)
tmp_performances[i] = time.time() - t
self.clear_message_types(protocol.messages)
mean_performance = tmp_performances.mean()
print(" {:.2f}s".format(mean_performance))
performances["{}".format(protocol_name)].append(mean_performance)
self.__plot(num_messages, performances, xlabel="Number of messages", ylabel="Time in seconds", grid=True)
self.__export_to_csv("/tmp/performance.csv", num_messages, performances)
@staticmethod
def __export_to_csv(filename: str, x: list, y: dict, relative=None):
if not filename.endswith(".csv"):
filename += ".csv"
with open(filename, "w") as f:
f.write("N,")
if relative is not None:
f.write("NRel,")
for y_cap in sorted(y):
f.write(y_cap + ",")
f.write("\n")
for i, x_val in enumerate(x):
f.write("{},".format(x_val))
if relative is not None:
f.write("{},".format(100 * x_val / relative))
for y_cap in sorted(y):
f.write("{},".format(y[y_cap][i]))
f.write("\n")
@staticmethod
def __plot(x: list, y: dict, xlabel: str, ylabel: str, grid=False, title=None):
plt.xlabel(xlabel)
plt.ylabel(ylabel)
for y_cap, y_values in sorted(y.items()):
plt.plot(x, y_values, label=y_cap)
if grid:
plt.grid()
if title:
plt.title(title)
plt.legend()
plt.show()
@staticmethod
def run_format_finder_for_protocol(protocol: ProtocolAnalyzer):
ff = FormatFinder(protocol.messages)
ff.known_participant_addresses.clear()
ff.run()
for msg_type, indices in ff.existing_message_types.items():
for i in indices:
protocol.messages[i].message_type = msg_type
@classmethod
def generate_homematic(cls, num_messages: int, save_protocol=True):
mb_m_frame = MessageTypeBuilder("mframe")
mb_c_frame = MessageTypeBuilder("cframe")
mb_r_frame = MessageTypeBuilder("rframe")
mb_a_frame = MessageTypeBuilder("aframe")
participants = [Participant("CCU", address_hex="3927cc"), Participant("Switch", address_hex="3101cc")]
checksum = GenericCRC.from_standard_checksum("CRC16 CC1101")
for mb_builder in [mb_m_frame, mb_c_frame, mb_r_frame, mb_a_frame]:
mb_builder.add_label(FieldType.Function.PREAMBLE, 32)
mb_builder.add_label(FieldType.Function.SYNC, 32)
mb_builder.add_label(FieldType.Function.LENGTH, 8)
mb_builder.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
mb_builder.add_label(FieldType.Function.TYPE, 16)
mb_builder.add_label(FieldType.Function.SRC_ADDRESS, 24)
mb_builder.add_label(FieldType.Function.DST_ADDRESS, 24)
if mb_builder.name == "mframe":
mb_builder.add_label(FieldType.Function.DATA, 16, name="command")
elif mb_builder.name == "cframe":
mb_builder.add_label(FieldType.Function.DATA, 16 * 4, name="command+challenge+magic")
elif mb_builder.name == "rframe":
mb_builder.add_label(FieldType.Function.DATA, 32 * 4, name="cipher")
elif mb_builder.name == "aframe":
mb_builder.add_label(FieldType.Function.DATA, 10 * 4, name="command + auth")
mb_builder.add_checksum_label(16, checksum)
message_types = [mb_m_frame.message_type, mb_c_frame.message_type, mb_r_frame.message_type,
mb_a_frame.message_type]
preamble = "0xaaaaaaaa"
sync = "0xe9cae9ca"
initial_sequence_number = 36
pg = ProtocolGenerator(message_types, participants,
preambles_by_mt={mt: preamble for mt in message_types},
syncs_by_mt={mt: sync for mt in message_types},
sequence_numbers={mt: initial_sequence_number for mt in message_types},
message_type_codes={mb_m_frame.message_type: 42560,
mb_c_frame.message_type: 40962,
mb_r_frame.message_type: 40963,
mb_a_frame.message_type: 32770})
for i in range(num_messages):
mt = pg.message_types[i % 4]
data_length = mt.get_first_label_with_type(FieldType.Function.DATA).length
data = "".join(random.choice(["0", "1"]) for _ in range(data_length))
pg.generate_message(mt, data, source=pg.participants[i % 2], destination=pg.participants[(i + 1) % 2])
if save_protocol:
cls.save_protocol("homematic", pg)
cls.clear_message_types(pg.messages)
return pg.protocol
@classmethod
def generate_enocean(cls, num_messages: int, save_protocol=True):
filename = get_path_for_data_file("enocean_bits.txt")
enocean_bits = []
with open(filename, "r") as f:
for line in map(str.strip, f):
enocean_bits.append(line)
protocol = ProtocolAnalyzer(None)
message_type = MessageType("empty")
for i in range(num_messages):
msg = Message.from_plain_bits_str(enocean_bits[i % len(enocean_bits)])
msg.message_type = message_type
protocol.messages.append(msg)
if save_protocol:
cls.save_protocol("enocean", protocol)
return protocol
@classmethod
def generate_rwe(cls, num_messages: int, save_protocol=True):
proto_file = get_path_for_data_file("rwe.proto.xml")
protocol = ProtocolAnalyzer(signal=None, filename=proto_file)
protocol.from_xml_file(filename=proto_file, read_bits=True)
messages = protocol.messages
result = ProtocolAnalyzer(None)
message_type = MessageType("empty")
for i in range(num_messages):
msg = messages[i % len(messages)] # type: Message
msg.message_type = message_type
result.messages.append(msg)
if save_protocol:
cls.save_protocol("rwe", result)
return result
def test_export_latex_table(self):
def bold_latex(s):
return r"\textbf{" + str(s) + r"}"
comments = {
1: "common protocol",
2: "unusual field sizes",
3: "contains ack and CRC8 CCITT",
4: "contains ack and CRC16 CCITT",
5: "three participants with ack frame",
6: "short address",
7: "four participants, varying preamble size, varying sync words",
8: "nibble fields + LE"
}
bold = {i: defaultdict(bool) for i in range(1, 9)}
bold[2][FieldType.Function.PREAMBLE] = True
bold[2][FieldType.Function.SRC_ADDRESS] = True
bold[2][FieldType.Function.DST_ADDRESS] = True
bold[3][FieldType.Function.CHECKSUM] = True
bold[4][FieldType.Function.CHECKSUM] = True
bold[6][FieldType.Function.SRC_ADDRESS] = True
bold[7][FieldType.Function.PREAMBLE] = True
bold[7][FieldType.Function.SYNC] = True
bold[7][FieldType.Function.SRC_ADDRESS] = True
bold[7][FieldType.Function.DST_ADDRESS] = True
bold[8][FieldType.Function.PREAMBLE] = True
bold[8][FieldType.Function.SYNC] = True
filename = os.path.expanduser("~/GIT/publications/awre/USENIX/protocol_table.tex")
rowcolors = [r"\rowcolor{black!10}", r"\rowcolor{black!20}"]
with open(filename, "w") as f:
f.write(r"\begin{table*}[!h]" + "\n")
f.write(
"\t" + r"\caption{Properties of tested protocols whereby $\times$ means field is not present and $N_P$ is the number of participants.}" + "\n")
f.write("\t" + r"\label{tab:protocols}" + "\n")
f.write("\t" + r"\centering" + "\n")
f.write("\t" + r"\begin{tabularx}{\linewidth}{cp{2.5cm}llcccccccc}" + "\n")
f.write("\t\t" + r"\hline" + "\n")
f.write("\t\t" + r"\rowcolor{black!90}" + "\n")
f.write("\t\t" + r"\textcolor{white}{\textbf{\#}} & "
r"\textcolor{white}{\textbf{Comment}} & "
r"\textcolor{white}{$\mathbf{ N_P }$} & "
r"\textcolor{white}{\textbf{Message}} & "
r"\textcolor{white}{\textbf{Even/odd}} & "
r"\multicolumn{7}{c}{\textcolor{white}{\textbf{Size of field in bit (BE=Big Endian, LE=Little Endian)}}}\\"
"\n\t\t"
r"\rowcolor{black!90}"
"\n\t\t"
r"& & & \textcolor{white}{\textbf{Type}} & \textcolor{white}{\textbf{message data}} &"
r"\textcolor{white}{Preamble} & "
r"\textcolor{white}{Sync} & "
r"\textcolor{white}{Length} & "
r"\textcolor{white}{SRC} & "
r"\textcolor{white}{DST} & "
r"\textcolor{white}{SEQ Nr} & "
r"\textcolor{white}{CRC} \\" + "\n")
f.write("\t\t" + r"\hline" + "\n")
rowcolor_index = 0
for i in range(1, 9):
pg = getattr(self, "_prepare_protocol_" + str(i))()
assert isinstance(pg, ProtocolGenerator)
try:
data1 = next(mt for mt in pg.message_types if mt.name == "data1")
data2 = next(mt for mt in pg.message_types if mt.name == "data2")
data1_len = data1.get_first_label_with_type(FieldType.Function.DATA).length // 8
data2_len = data2.get_first_label_with_type(FieldType.Function.DATA).length // 8
except StopIteration:
data1_len, data2_len = 8, 64
rowcolor = rowcolors[rowcolor_index % len(rowcolors)]
rowcount = 0
for j, mt in enumerate(pg.message_types):
if mt.name == "data2":
continue
rowcount += 1
if j == 0:
protocol_nr, participants = str(i), len(pg.participants)
if participants > 2:
participants = bold_latex(participants)
else:
protocol_nr, participants = " ", " "
f.write("\t\t" + rowcolor + "\n")
if len(pg.message_types) == 1 or (
mt.name == "data1" and "ack" not in {m.name for m in pg.message_types}):
f.write("\t\t{} & {} & {} & {} &".format(protocol_nr, comments[i], participants,
mt.name.replace("1", "")))
elif j == len(pg.message_types) - 1:
f.write(
"\t\t{} & \\multirow{{{}}}{{\\linewidth}}{{{}}} & {} & {} &".format(protocol_nr, -rowcount,
comments[i],
participants,
mt.name.replace("1",
"")))
else:
f.write("\t\t{} & & {} & {} &".format(protocol_nr, participants, mt.name.replace("1", "")))
data_lbl = mt.get_first_label_with_type(FieldType.Function.DATA)
if mt.name == "data1" or mt.name == "data2":
f.write("{}/{} byte &".format(data1_len, data2_len))
elif mt.name == "data" and data_lbl is None:
f.write("{}/{} byte &".format(data1_len, data2_len))
elif data_lbl is not None:
f.write("{0}/{0} byte & ".format(data_lbl.length // 8))
else:
f.write(r"$ \times $ & ")
for t in (FieldType.Function.PREAMBLE, FieldType.Function.SYNC, FieldType.Function.LENGTH,
FieldType.Function.SRC_ADDRESS, FieldType.Function.DST_ADDRESS,
FieldType.Function.SEQUENCE_NUMBER,
FieldType.Function.CHECKSUM):
lbl = mt.get_first_label_with_type(t)
if lbl is not None:
if bold[i][lbl.field_type.function]:
f.write(bold_latex(lbl.length))
else:
f.write(str(lbl.length))
if lbl.length > 8 and t in (FieldType.Function.LENGTH, FieldType.Function.SEQUENCE_NUMBER):
f.write(" ({})".format(bold_latex("LE") if pg.little_endian else "BE"))
else:
f.write(r"$ \times $")
if t != FieldType.Function.CHECKSUM:
f.write(" & ")
else:
f.write(r"\\" + "\n")
rowcolor_index += 1
f.write("\t" + r"\end{tabularx}" + "\n")
f.write(r"\end{table*}" + "\n")

View File

@ -0,0 +1,179 @@
import random
from collections import defaultdict
import matplotlib.pyplot as plt
from tests.awre.AWRETestCase import AWRETestCase
from urh.awre.FormatFinder import FormatFinder
from urh.awre.Histogram import Histogram
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
from urh.awre.ProtocolGenerator import ProtocolGenerator
from urh.signalprocessing.FieldType import FieldType
from urh.signalprocessing.Participant import Participant
SHOW_PLOTS = True
class TestAWREHistograms(AWRETestCase):
def test_very_simple_protocol(self):
"""
Test a very simple protocol consisting just of a preamble, sync and some random data
:return:
"""
mb = MessageTypeBuilder("very_simple_test")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 8)
num_messages = 10
pg = ProtocolGenerator([mb.message_type], syncs_by_mt={mb.message_type: "0x9a"})
for _ in range(num_messages):
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 255), 8))
self.save_protocol("very_simple", pg)
h = Histogram(FormatFinder.get_bitvectors_from_messages(pg.protocol.messages))
if SHOW_PLOTS:
h.plot()
def test_simple_protocol(self):
"""
Test a simple protocol with preamble, sync and length field and some random data
:return:
"""
mb = MessageTypeBuilder("simple_test")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
num_messages_by_data_length = {8: 5, 16: 10, 32: 15}
pg = ProtocolGenerator([mb.message_type], syncs_by_mt={mb.message_type: "0x9a9d"})
for data_length, num_messages in num_messages_by_data_length.items():
for _ in range(num_messages):
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** data_length - 1), data_length))
self.save_protocol("simple", pg)
plt.subplot("221")
plt.title("All messages")
format_finder = FormatFinder(pg.protocol.messages)
for i, sync_end in enumerate(format_finder.sync_ends):
self.assertEqual(sync_end, 24, msg=str(i))
h = Histogram(format_finder.bitvectors)
h.subplot_on(plt)
bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages)
bitvectors_by_length = defaultdict(list)
for bitvector in bitvectors:
bitvectors_by_length[len(bitvector)].append(bitvector)
for i, (message_length, bitvectors) in enumerate(bitvectors_by_length.items()):
plt.subplot(2, 2, i + 2)
plt.title("Messages with length {} ({})".format(message_length, len(bitvectors)))
Histogram(bitvectors).subplot_on(plt)
if SHOW_PLOTS:
plt.show()
def test_medium_protocol(self):
"""
Test a protocol with preamble, sync, length field, 2 participants and addresses and seq nr and random data
:return:
"""
mb = MessageTypeBuilder("medium_test")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 8)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
alice = Participant("Alice", "A", "1234", color_index=0)
bob = Participant("Bob", "B", "5a9d", color_index=1)
num_messages = 100
pg = ProtocolGenerator([mb.message_type], syncs_by_mt={mb.message_type: "0x1c"}, little_endian=False)
for i in range(num_messages):
len_data = random.randint(1, 5)
data = "".join(pg.decimal_to_bits(random.randint(0, 2 ** 8 - 1), 8) for _ in range(len_data))
if i % 2 == 0:
source, dest = alice, bob
else:
source, dest = bob, alice
pg.generate_message(data=data, source=source, destination=dest)
self.save_protocol("medium", pg)
plt.subplot(2, 2, 1)
plt.title("All messages")
bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages)
h = Histogram(bitvectors)
h.subplot_on(plt)
for i, (participant, bitvectors) in enumerate(
sorted(self.get_bitvectors_by_participant(pg.protocol.messages).items())):
plt.subplot(2, 2, i + 3)
plt.title("Messages with participant {} ({})".format(participant.shortname, len(bitvectors)))
Histogram(bitvectors).subplot_on(plt)
if SHOW_PLOTS:
plt.show()
def get_bitvectors_by_participant(self, messages):
import numpy as np
result = defaultdict(list)
for msg in messages: # type: Message
result[msg.participant].append(np.array(msg.decoded_bits, dtype=np.uint8, order="C"))
return result
def test_ack_protocol(self):
"""
Test a protocol with acks
:return:
"""
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 8)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
mb_ack = MessageTypeBuilder("ack")
mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
mb_ack.add_label(FieldType.Function.SYNC, 8)
mb_ack.add_label(FieldType.Function.LENGTH, 8)
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
alice = Participant("Alice", "A", "1234", color_index=0)
bob = Participant("Bob", "B", "5a9d", color_index=1)
num_messages = 50
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
syncs_by_mt={mb.message_type: "0xbf", mb_ack.message_type: "0xbf"},
little_endian=False)
for i in range(num_messages):
if i % 2 == 0:
source, dest = alice, bob
else:
source, dest = bob, alice
pg.generate_message(data="0xffff", source=source, destination=dest)
pg.generate_message(data="", source=dest, destination=source, message_type=mb_ack.message_type)
self.save_protocol("proto_with_acks", pg)
plt.subplot(2, 2, 1)
plt.title("All messages")
bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages)
h = Histogram(bitvectors)
h.subplot_on(plt)
for i, (participant, bitvectors) in enumerate(
sorted(self.get_bitvectors_by_participant(pg.protocol.messages).items())):
plt.subplot(2, 2, i + 3)
plt.title("Messages with participant {} ({})".format(participant.shortname, len(bitvectors)))
Histogram(bitvectors).subplot_on(plt)
if SHOW_PLOTS:
plt.show()

View File

View File

@ -0,0 +1,386 @@
import random
from array import array
import numpy as np
from tests.awre.AWRETestCase import AWRETestCase
from tests.utils_testing import get_path_for_data_file
from urh.awre.FormatFinder import FormatFinder
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
from urh.awre.ProtocolGenerator import ProtocolGenerator
from urh.awre.engines.AddressEngine import AddressEngine
from urh.signalprocessing.FieldType import FieldType
from urh.signalprocessing.Message import Message
from urh.signalprocessing.Participant import Participant
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
from urh.util import util
class TestAddressEngine(AWRETestCase):
def setUp(self):
super().setUp()
self.alice = Participant("Alice", "A", address_hex="1234")
self.bob = Participant("Bob", "B", address_hex="cafe")
def test_one_participant(self):
"""
Test a simple protocol with
preamble, sync and length field (8 bit) and some random data
:return:
"""
mb = MessageTypeBuilder("simple_address_test")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
num_messages_by_data_length = {8: 5, 16: 10, 32: 15}
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x9a9d"},
participants=[self.alice])
for data_length, num_messages in num_messages_by_data_length.items():
for i in range(num_messages):
pg.generate_message(data=pg.decimal_to_bits(22 * i, data_length), source=self.alice)
#self.save_protocol("address_one_participant", pg)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
address_dict = address_engine.find_addresses()
self.assertEqual(len(address_dict), 0)
def test_two_participants(self):
mb = MessageTypeBuilder("address_two_participants")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
num_messages = 50
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x9a9d"},
participants=[self.alice, self.bob])
for i in range(num_messages):
if i % 2 == 0:
source, destination = self.alice, self.bob
data_length = 8
else:
source, destination = self.bob, self.alice
data_length = 16
pg.generate_message(data=pg.decimal_to_bits(4 * i, data_length), source=source, destination=destination)
#self.save_protocol("address_two_participants", pg)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
address_dict = address_engine.find_addresses()
self.assertEqual(len(address_dict), 2)
addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0]))
addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1]))
self.assertIn(self.alice.address_hex, addresses_1)
self.assertIn(self.alice.address_hex, addresses_2)
self.assertIn(self.bob.address_hex, addresses_1)
self.assertIn(self.bob.address_hex, addresses_2)
ff.known_participant_addresses.clear()
self.assertEqual(len(ff.known_participant_addresses), 0)
ff.perform_iteration()
self.assertEqual(len(ff.known_participant_addresses), 2)
self.assertIn(bytes([int(h, 16) for h in self.alice.address_hex]),
map(bytes, ff.known_participant_addresses.values()))
self.assertIn(bytes([int(h, 16) for h in self.bob.address_hex]),
map(bytes, ff.known_participant_addresses.values()))
self.assertEqual(len(ff.message_types), 1)
mt = ff.message_types[0]
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
self.assertIsNotNone(dst_addr)
self.assertEqual(dst_addr.start, 32)
self.assertEqual(dst_addr.length, 16)
src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
self.assertIsNotNone(src_addr)
self.assertEqual(src_addr.start, 48)
self.assertEqual(src_addr.length, 16)
def test_two_participants_with_ack_messages(self):
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb_ack = MessageTypeBuilder("ack")
mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
mb_ack.add_label(FieldType.Function.SYNC, 16)
mb_ack.add_label(FieldType.Function.LENGTH, 8)
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
num_messages = 50
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"},
participants=[self.alice, self.bob])
random.seed(0)
for i in range(num_messages):
if i % 2 == 0:
source, destination = self.alice, self.bob
data_length = 8
else:
source, destination = self.bob, self.alice
data_length = 16
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
source=source, destination=destination)
pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination)
#self.save_protocol("address_two_participants_with_acks", pg)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
address_dict = address_engine.find_addresses()
self.assertEqual(len(address_dict), 2)
addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0]))
addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1]))
self.assertIn(self.alice.address_hex, addresses_1)
self.assertIn(self.alice.address_hex, addresses_2)
self.assertIn(self.bob.address_hex, addresses_1)
self.assertIn(self.bob.address_hex, addresses_2)
ff.known_participant_addresses.clear()
ff.perform_iteration()
self.assertEqual(len(ff.message_types), 2)
mt = ff.message_types[1]
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
self.assertIsNotNone(dst_addr)
self.assertEqual(dst_addr.start, 32)
self.assertEqual(dst_addr.length, 16)
src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
self.assertIsNotNone(src_addr)
self.assertEqual(src_addr.start, 48)
self.assertEqual(src_addr.length, 16)
mt = ff.message_types[0]
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
self.assertIsNotNone(dst_addr)
self.assertEqual(dst_addr.start, 32)
self.assertEqual(dst_addr.length, 16)
def test_two_participants_with_ack_messages_and_type(self):
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.TYPE, 8)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb_ack = MessageTypeBuilder("ack")
mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
mb_ack.add_label(FieldType.Function.SYNC, 16)
mb_ack.add_label(FieldType.Function.LENGTH, 8)
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
num_messages = 50
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"},
participants=[self.alice, self.bob])
random.seed(0)
for i in range(num_messages):
if i % 2 == 0:
source, destination = self.alice, self.bob
data_length = 8
else:
source, destination = self.bob, self.alice
data_length = 16
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
source=source, destination=destination)
pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination)
#self.save_protocol("address_two_participants_with_acks_and_types", pg)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
address_dict = address_engine.find_addresses()
self.assertEqual(len(address_dict), 2)
addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0]))
addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1]))
self.assertIn(self.alice.address_hex, addresses_1)
self.assertIn(self.alice.address_hex, addresses_2)
self.assertIn(self.bob.address_hex, addresses_1)
self.assertIn(self.bob.address_hex, addresses_2)
ff.known_participant_addresses.clear()
ff.perform_iteration()
self.assertEqual(len(ff.message_types), 2)
mt = ff.message_types[1]
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
self.assertIsNotNone(dst_addr)
self.assertEqual(dst_addr.start, 40)
self.assertEqual(dst_addr.length, 16)
src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
self.assertIsNotNone(src_addr)
self.assertEqual(src_addr.start, 56)
self.assertEqual(src_addr.length, 16)
mt = ff.message_types[0]
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
self.assertIsNotNone(dst_addr)
self.assertEqual(dst_addr.start, 32)
self.assertEqual(dst_addr.length, 16)
def test_three_participants_with_ack(self):
alice = Participant("Alice", address_hex="1337")
bob = Participant("Bob", address_hex="4711")
carl = Participant("Carl", address_hex="cafe")
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.PREAMBLE, 16)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
mb_ack = MessageTypeBuilder("ack")
mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
mb_ack.add_label(FieldType.Function.SYNC, 16)
mb_ack.add_label(FieldType.Function.LENGTH, 8)
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"},
preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8},
participants=[alice, bob, carl])
i = -1
while len(pg.protocol.messages) < 20:
i += 1
source = pg.participants[i % len(pg.participants)]
destination = pg.participants[(i + 1) % len(pg.participants)]
if i % 2 == 0:
data_bytes = 8
else:
data_bytes = 16
data = "".join(random.choice(["0", "1"]) for _ in range(data_bytes * 8))
pg.generate_message(data=data, source=source, destination=destination)
if "ack" in (msg_type.name for msg_type in pg.protocol.message_types):
pg.generate_message(message_type=1, data="", source=destination, destination=source)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
ff.known_participant_addresses.clear()
self.assertEqual(len(ff.known_participant_addresses), 0)
ff.run()
# Since there are ACKS in this protocol, the engine must be able to assign the correct participant addresses
# IN CORRECT ORDER!
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[0]), "1337")
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[1]), "4711")
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[2]), "cafe")
def test_protocol_with_acks_and_checksum(self):
proto_file = get_path_for_data_file("ack_frames_with_crc.proto.xml")
protocol = ProtocolAnalyzer(signal=None, filename=proto_file)
protocol.from_xml_file(filename=proto_file, read_bits=True)
self.clear_message_types(protocol.messages)
ff = FormatFinder(protocol.messages)
ff.known_participant_addresses.clear()
ff.run()
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[0]), "1337")
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[1]), "4711")
for mt in ff.message_types:
preamble = mt.get_first_label_with_type(FieldType.Function.PREAMBLE)
self.assertEqual(preamble.start, 0)
self.assertEqual(preamble.length, 16)
sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
self.assertEqual(sync.start, 16)
self.assertEqual(sync.length, 16)
length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
self.assertEqual(length.start, 32)
self.assertEqual(length.length, 8)
def test_address_engine_performance(self):
ff, messages = self.get_format_finder_from_protocol_file("35_messages.proto.xml", return_messages=True)
engine = AddressEngine(ff.hexvectors, ff.participant_indices)
engine.find()
def test_paper_example(self):
alice = Participant("Alice", "A")
bob = Participant("Bob", "B")
participants = [alice, bob]
msg1 = Message.from_plain_hex_str("aabb1234")
msg1.participant = alice
msg2 = Message.from_plain_hex_str("aabb6789")
msg2.participant = alice
msg3 = Message.from_plain_hex_str("bbaa4711")
msg3.participant = bob
msg4 = Message.from_plain_hex_str("bbaa1337")
msg4.participant = bob
protocol = ProtocolAnalyzer(None)
protocol.messages.extend([msg1, msg2, msg3, msg4])
#self.save_protocol("paper_example", protocol)
bitvectors = FormatFinder.get_bitvectors_from_messages(protocol.messages)
hexvectors = FormatFinder.get_hexvectors(bitvectors)
address_engine = AddressEngine(hexvectors, participant_indices=[participants.index(msg.participant) for msg in
protocol.messages])
def test_find_common_sub_sequence(self):
from urh.cythonext import awre_util
str1 = "0612345678"
str2 = "0756781234"
seq1 = np.array(list(map(int, str1)), dtype=np.uint8, order="C")
seq2 = np.array(list(map(int, str2)), dtype=np.uint8, order="C")
indices = awre_util.find_longest_common_sub_sequence_indices(seq1, seq2)
self.assertEqual(len(indices), 2)
for ind in indices:
s = str1[slice(*ind)]
self.assertIn(s, ("5678", "1234"))
self.assertIn(s, str1)
self.assertIn(s, str2)
def test_find_first_occurrence(self):
from urh.cythonext import awre_util
str1 = "00" * 100 + "1234500012345" + "00" * 100
str2 = "12345"
seq1 = np.array(list(map(int, str1)), dtype=np.uint8, order="C")
seq2 = np.array(list(map(int, str2)), dtype=np.uint8, order="C")
indices = awre_util.find_occurrences(seq1, seq2)
self.assertEqual(len(indices), 2)
index = indices[0]
self.assertEqual(str1[index:index + len(str2)], str2)
# Test with ignoring indices
indices = awre_util.find_occurrences(seq1, seq2, array("L", list(range(0, 205))))
self.assertEqual(len(indices), 1)
# Test with ignoring indices
indices = awre_util.find_occurrences(seq1, seq2, array("L", list(range(0, 210))))
self.assertEqual(len(indices), 0)
self.assertEqual(awre_util.find_occurrences(seq1, np.ones(10, dtype=np.uint8)), [])

View File

@ -0,0 +1,256 @@
import random
from tests.awre.AWRETestCase import AWRETestCase
from urh.awre.FormatFinder import FormatFinder
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
from urh.awre.Preprocessor import Preprocessor
from urh.awre.ProtocolGenerator import ProtocolGenerator
from urh.signalprocessing.FieldType import FieldType
from urh.signalprocessing.Message import Message
from urh.signalprocessing.Participant import Participant
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
import numpy as np
class TestAWREPreprocessing(AWRETestCase):
def test_very_simple_sync_word_finding(self):
preamble = "10101010"
sync = "1101"
pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync)],
num_messages=(20,),
data=(lambda i: 10 * i,))
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
possible_syncs = preprocessor.find_possible_syncs()
#self.save_protocol("very_simple_sync_test", pg)
self.assertGreaterEqual(len(possible_syncs), 1)
self.assertEqual(preprocessor.find_possible_syncs()[0], sync)
def test_simple_sync_word_finding(self):
preamble = "10101010"
sync = "1001"
pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "1010", sync)],
num_messages=(20, 5),
data=(lambda i: 10 * i, lambda i: 22 * i))
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
possible_syncs = preprocessor.find_possible_syncs()
#self.save_protocol("simple_sync_test", pg)
self.assertGreaterEqual(len(possible_syncs), 1)
self.assertEqual(preprocessor.find_possible_syncs()[0], sync)
def test_sync_word_finding_odd_preamble(self):
preamble = "0101010"
sync = "1101"
pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "10", sync)],
num_messages=(20, 5),
data=(lambda i: 10 * i, lambda i: i))
# If we have a odd preamble length, the last bit of the preamble is counted to the sync
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
possible_syncs = preprocessor.find_possible_syncs()
#self.save_protocol("odd_preamble", pg)
self.assertEqual(preamble[-1] + sync[:-1], possible_syncs[0])
def test_sync_word_finding_special_preamble(self):
preamble = "111001110011100"
sync = "0110"
pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "10", sync)],
num_messages=(20, 5),
data=(lambda i: 10 * i, lambda i: i))
# If we have a odd preamble length, the last bit of the preamble is counted to the sync
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
possible_syncs = preprocessor.find_possible_syncs()
#self.save_protocol("special_preamble", pg)
self.assertEqual(sync, possible_syncs[0])
def test_sync_word_finding_errored_preamble(self):
preamble = "00010101010" # first bits are wrong
sync = "0110"
pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "10", sync)],
num_messages=(20, 5),
data=(lambda i: 10 * i, lambda i: i))
# If we have a odd preamble length, the last bit of the preamble is counted to the sync
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
possible_syncs = preprocessor.find_possible_syncs()
#self.save_protocol("errored_preamble", pg)
self.assertIn(preamble[-1] + sync[:-1], possible_syncs)
def test_sync_word_finding_with_two_sync_words(self):
preamble = "0xaaaa"
sync1, sync2 = "0x1234", "0xcafe"
pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync1), (preamble, sync2)],
num_messages=(15, 10),
data=(lambda i: 12 * i, lambda i: 16 * i))
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
possible_syncs = preprocessor.find_possible_syncs()
#self.save_protocol("two_syncs", pg)
self.assertGreaterEqual(len(possible_syncs), 2)
self.assertIn(ProtocolGenerator.to_bits(sync1), possible_syncs)
self.assertIn(ProtocolGenerator.to_bits(sync2), possible_syncs)
def test_multiple_sync_words(self):
hex_messages = [
"aaS1234",
"aaScafe",
"aaSdead",
"aaSbeef",
]
for i in range(1, 256):
messages = []
sync = "{0:02x}".format(i)
if sync.startswith("a"):
continue
for msg in hex_messages:
messages.append(Message.from_plain_hex_str(msg.replace("S", sync)))
for i in range(1, len(messages)):
messages[i].message_type = messages[0].message_type
ff = FormatFinder(messages)
ff.run()
self.assertEqual(len(ff.message_types), 1, msg=sync)
preamble = ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE)
self.assertEqual(preamble.start, 0, msg=sync)
self.assertEqual(preamble.length, 8, msg=sync)
sync = ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC)
self.assertEqual(sync.start, 8, msg=sync)
self.assertEqual(sync.length, 8, msg=sync)
def test_sync_word_finding_varying_message_length(self):
hex_messages = [
"aaaa9a7d0f1337471100009a44ebdd13517bf9",
"aaaa9a7d4747111337000134a4473c002b909630b11df37e34728c79c60396176aff2b5384e82f31511581d0cbb4822ad1b6734e2372ad5cf4af4c9d6b067e5f7ec359ec443c3b5ddc7a9e",
"aaaa9a7d0f13374711000205ee081d26c86b8c",
"aaaa9a7d474711133700037cae4cda789885f88f5fb29adc9acf954cb2850b9d94e7f3b009347c466790e89f2bcd728987d4670690861bbaa120f71f14d4ef8dc738a6d7c30e7d2143c267",
"aaaa9a7d0f133747110004c2906142300427f3"
]
messages = [Message.from_plain_hex_str(hex_msg) for hex_msg in hex_messages]
for i in range(1, len(messages)):
messages[i].message_type = messages[0].message_type
ff = FormatFinder(messages)
ff.run()
self.assertEqual(len(ff.message_types), 1)
preamble = ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE)
self.assertEqual(preamble.start, 0)
self.assertEqual(preamble.length, 16)
sync = ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC)
self.assertEqual(sync.start, 16)
self.assertEqual(sync.length, 16)
def test_sync_word_finding_common_prefix(self):
"""
Messages are very similar (odd and even ones are the same)
However, they do not have two different sync words!
The algorithm needs to check for a common prefix of the two found sync words
:return:
"""
sync = "0x1337"
num_messages = 10
alice = Participant("Alice", address_hex="dead01")
bob = Participant("Bob", address_hex="beef24")
mb = MessageTypeBuilder("protocol_with_one_message_type")
mb.add_label(FieldType.Function.PREAMBLE, 72)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 24)
mb.add_label(FieldType.Function.DST_ADDRESS, 24)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x1337"},
preambles_by_mt={mb.message_type: "10" * 36},
participants=[alice, bob])
random.seed(0)
for i in range(num_messages):
if i % 2 == 0:
source, destination = alice, bob
data_length = 8
else:
source, destination = bob, alice
data_length = 16
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
source=source, destination=destination)
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
possible_syncs = preprocessor.find_possible_syncs()
#self.save_protocol("sync_by_common_prefix", pg)
self.assertEqual(len(possible_syncs), 1)
# +0000 is okay, because this will get fixed by correction in FormatFinder
self.assertIn(possible_syncs[0], [ProtocolGenerator.to_bits(sync), ProtocolGenerator.to_bits(sync) + "0000"])
def test_with_given_preamble_and_sync(self):
preamble = "10101010"
sync = "10011"
pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync)],
num_messages=(20,),
data=(lambda i: 10 * i,))
# If we have a odd preamble length, the last bit of the preamble is counted to the sync
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages],
existing_message_types={i: msg.message_type for i, msg in
enumerate(pg.protocol.messages)})
preamble_starts, preamble_lengths, sync_len = preprocessor.preprocess()
#self.save_protocol("given_preamble", pg)
self.assertTrue(all(preamble_start == 0 for preamble_start in preamble_starts))
self.assertTrue(all(preamble_length == len(preamble) for preamble_length in preamble_lengths))
self.assertEqual(sync_len, len(sync))
@staticmethod
def build_protocol_generator(preamble_syncs: list, num_messages: tuple, data: tuple) -> ProtocolGenerator:
message_types = []
preambles_by_mt = dict()
syncs_by_mt = dict()
assert len(preamble_syncs) == len(num_messages) == len(data)
for i, (preamble, sync_word) in enumerate(preamble_syncs):
assert isinstance(preamble, str)
assert isinstance(sync_word, str)
preamble, sync_word = map(ProtocolGenerator.to_bits, (preamble, sync_word))
mb = MessageTypeBuilder("message type #{0}".format(i))
mb.add_label(FieldType.Function.PREAMBLE, len(preamble))
mb.add_label(FieldType.Function.SYNC, len(sync_word))
message_types.append(mb.message_type)
preambles_by_mt[mb.message_type] = preamble
syncs_by_mt[mb.message_type] = sync_word
pg = ProtocolGenerator(message_types, preambles_by_mt=preambles_by_mt, syncs_by_mt=syncs_by_mt)
for i, msg_type in enumerate(message_types):
for j in range(num_messages[i]):
if callable(data[i]):
msg_data = pg.decimal_to_bits(data[i](j), num_bits=8)
else:
msg_data = data[i]
pg.generate_message(message_type=msg_type, data=msg_data)
return pg

View File

@ -0,0 +1,149 @@
from tests.awre.AWRETestCase import AWRETestCase
from tests.utils_testing import get_path_for_data_file
from urh.awre.CommonRange import CommonRange
from urh.awre.FormatFinder import FormatFinder
from urh.awre.Preprocessor import Preprocessor
from urh.awre.ProtocolGenerator import ProtocolGenerator
from urh.signalprocessing.FieldType import FieldType
from urh.signalprocessing.Message import Message
from urh.signalprocessing.MessageType import MessageType
from urh.signalprocessing.Participant import Participant
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
import numpy as np
class TestAWRERealProtocols(AWRETestCase):
def setUp(self):
super().setUp()
alice = Participant("Alice", "A")
bob = Participant("Bob", "B")
self.participants = [alice, bob]
def test_format_finding_enocean(self):
enocean_protocol = ProtocolAnalyzer(None)
with open(get_path_for_data_file("enocean_bits.txt")) as f:
for line in f:
enocean_protocol.messages.append(Message.from_plain_bits_str(line.replace("\n", "")))
enocean_protocol.messages[-1].message_type = enocean_protocol.default_message_type
ff = FormatFinder(enocean_protocol.messages)
ff.perform_iteration()
message_types = ff.message_types
self.assertEqual(len(message_types), 1)
preamble = message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE)
self.assertEqual(preamble.start, 0)
self.assertEqual(preamble.length, 8)
sync = message_types[0].get_first_label_with_type(FieldType.Function.SYNC)
self.assertEqual(sync.start, 8)
self.assertEqual(sync.length, 4)
checksum = message_types[0].get_first_label_with_type(FieldType.Function.CHECKSUM)
self.assertEqual(checksum.start, 56)
self.assertEqual(checksum.length, 4)
self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.LENGTH))
self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER))
def test_format_finding_rwe(self):
ff, messages = self.get_format_finder_from_protocol_file("rwe.proto.xml", return_messages=True)
ff.run()
sync1, sync2 = "0x9a7d9a7d", "0x67686768"
preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in messages])
possible_syncs = preprocessor.find_possible_syncs()
self.assertIn(ProtocolGenerator.to_bits(sync1), possible_syncs)
self.assertIn(ProtocolGenerator.to_bits(sync2), possible_syncs)
ack_messages = (3, 5, 7, 9, 11, 13, 15, 17, 20)
ack_message_type = next(mt for mt, messages in ff.existing_message_types.items() if ack_messages[0] in messages)
self.assertTrue(all(ack_msg in ff.existing_message_types[ack_message_type] for ack_msg in ack_messages))
for mt in ff.message_types:
preamble = mt.get_first_label_with_type(FieldType.Function.PREAMBLE)
self.assertEqual(preamble.start, 0)
self.assertEqual(preamble.length, 32)
sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
self.assertEqual(sync.start, 32)
self.assertEqual(sync.length, 32)
length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
self.assertEqual(length.start, 64)
self.assertEqual(length.length, 8)
dst = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
self.assertEqual(dst.length, 24)
if mt == ack_message_type or 1 in ff.existing_message_types[mt]:
self.assertEqual(dst.start, 72)
else:
self.assertEqual(dst.start, 88)
if mt != ack_message_type and 1 not in ff.existing_message_types[mt]:
src = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
self.assertEqual(src.start, 112)
self.assertEqual(src.length, 24)
elif 1 in ff.existing_message_types[mt]:
# long ack
src = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
self.assertEqual(src.start, 96)
self.assertEqual(src.length, 24)
crc = mt.get_first_label_with_type(FieldType.Function.CHECKSUM)
self.assertIsNotNone(crc)
def test_homematic(self):
proto_file = get_path_for_data_file("homematic.proto.xml")
protocol = ProtocolAnalyzer(signal=None, filename=proto_file)
protocol.message_types = []
protocol.from_xml_file(filename=proto_file, read_bits=True)
# prevent interfering with preassinged labels
protocol.message_types = [MessageType("Default")]
participants = sorted({msg.participant for msg in protocol.messages})
self.clear_message_types(protocol.messages)
ff = FormatFinder(protocol.messages, participants=participants)
ff.known_participant_addresses.clear()
ff.perform_iteration()
self.assertGreater(len(ff.message_types), 0)
for i, message_type in enumerate(ff.message_types):
preamble = message_type.get_first_label_with_type(FieldType.Function.PREAMBLE)
self.assertEqual(preamble.start, 0)
self.assertEqual(preamble.length, 32)
sync = message_type.get_first_label_with_type(FieldType.Function.SYNC)
self.assertEqual(sync.start, 32)
self.assertEqual(sync.length, 32)
length = message_type.get_first_label_with_type(FieldType.Function.LENGTH)
self.assertEqual(length.start, 64)
self.assertEqual(length.length, 8)
seq = message_type.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
self.assertEqual(seq.start, 72)
self.assertEqual(seq.length, 8)
src = message_type.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
self.assertEqual(src.start, 96)
self.assertEqual(src.length, 24)
dst = message_type.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
self.assertEqual(dst.start, 120)
self.assertEqual(dst.length, 24)
checksum = message_type.get_first_label_with_type(FieldType.Function.CHECKSUM)
self.assertEqual(checksum.length, 16)
self.assertIn("CC1101", checksum.checksum.caption)
for msg_index in ff.existing_message_types[message_type]:
msg_len = len(protocol.messages[msg_index])
self.assertEqual(checksum.start, msg_len-16)
self.assertEqual(checksum.end, msg_len)

View File

@ -0,0 +1,102 @@
import array
import numpy as np
from tests.awre.AWRETestCase import AWRETestCase
from urh.awre.CommonRange import ChecksumRange
from urh.awre.FormatFinder import FormatFinder
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
from urh.awre.ProtocolGenerator import ProtocolGenerator
from urh.awre.engines.ChecksumEngine import ChecksumEngine
from urh.signalprocessing.FieldType import FieldType
from urh.util import util
from urh.util.GenericCRC import GenericCRC
from urh.cythonext import util as c_util
class TestChecksumEngine(AWRETestCase):
def test_find_crc8(self):
messages = ["aabbcc7d", "abcdee24", "dacafe33"]
message_bits = [np.array(msg, dtype=np.uint8) for msg in map(util.hex2bit, messages)]
checksum_engine = ChecksumEngine(message_bits, n_gram_length=8)
result = checksum_engine.find()
self.assertEqual(len(result), 1)
checksum_range = result[0] # type: ChecksumRange
self.assertEqual(checksum_range.length, 8)
self.assertEqual(checksum_range.start, 24)
reference = GenericCRC()
reference.set_polynomial_from_hex("0x07")
self.assertEqual(checksum_range.crc.polynomial, reference.polynomial)
self.assertEqual(checksum_range.message_indices, {0, 1, 2})
def test_find_crc16(self):
messages = ["12345678347B", "abcdefffABBD", "cafe1337CE12"]
message_bits = [np.array(msg, dtype=np.uint8) for msg in map(util.hex2bit, messages)]
checksum_engine = ChecksumEngine(message_bits, n_gram_length=8)
result = checksum_engine.find()
self.assertEqual(len(result), 1)
checksum_range = result[0] # type: ChecksumRange
self.assertEqual(checksum_range.start, 32)
self.assertEqual(checksum_range.length, 16)
reference = GenericCRC()
reference.set_polynomial_from_hex("0x8005")
self.assertEqual(checksum_range.crc.polynomial, reference.polynomial)
self.assertEqual(checksum_range.message_indices, {0, 1, 2})
def test_find_crc32(self):
messages = ["deadcafe5D7F3F5A", "47111337E3319242", "beefaffe0DCD0E15"]
message_bits = [np.array(msg, dtype=np.uint8) for msg in map(util.hex2bit, messages)]
checksum_engine = ChecksumEngine(message_bits, n_gram_length=8)
result = checksum_engine.find()
self.assertEqual(len(result), 1)
checksum_range = result[0] # type: ChecksumRange
self.assertEqual(checksum_range.start, 32)
self.assertEqual(checksum_range.length, 32)
reference = GenericCRC()
reference.set_polynomial_from_hex("0x04C11DB7")
self.assertEqual(checksum_range.crc.polynomial, reference.polynomial)
self.assertEqual(checksum_range.message_indices, {0, 1, 2})
def test_find_generated_crc16(self):
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.DATA, 32)
mb.add_checksum_label(16, GenericCRC.from_standard_checksum("CRC16 CCITT"))
mb2 = MessageTypeBuilder("data2")
mb2.add_label(FieldType.Function.PREAMBLE, 8)
mb2.add_label(FieldType.Function.SYNC, 16)
mb2.add_label(FieldType.Function.LENGTH, 8)
mb2.add_label(FieldType.Function.DATA, 16)
mb2.add_checksum_label(16, GenericCRC.from_standard_checksum("CRC16 CCITT"))
pg = ProtocolGenerator([mb.message_type, mb2.message_type], syncs_by_mt={mb.message_type: "0x1234", mb2.message_type: "0x1234"})
num_messages = 5
for i in range(num_messages):
pg.generate_message(data="{0:032b}".format(i), message_type=mb.message_type)
pg.generate_message(data="{0:016b}".format(i), message_type=mb2.message_type)
#self.save_protocol("crc16_test", pg)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
ff.run()
self.assertEqual(len(ff.message_types), 2)
for mt in ff.message_types:
checksum_label = mt.get_first_label_with_type(FieldType.Function.CHECKSUM)
self.assertEqual(checksum_label.length, 16)
self.assertEqual(checksum_label.checksum.caption, "CRC16 CCITT")

View File

@ -0,0 +1,35 @@
import unittest
from urh.awre.CommonRange import CommonRange
class TestCommonRange(unittest.TestCase):
def test_ensure_not_overlaps(self):
test_range = CommonRange(start=4, length=8, value="12345678")
self.assertEqual(test_range.end, 11)
# no overlapping
self.assertEqual(test_range, test_range.ensure_not_overlaps(0, 3)[0])
self.assertEqual(test_range, test_range.ensure_not_overlaps(20, 24)[0])
# overlapping on left
result = test_range.ensure_not_overlaps(2, 6)[0]
self.assertEqual(result.start, 6)
self.assertEqual(result.end, 11)
# overlapping on right
result = test_range.ensure_not_overlaps(6, 14)[0]
self.assertEqual(result.start, 4)
self.assertEqual(result.end, 5)
# full overlapping
self.assertEqual(len(test_range.ensure_not_overlaps(3, 14)), 0)
# overlapping in the middle
result = test_range.ensure_not_overlaps(6, 9)
self.assertEqual(len(result), 2)
left, right = result[0], result[1]
self.assertEqual(left.start, 4)
self.assertEqual(left.end, 5)
self.assertEqual(right.start, 10)
self.assertEqual(right.end, 11)

View File

@ -0,0 +1,102 @@
import numpy as np
from tests.awre.AWRETestCase import AWRETestCase
from urh.awre.CommonRange import CommonRange, CommonRangeContainer
from urh.awre.FormatFinder import FormatFinder
class TestFormatFinder(AWRETestCase):
def test_create_message_types_1(self):
rng1 = CommonRange(0, 8, "1" * 8, score=1, field_type="Length")
rng1.message_indices = {0, 1, 2}
rng2 = CommonRange(8, 8, "1" * 8, score=1, field_type="Address")
rng2.message_indices = {0, 1, 2}
message_types = FormatFinder.create_common_range_containers({rng1, rng2})
self.assertEqual(len(message_types), 1)
expected = CommonRangeContainer([rng1, rng2], message_indices={0, 1, 2})
self.assertEqual(message_types[0], expected)
def test_create_message_types_2(self):
rng1 = CommonRange(0, 8, "1" * 8, score=1, field_type="Length")
rng1.message_indices = {0, 2, 4, 6, 8, 12}
rng2 = CommonRange(8, 8, "1" * 8, score=1, field_type="Address")
rng2.message_indices = {1, 2, 3, 4, 5, 12}
rng3 = CommonRange(16, 8, "1" * 8, score=1, field_type="Seq")
rng3.message_indices = {1, 3, 5, 7, 12}
message_types = FormatFinder.create_common_range_containers({rng1, rng2, rng3})
expected1 = CommonRangeContainer([rng1], message_indices={0, 6, 8})
expected2 = CommonRangeContainer([rng1, rng2], message_indices={2, 4})
expected3 = CommonRangeContainer([rng1, rng2, rng3], message_indices={12})
expected4 = CommonRangeContainer([rng2, rng3], message_indices={1, 3, 5})
expected5 = CommonRangeContainer([rng3], message_indices={7})
self.assertEqual(len(message_types), 5)
self.assertIn(expected1, message_types)
self.assertIn(expected2, message_types)
self.assertIn(expected3, message_types)
self.assertIn(expected4, message_types)
self.assertIn(expected5, message_types)
def test_retransform_message_indices(self):
sync_ends = np.array([12, 12, 12, 14, 14])
rng = CommonRange(0, 8, "1" * 8, score=1, field_type="length", message_indices={0, 1, 2, 3, 4})
retransformed_ranges = FormatFinder.retransform_message_indices([rng], [0, 1, 2, 3, 4], sync_ends)
# two different sync ends
self.assertEqual(len(retransformed_ranges), 2)
expected1 = CommonRange(12, 8, "1" * 8, score=1, field_type="length", message_indices={0, 1, 2})
expected2 = CommonRange(14, 8, "1" * 8, score=1, field_type="length", message_indices={3, 4})
self.assertIn(expected1, retransformed_ranges)
self.assertIn(expected2, retransformed_ranges)
def test_handle_no_overlapping_conflict(self):
rng1 = CommonRange(0, 8, "1" * 8, score=1, field_type="Length")
rng1.message_indices = {0, 1, 2}
rng2 = CommonRange(8, 8, "1" * 8, score=1, field_type="Address")
rng2.message_indices = {0, 1, 2}
container = CommonRangeContainer([rng1, rng2], message_indices={0, 1, 2})
# no conflict
result = FormatFinder.handle_overlapping_conflict([container])
self.assertEqual(len(result), 1)
self.assertEqual(len(result[0]), 2)
self.assertIn(rng1, result[0])
self.assertEqual(result[0].message_indices, {0, 1, 2})
self.assertIn(rng2, result[0])
def test_handle_easy_overlapping_conflict(self):
# Easy conflict: First Label has higher score
rng1 = CommonRange(8, 8, "1" * 8, score=1, field_type="Length")
rng1.message_indices = {0, 1, 2}
rng2 = CommonRange(8, 8, "1" * 8, score=0.8, field_type="Address")
rng2.message_indices = {0, 1, 2}
container = CommonRangeContainer([rng1, rng2], message_indices={0, 1, 2})
result = FormatFinder.handle_overlapping_conflict([container])
self.assertEqual(len(result), 1)
self.assertEqual(len(result[0]), 1)
self.assertIn(rng1, result[0])
self.assertEqual(result[0].message_indices, {0, 1, 2})
def test_handle_medium_overlapping_conflict(self):
rng1 = CommonRange(8, 8, "1" * 8, score=1, field_type="Length")
rng2 = CommonRange(4, 10, "1" * 8, score=0.8, field_type="Address")
rng3 = CommonRange(15, 20, "1" * 8, score=1, field_type="Seq")
rng4 = CommonRange(60, 80, "1" * 8, score=0.8, field_type="Type")
rng5 = CommonRange(70, 90, "1" * 8, score=0.9, field_type="Data")
container = CommonRangeContainer([rng1, rng2, rng3, rng4, rng5])
result = FormatFinder.handle_overlapping_conflict([container])
self.assertEqual(len(result), 1)
self.assertEqual(len(result[0]), 3)
self.assertIn(rng1, result[0])
self.assertIn(rng3, result[0])
self.assertIn(rng5, result[0])

View File

@ -0,0 +1,236 @@
from tests.awre.AWRETestCase import AWRETestCase
from urh.awre import AutoAssigner
from urh.awre.FormatFinder import FormatFinder
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
from urh.awre.Preprocessor import Preprocessor
from urh.awre.ProtocolGenerator import ProtocolGenerator
from urh.signalprocessing.FieldType import FieldType
from urh.signalprocessing.Participant import Participant
from urh.util import util
class TestGeneratedProtocols(AWRETestCase):
def __check_addresses(self, messages, format_finder, known_participant_addresses):
"""
Use the AutoAssigner used also in main GUI to test assigned participant addresses to get same results
as in main program and not rely on cache of FormatFinder, because values there might be false
but SRC address labels still on right position which is the basis for Auto Assigner
:param messages:
:param format_finder:
:param known_participant_addresses:
:return:
"""
for msg_type, indices in format_finder.existing_message_types.items():
for i in indices:
messages[i].message_type = msg_type
participants = list(set(m.participant for m in messages))
for p in participants:
p.address_hex = ""
AutoAssigner.auto_assign_participant_addresses(messages, participants)
for i in range(len(participants)):
self.assertIn(participants[i].address_hex,
list(map(util.convert_numbers_to_hex_string, known_participant_addresses.values())),
msg=" [ " + " ".join(p.address_hex for p in participants) + " ]")
def test_without_preamble(self):
alice = Participant("Alice", address_hex="24")
broadcast = Participant("Broadcast", address_hex="ff")
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 8)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x8e88"},
preambles_by_mt={mb.message_type: "10" * 8},
participants=[alice, broadcast])
for i in range(20):
data_bits = 16 if i % 2 == 0 else 32
source = pg.participants[i % 2]
destination = pg.participants[(i + 1) % 2]
pg.generate_message(data="1010" * (data_bits // 4), source=source, destination=destination)
#self.save_protocol("without_preamble", pg)
self.clear_message_types(pg.messages)
ff = FormatFinder(pg.messages)
ff.known_participant_addresses.clear()
ff.run()
self.assertEqual(len(ff.message_types), 1)
mt = ff.message_types[0]
sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
self.assertEqual(sync.start, 0)
self.assertEqual(sync.length, 16)
length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
self.assertEqual(length.start, 16)
self.assertEqual(length.length, 8)
dst = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
self.assertEqual(dst.start, 24)
self.assertEqual(dst.length, 8)
seq = mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
self.assertEqual(seq.start, 32)
self.assertEqual(seq.length, 8)
def test_without_preamble_random_data(self):
ff = self.get_format_finder_from_protocol_file("without_ack_random_data.proto.xml")
ff.run()
self.assertEqual(len(ff.message_types), 1)
mt = ff.message_types[0]
sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
self.assertEqual(sync.start, 0)
self.assertEqual(sync.length, 16)
length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
self.assertEqual(length.start, 16)
self.assertEqual(length.length, 8)
dst = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
self.assertEqual(dst.start, 24)
self.assertEqual(dst.length, 8)
seq = mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
self.assertEqual(seq.start, 32)
self.assertEqual(seq.length, 8)
def test_without_preamble_random_data2(self):
ff = self.get_format_finder_from_protocol_file("without_ack_random_data2.proto.xml")
ff.run()
self.assertEqual(len(ff.message_types), 1)
mt = ff.message_types[0]
sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
self.assertEqual(sync.start, 0)
self.assertEqual(sync.length, 16)
length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
self.assertEqual(length.start, 16)
self.assertEqual(length.length, 8)
dst = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
self.assertEqual(dst.start, 24)
self.assertEqual(dst.length, 8)
seq = mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
self.assertEqual(seq.start, 32)
self.assertEqual(seq.length, 8)
def test_with_checksum(self):
ff = self.get_format_finder_from_protocol_file("with_checksum.proto.xml", clear_participant_addresses=False)
known_participant_addresses = ff.known_participant_addresses.copy()
ff.known_participant_addresses.clear()
ff.run()
self.assertIn(known_participant_addresses[0].tostring(),
list(map(bytes, ff.known_participant_addresses.values())))
self.assertIn(known_participant_addresses[1].tostring(),
list(map(bytes, ff.known_participant_addresses.values())))
self.assertEqual(len(ff.message_types), 3)
def test_with_only_one_address(self):
ff = self.get_format_finder_from_protocol_file("only_one_address.proto.xml", clear_participant_addresses=False)
known_participant_addresses = ff.known_participant_addresses.copy()
ff.known_participant_addresses.clear()
ff.run()
self.assertIn(known_participant_addresses[0].tostring(),
list(map(bytes, ff.known_participant_addresses.values())))
self.assertIn(known_participant_addresses[1].tostring(),
list(map(bytes, ff.known_participant_addresses.values())))
def test_with_four_broken(self):
ff, messages = self.get_format_finder_from_protocol_file("four_broken.proto.xml",
clear_participant_addresses=False,
return_messages=True)
assert isinstance(ff, FormatFinder)
known_participant_addresses = ff.known_participant_addresses.copy()
ff.known_participant_addresses.clear()
ff.run()
self.__check_addresses(messages, ff, known_participant_addresses)
for i in range(4, len(messages)):
mt = next(mt for mt, indices in ff.existing_message_types.items() if i in indices)
self.assertIsNotNone(mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER))
def test_with_one_address_one_message_type(self):
ff, messages = self.get_format_finder_from_protocol_file("one_address_one_mt.proto.xml",
clear_participant_addresses=False,
return_messages=True)
self.assertEqual(len(messages), 17)
self.assertEqual(len(ff.hexvectors), 17)
known_participant_addresses = ff.known_participant_addresses.copy()
ff.known_participant_addresses.clear()
ff.run()
self.assertEqual(len(ff.message_types), 1)
self.assertIn(known_participant_addresses[0].tostring(),
list(map(bytes, ff.known_participant_addresses.values())))
self.assertIn(known_participant_addresses[1].tostring(),
list(map(bytes, ff.known_participant_addresses.values())))
def test_without_preamble_24_messages(self):
ff, messages = self.get_format_finder_from_protocol_file("no_preamble24.proto.xml",
clear_participant_addresses=False,
return_messages=True)
known_participant_addresses = ff.known_participant_addresses.copy()
ff.known_participant_addresses.clear()
ff.run()
self.assertEqual(len(ff.message_types), 1)
self.assertIn(known_participant_addresses[0].tostring(),
list(map(bytes, ff.known_participant_addresses.values())))
self.assertIn(known_participant_addresses[1].tostring(),
list(map(bytes, ff.known_participant_addresses.values())))
def test_with_three_syncs_different_preamble_lengths(self):
ff, messages = self.get_format_finder_from_protocol_file("three_syncs.proto.xml", return_messages=True)
preprocessor = Preprocessor(ff.get_bitvectors_from_messages(messages))
sync_words = preprocessor.find_possible_syncs()
self.assertIn("0000010000100000", sync_words, msg="Sync 1")
self.assertIn("0010001000100010", sync_words, msg="Sync 2")
self.assertIn("0110011101100111", sync_words, msg="Sync 3")
ff.run()
expected_sync_ends = [32, 24, 40, 24, 32, 24, 40, 24, 32, 24, 40, 24, 32, 24, 40, 24]
for i, (s1, s2) in enumerate(zip(expected_sync_ends, ff.sync_ends)):
self.assertEqual(s1, s2, msg=str(i))
def test_with_four_participants(self):
ff, messages = self.get_format_finder_from_protocol_file("four_participants.proto.xml",
clear_participant_addresses=False,
return_messages=True)
known_participant_addresses = ff.known_participant_addresses.copy()
ff.known_participant_addresses.clear()
ff.run()
self.__check_addresses(messages, ff, known_participant_addresses)
self.assertEqual(len(ff.message_types), 3)

View File

@ -0,0 +1,167 @@
import random
from tests.awre.AWRETestCase import AWRETestCase
from urh.awre.FormatFinder import FormatFinder
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
from urh.awre.ProtocolGenerator import ProtocolGenerator
from urh.awre.engines.LengthEngine import LengthEngine
from urh.signalprocessing.FieldType import FieldType
from urh.signalprocessing.ProtocoLabel import ProtocolLabel
class TestLengthEngine(AWRETestCase):
def test_simple_protocol(self):
"""
Test a simple protocol with
preamble, sync and length field (8 bit) and some random data
:return:
"""
mb = MessageTypeBuilder("simple_length_test")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
num_messages_by_data_length = {8: 5, 16: 10, 32: 15}
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x9a9d"})
random.seed(0)
for data_length, num_messages in num_messages_by_data_length.items():
for i in range(num_messages):
pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(data_length)]))
#self.save_protocol("simple_length", pg)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
length_engine = LengthEngine(ff.bitvectors)
highscored_ranges = length_engine.find(n_gram_length=8)
self.assertEqual(len(highscored_ranges), 3)
ff.perform_iteration()
self.assertEqual(len(ff.message_types), 1)
self.assertGreater(len(ff.message_types[0]), 0)
label = ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH)
self.assertEqual(label.start, 24)
self.assertEqual(label.length, 8)
def test_easy_protocol(self):
"""
preamble, sync, sequence number, length field (8 bit) and some random data
:return:
"""
mb = MessageTypeBuilder("easy_length_test")
mb.add_label(FieldType.Function.PREAMBLE, 16)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
num_messages_by_data_length = {32: 10, 64: 15, 16: 5, 24: 7}
pg = ProtocolGenerator([mb.message_type],
preambles_by_mt={mb.message_type: "10" * 8},
syncs_by_mt={mb.message_type: "0xcafe"})
for data_length, num_messages in num_messages_by_data_length.items():
for i in range(num_messages):
if i % 4 == 0:
data = "1" * data_length
elif i % 4 == 1:
data = "0" * data_length
elif i % 4 == 2:
data = "10" * (data_length // 2)
else:
data = "01" * (data_length // 2)
pg.generate_message(data=data)
#self.save_protocol("easy_length", pg)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
length_engine = LengthEngine(ff.bitvectors)
highscored_ranges = length_engine.find(n_gram_length=8)
self.assertEqual(len(highscored_ranges), 4)
ff.perform_iteration()
self.assertEqual(len(ff.message_types), 1)
self.assertGreater(len(ff.message_types[0]), 0)
label = ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH)
self.assertIsInstance(label, ProtocolLabel)
self.assertEqual(label.start, 32)
self.assertEqual(label.length, 8)
def test_medium_protocol(self):
"""
Protocol with two message types. Length field only present in one of them
:return:
"""
mb1 = MessageTypeBuilder("data")
mb1.add_label(FieldType.Function.PREAMBLE, 8)
mb1.add_label(FieldType.Function.SYNC, 8)
mb1.add_label(FieldType.Function.LENGTH, 8)
mb1.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
mb2 = MessageTypeBuilder("ack")
mb2.add_label(FieldType.Function.PREAMBLE, 8)
mb2.add_label(FieldType.Function.SYNC, 8)
pg = ProtocolGenerator([mb1.message_type, mb2.message_type],
syncs_by_mt={mb1.message_type: "11110011",
mb2.message_type: "11110011"})
num_messages_by_data_length = {8: 5, 16: 10, 32: 5}
for data_length, num_messages in num_messages_by_data_length.items():
for i in range(num_messages):
pg.generate_message(data=pg.decimal_to_bits(10 * i, data_length), message_type=mb1.message_type)
pg.generate_message(message_type=mb2.message_type, data="0xaf")
#self.save_protocol("medium_length", pg)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
ff.perform_iteration()
self.assertEqual(len(ff.message_types), 2)
length_mt = next(
mt for mt in ff.message_types if mt.get_first_label_with_type(FieldType.Function.LENGTH) is not None)
length_label = length_mt.get_first_label_with_type(FieldType.Function.LENGTH)
for i, sync_end in enumerate(ff.sync_ends):
self.assertEqual(sync_end, 16, msg=str(i))
self.assertEqual(16, length_label.start)
self.assertEqual(8, length_label.length)
def test_little_endian_16_bit(self):
mb = MessageTypeBuilder("little_endian_16_length_test")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 16)
num_messages_by_data_length = {256*8: 5, 16: 4, 512: 2}
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x9a9d"},
little_endian=True)
random.seed(0)
for data_length, num_messages in num_messages_by_data_length.items():
for i in range(num_messages):
pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(data_length)]))
#self.save_protocol("little_endian_16_length_test", pg)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
length_engine = LengthEngine(ff.bitvectors)
highscored_ranges = length_engine.find(n_gram_length=8)
self.assertEqual(len(highscored_ranges), 3)
ff.perform_iteration()
self.assertEqual(len(ff.message_types), 1)
self.assertGreater(len(ff.message_types[0]), 0)
label = ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH)
self.assertEqual(label.start, 24)
self.assertEqual(label.length, 16)

View File

@ -0,0 +1,198 @@
import copy
import random
from urh.signalprocessing.MessageType import MessageType
from urh.awre.FormatFinder import FormatFinder
from urh.awre.ProtocolGenerator import ProtocolGenerator
from urh.signalprocessing.FieldType import FieldType
from tests.awre.AWRETestCase import AWRETestCase
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
from urh.signalprocessing.Participant import Participant
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
class TestPartiallyLabeled(AWRETestCase):
"""
Some tests if there are already information about the message types present
"""
def test_fully_labeled(self):
"""
For fully labeled protocol, nothing should be done
:return:
"""
protocol = self.__prepare_example_protocol()
message_types = sorted(copy.deepcopy(protocol.message_types), key=lambda x: x.name)
ff = FormatFinder(protocol.messages)
ff.perform_iteration()
self.assertEqual(len(message_types), len(ff.message_types))
for mt1, mt2 in zip(message_types, ff.message_types):
self.assertTrue(self.__message_types_have_same_labels(mt1, mt2))
def test_one_message_type_empty(self):
"""
Empty the "ACK" message type, the labels should be find by FormatFinder
:return:
"""
protocol = self.__prepare_example_protocol()
n_message_types = len(protocol.message_types)
ack_mt = next(mt for mt in protocol.message_types if mt.name == "ack")
ack_mt.clear()
self.assertEqual(len(ack_mt), 0)
ff = FormatFinder(protocol.messages)
ff.perform_iteration()
self.assertEqual(n_message_types, len(ff.message_types))
self.assertEqual(len(ack_mt), 4, msg=str(ack_mt))
def test_given_address_information(self):
"""
Empty both message types and see if addresses are found, when information of participant addresses is given
:return:
"""
protocol = self.__prepare_example_protocol()
self.clear_message_types(protocol.messages)
ff = FormatFinder(protocol.messages)
ff.perform_iteration()
self.assertEqual(2, len(ff.message_types))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE))
self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.PREAMBLE))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC))
self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.SYNC))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH))
self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.LENGTH))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
self.assertIsNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
def test_type_part_already_labeled(self):
protocol = self.__prepare_simple_example_protocol()
self.clear_message_types(protocol.messages)
ff = FormatFinder(protocol.messages)
# overlaps type
ff.message_types[0].add_protocol_label_start_length(32, 8)
ff.perform_iteration()
self.assertEqual(1, len(ff.message_types))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
def test_length_part_already_labeled(self):
protocol = self.__prepare_simple_example_protocol()
self.clear_message_types(protocol.messages)
ff = FormatFinder(protocol.messages)
# overlaps length
ff.message_types[0].add_protocol_label_start_length(24, 8)
ff.perform_iteration()
self.assertEqual(1, len(ff.message_types))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC))
self.assertIsNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
def test_address_part_already_labeled(self):
protocol = self.__prepare_simple_example_protocol()
self.clear_message_types(protocol.messages)
ff = FormatFinder(protocol.messages)
# overlaps dst address
ff.message_types[0].add_protocol_label_start_length(40, 16)
ff.perform_iteration()
self.assertEqual(1, len(ff.message_types))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH))
self.assertIsNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
@staticmethod
def __message_types_have_same_labels(mt1: MessageType, mt2: MessageType):
if len(mt1) != len(mt2):
return False
for i, lbl in enumerate(mt1):
if lbl != mt2[i]:
return False
return True
def __prepare_example_protocol(self) -> ProtocolAnalyzer:
alice = Participant("Alice", "A", address_hex="1234")
bob = Participant("Bob", "B", address_hex="cafe")
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.TYPE, 8)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb_ack = MessageTypeBuilder("ack")
mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
mb_ack.add_label(FieldType.Function.SYNC, 16)
mb_ack.add_label(FieldType.Function.LENGTH, 8)
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
num_messages = 50
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"},
participants=[alice, bob])
random.seed(0)
for i in range(num_messages):
if i % 2 == 0:
source, destination = alice, bob
data_length = 8
else:
source, destination = bob, alice
data_length = 16
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
source=source, destination=destination)
pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination)
#self.save_protocol("labeled_protocol", pg)
return pg.protocol
def __prepare_simple_example_protocol(self):
random.seed(0)
alice = Participant("Alice", "A", address_hex="1234")
bob = Participant("Bob", "B", address_hex="cafe")
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.TYPE, 8)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x6768"},
participants=[alice, bob])
for i in range(10):
pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(16)]), source=alice, destination=bob)
pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(8)]), source=bob, destination=alice)
return pg.protocol

View File

@ -0,0 +1,182 @@
from tests.awre.AWRETestCase import AWRETestCase
from urh.awre.CommonRange import CommonRange
from urh.awre.FormatFinder import FormatFinder
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
from urh.awre.ProtocolGenerator import ProtocolGenerator
from urh.awre.engines.SequenceNumberEngine import SequenceNumberEngine
from urh.signalprocessing.FieldType import FieldType
from urh.signalprocessing.Participant import Participant
class TestSequenceNumberEngine(AWRETestCase):
def test_simple_protocol(self):
"""
Test a simple protocol with
preamble, sync and increasing sequence number (8 bit) and some constant data
:return:
"""
mb = MessageTypeBuilder("simple_seq_test")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
num_messages = 20
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x9a9d"})
for i in range(num_messages):
pg.generate_message(data="0xcafe")
#self.save_protocol("simple_sequence_number", pg)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
seq_engine = SequenceNumberEngine(ff.bitvectors, n_gram_length=8)
highscored_ranges = seq_engine.find()
self.assertEqual(len(highscored_ranges), 1)
ff.perform_iteration()
self.assertEqual(len(ff.message_types), 1)
self.assertGreater(len(ff.message_types[0]), 0)
self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1)
label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
self.assertEqual(label.start, 24)
self.assertEqual(label.length, 8)
def test_16bit_seq_nr(self):
mb = MessageTypeBuilder("16bit_seq_test")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
num_messages = 10
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x9a9d"}, sequence_number_increment=64)
for i in range(num_messages):
pg.generate_message(data="0xcafe")
#self.save_protocol("16bit_seq", pg)
bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages, sync_ends=[24]*num_messages)
seq_engine = SequenceNumberEngine(bitvectors, n_gram_length=8)
highscored_ranges = seq_engine.find()
self.assertEqual(len(highscored_ranges), 1)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
ff.perform_iteration()
self.assertEqual(len(ff.message_types), 1)
self.assertGreater(len(ff.message_types[0]), 0)
self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1)
label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
self.assertEqual(label.start, 24)
self.assertEqual(label.length, 16)
def test_16bit_seq_nr_with_zeros_in_first_part(self):
mb = MessageTypeBuilder("16bit_seq_first_byte_zero_test")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
num_messages = 10
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x9a9d"}, sequence_number_increment=1)
for i in range(num_messages):
pg.generate_message(data="0xcafe" + "abc" * i)
#self.save_protocol("16bit_seq_first_byte_zero_test", pg)
bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages, sync_ends=[24]*num_messages)
seq_engine = SequenceNumberEngine(bitvectors, n_gram_length=8)
highscored_ranges = seq_engine.find()
self.assertEqual(len(highscored_ranges), 1)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
ff.perform_iteration()
self.assertEqual(len(ff.message_types), 1)
self.assertGreater(len(ff.message_types[0]), 0)
self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1)
label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
# Not consider constants as part of SEQ Nr!
self.assertEqual(label.start, 40)
self.assertEqual(label.length, 8)
def test_no_sequence_number(self):
"""
Ensure no sequence number is labeled, when it cannot be found
:return:
"""
alice = Participant("Alice", address_hex="dead")
bob = Participant("Bob", address_hex="beef")
mb = MessageTypeBuilder("protocol_with_one_message_type")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
num_messages = 3
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x1337"},
participants=[alice, bob])
for i in range(num_messages):
if i % 2 == 0:
source, destination = alice, bob
else:
source, destination = bob, alice
pg.generate_message(data="", source=source, destination=destination)
#self.save_protocol("protocol_1", pg)
# Delete message type information -> no prior knowledge
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
ff.known_participant_addresses.clear()
ff.perform_iteration()
self.assertEqual(len(ff.message_types), 1)
self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 0)
def test_sequence_number_little_endian_16_bit(self):
mb = MessageTypeBuilder("16bit_seq_test")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
num_messages = 8
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x9a9d"},
little_endian=True, sequence_number_increment=64)
for i in range(num_messages):
pg.generate_message(data="0xcafe")
#self.save_protocol("16bit_litte_endian_seq", pg)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
ff.perform_iteration()
self.assertEqual(len(ff.message_types), 1)
self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1)
label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
self.assertEqual(label.start, 24)
self.assertEqual(label.length, 16)