Add config parsing unit tests

This commit is contained in:
s3lph 2022-04-15 23:40:11 +02:00
parent 294f299175
commit 2f7b654d57
3 changed files with 134 additions and 5 deletions

View file

@ -107,7 +107,7 @@ class ConflictMessage:
class KeyConflictResolution: class KeyConflictResolution:
def __init__(self, smtp: 'SmtpClient', interval: int, statefile: str, template: str): def __init__(self, smtp: SmtpClient, interval: int, statefile: str, template: str):
self._smtp = smtp self._smtp = smtp
self._interval: int = interval self._interval: int = interval
self._state_file: str = statefile self._state_file: str = statefile
@ -118,7 +118,7 @@ class KeyConflictResolution:
def resolve(self, def resolve(self,
target: str, target: str,
mail_from: str, mail_from: str,
subscriptions: List['SchleuderSubscriber']) -> List['SchleuderSubscriber']: subscriptions: List[SchleuderSubscriber]) -> List[SchleuderSubscriber]:
subs: Dict[str, List[SchleuderSubscriber]] = OrderedDict() subs: Dict[str, List[SchleuderSubscriber]] = OrderedDict()
for s in subscriptions: for s in subscriptions:
subs.setdefault(s.email, []).append(s) subs.setdefault(s.email, []).append(s)
@ -128,7 +128,7 @@ class KeyConflictResolution:
def _resolve(self, def _resolve(self,
target: str, target: str,
mail_from: str, mail_from: str,
subscriptions: List['SchleuderSubscriber']) -> 'SchleuderSubscriber': subscriptions: List[SchleuderSubscriber]) -> SchleuderSubscriber:
if len(subscriptions) == 1: if len(subscriptions) == 1:
return subscriptions[0] return subscriptions[0]
if len({s.key.blob for s in subscriptions}) == 1: if len({s.key.blob for s in subscriptions}) == 1:

View file

@ -5,6 +5,7 @@ import argparse
import yaml import yaml
from multischleuder import __version__
from multischleuder.api import SchleuderApi from multischleuder.api import SchleuderApi
from multischleuder.conflict import KeyConflictResolution from multischleuder.conflict import KeyConflictResolution
from multischleuder.processor import MultiList from multischleuder.processor import MultiList
@ -39,13 +40,13 @@ def parse_config(ns: argparse.Namespace) -> List['MultiList']:
api.dry_run() api.dry_run()
smtp_config = c.get('smtp', {}) smtp_config = c.get('smtp', {})
smtp = SmtpClient.parse(**smtp_config) smtp = SmtpClient.parse(smtp_config)
kcr_config = c.get('conflict', {}) kcr_config = c.get('conflict', {})
kcr = KeyConflictResolution(smtp, **kcr_config) kcr = KeyConflictResolution(smtp, **kcr_config)
lists = [] lists = []
for clist in c['lists']: for clist in c.get('lists', []):
ml = parse_list_config(api, kcr, clist) ml = parse_list_config(api, kcr, clist)
lists.append(ml) lists.append(ml)
return lists return lists
@ -56,6 +57,7 @@ def main():
ap.add_argument('--config', '-c', type=str, default='/etc/multischleuder/config.yml') ap.add_argument('--config', '-c', type=str, default='/etc/multischleuder/config.yml')
ap.add_argument('--dry-run', '-n', action='store_true', default=False) ap.add_argument('--dry-run', '-n', action='store_true', default=False)
ap.add_argument('--verbose', '-v', action='store_true', default=False) ap.add_argument('--verbose', '-v', action='store_true', default=False)
ap.add_argument('--version', action='version', version=__version__)
ns = ap.parse_args(sys.argv[1:]) ns = ap.parse_args(sys.argv[1:])
if ns.verbose: if ns.verbose:
logger = logging.getLogger().setLevel('DEBUG') logger = logging.getLogger().setLevel('DEBUG')

View file

@ -0,0 +1,127 @@
import unittest
from unittest.mock import MagicMock, mock_open, patch
from multischleuder.main import parse_config, parse_list_config
from multischleuder.test.test_api import _KEY_RESPONSE, _LIST_RESPONSE, _SUBSCRIBER_RESPONSE
_MINIMAL = '''
api:
url: https://localhost:4443
token: securetoken
conflict:
interval: 3600
statefile: /tmp/state.json
template: ''
'''
_CONFIG = '''---
api:
url: https://localhost:4443
token: securetoken
cafile: /tmp/ca.pem
smtp:
host: smtp.example.org
port: 26
tls: STARTTLS
username: multischleuder@example.org
password: supersecurepassword
conflict:
interval: 3600
statefile: /tmp/state.json
template: |
Dear {subscriber}
This is a test and should not be used in production:
{affected}
If you ever receive this text via email, notify your admin:
{chosen}
Regards, {schleuder}
lists:
- target: test-global@schleuder.example.org
from: test-global-owner@schleuder.example.org
sources:
- test-north@schleuder.example.org
- test-east@schleuder.example.org
- test-south@schleuder.example.org
- test-west@schleuder.example.org
unmanaged:
- admin@example.org
banned:
- aspammer@example.org
- anotherspammer@example.org
- target: test2-global@schleuder.example.org
from: test2-global-owner@schleuder.example.org
sources:
- test-north@schleuder.example.org
- test-south@schleuder.example.org
unmanaged:
- admin@example.org
- admin2@example.org
banned: []
'''
class TestSchleuderTypes(unittest.TestCase):
def test_parse_minimal_config(self):
ns = MagicMock()
ns.config = '/tmp/config.yml'
ns.dry_run = False
ns.verbose = False
with patch('builtins.open', mock_open(read_data=_MINIMAL)) as mock:
lists = parse_config(ns)
self.assertEqual(0, len(lists))
def test_parse_config(self):
ns = MagicMock()
ns.config = '/tmp/config.yml'
ns.dry_run = False
ns.verbose = False
with patch('builtins.open', mock_open(read_data=_CONFIG)) as mock:
lists = parse_config(ns)
self.assertEqual(2, len(lists))
list1, list2 = lists
self.assertEqual('https://localhost:4443', list1._api._url)
self.assertEqual('Basic c2NobGV1ZGVyOnNlY3VyZXRva2Vu', list2._api._headers['Authorization'])
self.assertEqual('/tmp/ca.pem', list1._api._cafile)
self.assertEqual(False, list1._api._dry_run)
self.assertEqual(3600, list2._kcr._interval)
self.assertEqual('/tmp/state.json', list1._kcr._state_file)
self.assertIn('Regards, {schleuder}', list2._kcr._template)
self.assertEqual('test-global@schleuder.example.org', list1._target)
self.assertEqual('test-global-owner@schleuder.example.org', list1._mail_from)
self.assertEqual(['admin@example.org'], list1._unmanaged)
self.assertEqual(['aspammer@example.org', 'anotherspammer@example.org'], list1._banned)
self.assertEqual(4, len(list1._sources))
self.assertIn('test-north@schleuder.example.org', list1._sources)
self.assertIn('test-east@schleuder.example.org', list1._sources)
self.assertIn('test-south@schleuder.example.org', list1._sources)
self.assertIn('test-west@schleuder.example.org', list1._sources)
self.assertEqual('test2-global@schleuder.example.org', list2._target)
self.assertEqual('test2-global-owner@schleuder.example.org', list2._mail_from)
self.assertEqual(['admin@example.org', 'admin2@example.org'], list2._unmanaged)
self.assertEqual([], list2._banned)
self.assertEqual(2, len(list2._sources))
self.assertIn('test-north@schleuder.example.org', list2._sources)
self.assertIn('test-south@schleuder.example.org', list2._sources)
def test_parse_dry_run(self):
ns = MagicMock()
ns.config = '/tmp/config.yml'
ns.dry_run = True
ns.verbose = False
with patch('builtins.open', mock_open(read_data=_CONFIG)) as mock:
lists = parse_config(ns)
self.assertEqual(True, lists[0]._api._dry_run)