import asyncio
import datetime
import functools
import json
import math
import os
import random
from uuid import UUID
from aiohttp import web
from bson import ObjectId
from prozorro_sale.tools.environment import Environment # noqa
from prozorro_sale.tools.logger import get_custom_logger
from prozorro_sale.tools.middlewares import request_unpack_params # noqa
LOG = get_custom_logger(__name__)
[docs]class classproperty:
"""Classproperty wrapper"""
def __init__(self, getter):
self.getter = getter
def __get__(self, instance, owner):
return self.getter(owner)
[docs]class ConcurrencyError(RuntimeError):
"""Base class for Concurrency exceptions."""
pass
[docs]class ApplicationError(Exception):
"""Base class for Application exceptions.
"""
pass
[docs]def check_required_fields(required_fields: set, data: dict):
"""
Method for checking expected values in data.
Args:
required_fields (set): expected inclusion in data
data (dict): Data
Returns:
dict: required fields
"""
data_errors = {}
error_msg = 'Required field'
for field in required_fields:
data_to_search = data
for nested_field in field.split('.'):
if nested_field not in data_to_search:
data_errors[field] = error_msg
break
data_to_search = data_to_search[nested_field]
return data_errors
[docs]def retry_on(error, times=5):
"""Request retries wrapper.
"""
def wrapper(func):
@functools.wraps(func)
async def handler(request, *args, **kwargs):
for _try in range(times):
try:
return await func(request, *args, **kwargs)
except error:
LOG.info('Caught concurrency error')
await asyncio.sleep(math.exp(_try / 3) - 1 + random.random() / 10) # nosec
continue
raise error('Retried too many times. abort request')
return handler
return wrapper
[docs]@web.middleware
async def retry_on_concurrency_error_middleware(request, handler):
"""Middleware for request retries if caught concurrency.
"""
if request.method != 'GET':
handler = retry_on(ConcurrencyError)(handler)
return await handler(request)
[docs]def check_required_env_vars(required_vars: set):
"""
Method to check required environment variables.
Args:
required_vars (set): Data
Raises:
ApplicationError: if not defined variables presented
"""
if undefined_vars := (required_vars - set(os.environ)):
raise ApplicationError(f'Env vars: {undefined_vars} are not defined.')
[docs]class DefaultTypesEncoder(json.JSONEncoder):
"""Custom encoder for json
"""
[docs] def default(self, obj):
if isinstance(obj, (datetime.datetime, UUID, ObjectId)):
return str(obj)
return super().default(obj)
[docs]def default_type_encoder(data):
"""
Method for json dumps with custom encoder.
Args:
data(dict):
Returns:
str:
"""
return json.dumps(data, cls=DefaultTypesEncoder)