operations.py 13.6 KB
Newer Older
Thaksaporn's avatar
Thaksaporn committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
import datetime
import decimal
import uuid
from functools import lru_cache
from itertools import chain

from django.conf import settings
from django.core.exceptions import FieldError
from django.db import utils
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models import aggregates, fields
from django.db.models.expressions import Col
from django.utils import timezone
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.duration import duration_microseconds
from django.utils.functional import cached_property


class DatabaseOperations(BaseDatabaseOperations):
    cast_char_field_without_max_length = 'text'
    cast_data_types = {
        'DateField': 'TEXT',
        'DateTimeField': 'TEXT',
    }
    explain_prefix = 'EXPLAIN QUERY PLAN'

    def bulk_batch_size(self, fields, objs):
        """
        SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
        999 variables per query.

        If there's only a single field to insert, the limit is 500
        (SQLITE_MAX_COMPOUND_SELECT).
        """
        if len(fields) == 1:
            return 500
        elif len(fields) > 1:
            return self.connection.features.max_query_params // len(fields)
        else:
            return len(objs)

    def check_expression_support(self, expression):
        bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField)
        bad_aggregates = (aggregates.Sum, aggregates.Avg, aggregates.Variance, aggregates.StdDev)
        if isinstance(expression, bad_aggregates):
            for expr in expression.get_source_expressions():
                try:
                    output_field = expr.output_field
                except FieldError:
                    # Not every subexpression has an output_field which is fine
                    # to ignore.
                    pass
                else:
                    if isinstance(output_field, bad_fields):
                        raise utils.NotSupportedError(
                            'You cannot use Sum, Avg, StdDev, and Variance '
                            'aggregations on date/time fields in sqlite3 '
                            'since date/time is saved as text.'
                        )
        if isinstance(expression, aggregates.Aggregate) and len(expression.source_expressions) > 1:
            raise utils.NotSupportedError(
                "SQLite doesn't support DISTINCT on aggregate functions "
                "accepting multiple arguments."
            )

    def date_extract_sql(self, lookup_type, field_name):
        """
        Support EXTRACT with a user-defined function django_date_extract()
        that's registered in connect(). Use single quotes because this is a
        string and could otherwise cause a collision with a field name.
        """
        return "django_date_extract('%s', %s)" % (lookup_type.lower(), field_name)

    def date_interval_sql(self, timedelta):
        return str(duration_microseconds(timedelta))

    def format_for_duration_arithmetic(self, sql):
        """Do nothing since formatting is handled in the custom function."""
        return sql

    def date_trunc_sql(self, lookup_type, field_name):
        return "django_date_trunc('%s', %s)" % (lookup_type.lower(), field_name)

    def time_trunc_sql(self, lookup_type, field_name):
        return "django_time_trunc('%s', %s)" % (lookup_type.lower(), field_name)

    def _convert_tzname_to_sql(self, tzname):
        return "'%s'" % tzname if settings.USE_TZ else 'NULL'

    def datetime_cast_date_sql(self, field_name, tzname):
        return "django_datetime_cast_date(%s, %s)" % (
            field_name, self._convert_tzname_to_sql(tzname),
        )

    def datetime_cast_time_sql(self, field_name, tzname):
        return "django_datetime_cast_time(%s, %s)" % (
            field_name, self._convert_tzname_to_sql(tzname),
        )

    def datetime_extract_sql(self, lookup_type, field_name, tzname):
        return "django_datetime_extract('%s', %s, %s)" % (
            lookup_type.lower(), field_name, self._convert_tzname_to_sql(tzname),
        )

    def datetime_trunc_sql(self, lookup_type, field_name, tzname):
        return "django_datetime_trunc('%s', %s, %s)" % (
            lookup_type.lower(), field_name, self._convert_tzname_to_sql(tzname),
        )

    def time_extract_sql(self, lookup_type, field_name):
        return "django_time_extract('%s', %s)" % (lookup_type.lower(), field_name)

    def pk_default_value(self):
        return "NULL"

    def _quote_params_for_last_executed_query(self, params):
        """
        Only for last_executed_query! Don't use this to execute SQL queries!
        """
        # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
        # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
        # number of return values, default = 2000). Since Python's sqlite3
        # module doesn't expose the get_limit() C API, assume the default
        # limits are in effect and split the work in batches if needed.
        BATCH_SIZE = 999
        if len(params) > BATCH_SIZE:
            results = ()
            for index in range(0, len(params), BATCH_SIZE):
                chunk = params[index:index + BATCH_SIZE]
                results += self._quote_params_for_last_executed_query(chunk)
            return results

        sql = 'SELECT ' + ', '.join(['QUOTE(?)'] * len(params))
        # Bypass Django's wrappers and use the underlying sqlite3 connection
        # to avoid logging this query - it would trigger infinite recursion.
        cursor = self.connection.connection.cursor()
        # Native sqlite3 cursors cannot be used as context managers.
        try:
            return cursor.execute(sql, params).fetchone()
        finally:
            cursor.close()

    def last_executed_query(self, cursor, sql, params):
        # Python substitutes parameters in Modules/_sqlite/cursor.c with:
        # pysqlite_statement_bind_parameters(self->statement, parameters, allow_8bit_chars);
        # Unfortunately there is no way to reach self->statement from Python,
        # so we quote and substitute parameters manually.
        if params:
            if isinstance(params, (list, tuple)):
                params = self._quote_params_for_last_executed_query(params)
            else:
                values = tuple(params.values())
                values = self._quote_params_for_last_executed_query(values)
                params = dict(zip(params, values))
            return sql % params
        # For consistency with SQLiteCursorWrapper.execute(), just return sql
        # when there are no parameters. See #13648 and #17158.
        else:
            return sql

    def quote_name(self, name):
        if name.startswith('"') and name.endswith('"'):
            return name  # Quoting once is enough.
        return '"%s"' % name

    def no_limit_value(self):
        return -1

    def __references_graph(self, table_name):
        query = """
        WITH tables AS (
            SELECT %s name
            UNION
            SELECT sqlite_master.name
            FROM sqlite_master
            JOIN tables ON (sql REGEXP %s || tables.name || %s)
        ) SELECT name FROM tables;
        """
        params = (
            table_name,
            r'(?i)\s+references\s+("|\')?',
            r'("|\')?\s*\(',
        )
        with self.connection.cursor() as cursor:
            results = cursor.execute(query, params)
            return [row[0] for row in results.fetchall()]

    @cached_property
    def _references_graph(self):
        # 512 is large enough to fit the ~330 tables (as of this writing) in
        # Django's test suite.
        return lru_cache(maxsize=512)(self.__references_graph)

    def sql_flush(self, style, tables, sequences, allow_cascade=False):
        if tables and allow_cascade:
            # Simulate TRUNCATE CASCADE by recursively collecting the tables
            # referencing the tables to be flushed.
            tables = set(chain.from_iterable(self._references_graph(table) for table in tables))
        sql = ['%s %s %s;' % (
            style.SQL_KEYWORD('DELETE'),
            style.SQL_KEYWORD('FROM'),
            style.SQL_FIELD(self.quote_name(table))
        ) for table in tables]
        # Note: No requirement for reset of auto-incremented indices (cf. other
        # sql_flush() implementations). Just return SQL at this point
        return sql

    def adapt_datetimefield_value(self, value):
        if value is None:
            return None

        # Expression values are adapted by the database.
        if hasattr(value, 'resolve_expression'):
            return value

        # SQLite doesn't support tz-aware datetimes
        if timezone.is_aware(value):
            if settings.USE_TZ:
                value = timezone.make_naive(value, self.connection.timezone)
            else:
                raise ValueError("SQLite backend does not support timezone-aware datetimes when USE_TZ is False.")

        return str(value)

    def adapt_timefield_value(self, value):
        if value is None:
            return None

        # Expression values are adapted by the database.
        if hasattr(value, 'resolve_expression'):
            return value

        # SQLite doesn't support tz-aware datetimes
        if timezone.is_aware(value):
            raise ValueError("SQLite backend does not support timezone-aware times.")

        return str(value)

    def get_db_converters(self, expression):
        converters = super().get_db_converters(expression)
        internal_type = expression.output_field.get_internal_type()
        if internal_type == 'DateTimeField':
            converters.append(self.convert_datetimefield_value)
        elif internal_type == 'DateField':
            converters.append(self.convert_datefield_value)
        elif internal_type == 'TimeField':
            converters.append(self.convert_timefield_value)
        elif internal_type == 'DecimalField':
            converters.append(self.get_decimalfield_converter(expression))
        elif internal_type == 'UUIDField':
            converters.append(self.convert_uuidfield_value)
        elif internal_type in ('NullBooleanField', 'BooleanField'):
            converters.append(self.convert_booleanfield_value)
        return converters

    def convert_datetimefield_value(self, value, expression, connection):
        if value is not None:
            if not isinstance(value, datetime.datetime):
                value = parse_datetime(value)
            if settings.USE_TZ and not timezone.is_aware(value):
                value = timezone.make_aware(value, self.connection.timezone)
        return value

    def convert_datefield_value(self, value, expression, connection):
        if value is not None:
            if not isinstance(value, datetime.date):
                value = parse_date(value)
        return value

    def convert_timefield_value(self, value, expression, connection):
        if value is not None:
            if not isinstance(value, datetime.time):
                value = parse_time(value)
        return value

    def get_decimalfield_converter(self, expression):
        # SQLite stores only 15 significant digits. Digits coming from
        # float inaccuracy must be removed.
        create_decimal = decimal.Context(prec=15).create_decimal_from_float
        if isinstance(expression, Col):
            quantize_value = decimal.Decimal(1).scaleb(-expression.output_field.decimal_places)

            def converter(value, expression, connection):
                if value is not None:
                    return create_decimal(value).quantize(quantize_value, context=expression.output_field.context)
        else:
            def converter(value, expression, connection):
                if value is not None:
                    return create_decimal(value)
        return converter

    def convert_uuidfield_value(self, value, expression, connection):
        if value is not None:
            value = uuid.UUID(value)
        return value

    def convert_booleanfield_value(self, value, expression, connection):
        return bool(value) if value in (1, 0) else value

    def bulk_insert_sql(self, fields, placeholder_rows):
        return " UNION ALL ".join(
            "SELECT %s" % ", ".join(row)
            for row in placeholder_rows
        )

    def combine_expression(self, connector, sub_expressions):
        # SQLite doesn't have a ^ operator, so use the user-defined POWER
        # function that's registered in connect().
        if connector == '^':
            return 'POWER(%s)' % ','.join(sub_expressions)
        return super().combine_expression(connector, sub_expressions)

    def combine_duration_expression(self, connector, sub_expressions):
        if connector not in ['+', '-']:
            raise utils.DatabaseError('Invalid connector for timedelta: %s.' % connector)
        fn_params = ["'%s'" % connector] + sub_expressions
        if len(fn_params) > 3:
            raise ValueError('Too many params for timedelta operations.')
        return "django_format_dtdelta(%s)" % ', '.join(fn_params)

    def integer_field_range(self, internal_type):
        # SQLite doesn't enforce any integer constraints
        return (None, None)

    def subtract_temporals(self, internal_type, lhs, rhs):
        lhs_sql, lhs_params = lhs
        rhs_sql, rhs_params = rhs
        if internal_type == 'TimeField':
            return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), lhs_params + rhs_params
        return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), lhs_params + rhs_params

    def insert_statement(self, ignore_conflicts=False):
        return 'INSERT OR IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)