"""Django-esque declarative fields for sqlalchemy."""
from contextlib import suppress
import sqlalchemy as sa
from django import forms as djangoforms
from django.core import validators as django_validators
from django.db.backends.base import operations
from django.forms import fields as djangofields
from django.utils.module_loading import import_string
from .url import DIALECT_MAP_TO_DJANGO
__all__ = [
"BigIntegerField",
"BinaryField",
"BooleanField",
"CharField",
"DateField",
"DateTimeField",
"DecimalField",
"DurationField",
"EmailField",
"EnumField",
"Field",
"FloatField",
"IntegerField",
"NullBooleanField",
"SlugField",
"SmallIntegerField",
"TextField",
"TimeField",
"TimestampField",
"URLField",
]
[docs]class Field(sa.Column):
"""Base django-esque field."""
default_validators = []
type_class = None
form_class = None
widget_class = None
def __init__(self, *args, **kwargs):
self.db = kwargs.pop("db", None)
name = None
args = list(args)
if args and isinstance(args[0], str):
name = args.pop(0)
column_type = kwargs.pop("type_", None)
if args and hasattr(args[0], "_sqla_type"):
column_type = args.pop(0)
validators = kwargs.pop("validators", [])
required = kwargs.pop("required", None)
column_kwargs = self.get_column_kwargs(kwargs)
if column_type is None:
type_class = self.get_type_class(kwargs)
type_kwargs = self.get_type_kwargs(type_class, kwargs)
column_type = self.get_type(type_class, type_kwargs)
column_kwargs["type_"] = column_type
column_kwargs.setdefault("name", name)
super().__init__(*args, **column_kwargs)
self.info["validators"] = self.get_validators(validators)
self.info["required"] = required if required is not None else not self.nullable
label = kwargs.pop("label", None)
if label:
self.info["label"] = label
self.info["form_class"] = kwargs.pop("form_class", None) or self.get_form_class(kwargs)
if self.widget_class:
self.info["widget_class"] = self.widget_class
[docs] def get_type_kwargs(self, type_class, kwargs):
"""Returns sqlalchemy type kwargs."""
type_args = sa.util.get_cls_kwargs(type_class)
return {k: kwargs.pop(k) for k in type_args if not k.startswith("*") and k in kwargs}
[docs] def get_column_kwargs(self, kwargs):
"""Returns sqlalchemy column kwargs."""
column_args = [
"autoincrement",
"comment",
"default",
"doc",
"index",
"info",
"key",
"name",
"onupdate",
"server_default",
"server_onupdate",
"type_",
"unique",
"_proxies",
]
column_kwargs = {k: kwargs.pop(k) for k in column_args if k in kwargs}
column_kwargs["primary_key"] = kwargs.pop("primary_key", False)
column_kwargs["nullable"] = kwargs.pop("nullable", not column_kwargs["primary_key"])
return column_kwargs
[docs] def get_type_class(self, kwargs):
"""Returns sqlalchemy column type."""
return self.type_class
[docs] def get_validators(self, validators):
"""Returns django validators for the field."""
return self.default_validators[:] + validators
[docs] def get_type(self, type_class, type_kwargs):
"""Returns sqlalchemy column type instance for the field."""
return type_class(**type_kwargs)
[docs]class BooleanField(Field):
"""Django like boolean field."""
type_class = sa.Boolean
form_class = djangofields.BooleanField
[docs] def get_type_kwargs(self, type_class, kwargs):
type_kwargs = super().get_type_kwargs(type_class, kwargs)
type_kwargs["name"] = kwargs.pop("constraint_name", None)
return type_kwargs
[docs] def get_column_kwargs(self, kwargs):
column_kwargs = super().get_column_kwargs(kwargs)
column_kwargs["nullable"] = False
column_kwargs.setdefault("default", False)
return column_kwargs
[docs]class CharField(Field):
"""Django like char field."""
type_class = sa.String
length_is_required = True
form_class = djangofields.CharField
[docs] def get_type_kwargs(self, type_class, kwargs):
type_kwargs = super().get_type_kwargs(type_class, kwargs)
type_kwargs["length"] = type_kwargs.get("length") or kwargs.get("max_length")
if not type_kwargs["length"] and self.length_is_required:
raise TypeError('Missing length parameter. Must provide either "max_length" or "length" parameter')
return type_kwargs
[docs] def get_validators(self, validators):
validators = super().get_validators(validators)
if self.type.length and not any(isinstance(i, django_validators.MaxLengthValidator) for i in validators):
validators.append(django_validators.MaxLengthValidator(self.type.length))
return validators
[docs]class DateField(Field):
"""Django like date field."""
type_class = sa.Date
form_class = djangofields.DateField
[docs]class DateTimeField(Field):
"""Django like datetime field."""
type_class = sa.DateTime
form_class = djangofields.DateTimeField
[docs]class DurationField(Field):
"""Django like duration field."""
type_class = sa.Interval
form_class = djangofields.DurationField
[docs]class DecimalField(Field):
"""Django like decimal field."""
type_class = sa.Numeric
form_class = djangofields.DecimalField
[docs] def get_type_kwargs(self, type_class, kwargs):
type_kwargs = super().get_type_kwargs(type_class, kwargs)
type_kwargs.setdefault("precision", kwargs.pop("max_digits", None))
type_kwargs.setdefault("scale", kwargs.pop("decimal_places", None))
type_kwargs["asdecimal"] = True
return type_kwargs
[docs] def get_validators(self, validators):
return super().get_validators(validators) + [
django_validators.DecimalValidator(self.type.precision, self.type.scale)
]
[docs]class EmailField(CharField):
"""Django like email field."""
default_validators = [django_validators.validate_email]
form_class = djangofields.EmailField
[docs]class EnumField(Field):
"""Django like choice field that uses an enum sqlalchemy type."""
type_class = sa.Enum
[docs] def get_type_kwargs(self, type_class, kwargs):
type_kwargs = super().get_type_kwargs(type_class, kwargs)
choices = kwargs.pop("choices", None) or kwargs.pop("enum_class", None)
type_kwargs["choices"] = choices if len(choices) > 1 and isinstance(choices, (list, tuple)) else [choices]
type_kwargs["name"] = kwargs.pop("constraint_name", None)
enum_args = [
"native_enum",
"create_constraint",
"values_callable",
"convert_unicode",
"validate_strings",
"schema",
"quote",
"_create_events",
]
for k in enum_args:
if k in kwargs:
type_kwargs[k] = kwargs.pop(k)
return type_kwargs
[docs] def get_type(self, type_class, type_kwargs):
choices = type_kwargs.pop("choices")
return type_class(*choices, **type_kwargs)
[docs]class FloatField(Field):
"""Django like float field."""
type_class = sa.Float
form_class = djangofields.FloatField
[docs] def get_type_kwargs(self, type_class, kwargs):
type_kwargs = super().get_type_kwargs(type_class, kwargs)
type_kwargs["precision"] = type_kwargs.get("precision") or kwargs.pop("max_digits", None)
return type_kwargs
class ValidateIntegerFieldMixin:
"""A mixin that provides default min/max validators for integer types."""
def get_django_dialect_ranges(self):
"""Returns django min/max ranges using current dialect."""
ops = operations.BaseDatabaseOperations
with suppress(ImportError):
ops = (
import_string(DIALECT_MAP_TO_DJANGO.get(self.db.url.get_dialect().name) + ".base.DatabaseOperations")
if self.db
else operations.BaseDatabaseOperations
)
return ops.integer_field_ranges
def get_dialect_range(self):
"""Returns the min/max ranges supported by dialect."""
return self.get_django_dialect_ranges()[self.__class__.__name__]
def get_validators(self, validators):
"""Returns django integer min/max validators supported by the
database."""
validators = super().get_validators(validators)
min_int, max_int = self.get_dialect_range()
if not any(isinstance(i, django_validators.MinValueValidator) for i in validators):
validators.append(django_validators.MinValueValidator(min_int))
if not any(isinstance(i, django_validators.MaxValueValidator) for i in validators):
validators.append(django_validators.MaxValueValidator(max_int))
return validators
[docs]class IntegerField(ValidateIntegerFieldMixin, Field):
"""Django like integer field."""
default_validators = [django_validators.validate_integer]
type_class = sa.Integer
form_class = djangofields.IntegerField
[docs]class BigIntegerField(ValidateIntegerFieldMixin, Field):
"""Django like big integer field."""
default_validators = [django_validators.validate_integer]
type_class = sa.BigInteger
form_class = djangofields.IntegerField
[docs]class SmallIntegerField(ValidateIntegerFieldMixin, Field):
"""Django like small integer field."""
default_validators = [django_validators.validate_integer]
type_class = sa.SmallInteger
form_class = djangofields.IntegerField
[docs]class NullBooleanField(BooleanField):
"""Django like nullable boolean field."""
form_class = djangofields.NullBooleanField
[docs] def get_column_kwargs(self, kwargs):
kwargs["nullable"] = True
return kwargs
[docs]class SlugField(CharField):
"""Django like slug field."""
default_validators = [django_validators.validate_slug]
form_class = djangofields.SlugField
[docs]class TextField(CharField):
"""Django like text field."""
type_class = sa.Text
length_is_required = False
form_class = djangofields.CharField
widget_class = djangoforms.Textarea
[docs]class TimeField(Field):
"""Django like time field."""
type_class = sa.Time
form_class = djangofields.TimeField
[docs]class TimestampField(DateTimeField):
"""Django like datetime field that uses timestamp sqlalchemy type."""
type_class = sa.TIMESTAMP
[docs]class URLField(CharField):
"""Django like url field."""
default_validators = [django_validators.URLValidator()]
form_class = djangofields.URLField
[docs]class BinaryField(Field):
"""Django like binary field."""
type_class = sa.LargeBinary
length_is_required = False