Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import functools
   5import itertools
   6import typing as t
   7from collections import deque
   8from decimal import Decimal
   9
  10import sqlglot
  11from sqlglot import Dialect, exp
  12from sqlglot.helper import first, merge_ranges, while_changing
  13from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  14
  15if t.TYPE_CHECKING:
  16    from sqlglot.dialects.dialect import DialectType
  17
  18    DateTruncBinaryTransform = t.Callable[
  19        [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
  20    ]
  21
  22# Final means that an expression should not be simplified
  23FINAL = "final"
  24
  25
  26class UnsupportedUnit(Exception):
  27    pass
  28
  29
  30def simplify(
  31    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
  32):
  33    """
  34    Rewrite sqlglot AST to simplify expressions.
  35
  36    Example:
  37        >>> import sqlglot
  38        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  39        >>> simplify(expression).sql()
  40        'TRUE'
  41
  42    Args:
  43        expression (sqlglot.Expression): expression to simplify
  44        constant_propagation: whether the constant propagation rule should be used
  45
  46    Returns:
  47        sqlglot.Expression: simplified expression
  48    """
  49
  50    dialect = Dialect.get_or_raise(dialect)
  51
  52    def _simplify(expression, root=True):
  53        if expression.meta.get(FINAL):
  54            return expression
  55
  56        # group by expressions cannot be simplified, for example
  57        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  58        # the projection must exactly match the group by key
  59        group = expression.args.get("group")
  60
  61        if group and hasattr(expression, "selects"):
  62            groups = set(group.expressions)
  63            group.meta[FINAL] = True
  64
  65            for e in expression.selects:
  66                for node in e.walk():
  67                    if node in groups:
  68                        e.meta[FINAL] = True
  69                        break
  70
  71            having = expression.args.get("having")
  72            if having:
  73                for node in having.walk():
  74                    if node in groups:
  75                        having.meta[FINAL] = True
  76                        break
  77
  78        # Pre-order transformations
  79        node = expression
  80        node = rewrite_between(node)
  81        node = uniq_sort(node, root)
  82        node = absorb_and_eliminate(node, root)
  83        node = simplify_concat(node)
  84        node = simplify_conditionals(node)
  85
  86        if constant_propagation:
  87            node = propagate_constants(node, root)
  88
  89        exp.replace_children(node, lambda e: _simplify(e, False))
  90
  91        # Post-order transformations
  92        node = simplify_not(node)
  93        node = flatten(node)
  94        node = simplify_connectors(node, root)
  95        node = remove_complements(node, root)
  96        node = simplify_coalesce(node)
  97        node.parent = expression.parent
  98        node = simplify_literals(node, root)
  99        node = simplify_equality(node)
 100        node = simplify_parens(node)
 101        node = simplify_datetrunc(node, dialect)
 102        node = sort_comparison(node)
 103        node = simplify_startswith(node)
 104
 105        if root:
 106            expression.replace(node)
 107        return node
 108
 109    expression = while_changing(expression, _simplify)
 110    remove_where_true(expression)
 111    return expression
 112
 113
 114def catch(*exceptions):
 115    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 116
 117    def decorator(func):
 118        def wrapped(expression, *args, **kwargs):
 119            try:
 120                return func(expression, *args, **kwargs)
 121            except exceptions:
 122                return expression
 123
 124        return wrapped
 125
 126    return decorator
 127
 128
 129def rewrite_between(expression: exp.Expression) -> exp.Expression:
 130    """Rewrite x between y and z to x >= y AND x <= z.
 131
 132    This is done because comparison simplification is only done on lt/lte/gt/gte.
 133    """
 134    if isinstance(expression, exp.Between):
 135        negate = isinstance(expression.parent, exp.Not)
 136
 137        expression = exp.and_(
 138            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 139            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 140            copy=False,
 141        )
 142
 143        if negate:
 144            expression = exp.paren(expression, copy=False)
 145
 146    return expression
 147
 148
 149COMPLEMENT_COMPARISONS = {
 150    exp.LT: exp.GTE,
 151    exp.GT: exp.LTE,
 152    exp.LTE: exp.GT,
 153    exp.GTE: exp.LT,
 154    exp.EQ: exp.NEQ,
 155    exp.NEQ: exp.EQ,
 156}
 157
 158
 159def simplify_not(expression):
 160    """
 161    Demorgan's Law
 162    NOT (x OR y) -> NOT x AND NOT y
 163    NOT (x AND y) -> NOT x OR NOT y
 164    """
 165    if isinstance(expression, exp.Not):
 166        this = expression.this
 167        if is_null(this):
 168            return exp.null()
 169        if this.__class__ in COMPLEMENT_COMPARISONS:
 170            return COMPLEMENT_COMPARISONS[this.__class__](
 171                this=this.this, expression=this.expression
 172            )
 173        if isinstance(this, exp.Paren):
 174            condition = this.unnest()
 175            if isinstance(condition, exp.And):
 176                return exp.paren(
 177                    exp.or_(
 178                        exp.not_(condition.left, copy=False),
 179                        exp.not_(condition.right, copy=False),
 180                        copy=False,
 181                    )
 182                )
 183            if isinstance(condition, exp.Or):
 184                return exp.paren(
 185                    exp.and_(
 186                        exp.not_(condition.left, copy=False),
 187                        exp.not_(condition.right, copy=False),
 188                        copy=False,
 189                    )
 190                )
 191            if is_null(condition):
 192                return exp.null()
 193        if always_true(this):
 194            return exp.false()
 195        if is_false(this):
 196            return exp.true()
 197        if isinstance(this, exp.Not):
 198            # double negation
 199            # NOT NOT x -> x
 200            return this.this
 201    return expression
 202
 203
 204def flatten(expression):
 205    """
 206    A AND (B AND C) -> A AND B AND C
 207    A OR (B OR C) -> A OR B OR C
 208    """
 209    if isinstance(expression, exp.Connector):
 210        for node in expression.args.values():
 211            child = node.unnest()
 212            if isinstance(child, expression.__class__):
 213                node.replace(child)
 214    return expression
 215
 216
 217def simplify_connectors(expression, root=True):
 218    def _simplify_connectors(expression, left, right):
 219        if left == right:
 220            return left
 221        if isinstance(expression, exp.And):
 222            if is_false(left) or is_false(right):
 223                return exp.false()
 224            if is_null(left) or is_null(right):
 225                return exp.null()
 226            if always_true(left) and always_true(right):
 227                return exp.true()
 228            if always_true(left):
 229                return right
 230            if always_true(right):
 231                return left
 232            return _simplify_comparison(expression, left, right)
 233        elif isinstance(expression, exp.Or):
 234            if always_true(left) or always_true(right):
 235                return exp.true()
 236            if is_false(left) and is_false(right):
 237                return exp.false()
 238            if (
 239                (is_null(left) and is_null(right))
 240                or (is_null(left) and is_false(right))
 241                or (is_false(left) and is_null(right))
 242            ):
 243                return exp.null()
 244            if is_false(left):
 245                return right
 246            if is_false(right):
 247                return left
 248            return _simplify_comparison(expression, left, right, or_=True)
 249
 250    if isinstance(expression, exp.Connector):
 251        return _flat_simplify(expression, _simplify_connectors, root)
 252    return expression
 253
 254
 255LT_LTE = (exp.LT, exp.LTE)
 256GT_GTE = (exp.GT, exp.GTE)
 257
 258COMPARISONS = (
 259    *LT_LTE,
 260    *GT_GTE,
 261    exp.EQ,
 262    exp.NEQ,
 263    exp.Is,
 264)
 265
 266INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 267    exp.LT: exp.GT,
 268    exp.GT: exp.LT,
 269    exp.LTE: exp.GTE,
 270    exp.GTE: exp.LTE,
 271}
 272
 273NONDETERMINISTIC = (exp.Rand, exp.Randn)
 274
 275
 276def _simplify_comparison(expression, left, right, or_=False):
 277    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 278        ll, lr = left.args.values()
 279        rl, rr = right.args.values()
 280
 281        largs = {ll, lr}
 282        rargs = {rl, rr}
 283
 284        matching = largs & rargs
 285        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 286
 287        if matching and columns:
 288            try:
 289                l = first(largs - columns)
 290                r = first(rargs - columns)
 291            except StopIteration:
 292                return expression
 293
 294            if l.is_number and r.is_number:
 295                l = float(l.name)
 296                r = float(r.name)
 297            elif l.is_string and r.is_string:
 298                l = l.name
 299                r = r.name
 300            else:
 301                l = extract_date(l)
 302                if not l:
 303                    return None
 304                r = extract_date(r)
 305                if not r:
 306                    return None
 307
 308            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 309                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 310                    return left if (av > bv if or_ else av <= bv) else right
 311                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 312                    return left if (av < bv if or_ else av >= bv) else right
 313
 314                # we can't ever shortcut to true because the column could be null
 315                if not or_:
 316                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 317                        if av <= bv:
 318                            return exp.false()
 319                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 320                        if av >= bv:
 321                            return exp.false()
 322                    elif isinstance(a, exp.EQ):
 323                        if isinstance(b, exp.LT):
 324                            return exp.false() if av >= bv else a
 325                        if isinstance(b, exp.LTE):
 326                            return exp.false() if av > bv else a
 327                        if isinstance(b, exp.GT):
 328                            return exp.false() if av <= bv else a
 329                        if isinstance(b, exp.GTE):
 330                            return exp.false() if av < bv else a
 331                        if isinstance(b, exp.NEQ):
 332                            return exp.false() if av == bv else a
 333    return None
 334
 335
 336def remove_complements(expression, root=True):
 337    """
 338    Removing complements.
 339
 340    A AND NOT A -> FALSE
 341    A OR NOT A -> TRUE
 342    """
 343    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 344        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 345
 346        for a, b in itertools.permutations(expression.flatten(), 2):
 347            if is_complement(a, b):
 348                return complement
 349    return expression
 350
 351
 352def uniq_sort(expression, root=True):
 353    """
 354    Uniq and sort a connector.
 355
 356    C AND A AND B AND B -> A AND B AND C
 357    """
 358    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 359        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 360        flattened = tuple(expression.flatten())
 361        deduped = {gen(e): e for e in flattened}
 362        arr = tuple(deduped.items())
 363
 364        # check if the operands are already sorted, if not sort them
 365        # A AND C AND B -> A AND B AND C
 366        for i, (sql, e) in enumerate(arr[1:]):
 367            if sql < arr[i][0]:
 368                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 369                break
 370        else:
 371            # we didn't have to sort but maybe we need to dedup
 372            if len(deduped) < len(flattened):
 373                expression = result_func(*deduped.values(), copy=False)
 374
 375    return expression
 376
 377
 378def absorb_and_eliminate(expression, root=True):
 379    """
 380    absorption:
 381        A AND (A OR B) -> A
 382        A OR (A AND B) -> A
 383        A AND (NOT A OR B) -> A AND B
 384        A OR (NOT A AND B) -> A OR B
 385    elimination:
 386        (A AND B) OR (A AND NOT B) -> A
 387        (A OR B) AND (A OR NOT B) -> A
 388    """
 389    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 390        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 391
 392        for a, b in itertools.permutations(expression.flatten(), 2):
 393            if isinstance(a, kind):
 394                aa, ab = a.unnest_operands()
 395
 396                # absorb
 397                if is_complement(b, aa):
 398                    aa.replace(exp.true() if kind == exp.And else exp.false())
 399                elif is_complement(b, ab):
 400                    ab.replace(exp.true() if kind == exp.And else exp.false())
 401                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 402                    a.replace(exp.false() if kind == exp.And else exp.true())
 403                elif isinstance(b, kind):
 404                    # eliminate
 405                    rhs = b.unnest_operands()
 406                    ba, bb = rhs
 407
 408                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 409                        a.replace(aa)
 410                        b.replace(aa)
 411                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 412                        a.replace(ab)
 413                        b.replace(ab)
 414
 415    return expression
 416
 417
 418def propagate_constants(expression, root=True):
 419    """
 420    Propagate constants for conjunctions in DNF:
 421
 422    SELECT * FROM t WHERE a = b AND b = 5 becomes
 423    SELECT * FROM t WHERE a = 5 AND b = 5
 424
 425    Reference: https://www.sqlite.org/optoverview.html
 426    """
 427
 428    if (
 429        isinstance(expression, exp.And)
 430        and (root or not expression.same_parent)
 431        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 432    ):
 433        constant_mapping = {}
 434        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 435            if isinstance(expr, exp.EQ):
 436                l, r = expr.left, expr.right
 437
 438                # TODO: create a helper that can be used to detect nested literal expressions such
 439                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 440                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 441                    constant_mapping[l] = (id(l), r)
 442
 443        if constant_mapping:
 444            for column in find_all_in_scope(expression, exp.Column):
 445                parent = column.parent
 446                column_id, constant = constant_mapping.get(column) or (None, None)
 447                if (
 448                    column_id is not None
 449                    and id(column) != column_id
 450                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 451                ):
 452                    column.replace(constant.copy())
 453
 454    return expression
 455
 456
 457INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 458    exp.DateAdd: exp.Sub,
 459    exp.DateSub: exp.Add,
 460    exp.DatetimeAdd: exp.Sub,
 461    exp.DatetimeSub: exp.Add,
 462}
 463
 464INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 465    **INVERSE_DATE_OPS,
 466    exp.Add: exp.Sub,
 467    exp.Sub: exp.Add,
 468}
 469
 470
 471def _is_number(expression: exp.Expression) -> bool:
 472    return expression.is_number
 473
 474
 475def _is_interval(expression: exp.Expression) -> bool:
 476    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 477
 478
 479@catch(ModuleNotFoundError, UnsupportedUnit)
 480def simplify_equality(expression: exp.Expression) -> exp.Expression:
 481    """
 482    Use the subtraction and addition properties of equality to simplify expressions:
 483
 484        x + 1 = 3 becomes x = 2
 485
 486    There are two binary operations in the above expression: + and =
 487    Here's how we reference all the operands in the code below:
 488
 489          l     r
 490        x + 1 = 3
 491        a   b
 492    """
 493    if isinstance(expression, COMPARISONS):
 494        l, r = expression.left, expression.right
 495
 496        if l.__class__ not in INVERSE_OPS:
 497            return expression
 498
 499        if r.is_number:
 500            a_predicate = _is_number
 501            b_predicate = _is_number
 502        elif _is_date_literal(r):
 503            a_predicate = _is_date_literal
 504            b_predicate = _is_interval
 505        else:
 506            return expression
 507
 508        if l.__class__ in INVERSE_DATE_OPS:
 509            l = t.cast(exp.IntervalOp, l)
 510            a = l.this
 511            b = l.interval()
 512        else:
 513            l = t.cast(exp.Binary, l)
 514            a, b = l.left, l.right
 515
 516        if not a_predicate(a) and b_predicate(b):
 517            pass
 518        elif not a_predicate(b) and b_predicate(a):
 519            a, b = b, a
 520        else:
 521            return expression
 522
 523        return expression.__class__(
 524            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 525        )
 526    return expression
 527
 528
 529def simplify_literals(expression, root=True):
 530    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 531        return _flat_simplify(expression, _simplify_binary, root)
 532
 533    if isinstance(expression, exp.Neg):
 534        this = expression.this
 535        if this.is_number:
 536            value = this.name
 537            if value[0] == "-":
 538                return exp.Literal.number(value[1:])
 539            return exp.Literal.number(f"-{value}")
 540
 541    if type(expression) in INVERSE_DATE_OPS:
 542        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 543
 544    return expression
 545
 546
 547NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 548
 549
 550def _simplify_binary(expression, a, b):
 551    if isinstance(expression, exp.Is):
 552        if isinstance(b, exp.Not):
 553            c = b.this
 554            not_ = True
 555        else:
 556            c = b
 557            not_ = False
 558
 559        if is_null(c):
 560            if isinstance(a, exp.Literal):
 561                return exp.true() if not_ else exp.false()
 562            if is_null(a):
 563                return exp.false() if not_ else exp.true()
 564    elif isinstance(expression, NULL_OK):
 565        return None
 566    elif is_null(a) or is_null(b):
 567        return exp.null()
 568
 569    if a.is_number and b.is_number:
 570        num_a = int(a.name) if a.is_int else Decimal(a.name)
 571        num_b = int(b.name) if b.is_int else Decimal(b.name)
 572
 573        if isinstance(expression, exp.Add):
 574            return exp.Literal.number(num_a + num_b)
 575        if isinstance(expression, exp.Mul):
 576            return exp.Literal.number(num_a * num_b)
 577
 578        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 579        if isinstance(expression, exp.Sub):
 580            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 581        if isinstance(expression, exp.Div):
 582            # engines have differing int div behavior so intdiv is not safe
 583            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 584                return None
 585            return exp.Literal.number(num_a / num_b)
 586
 587        boolean = eval_boolean(expression, num_a, num_b)
 588
 589        if boolean:
 590            return boolean
 591    elif a.is_string and b.is_string:
 592        boolean = eval_boolean(expression, a.this, b.this)
 593
 594        if boolean:
 595            return boolean
 596    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 597        a, b = extract_date(a), extract_interval(b)
 598        if a and b:
 599            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 600                return date_literal(a + b)
 601            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 602                return date_literal(a - b)
 603    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 604        a, b = extract_interval(a), extract_date(b)
 605        # you cannot subtract a date from an interval
 606        if a and b and isinstance(expression, exp.Add):
 607            return date_literal(a + b)
 608    elif _is_date_literal(a) and _is_date_literal(b):
 609        if isinstance(expression, exp.Predicate):
 610            a, b = extract_date(a), extract_date(b)
 611            boolean = eval_boolean(expression, a, b)
 612            if boolean:
 613                return boolean
 614
 615    return None
 616
 617
 618def simplify_parens(expression):
 619    if not isinstance(expression, exp.Paren):
 620        return expression
 621
 622    this = expression.this
 623    parent = expression.parent
 624    parent_is_predicate = isinstance(parent, exp.Predicate)
 625
 626    if not isinstance(this, exp.Select) and (
 627        not isinstance(parent, (exp.Condition, exp.Binary))
 628        or isinstance(parent, exp.Paren)
 629        or (
 630            not isinstance(this, exp.Binary)
 631            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 632        )
 633        or (isinstance(this, exp.Predicate) and not parent_is_predicate)
 634        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 635        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 636        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 637    ):
 638        return this
 639    return expression
 640
 641
 642def _is_nonnull_constant(expression: exp.Expression) -> bool:
 643    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 644
 645
 646def _is_constant(expression: exp.Expression) -> bool:
 647    return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
 648
 649
 650def simplify_coalesce(expression):
 651    # COALESCE(x) -> x
 652    if (
 653        isinstance(expression, exp.Coalesce)
 654        and (not expression.expressions or _is_nonnull_constant(expression.this))
 655        # COALESCE is also used as a Spark partitioning hint
 656        and not isinstance(expression.parent, exp.Hint)
 657    ):
 658        return expression.this
 659
 660    if not isinstance(expression, COMPARISONS):
 661        return expression
 662
 663    if isinstance(expression.left, exp.Coalesce):
 664        coalesce = expression.left
 665        other = expression.right
 666    elif isinstance(expression.right, exp.Coalesce):
 667        coalesce = expression.right
 668        other = expression.left
 669    else:
 670        return expression
 671
 672    # This transformation is valid for non-constants,
 673    # but it really only does anything if they are both constants.
 674    if not _is_constant(other):
 675        return expression
 676
 677    # Find the first constant arg
 678    for arg_index, arg in enumerate(coalesce.expressions):
 679        if _is_constant(arg):
 680            break
 681    else:
 682        return expression
 683
 684    coalesce.set("expressions", coalesce.expressions[:arg_index])
 685
 686    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 687    # since we already remove COALESCE at the top of this function.
 688    coalesce = coalesce if coalesce.expressions else coalesce.this
 689
 690    # This expression is more complex than when we started, but it will get simplified further
 691    return exp.paren(
 692        exp.or_(
 693            exp.and_(
 694                coalesce.is_(exp.null()).not_(copy=False),
 695                expression.copy(),
 696                copy=False,
 697            ),
 698            exp.and_(
 699                coalesce.is_(exp.null()),
 700                type(expression)(this=arg.copy(), expression=other.copy()),
 701                copy=False,
 702            ),
 703            copy=False,
 704        )
 705    )
 706
 707
 708CONCATS = (exp.Concat, exp.DPipe)
 709
 710
 711def simplify_concat(expression):
 712    """Reduces all groups that contain string literals by concatenating them."""
 713    if not isinstance(expression, CONCATS) or (
 714        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 715        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 716    ):
 717        return expression
 718
 719    if isinstance(expression, exp.ConcatWs):
 720        sep_expr, *expressions = expression.expressions
 721        sep = sep_expr.name
 722        concat_type = exp.ConcatWs
 723        args = {}
 724    else:
 725        expressions = expression.expressions
 726        sep = ""
 727        concat_type = exp.Concat
 728        args = {
 729            "safe": expression.args.get("safe"),
 730            "coalesce": expression.args.get("coalesce"),
 731        }
 732
 733    new_args = []
 734    for is_string_group, group in itertools.groupby(
 735        expressions or expression.flatten(), lambda e: e.is_string
 736    ):
 737        if is_string_group:
 738            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 739        else:
 740            new_args.extend(group)
 741
 742    if len(new_args) == 1 and new_args[0].is_string:
 743        return new_args[0]
 744
 745    if concat_type is exp.ConcatWs:
 746        new_args = [sep_expr] + new_args
 747
 748    return concat_type(expressions=new_args, **args)
 749
 750
 751def simplify_conditionals(expression):
 752    """Simplifies expressions like IF, CASE if their condition is statically known."""
 753    if isinstance(expression, exp.Case):
 754        this = expression.this
 755        for case in expression.args["ifs"]:
 756            cond = case.this
 757            if this:
 758                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 759                cond = cond.replace(this.pop().eq(cond))
 760
 761            if always_true(cond):
 762                return case.args["true"]
 763
 764            if always_false(cond):
 765                case.pop()
 766                if not expression.args["ifs"]:
 767                    return expression.args.get("default") or exp.null()
 768    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 769        if always_true(expression.this):
 770            return expression.args["true"]
 771        if always_false(expression.this):
 772            return expression.args.get("false") or exp.null()
 773
 774    return expression
 775
 776
 777def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 778    """
 779    Reduces a prefix check to either TRUE or FALSE if both the string and the
 780    prefix are statically known.
 781
 782    Example:
 783        >>> from sqlglot import parse_one
 784        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 785        'TRUE'
 786    """
 787    if (
 788        isinstance(expression, exp.StartsWith)
 789        and expression.this.is_string
 790        and expression.expression.is_string
 791    ):
 792        return exp.convert(expression.name.startswith(expression.expression.name))
 793
 794    return expression
 795
 796
 797DateRange = t.Tuple[datetime.date, datetime.date]
 798
 799
 800def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 801    """
 802    Get the date range for a DATE_TRUNC equality comparison:
 803
 804    Example:
 805        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 806    Returns:
 807        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 808    """
 809    floor = date_floor(date, unit, dialect)
 810
 811    if date != floor:
 812        # This will always be False, except for NULL values.
 813        return None
 814
 815    return floor, floor + interval(unit)
 816
 817
 818def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
 819    """Get the logical expression for a date range"""
 820    return exp.and_(
 821        left >= date_literal(drange[0]),
 822        left < date_literal(drange[1]),
 823        copy=False,
 824    )
 825
 826
 827def _datetrunc_eq(
 828    left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
 829) -> t.Optional[exp.Expression]:
 830    drange = _datetrunc_range(date, unit, dialect)
 831    if not drange:
 832        return None
 833
 834    return _datetrunc_eq_expression(left, drange)
 835
 836
 837def _datetrunc_neq(
 838    left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
 839) -> t.Optional[exp.Expression]:
 840    drange = _datetrunc_range(date, unit, dialect)
 841    if not drange:
 842        return None
 843
 844    return exp.and_(
 845        left < date_literal(drange[0]),
 846        left >= date_literal(drange[1]),
 847        copy=False,
 848    )
 849
 850
 851DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 852    exp.LT: lambda l, dt, u, d: l
 853    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
 854    exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
 855    exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
 856    exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
 857    exp.EQ: _datetrunc_eq,
 858    exp.NEQ: _datetrunc_neq,
 859}
 860DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 861DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
 862
 863
 864def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 865    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
 866
 867
 868@catch(ModuleNotFoundError, UnsupportedUnit)
 869def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
 870    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 871    comparison = expression.__class__
 872
 873    if isinstance(expression, DATETRUNCS):
 874        date = extract_date(expression.this)
 875        if date and expression.unit:
 876            return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
 877    elif comparison not in DATETRUNC_COMPARISONS:
 878        return expression
 879
 880    if isinstance(expression, exp.Binary):
 881        l, r = expression.left, expression.right
 882
 883        if not _is_datetrunc_predicate(l, r):
 884            return expression
 885
 886        l = t.cast(exp.DateTrunc, l)
 887        unit = l.unit.name.lower()
 888        date = extract_date(r)
 889
 890        if not date:
 891            return expression
 892
 893        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
 894    elif isinstance(expression, exp.In):
 895        l = expression.this
 896        rs = expression.expressions
 897
 898        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 899            l = t.cast(exp.DateTrunc, l)
 900            unit = l.unit.name.lower()
 901
 902            ranges = []
 903            for r in rs:
 904                date = extract_date(r)
 905                if not date:
 906                    return expression
 907                drange = _datetrunc_range(date, unit, dialect)
 908                if drange:
 909                    ranges.append(drange)
 910
 911            if not ranges:
 912                return expression
 913
 914            ranges = merge_ranges(ranges)
 915
 916            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
 917
 918    return expression
 919
 920
 921def sort_comparison(expression: exp.Expression) -> exp.Expression:
 922    if expression.__class__ in COMPLEMENT_COMPARISONS:
 923        l, r = expression.this, expression.expression
 924        l_column = isinstance(l, exp.Column)
 925        r_column = isinstance(r, exp.Column)
 926        l_const = _is_constant(l)
 927        r_const = _is_constant(r)
 928
 929        if (l_column and not r_column) or (r_const and not l_const):
 930            return expression
 931        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
 932            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
 933                this=r, expression=l
 934            )
 935    return expression
 936
 937
 938# CROSS joins result in an empty table if the right table is empty.
 939# So we can only simplify certain types of joins to CROSS.
 940# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 941JOINS = {
 942    ("", ""),
 943    ("", "INNER"),
 944    ("RIGHT", ""),
 945    ("RIGHT", "OUTER"),
 946}
 947
 948
 949def remove_where_true(expression):
 950    for where in expression.find_all(exp.Where):
 951        if always_true(where.this):
 952            where.pop()
 953    for join in expression.find_all(exp.Join):
 954        if (
 955            always_true(join.args.get("on"))
 956            and not join.args.get("using")
 957            and not join.args.get("method")
 958            and (join.side, join.kind) in JOINS
 959        ):
 960            join.args["on"].pop()
 961            join.set("side", None)
 962            join.set("kind", "CROSS")
 963
 964
 965def always_true(expression):
 966    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
 967        expression, exp.Literal
 968    )
 969
 970
 971def always_false(expression):
 972    return is_false(expression) or is_null(expression)
 973
 974
 975def is_complement(a, b):
 976    return isinstance(b, exp.Not) and b.this == a
 977
 978
 979def is_false(a: exp.Expression) -> bool:
 980    return type(a) is exp.Boolean and not a.this
 981
 982
 983def is_null(a: exp.Expression) -> bool:
 984    return type(a) is exp.Null
 985
 986
 987def eval_boolean(expression, a, b):
 988    if isinstance(expression, (exp.EQ, exp.Is)):
 989        return boolean_literal(a == b)
 990    if isinstance(expression, exp.NEQ):
 991        return boolean_literal(a != b)
 992    if isinstance(expression, exp.GT):
 993        return boolean_literal(a > b)
 994    if isinstance(expression, exp.GTE):
 995        return boolean_literal(a >= b)
 996    if isinstance(expression, exp.LT):
 997        return boolean_literal(a < b)
 998    if isinstance(expression, exp.LTE):
 999        return boolean_literal(a <= b)
1000    return None
1001
1002
1003def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1004    if isinstance(value, datetime.datetime):
1005        return value.date()
1006    if isinstance(value, datetime.date):
1007        return value
1008    try:
1009        return datetime.datetime.fromisoformat(value).date()
1010    except ValueError:
1011        return None
1012
1013
1014def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1015    if isinstance(value, datetime.datetime):
1016        return value
1017    if isinstance(value, datetime.date):
1018        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1019    try:
1020        return datetime.datetime.fromisoformat(value)
1021    except ValueError:
1022        return None
1023
1024
1025def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1026    if not value:
1027        return None
1028    if to.is_type(exp.DataType.Type.DATE):
1029        return cast_as_date(value)
1030    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1031        return cast_as_datetime(value)
1032    return None
1033
1034
1035def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1036    if isinstance(cast, exp.Cast):
1037        to = cast.to
1038    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1039        to = exp.DataType.build(exp.DataType.Type.DATE)
1040    else:
1041        return None
1042
1043    if isinstance(cast.this, exp.Literal):
1044        value: t.Any = cast.this.name
1045    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1046        value = extract_date(cast.this)
1047    else:
1048        return None
1049    return cast_value(value, to)
1050
1051
1052def _is_date_literal(expression: exp.Expression) -> bool:
1053    return extract_date(expression) is not None
1054
1055
1056def extract_interval(expression):
1057    try:
1058        n = int(expression.name)
1059        unit = expression.text("unit").lower()
1060        return interval(unit, n)
1061    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1062        return None
1063
1064
1065def date_literal(date):
1066    return exp.cast(
1067        exp.Literal.string(date),
1068        (
1069            exp.DataType.Type.DATETIME
1070            if isinstance(date, datetime.datetime)
1071            else exp.DataType.Type.DATE
1072        ),
1073    )
1074
1075
1076def interval(unit: str, n: int = 1):
1077    from dateutil.relativedelta import relativedelta
1078
1079    if unit == "year":
1080        return relativedelta(years=1 * n)
1081    if unit == "quarter":
1082        return relativedelta(months=3 * n)
1083    if unit == "month":
1084        return relativedelta(months=1 * n)
1085    if unit == "week":
1086        return relativedelta(weeks=1 * n)
1087    if unit == "day":
1088        return relativedelta(days=1 * n)
1089    if unit == "hour":
1090        return relativedelta(hours=1 * n)
1091    if unit == "minute":
1092        return relativedelta(minutes=1 * n)
1093    if unit == "second":
1094        return relativedelta(seconds=1 * n)
1095
1096    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1097
1098
1099def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1100    if unit == "year":
1101        return d.replace(month=1, day=1)
1102    if unit == "quarter":
1103        if d.month <= 3:
1104            return d.replace(month=1, day=1)
1105        elif d.month <= 6:
1106            return d.replace(month=4, day=1)
1107        elif d.month <= 9:
1108            return d.replace(month=7, day=1)
1109        else:
1110            return d.replace(month=10, day=1)
1111    if unit == "month":
1112        return d.replace(month=d.month, day=1)
1113    if unit == "week":
1114        # Assuming week starts on Monday (0) and ends on Sunday (6)
1115        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1116    if unit == "day":
1117        return d
1118
1119    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1120
1121
1122def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1123    floor = date_floor(d, unit, dialect)
1124
1125    if floor == d:
1126        return d
1127
1128    return floor + interval(unit)
1129
1130
1131def boolean_literal(condition):
1132    return exp.true() if condition else exp.false()
1133
1134
1135def _flat_simplify(expression, simplifier, root=True):
1136    if root or not expression.same_parent:
1137        operands = []
1138        queue = deque(expression.flatten(unnest=False))
1139        size = len(queue)
1140
1141        while queue:
1142            a = queue.popleft()
1143
1144            for b in queue:
1145                result = simplifier(expression, a, b)
1146
1147                if result and result is not expression:
1148                    queue.remove(b)
1149                    queue.appendleft(result)
1150                    break
1151            else:
1152                operands.append(a)
1153
1154        if len(operands) < size:
1155            return functools.reduce(
1156                lambda a, b: expression.__class__(this=a, expression=b), operands
1157            )
1158    return expression
1159
1160
1161def gen(expression: t.Any) -> str:
1162    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1163
1164    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1165    generator is expensive so we have a bare minimum sql generator here.
1166    """
1167    return Gen().gen(expression)
1168
1169
1170class Gen:
1171    def __init__(self):
1172        self.stack = []
1173        self.sqls = []
1174
1175    def gen(self, expression: exp.Expression) -> str:
1176        self.stack = [expression]
1177        self.sqls.clear()
1178
1179        while self.stack:
1180            node = self.stack.pop()
1181
1182            if isinstance(node, exp.Expression):
1183                exp_handler_name = f"{node.key}_sql"
1184
1185                if hasattr(self, exp_handler_name):
1186                    getattr(self, exp_handler_name)(node)
1187                elif isinstance(node, exp.Func):
1188                    self._function(node)
1189                else:
1190                    key = node.key.upper()
1191                    self.stack.append(f"{key} " if self._args(node) else key)
1192            elif type(node) is list:
1193                for n in reversed(node):
1194                    if n is not None:
1195                        self.stack.extend((n, ","))
1196                if node:
1197                    self.stack.pop()
1198            else:
1199                if node is not None:
1200                    self.sqls.append(str(node))
1201
1202        return "".join(self.sqls)
1203
1204    def add_sql(self, e: exp.Add) -> None:
1205        self._binary(e, " + ")
1206
1207    def alias_sql(self, e: exp.Alias) -> None:
1208        self.stack.extend(
1209            (
1210                e.args.get("alias"),
1211                " AS ",
1212                e.args.get("this"),
1213            )
1214        )
1215
1216    def and_sql(self, e: exp.And) -> None:
1217        self._binary(e, " AND ")
1218
1219    def anonymous_sql(self, e: exp.Anonymous) -> None:
1220        this = e.this
1221        if isinstance(this, str):
1222            name = this.upper()
1223        elif isinstance(this, exp.Identifier):
1224            name = this.this
1225            name = f'"{name}"' if this.quoted else name.upper()
1226        else:
1227            raise ValueError(
1228                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1229            )
1230
1231        self.stack.extend(
1232            (
1233                ")",
1234                e.expressions,
1235                "(",
1236                name,
1237            )
1238        )
1239
1240    def between_sql(self, e: exp.Between) -> None:
1241        self.stack.extend(
1242            (
1243                e.args.get("high"),
1244                " AND ",
1245                e.args.get("low"),
1246                " BETWEEN ",
1247                e.this,
1248            )
1249        )
1250
1251    def boolean_sql(self, e: exp.Boolean) -> None:
1252        self.stack.append("TRUE" if e.this else "FALSE")
1253
1254    def bracket_sql(self, e: exp.Bracket) -> None:
1255        self.stack.extend(
1256            (
1257                "]",
1258                e.expressions,
1259                "[",
1260                e.this,
1261            )
1262        )
1263
1264    def column_sql(self, e: exp.Column) -> None:
1265        for p in reversed(e.parts):
1266            self.stack.extend((p, "."))
1267        self.stack.pop()
1268
1269    def datatype_sql(self, e: exp.DataType) -> None:
1270        self._args(e, 1)
1271        self.stack.append(f"{e.this.name} ")
1272
1273    def div_sql(self, e: exp.Div) -> None:
1274        self._binary(e, " / ")
1275
1276    def dot_sql(self, e: exp.Dot) -> None:
1277        self._binary(e, ".")
1278
1279    def eq_sql(self, e: exp.EQ) -> None:
1280        self._binary(e, " = ")
1281
1282    def from_sql(self, e: exp.From) -> None:
1283        self.stack.extend((e.this, "FROM "))
1284
1285    def gt_sql(self, e: exp.GT) -> None:
1286        self._binary(e, " > ")
1287
1288    def gte_sql(self, e: exp.GTE) -> None:
1289        self._binary(e, " >= ")
1290
1291    def identifier_sql(self, e: exp.Identifier) -> None:
1292        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1293
1294    def ilike_sql(self, e: exp.ILike) -> None:
1295        self._binary(e, " ILIKE ")
1296
1297    def in_sql(self, e: exp.In) -> None:
1298        self.stack.append(")")
1299        self._args(e, 1)
1300        self.stack.extend(
1301            (
1302                "(",
1303                " IN ",
1304                e.this,
1305            )
1306        )
1307
1308    def intdiv_sql(self, e: exp.IntDiv) -> None:
1309        self._binary(e, " DIV ")
1310
1311    def is_sql(self, e: exp.Is) -> None:
1312        self._binary(e, " IS ")
1313
1314    def like_sql(self, e: exp.Like) -> None:
1315        self._binary(e, " Like ")
1316
1317    def literal_sql(self, e: exp.Literal) -> None:
1318        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1319
1320    def lt_sql(self, e: exp.LT) -> None:
1321        self._binary(e, " < ")
1322
1323    def lte_sql(self, e: exp.LTE) -> None:
1324        self._binary(e, " <= ")
1325
1326    def mod_sql(self, e: exp.Mod) -> None:
1327        self._binary(e, " % ")
1328
1329    def mul_sql(self, e: exp.Mul) -> None:
1330        self._binary(e, " * ")
1331
1332    def neg_sql(self, e: exp.Neg) -> None:
1333        self._unary(e, "-")
1334
1335    def neq_sql(self, e: exp.NEQ) -> None:
1336        self._binary(e, " <> ")
1337
1338    def not_sql(self, e: exp.Not) -> None:
1339        self._unary(e, "NOT ")
1340
1341    def null_sql(self, e: exp.Null) -> None:
1342        self.stack.append("NULL")
1343
1344    def or_sql(self, e: exp.Or) -> None:
1345        self._binary(e, " OR ")
1346
1347    def paren_sql(self, e: exp.Paren) -> None:
1348        self.stack.extend(
1349            (
1350                ")",
1351                e.this,
1352                "(",
1353            )
1354        )
1355
1356    def sub_sql(self, e: exp.Sub) -> None:
1357        self._binary(e, " - ")
1358
1359    def subquery_sql(self, e: exp.Subquery) -> None:
1360        self._args(e, 2)
1361        alias = e.args.get("alias")
1362        if alias:
1363            self.stack.append(alias)
1364        self.stack.extend((")", e.this, "("))
1365
1366    def table_sql(self, e: exp.Table) -> None:
1367        self._args(e, 4)
1368        alias = e.args.get("alias")
1369        if alias:
1370            self.stack.append(alias)
1371        for p in reversed(e.parts):
1372            self.stack.extend((p, "."))
1373        self.stack.pop()
1374
1375    def tablealias_sql(self, e: exp.TableAlias) -> None:
1376        columns = e.columns
1377
1378        if columns:
1379            self.stack.extend((")", columns, "("))
1380
1381        self.stack.extend((e.this, " AS "))
1382
1383    def var_sql(self, e: exp.Var) -> None:
1384        self.stack.append(e.this)
1385
1386    def _binary(self, e: exp.Binary, op: str) -> None:
1387        self.stack.extend((e.expression, op, e.this))
1388
1389    def _unary(self, e: exp.Unary, op: str) -> None:
1390        self.stack.extend((e.this, op))
1391
1392    def _function(self, e: exp.Func) -> None:
1393        self.stack.extend(
1394            (
1395                ")",
1396                list(e.args.values()),
1397                "(",
1398                e.sql_name(),
1399            )
1400        )
1401
1402    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1403        kvs = []
1404        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1405
1406        for k in arg_types or arg_types:
1407            v = node.args.get(k)
1408
1409            if v is not None:
1410                kvs.append([f":{k}", v])
1411        if kvs:
1412            self.stack.append(kvs)
1413            return True
1414        return False
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
27class UnsupportedUnit(Exception):
28    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None):
 31def simplify(
 32    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
 33):
 34    """
 35    Rewrite sqlglot AST to simplify expressions.
 36
 37    Example:
 38        >>> import sqlglot
 39        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 40        >>> simplify(expression).sql()
 41        'TRUE'
 42
 43    Args:
 44        expression (sqlglot.Expression): expression to simplify
 45        constant_propagation: whether the constant propagation rule should be used
 46
 47    Returns:
 48        sqlglot.Expression: simplified expression
 49    """
 50
 51    dialect = Dialect.get_or_raise(dialect)
 52
 53    def _simplify(expression, root=True):
 54        if expression.meta.get(FINAL):
 55            return expression
 56
 57        # group by expressions cannot be simplified, for example
 58        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 59        # the projection must exactly match the group by key
 60        group = expression.args.get("group")
 61
 62        if group and hasattr(expression, "selects"):
 63            groups = set(group.expressions)
 64            group.meta[FINAL] = True
 65
 66            for e in expression.selects:
 67                for node in e.walk():
 68                    if node in groups:
 69                        e.meta[FINAL] = True
 70                        break
 71
 72            having = expression.args.get("having")
 73            if having:
 74                for node in having.walk():
 75                    if node in groups:
 76                        having.meta[FINAL] = True
 77                        break
 78
 79        # Pre-order transformations
 80        node = expression
 81        node = rewrite_between(node)
 82        node = uniq_sort(node, root)
 83        node = absorb_and_eliminate(node, root)
 84        node = simplify_concat(node)
 85        node = simplify_conditionals(node)
 86
 87        if constant_propagation:
 88            node = propagate_constants(node, root)
 89
 90        exp.replace_children(node, lambda e: _simplify(e, False))
 91
 92        # Post-order transformations
 93        node = simplify_not(node)
 94        node = flatten(node)
 95        node = simplify_connectors(node, root)
 96        node = remove_complements(node, root)
 97        node = simplify_coalesce(node)
 98        node.parent = expression.parent
 99        node = simplify_literals(node, root)
100        node = simplify_equality(node)
101        node = simplify_parens(node)
102        node = simplify_datetrunc(node, dialect)
103        node = sort_comparison(node)
104        node = simplify_startswith(node)
105
106        if root:
107            expression.replace(node)
108        return node
109
110    expression = while_changing(expression, _simplify)
111    remove_where_true(expression)
112    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
115def catch(*exceptions):
116    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
117
118    def decorator(func):
119        def wrapped(expression, *args, **kwargs):
120            try:
121                return func(expression, *args, **kwargs)
122            except exceptions:
123                return expression
124
125        return wrapped
126
127    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
130def rewrite_between(expression: exp.Expression) -> exp.Expression:
131    """Rewrite x between y and z to x >= y AND x <= z.
132
133    This is done because comparison simplification is only done on lt/lte/gt/gte.
134    """
135    if isinstance(expression, exp.Between):
136        negate = isinstance(expression.parent, exp.Not)
137
138        expression = exp.and_(
139            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
140            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
141            copy=False,
142        )
143
144        if negate:
145            expression = exp.paren(expression, copy=False)
146
147    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
160def simplify_not(expression):
161    """
162    Demorgan's Law
163    NOT (x OR y) -> NOT x AND NOT y
164    NOT (x AND y) -> NOT x OR NOT y
165    """
166    if isinstance(expression, exp.Not):
167        this = expression.this
168        if is_null(this):
169            return exp.null()
170        if this.__class__ in COMPLEMENT_COMPARISONS:
171            return COMPLEMENT_COMPARISONS[this.__class__](
172                this=this.this, expression=this.expression
173            )
174        if isinstance(this, exp.Paren):
175            condition = this.unnest()
176            if isinstance(condition, exp.And):
177                return exp.paren(
178                    exp.or_(
179                        exp.not_(condition.left, copy=False),
180                        exp.not_(condition.right, copy=False),
181                        copy=False,
182                    )
183                )
184            if isinstance(condition, exp.Or):
185                return exp.paren(
186                    exp.and_(
187                        exp.not_(condition.left, copy=False),
188                        exp.not_(condition.right, copy=False),
189                        copy=False,
190                    )
191                )
192            if is_null(condition):
193                return exp.null()
194        if always_true(this):
195            return exp.false()
196        if is_false(this):
197            return exp.true()
198        if isinstance(this, exp.Not):
199            # double negation
200            # NOT NOT x -> x
201            return this.this
202    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
205def flatten(expression):
206    """
207    A AND (B AND C) -> A AND B AND C
208    A OR (B OR C) -> A OR B OR C
209    """
210    if isinstance(expression, exp.Connector):
211        for node in expression.args.values():
212            child = node.unnest()
213            if isinstance(child, expression.__class__):
214                node.replace(child)
215    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
218def simplify_connectors(expression, root=True):
219    def _simplify_connectors(expression, left, right):
220        if left == right:
221            return left
222        if isinstance(expression, exp.And):
223            if is_false(left) or is_false(right):
224                return exp.false()
225            if is_null(left) or is_null(right):
226                return exp.null()
227            if always_true(left) and always_true(right):
228                return exp.true()
229            if always_true(left):
230                return right
231            if always_true(right):
232                return left
233            return _simplify_comparison(expression, left, right)
234        elif isinstance(expression, exp.Or):
235            if always_true(left) or always_true(right):
236                return exp.true()
237            if is_false(left) and is_false(right):
238                return exp.false()
239            if (
240                (is_null(left) and is_null(right))
241                or (is_null(left) and is_false(right))
242                or (is_false(left) and is_null(right))
243            ):
244                return exp.null()
245            if is_false(left):
246                return right
247            if is_false(right):
248                return left
249            return _simplify_comparison(expression, left, right, or_=True)
250
251    if isinstance(expression, exp.Connector):
252        return _flat_simplify(expression, _simplify_connectors, root)
253    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
def remove_complements(expression, root=True):
337def remove_complements(expression, root=True):
338    """
339    Removing complements.
340
341    A AND NOT A -> FALSE
342    A OR NOT A -> TRUE
343    """
344    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
345        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
346
347        for a, b in itertools.permutations(expression.flatten(), 2):
348            if is_complement(a, b):
349                return complement
350    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
353def uniq_sort(expression, root=True):
354    """
355    Uniq and sort a connector.
356
357    C AND A AND B AND B -> A AND B AND C
358    """
359    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
360        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
361        flattened = tuple(expression.flatten())
362        deduped = {gen(e): e for e in flattened}
363        arr = tuple(deduped.items())
364
365        # check if the operands are already sorted, if not sort them
366        # A AND C AND B -> A AND B AND C
367        for i, (sql, e) in enumerate(arr[1:]):
368            if sql < arr[i][0]:
369                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
370                break
371        else:
372            # we didn't have to sort but maybe we need to dedup
373            if len(deduped) < len(flattened):
374                expression = result_func(*deduped.values(), copy=False)
375
376    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
379def absorb_and_eliminate(expression, root=True):
380    """
381    absorption:
382        A AND (A OR B) -> A
383        A OR (A AND B) -> A
384        A AND (NOT A OR B) -> A AND B
385        A OR (NOT A AND B) -> A OR B
386    elimination:
387        (A AND B) OR (A AND NOT B) -> A
388        (A OR B) AND (A OR NOT B) -> A
389    """
390    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
391        kind = exp.Or if isinstance(expression, exp.And) else exp.And
392
393        for a, b in itertools.permutations(expression.flatten(), 2):
394            if isinstance(a, kind):
395                aa, ab = a.unnest_operands()
396
397                # absorb
398                if is_complement(b, aa):
399                    aa.replace(exp.true() if kind == exp.And else exp.false())
400                elif is_complement(b, ab):
401                    ab.replace(exp.true() if kind == exp.And else exp.false())
402                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
403                    a.replace(exp.false() if kind == exp.And else exp.true())
404                elif isinstance(b, kind):
405                    # eliminate
406                    rhs = b.unnest_operands()
407                    ba, bb = rhs
408
409                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
410                        a.replace(aa)
411                        b.replace(aa)
412                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
413                        a.replace(ab)
414                        b.replace(ab)
415
416    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
419def propagate_constants(expression, root=True):
420    """
421    Propagate constants for conjunctions in DNF:
422
423    SELECT * FROM t WHERE a = b AND b = 5 becomes
424    SELECT * FROM t WHERE a = 5 AND b = 5
425
426    Reference: https://www.sqlite.org/optoverview.html
427    """
428
429    if (
430        isinstance(expression, exp.And)
431        and (root or not expression.same_parent)
432        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
433    ):
434        constant_mapping = {}
435        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
436            if isinstance(expr, exp.EQ):
437                l, r = expr.left, expr.right
438
439                # TODO: create a helper that can be used to detect nested literal expressions such
440                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
441                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
442                    constant_mapping[l] = (id(l), r)
443
444        if constant_mapping:
445            for column in find_all_in_scope(expression, exp.Column):
446                parent = column.parent
447                column_id, constant = constant_mapping.get(column) or (None, None)
448                if (
449                    column_id is not None
450                    and id(column) != column_id
451                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
452                ):
453                    column.replace(constant.copy())
454
455    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
119        def wrapped(expression, *args, **kwargs):
120            try:
121                return func(expression, *args, **kwargs)
122            except exceptions:
123                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
530def simplify_literals(expression, root=True):
531    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
532        return _flat_simplify(expression, _simplify_binary, root)
533
534    if isinstance(expression, exp.Neg):
535        this = expression.this
536        if this.is_number:
537            value = this.name
538            if value[0] == "-":
539                return exp.Literal.number(value[1:])
540            return exp.Literal.number(f"-{value}")
541
542    if type(expression) in INVERSE_DATE_OPS:
543        return _simplify_binary(expression, expression.this, expression.interval()) or expression
544
545    return expression
def simplify_parens(expression):
619def simplify_parens(expression):
620    if not isinstance(expression, exp.Paren):
621        return expression
622
623    this = expression.this
624    parent = expression.parent
625    parent_is_predicate = isinstance(parent, exp.Predicate)
626
627    if not isinstance(this, exp.Select) and (
628        not isinstance(parent, (exp.Condition, exp.Binary))
629        or isinstance(parent, exp.Paren)
630        or (
631            not isinstance(this, exp.Binary)
632            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
633        )
634        or (isinstance(this, exp.Predicate) and not parent_is_predicate)
635        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
636        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
637        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
638    ):
639        return this
640    return expression
def simplify_coalesce(expression):
651def simplify_coalesce(expression):
652    # COALESCE(x) -> x
653    if (
654        isinstance(expression, exp.Coalesce)
655        and (not expression.expressions or _is_nonnull_constant(expression.this))
656        # COALESCE is also used as a Spark partitioning hint
657        and not isinstance(expression.parent, exp.Hint)
658    ):
659        return expression.this
660
661    if not isinstance(expression, COMPARISONS):
662        return expression
663
664    if isinstance(expression.left, exp.Coalesce):
665        coalesce = expression.left
666        other = expression.right
667    elif isinstance(expression.right, exp.Coalesce):
668        coalesce = expression.right
669        other = expression.left
670    else:
671        return expression
672
673    # This transformation is valid for non-constants,
674    # but it really only does anything if they are both constants.
675    if not _is_constant(other):
676        return expression
677
678    # Find the first constant arg
679    for arg_index, arg in enumerate(coalesce.expressions):
680        if _is_constant(arg):
681            break
682    else:
683        return expression
684
685    coalesce.set("expressions", coalesce.expressions[:arg_index])
686
687    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
688    # since we already remove COALESCE at the top of this function.
689    coalesce = coalesce if coalesce.expressions else coalesce.this
690
691    # This expression is more complex than when we started, but it will get simplified further
692    return exp.paren(
693        exp.or_(
694            exp.and_(
695                coalesce.is_(exp.null()).not_(copy=False),
696                expression.copy(),
697                copy=False,
698            ),
699            exp.and_(
700                coalesce.is_(exp.null()),
701                type(expression)(this=arg.copy(), expression=other.copy()),
702                copy=False,
703            ),
704            copy=False,
705        )
706    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
712def simplify_concat(expression):
713    """Reduces all groups that contain string literals by concatenating them."""
714    if not isinstance(expression, CONCATS) or (
715        # We can't reduce a CONCAT_WS call if we don't statically know the separator
716        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
717    ):
718        return expression
719
720    if isinstance(expression, exp.ConcatWs):
721        sep_expr, *expressions = expression.expressions
722        sep = sep_expr.name
723        concat_type = exp.ConcatWs
724        args = {}
725    else:
726        expressions = expression.expressions
727        sep = ""
728        concat_type = exp.Concat
729        args = {
730            "safe": expression.args.get("safe"),
731            "coalesce": expression.args.get("coalesce"),
732        }
733
734    new_args = []
735    for is_string_group, group in itertools.groupby(
736        expressions or expression.flatten(), lambda e: e.is_string
737    ):
738        if is_string_group:
739            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
740        else:
741            new_args.extend(group)
742
743    if len(new_args) == 1 and new_args[0].is_string:
744        return new_args[0]
745
746    if concat_type is exp.ConcatWs:
747        new_args = [sep_expr] + new_args
748
749    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
752def simplify_conditionals(expression):
753    """Simplifies expressions like IF, CASE if their condition is statically known."""
754    if isinstance(expression, exp.Case):
755        this = expression.this
756        for case in expression.args["ifs"]:
757            cond = case.this
758            if this:
759                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
760                cond = cond.replace(this.pop().eq(cond))
761
762            if always_true(cond):
763                return case.args["true"]
764
765            if always_false(cond):
766                case.pop()
767                if not expression.args["ifs"]:
768                    return expression.args.get("default") or exp.null()
769    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
770        if always_true(expression.this):
771            return expression.args["true"]
772        if always_false(expression.this):
773            return expression.args.get("false") or exp.null()
774
775    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
778def simplify_startswith(expression: exp.Expression) -> exp.Expression:
779    """
780    Reduces a prefix check to either TRUE or FALSE if both the string and the
781    prefix are statically known.
782
783    Example:
784        >>> from sqlglot import parse_one
785        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
786        'TRUE'
787    """
788    if (
789        isinstance(expression, exp.StartsWith)
790        and expression.this.is_string
791        and expression.expression.is_string
792    ):
793        return exp.convert(expression.name.startswith(expression.expression.name))
794
795    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.dialect.Dialect], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.In'>}
def simplify_datetrunc(expression, *args, **kwargs):
119        def wrapped(expression, *args, **kwargs):
120            try:
121                return func(expression, *args, **kwargs)
122            except exceptions:
123                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
922def sort_comparison(expression: exp.Expression) -> exp.Expression:
923    if expression.__class__ in COMPLEMENT_COMPARISONS:
924        l, r = expression.this, expression.expression
925        l_column = isinstance(l, exp.Column)
926        r_column = isinstance(r, exp.Column)
927        l_const = _is_constant(l)
928        r_const = _is_constant(r)
929
930        if (l_column and not r_column) or (r_const and not l_const):
931            return expression
932        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
933            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
934                this=r, expression=l
935            )
936    return expression
JOINS = {('', 'INNER'), ('RIGHT', ''), ('RIGHT', 'OUTER'), ('', '')}
def remove_where_true(expression):
950def remove_where_true(expression):
951    for where in expression.find_all(exp.Where):
952        if always_true(where.this):
953            where.pop()
954    for join in expression.find_all(exp.Join):
955        if (
956            always_true(join.args.get("on"))
957            and not join.args.get("using")
958            and not join.args.get("method")
959            and (join.side, join.kind) in JOINS
960        ):
961            join.args["on"].pop()
962            join.set("side", None)
963            join.set("kind", "CROSS")
def always_true(expression):
966def always_true(expression):
967    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
968        expression, exp.Literal
969    )
def always_false(expression):
972def always_false(expression):
973    return is_false(expression) or is_null(expression)
def is_complement(a, b):
976def is_complement(a, b):
977    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
980def is_false(a: exp.Expression) -> bool:
981    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
984def is_null(a: exp.Expression) -> bool:
985    return type(a) is exp.Null
def eval_boolean(expression, a, b):
 988def eval_boolean(expression, a, b):
 989    if isinstance(expression, (exp.EQ, exp.Is)):
 990        return boolean_literal(a == b)
 991    if isinstance(expression, exp.NEQ):
 992        return boolean_literal(a != b)
 993    if isinstance(expression, exp.GT):
 994        return boolean_literal(a > b)
 995    if isinstance(expression, exp.GTE):
 996        return boolean_literal(a >= b)
 997    if isinstance(expression, exp.LT):
 998        return boolean_literal(a < b)
 999    if isinstance(expression, exp.LTE):
1000        return boolean_literal(a <= b)
1001    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1004def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1005    if isinstance(value, datetime.datetime):
1006        return value.date()
1007    if isinstance(value, datetime.date):
1008        return value
1009    try:
1010        return datetime.datetime.fromisoformat(value).date()
1011    except ValueError:
1012        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1015def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1016    if isinstance(value, datetime.datetime):
1017        return value
1018    if isinstance(value, datetime.date):
1019        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1020    try:
1021        return datetime.datetime.fromisoformat(value)
1022    except ValueError:
1023        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1026def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1027    if not value:
1028        return None
1029    if to.is_type(exp.DataType.Type.DATE):
1030        return cast_as_date(value)
1031    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1032        return cast_as_datetime(value)
1033    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1036def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1037    if isinstance(cast, exp.Cast):
1038        to = cast.to
1039    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1040        to = exp.DataType.build(exp.DataType.Type.DATE)
1041    else:
1042        return None
1043
1044    if isinstance(cast.this, exp.Literal):
1045        value: t.Any = cast.this.name
1046    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1047        value = extract_date(cast.this)
1048    else:
1049        return None
1050    return cast_value(value, to)
def extract_interval(expression):
1057def extract_interval(expression):
1058    try:
1059        n = int(expression.name)
1060        unit = expression.text("unit").lower()
1061        return interval(unit, n)
1062    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1063        return None
def date_literal(date):
1066def date_literal(date):
1067    return exp.cast(
1068        exp.Literal.string(date),
1069        (
1070            exp.DataType.Type.DATETIME
1071            if isinstance(date, datetime.datetime)
1072            else exp.DataType.Type.DATE
1073        ),
1074    )
def interval(unit: str, n: int = 1):
1077def interval(unit: str, n: int = 1):
1078    from dateutil.relativedelta import relativedelta
1079
1080    if unit == "year":
1081        return relativedelta(years=1 * n)
1082    if unit == "quarter":
1083        return relativedelta(months=3 * n)
1084    if unit == "month":
1085        return relativedelta(months=1 * n)
1086    if unit == "week":
1087        return relativedelta(weeks=1 * n)
1088    if unit == "day":
1089        return relativedelta(days=1 * n)
1090    if unit == "hour":
1091        return relativedelta(hours=1 * n)
1092    if unit == "minute":
1093        return relativedelta(minutes=1 * n)
1094    if unit == "second":
1095        return relativedelta(seconds=1 * n)
1096
1097    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1100def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1101    if unit == "year":
1102        return d.replace(month=1, day=1)
1103    if unit == "quarter":
1104        if d.month <= 3:
1105            return d.replace(month=1, day=1)
1106        elif d.month <= 6:
1107            return d.replace(month=4, day=1)
1108        elif d.month <= 9:
1109            return d.replace(month=7, day=1)
1110        else:
1111            return d.replace(month=10, day=1)
1112    if unit == "month":
1113        return d.replace(month=d.month, day=1)
1114    if unit == "week":
1115        # Assuming week starts on Monday (0) and ends on Sunday (6)
1116        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1117    if unit == "day":
1118        return d
1119
1120    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1123def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1124    floor = date_floor(d, unit, dialect)
1125
1126    if floor == d:
1127        return d
1128
1129    return floor + interval(unit)
def boolean_literal(condition):
1132def boolean_literal(condition):
1133    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1162def gen(expression: t.Any) -> str:
1163    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1164
1165    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1166    generator is expensive so we have a bare minimum sql generator here.
1167    """
1168    return Gen().gen(expression)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

class Gen:
1171class Gen:
1172    def __init__(self):
1173        self.stack = []
1174        self.sqls = []
1175
1176    def gen(self, expression: exp.Expression) -> str:
1177        self.stack = [expression]
1178        self.sqls.clear()
1179
1180        while self.stack:
1181            node = self.stack.pop()
1182
1183            if isinstance(node, exp.Expression):
1184                exp_handler_name = f"{node.key}_sql"
1185
1186                if hasattr(self, exp_handler_name):
1187                    getattr(self, exp_handler_name)(node)
1188                elif isinstance(node, exp.Func):
1189                    self._function(node)
1190                else:
1191                    key = node.key.upper()
1192                    self.stack.append(f"{key} " if self._args(node) else key)
1193            elif type(node) is list:
1194                for n in reversed(node):
1195                    if n is not None:
1196                        self.stack.extend((n, ","))
1197                if node:
1198                    self.stack.pop()
1199            else:
1200                if node is not None:
1201                    self.sqls.append(str(node))
1202
1203        return "".join(self.sqls)
1204
1205    def add_sql(self, e: exp.Add) -> None:
1206        self._binary(e, " + ")
1207
1208    def alias_sql(self, e: exp.Alias) -> None:
1209        self.stack.extend(
1210            (
1211                e.args.get("alias"),
1212                " AS ",
1213                e.args.get("this"),
1214            )
1215        )
1216
1217    def and_sql(self, e: exp.And) -> None:
1218        self._binary(e, " AND ")
1219
1220    def anonymous_sql(self, e: exp.Anonymous) -> None:
1221        this = e.this
1222        if isinstance(this, str):
1223            name = this.upper()
1224        elif isinstance(this, exp.Identifier):
1225            name = this.this
1226            name = f'"{name}"' if this.quoted else name.upper()
1227        else:
1228            raise ValueError(
1229                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1230            )
1231
1232        self.stack.extend(
1233            (
1234                ")",
1235                e.expressions,
1236                "(",
1237                name,
1238            )
1239        )
1240
1241    def between_sql(self, e: exp.Between) -> None:
1242        self.stack.extend(
1243            (
1244                e.args.get("high"),
1245                " AND ",
1246                e.args.get("low"),
1247                " BETWEEN ",
1248                e.this,
1249            )
1250        )
1251
1252    def boolean_sql(self, e: exp.Boolean) -> None:
1253        self.stack.append("TRUE" if e.this else "FALSE")
1254
1255    def bracket_sql(self, e: exp.Bracket) -> None:
1256        self.stack.extend(
1257            (
1258                "]",
1259                e.expressions,
1260                "[",
1261                e.this,
1262            )
1263        )
1264
1265    def column_sql(self, e: exp.Column) -> None:
1266        for p in reversed(e.parts):
1267            self.stack.extend((p, "."))
1268        self.stack.pop()
1269
1270    def datatype_sql(self, e: exp.DataType) -> None:
1271        self._args(e, 1)
1272        self.stack.append(f"{e.this.name} ")
1273
1274    def div_sql(self, e: exp.Div) -> None:
1275        self._binary(e, " / ")
1276
1277    def dot_sql(self, e: exp.Dot) -> None:
1278        self._binary(e, ".")
1279
1280    def eq_sql(self, e: exp.EQ) -> None:
1281        self._binary(e, " = ")
1282
1283    def from_sql(self, e: exp.From) -> None:
1284        self.stack.extend((e.this, "FROM "))
1285
1286    def gt_sql(self, e: exp.GT) -> None:
1287        self._binary(e, " > ")
1288
1289    def gte_sql(self, e: exp.GTE) -> None:
1290        self._binary(e, " >= ")
1291
1292    def identifier_sql(self, e: exp.Identifier) -> None:
1293        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1294
1295    def ilike_sql(self, e: exp.ILike) -> None:
1296        self._binary(e, " ILIKE ")
1297
1298    def in_sql(self, e: exp.In) -> None:
1299        self.stack.append(")")
1300        self._args(e, 1)
1301        self.stack.extend(
1302            (
1303                "(",
1304                " IN ",
1305                e.this,
1306            )
1307        )
1308
1309    def intdiv_sql(self, e: exp.IntDiv) -> None:
1310        self._binary(e, " DIV ")
1311
1312    def is_sql(self, e: exp.Is) -> None:
1313        self._binary(e, " IS ")
1314
1315    def like_sql(self, e: exp.Like) -> None:
1316        self._binary(e, " Like ")
1317
1318    def literal_sql(self, e: exp.Literal) -> None:
1319        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1320
1321    def lt_sql(self, e: exp.LT) -> None:
1322        self._binary(e, " < ")
1323
1324    def lte_sql(self, e: exp.LTE) -> None:
1325        self._binary(e, " <= ")
1326
1327    def mod_sql(self, e: exp.Mod) -> None:
1328        self._binary(e, " % ")
1329
1330    def mul_sql(self, e: exp.Mul) -> None:
1331        self._binary(e, " * ")
1332
1333    def neg_sql(self, e: exp.Neg) -> None:
1334        self._unary(e, "-")
1335
1336    def neq_sql(self, e: exp.NEQ) -> None:
1337        self._binary(e, " <> ")
1338
1339    def not_sql(self, e: exp.Not) -> None:
1340        self._unary(e, "NOT ")
1341
1342    def null_sql(self, e: exp.Null) -> None:
1343        self.stack.append("NULL")
1344
1345    def or_sql(self, e: exp.Or) -> None:
1346        self._binary(e, " OR ")
1347
1348    def paren_sql(self, e: exp.Paren) -> None:
1349        self.stack.extend(
1350            (
1351                ")",
1352                e.this,
1353                "(",
1354            )
1355        )
1356
1357    def sub_sql(self, e: exp.Sub) -> None:
1358        self._binary(e, " - ")
1359
1360    def subquery_sql(self, e: exp.Subquery) -> None:
1361        self._args(e, 2)
1362        alias = e.args.get("alias")
1363        if alias:
1364            self.stack.append(alias)
1365        self.stack.extend((")", e.this, "("))
1366
1367    def table_sql(self, e: exp.Table) -> None:
1368        self._args(e, 4)
1369        alias = e.args.get("alias")
1370        if alias:
1371            self.stack.append(alias)
1372        for p in reversed(e.parts):
1373            self.stack.extend((p, "."))
1374        self.stack.pop()
1375
1376    def tablealias_sql(self, e: exp.TableAlias) -> None:
1377        columns = e.columns
1378
1379        if columns:
1380            self.stack.extend((")", columns, "("))
1381
1382        self.stack.extend((e.this, " AS "))
1383
1384    def var_sql(self, e: exp.Var) -> None:
1385        self.stack.append(e.this)
1386
1387    def _binary(self, e: exp.Binary, op: str) -> None:
1388        self.stack.extend((e.expression, op, e.this))
1389
1390    def _unary(self, e: exp.Unary, op: str) -> None:
1391        self.stack.extend((e.this, op))
1392
1393    def _function(self, e: exp.Func) -> None:
1394        self.stack.extend(
1395            (
1396                ")",
1397                list(e.args.values()),
1398                "(",
1399                e.sql_name(),
1400            )
1401        )
1402
1403    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1404        kvs = []
1405        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1406
1407        for k in arg_types or arg_types:
1408            v = node.args.get(k)
1409
1410            if v is not None:
1411                kvs.append([f":{k}", v])
1412        if kvs:
1413            self.stack.append(kvs)
1414            return True
1415        return False
stack
sqls
def gen(self, expression: sqlglot.expressions.Expression) -> str:
1176    def gen(self, expression: exp.Expression) -> str:
1177        self.stack = [expression]
1178        self.sqls.clear()
1179
1180        while self.stack:
1181            node = self.stack.pop()
1182
1183            if isinstance(node, exp.Expression):
1184                exp_handler_name = f"{node.key}_sql"
1185
1186                if hasattr(self, exp_handler_name):
1187                    getattr(self, exp_handler_name)(node)
1188                elif isinstance(node, exp.Func):
1189                    self._function(node)
1190                else:
1191                    key = node.key.upper()
1192                    self.stack.append(f"{key} " if self._args(node) else key)
1193            elif type(node) is list:
1194                for n in reversed(node):
1195                    if n is not None:
1196                        self.stack.extend((n, ","))
1197                if node:
1198                    self.stack.pop()
1199            else:
1200                if node is not None:
1201                    self.sqls.append(str(node))
1202
1203        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1205    def add_sql(self, e: exp.Add) -> None:
1206        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1208    def alias_sql(self, e: exp.Alias) -> None:
1209        self.stack.extend(
1210            (
1211                e.args.get("alias"),
1212                " AS ",
1213                e.args.get("this"),
1214            )
1215        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1217    def and_sql(self, e: exp.And) -> None:
1218        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1220    def anonymous_sql(self, e: exp.Anonymous) -> None:
1221        this = e.this
1222        if isinstance(this, str):
1223            name = this.upper()
1224        elif isinstance(this, exp.Identifier):
1225            name = this.this
1226            name = f'"{name}"' if this.quoted else name.upper()
1227        else:
1228            raise ValueError(
1229                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1230            )
1231
1232        self.stack.extend(
1233            (
1234                ")",
1235                e.expressions,
1236                "(",
1237                name,
1238            )
1239        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1241    def between_sql(self, e: exp.Between) -> None:
1242        self.stack.extend(
1243            (
1244                e.args.get("high"),
1245                " AND ",
1246                e.args.get("low"),
1247                " BETWEEN ",
1248                e.this,
1249            )
1250        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1252    def boolean_sql(self, e: exp.Boolean) -> None:
1253        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1255    def bracket_sql(self, e: exp.Bracket) -> None:
1256        self.stack.extend(
1257            (
1258                "]",
1259                e.expressions,
1260                "[",
1261                e.this,
1262            )
1263        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1265    def column_sql(self, e: exp.Column) -> None:
1266        for p in reversed(e.parts):
1267            self.stack.extend((p, "."))
1268        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1270    def datatype_sql(self, e: exp.DataType) -> None:
1271        self._args(e, 1)
1272        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1274    def div_sql(self, e: exp.Div) -> None:
1275        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1277    def dot_sql(self, e: exp.Dot) -> None:
1278        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1280    def eq_sql(self, e: exp.EQ) -> None:
1281        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1283    def from_sql(self, e: exp.From) -> None:
1284        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1286    def gt_sql(self, e: exp.GT) -> None:
1287        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1289    def gte_sql(self, e: exp.GTE) -> None:
1290        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1292    def identifier_sql(self, e: exp.Identifier) -> None:
1293        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1295    def ilike_sql(self, e: exp.ILike) -> None:
1296        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1298    def in_sql(self, e: exp.In) -> None:
1299        self.stack.append(")")
1300        self._args(e, 1)
1301        self.stack.extend(
1302            (
1303                "(",
1304                " IN ",
1305                e.this,
1306            )
1307        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1309    def intdiv_sql(self, e: exp.IntDiv) -> None:
1310        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1312    def is_sql(self, e: exp.Is) -> None:
1313        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1315    def like_sql(self, e: exp.Like) -> None:
1316        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1318    def literal_sql(self, e: exp.Literal) -> None:
1319        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1321    def lt_sql(self, e: exp.LT) -> None:
1322        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1324    def lte_sql(self, e: exp.LTE) -> None:
1325        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1327    def mod_sql(self, e: exp.Mod) -> None:
1328        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1330    def mul_sql(self, e: exp.Mul) -> None:
1331        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1333    def neg_sql(self, e: exp.Neg) -> None:
1334        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1336    def neq_sql(self, e: exp.NEQ) -> None:
1337        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1339    def not_sql(self, e: exp.Not) -> None:
1340        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1342    def null_sql(self, e: exp.Null) -> None:
1343        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1345    def or_sql(self, e: exp.Or) -> None:
1346        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1348    def paren_sql(self, e: exp.Paren) -> None:
1349        self.stack.extend(
1350            (
1351                ")",
1352                e.this,
1353                "(",
1354            )
1355        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1357    def sub_sql(self, e: exp.Sub) -> None:
1358        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1360    def subquery_sql(self, e: exp.Subquery) -> None:
1361        self._args(e, 2)
1362        alias = e.args.get("alias")
1363        if alias:
1364            self.stack.append(alias)
1365        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1367    def table_sql(self, e: exp.Table) -> None:
1368        self._args(e, 4)
1369        alias = e.args.get("alias")
1370        if alias:
1371            self.stack.append(alias)
1372        for p in reversed(e.parts):
1373            self.stack.extend((p, "."))
1374        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1376    def tablealias_sql(self, e: exp.TableAlias) -> None:
1377        columns = e.columns
1378
1379        if columns:
1380            self.stack.extend((")", columns, "("))
1381
1382        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1384    def var_sql(self, e: exp.Var) -> None:
1385        self.stack.append(e.this)