"""Classes and methods for handling filter operators."""
import re
from sqlalchemy.dialects.postgresql import JSONB
[docs]class FilterRegistry:
"""Registry of allowed filter operators.
Attributes:
data (dict): data store for filter op information.
"""
def __init__(self):
self.data = {}
self.register_standard_filters()
[docs] def register_standard_filters(self):
"""Register standard supported filter operators."""
for comparator_name in (
'__eq__',
'__ne__',
'startswith',
'endswith',
'contains',
'__lt__',
'__gt__',
'__le__',
'__ge__'
):
self.register(comparator_name)
# Transform '%' to '*' for like and ilike
for comparator_name in (
'like',
'ilike'
):
self.register(
comparator_name,
value_transform=lambda val: re.sub(r'\*', '%', val)
)
# JSONB specific operators
for comparator_name in (
'contains',
'contained_by',
'has_all',
'has_any',
'has_key'
):
self.register(
comparator_name,
column_type=JSONB
)
[docs] def register(
self,
comparator_name,
filter_name=None,
value_transform=lambda val: val,
column_type='__ALL__'
):
""" Register a new filter operator.
Args:
comparator_name (str): name of sqlalchemy comparator method.
filter_name(str): name of filter param in URL. Defaults to
comparator_name with any occurrences of '__' removed (so '__eq__'
defaults to 'eq', for example).
value_transform (func): function taking the filter value as the only
argument and returning a transformed value. Defaults to a
function returning an unmodified value.
column_type (class): type (class object, not name) for which this
operator is to be registered. Defaults to '__ALL__' (the string)
which makes the operator valid for all column types.
"""
try:
registry = self.data[column_type]
except KeyError:
registry = self.data[column_type] = {}
registry[filter_name or comparator_name.replace('__', '')] = {
'comparator_name': comparator_name,
'value_transform': value_transform
}
[docs] def get_filter(self, column_type, filter_name):
"""Get dictionary of filter information.
Args:
column_type (class): type (class object, not name) of a Column.
filter_name(str): name of filter param in URL.
Returns:
dict: information dictionary for filter. Type specific entry if it
exists, entry from '__ALL__' if it does not.
Raises:
KeyError: if filter_name is not in the type specific or ALL sections.
"""
try:
return self.data[column_type][filter_name]
except KeyError:
return self.data['__ALL__'][filter_name]
[docs] def valid_filter_names(self, column_types=None):
"""Return set of supported filter operator names."""
ops = set()
column_types = set(column_types or {k for k in self.data})
column_types.add('__ALL__')
for ctype in column_types:
ops |= self.data[ctype].keys()
return ops