792 lines
35 KiB
Python
792 lines
35 KiB
Python
|
import array
|
||
|
import multiprocessing
|
||
|
import os
|
||
|
import random
|
||
|
import time
|
||
|
from collections import defaultdict
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
|
||
|
from tests.awre.AWRETestCase import AWRETestCase
|
||
|
from tests.utils_testing import get_path_for_data_file
|
||
|
from urh.awre.FormatFinder import FormatFinder
|
||
|
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
|
||
|
from urh.awre.Preprocessor import Preprocessor
|
||
|
from urh.awre.ProtocolGenerator import ProtocolGenerator
|
||
|
from urh.awre.engines.Engine import Engine
|
||
|
from urh.signalprocessing.FieldType import FieldType
|
||
|
from urh.signalprocessing.Message import Message
|
||
|
from urh.signalprocessing.MessageType import MessageType
|
||
|
from urh.signalprocessing.Participant import Participant
|
||
|
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
|
||
|
from urh.util.GenericCRC import GenericCRC
|
||
|
|
||
|
|
||
|
def run_for_num_broken(protocol_nr, num_broken: list, num_messages: int, num_runs: int) -> list:
|
||
|
random.seed(0)
|
||
|
np.random.seed(0)
|
||
|
|
||
|
result = []
|
||
|
for broken in num_broken:
|
||
|
tmp_accuracies = np.empty(num_runs, dtype=np.float64)
|
||
|
tmp_accuracies_without_broken = np.empty(num_runs, dtype=np.float64)
|
||
|
for i in range(num_runs):
|
||
|
protocol, expected_labels = AWRExperiments.get_protocol(protocol_nr,
|
||
|
num_messages=num_messages,
|
||
|
num_broken_messages=broken,
|
||
|
silent=True)
|
||
|
|
||
|
AWRExperiments.run_format_finder_for_protocol(protocol)
|
||
|
accuracy = AWRExperiments.calculate_accuracy(protocol.messages, expected_labels)
|
||
|
accuracy_without_broken = AWRExperiments.calculate_accuracy(protocol.messages, expected_labels, broken)
|
||
|
tmp_accuracies[i] = accuracy
|
||
|
tmp_accuracies_without_broken[i] = accuracy_without_broken
|
||
|
|
||
|
avg_accuracy = np.mean(tmp_accuracies)
|
||
|
avg_accuracy_without_broken = np.mean(tmp_accuracies_without_broken)
|
||
|
|
||
|
result.append((avg_accuracy, avg_accuracy_without_broken))
|
||
|
print("Protocol {} with {} broken: {:>3}% {:>3}%".format(protocol_nr, broken, int(avg_accuracy),
|
||
|
int(avg_accuracy_without_broken)))
|
||
|
|
||
|
return result
|
||
|
|
||
|
|
||
|
class AWRExperiments(AWRETestCase):
|
||
|
@staticmethod
|
||
|
def _prepare_protocol_1() -> ProtocolGenerator:
|
||
|
alice = Participant("Alice", address_hex="dead")
|
||
|
bob = Participant("Bob", address_hex="beef")
|
||
|
|
||
|
mb = MessageTypeBuilder("data")
|
||
|
mb.add_label(FieldType.Function.PREAMBLE, 8)
|
||
|
mb.add_label(FieldType.Function.SYNC, 16)
|
||
|
mb.add_label(FieldType.Function.LENGTH, 8)
|
||
|
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||
|
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||
|
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
|
||
|
|
||
|
pg = ProtocolGenerator([mb.message_type],
|
||
|
syncs_by_mt={mb.message_type: "0x1337"},
|
||
|
participants=[alice, bob])
|
||
|
return pg
|
||
|
|
||
|
@staticmethod
|
||
|
def _prepare_protocol_2() -> ProtocolGenerator:
|
||
|
alice = Participant("Alice", address_hex="dead01")
|
||
|
bob = Participant("Bob", address_hex="beef24")
|
||
|
|
||
|
mb = MessageTypeBuilder("data")
|
||
|
mb.add_label(FieldType.Function.PREAMBLE, 72)
|
||
|
mb.add_label(FieldType.Function.SYNC, 16)
|
||
|
mb.add_label(FieldType.Function.LENGTH, 8)
|
||
|
mb.add_label(FieldType.Function.SRC_ADDRESS, 24)
|
||
|
mb.add_label(FieldType.Function.DST_ADDRESS, 24)
|
||
|
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
|
||
|
|
||
|
pg = ProtocolGenerator([mb.message_type],
|
||
|
syncs_by_mt={mb.message_type: "0x1337"},
|
||
|
preambles_by_mt={mb.message_type: "10" * 36},
|
||
|
sequence_number_increment=32,
|
||
|
participants=[alice, bob])
|
||
|
|
||
|
return pg
|
||
|
|
||
|
@staticmethod
|
||
|
def _prepare_protocol_3() -> ProtocolGenerator:
|
||
|
alice = Participant("Alice", address_hex="1337")
|
||
|
bob = Participant("Bob", address_hex="beef")
|
||
|
|
||
|
checksum = GenericCRC.from_standard_checksum("CRC8 CCITT")
|
||
|
|
||
|
mb = MessageTypeBuilder("data")
|
||
|
mb.add_label(FieldType.Function.PREAMBLE, 16)
|
||
|
mb.add_label(FieldType.Function.SYNC, 16)
|
||
|
mb.add_label(FieldType.Function.LENGTH, 8)
|
||
|
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||
|
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||
|
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
|
||
|
mb.add_label(FieldType.Function.DATA, 10 * 8)
|
||
|
mb.add_checksum_label(8, checksum)
|
||
|
|
||
|
mb_ack = MessageTypeBuilder("ack")
|
||
|
mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
|
||
|
mb_ack.add_label(FieldType.Function.SYNC, 16)
|
||
|
mb_ack.add_label(FieldType.Function.LENGTH, 8)
|
||
|
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||
|
mb_ack.add_checksum_label(8, checksum)
|
||
|
|
||
|
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
|
||
|
syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"},
|
||
|
preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8},
|
||
|
participants=[alice, bob])
|
||
|
|
||
|
return pg
|
||
|
|
||
|
@staticmethod
|
||
|
def _prepare_protocol_4() -> ProtocolGenerator:
|
||
|
alice = Participant("Alice", address_hex="1337")
|
||
|
bob = Participant("Bob", address_hex="beef")
|
||
|
|
||
|
checksum = GenericCRC.from_standard_checksum("CRC16 CCITT")
|
||
|
|
||
|
mb = MessageTypeBuilder("data1")
|
||
|
mb.add_label(FieldType.Function.PREAMBLE, 16)
|
||
|
mb.add_label(FieldType.Function.SYNC, 16)
|
||
|
mb.add_label(FieldType.Function.LENGTH, 8)
|
||
|
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||
|
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||
|
mb.add_label(FieldType.Function.DATA, 8 * 8)
|
||
|
mb.add_checksum_label(16, checksum)
|
||
|
|
||
|
mb2 = MessageTypeBuilder("data2")
|
||
|
mb2.add_label(FieldType.Function.PREAMBLE, 16)
|
||
|
mb2.add_label(FieldType.Function.SYNC, 16)
|
||
|
mb2.add_label(FieldType.Function.LENGTH, 8)
|
||
|
mb2.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||
|
mb2.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||
|
mb2.add_label(FieldType.Function.DATA, 64 * 8)
|
||
|
mb2.add_checksum_label(16, checksum)
|
||
|
|
||
|
mb_ack = MessageTypeBuilder("ack")
|
||
|
mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
|
||
|
mb_ack.add_label(FieldType.Function.SYNC, 16)
|
||
|
mb_ack.add_label(FieldType.Function.LENGTH, 8)
|
||
|
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||
|
mb_ack.add_checksum_label(16, checksum)
|
||
|
|
||
|
mt1, mt2, mt3 = mb.message_type, mb2.message_type, mb_ack.message_type
|
||
|
|
||
|
preamble = "10001000" * 2
|
||
|
|
||
|
pg = ProtocolGenerator([mt1, mt2, mt3],
|
||
|
syncs_by_mt={mt1: "0x9a7d", mt2: "0x9a7d", mt3: "0x9a7d"},
|
||
|
preambles_by_mt={mt1: preamble, mt2: preamble, mt3: preamble},
|
||
|
participants=[alice, bob])
|
||
|
|
||
|
return pg
|
||
|
|
||
|
@staticmethod
|
||
|
def _prepare_protocol_5() -> ProtocolGenerator:
|
||
|
alice = Participant("Alice", address_hex="1337")
|
||
|
bob = Participant("Bob", address_hex="beef")
|
||
|
carl = Participant("Carl", address_hex="cafe")
|
||
|
|
||
|
mb = MessageTypeBuilder("data")
|
||
|
mb.add_label(FieldType.Function.PREAMBLE, 16)
|
||
|
mb.add_label(FieldType.Function.SYNC, 16)
|
||
|
mb.add_label(FieldType.Function.LENGTH, 8)
|
||
|
mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
|
||
|
mb.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||
|
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
|
||
|
|
||
|
mb_ack = MessageTypeBuilder("ack")
|
||
|
mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
|
||
|
mb_ack.add_label(FieldType.Function.SYNC, 16)
|
||
|
mb_ack.add_label(FieldType.Function.LENGTH, 8)
|
||
|
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
|
||
|
|
||
|
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
|
||
|
syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"},
|
||
|
preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8},
|
||
|
participants=[alice, bob, carl])
|
||
|
|
||
|
return pg
|
||
|
|
||
|
@staticmethod
|
||
|
def _prepare_protocol_6() -> ProtocolGenerator:
|
||
|
alice = Participant("Alice", address_hex="24")
|
||
|
broadcast = Participant("Bob", address_hex="ff")
|
||
|
|
||
|
mb = MessageTypeBuilder("data")
|
||
|
mb.add_label(FieldType.Function.SYNC, 16)
|
||
|
mb.add_label(FieldType.Function.LENGTH, 8)
|
||
|
mb.add_label(FieldType.Function.SRC_ADDRESS, 8)
|
||
|
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
|
||
|
|
||
|
pg = ProtocolGenerator([mb.message_type],
|
||
|
syncs_by_mt={mb.message_type: "0x8e88"},
|
||
|
preambles_by_mt={mb.message_type: "10" * 8},
|
||
|
participants=[alice, broadcast])
|
||
|
|
||
|
return pg
|
||
|
|
||
|
@staticmethod
|
||
|
def _prepare_protocol_7() -> ProtocolGenerator:
|
||
|
alice = Participant("Alice", address_hex="313370")
|
||
|
bob = Participant("Bob", address_hex="031337")
|
||
|
charly = Participant("Charly", address_hex="110000")
|
||
|
daniel = Participant("Daniel", address_hex="001100")
|
||
|
# broadcast = Participant("Broadcast", address_hex="ff") #TODO: Sometimes messages to broadcast
|
||
|
|
||
|
checksum = GenericCRC.from_standard_checksum("CRC16 CC1101")
|
||
|
|
||
|
mb = MessageTypeBuilder("data")
|
||
|
mb.add_label(FieldType.Function.PREAMBLE, 16)
|
||
|
mb.add_label(FieldType.Function.SYNC, 16)
|
||
|
mb.add_label(FieldType.Function.LENGTH, 8)
|
||
|
mb.add_label(FieldType.Function.DST_ADDRESS, 24)
|
||
|
mb.add_label(FieldType.Function.SRC_ADDRESS, 24)
|
||
|
mb.add_label(FieldType.Function.DATA, 8 * 8)
|
||
|
mb.add_checksum_label(16, checksum)
|
||
|
|
||
|
mb_ack = MessageTypeBuilder("ack")
|
||
|
mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
|
||
|
mb_ack.add_label(FieldType.Function.SYNC, 16)
|
||
|
mb_ack.add_label(FieldType.Function.DST_ADDRESS, 24)
|
||
|
mb_ack.add_checksum_label(16, checksum)
|
||
|
|
||
|
mb_kex = MessageTypeBuilder("kex")
|
||
|
mb_kex.add_label(FieldType.Function.PREAMBLE, 24)
|
||
|
mb_kex.add_label(FieldType.Function.SYNC, 16)
|
||
|
mb_kex.add_label(FieldType.Function.DST_ADDRESS, 24)
|
||
|
mb_kex.add_label(FieldType.Function.SRC_ADDRESS, 24)
|
||
|
mb_kex.add_label(FieldType.Function.DATA, 64 * 8)
|
||
|
mb_kex.add_checksum_label(16, checksum)
|
||
|
|
||
|
pg = ProtocolGenerator([mb.message_type, mb_ack.message_type, mb_kex.message_type],
|
||
|
syncs_by_mt={mb.message_type: "0x0420", mb_ack.message_type: "0x2222",
|
||
|
mb_kex.message_type: "0x6767"},
|
||
|
preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 4,
|
||
|
mb_kex.message_type: "10" * 12},
|
||
|
participants=[alice, bob, charly, daniel])
|
||
|
|
||
|
return pg
|
||
|
|
||
|
@staticmethod
|
||
|
def _prepare_protocol_8() -> ProtocolGenerator:
|
||
|
alice = Participant("Alice")
|
||
|
|
||
|
mb = MessageTypeBuilder("data1")
|
||
|
mb.add_label(FieldType.Function.PREAMBLE, 4)
|
||
|
mb.add_label(FieldType.Function.SYNC, 4)
|
||
|
mb.add_label(FieldType.Function.LENGTH, 16)
|
||
|
mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
|
||
|
mb.add_label(FieldType.Function.DATA, 8 * 542)
|
||
|
|
||
|
mb2 = MessageTypeBuilder("data2")
|
||
|
mb2.add_label(FieldType.Function.PREAMBLE, 4)
|
||
|
mb2.add_label(FieldType.Function.SYNC, 4)
|
||
|
mb2.add_label(FieldType.Function.LENGTH, 16)
|
||
|
mb2.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
|
||
|
mb2.add_label(FieldType.Function.DATA, 8 * 260)
|
||
|
|
||
|
pg = ProtocolGenerator([mb.message_type, mb2.message_type],
|
||
|
syncs_by_mt={mb.message_type: "0x9", mb2.message_type: "0x9"},
|
||
|
preambles_by_mt={mb.message_type: "10" * 2, mb2.message_type: "10" * 2},
|
||
|
sequence_number_increment=32,
|
||
|
participants=[alice],
|
||
|
little_endian=True)
|
||
|
|
||
|
return pg
|
||
|
|
||
|
def test_export_to_latex(self):
|
||
|
filename = os.path.expanduser("~/GIT/publications/awre/USENIX/protocols.tex")
|
||
|
if os.path.isfile(filename):
|
||
|
os.remove(filename)
|
||
|
|
||
|
for i in range(1, 9):
|
||
|
pg = getattr(self, "_prepare_protocol_" + str(i))()
|
||
|
pg.export_to_latex(filename, i)
|
||
|
|
||
|
@classmethod
|
||
|
def get_protocol(cls, protocol_number: int, num_messages, num_broken_messages=0, silent=False):
|
||
|
if protocol_number == 1:
|
||
|
pg = cls._prepare_protocol_1()
|
||
|
elif protocol_number == 2:
|
||
|
pg = cls._prepare_protocol_2()
|
||
|
elif protocol_number == 3:
|
||
|
pg = cls._prepare_protocol_3()
|
||
|
elif protocol_number == 4:
|
||
|
pg = cls._prepare_protocol_4()
|
||
|
elif protocol_number == 5:
|
||
|
pg = cls._prepare_protocol_5()
|
||
|
elif protocol_number == 6:
|
||
|
pg = cls._prepare_protocol_6()
|
||
|
elif protocol_number == 7:
|
||
|
pg = cls._prepare_protocol_7()
|
||
|
elif protocol_number == 8:
|
||
|
pg = cls._prepare_protocol_8()
|
||
|
else:
|
||
|
raise ValueError("Unknown protocol number")
|
||
|
|
||
|
messages_types_with_data_field = [mt for mt in pg.protocol.message_types
|
||
|
if mt.get_first_label_with_type(FieldType.Function.DATA)]
|
||
|
i = -1
|
||
|
while len(pg.protocol.messages) < num_messages:
|
||
|
i += 1
|
||
|
source = pg.participants[i % len(pg.participants)]
|
||
|
destination = pg.participants[(i + 1) % len(pg.participants)]
|
||
|
if i % 2 == 0:
|
||
|
data_bytes = 8
|
||
|
else:
|
||
|
# data_bytes = 16
|
||
|
data_bytes = 64
|
||
|
|
||
|
if len(messages_types_with_data_field) == 0:
|
||
|
# set data automatically
|
||
|
data = "".join(random.choice(["0", "1"]) for _ in range(data_bytes * 8))
|
||
|
pg.generate_message(data=data, source=source, destination=destination)
|
||
|
else:
|
||
|
# search for message type with right data length
|
||
|
mt = messages_types_with_data_field[i % len(messages_types_with_data_field)]
|
||
|
data_length = mt.get_first_label_with_type(FieldType.Function.DATA).length
|
||
|
data = "".join(random.choice(["0", "1"]) for _ in range(data_length))
|
||
|
pg.generate_message(message_type=mt, data=data, source=source, destination=destination)
|
||
|
|
||
|
ack_message_type = next((mt for mt in pg.protocol.message_types if "ack" in mt.name), None)
|
||
|
if ack_message_type:
|
||
|
pg.generate_message(message_type=ack_message_type, data="", source=destination, destination=source)
|
||
|
|
||
|
for i in range(num_broken_messages):
|
||
|
msg = pg.protocol.messages[i]
|
||
|
pos = random.randint(0, len(msg.plain_bits) // 2)
|
||
|
msg.plain_bits[pos:] = array.array("B",
|
||
|
[random.randint(0, 1) for _ in range(len(msg.plain_bits) - pos)])
|
||
|
|
||
|
if num_broken_messages == 0:
|
||
|
cls.save_protocol("protocol{}_{}_messages".format(protocol_number, num_messages), pg, silent=silent)
|
||
|
else:
|
||
|
cls.save_protocol("protocol{}_{}_broken".format(protocol_number, num_broken_messages), pg, silent=silent)
|
||
|
|
||
|
expected_message_types = [msg.message_type for msg in pg.protocol.messages]
|
||
|
|
||
|
# Delete message type information -> no prior knowledge
|
||
|
cls.clear_message_types(pg.protocol.messages)
|
||
|
|
||
|
# Delete data labels if present
|
||
|
for mt in expected_message_types:
|
||
|
data_lbl = mt.get_first_label_with_type(FieldType.Function.DATA)
|
||
|
if data_lbl:
|
||
|
mt.remove(data_lbl)
|
||
|
|
||
|
return pg.protocol, expected_message_types
|
||
|
|
||
|
@staticmethod
|
||
|
def calculate_accuracy(messages, expected_labels, num_broken_messages=0):
|
||
|
"""
|
||
|
Calculate the accuracy of labels compared to expected labels
|
||
|
Accuracy is 100% when labels == expected labels
|
||
|
Accuracy drops by 1 / len(expected_labels) for every expected label not present in labels
|
||
|
|
||
|
:type messages: list of Message
|
||
|
:type expected_labels: list of MessageType
|
||
|
:return:
|
||
|
"""
|
||
|
accuracy = sum(len(set(expected_labels[i]) & set(messages[i].message_type)) / len(expected_labels[i])
|
||
|
for i in range(num_broken_messages, len(messages)))
|
||
|
try:
|
||
|
accuracy /= (len(messages) - num_broken_messages)
|
||
|
except ZeroDivisionError:
|
||
|
accuracy = 0
|
||
|
|
||
|
return accuracy * 100
|
||
|
|
||
|
def test_against_num_messages(self):
|
||
|
num_messages = list(range(1, 24, 1))
|
||
|
accuracies = defaultdict(list)
|
||
|
|
||
|
protocols = [1, 2, 3, 4, 5, 6, 7, 8]
|
||
|
|
||
|
random.seed(0)
|
||
|
np.random.seed(0)
|
||
|
for protocol_nr in protocols:
|
||
|
for n in num_messages:
|
||
|
protocol, expected_labels = self.get_protocol(protocol_nr, num_messages=n)
|
||
|
self.run_format_finder_for_protocol(protocol)
|
||
|
|
||
|
accuracy = self.calculate_accuracy(protocol.messages, expected_labels)
|
||
|
accuracies["protocol {}".format(protocol_nr)].append(accuracy)
|
||
|
|
||
|
self.__plot(num_messages, accuracies, xlabel="Number of messages", ylabel="Accuracy in %", grid=True)
|
||
|
self.__export_to_csv("/tmp/accuray-vs-messages", num_messages, accuracies)
|
||
|
|
||
|
def test_against_error(self):
|
||
|
Engine._DEBUG_ = False
|
||
|
Preprocessor._DEBUG_ = False
|
||
|
|
||
|
num_runs = 100
|
||
|
|
||
|
num_messages = 30
|
||
|
num_broken_messages = list(range(0, num_messages + 1))
|
||
|
accuracies = defaultdict(list)
|
||
|
accuracies_without_broken = defaultdict(list)
|
||
|
|
||
|
protocols = [1, 2, 3, 4, 5, 6, 7, 8]
|
||
|
|
||
|
random.seed(0)
|
||
|
np.random.seed(0)
|
||
|
|
||
|
with multiprocessing.Pool() as p:
|
||
|
result = p.starmap(run_for_num_broken,
|
||
|
[(i, num_broken_messages, num_messages, num_runs) for i in protocols])
|
||
|
for i, acc in enumerate(result):
|
||
|
accuracies["protocol {}".format(i + 1)] = [a[0] for a in acc]
|
||
|
accuracies_without_broken["protocol {}".format(i + 1)] = [a[1] for a in acc]
|
||
|
|
||
|
self.__plot(100 * np.array(num_broken_messages) / num_messages, accuracies,
|
||
|
title="Overall Accuracy vs percentage of broken messages",
|
||
|
xlabel="Broken messages in %",
|
||
|
ylabel="Accuracy in %", grid=True)
|
||
|
self.__plot(100 * np.array(num_broken_messages) / num_messages, accuracies_without_broken,
|
||
|
title=" Accuracy of unbroken vs percentage of broken messages",
|
||
|
xlabel="Broken messages in %",
|
||
|
ylabel="Accuracy in %", grid=True)
|
||
|
self.__export_to_csv("/tmp/accuray-vs-error", num_broken_messages, accuracies, relative=num_messages)
|
||
|
self.__export_to_csv("/tmp/accuray-vs-error-without-broken", num_broken_messages, accuracies_without_broken,
|
||
|
relative=num_messages)
|
||
|
|
||
|
def test_performance(self):
|
||
|
Engine._DEBUG_ = False
|
||
|
Preprocessor._DEBUG_ = False
|
||
|
|
||
|
num_messages = list(range(200, 205, 5))
|
||
|
protocols = [1]
|
||
|
|
||
|
random.seed(0)
|
||
|
np.random.seed(0)
|
||
|
|
||
|
performances = defaultdict(list)
|
||
|
|
||
|
for protocol_nr in protocols:
|
||
|
print("Running for protocol", protocol_nr)
|
||
|
for messages in num_messages:
|
||
|
protocol, _ = self.get_protocol(protocol_nr, messages, silent=True)
|
||
|
|
||
|
t = time.time()
|
||
|
self.run_format_finder_for_protocol(protocol)
|
||
|
performances["protocol {}".format(protocol_nr)].append(time.time() - t)
|
||
|
|
||
|
# self.__plot(num_messages, performances, xlabel="Number of messages", ylabel="Time in seconds", grid=True)
|
||
|
|
||
|
def test_performance_real_protocols(self):
|
||
|
Engine._DEBUG_ = False
|
||
|
Preprocessor._DEBUG_ = False
|
||
|
|
||
|
num_runs = 100
|
||
|
|
||
|
num_messages = list(range(8, 512, 4))
|
||
|
protocol_names = ["enocean", "homematic", "rwe"]
|
||
|
|
||
|
random.seed(0)
|
||
|
np.random.seed(0)
|
||
|
|
||
|
performances = defaultdict(list)
|
||
|
|
||
|
for protocol_name in protocol_names:
|
||
|
for messages in num_messages:
|
||
|
if protocol_name == "homematic":
|
||
|
protocol = self.generate_homematic(messages, save_protocol=False)
|
||
|
elif protocol_name == "enocean":
|
||
|
protocol = self.generate_enocean(messages, save_protocol=False)
|
||
|
elif protocol_name == "rwe":
|
||
|
protocol = self.generate_rwe(messages, save_protocol=False)
|
||
|
else:
|
||
|
raise ValueError("Unknown protocol name")
|
||
|
|
||
|
tmp_performances = np.empty(num_runs, dtype=np.float64)
|
||
|
for i in range(num_runs):
|
||
|
print("\r{0} with {1:02d} messages ({2}/{3} runs)".format(protocol_name, messages, i + 1, num_runs),
|
||
|
flush=True, end="")
|
||
|
|
||
|
t = time.time()
|
||
|
self.run_format_finder_for_protocol(protocol)
|
||
|
tmp_performances[i] = time.time() - t
|
||
|
self.clear_message_types(protocol.messages)
|
||
|
|
||
|
mean_performance = tmp_performances.mean()
|
||
|
print(" {:.2f}s".format(mean_performance))
|
||
|
performances["{}".format(protocol_name)].append(mean_performance)
|
||
|
|
||
|
self.__plot(num_messages, performances, xlabel="Number of messages", ylabel="Time in seconds", grid=True)
|
||
|
self.__export_to_csv("/tmp/performance.csv", num_messages, performances)
|
||
|
|
||
|
@staticmethod
|
||
|
def __export_to_csv(filename: str, x: list, y: dict, relative=None):
|
||
|
if not filename.endswith(".csv"):
|
||
|
filename += ".csv"
|
||
|
|
||
|
with open(filename, "w") as f:
|
||
|
f.write("N,")
|
||
|
if relative is not None:
|
||
|
f.write("NRel,")
|
||
|
for y_cap in sorted(y):
|
||
|
f.write(y_cap + ",")
|
||
|
f.write("\n")
|
||
|
|
||
|
for i, x_val in enumerate(x):
|
||
|
f.write("{},".format(x_val))
|
||
|
if relative is not None:
|
||
|
f.write("{},".format(100 * x_val / relative))
|
||
|
|
||
|
for y_cap in sorted(y):
|
||
|
f.write("{},".format(y[y_cap][i]))
|
||
|
f.write("\n")
|
||
|
|
||
|
@staticmethod
|
||
|
def __plot(x: list, y: dict, xlabel: str, ylabel: str, grid=False, title=None):
|
||
|
plt.xlabel(xlabel)
|
||
|
plt.ylabel(ylabel)
|
||
|
|
||
|
for y_cap, y_values in sorted(y.items()):
|
||
|
plt.plot(x, y_values, label=y_cap)
|
||
|
|
||
|
if grid:
|
||
|
plt.grid()
|
||
|
|
||
|
if title:
|
||
|
plt.title(title)
|
||
|
|
||
|
plt.legend()
|
||
|
plt.show()
|
||
|
|
||
|
@staticmethod
|
||
|
def run_format_finder_for_protocol(protocol: ProtocolAnalyzer):
|
||
|
ff = FormatFinder(protocol.messages)
|
||
|
ff.known_participant_addresses.clear()
|
||
|
ff.run()
|
||
|
|
||
|
for msg_type, indices in ff.existing_message_types.items():
|
||
|
for i in indices:
|
||
|
protocol.messages[i].message_type = msg_type
|
||
|
|
||
|
@classmethod
|
||
|
def generate_homematic(cls, num_messages: int, save_protocol=True):
|
||
|
mb_m_frame = MessageTypeBuilder("mframe")
|
||
|
mb_c_frame = MessageTypeBuilder("cframe")
|
||
|
mb_r_frame = MessageTypeBuilder("rframe")
|
||
|
mb_a_frame = MessageTypeBuilder("aframe")
|
||
|
|
||
|
participants = [Participant("CCU", address_hex="3927cc"), Participant("Switch", address_hex="3101cc")]
|
||
|
|
||
|
checksum = GenericCRC.from_standard_checksum("CRC16 CC1101")
|
||
|
for mb_builder in [mb_m_frame, mb_c_frame, mb_r_frame, mb_a_frame]:
|
||
|
mb_builder.add_label(FieldType.Function.PREAMBLE, 32)
|
||
|
mb_builder.add_label(FieldType.Function.SYNC, 32)
|
||
|
mb_builder.add_label(FieldType.Function.LENGTH, 8)
|
||
|
mb_builder.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
|
||
|
mb_builder.add_label(FieldType.Function.TYPE, 16)
|
||
|
mb_builder.add_label(FieldType.Function.SRC_ADDRESS, 24)
|
||
|
mb_builder.add_label(FieldType.Function.DST_ADDRESS, 24)
|
||
|
if mb_builder.name == "mframe":
|
||
|
mb_builder.add_label(FieldType.Function.DATA, 16, name="command")
|
||
|
elif mb_builder.name == "cframe":
|
||
|
mb_builder.add_label(FieldType.Function.DATA, 16 * 4, name="command+challenge+magic")
|
||
|
elif mb_builder.name == "rframe":
|
||
|
mb_builder.add_label(FieldType.Function.DATA, 32 * 4, name="cipher")
|
||
|
elif mb_builder.name == "aframe":
|
||
|
mb_builder.add_label(FieldType.Function.DATA, 10 * 4, name="command + auth")
|
||
|
mb_builder.add_checksum_label(16, checksum)
|
||
|
|
||
|
message_types = [mb_m_frame.message_type, mb_c_frame.message_type, mb_r_frame.message_type,
|
||
|
mb_a_frame.message_type]
|
||
|
preamble = "0xaaaaaaaa"
|
||
|
sync = "0xe9cae9ca"
|
||
|
initial_sequence_number = 36
|
||
|
pg = ProtocolGenerator(message_types, participants,
|
||
|
preambles_by_mt={mt: preamble for mt in message_types},
|
||
|
syncs_by_mt={mt: sync for mt in message_types},
|
||
|
sequence_numbers={mt: initial_sequence_number for mt in message_types},
|
||
|
message_type_codes={mb_m_frame.message_type: 42560,
|
||
|
mb_c_frame.message_type: 40962,
|
||
|
mb_r_frame.message_type: 40963,
|
||
|
mb_a_frame.message_type: 32770})
|
||
|
|
||
|
for i in range(num_messages):
|
||
|
mt = pg.message_types[i % 4]
|
||
|
data_length = mt.get_first_label_with_type(FieldType.Function.DATA).length
|
||
|
data = "".join(random.choice(["0", "1"]) for _ in range(data_length))
|
||
|
pg.generate_message(mt, data, source=pg.participants[i % 2], destination=pg.participants[(i + 1) % 2])
|
||
|
|
||
|
if save_protocol:
|
||
|
cls.save_protocol("homematic", pg)
|
||
|
|
||
|
cls.clear_message_types(pg.messages)
|
||
|
return pg.protocol
|
||
|
|
||
|
@classmethod
|
||
|
def generate_enocean(cls, num_messages: int, save_protocol=True):
|
||
|
filename = get_path_for_data_file("enocean_bits.txt")
|
||
|
enocean_bits = []
|
||
|
with open(filename, "r") as f:
|
||
|
for line in map(str.strip, f):
|
||
|
enocean_bits.append(line)
|
||
|
|
||
|
protocol = ProtocolAnalyzer(None)
|
||
|
message_type = MessageType("empty")
|
||
|
for i in range(num_messages):
|
||
|
msg = Message.from_plain_bits_str(enocean_bits[i % len(enocean_bits)])
|
||
|
msg.message_type = message_type
|
||
|
protocol.messages.append(msg)
|
||
|
|
||
|
if save_protocol:
|
||
|
cls.save_protocol("enocean", protocol)
|
||
|
|
||
|
return protocol
|
||
|
|
||
|
@classmethod
|
||
|
def generate_rwe(cls, num_messages: int, save_protocol=True):
|
||
|
proto_file = get_path_for_data_file("rwe.proto.xml")
|
||
|
protocol = ProtocolAnalyzer(signal=None, filename=proto_file)
|
||
|
protocol.from_xml_file(filename=proto_file, read_bits=True)
|
||
|
messages = protocol.messages
|
||
|
|
||
|
result = ProtocolAnalyzer(None)
|
||
|
message_type = MessageType("empty")
|
||
|
for i in range(num_messages):
|
||
|
msg = messages[i % len(messages)] # type: Message
|
||
|
msg.message_type = message_type
|
||
|
result.messages.append(msg)
|
||
|
|
||
|
if save_protocol:
|
||
|
cls.save_protocol("rwe", result)
|
||
|
|
||
|
return result
|
||
|
|
||
|
def test_export_latex_table(self):
|
||
|
def bold_latex(s):
|
||
|
return r"\textbf{" + str(s) + r"}"
|
||
|
|
||
|
comments = {
|
||
|
1: "common protocol",
|
||
|
2: "unusual field sizes",
|
||
|
3: "contains ack and CRC8 CCITT",
|
||
|
4: "contains ack and CRC16 CCITT",
|
||
|
5: "three participants with ack frame",
|
||
|
6: "short address",
|
||
|
7: "four participants, varying preamble size, varying sync words",
|
||
|
8: "nibble fields + LE"
|
||
|
}
|
||
|
|
||
|
bold = {i: defaultdict(bool) for i in range(1, 9)}
|
||
|
bold[2][FieldType.Function.PREAMBLE] = True
|
||
|
bold[2][FieldType.Function.SRC_ADDRESS] = True
|
||
|
bold[2][FieldType.Function.DST_ADDRESS] = True
|
||
|
|
||
|
bold[3][FieldType.Function.CHECKSUM] = True
|
||
|
|
||
|
bold[4][FieldType.Function.CHECKSUM] = True
|
||
|
|
||
|
bold[6][FieldType.Function.SRC_ADDRESS] = True
|
||
|
|
||
|
bold[7][FieldType.Function.PREAMBLE] = True
|
||
|
bold[7][FieldType.Function.SYNC] = True
|
||
|
bold[7][FieldType.Function.SRC_ADDRESS] = True
|
||
|
bold[7][FieldType.Function.DST_ADDRESS] = True
|
||
|
|
||
|
bold[8][FieldType.Function.PREAMBLE] = True
|
||
|
bold[8][FieldType.Function.SYNC] = True
|
||
|
|
||
|
filename = os.path.expanduser("~/GIT/publications/awre/USENIX/protocol_table.tex")
|
||
|
rowcolors = [r"\rowcolor{black!10}", r"\rowcolor{black!20}"]
|
||
|
|
||
|
with open(filename, "w") as f:
|
||
|
f.write(r"\begin{table*}[!h]" + "\n")
|
||
|
f.write(
|
||
|
"\t" + r"\caption{Properties of tested protocols whereby $\times$ means field is not present and $N_P$ is the number of participants.}" + "\n")
|
||
|
f.write("\t" + r"\label{tab:protocols}" + "\n")
|
||
|
f.write("\t" + r"\centering" + "\n")
|
||
|
f.write("\t" + r"\begin{tabularx}{\linewidth}{cp{2.5cm}llcccccccc}" + "\n")
|
||
|
f.write("\t\t" + r"\hline" + "\n")
|
||
|
f.write("\t\t" + r"\rowcolor{black!90}" + "\n")
|
||
|
f.write("\t\t" + r"\textcolor{white}{\textbf{\#}} & "
|
||
|
r"\textcolor{white}{\textbf{Comment}} & "
|
||
|
r"\textcolor{white}{$\mathbf{ N_P }$} & "
|
||
|
r"\textcolor{white}{\textbf{Message}} & "
|
||
|
r"\textcolor{white}{\textbf{Even/odd}} & "
|
||
|
r"\multicolumn{7}{c}{\textcolor{white}{\textbf{Size of field in bit (BE=Big Endian, LE=Little Endian)}}}\\"
|
||
|
"\n\t\t"
|
||
|
r"\rowcolor{black!90}"
|
||
|
"\n\t\t"
|
||
|
r"& & & \textcolor{white}{\textbf{Type}} & \textcolor{white}{\textbf{message data}} &"
|
||
|
r"\textcolor{white}{Preamble} & "
|
||
|
r"\textcolor{white}{Sync} & "
|
||
|
r"\textcolor{white}{Length} & "
|
||
|
r"\textcolor{white}{SRC} & "
|
||
|
r"\textcolor{white}{DST} & "
|
||
|
r"\textcolor{white}{SEQ Nr} & "
|
||
|
r"\textcolor{white}{CRC} \\" + "\n")
|
||
|
f.write("\t\t" + r"\hline" + "\n")
|
||
|
|
||
|
rowcolor_index = 0
|
||
|
for i in range(1, 9):
|
||
|
pg = getattr(self, "_prepare_protocol_" + str(i))()
|
||
|
assert isinstance(pg, ProtocolGenerator)
|
||
|
|
||
|
try:
|
||
|
data1 = next(mt for mt in pg.message_types if mt.name == "data1")
|
||
|
data2 = next(mt for mt in pg.message_types if mt.name == "data2")
|
||
|
|
||
|
data1_len = data1.get_first_label_with_type(FieldType.Function.DATA).length // 8
|
||
|
data2_len = data2.get_first_label_with_type(FieldType.Function.DATA).length // 8
|
||
|
|
||
|
except StopIteration:
|
||
|
data1_len, data2_len = 8, 64
|
||
|
|
||
|
rowcolor = rowcolors[rowcolor_index % len(rowcolors)]
|
||
|
rowcount = 0
|
||
|
for j, mt in enumerate(pg.message_types):
|
||
|
if mt.name == "data2":
|
||
|
continue
|
||
|
|
||
|
rowcount += 1
|
||
|
if j == 0:
|
||
|
protocol_nr, participants = str(i), len(pg.participants)
|
||
|
if participants > 2:
|
||
|
participants = bold_latex(participants)
|
||
|
else:
|
||
|
protocol_nr, participants = " ", " "
|
||
|
|
||
|
f.write("\t\t" + rowcolor + "\n")
|
||
|
|
||
|
if len(pg.message_types) == 1 or (
|
||
|
mt.name == "data1" and "ack" not in {m.name for m in pg.message_types}):
|
||
|
f.write("\t\t{} & {} & {} & {} &".format(protocol_nr, comments[i], participants,
|
||
|
mt.name.replace("1", "")))
|
||
|
elif j == len(pg.message_types) - 1:
|
||
|
f.write(
|
||
|
"\t\t{} & \\multirow{{{}}}{{\\linewidth}}{{{}}} & {} & {} &".format(protocol_nr, -rowcount,
|
||
|
comments[i],
|
||
|
participants,
|
||
|
mt.name.replace("1",
|
||
|
"")))
|
||
|
else:
|
||
|
f.write("\t\t{} & & {} & {} &".format(protocol_nr, participants, mt.name.replace("1", "")))
|
||
|
data_lbl = mt.get_first_label_with_type(FieldType.Function.DATA)
|
||
|
|
||
|
if mt.name == "data1" or mt.name == "data2":
|
||
|
f.write("{}/{} byte &".format(data1_len, data2_len))
|
||
|
elif mt.name == "data" and data_lbl is None:
|
||
|
f.write("{}/{} byte &".format(data1_len, data2_len))
|
||
|
elif data_lbl is not None:
|
||
|
f.write("{0}/{0} byte & ".format(data_lbl.length // 8))
|
||
|
else:
|
||
|
f.write(r"$ \times $ & ")
|
||
|
|
||
|
for t in (FieldType.Function.PREAMBLE, FieldType.Function.SYNC, FieldType.Function.LENGTH,
|
||
|
FieldType.Function.SRC_ADDRESS, FieldType.Function.DST_ADDRESS,
|
||
|
FieldType.Function.SEQUENCE_NUMBER,
|
||
|
FieldType.Function.CHECKSUM):
|
||
|
lbl = mt.get_first_label_with_type(t)
|
||
|
if lbl is not None:
|
||
|
if bold[i][lbl.field_type.function]:
|
||
|
f.write(bold_latex(lbl.length))
|
||
|
else:
|
||
|
f.write(str(lbl.length))
|
||
|
if lbl.length > 8 and t in (FieldType.Function.LENGTH, FieldType.Function.SEQUENCE_NUMBER):
|
||
|
f.write(" ({})".format(bold_latex("LE") if pg.little_endian else "BE"))
|
||
|
else:
|
||
|
f.write(r"$ \times $")
|
||
|
|
||
|
if t != FieldType.Function.CHECKSUM:
|
||
|
f.write(" & ")
|
||
|
else:
|
||
|
f.write(r"\\" + "\n")
|
||
|
|
||
|
rowcolor_index += 1
|
||
|
|
||
|
f.write("\t" + r"\end{tabularx}" + "\n")
|
||
|
|
||
|
f.write(r"\end{table*}" + "\n")
|