More refactoring, more testing

This commit is contained in:
s3lph 2022-04-18 19:44:19 +02:00
parent 6dba5cc37f
commit 6aaf2f3baa
12 changed files with 432 additions and 101 deletions

View file

@ -50,7 +50,7 @@ class SchleuderApi:
# Perform the actual request # Perform the actual request
req = urllib.request.Request(url, data=payload, method=method, headers=self._headers) req = urllib.request.Request(url, data=payload, method=method, headers=self._headers)
resp = urllib.request.urlopen(req, context=context) resp = urllib.request.urlopen(req, context=context)
respdata: bytes = resp.read().decode() respdata: str = resp.read().decode()
if len(respdata) > 0: if len(respdata) > 0:
return json.loads(respdata) return json.loads(respdata)
return None return None

View file

@ -16,7 +16,7 @@ from datetime import datetime
import pgpy # type: ignore import pgpy # type: ignore
from multischleuder.reporting import ConflictMessage from multischleuder.reporting import ConflictMessage, Message
from multischleuder.types import SchleuderKey, SchleuderSubscriber from multischleuder.types import SchleuderKey, SchleuderSubscriber
@ -35,13 +35,13 @@ class KeyConflictResolution:
def resolve(self, def resolve(self,
target: str, target: str,
mail_from: str, mail_from: str,
subscriptions: List[SchleuderSubscriber]) -> Tuple[List[SchleuderSubscriber], List[ConflictMessage]]: subscriptions: List[SchleuderSubscriber]) -> Tuple[List[SchleuderSubscriber], List[Optional[Message]]]:
subs: Dict[str, List[SchleuderSubscriber]] = OrderedDict() subs: Dict[str, List[SchleuderSubscriber]] = OrderedDict()
for s in subscriptions: for s in subscriptions:
subs.setdefault(s.email, []).append(s) subs.setdefault(s.email, []).append(s)
# Perform conflict resolution for each set of subscribers with the same email # Perform conflict resolution for each set of subscribers with the same email
resolved: List[SchleuderSubscriber] = [] resolved: List[SchleuderSubscriber] = []
conflicts: List[ConflictMessage] = [] conflicts: List[Optional[Message]] = []
for c in subs.values(): for c in subs.values():
r, m = self._resolve(target, mail_from, c) r, m = self._resolve(target, mail_from, c)
if r is not None: if r is not None:
@ -86,6 +86,7 @@ class KeyConflictResolution:
def _should_send(self, digest: str) -> bool: def _should_send(self, digest: str) -> bool:
now = int(datetime.utcnow().timestamp()) now = int(datetime.utcnow().timestamp())
try:
with open(self._state_file, 'a+') as f: with open(self._state_file, 'a+') as f:
state: Dict[str, int] = {} state: Dict[str, int] = {}
if f.tell() > 0: if f.tell() > 0:
@ -109,6 +110,9 @@ class KeyConflictResolution:
f.truncate() f.truncate()
json.dump(state, f) json.dump(state, f)
return send return send
except BaseException:
self._logger.exception('Cannot open or write statefile. Not sending any messages!')
return False
def _make_digest(self, chosen: SchleuderSubscriber, candidates: List[SchleuderSubscriber]) -> str: def _make_digest(self, chosen: SchleuderSubscriber, candidates: List[SchleuderSubscriber]) -> str:
# Sort so the hash stays the same if the set of subscriptions is the same. # Sort so the hash stays the same if the set of subscriptions is the same.

View file

@ -1,5 +1,5 @@
from typing import Any, Dict, List from typing import Any, Dict, List, Tuple
import argparse import argparse
import logging import logging
@ -11,18 +11,21 @@ from multischleuder import __version__
from multischleuder.api import SchleuderApi from multischleuder.api import SchleuderApi
from multischleuder.conflict import KeyConflictResolution from multischleuder.conflict import KeyConflictResolution
from multischleuder.processor import MultiList from multischleuder.processor import MultiList
from multischleuder.reporting import Reporter
from multischleuder.smtp import SmtpClient from multischleuder.smtp import SmtpClient
def parse_list_config(api: SchleuderApi, def parse_list_config(api: SchleuderApi,
kcr: KeyConflictResolution, kcr: KeyConflictResolution,
smtp: SmtpClient,
config: Dict[str, Any]) -> 'MultiList': config: Dict[str, Any]) -> 'MultiList':
target = config['target'] target = config['target']
default_from = target.replace('@', '-owner@') default_from = target.replace('@', '-owner@')
mail_from = config.get('from', default_from) mail_from = config.get('from', default_from)
banned = config.get('banned', []) banned = config.get('banned', [])
unmanaged = config.get('unmanaged', []) unmanaged = config.get('unmanaged', [])
reporter = Reporter(
send_admin_reports=config.get('send_admin_reports', True),
send_conflict_messages=config.get('send_conflict_messages', True))
return MultiList( return MultiList(
sources=config['sources'], sources=config['sources'],
target=target, target=target,
@ -31,13 +34,11 @@ def parse_list_config(api: SchleuderApi,
mail_from=mail_from, mail_from=mail_from,
api=api, api=api,
kcr=kcr, kcr=kcr,
smtp=smtp, reporter=reporter
send_admin_reports=config.get('send_admin_reports', True),
send_conflict_messages=config.get('send_conflict_messages', True),
) )
def parse_config(ns: argparse.Namespace) -> List['MultiList']: def parse_config(ns: argparse.Namespace) -> Tuple[List['MultiList'], SmtpClient]:
with open(ns.config, 'r') as f: with open(ns.config, 'r') as f:
c = yaml.safe_load(f) c = yaml.safe_load(f)
@ -56,9 +57,9 @@ def parse_config(ns: argparse.Namespace) -> List['MultiList']:
lists = [] lists = []
for clist in c.get('lists', []): for clist in c.get('lists', []):
ml = parse_list_config(api, kcr, smtp, clist) ml = parse_list_config(api, kcr, clist)
lists.append(ml) lists.append(ml)
return lists return lists, smtp
def main(): def main():
@ -72,6 +73,7 @@ def main():
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel('DEBUG') logger.setLevel('DEBUG')
logger.debug('Verbose logging enabled') logger.debug('Verbose logging enabled')
lists = parse_config(ns) lists, smtp = parse_config(ns)
for lst in lists: for lst in lists:
lst.process(ns.dry_run) lst.process(ns.dry_run)
smtp.send_messages(Reporter.get_messages())

View file

@ -5,8 +5,7 @@ import logging
from multischleuder.api import SchleuderApi from multischleuder.api import SchleuderApi
from multischleuder.conflict import KeyConflictResolution from multischleuder.conflict import KeyConflictResolution
from multischleuder.reporting import AdminReport, Message from multischleuder.reporting import AdminReport, Message, Reporter
from multischleuder.smtp import SmtpClient
from multischleuder.types import SchleuderKey, SchleuderList, SchleuderSubscriber from multischleuder.types import SchleuderKey, SchleuderList, SchleuderSubscriber
@ -20,9 +19,7 @@ class MultiList:
mail_from: str, mail_from: str,
api: SchleuderApi, api: SchleuderApi,
kcr: KeyConflictResolution, kcr: KeyConflictResolution,
smtp: SmtpClient, reporter: Reporter):
send_admin_reports: bool,
send_conflict_messages: bool):
self._sources: List[str] = sources self._sources: List[str] = sources
self._target: str = target self._target: str = target
self._unmanaged: List[str] = unmanaged self._unmanaged: List[str] = unmanaged
@ -30,10 +27,7 @@ class MultiList:
self._mail_from: str = mail_from self._mail_from: str = mail_from
self._api: SchleuderApi = api self._api: SchleuderApi = api
self._kcr: KeyConflictResolution = kcr self._kcr: KeyConflictResolution = kcr
self._smtp: SmtpClient = smtp self._reporter: Reporter = reporter
self._send_admin_reports: bool = send_admin_reports
self._send_conflict_messages: bool = send_conflict_messages
self._messages: List[Message] = []
self._logger: logging.Logger = logging.getLogger() self._logger: logging.Logger = logging.getLogger()
def process(self, dry_run: bool = False): def process(self, dry_run: bool = False):
@ -57,8 +51,7 @@ class MultiList:
all_subs.append(s) all_subs.append(s)
# ... which is taken care of by the key conflict resolution routine # ... which is taken care of by the key conflict resolution routine
resolved, conflicts = self._kcr.resolve(self._target, self._mail_from, all_subs) resolved, conflicts = self._kcr.resolve(self._target, self._mail_from, all_subs)
if self._send_conflict_messages: self._reporter.add_messages(conflicts)
self._messages.extend(conflicts)
intended_subs: Set[SchleuderSubscriber] = set(resolved) intended_subs: Set[SchleuderSubscriber] = set(resolved)
intended_keys: Set[SchleuderKey] = {s.key for s in intended_subs if s.key is not None} intended_keys: Set[SchleuderKey] = {s.key for s in intended_subs if s.key is not None}
# Determine the change set # Determine the change set
@ -86,21 +79,11 @@ class MultiList:
if len(to_add) + len(to_subscribe) + len(to_unsubscribe) + len(to_remove) == 0: if len(to_add) + len(to_subscribe) + len(to_unsubscribe) + len(to_remove) == 0:
self._logger.info(f'No changes for {self._target}') self._logger.info(f'No changes for {self._target}')
else:
if self._send_admin_reports:
for admin in target_admins: for admin in target_admins:
report = AdminReport(self._target, admin.email, self._mail_from, report = AdminReport(self._target, admin.email, self._mail_from,
admin.key.blob if admin.key is not None else None, admin.key.blob if admin.key is not None else None,
to_subscribe, to_unsubscribe, to_update, to_add, to_remove) to_subscribe, to_unsubscribe, to_update, to_add, to_remove)
self._messages.append(report) self._reporter.add_message(report)
print(str(report))
# Finally, send any queued messages.
if len(self._messages) > 0:
self._logger.info(f'Sending f{len(self._messages)} messages')
self._smtp.send_messages(self._messages)
self._messages = []
self._logger.info(f'Finished processing: {self._target}') self._logger.info(f'Finished processing: {self._target}')
def _lists_by_name(self) -> Tuple[SchleuderList, List[SchleuderList]]: def _lists_by_name(self) -> Tuple[SchleuderList, List[SchleuderList]]:

View file

@ -134,7 +134,8 @@ class AdminReport(Message):
removed: Set[SchleuderKey]): removed: Set[SchleuderKey]):
if len(subscribed) == 0 and len(unsubscribed) == 0 and \ if len(subscribed) == 0 and len(unsubscribed) == 0 and \
len(removed) == 0 and len(added) == 0 and len(updated) == 0: len(removed) == 0 and len(added) == 0 and len(updated) == 0:
raise ValueError('No changes, not creating admin report') # No changes, not creating admin report
return None
content = f''' content = f'''
== Admin Report for MultiSchleuder {schleuder} == == Admin Report for MultiSchleuder {schleuder} ==
''' '''
@ -179,3 +180,36 @@ class AdminReport(Message):
encrypt_to=[encrypt_to] if encrypt_to is not None else [] encrypt_to=[encrypt_to] if encrypt_to is not None else []
) )
self.mime['Subject'] = f'MultiSchleuder Admin Report: {self._schleuder}' self.mime['Subject'] = f'MultiSchleuder Admin Report: {self._schleuder}'
class Reporter:
_messages: List['Message'] = []
_logger: logging.Logger = logging.getLogger()
def __init__(self,
send_conflict_messages: bool,
send_admin_reports: bool):
self._send_conflict_messages: bool = send_conflict_messages
self._send_admin_reports: bool = send_admin_reports
def add_message(self, message: Optional[Message]):
if message is None:
return
if not self._send_conflict_messages and isinstance(message, ConflictMessage):
return
if not self._send_admin_reports and isinstance(message, AdminReport):
return
self.__class__._messages.append(message)
def add_messages(self, messages: List[Optional[Message]]):
for msg in messages:
self.add_message(msg)
@classmethod
def get_messages(cls) -> List['Message']:
return list(cls._messages)
@classmethod
def clear_messages(cls):
cls._messages.clear()

View file

@ -64,9 +64,9 @@ class SmtpClient:
for m in messages: for m in messages:
msg = m.mime msg = m.mime
self._logger.debug(f'MIME Message:\n{str(msg)}') self._logger.debug(f'MIME Message:\n{str(msg)}')
self.send_message(msg) self._send_message(msg)
def send_message(self, msg: email.message.Message): def _send_message(self, msg: email.message.Message):
if self._smtp is None: if self._smtp is None:
raise RuntimeError('SMTP connection is not established') raise RuntimeError('SMTP connection is not established')
if not self._dry_run: if not self._dry_run:

View file

@ -71,7 +71,7 @@ _SUBSCRIBER_RESPONSE = '''
"list_id": 42, "list_id": 42,
"email": "andy.example@example.org", "email": "andy.example@example.org",
"fingerprint": "ADB9BC679FF53CC8EF66FAC39348FDAB7A7663F9", "fingerprint": "ADB9BC679FF53CC8EF66FAC39348FDAB7A7663F9",
"admin": false, "admin": true,
"delivery_enabled": true, "delivery_enabled": true,
"created_at": "2022-04-15T01:11:12.123Z", "created_at": "2022-04-15T01:11:12.123Z",
"updated_at": "2022-04-15T01:11:12.123Z" "updated_at": "2022-04-15T01:11:12.123Z"
@ -86,7 +86,7 @@ _SUBSCRIBER_RESPONSE_NOKEY = '''
"list_id": 42, "list_id": 42,
"email": "andy.example@example.org", "email": "andy.example@example.org",
"fingerprint": "", "fingerprint": "",
"admin": true, "admin": false,
"delivery_enabled": true, "delivery_enabled": true,
"created_at": "2022-04-15T01:11:12.123Z", "created_at": "2022-04-15T01:11:12.123Z",
"updated_at": "2022-04-15T01:11:12.123Z" "updated_at": "2022-04-15T01:11:12.123Z"
@ -117,6 +117,8 @@ class TestSchleuderApi(unittest.TestCase):
def read(): def read():
url = mock.call_args_list[-1][0][0].get_full_url() url = mock.call_args_list[-1][0][0].get_full_url()
method = mock.call_args_list[-1][0][0].method
if method == 'GET':
if '/lists' in url: if '/lists' in url:
return _LIST_RESPONSE.encode() return _LIST_RESPONSE.encode()
if '/subscriptions' in url: if '/subscriptions' in url:
@ -126,6 +128,8 @@ class TestSchleuderApi(unittest.TestCase):
if '/keys' in url: if '/keys' in url:
return _KEY_RESPONSE.encode() return _KEY_RESPONSE.encode()
return b'null' return b'null'
else:
return b''
m.read = read m.read = read
m.__enter__.return_value = m m.__enter__.return_value = m
mock.return_value = m mock.return_value = m
@ -146,16 +150,16 @@ class TestSchleuderApi(unittest.TestCase):
@patch('urllib.request.urlopen') @patch('urllib.request.urlopen')
def test_get_list_admins(self, mock): def test_get_list_admins(self, mock):
api = self._mock_api(mock)
admins = api.get_list_admins(SchleuderList(42, '', ''))
self.assertEqual(0, len(admins))
api = self._mock_api(mock, nokey=True) api = self._mock_api(mock, nokey=True)
admins = api.get_list_admins(SchleuderList(42, '', '')) admins = api.get_list_admins(SchleuderList(42, '', ''))
self.assertEqual(0, len(admins))
api = self._mock_api(mock)
admins = api.get_list_admins(SchleuderList(42, '', ''))
self.assertEqual(1, len(admins)) self.assertEqual(1, len(admins))
self.assertEqual(24, admins[0].id) self.assertEqual(23, admins[0].id)
self.assertEqual('andy.example@example.org', admins[0].email) self.assertEqual('andy.example@example.org', admins[0].email)
self.assertEqual(42, admins[0].schleuder) self.assertEqual(42, admins[0].schleuder)
self.assertIsNone(admins[0].key) self.assertEqual('ADB9BC679FF53CC8EF66FAC39348FDAB7A7663F9', admins[0].key.fingerprint)
@patch('urllib.request.urlopen') @patch('urllib.request.urlopen')
def test_get_subscribers(self, mock): def test_get_subscribers(self, mock):

View file

@ -25,7 +25,7 @@ api:
cafile: /tmp/ca.pem cafile: /tmp/ca.pem
smtp: smtp:
host: smtp.example.org hostname: smtp.example.org
port: 26 port: 26
tls: STARTTLS tls: STARTTLS
username: multischleuder@example.org username: multischleuder@example.org
@ -70,7 +70,7 @@ lists:
''' '''
class TestSchleuderTypes(unittest.TestCase): class TestConfig(unittest.TestCase):
def test_parse_minimal_config(self): def test_parse_minimal_config(self):
ns = MagicMock() ns = MagicMock()
@ -78,7 +78,7 @@ class TestSchleuderTypes(unittest.TestCase):
ns.dry_run = False ns.dry_run = False
ns.verbose = False ns.verbose = False
with patch('builtins.open', mock_open(read_data=_MINIMAL)) as mock: with patch('builtins.open', mock_open(read_data=_MINIMAL)) as mock:
lists = parse_config(ns) lists, _ = parse_config(ns)
self.assertEqual(0, len(lists)) self.assertEqual(0, len(lists))
def test_parse_config(self): def test_parse_config(self):
@ -87,7 +87,7 @@ class TestSchleuderTypes(unittest.TestCase):
ns.dry_run = False ns.dry_run = False
ns.verbose = False ns.verbose = False
with patch('builtins.open', mock_open(read_data=_CONFIG)) as mock: with patch('builtins.open', mock_open(read_data=_CONFIG)) as mock:
lists = parse_config(ns) lists, smtp = parse_config(ns)
self.assertEqual(2, len(lists)) self.assertEqual(2, len(lists))
list1, list2 = lists list1, list2 = lists
@ -117,11 +117,13 @@ class TestSchleuderTypes(unittest.TestCase):
self.assertIn('test-north@schleuder.example.org', list2._sources) self.assertIn('test-north@schleuder.example.org', list2._sources)
self.assertIn('test-south@schleuder.example.org', list2._sources) self.assertIn('test-south@schleuder.example.org', list2._sources)
self.assertEqual('smtp+starttls://multischleuder@example.org@smtp.example.org:26', str(smtp))
def test_parse_dry_run(self): def test_parse_dry_run(self):
ns = MagicMock() ns = MagicMock()
ns.config = '/tmp/config.yml' ns.config = '/tmp/config.yml'
ns.dry_run = True ns.dry_run = True
ns.verbose = False ns.verbose = False
with patch('builtins.open', mock_open(read_data=_CONFIG)) as mock: with patch('builtins.open', mock_open(read_data=_CONFIG)) as mock:
lists = parse_config(ns) lists, _ = parse_config(ns)
self.assertEqual(True, lists[0]._api._dry_run) self.assertEqual(True, lists[0]._api._dry_run)

View file

@ -219,6 +219,75 @@ class TestKeyConflictResolution(unittest.TestCase):
self.assertEqual(0, len(resolved)) self.assertEqual(0, len(resolved))
self.assertEqual(0, len(messages)) self.assertEqual(0, len(messages))
def test_send_messages_nofile(self):
sch1 = SchleuderList(42, 'test-north@schleuder.example.org', '474777DA74528A7021184C8A0017324A6CFFBF92')
key1 = SchleuderKey(_PRIVKEY_1.fingerprint.replace(' ', ''), 'foo@example.org', str(_PRIVKEY_1.pubkey), sch1.id)
date1 = datetime(2022, 4, 15, 5, 23, 42, 0, tzinfo=tzutc())
date2 = datetime(2022, 4, 13, 5, 23, 42, 0, tzinfo=tzutc())
sub1 = SchleuderSubscriber(3, 'foo@example.org', key1, sch1.id, date1)
# This subscription is older, so its key will be preferred
sch2 = SchleuderList(23, 'test-south@schleuder.example.org', 'AF586C0625CF77BBB659747515D41C5D84BF99D3')
key2 = SchleuderKey(_PRIVKEY_2.fingerprint.replace(' ', ''), 'foo@example.org', str(_PRIVKEY_2.pubkey), sch2.id)
sub2 = SchleuderSubscriber(7, 'foo@example.org', key2, sch2.id, date2)
kcr = KeyConflictResolution(3600, '/nonexistent/directory/state.json', _TEMPLATE)
resolved, msgs = kcr.resolve(
target='test@schleuder.example.org',
mail_from='test-owner@schleuder.example.org',
subscriptions=[sub1, sub2])
self.assertEqual(0, len(msgs))
def test_send_messages_brokenstate(self):
sch1 = SchleuderList(42, 'test-north@schleuder.example.org', '474777DA74528A7021184C8A0017324A6CFFBF92')
key1 = SchleuderKey(_PRIVKEY_1.fingerprint.replace(' ', ''), 'foo@example.org', str(_PRIVKEY_1.pubkey), sch1.id)
date1 = datetime(2022, 4, 15, 5, 23, 42, 0, tzinfo=tzutc())
date2 = datetime(2022, 4, 13, 5, 23, 42, 0, tzinfo=tzutc())
sub1 = SchleuderSubscriber(3, 'foo@example.org', key1, sch1.id, date1)
# This subscription is older, so its key will be preferred
sch2 = SchleuderList(23, 'test-south@schleuder.example.org', 'AF586C0625CF77BBB659747515D41C5D84BF99D3')
key2 = SchleuderKey(_PRIVKEY_2.fingerprint.replace(' ', ''), 'foo@example.org', str(_PRIVKEY_2.pubkey), sch2.id)
sub2 = SchleuderSubscriber(7, 'foo@example.org', key2, sch2.id, date2)
kcr = KeyConflictResolution(3600, '/tmp/state.json', _TEMPLATE)
contents = io.StringIO('[[intentionally/broken\\json]]')
contents.seek(io.SEEK_END) # Opened with 'a+'
with patch('builtins.open', mock_open(read_data='[[intentionally/broken\\json]]')) as mock_statefile:
mock_statefile().__enter__.return_value = contents
resolved, msgs = kcr.resolve(
target='test@schleuder.example.org',
mail_from='test-owner@schleuder.example.org',
subscriptions=[sub1, sub2])
self.assertEqual(0, len(msgs))
def test_send_messages_emptystate(self):
sch1 = SchleuderList(42, 'test-north@schleuder.example.org', '474777DA74528A7021184C8A0017324A6CFFBF92')
key1 = SchleuderKey(_PRIVKEY_1.fingerprint.replace(' ', ''), 'foo@example.org', str(_PRIVKEY_1.pubkey), sch1.id)
date1 = datetime(2022, 4, 15, 5, 23, 42, 0, tzinfo=tzutc())
date2 = datetime(2022, 4, 13, 5, 23, 42, 0, tzinfo=tzutc())
sub1 = SchleuderSubscriber(3, 'foo@example.org', key1, sch1.id, date1)
# This subscription is older, so its key will be preferred
sch2 = SchleuderList(23, 'test-south@schleuder.example.org', 'AF586C0625CF77BBB659747515D41C5D84BF99D3')
key2 = SchleuderKey(_PRIVKEY_2.fingerprint.replace(' ', ''), 'foo@example.org', str(_PRIVKEY_2.pubkey), sch2.id)
sub2 = SchleuderSubscriber(7, 'foo@example.org', key2, sch2.id, date2)
kcr = KeyConflictResolution(3600, '/tmp/state.json', _TEMPLATE)
contents = io.StringIO()
with patch('builtins.open', mock_open(read_data='')) as mock_statefile:
mock_statefile().__enter__.return_value = contents
resolved, msgs = kcr.resolve(
target='test@schleuder.example.org',
mail_from='test-owner@schleuder.example.org',
subscriptions=[sub1, sub2])
self.assertEqual(1, len(msgs))
now = datetime.utcnow().timestamp()
mock_statefile.assert_called_with('/tmp/state.json', 'a+')
contents.seek(0)
state = json.loads(contents.read())
self.assertEqual(1, len(state))
self.assertIn(msgs[0].mime['X-MultiSchleuder-Digest'], state)
self.assertLess(now - state[msgs[0].mime['X-MultiSchleuder-Digest']], 60)
def test_send_messages_nostate(self): def test_send_messages_nostate(self):
sch1 = SchleuderList(42, 'test-north@schleuder.example.org', '474777DA74528A7021184C8A0017324A6CFFBF92') sch1 = SchleuderList(42, 'test-north@schleuder.example.org', '474777DA74528A7021184C8A0017324A6CFFBF92')
key1 = SchleuderKey(_PRIVKEY_1.fingerprint.replace(' ', ''), 'foo@example.org', str(_PRIVKEY_1.pubkey), sch1.id) key1 = SchleuderKey(_PRIVKEY_1.fingerprint.replace(' ', ''), 'foo@example.org', str(_PRIVKEY_1.pubkey), sch1.id)

View file

@ -67,6 +67,16 @@ def _get_key(fpr: str, schleuder: SchleuderList):
}[fpr] }[fpr]
def _get_admins(schleuder: SchleuderList):
if schleuder.id != 2:
return []
key = SchleuderKey('966842467B3254143F994D5E5C408C012D216471',
'admin@example.org', 'BEGIN PGP 2D216471', schleuder.id)
date = datetime(2022, 4, 15, 5, 23, 42, 0, tzinfo=tzutc())
admin = SchleuderSubscriber(0, 'admin@example.org', key, schleuder.id, date)
return [admin]
def _get_subs(schleuder: SchleuderList): def _get_subs(schleuder: SchleuderList):
key1 = SchleuderKey('966842467B3254143F994D5E5C408C012D216471', key1 = SchleuderKey('966842467B3254143F994D5E5C408C012D216471',
'admin@example.org', 'BEGIN PGP 2D216471', schleuder.id) 'admin@example.org', 'BEGIN PGP 2D216471', schleuder.id)
@ -125,6 +135,14 @@ def _get_subs(schleuder: SchleuderList):
return [] return []
def _get_equal_subs(schleuder: SchleuderList):
schleuder = SchleuderList(2, schleuder.name, schleuder.fingerprint)
subs = _get_subs(schleuder)
return [s
for s in subs
if s.email not in ['admin@example.org', 'aspammer@example.org', 'anotherspammer@example.org']]
def _get_sub(email: str, schleuder: SchleuderList): def _get_sub(email: str, schleuder: SchleuderList):
subs = _get_subs(schleuder) subs = _get_subs(schleuder)
return [s for s in subs if s.email == email][0] return [s for s in subs if s.email == email][0]
@ -137,15 +155,18 @@ class TestMultiList(unittest.TestCase):
kcr.resolve = _resolve kcr.resolve = _resolve
return kcr return kcr
def _api_mock(self): def _api_mock(self, nochange=False):
api = MagicMock() api = MagicMock()
api.get_lists = _list_lists api.get_lists = _list_lists
api.get_list_admins = _get_admins
api.get_subscribers = _get_subs api.get_subscribers = _get_subs
if nochange:
api.get_subscribers = _get_equal_subs
api.get_subscriber = _get_sub api.get_subscriber = _get_sub
api.get_key = _get_key api.get_key = _get_key
return api return api
def test_create(self): def test_full(self):
sources = [ sources = [
'test-north@schleuder.example.org', 'test-north@schleuder.example.org',
'test-east@schleuder.example.org', 'test-east@schleuder.example.org',
@ -153,7 +174,7 @@ class TestMultiList(unittest.TestCase):
'test-west@schleuder.example.org' 'test-west@schleuder.example.org'
] ]
api = self._api_mock() api = self._api_mock()
smtp = MagicMock() reporter = MagicMock()
ml = MultiList(sources=sources, ml = MultiList(sources=sources,
target='test-global@schleuder.example.org', target='test-global@schleuder.example.org',
unmanaged=['admin@example.org'], unmanaged=['admin@example.org'],
@ -161,9 +182,7 @@ class TestMultiList(unittest.TestCase):
mail_from='test-global-owner@schleuder.example.org', mail_from='test-global-owner@schleuder.example.org',
api=api, api=api,
kcr=self._kcr_mock(), kcr=self._kcr_mock(),
smtp=smtp, reporter=reporter)
send_admin_reports=True,
send_conflict_messages=True)
ml.process() ml.process()
# Key uploads # Key uploads
@ -218,3 +237,32 @@ class TestMultiList(unittest.TestCase):
self.assertEqual('8258FAF8B161B3DD8F784874F73E2DDF045AE2D6', c[0].fingerprint) self.assertEqual('8258FAF8B161B3DD8F784874F73E2DDF045AE2D6', c[0].fingerprint)
# Todo: check message queue # Todo: check message queue
def test_no_changes(self):
sources = [
'test-north@schleuder.example.org',
'test-east@schleuder.example.org',
'test-south@schleuder.example.org',
'test-west@schleuder.example.org'
]
api = self._api_mock(nochange=True)
reporter = MagicMock()
ml = MultiList(sources=sources,
target='test-global@schleuder.example.org',
unmanaged=['admin@example.org'],
banned=['aspammer@example.org', 'anotherspammer@example.org'],
mail_from='test-global-owner@schleuder.example.org',
api=api,
kcr=self._kcr_mock(),
reporter=reporter)
ml.process()
# Key uploads
self.assertEqual(0, len(api.post_key.call_args_list))
# Subscriptions
self.assertEqual(0, len(api.subscribe.call_args_list))
# Key updates
self.assertEqual(0, len(api.update_fingerprint.call_args_list))
# Unsubscribes
self.assertEqual(0, len(api.unsubscribe.call_args_list))
# Key deletions
self.assertEqual(0, len(api.delete_key.call_args_list))

View file

@ -0,0 +1,78 @@
import unittest
from datetime import datetime
from multischleuder.reporting import ConflictMessage, AdminReport, Reporter
from multischleuder.types import SchleuderKey, SchleuderList, SchleuderSubscriber
def one_of_each_kind():
sub = SchleuderSubscriber(1, 'foo@example.org', None, 1, datetime.utcnow())
msg1 = ConflictMessage(
schleuder='test@example.org',
chosen=sub,
affected=[sub],
digest='digest',
mail_from='test-owner@example.org',
template='averylongmessage')
msg2 = AdminReport(
schleuder='test@example.org',
mail_to='admin@example.org',
mail_from='test-owner@example.org',
encrypt_to=None,
subscribed={},
unsubscribed={sub},
updated={},
added={},
removed={})
return [msg1, msg2]
class TestReporting(unittest.TestCase):
def test_reporter_config_all_enabled(self):
msgs = one_of_each_kind()
r = Reporter(send_conflict_messages=True,
send_admin_reports=True)
r.add_messages(msgs)
self.assertEquals(2, len(Reporter.get_messages()))
self.assertIsInstance(Reporter.get_messages()[-2], ConflictMessage)
self.assertEquals('foo@example.org', Reporter.get_messages()[-2]._to)
self.assertIsInstance(Reporter.get_messages()[-1], AdminReport)
self.assertEquals('admin@example.org', Reporter.get_messages()[-1]._to)
Reporter.clear_messages()
def test_reporter_config_conflict_only(self):
msgs = one_of_each_kind()
r = Reporter(send_conflict_messages=True,
send_admin_reports=False)
r.add_messages(msgs)
self.assertEquals(1, len(Reporter.get_messages()))
self.assertIsInstance(Reporter.get_messages()[-1], ConflictMessage)
self.assertEquals('foo@example.org', Reporter.get_messages()[-1]._to)
Reporter.clear_messages()
def test_reporter_config_admin_only(self):
msgs = one_of_each_kind()
r = Reporter(send_conflict_messages=False,
send_admin_reports=True)
r.add_messages(msgs)
self.assertEquals(1, len(Reporter.get_messages()))
self.assertIsInstance(Reporter.get_messages()[-1], AdminReport)
self.assertEquals('admin@example.org', Reporter.get_messages()[-1]._to)
Reporter.clear_messages()
def test_reporter_config_all_disabled(self):
msgs = one_of_each_kind()
r = Reporter(send_conflict_messages=False,
send_admin_reports=False)
r.add_messages(msgs)
self.assertEquals(0, len(Reporter.get_messages()))
def test_reporter_null_message(self):
r = Reporter(send_conflict_messages=True,
send_admin_reports=True)
r.add_messages([None])
self.assertEquals(0, len(Reporter.get_messages()))
Reporter.clear_messages()

View file

@ -1,21 +1,40 @@
import asyncio import asyncio
import unittest import unittest
from datetime import datetime
from email.mime.text import MIMEText from email.mime.text import MIMEText
from aiosmtpd.controller import Controller from aiosmtpd.controller import Controller
from aiosmtpd.smtp import AuthResult, SMTP from aiosmtpd.smtp import AuthResult, SMTP
from multischleuder.reporting import ConflictMessage, AdminReport
from multischleuder.smtp import SmtpClient, TlsMode from multischleuder.smtp import SmtpClient, TlsMode
from multischleuder.types import SchleuderSubscriber
class MockSmtpHandler: class MockSmtpHandler:
def __init__(self): def __init__(self):
self.rcpt = None self.rcpt = []
self.connected = False
async def handle_HELO(self, server, session, envelope, hostname):
self.connected = True
session.host_name = hostname
return '250 dummy.example.org'
async def handle_EHLO(self, server, session, envelope, hostname, responses):
self.connected = True
session.host_name = hostname
return [
'250-dummy.example.org',
'250-AUTH PLAIN LOGIN',
'250-AUTH=PLAIN LOGIN',
'250 HELP'
]
async def handle_RCPT(self, server, session, envelope, address, rcpt_options): async def handle_RCPT(self, server, session, envelope, address, rcpt_options):
self.rcpt = address self.rcpt.append(address)
envelope.rcpt_tos.append(address) envelope.rcpt_tos.append(address)
return '250 OK' return '250 OK'
@ -76,9 +95,28 @@ class TestSmtpClient(unittest.TestCase):
self.assertEqual(465, SmtpClient.parse({'tls': 'SMTPS'})._port) self.assertEqual(465, SmtpClient.parse({'tls': 'SMTPS'})._port)
self.assertEqual(587, SmtpClient.parse({'tls': 'STARTTLS'})._port) self.assertEqual(587, SmtpClient.parse({'tls': 'STARTTLS'})._port)
def test_send_message_auth(self): def test_send_message_dryrun(self):
ctrl = MockController(handler=MockSmtpHandler(), hostname='127.0.0.1', port=10025) ctrl = MockController(handler=MockSmtpHandler(), hostname='127.0.0.1', port=10025)
ctrl.start() ctrl.start()
client = SmtpClient(
hostname='127.0.0.1',
port=ctrl.port,
username='example',
password='supersecurepassword')
client.dry_run()
msg = MIMEText('Hello World!')
msg['From'] = 'foo@example.org'
msg['To'] = 'bar@example.org'
with client:
client._send_message(msg)
ctrl.stop()
self.assertEqual('example', ctrl.received_user)
self.assertEqual('supersecurepassword', ctrl.received_pass)
self.assertEqual(0, len(ctrl.handler.rcpt))
def test_send_message_auth(self):
ctrl = MockController(handler=MockSmtpHandler(), hostname='127.0.0.1', port=10026)
ctrl.start()
client = SmtpClient( client = SmtpClient(
hostname='127.0.0.1', hostname='127.0.0.1',
port=ctrl.port, port=ctrl.port,
@ -88,22 +126,91 @@ class TestSmtpClient(unittest.TestCase):
msg['From'] = 'foo@example.org' msg['From'] = 'foo@example.org'
msg['To'] = 'bar@example.org' msg['To'] = 'bar@example.org'
with client: with client:
client.send_message(msg) client._send_message(msg)
ctrl.stop() ctrl.stop()
self.assertEqual('example', ctrl.received_user) self.assertEqual('example', ctrl.received_user)
self.assertEqual('supersecurepassword', ctrl.received_pass) self.assertEqual('supersecurepassword', ctrl.received_pass)
self.assertEqual('bar@example.org', ctrl.handler.rcpt) self.assertEqual('bar@example.org', ctrl.handler.rcpt[0])
def test_send_message(self): def test_send_message_noauth(self):
ctrl = MockController(handler=MockSmtpHandler(), hostname='127.0.0.1', port=10026) ctrl = MockController(handler=MockSmtpHandler(), hostname='127.0.0.1', port=10027)
ctrl.start() ctrl.start()
client = SmtpClient(hostname='127.0.0.1', port=ctrl.port) client = SmtpClient(hostname='127.0.0.1', port=ctrl.port)
msg = MIMEText('Hello World!') msg = MIMEText('Hello World!')
msg['From'] = 'foo@example.org' msg['From'] = 'foo@example.org'
msg['To'] = 'baz@example.org' msg['To'] = 'baz@example.org'
with client: with client:
client.send_message(msg) client._send_message(msg)
ctrl.stop() ctrl.stop()
self.assertIsNone(ctrl.received_user) self.assertIsNone(ctrl.received_user)
self.assertIsNone(ctrl.received_pass) self.assertIsNone(ctrl.received_pass)
self.assertEqual('baz@example.org', ctrl.handler.rcpt) self.assertTrue(ctrl.handler.connected)
self.assertEqual('baz@example.org', ctrl.handler.rcpt[0])
def test_send_no_messages(self):
ctrl = MockController(handler=MockSmtpHandler(), hostname='127.0.0.1', port=10028)
ctrl.start()
client = SmtpClient(
hostname='127.0.0.1',
port=ctrl.port,
username='example',
password='supersecurepassword')
client.send_messages([])
ctrl.stop()
self.assertFalse(ctrl.handler.connected)
self.assertIsNone(ctrl.received_user)
self.assertIsNone(ctrl.received_pass)
def test_send_multiple_messages(self):
ctrl = MockController(handler=MockSmtpHandler(), hostname='127.0.0.1', port=10029)
ctrl.start()
client = SmtpClient(
hostname='127.0.0.1',
port=ctrl.port,
username='example',
password='supersecurepassword')
sub = SchleuderSubscriber(1, 'foo@example.org', None, 1, datetime.utcnow())
msg1 = ConflictMessage(
schleuder='test@example.org',
chosen=sub,
affected=[sub],
digest='digest',
mail_from='test-owner@example.org',
template='averylongmessage')
msg2 = AdminReport(
schleuder='test@example.org',
mail_to='admin@example.org',
mail_from='test-owner@example.org',
encrypt_to=None,
subscribed={},
unsubscribed={sub},
updated={},
added={},
removed={})
client.send_messages([msg1, msg2])
ctrl.stop()
self.assertTrue(ctrl.handler.connected)
self.assertEqual('foo@example.org', ctrl.handler.rcpt[0])
self.assertEqual('admin@example.org', ctrl.handler.rcpt[1])
def test_send_dry_run(self):
ctrl = MockController(handler=MockSmtpHandler(), hostname='127.0.0.1', port=10030)
ctrl.start()
client = SmtpClient(
hostname='127.0.0.1',
port=ctrl.port,
username='example',
password='supersecurepassword')
client.dry_run()
sub = SchleuderSubscriber(1, 'foo@example.org', None, 1, datetime.utcnow())
msg1 = ConflictMessage(
schleuder='test@example.org',
chosen=sub,
affected=[sub],
digest='digest',
mail_from='test-owner@example.org',
template='averylongmessage')
client.send_messages([msg1])
ctrl.stop()
self.assertFalse(ctrl.handler.connected)
self.assertEqual(0, len(ctrl.handler.rcpt))