From f681096bb2f07f555589db3675ea19890c660e13 Mon Sep 17 00:00:00 2001 From: nerosketch Date: Thu, 14 Jun 2018 01:38:46 +0300 Subject: [PATCH] refactoring nas manager --- agent/mod_mikrotik.py | 288 ++++++++++----------- agent/mod_mikrotik_old.py | 514 ++++++++++++++++++++++++++++++++++++++ agent/structs.py | 78 +----- 3 files changed, 652 insertions(+), 228 deletions(-) create mode 100644 agent/mod_mikrotik_old.py diff --git a/agent/mod_mikrotik.py b/agent/mod_mikrotik.py index 5a53711..28d82e0 100644 --- a/agent/mod_mikrotik.py +++ b/agent/mod_mikrotik.py @@ -1,25 +1,22 @@ -# -*- coding: utf-8 -*- import re import socket import binascii -from abc import ABCMeta from hashlib import md5 -from typing import Iterable, Optional, Tuple -from .core import BaseTransmitter, NasFailedResult, NasNetworkError -from djing.lib import Singleton +from typing import Iterable, Optional, Tuple, Generator, Dict +from djing.lib import safe_int from .structs import TariffStruct, AbonStruct, IpStruct, VectorAbon, VectorTariff from . import settings as local_settings from django.conf import settings from djing import ping -from agent.core import NasNetworkError, NasFailedResult +from agent.core import BaseTransmitter, NasNetworkError, NasFailedResult DEBUG = getattr(settings, 'DEBUG', False) LIST_USERS_ALLOWED = 'DjingUsersAllowed' -#LIST_USERS_BLOCKED = 'DjingUsersBlocked' +LIST_DEVICES_ALLOWED = 'DjingDevicesAllowed' -class ApiRos(metaclass=Singleton): +class ApiRos(object): """Routeros api""" sk = None is_login = False @@ -161,8 +158,21 @@ class ApiRos(metaclass=Singleton): ret += s return ret + def __del__(self): + if hasattr(self, 'sk'): + self.sk.close() + + +class IpAddressListObj(IpStruct): + __slots__ = ('__ip', 'mk_id') + + def __init__(self, ip, mk_id): + super(IpAddressListObj, self).__init__(ip) + self.mk_id = str(mk_id).replace('*', '') + + +class MikrotikTransmitter(BaseTransmitter, ApiRos): -class TransmitterManager(BaseTransmitter, metaclass=ABCMeta): def __init__(self, login=None, password=None, ip=None, port=None): ip = ip or getattr(local_settings, 'NAS_IP') if ip is None or ip == '': @@ -172,41 +182,36 @@ class TransmitterManager(BaseTransmitter, metaclass=ABCMeta): 'ip_addr': ip }) try: - self.ar = ApiRos(ip, port) - self.ar.login(login or getattr(local_settings, 'NAS_LOGIN'), - password or getattr(local_settings, 'NAS_PASSW')) + super(MikrotikTransmitter, self).__init__(ip, port) + self.login( + login or getattr(local_settings, 'NAS_LOGIN'), + password or getattr(local_settings, 'NAS_PASSW') + ) except ConnectionRefusedError: raise NasNetworkError('Connection to %s is Refused' % ip) - def __del__(self): - if hasattr(self, 's'): - self.s.close() - - def _exec_cmd(self, cmd: Iterable) -> list: + def _exec_cmd(self, cmd: Iterable) -> Dict: if not isinstance(cmd, (list, tuple)): raise TypeError - result_iter = self.ar.talk_iter(cmd) - res = [] - for rt in result_iter: - if rt[0] == '!trap': - raise NasFailedResult(rt[1]['=message']) - res.append(rt[1]) - return res - - def _exec_cmd_iter(self, cmd: Iterable) -> Iterable: + r = dict() + for k, v in self.talk_iter(cmd): + if k == '!done': + break + r[k] = v or None + return r + + def _exec_cmd_iter(self, cmd: Iterable) -> Generator: if not isinstance(cmd, (list, tuple)): raise TypeError - result_iter = self.ar.talk_iter(cmd) - for rt in result_iter: - if len(rt) < 2: - continue - if rt[0] == '!trap': - raise NasFailedResult(rt[1]['=message']) - yield rt + for k, v in self.talk_iter(cmd): + if k == '!trap': + raise NasFailedResult(v.get('=message')) + if v: + yield v # Build object ShapeItem from information from mikrotik @staticmethod - def _build_shape_obj(info: dict) -> AbonStruct: + def _build_shape_obj(info: Dict) -> AbonStruct: # Переводим приставку скорости Mikrotik в Mbit/s def parse_speed(text_speed): text_speed_digit = float(text_speed[:-1] or 0.0) @@ -239,62 +244,75 @@ class TransmitterManager(BaseTransmitter, metaclass=ABCMeta): except ValueError: pass + ################################################# + # QUEUES + ################################################# -class QueueManager(TransmitterManager, metaclass=ABCMeta): # Find queue by name - def find(self, name: str) -> AbonStruct: + def find_queue(self, name: str) -> Optional[AbonStruct]: ret = self._exec_cmd(('/queue/simple/print', '?name=%s' % name)) if len(ret) > 1: return self._build_shape_obj(ret[0]) - def add(self, user: AbonStruct): + def add_queue(self, user: AbonStruct) -> None: if not isinstance(user, AbonStruct): raise TypeError if user.tariff is None or not isinstance(user.tariff, TariffStruct): return - return self._exec_cmd(('/queue/simple/add', - '=name=uid%d' % user.uid, - # FIXME: тут в разных микротиках или =target-addresses или =target - '=target=%s' % user.ip, - '=max-limit=%.3fM/%.3fM' % (user.tariff.speedOut, user.tariff.speedIn), - '=queue=MikroBILL_SFQ/MikroBILL_SFQ', - '=burst-time=1/1' - )) - - def remove(self, user: AbonStruct): + r = self._exec_cmd(( + '/queue/simple/add', + '=name=uid%d' % user.uid, + # FIXME: тут в разных микротиках или =target-addresses или =target + '=target=%s' % user.ip, + '=max-limit=%.3fM/%.3fM' % (user.tariff.speedOut, user.tariff.speedIn), + '=queue=MikroBILL_SFQ/MikroBILL_SFQ', + '=burst-time=1/1' + )) + print(r) + + def remove_queue(self, user: AbonStruct) -> None: if not isinstance(user, AbonStruct): raise TypeError - q = self.find('uid%d' % user.uid) + q = self.find_queue('uid%d' % user.uid) if q is not None: - return self._exec_cmd(('/queue/simple/remove', '=.id=' + getattr(q, 'queue_id', ''),)) - - def remove_range(self, q_ids: Iterable[str]): - try: - # q_ids = [q.queue_id for q in q_ids] - return self._exec_cmd(('/queue/simple/remove', '=numbers=' + ','.join(q_ids))) - except TypeError as e: - print(e) - - def update(self, user: AbonStruct): + queue_id = safe_int(getattr(q, 'queue_id')) + if queue_id != 0: + r = self._exec_cmd(( + '/queue/simple/remove', + '=.id=%d' % queue_id + )) + print(r) + + def remove_queue_range(self, q_ids: Iterable[str]): + # FIXME: check result from _exec_cmd + r = self._exec_cmd(('/queue/simple/remove', '=numbers=' + ','.join(q_ids))) + return r + + def update_queue(self, user: AbonStruct): if not isinstance(user, AbonStruct): raise TypeError - if user.tariff is None or not isinstance(user.tariff, TariffStruct): + if user.tariff is None: return - queue = self.find('uid%d' % user.uid) + queue = self.find_queue('uid%d' % user.uid) if queue is None: - return self.add(user) + return self.add_queue(user) else: - mk_id = getattr(queue, 'queue_id', '') - return self._exec_cmd(('/queue/simple/set', '=.id=' + mk_id, - '=name=uid%d' % user.uid, - '=max-limit=%.3fM/%.3fM' % (user.tariff.speedOut, user.tariff.speedIn), - # FIXME: тут в разных микротиках или =target-addresses или =target - '=target=%s' % user.ip, - '=queue=MikroBILL_SFQ/MikroBILL_SFQ', - '=burst-time=1/1' - )) - - def read_queue_iter(self): + mk_id = safe_int(getattr(queue, 'queue_id', 0)) + cmd = [ + '/queue/simple/set', + '=name=uid%d' % user.uid, + '=max-limit=%.3fM/%.3fM' % (user.tariff.speedOut, user.tariff.speedIn), + # FIXME: тут в разных микротиках или =target-addresses или =target + '=target=%s' % user.ip, + '=queue=MikroBILL_SFQ/MikroBILL_SFQ', + '=burst-time=1/1' + ] + if mk_id != 0: + cmd.insert(1, '=.id=%d' % mk_id) + r = self._exec_cmd(cmd) + return r + + def read_queue_iter(self) -> Generator: for code, dat in self._exec_cmd_iter(('/queue/simple/print', '=detail')): if code == '!done': return @@ -302,42 +320,11 @@ class QueueManager(TransmitterManager, metaclass=ABCMeta): if sobj is not None: yield sobj - def read_mikroids_iter(self): - queues = self._exec_cmd_iter(('/queue/simple/print', '=detail')) - for queue in queues: - if queue[0] == '!done': - return - yield int(queue[1]['=.id'].replace('*', ''), base=16) + ################################################# + # Ip->firewall->address list + ################################################# - def disable(self, user: AbonStruct): - if not isinstance(user, AbonStruct): - raise TypeError - q = self.find('uid%d' % user.uid) - if q is None: - self.add(user) - return self.disable(user) - else: - return self._exec_cmd(('/queue/simple/disable', '=.id=*' + getattr(q, 'queue_id', ''))) - - def enable(self, user: AbonStruct): - if not isinstance(user, AbonStruct): - raise TypeError - q = self.find('uid%d' % user.uid) - if q is None: - self.add(user) - self.enable(user) - else: - return self._exec_cmd(('/queue/simple/enable', '=.id=*' + getattr(q, 'queue_id', ''))) - - -class IpAddressListObj(IpStruct): - def __init__(self, ip, mk_id): - super(IpAddressListObj, self).__init__(ip) - self.mk_id = str(mk_id).replace('*', '') - - -class IpAddressListManager(TransmitterManager, metaclass=ABCMeta): - def add(self, list_name: str, ip: IpStruct): + def add_ip(self, list_name: str, ip: IpStruct): if not isinstance(ip, IpStruct): raise TypeError commands = ( @@ -345,23 +332,24 @@ class IpAddressListManager(TransmitterManager, metaclass=ABCMeta): '=list=%s' % list_name, '=address=%s' % ip ) - return self._exec_cmd(commands) + r = self._exec_cmd(commands) + return r - def remove(self, mk_id): + def remove_ip(self, mk_id): return self._exec_cmd(( '/ip/firewall/address-list/remove', '=.id=*' + str(mk_id).replace('*', '') )) - def remove_range(self, items: Iterable[IpAddressListObj]): + def remove_ip_range(self, items: Iterable[IpAddressListObj]): ids = tuple(ip.mk_id for ip in items if isinstance(ip, IpAddressListObj)) if len(ids) > 0: - return self._exec_cmd([ + return self._exec_cmd(( '/ip/firewall/address-list/remove', '=numbers=*%s' % ',*'.join(ids) - ]) + )) - def find(self, ip: IpStruct, list_name: str): + def find_ip(self, ip: IpStruct, list_name: str): if not isinstance(ip, IpStruct): raise TypeError return self._exec_cmd(( @@ -370,7 +358,7 @@ class IpAddressListManager(TransmitterManager, metaclass=ABCMeta): '?address=%s' % ip )) - def read_ips_iter(self, list_name: str): + def read_ips_iter(self, list_name: str) -> Generator: ips = self._exec_cmd_iter(( '/ip/firewall/address-list/print', 'where', '?list=%s' % list_name, @@ -380,26 +368,10 @@ class IpAddressListManager(TransmitterManager, metaclass=ABCMeta): if dat != {}: yield IpAddressListObj(dat['=address'], dat['=.id']) - def disable(self, user: AbonStruct): - r = IpAddressListManager.find(self, user.ip, LIST_USERS_ALLOWED) - if len(r) > 1: - mk_id = r[0]['=.id'] - return self._exec_cmd(( - '/ip/firewall/address-list/disable', - '=.id=' + str(mk_id), - )) - - def enable(self, user): - r = IpAddressListManager.find(self, user.ip, LIST_USERS_ALLOWED) - if len(r) > 1: - mk_id = r[0]['=.id'] - return self._exec_cmd(( - '/ip/firewall/address-list/enable', - '=.id=' + str(mk_id), - )) - + ################################################# + # BaseTransmitter implementation + ################################################# -class MikrotikTransmitter(QueueManager, IpAddressListManager): def add_user_range(self, user_list: VectorAbon): for usr in user_list: self.add_user(usr) @@ -408,11 +380,11 @@ class MikrotikTransmitter(QueueManager, IpAddressListManager): if not isinstance(users, (tuple, list, set)): raise ValueError('*users* is used twice, generator does not fit') queue_ids = (usr.queue_id for usr in users if usr is not None) - QueueManager.remove_range(self, queue_ids) + self.remove_queue_range(queue_ids) for ip in (user.ip for user in users if isinstance(user, AbonStruct)): - ip_list_entity = IpAddressListManager.find(self, ip, LIST_USERS_ALLOWED) + ip_list_entity = self.find_ip(ip, LIST_USERS_ALLOWED) if ip_list_entity is not None and len(ip_list_entity) > 1: - IpAddressListManager.remove(self, ip_list_entity[0]['=.id']) + self.remove_ip(ip_list_entity[0]['=.id']) def add_user(self, user: AbonStruct, *args): if not isinstance(user.ip, IpStruct): @@ -422,55 +394,52 @@ class MikrotikTransmitter(QueueManager, IpAddressListManager): if not isinstance(user.tariff, TariffStruct): raise TypeError try: - QueueManager.add(self, user) + self.add_queue(user) except (NasNetworkError, NasFailedResult) as e: print('Error:', e) try: - IpAddressListManager.add(self, LIST_USERS_ALLOWED, user.ip) + self.add_ip(LIST_USERS_ALLOWED, user.ip) except (NasNetworkError, NasFailedResult) as e: print('Error:', e) def remove_user(self, user: AbonStruct): - QueueManager.remove(self, user) - firewall_ip_list_obj = IpAddressListManager.find(self, user.ip, LIST_USERS_ALLOWED) + self.remove_queue(user) + firewall_ip_list_obj = self.find_ip(user.ip, LIST_USERS_ALLOWED) if firewall_ip_list_obj is not None and len(firewall_ip_list_obj) > 1: - IpAddressListManager.remove(self, firewall_ip_list_obj[0]['=.id']) + self.remove_ip(firewall_ip_list_obj[0]['=.id']) def update_user(self, user: AbonStruct, *args): if not isinstance(user.ip, IpStruct): raise TypeError - - find_res = IpAddressListManager.find(self, user.ip, LIST_USERS_ALLOWED) - queue = QueueManager.find(self, 'uid%d' % user.uid) - + find_res = self.find_ip(user.ip, LIST_USERS_ALLOWED) + queue = self.find_queue('uid%d' % user.uid) if not user.is_active: # если не активен - то и обновлять не надо # но и выключить на всяк случай надо, а то вдруг был включён if len(find_res) > 1: # и если найден был - то удалим ip из разрешённых - IpAddressListManager.remove(self, find_res[0]['=.id']) + self.remove_ip(find_res[0]['=.id']) if queue is not None: - QueueManager.remove(self, user) + self.remove_queue(user) return # если нет услуги то её не должно быть и в nas - if user.tariff is None or not isinstance(user.tariff, TariffStruct): + if user.tariff is None: if queue is not None: - QueueManager.remove(self, user) + self.remove_queue(user) return # если не найден (mikrotik возвращает пустой словарь в списке если ничего нет) if len(find_res) < 2: # добавим запись об абоненте - IpAddressListManager.add(self, LIST_USERS_ALLOWED, user.ip) + self.add_ip(LIST_USERS_ALLOWED, user.ip) # Проверяем шейпер - if queue is None: - QueueManager.add(self, user) + self.add_queue(user) return if queue != user: - QueueManager.update(self, user) + self.update_queue(user) def ping(self, host, count=10) -> Optional[Tuple[int, int]]: r = self._exec_cmd(( @@ -487,19 +456,15 @@ class MikrotikTransmitter(QueueManager, IpAddressListManager): received, sent = int(r[-2:][0]['=received']), int(r[-2:][0]['=sent']) return received, sent - # Тарифы хранить нам не надо, так что методы тарифов ниже не реализуем def add_tariff_range(self, tariff_list: VectorTariff): pass - # соответственно и удалять тарифы не надо def remove_tariff_range(self, tariff_list: VectorTariff): pass - # и добавлять тоже def add_tariff(self, tariff: TariffStruct): pass - # и обновлять def update_tariff(self, tariff: TariffStruct): pass @@ -508,14 +473,13 @@ class MikrotikTransmitter(QueueManager, IpAddressListManager): def read_users(self) -> Iterable[AbonStruct]: # shapes is ShapeItem - allowed_ips = set(IpAddressListManager.read_ips_iter(self, LIST_USERS_ALLOWED)) - queues = tuple(q for q in QueueManager.read_queue_iter(self) if q.ip in allowed_ips) + allowed_ips = set(self.read_ips_iter(LIST_USERS_ALLOWED)) + queues = tuple(q for q in self.read_queue_iter() if q.ip in allowed_ips) - ips_from_queues = set(q.ip for q in queues) + ips_from_queues = set((q.ip, q) for q in queues) # delete ip addresses that are in firewall/address-list and there are no corresponding in queues diff = tuple(allowed_ips - ips_from_queues) if len(diff) > 0: - IpAddressListManager.remove_range(self, diff) - + self.remove_ip_range(diff) return queues diff --git a/agent/mod_mikrotik_old.py b/agent/mod_mikrotik_old.py new file mode 100644 index 0000000..b085d64 --- /dev/null +++ b/agent/mod_mikrotik_old.py @@ -0,0 +1,514 @@ +import re +import socket +import binascii +from abc import ABCMeta +from hashlib import md5 +from typing import Iterable, Optional, Tuple +from djing.lib import Singleton +from .structs import TariffStruct, AbonStruct, IpStruct, VectorAbon, VectorTariff +from . import settings as local_settings +from django.conf import settings +from djing import ping +from agent.core import BaseTransmitter, NasNetworkError, NasFailedResult + +DEBUG = getattr(settings, 'DEBUG', False) + +LIST_USERS_ALLOWED = 'DjingUsersAllowed' +LIST_DEVICES_ALLOWED = 'DjingDevicesAllowed' + + +class ApiRos(metaclass=Singleton): + """Routeros api""" + sk = None + is_login = False + + def __init__(self, ip: str, port: int): + if self.sk is None: + sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if port is None: + port = local_settings.NAS_PORT + sk.connect((ip, port or 8728)) + self.sk = sk + + self.currenttag = 0 + + def login(self, username, pwd): + if self.is_login: + return + chal = None + for repl, attrs in self.talk_iter(("/login",)): + chal = binascii.unhexlify(attrs['=ret']) + md = md5() + md.update(b'\x00') + md.update(bytes(pwd, 'utf-8')) + md.update(chal) + for _ in self.talk_iter(("/login", "=name=" + username, + "=response=00" + binascii.hexlify(md.digest()).decode('utf-8'))): + pass + self.is_login = True + + def talk_iter(self, words: Iterable): + if self.write_sentence(words) == 0: + return + while 1: + i = self.read_sentence() + if len(i) == 0: + continue + reply = i[0] + attrs = {} + for w in i[1:]: + j = w.find('=', 1) + if j == -1: + attrs[w] = '' + else: + attrs[w[:j]] = w[j + 1:] + yield (reply, attrs) + if reply == '!done': + return + + def write_sentence(self, words: Iterable): + ret = 0 + for w in words: + self.write_word(w) + ret += 1 + self.write_word('') + return ret + + def read_sentence(self): + r = [] + while 1: + w = self.read_word() + if w == '': + return r + r.append(w) + + def write_word(self, w): + if DEBUG: + print("<<< " + w) + b = bytes(w, "utf-8") + self.write_len(len(b)) + self.write_bytes(b) + + def read_word(self): + ret = self.read_bytes(self.read_len()).decode('utf-8') + if DEBUG: + print(">>> " + ret) + return ret + + def write_len(self, l): + if l < 0x80: + self.write_bytes(bytes((l,))) + elif l < 0x4000: + l |= 0x8000 + self.write_bytes(bytes(((l >> 8) & 0xff, l & 0xff))) + elif l < 0x200000: + l |= 0xC00000 + self.write_bytes(bytes(((l >> 16) & 0xff, (l >> 8) & 0xff, l & 0xff))) + elif l < 0x10000000: + l |= 0xE0000000 + self.write_bytes(bytes(((l >> 24) & 0xff, (l >> 16) & 0xff, (l >> 8) & 0xff, l & 0xff))) + else: + self.write_bytes(bytes((0xf0, (l >> 24) & 0xff, (l >> 16) & 0xff, (l >> 8) & 0xff, l & 0xff))) + + def read_len(self): + c = self.read_bytes(1)[0] + if (c & 0x80) == 0x00: + pass + elif (c & 0xC0) == 0x80: + c &= ~0xC0 + c <<= 8 + c += self.read_bytes(1)[0] + elif (c & 0xE0) == 0xC0: + c &= ~0xE0 + c <<= 8 + c += self.read_bytes(1)[0] + c <<= 8 + c += self.read_bytes(1)[0] + elif (c & 0xF0) == 0xE0: + c &= ~0xF0 + c <<= 8 + c += self.read_bytes(1)[0] + c <<= 8 + c += self.read_bytes(1)[0] + c <<= 8 + c += self.read_bytes(1)[0] + elif (c & 0xF8) == 0xF0: + c = self.read_bytes(1)[0] + c <<= 8 + c += self.read_bytes(1)[0] + c <<= 8 + c += self.read_bytes(1)[0] + c <<= 8 + c += self.read_bytes(1)[0] + return c + + def write_bytes(self, s): + n = 0 + while n < len(s): + r = self.sk.send(s[n:]) + if r == 0: + raise NasFailedResult("connection closed by remote end") + n += r + + def read_bytes(self, length): + ret = b'' + while len(ret) < length: + s = self.sk.recv(length - len(ret)) + if len(s) == 0: + raise NasFailedResult("connection closed by remote end") + ret += s + return ret + + def __del__(self): + if hasattr(self, 'sk'): + self.sk.close() + + +class TransmitterManager(BaseTransmitter, metaclass=ABCMeta): + def __init__(self, login=None, password=None, ip=None, port=None): + ip = ip or getattr(local_settings, 'NAS_IP') + if ip is None or ip == '': + raise NasNetworkError('Ip address of NAS does not specified') + if not ping(ip): + raise NasNetworkError('NAS %(ip_addr)s does not pinged' % { + 'ip_addr': ip + }) + try: + self.ar = ApiRos(ip, port) + self.ar.login(login or getattr(local_settings, 'NAS_LOGIN'), + password or getattr(local_settings, 'NAS_PASSW')) + except ConnectionRefusedError: + raise NasNetworkError('Connection to %s is Refused' % ip) + + def _exec_cmd(self, cmd: Iterable) -> list: + if not isinstance(cmd, (list, tuple)): + raise TypeError + result_iter = self.ar.talk_iter(cmd) + res = [] + for rt in result_iter: + if rt[0] == '!trap': + raise NasFailedResult(rt[1]['=message']) + res.append(rt[1]) + return res + + def _exec_cmd_iter(self, cmd: Iterable) -> Iterable: + if not isinstance(cmd, (list, tuple)): + raise TypeError + result_iter = self.ar.talk_iter(cmd) + for rt in result_iter: + if len(rt) < 2: + continue + if rt[0] == '!trap': + raise NasFailedResult(rt[1]['=message']) + yield rt + + # Build object ShapeItem from information from mikrotik + @staticmethod + def _build_shape_obj(info: dict) -> AbonStruct: + # Переводим приставку скорости Mikrotik в Mbit/s + def parse_speed(text_speed): + text_speed_digit = float(text_speed[:-1] or 0.0) + text_append = text_speed[-1:] + if text_append == 'M': + res = text_speed_digit + elif text_append == 'k': + res = text_speed_digit / 1000 + # elif text_append == 'G': + # res = text_speed_digit * 0x400 + else: + res = float(re.sub(r'[a-zA-Z]', '', text_speed)) / 1000 ** 2 + return res + + speeds = info['=max-limit'].split('/') + t = TariffStruct( + speed_in=parse_speed(speeds[1]), + speed_out=parse_speed(speeds[0]) + ) + try: + a = AbonStruct( + uid=int(info['=name'][3:]), + # FIXME: тут в разных микротиках или =target-addresses или =target + ip=info['=target'][:-3], + tariff=t, + is_active=False if info['=disabled'] == 'false' else True + ) + a.queue_id = info['=.id'] + return a + except ValueError: + pass + + +class QueueManager(TransmitterManager, metaclass=ABCMeta): + # Find queue by name + def find(self, name: str) -> AbonStruct: + ret = self._exec_cmd(('/queue/simple/print', '?name=%s' % name)) + if len(ret) > 1: + return self._build_shape_obj(ret[0]) + + def add(self, user: AbonStruct): + if not isinstance(user, AbonStruct): + raise TypeError + if user.tariff is None or not isinstance(user.tariff, TariffStruct): + return + return self._exec_cmd(('/queue/simple/add', + '=name=uid%d' % user.uid, + # FIXME: тут в разных микротиках или =target-addresses или =target + '=target=%s' % user.ip, + '=max-limit=%.3fM/%.3fM' % (user.tariff.speedOut, user.tariff.speedIn), + '=queue=MikroBILL_SFQ/MikroBILL_SFQ', + '=burst-time=1/1' + )) + + def remove(self, user: AbonStruct): + if not isinstance(user, AbonStruct): + raise TypeError + q = self.find('uid%d' % user.uid) + if q is not None: + return self._exec_cmd(('/queue/simple/remove', '=.id=' + getattr(q, 'queue_id', ''),)) + + def remove_range(self, q_ids: Iterable[str]): + try: + # q_ids = [q.queue_id for q in q_ids] + return self._exec_cmd(('/queue/simple/remove', '=numbers=' + ','.join(q_ids))) + except TypeError as e: + print(e) + + def update(self, user: AbonStruct): + if not isinstance(user, AbonStruct): + raise TypeError + if user.tariff is None or not isinstance(user.tariff, TariffStruct): + return + queue = self.find('uid%d' % user.uid) + if queue is None: + return self.add(user) + else: + mk_id = getattr(queue, 'queue_id', '') + return self._exec_cmd(('/queue/simple/set', '=.id=' + mk_id, + '=name=uid%d' % user.uid, + '=max-limit=%.3fM/%.3fM' % (user.tariff.speedOut, user.tariff.speedIn), + # FIXME: тут в разных микротиках или =target-addresses или =target + '=target=%s' % user.ip, + '=queue=MikroBILL_SFQ/MikroBILL_SFQ', + '=burst-time=1/1' + )) + + def read_queue_iter(self): + for code, dat in self._exec_cmd_iter(('/queue/simple/print', '=detail')): + if code == '!done': + return + sobj = self._build_shape_obj(dat) + if sobj is not None: + yield sobj + + def disable(self, user: AbonStruct): + if not isinstance(user, AbonStruct): + raise TypeError + q = self.find('uid%d' % user.uid) + if q is None: + self.add(user) + return self.disable(user) + else: + return self._exec_cmd(('/queue/simple/disable', '=.id=*' + getattr(q, 'queue_id', ''))) + + def enable(self, user: AbonStruct): + if not isinstance(user, AbonStruct): + raise TypeError + q = self.find('uid%d' % user.uid) + if q is None: + self.add(user) + self.enable(user) + else: + return self._exec_cmd(('/queue/simple/enable', '=.id=*' + getattr(q, 'queue_id', ''))) + + +class IpAddressListObj(IpStruct): + __slots__ = ('__ip', 'mk_id') + + def __init__(self, ip, mk_id): + super(IpAddressListObj, self).__init__(ip) + self.mk_id = str(mk_id).replace('*', '') + + +class IpAddressListManager(TransmitterManager, metaclass=ABCMeta): + def add(self, list_name: str, ip: IpStruct): + if not isinstance(ip, IpStruct): + raise TypeError + commands = ( + '/ip/firewall/address-list/add', + '=list=%s' % list_name, + '=address=%s' % ip + ) + return self._exec_cmd(commands) + + def remove(self, mk_id): + return self._exec_cmd(( + '/ip/firewall/address-list/remove', + '=.id=*' + str(mk_id).replace('*', '') + )) + + def remove_range(self, items: Iterable[IpAddressListObj]): + ids = tuple(ip.mk_id for ip in items if isinstance(ip, IpAddressListObj)) + if len(ids) > 0: + return self._exec_cmd([ + '/ip/firewall/address-list/remove', + '=numbers=*%s' % ',*'.join(ids) + ]) + + def find(self, ip: IpStruct, list_name: str): + if not isinstance(ip, IpStruct): + raise TypeError + return self._exec_cmd(( + '/ip/firewall/address-list/print', 'where', + '?list=%s' % list_name, + '?address=%s' % ip + )) + + def read_ips_iter(self, list_name: str): + ips = self._exec_cmd_iter(( + '/ip/firewall/address-list/print', 'where', + '?list=%s' % list_name, + '?dynamic=no' + )) + for code, dat in ips: + if dat != {}: + yield IpAddressListObj(dat['=address'], dat['=.id']) + + def disable(self, user: AbonStruct): + r = IpAddressListManager.find(self, user.ip, LIST_USERS_ALLOWED) + if len(r) > 1: + mk_id = r[0]['=.id'] + return self._exec_cmd(( + '/ip/firewall/address-list/disable', + '=.id=' + str(mk_id), + )) + + def enable(self, user): + r = IpAddressListManager.find(self, user.ip, LIST_USERS_ALLOWED) + if len(r) > 1: + mk_id = r[0]['=.id'] + return self._exec_cmd(( + '/ip/firewall/address-list/enable', + '=.id=' + str(mk_id), + )) + + +class MikrotikTransmitter(QueueManager, IpAddressListManager): + def add_user_range(self, user_list: VectorAbon): + for usr in user_list: + self.add_user(usr) + + def remove_user_range(self, users: VectorAbon): + if not isinstance(users, (tuple, list, set)): + raise ValueError('*users* is used twice, generator does not fit') + queue_ids = (usr.queue_id for usr in users if usr is not None) + QueueManager.remove_range(self, queue_ids) + for ip in (user.ip for user in users if isinstance(user, AbonStruct)): + ip_list_entity = IpAddressListManager.find(self, ip, LIST_USERS_ALLOWED) + if ip_list_entity is not None and len(ip_list_entity) > 1: + IpAddressListManager.remove(self, ip_list_entity[0]['=.id']) + + def add_user(self, user: AbonStruct, *args): + if not isinstance(user.ip, IpStruct): + raise TypeError + if user.tariff is None: + return + if not isinstance(user.tariff, TariffStruct): + raise TypeError + try: + QueueManager.add(self, user) + except (NasNetworkError, NasFailedResult) as e: + print('Error:', e) + try: + IpAddressListManager.add(self, LIST_USERS_ALLOWED, user.ip) + except (NasNetworkError, NasFailedResult) as e: + print('Error:', e) + + def remove_user(self, user: AbonStruct): + QueueManager.remove(self, user) + firewall_ip_list_obj = IpAddressListManager.find(self, user.ip, LIST_USERS_ALLOWED) + if firewall_ip_list_obj is not None and len(firewall_ip_list_obj) > 1: + IpAddressListManager.remove(self, firewall_ip_list_obj[0]['=.id']) + + def update_user(self, user: AbonStruct, *args): + if not isinstance(user.ip, IpStruct): + raise TypeError + + find_res = IpAddressListManager.find(self, user.ip, LIST_USERS_ALLOWED) + queue = QueueManager.find(self, 'uid%d' % user.uid) + + if not user.is_active: + # если не активен - то и обновлять не надо + # но и выключить на всяк случай надо, а то вдруг был включён + if len(find_res) > 1: + # и если найден был - то удалим ip из разрешённых + IpAddressListManager.remove(self, find_res[0]['=.id']) + if queue is not None: + QueueManager.remove(self, user) + return + + # если нет услуги то её не должно быть и в nas + if user.tariff is None or not isinstance(user.tariff, TariffStruct): + if queue is not None: + QueueManager.remove(self, user) + return + + # если не найден (mikrotik возвращает пустой словарь в списке если ничего нет) + if len(find_res) < 2: + # добавим запись об абоненте + IpAddressListManager.add(self, LIST_USERS_ALLOWED, user.ip) + + # Проверяем шейпер + + if queue is None: + QueueManager.add(self, user) + return + if queue != user: + QueueManager.update(self, user) + + def ping(self, host, count=10) -> Optional[Tuple[int, int]]: + r = self._exec_cmd(( + '/ip/arp/print', + '?address=%s' % host + )) + if r == [{}]: + return + interface = r[0]['=interface'] + r = self._exec_cmd(( + '/ping', '=address=%s' % host, '=arp-ping=yes', '=interval=100ms', '=count=%d' % count, + '=interface=%s' % interface + )) + received, sent = int(r[-2:][0]['=received']), int(r[-2:][0]['=sent']) + return received, sent + + # Тарифы хранить нам не надо, так что методы тарифов ниже не реализуем + def add_tariff_range(self, tariff_list: VectorTariff): + pass + + # соответственно и удалять тарифы не надо + def remove_tariff_range(self, tariff_list: VectorTariff): + pass + + # и добавлять тоже + def add_tariff(self, tariff: TariffStruct): + pass + + # и обновлять + def update_tariff(self, tariff: TariffStruct): + pass + + def remove_tariff(self, tid: int): + pass + + def read_users(self) -> Iterable[AbonStruct]: + # shapes is ShapeItem + allowed_ips = set(IpAddressListManager.read_ips_iter(self, LIST_USERS_ALLOWED)) + queues = tuple(q for q in QueueManager.read_queue_iter(self) if q.ip in allowed_ips) + + ips_from_queues = set(q.ip for q in queues) + + # delete ip addresses that are in firewall/address-list and there are no corresponding in queues + diff = tuple(allowed_ips - ips_from_queues) + if len(diff) > 0: + IpAddressListManager.remove_range(self, diff) + + return queues diff --git a/agent/structs.py b/agent/structs.py index 39c7bac..3b53bd2 100644 --- a/agent/structs.py +++ b/agent/structs.py @@ -1,39 +1,22 @@ # -*- coding: utf-8 -*- -from abc import ABCMeta, abstractmethod -from struct import pack, unpack, calcsize -from typing import Iterable, Optional +from abc import ABCMeta +from typing import Iterable from djing.lib import int2ip, ip2int class BaseStruct(object, metaclass=ABCMeta): - @abstractmethod - def serialize(self) -> Optional[bytes]: - """make binary""" - - @abstractmethod - def deserialize(self, data: bytes, *args): - """restore from binary""" - - def __ne__(self, other): - return not self == other + __slots__ = () class IpStruct(BaseStruct): + __slots__ = ('__ip',) + def __init__(self, ip): if type(ip) is int: self.__ip = ip else: self.__ip = ip2int(str(ip)) - def serialize(self) -> Optional[bytes]: - dt = pack("!I", int(self.__ip)) - return dt - - def deserialize(self, data: bytes, *args): - dt = unpack("!I", data) - self.__ip = int(dt[0]) - return self - def get_int(self): return self.__ip @@ -54,26 +37,17 @@ class IpStruct(BaseStruct): # Как обслуживается абонент class TariffStruct(BaseStruct): + __slots__ = ('tid', 'speedIn', 'speedOut') + def __init__(self, tariff_id=0, speed_in=None, speed_out=None): self.tid = int(tariff_id) self.speedIn = speed_in or 0 self.speedOut = speed_out or 0 - def serialize(self) -> Optional[bytes]: - dt = pack("!Iff", int(self.tid), float(self.speedIn), float(self.speedOut)) - return dt - # Да, если все значения нулевые def is_empty(self): return self.tid == 0 and self.speedIn == 0.001 and self.speedOut == 0.001 - def deserialize(self, data: bytes, *args): - dt = unpack("!Iff", data) - self.tid = int(dt[0]) - self.speedIn = float(dt[1]) - self.speedOut = float(dt[2]) - return self - def __eq__(self, other): # не сравниваем id, т.к. тарифы с одинаковыми скоростями для NAS одинаковы # Да и иногда не удобно доставать из nas id тарифы из базы @@ -90,32 +64,14 @@ class TariffStruct(BaseStruct): # Абонент из базы class AbonStruct(BaseStruct): + __slots__ = ('uid', 'ip', 'tariff', 'is_active', 'queue_id') + def __init__(self, uid=0, ip=None, tariff=None, is_active=True): self.uid = int(uid or 0) self.ip = IpStruct(ip) self.tariff = tariff self.is_active = is_active - - def serialize(self) -> Optional[bytes]: - if self.tariff is None: - return - if not isinstance(self.tariff, TariffStruct): - raise TypeError('Instance must be TariffStruct') - if not isinstance(self.ip, IpStruct): - raise TypeError('Instance must be IpStruct') - dt = pack("!LII?", self.uid, int(self.ip), self.tariff.tid, self.is_active) - return dt - - def deserialize(self, data: bytes, tariff=None): - dt = unpack("!LII?", data) - self.uid = dt[0] - self.ip = IpStruct(dt[1]) - if tariff is not None: - if not isinstance(tariff, TariffStruct): - raise TypeError - self.tariff = tariff - self.is_active = dt['3'] - return self + self.queue_id = 0 def __eq__(self, other): if not isinstance(other, AbonStruct): @@ -133,22 +89,12 @@ class AbonStruct(BaseStruct): # Правило шейпинга в фаере, или ещё можно сказать услуга абонента на NAS class ShapeItem(BaseStruct): + __slots__ = ('abon', 'sid') + def __init__(self, abon, sid): self.abon = abon self.sid = sid - def serialize(self) -> Optional[bytes]: - abon_pack = self.abon.serialize() - dt = pack('!L', self.sid) - return dt + abon_pack - - def deserialize(self, data: bytes, *args): - sz = calcsize('!L') - dt = unpack('!L', data[:sz]) - self.sid = dt - self.abon.deserialize(data[sz:]) - return self - def __eq__(self, other): if not isinstance(other, ShapeItem): raise TypeError