HackRF-Treasure-Chest/Software/Universal Radio Hacker/tests/awre/test_address_engine.py
2022-09-22 13:46:47 -07:00

387 lines
17 KiB
Python

import random
from array import array
import numpy as np
from tests.awre.AWRETestCase import AWRETestCase
from tests.utils_testing import get_path_for_data_file
from urh.awre.FormatFinder import FormatFinder
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
from urh.awre.ProtocolGenerator import ProtocolGenerator
from urh.awre.engines.AddressEngine import AddressEngine
from urh.signalprocessing.FieldType import FieldType
from urh.signalprocessing.Message import Message
from urh.signalprocessing.Participant import Participant
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
from urh.util import util
class TestAddressEngine(AWRETestCase):
def setUp(self):
super().setUp()
self.alice = Participant("Alice", "A", address_hex="1234")
self.bob = Participant("Bob", "B", address_hex="cafe")
def test_one_participant(self):
"""
Test a simple protocol with
preamble, sync and length field (8 bit) and some random data
:return:
"""
mb = MessageTypeBuilder("simple_address_test")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
num_messages_by_data_length = {8: 5, 16: 10, 32: 15}
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x9a9d"},
participants=[self.alice])
for data_length, num_messages in num_messages_by_data_length.items():
for i in range(num_messages):
pg.generate_message(data=pg.decimal_to_bits(22 * i, data_length), source=self.alice)
#self.save_protocol("address_one_participant", pg)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
address_dict = address_engine.find_addresses()
self.assertEqual(len(address_dict), 0)
def test_two_participants(self):
mb = MessageTypeBuilder("address_two_participants")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
num_messages = 50
pg = ProtocolGenerator([mb.message_type],
syncs_by_mt={mb.message_type: "0x9a9d"},
participants=[self.alice, self.bob])
for i in range(num_messages):
if i % 2 == 0:
source, destination = self.alice, self.bob
data_length = 8
else:
source, destination = self.bob, self.alice
data_length = 16
pg.generate_message(data=pg.decimal_to_bits(4 * i, data_length), source=source, destination=destination)
#self.save_protocol("address_two_participants", pg)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
address_dict = address_engine.find_addresses()
self.assertEqual(len(address_dict), 2)
addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0]))
addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1]))
self.assertIn(self.alice.address_hex, addresses_1)
self.assertIn(self.alice.address_hex, addresses_2)
self.assertIn(self.bob.address_hex, addresses_1)
self.assertIn(self.bob.address_hex, addresses_2)
ff.known_participant_addresses.clear()
self.assertEqual(len(ff.known_participant_addresses), 0)
ff.perform_iteration()
self.assertEqual(len(ff.known_participant_addresses), 2)
self.assertIn(bytes([int(h, 16) for h in self.alice.address_hex]),
map(bytes, ff.known_participant_addresses.values()))
self.assertIn(bytes([int(h, 16) for h in self.bob.address_hex]),
map(bytes, ff.known_participant_addresses.values()))
self.assertEqual(len(ff.message_types), 1)
mt = ff.message_types[0]
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
self.assertIsNotNone(dst_addr)
self.assertEqual(dst_addr.start, 32)
self.assertEqual(dst_addr.length, 16)
src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
self.assertIsNotNone(src_addr)
self.assertEqual(src_addr.start, 48)
self.assertEqual(src_addr.length, 16)
def test_two_participants_with_ack_messages(self):
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb_ack = MessageTypeBuilder("ack")
mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
mb_ack.add_label(FieldType.Function.SYNC, 16)
mb_ack.add_label(FieldType.Function.LENGTH, 8)
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
num_messages = 50
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"},
participants=[self.alice, self.bob])
random.seed(0)
for i in range(num_messages):
if i % 2 == 0:
source, destination = self.alice, self.bob
data_length = 8
else:
source, destination = self.bob, self.alice
data_length = 16
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
source=source, destination=destination)
pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination)
#self.save_protocol("address_two_participants_with_acks", pg)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
address_dict = address_engine.find_addresses()
self.assertEqual(len(address_dict), 2)
addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0]))
addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1]))
self.assertIn(self.alice.address_hex, addresses_1)
self.assertIn(self.alice.address_hex, addresses_2)
self.assertIn(self.bob.address_hex, addresses_1)
self.assertIn(self.bob.address_hex, addresses_2)
ff.known_participant_addresses.clear()
ff.perform_iteration()
self.assertEqual(len(ff.message_types), 2)
mt = ff.message_types[1]
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
self.assertIsNotNone(dst_addr)
self.assertEqual(dst_addr.start, 32)
self.assertEqual(dst_addr.length, 16)
src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
self.assertIsNotNone(src_addr)
self.assertEqual(src_addr.start, 48)
self.assertEqual(src_addr.length, 16)
mt = ff.message_types[0]
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
self.assertIsNotNone(dst_addr)
self.assertEqual(dst_addr.start, 32)
self.assertEqual(dst_addr.length, 16)
def test_two_participants_with_ack_messages_and_type(self):
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.PREAMBLE, 8)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.TYPE, 8)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb_ack = MessageTypeBuilder("ack")
mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
mb_ack.add_label(FieldType.Function.SYNC, 16)
mb_ack.add_label(FieldType.Function.LENGTH, 8)
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
num_messages = 50
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"},
participants=[self.alice, self.bob])
random.seed(0)
for i in range(num_messages):
if i % 2 == 0:
source, destination = self.alice, self.bob
data_length = 8
else:
source, destination = self.bob, self.alice
data_length = 16
pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
source=source, destination=destination)
pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination)
#self.save_protocol("address_two_participants_with_acks_and_types", pg)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
address_dict = address_engine.find_addresses()
self.assertEqual(len(address_dict), 2)
addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0]))
addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1]))
self.assertIn(self.alice.address_hex, addresses_1)
self.assertIn(self.alice.address_hex, addresses_2)
self.assertIn(self.bob.address_hex, addresses_1)
self.assertIn(self.bob.address_hex, addresses_2)
ff.known_participant_addresses.clear()
ff.perform_iteration()
self.assertEqual(len(ff.message_types), 2)
mt = ff.message_types[1]
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
self.assertIsNotNone(dst_addr)
self.assertEqual(dst_addr.start, 40)
self.assertEqual(dst_addr.length, 16)
src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
self.assertIsNotNone(src_addr)
self.assertEqual(src_addr.start, 56)
self.assertEqual(src_addr.length, 16)
mt = ff.message_types[0]
dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
self.assertIsNotNone(dst_addr)
self.assertEqual(dst_addr.start, 32)
self.assertEqual(dst_addr.length, 16)
def test_three_participants_with_ack(self):
alice = Participant("Alice", address_hex="1337")
bob = Participant("Bob", address_hex="4711")
carl = Participant("Carl", address_hex="cafe")
mb = MessageTypeBuilder("data")
mb.add_label(FieldType.Function.PREAMBLE, 16)
mb.add_label(FieldType.Function.SYNC, 16)
mb.add_label(FieldType.Function.LENGTH, 8)
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
mb_ack = MessageTypeBuilder("ack")
mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
mb_ack.add_label(FieldType.Function.SYNC, 16)
mb_ack.add_label(FieldType.Function.LENGTH, 8)
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"},
preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8},
participants=[alice, bob, carl])
i = -1
while len(pg.protocol.messages) < 20:
i += 1
source = pg.participants[i % len(pg.participants)]
destination = pg.participants[(i + 1) % len(pg.participants)]
if i % 2 == 0:
data_bytes = 8
else:
data_bytes = 16
data = "".join(random.choice(["0", "1"]) for _ in range(data_bytes * 8))
pg.generate_message(data=data, source=source, destination=destination)
if "ack" in (msg_type.name for msg_type in pg.protocol.message_types):
pg.generate_message(message_type=1, data="", source=destination, destination=source)
self.clear_message_types(pg.protocol.messages)
ff = FormatFinder(pg.protocol.messages)
ff.known_participant_addresses.clear()
self.assertEqual(len(ff.known_participant_addresses), 0)
ff.run()
# Since there are ACKS in this protocol, the engine must be able to assign the correct participant addresses
# IN CORRECT ORDER!
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[0]), "1337")
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[1]), "4711")
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[2]), "cafe")
def test_protocol_with_acks_and_checksum(self):
proto_file = get_path_for_data_file("ack_frames_with_crc.proto.xml")
protocol = ProtocolAnalyzer(signal=None, filename=proto_file)
protocol.from_xml_file(filename=proto_file, read_bits=True)
self.clear_message_types(protocol.messages)
ff = FormatFinder(protocol.messages)
ff.known_participant_addresses.clear()
ff.run()
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[0]), "1337")
self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[1]), "4711")
for mt in ff.message_types:
preamble = mt.get_first_label_with_type(FieldType.Function.PREAMBLE)
self.assertEqual(preamble.start, 0)
self.assertEqual(preamble.length, 16)
sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
self.assertEqual(sync.start, 16)
self.assertEqual(sync.length, 16)
length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
self.assertEqual(length.start, 32)
self.assertEqual(length.length, 8)
def test_address_engine_performance(self):
ff, messages = self.get_format_finder_from_protocol_file("35_messages.proto.xml", return_messages=True)
engine = AddressEngine(ff.hexvectors, ff.participant_indices)
engine.find()
def test_paper_example(self):
alice = Participant("Alice", "A")
bob = Participant("Bob", "B")
participants = [alice, bob]
msg1 = Message.from_plain_hex_str("aabb1234")
msg1.participant = alice
msg2 = Message.from_plain_hex_str("aabb6789")
msg2.participant = alice
msg3 = Message.from_plain_hex_str("bbaa4711")
msg3.participant = bob
msg4 = Message.from_plain_hex_str("bbaa1337")
msg4.participant = bob
protocol = ProtocolAnalyzer(None)
protocol.messages.extend([msg1, msg2, msg3, msg4])
#self.save_protocol("paper_example", protocol)
bitvectors = FormatFinder.get_bitvectors_from_messages(protocol.messages)
hexvectors = FormatFinder.get_hexvectors(bitvectors)
address_engine = AddressEngine(hexvectors, participant_indices=[participants.index(msg.participant) for msg in
protocol.messages])
def test_find_common_sub_sequence(self):
from urh.cythonext import awre_util
str1 = "0612345678"
str2 = "0756781234"
seq1 = np.array(list(map(int, str1)), dtype=np.uint8, order="C")
seq2 = np.array(list(map(int, str2)), dtype=np.uint8, order="C")
indices = awre_util.find_longest_common_sub_sequence_indices(seq1, seq2)
self.assertEqual(len(indices), 2)
for ind in indices:
s = str1[slice(*ind)]
self.assertIn(s, ("5678", "1234"))
self.assertIn(s, str1)
self.assertIn(s, str2)
def test_find_first_occurrence(self):
from urh.cythonext import awre_util
str1 = "00" * 100 + "1234500012345" + "00" * 100
str2 = "12345"
seq1 = np.array(list(map(int, str1)), dtype=np.uint8, order="C")
seq2 = np.array(list(map(int, str2)), dtype=np.uint8, order="C")
indices = awre_util.find_occurrences(seq1, seq2)
self.assertEqual(len(indices), 2)
index = indices[0]
self.assertEqual(str1[index:index + len(str2)], str2)
# Test with ignoring indices
indices = awre_util.find_occurrences(seq1, seq2, array("L", list(range(0, 205))))
self.assertEqual(len(indices), 1)
# Test with ignoring indices
indices = awre_util.find_occurrences(seq1, seq2, array("L", list(range(0, 210))))
self.assertEqual(len(indices), 0)
self.assertEqual(awre_util.find_occurrences(seq1, np.ones(10, dtype=np.uint8)), [])