387 lines
17 KiB
Python
387 lines
17 KiB
Python
import random
|
|
from array import array
|
|
|
|
import numpy as np
|
|
|
|
from tests.awre.AWRETestCase import AWRETestCase
|
|
from tests.utils_testing import get_path_for_data_file
|
|
from urh.awre.FormatFinder import FormatFinder
|
|
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
|
|
from urh.awre.ProtocolGenerator import ProtocolGenerator
|
|
from urh.awre.engines.AddressEngine import AddressEngine
|
|
from urh.signalprocessing.FieldType import FieldType
|
|
from urh.signalprocessing.Message import Message
|
|
from urh.signalprocessing.Participant import Participant
|
|
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
|
|
from urh.util import util
|
|
|
|
|
|
class TestAddressEngine(AWRETestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.alice = Participant("Alice", "A", address_hex="1234")
|
|
self.bob = Participant("Bob", "B", address_hex="cafe")
|
|
|
|
def test_one_participant(self):
|
|
"""
|
|
Test a simple protocol with
|
|
preamble, sync and length field (8 bit) and some random data
|
|
|
|
:return:
|
|
"""
|
|
mb = MessageTypeBuilder("simple_address_test")
|
|
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
|
mb.add_label(FieldType.Function.SYNC, 16)
|
|
mb.add_label(FieldType.Function.LENGTH, 8)
|
|
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
|
|
|
num_messages_by_data_length = {8: 5, 16: 10, 32: 15}
|
|
pg = ProtocolGenerator([mb.message_type],
|
|
syncs_by_mt={mb.message_type: "0x9a9d"},
|
|
participants=[self.alice])
|
|
for data_length, num_messages in num_messages_by_data_length.items():
|
|
for i in range(num_messages):
|
|
pg.generate_message(data=pg.decimal_to_bits(22 * i, data_length), source=self.alice)
|
|
|
|
#self.save_protocol("address_one_participant", pg)
|
|
|
|
self.clear_message_types(pg.protocol.messages)
|
|
ff = FormatFinder(pg.protocol.messages)
|
|
|
|
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
|
|
address_dict = address_engine.find_addresses()
|
|
|
|
self.assertEqual(len(address_dict), 0)
|
|
|
|
def test_two_participants(self):
|
|
mb = MessageTypeBuilder("address_two_participants")
|
|
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
|
mb.add_label(FieldType.Function.SYNC, 16)
|
|
mb.add_label(FieldType.Function.LENGTH, 8)
|
|
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
|
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
|
|
|
num_messages = 50
|
|
|
|
pg = ProtocolGenerator([mb.message_type],
|
|
syncs_by_mt={mb.message_type: "0x9a9d"},
|
|
participants=[self.alice, self.bob])
|
|
|
|
for i in range(num_messages):
|
|
if i % 2 == 0:
|
|
source, destination = self.alice, self.bob
|
|
data_length = 8
|
|
else:
|
|
source, destination = self.bob, self.alice
|
|
data_length = 16
|
|
pg.generate_message(data=pg.decimal_to_bits(4 * i, data_length), source=source, destination=destination)
|
|
|
|
#self.save_protocol("address_two_participants", pg)
|
|
|
|
self.clear_message_types(pg.protocol.messages)
|
|
ff = FormatFinder(pg.protocol.messages)
|
|
|
|
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
|
|
address_dict = address_engine.find_addresses()
|
|
self.assertEqual(len(address_dict), 2)
|
|
addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0]))
|
|
addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1]))
|
|
self.assertIn(self.alice.address_hex, addresses_1)
|
|
self.assertIn(self.alice.address_hex, addresses_2)
|
|
self.assertIn(self.bob.address_hex, addresses_1)
|
|
self.assertIn(self.bob.address_hex, addresses_2)
|
|
|
|
ff.known_participant_addresses.clear()
|
|
self.assertEqual(len(ff.known_participant_addresses), 0)
|
|
|
|
ff.perform_iteration()
|
|
|
|
self.assertEqual(len(ff.known_participant_addresses), 2)
|
|
self.assertIn(bytes([int(h, 16) for h in self.alice.address_hex]),
|
|
map(bytes, ff.known_participant_addresses.values()))
|
|
self.assertIn(bytes([int(h, 16) for h in self.bob.address_hex]),
|
|
map(bytes, ff.known_participant_addresses.values()))
|
|
|
|
self.assertEqual(len(ff.message_types), 1)
|
|
mt = ff.message_types[0]
|
|
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
|
|
self.assertIsNotNone(dst_addr)
|
|
self.assertEqual(dst_addr.start, 32)
|
|
self.assertEqual(dst_addr.length, 16)
|
|
src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
|
|
self.assertIsNotNone(src_addr)
|
|
self.assertEqual(src_addr.start, 48)
|
|
self.assertEqual(src_addr.length, 16)
|
|
|
|
def test_two_participants_with_ack_messages(self):
|
|
mb = MessageTypeBuilder("data")
|
|
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
|
mb.add_label(FieldType.Function.SYNC, 16)
|
|
mb.add_label(FieldType.Function.LENGTH, 8)
|
|
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
|
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
|
mb_ack = MessageTypeBuilder("ack")
|
|
mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
|
|
mb_ack.add_label(FieldType.Function.SYNC, 16)
|
|
mb_ack.add_label(FieldType.Function.LENGTH, 8)
|
|
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
|
|
|
|
num_messages = 50
|
|
|
|
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
|
|
syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"},
|
|
participants=[self.alice, self.bob])
|
|
|
|
random.seed(0)
|
|
for i in range(num_messages):
|
|
if i % 2 == 0:
|
|
source, destination = self.alice, self.bob
|
|
data_length = 8
|
|
else:
|
|
source, destination = self.bob, self.alice
|
|
data_length = 16
|
|
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
|
|
source=source, destination=destination)
|
|
pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination)
|
|
|
|
#self.save_protocol("address_two_participants_with_acks", pg)
|
|
|
|
self.clear_message_types(pg.protocol.messages)
|
|
ff = FormatFinder(pg.protocol.messages)
|
|
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
|
|
address_dict = address_engine.find_addresses()
|
|
self.assertEqual(len(address_dict), 2)
|
|
addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0]))
|
|
addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1]))
|
|
self.assertIn(self.alice.address_hex, addresses_1)
|
|
self.assertIn(self.alice.address_hex, addresses_2)
|
|
self.assertIn(self.bob.address_hex, addresses_1)
|
|
self.assertIn(self.bob.address_hex, addresses_2)
|
|
|
|
ff.known_participant_addresses.clear()
|
|
ff.perform_iteration()
|
|
self.assertEqual(len(ff.message_types), 2)
|
|
mt = ff.message_types[1]
|
|
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
|
|
self.assertIsNotNone(dst_addr)
|
|
self.assertEqual(dst_addr.start, 32)
|
|
self.assertEqual(dst_addr.length, 16)
|
|
src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
|
|
self.assertIsNotNone(src_addr)
|
|
self.assertEqual(src_addr.start, 48)
|
|
self.assertEqual(src_addr.length, 16)
|
|
|
|
mt = ff.message_types[0]
|
|
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
|
|
self.assertIsNotNone(dst_addr)
|
|
self.assertEqual(dst_addr.start, 32)
|
|
self.assertEqual(dst_addr.length, 16)
|
|
|
|
def test_two_participants_with_ack_messages_and_type(self):
|
|
mb = MessageTypeBuilder("data")
|
|
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
|
mb.add_label(FieldType.Function.SYNC, 16)
|
|
mb.add_label(FieldType.Function.LENGTH, 8)
|
|
mb.add_label(FieldType.Function.TYPE, 8)
|
|
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
|
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
|
mb_ack = MessageTypeBuilder("ack")
|
|
mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
|
|
mb_ack.add_label(FieldType.Function.SYNC, 16)
|
|
mb_ack.add_label(FieldType.Function.LENGTH, 8)
|
|
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
|
|
|
|
num_messages = 50
|
|
|
|
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
|
|
syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"},
|
|
participants=[self.alice, self.bob])
|
|
|
|
random.seed(0)
|
|
for i in range(num_messages):
|
|
if i % 2 == 0:
|
|
source, destination = self.alice, self.bob
|
|
data_length = 8
|
|
else:
|
|
source, destination = self.bob, self.alice
|
|
data_length = 16
|
|
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
|
|
source=source, destination=destination)
|
|
pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination)
|
|
|
|
#self.save_protocol("address_two_participants_with_acks_and_types", pg)
|
|
|
|
self.clear_message_types(pg.protocol.messages)
|
|
ff = FormatFinder(pg.protocol.messages)
|
|
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
|
|
address_dict = address_engine.find_addresses()
|
|
self.assertEqual(len(address_dict), 2)
|
|
addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0]))
|
|
addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1]))
|
|
self.assertIn(self.alice.address_hex, addresses_1)
|
|
self.assertIn(self.alice.address_hex, addresses_2)
|
|
self.assertIn(self.bob.address_hex, addresses_1)
|
|
self.assertIn(self.bob.address_hex, addresses_2)
|
|
|
|
ff.known_participant_addresses.clear()
|
|
ff.perform_iteration()
|
|
self.assertEqual(len(ff.message_types), 2)
|
|
mt = ff.message_types[1]
|
|
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
|
|
self.assertIsNotNone(dst_addr)
|
|
self.assertEqual(dst_addr.start, 40)
|
|
self.assertEqual(dst_addr.length, 16)
|
|
src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
|
|
self.assertIsNotNone(src_addr)
|
|
self.assertEqual(src_addr.start, 56)
|
|
self.assertEqual(src_addr.length, 16)
|
|
|
|
mt = ff.message_types[0]
|
|
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
|
|
self.assertIsNotNone(dst_addr)
|
|
self.assertEqual(dst_addr.start, 32)
|
|
self.assertEqual(dst_addr.length, 16)
|
|
|
|
def test_three_participants_with_ack(self):
|
|
alice = Participant("Alice", address_hex="1337")
|
|
bob = Participant("Bob", address_hex="4711")
|
|
carl = Participant("Carl", address_hex="cafe")
|
|
|
|
mb = MessageTypeBuilder("data")
|
|
mb.add_label(FieldType.Function.PREAMBLE, 16)
|
|
mb.add_label(FieldType.Function.SYNC, 16)
|
|
mb.add_label(FieldType.Function.LENGTH, 8)
|
|
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
|
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
|
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
|
|
|
|
mb_ack = MessageTypeBuilder("ack")
|
|
mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
|
|
mb_ack.add_label(FieldType.Function.SYNC, 16)
|
|
mb_ack.add_label(FieldType.Function.LENGTH, 8)
|
|
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
|
|
|
|
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
|
|
syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"},
|
|
preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8},
|
|
participants=[alice, bob, carl])
|
|
|
|
i = -1
|
|
while len(pg.protocol.messages) < 20:
|
|
i += 1
|
|
source = pg.participants[i % len(pg.participants)]
|
|
destination = pg.participants[(i + 1) % len(pg.participants)]
|
|
if i % 2 == 0:
|
|
data_bytes = 8
|
|
else:
|
|
data_bytes = 16
|
|
|
|
data = "".join(random.choice(["0", "1"]) for _ in range(data_bytes * 8))
|
|
pg.generate_message(data=data, source=source, destination=destination)
|
|
|
|
if "ack" in (msg_type.name for msg_type in pg.protocol.message_types):
|
|
pg.generate_message(message_type=1, data="", source=destination, destination=source)
|
|
|
|
self.clear_message_types(pg.protocol.messages)
|
|
ff = FormatFinder(pg.protocol.messages)
|
|
ff.known_participant_addresses.clear()
|
|
self.assertEqual(len(ff.known_participant_addresses), 0)
|
|
ff.run()
|
|
|
|
# Since there are ACKS in this protocol, the engine must be able to assign the correct participant addresses
|
|
# IN CORRECT ORDER!
|
|
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[0]), "1337")
|
|
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[1]), "4711")
|
|
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[2]), "cafe")
|
|
|
|
def test_protocol_with_acks_and_checksum(self):
|
|
proto_file = get_path_for_data_file("ack_frames_with_crc.proto.xml")
|
|
protocol = ProtocolAnalyzer(signal=None, filename=proto_file)
|
|
protocol.from_xml_file(filename=proto_file, read_bits=True)
|
|
|
|
self.clear_message_types(protocol.messages)
|
|
|
|
ff = FormatFinder(protocol.messages)
|
|
ff.known_participant_addresses.clear()
|
|
|
|
ff.run()
|
|
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[0]), "1337")
|
|
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[1]), "4711")
|
|
|
|
for mt in ff.message_types:
|
|
preamble = mt.get_first_label_with_type(FieldType.Function.PREAMBLE)
|
|
self.assertEqual(preamble.start, 0)
|
|
self.assertEqual(preamble.length, 16)
|
|
sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
|
|
self.assertEqual(sync.start, 16)
|
|
self.assertEqual(sync.length, 16)
|
|
length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
|
|
self.assertEqual(length.start, 32)
|
|
self.assertEqual(length.length, 8)
|
|
|
|
def test_address_engine_performance(self):
|
|
ff, messages = self.get_format_finder_from_protocol_file("35_messages.proto.xml", return_messages=True)
|
|
|
|
engine = AddressEngine(ff.hexvectors, ff.participant_indices)
|
|
engine.find()
|
|
|
|
def test_paper_example(self):
|
|
alice = Participant("Alice", "A")
|
|
bob = Participant("Bob", "B")
|
|
participants = [alice, bob]
|
|
msg1 = Message.from_plain_hex_str("aabb1234")
|
|
msg1.participant = alice
|
|
msg2 = Message.from_plain_hex_str("aabb6789")
|
|
msg2.participant = alice
|
|
msg3 = Message.from_plain_hex_str("bbaa4711")
|
|
msg3.participant = bob
|
|
msg4 = Message.from_plain_hex_str("bbaa1337")
|
|
msg4.participant = bob
|
|
|
|
protocol = ProtocolAnalyzer(None)
|
|
protocol.messages.extend([msg1, msg2, msg3, msg4])
|
|
#self.save_protocol("paper_example", protocol)
|
|
|
|
bitvectors = FormatFinder.get_bitvectors_from_messages(protocol.messages)
|
|
hexvectors = FormatFinder.get_hexvectors(bitvectors)
|
|
address_engine = AddressEngine(hexvectors, participant_indices=[participants.index(msg.participant) for msg in
|
|
protocol.messages])
|
|
|
|
def test_find_common_sub_sequence(self):
|
|
from urh.cythonext import awre_util
|
|
str1 = "0612345678"
|
|
str2 = "0756781234"
|
|
|
|
seq1 = np.array(list(map(int, str1)), dtype=np.uint8, order="C")
|
|
seq2 = np.array(list(map(int, str2)), dtype=np.uint8, order="C")
|
|
|
|
indices = awre_util.find_longest_common_sub_sequence_indices(seq1, seq2)
|
|
self.assertEqual(len(indices), 2)
|
|
for ind in indices:
|
|
s = str1[slice(*ind)]
|
|
self.assertIn(s, ("5678", "1234"))
|
|
self.assertIn(s, str1)
|
|
self.assertIn(s, str2)
|
|
|
|
def test_find_first_occurrence(self):
|
|
from urh.cythonext import awre_util
|
|
str1 = "00" * 100 + "1234500012345" + "00" * 100
|
|
str2 = "12345"
|
|
|
|
seq1 = np.array(list(map(int, str1)), dtype=np.uint8, order="C")
|
|
seq2 = np.array(list(map(int, str2)), dtype=np.uint8, order="C")
|
|
indices = awre_util.find_occurrences(seq1, seq2)
|
|
self.assertEqual(len(indices), 2)
|
|
index = indices[0]
|
|
self.assertEqual(str1[index:index + len(str2)], str2)
|
|
|
|
# Test with ignoring indices
|
|
indices = awre_util.find_occurrences(seq1, seq2, array("L", list(range(0, 205))))
|
|
self.assertEqual(len(indices), 1)
|
|
|
|
# Test with ignoring indices
|
|
indices = awre_util.find_occurrences(seq1, seq2, array("L", list(range(0, 210))))
|
|
self.assertEqual(len(indices), 0)
|
|
|
|
self.assertEqual(awre_util.find_occurrences(seq1, np.ones(10, dtype=np.uint8)), [])
|