sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator, unsupported_args 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time, subsecond_precision 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26 from sqlglot.optimizer.annotate_types import TypeAnnotator 27 28 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 29 30logger = logging.getLogger("sqlglot") 31 32UNESCAPED_SEQUENCES = { 33 "\\a": "\a", 34 "\\b": "\b", 35 "\\f": "\f", 36 "\\n": "\n", 37 "\\r": "\r", 38 "\\t": "\t", 39 "\\v": "\v", 40 "\\\\": "\\", 41} 42 43 44def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 45 return lambda self, e: self._annotate_with_type(e, data_type) 46 47 48class Dialects(str, Enum): 49 """Dialects supported by SQLGLot.""" 50 51 DIALECT = "" 52 53 ATHENA = "athena" 54 BIGQUERY = "bigquery" 55 CLICKHOUSE = "clickhouse" 56 DATABRICKS = "databricks" 57 DORIS = "doris" 58 DRILL = "drill" 59 DUCKDB = "duckdb" 60 HIVE = "hive" 61 MATERIALIZE = "materialize" 62 MYSQL = "mysql" 63 ORACLE = "oracle" 64 POSTGRES = "postgres" 65 PRESTO = "presto" 66 PRQL = "prql" 67 REDSHIFT = "redshift" 68 RISINGWAVE = "risingwave" 69 SNOWFLAKE = "snowflake" 70 SPARK = "spark" 71 SPARK2 = "spark2" 72 SQLITE = "sqlite" 73 STARROCKS = "starrocks" 74 TABLEAU = "tableau" 75 TERADATA = "teradata" 76 TRINO = "trino" 77 TSQL = "tsql" 78 79 80class NormalizationStrategy(str, AutoName): 81 """Specifies the strategy according to which identifiers should be normalized.""" 82 83 LOWERCASE = auto() 84 """Unquoted identifiers are lowercased.""" 85 86 UPPERCASE = auto() 87 """Unquoted identifiers are uppercased.""" 88 89 CASE_SENSITIVE = auto() 90 """Always case-sensitive, regardless of quotes.""" 91 92 CASE_INSENSITIVE = auto() 93 """Always case-insensitive, regardless of quotes.""" 94 95 96class _Dialect(type): 97 classes: t.Dict[str, t.Type[Dialect]] = {} 98 99 def __eq__(cls, other: t.Any) -> bool: 100 if cls is other: 101 return True 102 if isinstance(other, str): 103 return cls is cls.get(other) 104 if isinstance(other, Dialect): 105 return cls is type(other) 106 107 return False 108 109 def __hash__(cls) -> int: 110 return hash(cls.__name__.lower()) 111 112 @classmethod 113 def __getitem__(cls, key: str) -> t.Type[Dialect]: 114 return cls.classes[key] 115 116 @classmethod 117 def get( 118 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 119 ) -> t.Optional[t.Type[Dialect]]: 120 return cls.classes.get(key, default) 121 122 def __new__(cls, clsname, bases, attrs): 123 klass = super().__new__(cls, clsname, bases, attrs) 124 enum = Dialects.__members__.get(clsname.upper()) 125 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 126 127 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 128 klass.FORMAT_TRIE = ( 129 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 130 ) 131 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 132 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 133 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 134 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 135 136 klass.INVERSE_CREATABLE_KIND_MAPPING = { 137 v: k for k, v in klass.CREATABLE_KIND_MAPPING.items() 138 } 139 140 base = seq_get(bases, 0) 141 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 142 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 143 base_parser = (getattr(base, "parser_class", Parser),) 144 base_generator = (getattr(base, "generator_class", Generator),) 145 146 klass.tokenizer_class = klass.__dict__.get( 147 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 148 ) 149 klass.jsonpath_tokenizer_class = klass.__dict__.get( 150 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 151 ) 152 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 153 klass.generator_class = klass.__dict__.get( 154 "Generator", type("Generator", base_generator, {}) 155 ) 156 157 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 158 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 159 klass.tokenizer_class._IDENTIFIERS.items() 160 )[0] 161 162 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 163 return next( 164 ( 165 (s, e) 166 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 167 if t == token_type 168 ), 169 (None, None), 170 ) 171 172 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 173 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 174 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 175 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 176 177 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 178 klass.UNESCAPED_SEQUENCES = { 179 **UNESCAPED_SEQUENCES, 180 **klass.UNESCAPED_SEQUENCES, 181 } 182 183 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 184 185 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 186 187 if enum not in ("", "bigquery"): 188 klass.generator_class.SELECT_KINDS = () 189 190 if enum not in ("", "athena", "presto", "trino"): 191 klass.generator_class.TRY_SUPPORTED = False 192 klass.generator_class.SUPPORTS_UESCAPE = False 193 194 if enum not in ("", "databricks", "hive", "spark", "spark2"): 195 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 196 for modifier in ("cluster", "distribute", "sort"): 197 modifier_transforms.pop(modifier, None) 198 199 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 200 201 if enum not in ("", "doris", "mysql"): 202 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 203 TokenType.STRAIGHT_JOIN, 204 } 205 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 206 TokenType.STRAIGHT_JOIN, 207 } 208 209 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 210 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 211 TokenType.ANTI, 212 TokenType.SEMI, 213 } 214 215 return klass 216 217 218class Dialect(metaclass=_Dialect): 219 INDEX_OFFSET = 0 220 """The base index offset for arrays.""" 221 222 WEEK_OFFSET = 0 223 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 224 225 UNNEST_COLUMN_ONLY = False 226 """Whether `UNNEST` table aliases are treated as column aliases.""" 227 228 ALIAS_POST_TABLESAMPLE = False 229 """Whether the table alias comes after tablesample.""" 230 231 TABLESAMPLE_SIZE_IS_PERCENT = False 232 """Whether a size in the table sample clause represents percentage.""" 233 234 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 235 """Specifies the strategy according to which identifiers should be normalized.""" 236 237 IDENTIFIERS_CAN_START_WITH_DIGIT = False 238 """Whether an unquoted identifier can start with a digit.""" 239 240 DPIPE_IS_STRING_CONCAT = True 241 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 242 243 STRICT_STRING_CONCAT = False 244 """Whether `CONCAT`'s arguments must be strings.""" 245 246 SUPPORTS_USER_DEFINED_TYPES = True 247 """Whether user-defined data types are supported.""" 248 249 SUPPORTS_SEMI_ANTI_JOIN = True 250 """Whether `SEMI` or `ANTI` joins are supported.""" 251 252 SUPPORTS_COLUMN_JOIN_MARKS = False 253 """Whether the old-style outer join (+) syntax is supported.""" 254 255 COPY_PARAMS_ARE_CSV = True 256 """Separator of COPY statement parameters.""" 257 258 NORMALIZE_FUNCTIONS: bool | str = "upper" 259 """ 260 Determines how function names are going to be normalized. 261 Possible values: 262 "upper" or True: Convert names to uppercase. 263 "lower": Convert names to lowercase. 264 False: Disables function name normalization. 265 """ 266 267 PRESERVE_ORIGINAL_NAMES: bool = False 268 """ 269 Whether the name of the function should be preserved inside the node's metadata, 270 can be useful for roundtripping deprecated vs new functions that share an AST node 271 e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery 272 """ 273 274 LOG_BASE_FIRST: t.Optional[bool] = True 275 """ 276 Whether the base comes first in the `LOG` function. 277 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 278 """ 279 280 NULL_ORDERING = "nulls_are_small" 281 """ 282 Default `NULL` ordering method to use if not explicitly set. 283 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 284 """ 285 286 TYPED_DIVISION = False 287 """ 288 Whether the behavior of `a / b` depends on the types of `a` and `b`. 289 False means `a / b` is always float division. 290 True means `a / b` is integer division if both `a` and `b` are integers. 291 """ 292 293 SAFE_DIVISION = False 294 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 295 296 CONCAT_COALESCE = False 297 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 298 299 HEX_LOWERCASE = False 300 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 301 302 DATE_FORMAT = "'%Y-%m-%d'" 303 DATEINT_FORMAT = "'%Y%m%d'" 304 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 305 306 TIME_MAPPING: t.Dict[str, str] = {} 307 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 308 309 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 310 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 311 FORMAT_MAPPING: t.Dict[str, str] = {} 312 """ 313 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 314 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 315 """ 316 317 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 318 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 319 320 PSEUDOCOLUMNS: t.Set[str] = set() 321 """ 322 Columns that are auto-generated by the engine corresponding to this dialect. 323 For example, such columns may be excluded from `SELECT *` queries. 324 """ 325 326 PREFER_CTE_ALIAS_COLUMN = False 327 """ 328 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 329 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 330 any projection aliases in the subquery. 331 332 For example, 333 WITH y(c) AS ( 334 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 335 ) SELECT c FROM y; 336 337 will be rewritten as 338 339 WITH y(c) AS ( 340 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 341 ) SELECT c FROM y; 342 """ 343 344 COPY_PARAMS_ARE_CSV = True 345 """ 346 Whether COPY statement parameters are separated by comma or whitespace 347 """ 348 349 FORCE_EARLY_ALIAS_REF_EXPANSION = False 350 """ 351 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 352 353 For example: 354 WITH data AS ( 355 SELECT 356 1 AS id, 357 2 AS my_id 358 ) 359 SELECT 360 id AS my_id 361 FROM 362 data 363 WHERE 364 my_id = 1 365 GROUP BY 366 my_id, 367 HAVING 368 my_id = 1 369 370 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 371 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 372 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 373 - Clickhouse, which will forward the alias across the query i.e it resolves 374 to "WHERE id = 1 GROUP BY id HAVING id = 1" 375 """ 376 377 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 378 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 379 380 SUPPORTS_ORDER_BY_ALL = False 381 """ 382 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 383 """ 384 385 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 386 """ 387 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 388 as the former is of type INT[] vs the latter which is SUPER 389 """ 390 391 SUPPORTS_FIXED_SIZE_ARRAYS = False 392 """ 393 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 394 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 395 be interpreted as a subscript/index operator. 396 """ 397 398 STRICT_JSON_PATH_SYNTAX = True 399 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 400 401 ON_CONDITION_EMPTY_BEFORE_ERROR = True 402 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 403 404 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 405 """Whether ArrayAgg needs to filter NULL values.""" 406 407 PROMOTE_TO_INFERRED_DATETIME_TYPE = False 408 """ 409 This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted 410 to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal 411 is cast to x's type to match it instead. 412 """ 413 414 SUPPORTS_VALUES_DEFAULT = True 415 """Whether the DEFAULT keyword is supported in the VALUES clause.""" 416 417 REGEXP_EXTRACT_DEFAULT_GROUP = 0 418 """The default value for the capturing group.""" 419 420 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 421 exp.Except: True, 422 exp.Intersect: True, 423 exp.Union: True, 424 } 425 """ 426 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 427 must be explicitly specified. 428 """ 429 430 CREATABLE_KIND_MAPPING: dict[str, str] = {} 431 """ 432 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 433 equivalent of CREATE SCHEMA is CREATE DATABASE. 434 """ 435 436 # --- Autofilled --- 437 438 tokenizer_class = Tokenizer 439 jsonpath_tokenizer_class = JSONPathTokenizer 440 parser_class = Parser 441 generator_class = Generator 442 443 # A trie of the time_mapping keys 444 TIME_TRIE: t.Dict = {} 445 FORMAT_TRIE: t.Dict = {} 446 447 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 448 INVERSE_TIME_TRIE: t.Dict = {} 449 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 450 INVERSE_FORMAT_TRIE: t.Dict = {} 451 452 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 453 454 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 455 456 # Delimiters for string literals and identifiers 457 QUOTE_START = "'" 458 QUOTE_END = "'" 459 IDENTIFIER_START = '"' 460 IDENTIFIER_END = '"' 461 462 # Delimiters for bit, hex, byte and unicode literals 463 BIT_START: t.Optional[str] = None 464 BIT_END: t.Optional[str] = None 465 HEX_START: t.Optional[str] = None 466 HEX_END: t.Optional[str] = None 467 BYTE_START: t.Optional[str] = None 468 BYTE_END: t.Optional[str] = None 469 UNICODE_START: t.Optional[str] = None 470 UNICODE_END: t.Optional[str] = None 471 472 DATE_PART_MAPPING = { 473 "Y": "YEAR", 474 "YY": "YEAR", 475 "YYY": "YEAR", 476 "YYYY": "YEAR", 477 "YR": "YEAR", 478 "YEARS": "YEAR", 479 "YRS": "YEAR", 480 "MM": "MONTH", 481 "MON": "MONTH", 482 "MONS": "MONTH", 483 "MONTHS": "MONTH", 484 "D": "DAY", 485 "DD": "DAY", 486 "DAYS": "DAY", 487 "DAYOFMONTH": "DAY", 488 "DAY OF WEEK": "DAYOFWEEK", 489 "WEEKDAY": "DAYOFWEEK", 490 "DOW": "DAYOFWEEK", 491 "DW": "DAYOFWEEK", 492 "WEEKDAY_ISO": "DAYOFWEEKISO", 493 "DOW_ISO": "DAYOFWEEKISO", 494 "DW_ISO": "DAYOFWEEKISO", 495 "DAY OF YEAR": "DAYOFYEAR", 496 "DOY": "DAYOFYEAR", 497 "DY": "DAYOFYEAR", 498 "W": "WEEK", 499 "WK": "WEEK", 500 "WEEKOFYEAR": "WEEK", 501 "WOY": "WEEK", 502 "WY": "WEEK", 503 "WEEK_ISO": "WEEKISO", 504 "WEEKOFYEARISO": "WEEKISO", 505 "WEEKOFYEAR_ISO": "WEEKISO", 506 "Q": "QUARTER", 507 "QTR": "QUARTER", 508 "QTRS": "QUARTER", 509 "QUARTERS": "QUARTER", 510 "H": "HOUR", 511 "HH": "HOUR", 512 "HR": "HOUR", 513 "HOURS": "HOUR", 514 "HRS": "HOUR", 515 "M": "MINUTE", 516 "MI": "MINUTE", 517 "MIN": "MINUTE", 518 "MINUTES": "MINUTE", 519 "MINS": "MINUTE", 520 "S": "SECOND", 521 "SEC": "SECOND", 522 "SECONDS": "SECOND", 523 "SECS": "SECOND", 524 "MS": "MILLISECOND", 525 "MSEC": "MILLISECOND", 526 "MSECS": "MILLISECOND", 527 "MSECOND": "MILLISECOND", 528 "MSECONDS": "MILLISECOND", 529 "MILLISEC": "MILLISECOND", 530 "MILLISECS": "MILLISECOND", 531 "MILLISECON": "MILLISECOND", 532 "MILLISECONDS": "MILLISECOND", 533 "US": "MICROSECOND", 534 "USEC": "MICROSECOND", 535 "USECS": "MICROSECOND", 536 "MICROSEC": "MICROSECOND", 537 "MICROSECS": "MICROSECOND", 538 "USECOND": "MICROSECOND", 539 "USECONDS": "MICROSECOND", 540 "MICROSECONDS": "MICROSECOND", 541 "NS": "NANOSECOND", 542 "NSEC": "NANOSECOND", 543 "NANOSEC": "NANOSECOND", 544 "NSECOND": "NANOSECOND", 545 "NSECONDS": "NANOSECOND", 546 "NANOSECS": "NANOSECOND", 547 "EPOCH_SECOND": "EPOCH", 548 "EPOCH_SECONDS": "EPOCH", 549 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 550 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 551 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 552 "TZH": "TIMEZONE_HOUR", 553 "TZM": "TIMEZONE_MINUTE", 554 "DEC": "DECADE", 555 "DECS": "DECADE", 556 "DECADES": "DECADE", 557 "MIL": "MILLENIUM", 558 "MILS": "MILLENIUM", 559 "MILLENIA": "MILLENIUM", 560 "C": "CENTURY", 561 "CENT": "CENTURY", 562 "CENTS": "CENTURY", 563 "CENTURIES": "CENTURY", 564 } 565 566 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 567 exp.DataType.Type.BIGINT: { 568 exp.ApproxDistinct, 569 exp.ArraySize, 570 exp.Length, 571 }, 572 exp.DataType.Type.BOOLEAN: { 573 exp.Between, 574 exp.Boolean, 575 exp.In, 576 exp.RegexpLike, 577 }, 578 exp.DataType.Type.DATE: { 579 exp.CurrentDate, 580 exp.Date, 581 exp.DateFromParts, 582 exp.DateStrToDate, 583 exp.DiToDate, 584 exp.StrToDate, 585 exp.TimeStrToDate, 586 exp.TsOrDsToDate, 587 }, 588 exp.DataType.Type.DATETIME: { 589 exp.CurrentDatetime, 590 exp.Datetime, 591 exp.DatetimeAdd, 592 exp.DatetimeSub, 593 }, 594 exp.DataType.Type.DOUBLE: { 595 exp.ApproxQuantile, 596 exp.Avg, 597 exp.Exp, 598 exp.Ln, 599 exp.Log, 600 exp.Pow, 601 exp.Quantile, 602 exp.Round, 603 exp.SafeDivide, 604 exp.Sqrt, 605 exp.Stddev, 606 exp.StddevPop, 607 exp.StddevSamp, 608 exp.ToDouble, 609 exp.Variance, 610 exp.VariancePop, 611 }, 612 exp.DataType.Type.INT: { 613 exp.Ceil, 614 exp.DatetimeDiff, 615 exp.DateDiff, 616 exp.TimestampDiff, 617 exp.TimeDiff, 618 exp.DateToDi, 619 exp.Levenshtein, 620 exp.Sign, 621 exp.StrPosition, 622 exp.TsOrDiToDi, 623 }, 624 exp.DataType.Type.JSON: { 625 exp.ParseJSON, 626 }, 627 exp.DataType.Type.TIME: { 628 exp.Time, 629 }, 630 exp.DataType.Type.TIMESTAMP: { 631 exp.CurrentTime, 632 exp.CurrentTimestamp, 633 exp.StrToTime, 634 exp.TimeAdd, 635 exp.TimeStrToTime, 636 exp.TimeSub, 637 exp.TimestampAdd, 638 exp.TimestampSub, 639 exp.UnixToTime, 640 }, 641 exp.DataType.Type.TINYINT: { 642 exp.Day, 643 exp.Month, 644 exp.Week, 645 exp.Year, 646 exp.Quarter, 647 }, 648 exp.DataType.Type.VARCHAR: { 649 exp.ArrayConcat, 650 exp.Concat, 651 exp.ConcatWs, 652 exp.DateToDateStr, 653 exp.GroupConcat, 654 exp.Initcap, 655 exp.Lower, 656 exp.Substring, 657 exp.String, 658 exp.TimeToStr, 659 exp.TimeToTimeStr, 660 exp.Trim, 661 exp.TsOrDsToDateStr, 662 exp.UnixToStr, 663 exp.UnixToTimeStr, 664 exp.Upper, 665 }, 666 } 667 668 ANNOTATORS: AnnotatorsType = { 669 **{ 670 expr_type: lambda self, e: self._annotate_unary(e) 671 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 672 }, 673 **{ 674 expr_type: lambda self, e: self._annotate_binary(e) 675 for expr_type in subclasses(exp.__name__, exp.Binary) 676 }, 677 **{ 678 expr_type: _annotate_with_type_lambda(data_type) 679 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 680 for expr_type in expressions 681 }, 682 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 683 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 684 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 685 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 686 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 687 exp.Bracket: lambda self, e: self._annotate_bracket(e), 688 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 689 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 690 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 691 exp.Count: lambda self, e: self._annotate_with_type( 692 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 693 ), 694 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 695 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 696 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 697 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 698 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 699 exp.Div: lambda self, e: self._annotate_div(e), 700 exp.Dot: lambda self, e: self._annotate_dot(e), 701 exp.Explode: lambda self, e: self._annotate_explode(e), 702 exp.Extract: lambda self, e: self._annotate_extract(e), 703 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 704 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 705 e, exp.DataType.build("ARRAY<DATE>") 706 ), 707 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 708 e, exp.DataType.build("ARRAY<TIMESTAMP>") 709 ), 710 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 711 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 712 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 713 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 714 exp.Literal: lambda self, e: self._annotate_literal(e), 715 exp.Map: lambda self, e: self._annotate_map(e), 716 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 717 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 718 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 719 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 720 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 721 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 722 exp.Struct: lambda self, e: self._annotate_struct(e), 723 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 724 exp.Timestamp: lambda self, e: self._annotate_with_type( 725 e, 726 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 727 ), 728 exp.ToMap: lambda self, e: self._annotate_to_map(e), 729 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 730 exp.Unnest: lambda self, e: self._annotate_unnest(e), 731 exp.VarMap: lambda self, e: self._annotate_map(e), 732 } 733 734 @classmethod 735 def get_or_raise(cls, dialect: DialectType) -> Dialect: 736 """ 737 Look up a dialect in the global dialect registry and return it if it exists. 738 739 Args: 740 dialect: The target dialect. If this is a string, it can be optionally followed by 741 additional key-value pairs that are separated by commas and are used to specify 742 dialect settings, such as whether the dialect's identifiers are case-sensitive. 743 744 Example: 745 >>> dialect = dialect_class = get_or_raise("duckdb") 746 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 747 748 Returns: 749 The corresponding Dialect instance. 750 """ 751 752 if not dialect: 753 return cls() 754 if isinstance(dialect, _Dialect): 755 return dialect() 756 if isinstance(dialect, Dialect): 757 return dialect 758 if isinstance(dialect, str): 759 try: 760 dialect_name, *kv_strings = dialect.split(",") 761 kv_pairs = (kv.split("=") for kv in kv_strings) 762 kwargs = {} 763 for pair in kv_pairs: 764 key = pair[0].strip() 765 value: t.Union[bool | str | None] = None 766 767 if len(pair) == 1: 768 # Default initialize standalone settings to True 769 value = True 770 elif len(pair) == 2: 771 value = pair[1].strip() 772 773 # Coerce the value to boolean if it matches to the truthy/falsy values below 774 value_lower = value.lower() 775 if value_lower in ("true", "1"): 776 value = True 777 elif value_lower in ("false", "0"): 778 value = False 779 780 kwargs[key] = value 781 782 except ValueError: 783 raise ValueError( 784 f"Invalid dialect format: '{dialect}'. " 785 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 786 ) 787 788 result = cls.get(dialect_name.strip()) 789 if not result: 790 from difflib import get_close_matches 791 792 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 793 if similar: 794 similar = f" Did you mean {similar}?" 795 796 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 797 798 return result(**kwargs) 799 800 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 801 802 @classmethod 803 def format_time( 804 cls, expression: t.Optional[str | exp.Expression] 805 ) -> t.Optional[exp.Expression]: 806 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 807 if isinstance(expression, str): 808 return exp.Literal.string( 809 # the time formats are quoted 810 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 811 ) 812 813 if expression and expression.is_string: 814 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 815 816 return expression 817 818 def __init__(self, **kwargs) -> None: 819 normalization_strategy = kwargs.pop("normalization_strategy", None) 820 821 if normalization_strategy is None: 822 self.normalization_strategy = self.NORMALIZATION_STRATEGY 823 else: 824 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 825 826 self.settings = kwargs 827 828 def __eq__(self, other: t.Any) -> bool: 829 # Does not currently take dialect state into account 830 return type(self) == other 831 832 def __hash__(self) -> int: 833 # Does not currently take dialect state into account 834 return hash(type(self)) 835 836 def normalize_identifier(self, expression: E) -> E: 837 """ 838 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 839 840 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 841 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 842 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 843 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 844 845 There are also dialects like Spark, which are case-insensitive even when quotes are 846 present, and dialects like MySQL, whose resolution rules match those employed by the 847 underlying operating system, for example they may always be case-sensitive in Linux. 848 849 Finally, the normalization behavior of some engines can even be controlled through flags, 850 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 851 852 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 853 that it can analyze queries in the optimizer and successfully capture their semantics. 854 """ 855 if ( 856 isinstance(expression, exp.Identifier) 857 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 858 and ( 859 not expression.quoted 860 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 861 ) 862 ): 863 expression.set( 864 "this", 865 ( 866 expression.this.upper() 867 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 868 else expression.this.lower() 869 ), 870 ) 871 872 return expression 873 874 def case_sensitive(self, text: str) -> bool: 875 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 876 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 877 return False 878 879 unsafe = ( 880 str.islower 881 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 882 else str.isupper 883 ) 884 return any(unsafe(char) for char in text) 885 886 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 887 """Checks if text can be identified given an identify option. 888 889 Args: 890 text: The text to check. 891 identify: 892 `"always"` or `True`: Always returns `True`. 893 `"safe"`: Only returns `True` if the identifier is case-insensitive. 894 895 Returns: 896 Whether the given text can be identified. 897 """ 898 if identify is True or identify == "always": 899 return True 900 901 if identify == "safe": 902 return not self.case_sensitive(text) 903 904 return False 905 906 def quote_identifier(self, expression: E, identify: bool = True) -> E: 907 """ 908 Adds quotes to a given identifier. 909 910 Args: 911 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 912 identify: If set to `False`, the quotes will only be added if the identifier is deemed 913 "unsafe", with respect to its characters and this dialect's normalization strategy. 914 """ 915 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 916 name = expression.this 917 expression.set( 918 "quoted", 919 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 920 ) 921 922 return expression 923 924 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 925 if isinstance(path, exp.Literal): 926 path_text = path.name 927 if path.is_number: 928 path_text = f"[{path_text}]" 929 try: 930 return parse_json_path(path_text, self) 931 except ParseError as e: 932 if self.STRICT_JSON_PATH_SYNTAX: 933 logger.warning(f"Invalid JSON path syntax. {str(e)}") 934 935 return path 936 937 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 938 return self.parser(**opts).parse(self.tokenize(sql), sql) 939 940 def parse_into( 941 self, expression_type: exp.IntoType, sql: str, **opts 942 ) -> t.List[t.Optional[exp.Expression]]: 943 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 944 945 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 946 return self.generator(**opts).generate(expression, copy=copy) 947 948 def transpile(self, sql: str, **opts) -> t.List[str]: 949 return [ 950 self.generate(expression, copy=False, **opts) if expression else "" 951 for expression in self.parse(sql) 952 ] 953 954 def tokenize(self, sql: str) -> t.List[Token]: 955 return self.tokenizer.tokenize(sql) 956 957 @property 958 def tokenizer(self) -> Tokenizer: 959 return self.tokenizer_class(dialect=self) 960 961 @property 962 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 963 return self.jsonpath_tokenizer_class(dialect=self) 964 965 def parser(self, **opts) -> Parser: 966 return self.parser_class(dialect=self, **opts) 967 968 def generator(self, **opts) -> Generator: 969 return self.generator_class(dialect=self, **opts) 970 971 972DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 973 974 975def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 976 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 977 978 979@unsupported_args("accuracy") 980def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 981 return self.func("APPROX_COUNT_DISTINCT", expression.this) 982 983 984def if_sql( 985 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 986) -> t.Callable[[Generator, exp.If], str]: 987 def _if_sql(self: Generator, expression: exp.If) -> str: 988 return self.func( 989 name, 990 expression.this, 991 expression.args.get("true"), 992 expression.args.get("false") or false_value, 993 ) 994 995 return _if_sql 996 997 998def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 999 this = expression.this 1000 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 1001 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 1002 1003 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 1004 1005 1006def inline_array_sql(self: Generator, expression: exp.Array) -> str: 1007 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 1008 1009 1010def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 1011 elem = seq_get(expression.expressions, 0) 1012 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 1013 return self.func("ARRAY", elem) 1014 return inline_array_sql(self, expression) 1015 1016 1017def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 1018 return self.like_sql( 1019 exp.Like( 1020 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 1021 ) 1022 ) 1023 1024 1025def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 1026 zone = self.sql(expression, "this") 1027 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 1028 1029 1030def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 1031 if expression.args.get("recursive"): 1032 self.unsupported("Recursive CTEs are unsupported") 1033 expression.args["recursive"] = False 1034 return self.with_sql(expression) 1035 1036 1037def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide, if_sql: str = "IF") -> str: 1038 n = self.sql(expression, "this") 1039 d = self.sql(expression, "expression") 1040 return f"{if_sql}(({d}) <> 0, ({n}) / ({d}), NULL)" 1041 1042 1043def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 1044 self.unsupported("TABLESAMPLE unsupported") 1045 return self.sql(expression.this) 1046 1047 1048def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 1049 self.unsupported("PIVOT unsupported") 1050 return "" 1051 1052 1053def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 1054 return self.cast_sql(expression) 1055 1056 1057def no_comment_column_constraint_sql( 1058 self: Generator, expression: exp.CommentColumnConstraint 1059) -> str: 1060 self.unsupported("CommentColumnConstraint unsupported") 1061 return "" 1062 1063 1064def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1065 self.unsupported("MAP_FROM_ENTRIES unsupported") 1066 return "" 1067 1068 1069def property_sql(self: Generator, expression: exp.Property) -> str: 1070 return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" 1071 1072 1073def str_position_sql( 1074 self: Generator, 1075 expression: exp.StrPosition, 1076 generate_instance: bool = False, 1077 str_position_func_name: str = "STRPOS", 1078) -> str: 1079 this = self.sql(expression, "this") 1080 substr = self.sql(expression, "substr") 1081 position = self.sql(expression, "position") 1082 instance = expression.args.get("instance") if generate_instance else None 1083 position_offset = "" 1084 1085 if position: 1086 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1087 this = self.func("SUBSTR", this, position) 1088 position_offset = f" + {position} - 1" 1089 1090 return self.func(str_position_func_name, this, substr, instance) + position_offset 1091 1092 1093def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1094 return ( 1095 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1096 ) 1097 1098 1099def var_map_sql( 1100 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1101) -> str: 1102 keys = expression.args["keys"] 1103 values = expression.args["values"] 1104 1105 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1106 self.unsupported("Cannot convert array columns into map.") 1107 return self.func(map_func_name, keys, values) 1108 1109 args = [] 1110 for key, value in zip(keys.expressions, values.expressions): 1111 args.append(self.sql(key)) 1112 args.append(self.sql(value)) 1113 1114 return self.func(map_func_name, *args) 1115 1116 1117def build_formatted_time( 1118 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1119) -> t.Callable[[t.List], E]: 1120 """Helper used for time expressions. 1121 1122 Args: 1123 exp_class: the expression class to instantiate. 1124 dialect: target sql dialect. 1125 default: the default format, True being time. 1126 1127 Returns: 1128 A callable that can be used to return the appropriately formatted time expression. 1129 """ 1130 1131 def _builder(args: t.List): 1132 return exp_class( 1133 this=seq_get(args, 0), 1134 format=Dialect[dialect].format_time( 1135 seq_get(args, 1) 1136 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1137 ), 1138 ) 1139 1140 return _builder 1141 1142 1143def time_format( 1144 dialect: DialectType = None, 1145) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1146 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1147 """ 1148 Returns the time format for a given expression, unless it's equivalent 1149 to the default time format of the dialect of interest. 1150 """ 1151 time_format = self.format_time(expression) 1152 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1153 1154 return _time_format 1155 1156 1157def build_date_delta( 1158 exp_class: t.Type[E], 1159 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1160 default_unit: t.Optional[str] = "DAY", 1161) -> t.Callable[[t.List], E]: 1162 def _builder(args: t.List) -> E: 1163 unit_based = len(args) == 3 1164 this = args[2] if unit_based else seq_get(args, 0) 1165 unit = None 1166 if unit_based or default_unit: 1167 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1168 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1169 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1170 1171 return _builder 1172 1173 1174def build_date_delta_with_interval( 1175 expression_class: t.Type[E], 1176) -> t.Callable[[t.List], t.Optional[E]]: 1177 def _builder(args: t.List) -> t.Optional[E]: 1178 if len(args) < 2: 1179 return None 1180 1181 interval = args[1] 1182 1183 if not isinstance(interval, exp.Interval): 1184 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1185 1186 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1187 1188 return _builder 1189 1190 1191def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1192 unit = seq_get(args, 0) 1193 this = seq_get(args, 1) 1194 1195 if isinstance(this, exp.Cast) and this.is_type("date"): 1196 return exp.DateTrunc(unit=unit, this=this) 1197 return exp.TimestampTrunc(this=this, unit=unit) 1198 1199 1200def date_add_interval_sql( 1201 data_type: str, kind: str 1202) -> t.Callable[[Generator, exp.Expression], str]: 1203 def func(self: Generator, expression: exp.Expression) -> str: 1204 this = self.sql(expression, "this") 1205 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1206 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1207 1208 return func 1209 1210 1211def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1212 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1213 args = [unit_to_str(expression), expression.this] 1214 if zone: 1215 args.append(expression.args.get("zone")) 1216 return self.func("DATE_TRUNC", *args) 1217 1218 return _timestamptrunc_sql 1219 1220 1221def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1222 zone = expression.args.get("zone") 1223 if not zone: 1224 from sqlglot.optimizer.annotate_types import annotate_types 1225 1226 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1227 return self.sql(exp.cast(expression.this, target_type)) 1228 if zone.name.lower() in TIMEZONES: 1229 return self.sql( 1230 exp.AtTimeZone( 1231 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1232 zone=zone, 1233 ) 1234 ) 1235 return self.func("TIMESTAMP", expression.this, zone) 1236 1237 1238def no_time_sql(self: Generator, expression: exp.Time) -> str: 1239 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1240 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1241 expr = exp.cast( 1242 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1243 ) 1244 return self.sql(expr) 1245 1246 1247def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1248 this = expression.this 1249 expr = expression.expression 1250 1251 if expr.name.lower() in TIMEZONES: 1252 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1253 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1254 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1255 return self.sql(this) 1256 1257 this = exp.cast(this, exp.DataType.Type.DATE) 1258 expr = exp.cast(expr, exp.DataType.Type.TIME) 1259 1260 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1261 1262 1263def locate_to_strposition(args: t.List) -> exp.Expression: 1264 return exp.StrPosition( 1265 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1266 ) 1267 1268 1269def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1270 return self.func( 1271 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1272 ) 1273 1274 1275def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1276 return self.sql( 1277 exp.Substring( 1278 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1279 ) 1280 ) 1281 1282 1283def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1284 return self.sql( 1285 exp.Substring( 1286 this=expression.this, 1287 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1288 ) 1289 ) 1290 1291 1292def timestrtotime_sql( 1293 self: Generator, 1294 expression: exp.TimeStrToTime, 1295 include_precision: bool = False, 1296) -> str: 1297 datatype = exp.DataType.build( 1298 exp.DataType.Type.TIMESTAMPTZ 1299 if expression.args.get("zone") 1300 else exp.DataType.Type.TIMESTAMP 1301 ) 1302 1303 if isinstance(expression.this, exp.Literal) and include_precision: 1304 precision = subsecond_precision(expression.this.name) 1305 if precision > 0: 1306 datatype = exp.DataType.build( 1307 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1308 ) 1309 1310 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect)) 1311 1312 1313def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1314 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1315 1316 1317# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1318def encode_decode_sql( 1319 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1320) -> str: 1321 charset = expression.args.get("charset") 1322 if charset and charset.name.lower() != "utf-8": 1323 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1324 1325 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1326 1327 1328def min_or_least(self: Generator, expression: exp.Min) -> str: 1329 name = "LEAST" if expression.expressions else "MIN" 1330 return rename_func(name)(self, expression) 1331 1332 1333def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1334 name = "GREATEST" if expression.expressions else "MAX" 1335 return rename_func(name)(self, expression) 1336 1337 1338def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1339 cond = expression.this 1340 1341 if isinstance(expression.this, exp.Distinct): 1342 cond = expression.this.expressions[0] 1343 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1344 1345 return self.func("sum", exp.func("if", cond, 1, 0)) 1346 1347 1348def trim_sql(self: Generator, expression: exp.Trim) -> str: 1349 target = self.sql(expression, "this") 1350 trim_type = self.sql(expression, "position") 1351 remove_chars = self.sql(expression, "expression") 1352 collation = self.sql(expression, "collation") 1353 1354 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1355 if not remove_chars: 1356 return self.trim_sql(expression) 1357 1358 trim_type = f"{trim_type} " if trim_type else "" 1359 remove_chars = f"{remove_chars} " if remove_chars else "" 1360 from_part = "FROM " if trim_type or remove_chars else "" 1361 collation = f" COLLATE {collation}" if collation else "" 1362 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1363 1364 1365def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1366 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1367 1368 1369def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1370 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1371 1372 1373def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1374 delim, *rest_args = expression.expressions 1375 return self.sql( 1376 reduce( 1377 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1378 rest_args, 1379 ) 1380 ) 1381 1382 1383@unsupported_args("position", "occurrence", "parameters") 1384def regexp_extract_sql( 1385 self: Generator, expression: exp.RegexpExtract | exp.RegexpExtractAll 1386) -> str: 1387 group = expression.args.get("group") 1388 1389 # Do not render group if it's the default value for this dialect 1390 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1391 group = None 1392 1393 return self.func(expression.sql_name(), expression.this, expression.expression, group) 1394 1395 1396@unsupported_args("position", "occurrence", "modifiers") 1397def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1398 return self.func( 1399 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1400 ) 1401 1402 1403def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1404 names = [] 1405 for agg in aggregations: 1406 if isinstance(agg, exp.Alias): 1407 names.append(agg.alias) 1408 else: 1409 """ 1410 This case corresponds to aggregations without aliases being used as suffixes 1411 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1412 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1413 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1414 """ 1415 agg_all_unquoted = agg.transform( 1416 lambda node: ( 1417 exp.Identifier(this=node.name, quoted=False) 1418 if isinstance(node, exp.Identifier) 1419 else node 1420 ) 1421 ) 1422 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1423 1424 return names 1425 1426 1427def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1428 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1429 1430 1431# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1432def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1433 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1434 1435 1436def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1437 return self.func("MAX", expression.this) 1438 1439 1440def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1441 a = self.sql(expression.left) 1442 b = self.sql(expression.right) 1443 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1444 1445 1446def is_parse_json(expression: exp.Expression) -> bool: 1447 return isinstance(expression, exp.ParseJSON) or ( 1448 isinstance(expression, exp.Cast) and expression.is_type("json") 1449 ) 1450 1451 1452def isnull_to_is_null(args: t.List) -> exp.Expression: 1453 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1454 1455 1456def generatedasidentitycolumnconstraint_sql( 1457 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1458) -> str: 1459 start = self.sql(expression, "start") or "1" 1460 increment = self.sql(expression, "increment") or "1" 1461 return f"IDENTITY({start}, {increment})" 1462 1463 1464def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1465 @unsupported_args("count") 1466 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1467 return self.func(name, expression.this, expression.expression) 1468 1469 return _arg_max_or_min_sql 1470 1471 1472def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1473 this = expression.this.copy() 1474 1475 return_type = expression.return_type 1476 if return_type.is_type(exp.DataType.Type.DATE): 1477 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1478 # can truncate timestamp strings, because some dialects can't cast them to DATE 1479 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1480 1481 expression.this.replace(exp.cast(this, return_type)) 1482 return expression 1483 1484 1485def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1486 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1487 if cast and isinstance(expression, exp.TsOrDsAdd): 1488 expression = ts_or_ds_add_cast(expression) 1489 1490 return self.func( 1491 name, 1492 unit_to_var(expression), 1493 expression.expression, 1494 expression.this, 1495 ) 1496 1497 return _delta_sql 1498 1499 1500def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1501 unit = expression.args.get("unit") 1502 1503 if isinstance(unit, exp.Placeholder): 1504 return unit 1505 if unit: 1506 return exp.Literal.string(unit.name) 1507 return exp.Literal.string(default) if default else None 1508 1509 1510def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1511 unit = expression.args.get("unit") 1512 1513 if isinstance(unit, (exp.Var, exp.Placeholder)): 1514 return unit 1515 return exp.Var(this=default) if default else None 1516 1517 1518@t.overload 1519def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1520 pass 1521 1522 1523@t.overload 1524def map_date_part( 1525 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1526) -> t.Optional[exp.Expression]: 1527 pass 1528 1529 1530def map_date_part(part, dialect: DialectType = Dialect): 1531 mapped = ( 1532 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1533 ) 1534 return exp.var(mapped) if mapped else part 1535 1536 1537def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1538 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1539 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1540 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1541 1542 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1543 1544 1545def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1546 """Remove table refs from columns in when statements.""" 1547 alias = expression.this.args.get("alias") 1548 1549 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1550 return self.dialect.normalize_identifier(identifier).name if identifier else None 1551 1552 targets = {normalize(expression.this.this)} 1553 1554 if alias: 1555 targets.add(normalize(alias.this)) 1556 1557 for when in expression.expressions: 1558 # only remove the target names from the THEN clause 1559 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1560 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1561 then = when.args.get("then") 1562 if then: 1563 then.transform( 1564 lambda node: ( 1565 exp.column(node.this) 1566 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1567 else node 1568 ), 1569 copy=False, 1570 ) 1571 1572 return self.merge_sql(expression) 1573 1574 1575def build_json_extract_path( 1576 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1577) -> t.Callable[[t.List], F]: 1578 def _builder(args: t.List) -> F: 1579 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1580 for arg in args[1:]: 1581 if not isinstance(arg, exp.Literal): 1582 # We use the fallback parser because we can't really transpile non-literals safely 1583 return expr_type.from_arg_list(args) 1584 1585 text = arg.name 1586 if is_int(text): 1587 index = int(text) 1588 segments.append( 1589 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1590 ) 1591 else: 1592 segments.append(exp.JSONPathKey(this=text)) 1593 1594 # This is done to avoid failing in the expression validator due to the arg count 1595 del args[2:] 1596 return expr_type( 1597 this=seq_get(args, 0), 1598 expression=exp.JSONPath(expressions=segments), 1599 only_json_types=arrow_req_json_type, 1600 ) 1601 1602 return _builder 1603 1604 1605def json_extract_segments( 1606 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1607) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1608 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1609 path = expression.expression 1610 if not isinstance(path, exp.JSONPath): 1611 return rename_func(name)(self, expression) 1612 1613 escape = path.args.get("escape") 1614 1615 segments = [] 1616 for segment in path.expressions: 1617 path = self.sql(segment) 1618 if path: 1619 if isinstance(segment, exp.JSONPathPart) and ( 1620 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1621 ): 1622 if escape: 1623 path = self.escape_str(path) 1624 1625 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1626 1627 segments.append(path) 1628 1629 if op: 1630 return f" {op} ".join([self.sql(expression.this), *segments]) 1631 return self.func(name, expression.this, *segments) 1632 1633 return _json_extract_segments 1634 1635 1636def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1637 if isinstance(expression.this, exp.JSONPathWildcard): 1638 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1639 1640 return expression.name 1641 1642 1643def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1644 cond = expression.expression 1645 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1646 alias = cond.expressions[0] 1647 cond = cond.this 1648 elif isinstance(cond, exp.Predicate): 1649 alias = "_u" 1650 else: 1651 self.unsupported("Unsupported filter condition") 1652 return "" 1653 1654 unnest = exp.Unnest(expressions=[expression.this]) 1655 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1656 return self.sql(exp.Array(expressions=[filtered])) 1657 1658 1659def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1660 return self.func( 1661 "TO_NUMBER", 1662 expression.this, 1663 expression.args.get("format"), 1664 expression.args.get("nlsparam"), 1665 ) 1666 1667 1668def build_default_decimal_type( 1669 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1670) -> t.Callable[[exp.DataType], exp.DataType]: 1671 def _builder(dtype: exp.DataType) -> exp.DataType: 1672 if dtype.expressions or precision is None: 1673 return dtype 1674 1675 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1676 return exp.DataType.build(f"DECIMAL({params})") 1677 1678 return _builder 1679 1680 1681def build_timestamp_from_parts(args: t.List) -> exp.Func: 1682 if len(args) == 2: 1683 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1684 # so we parse this into Anonymous for now instead of introducing complexity 1685 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1686 1687 return exp.TimestampFromParts.from_arg_list(args) 1688 1689 1690def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1691 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1692 1693 1694def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1695 start = expression.args.get("start") 1696 end = expression.args.get("end") 1697 step = expression.args.get("step") 1698 1699 if isinstance(start, exp.Cast): 1700 target_type = start.to 1701 elif isinstance(end, exp.Cast): 1702 target_type = end.to 1703 else: 1704 target_type = None 1705 1706 if start and end and target_type and target_type.is_type("date", "timestamp"): 1707 if isinstance(start, exp.Cast) and target_type is start.to: 1708 end = exp.cast(end, target_type) 1709 else: 1710 start = exp.cast(start, target_type) 1711 1712 return self.func("SEQUENCE", start, end, step) 1713 1714 1715def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: 1716 def _builder(args: t.List, dialect: Dialect) -> E: 1717 return expr_type( 1718 this=seq_get(args, 0), 1719 expression=seq_get(args, 1), 1720 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1721 parameters=seq_get(args, 3), 1722 ) 1723 1724 return _builder 1725 1726 1727def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1728 if isinstance(expression.this, exp.Explode): 1729 return self.sql( 1730 exp.Join( 1731 this=exp.Unnest( 1732 expressions=[expression.this.this], 1733 alias=expression.args.get("alias"), 1734 offset=isinstance(expression.this, exp.Posexplode), 1735 ), 1736 kind="cross", 1737 ) 1738 ) 1739 return self.lateral_sql(expression) 1740 1741 1742def timestampdiff_sql(self: Generator, expression: exp.DatetimeDiff | exp.TimestampDiff) -> str: 1743 return self.func("TIMESTAMPDIFF", expression.unit, expression.expression, expression.this) 1744 1745 1746def no_make_interval_sql(self: Generator, expression: exp.MakeInterval, sep: str = ", ") -> str: 1747 args = [] 1748 for unit, value in expression.args.items(): 1749 if isinstance(value, exp.Kwarg): 1750 value = value.expression 1751 1752 args.append(f"{value} {unit}") 1753 1754 return f"INTERVAL '{self.format_args(*args, sep=sep)}'"
49class Dialects(str, Enum): 50 """Dialects supported by SQLGLot.""" 51 52 DIALECT = "" 53 54 ATHENA = "athena" 55 BIGQUERY = "bigquery" 56 CLICKHOUSE = "clickhouse" 57 DATABRICKS = "databricks" 58 DORIS = "doris" 59 DRILL = "drill" 60 DUCKDB = "duckdb" 61 HIVE = "hive" 62 MATERIALIZE = "materialize" 63 MYSQL = "mysql" 64 ORACLE = "oracle" 65 POSTGRES = "postgres" 66 PRESTO = "presto" 67 PRQL = "prql" 68 REDSHIFT = "redshift" 69 RISINGWAVE = "risingwave" 70 SNOWFLAKE = "snowflake" 71 SPARK = "spark" 72 SPARK2 = "spark2" 73 SQLITE = "sqlite" 74 STARROCKS = "starrocks" 75 TABLEAU = "tableau" 76 TERADATA = "teradata" 77 TRINO = "trino" 78 TSQL = "tsql"
Dialects supported by SQLGLot.
81class NormalizationStrategy(str, AutoName): 82 """Specifies the strategy according to which identifiers should be normalized.""" 83 84 LOWERCASE = auto() 85 """Unquoted identifiers are lowercased.""" 86 87 UPPERCASE = auto() 88 """Unquoted identifiers are uppercased.""" 89 90 CASE_SENSITIVE = auto() 91 """Always case-sensitive, regardless of quotes.""" 92 93 CASE_INSENSITIVE = auto() 94 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
219class Dialect(metaclass=_Dialect): 220 INDEX_OFFSET = 0 221 """The base index offset for arrays.""" 222 223 WEEK_OFFSET = 0 224 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 225 226 UNNEST_COLUMN_ONLY = False 227 """Whether `UNNEST` table aliases are treated as column aliases.""" 228 229 ALIAS_POST_TABLESAMPLE = False 230 """Whether the table alias comes after tablesample.""" 231 232 TABLESAMPLE_SIZE_IS_PERCENT = False 233 """Whether a size in the table sample clause represents percentage.""" 234 235 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 236 """Specifies the strategy according to which identifiers should be normalized.""" 237 238 IDENTIFIERS_CAN_START_WITH_DIGIT = False 239 """Whether an unquoted identifier can start with a digit.""" 240 241 DPIPE_IS_STRING_CONCAT = True 242 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 243 244 STRICT_STRING_CONCAT = False 245 """Whether `CONCAT`'s arguments must be strings.""" 246 247 SUPPORTS_USER_DEFINED_TYPES = True 248 """Whether user-defined data types are supported.""" 249 250 SUPPORTS_SEMI_ANTI_JOIN = True 251 """Whether `SEMI` or `ANTI` joins are supported.""" 252 253 SUPPORTS_COLUMN_JOIN_MARKS = False 254 """Whether the old-style outer join (+) syntax is supported.""" 255 256 COPY_PARAMS_ARE_CSV = True 257 """Separator of COPY statement parameters.""" 258 259 NORMALIZE_FUNCTIONS: bool | str = "upper" 260 """ 261 Determines how function names are going to be normalized. 262 Possible values: 263 "upper" or True: Convert names to uppercase. 264 "lower": Convert names to lowercase. 265 False: Disables function name normalization. 266 """ 267 268 PRESERVE_ORIGINAL_NAMES: bool = False 269 """ 270 Whether the name of the function should be preserved inside the node's metadata, 271 can be useful for roundtripping deprecated vs new functions that share an AST node 272 e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery 273 """ 274 275 LOG_BASE_FIRST: t.Optional[bool] = True 276 """ 277 Whether the base comes first in the `LOG` function. 278 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 279 """ 280 281 NULL_ORDERING = "nulls_are_small" 282 """ 283 Default `NULL` ordering method to use if not explicitly set. 284 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 285 """ 286 287 TYPED_DIVISION = False 288 """ 289 Whether the behavior of `a / b` depends on the types of `a` and `b`. 290 False means `a / b` is always float division. 291 True means `a / b` is integer division if both `a` and `b` are integers. 292 """ 293 294 SAFE_DIVISION = False 295 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 296 297 CONCAT_COALESCE = False 298 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 299 300 HEX_LOWERCASE = False 301 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 302 303 DATE_FORMAT = "'%Y-%m-%d'" 304 DATEINT_FORMAT = "'%Y%m%d'" 305 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 306 307 TIME_MAPPING: t.Dict[str, str] = {} 308 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 309 310 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 311 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 312 FORMAT_MAPPING: t.Dict[str, str] = {} 313 """ 314 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 315 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 316 """ 317 318 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 319 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 320 321 PSEUDOCOLUMNS: t.Set[str] = set() 322 """ 323 Columns that are auto-generated by the engine corresponding to this dialect. 324 For example, such columns may be excluded from `SELECT *` queries. 325 """ 326 327 PREFER_CTE_ALIAS_COLUMN = False 328 """ 329 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 330 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 331 any projection aliases in the subquery. 332 333 For example, 334 WITH y(c) AS ( 335 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 336 ) SELECT c FROM y; 337 338 will be rewritten as 339 340 WITH y(c) AS ( 341 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 342 ) SELECT c FROM y; 343 """ 344 345 COPY_PARAMS_ARE_CSV = True 346 """ 347 Whether COPY statement parameters are separated by comma or whitespace 348 """ 349 350 FORCE_EARLY_ALIAS_REF_EXPANSION = False 351 """ 352 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 353 354 For example: 355 WITH data AS ( 356 SELECT 357 1 AS id, 358 2 AS my_id 359 ) 360 SELECT 361 id AS my_id 362 FROM 363 data 364 WHERE 365 my_id = 1 366 GROUP BY 367 my_id, 368 HAVING 369 my_id = 1 370 371 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 372 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 373 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 374 - Clickhouse, which will forward the alias across the query i.e it resolves 375 to "WHERE id = 1 GROUP BY id HAVING id = 1" 376 """ 377 378 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 379 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 380 381 SUPPORTS_ORDER_BY_ALL = False 382 """ 383 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 384 """ 385 386 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 387 """ 388 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 389 as the former is of type INT[] vs the latter which is SUPER 390 """ 391 392 SUPPORTS_FIXED_SIZE_ARRAYS = False 393 """ 394 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 395 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 396 be interpreted as a subscript/index operator. 397 """ 398 399 STRICT_JSON_PATH_SYNTAX = True 400 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 401 402 ON_CONDITION_EMPTY_BEFORE_ERROR = True 403 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 404 405 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 406 """Whether ArrayAgg needs to filter NULL values.""" 407 408 PROMOTE_TO_INFERRED_DATETIME_TYPE = False 409 """ 410 This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted 411 to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal 412 is cast to x's type to match it instead. 413 """ 414 415 SUPPORTS_VALUES_DEFAULT = True 416 """Whether the DEFAULT keyword is supported in the VALUES clause.""" 417 418 REGEXP_EXTRACT_DEFAULT_GROUP = 0 419 """The default value for the capturing group.""" 420 421 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 422 exp.Except: True, 423 exp.Intersect: True, 424 exp.Union: True, 425 } 426 """ 427 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 428 must be explicitly specified. 429 """ 430 431 CREATABLE_KIND_MAPPING: dict[str, str] = {} 432 """ 433 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 434 equivalent of CREATE SCHEMA is CREATE DATABASE. 435 """ 436 437 # --- Autofilled --- 438 439 tokenizer_class = Tokenizer 440 jsonpath_tokenizer_class = JSONPathTokenizer 441 parser_class = Parser 442 generator_class = Generator 443 444 # A trie of the time_mapping keys 445 TIME_TRIE: t.Dict = {} 446 FORMAT_TRIE: t.Dict = {} 447 448 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 449 INVERSE_TIME_TRIE: t.Dict = {} 450 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 451 INVERSE_FORMAT_TRIE: t.Dict = {} 452 453 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 454 455 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 456 457 # Delimiters for string literals and identifiers 458 QUOTE_START = "'" 459 QUOTE_END = "'" 460 IDENTIFIER_START = '"' 461 IDENTIFIER_END = '"' 462 463 # Delimiters for bit, hex, byte and unicode literals 464 BIT_START: t.Optional[str] = None 465 BIT_END: t.Optional[str] = None 466 HEX_START: t.Optional[str] = None 467 HEX_END: t.Optional[str] = None 468 BYTE_START: t.Optional[str] = None 469 BYTE_END: t.Optional[str] = None 470 UNICODE_START: t.Optional[str] = None 471 UNICODE_END: t.Optional[str] = None 472 473 DATE_PART_MAPPING = { 474 "Y": "YEAR", 475 "YY": "YEAR", 476 "YYY": "YEAR", 477 "YYYY": "YEAR", 478 "YR": "YEAR", 479 "YEARS": "YEAR", 480 "YRS": "YEAR", 481 "MM": "MONTH", 482 "MON": "MONTH", 483 "MONS": "MONTH", 484 "MONTHS": "MONTH", 485 "D": "DAY", 486 "DD": "DAY", 487 "DAYS": "DAY", 488 "DAYOFMONTH": "DAY", 489 "DAY OF WEEK": "DAYOFWEEK", 490 "WEEKDAY": "DAYOFWEEK", 491 "DOW": "DAYOFWEEK", 492 "DW": "DAYOFWEEK", 493 "WEEKDAY_ISO": "DAYOFWEEKISO", 494 "DOW_ISO": "DAYOFWEEKISO", 495 "DW_ISO": "DAYOFWEEKISO", 496 "DAY OF YEAR": "DAYOFYEAR", 497 "DOY": "DAYOFYEAR", 498 "DY": "DAYOFYEAR", 499 "W": "WEEK", 500 "WK": "WEEK", 501 "WEEKOFYEAR": "WEEK", 502 "WOY": "WEEK", 503 "WY": "WEEK", 504 "WEEK_ISO": "WEEKISO", 505 "WEEKOFYEARISO": "WEEKISO", 506 "WEEKOFYEAR_ISO": "WEEKISO", 507 "Q": "QUARTER", 508 "QTR": "QUARTER", 509 "QTRS": "QUARTER", 510 "QUARTERS": "QUARTER", 511 "H": "HOUR", 512 "HH": "HOUR", 513 "HR": "HOUR", 514 "HOURS": "HOUR", 515 "HRS": "HOUR", 516 "M": "MINUTE", 517 "MI": "MINUTE", 518 "MIN": "MINUTE", 519 "MINUTES": "MINUTE", 520 "MINS": "MINUTE", 521 "S": "SECOND", 522 "SEC": "SECOND", 523 "SECONDS": "SECOND", 524 "SECS": "SECOND", 525 "MS": "MILLISECOND", 526 "MSEC": "MILLISECOND", 527 "MSECS": "MILLISECOND", 528 "MSECOND": "MILLISECOND", 529 "MSECONDS": "MILLISECOND", 530 "MILLISEC": "MILLISECOND", 531 "MILLISECS": "MILLISECOND", 532 "MILLISECON": "MILLISECOND", 533 "MILLISECONDS": "MILLISECOND", 534 "US": "MICROSECOND", 535 "USEC": "MICROSECOND", 536 "USECS": "MICROSECOND", 537 "MICROSEC": "MICROSECOND", 538 "MICROSECS": "MICROSECOND", 539 "USECOND": "MICROSECOND", 540 "USECONDS": "MICROSECOND", 541 "MICROSECONDS": "MICROSECOND", 542 "NS": "NANOSECOND", 543 "NSEC": "NANOSECOND", 544 "NANOSEC": "NANOSECOND", 545 "NSECOND": "NANOSECOND", 546 "NSECONDS": "NANOSECOND", 547 "NANOSECS": "NANOSECOND", 548 "EPOCH_SECOND": "EPOCH", 549 "EPOCH_SECONDS": "EPOCH", 550 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 551 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 552 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 553 "TZH": "TIMEZONE_HOUR", 554 "TZM": "TIMEZONE_MINUTE", 555 "DEC": "DECADE", 556 "DECS": "DECADE", 557 "DECADES": "DECADE", 558 "MIL": "MILLENIUM", 559 "MILS": "MILLENIUM", 560 "MILLENIA": "MILLENIUM", 561 "C": "CENTURY", 562 "CENT": "CENTURY", 563 "CENTS": "CENTURY", 564 "CENTURIES": "CENTURY", 565 } 566 567 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 568 exp.DataType.Type.BIGINT: { 569 exp.ApproxDistinct, 570 exp.ArraySize, 571 exp.Length, 572 }, 573 exp.DataType.Type.BOOLEAN: { 574 exp.Between, 575 exp.Boolean, 576 exp.In, 577 exp.RegexpLike, 578 }, 579 exp.DataType.Type.DATE: { 580 exp.CurrentDate, 581 exp.Date, 582 exp.DateFromParts, 583 exp.DateStrToDate, 584 exp.DiToDate, 585 exp.StrToDate, 586 exp.TimeStrToDate, 587 exp.TsOrDsToDate, 588 }, 589 exp.DataType.Type.DATETIME: { 590 exp.CurrentDatetime, 591 exp.Datetime, 592 exp.DatetimeAdd, 593 exp.DatetimeSub, 594 }, 595 exp.DataType.Type.DOUBLE: { 596 exp.ApproxQuantile, 597 exp.Avg, 598 exp.Exp, 599 exp.Ln, 600 exp.Log, 601 exp.Pow, 602 exp.Quantile, 603 exp.Round, 604 exp.SafeDivide, 605 exp.Sqrt, 606 exp.Stddev, 607 exp.StddevPop, 608 exp.StddevSamp, 609 exp.ToDouble, 610 exp.Variance, 611 exp.VariancePop, 612 }, 613 exp.DataType.Type.INT: { 614 exp.Ceil, 615 exp.DatetimeDiff, 616 exp.DateDiff, 617 exp.TimestampDiff, 618 exp.TimeDiff, 619 exp.DateToDi, 620 exp.Levenshtein, 621 exp.Sign, 622 exp.StrPosition, 623 exp.TsOrDiToDi, 624 }, 625 exp.DataType.Type.JSON: { 626 exp.ParseJSON, 627 }, 628 exp.DataType.Type.TIME: { 629 exp.Time, 630 }, 631 exp.DataType.Type.TIMESTAMP: { 632 exp.CurrentTime, 633 exp.CurrentTimestamp, 634 exp.StrToTime, 635 exp.TimeAdd, 636 exp.TimeStrToTime, 637 exp.TimeSub, 638 exp.TimestampAdd, 639 exp.TimestampSub, 640 exp.UnixToTime, 641 }, 642 exp.DataType.Type.TINYINT: { 643 exp.Day, 644 exp.Month, 645 exp.Week, 646 exp.Year, 647 exp.Quarter, 648 }, 649 exp.DataType.Type.VARCHAR: { 650 exp.ArrayConcat, 651 exp.Concat, 652 exp.ConcatWs, 653 exp.DateToDateStr, 654 exp.GroupConcat, 655 exp.Initcap, 656 exp.Lower, 657 exp.Substring, 658 exp.String, 659 exp.TimeToStr, 660 exp.TimeToTimeStr, 661 exp.Trim, 662 exp.TsOrDsToDateStr, 663 exp.UnixToStr, 664 exp.UnixToTimeStr, 665 exp.Upper, 666 }, 667 } 668 669 ANNOTATORS: AnnotatorsType = { 670 **{ 671 expr_type: lambda self, e: self._annotate_unary(e) 672 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 673 }, 674 **{ 675 expr_type: lambda self, e: self._annotate_binary(e) 676 for expr_type in subclasses(exp.__name__, exp.Binary) 677 }, 678 **{ 679 expr_type: _annotate_with_type_lambda(data_type) 680 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 681 for expr_type in expressions 682 }, 683 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 684 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 685 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 686 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 687 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 688 exp.Bracket: lambda self, e: self._annotate_bracket(e), 689 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 690 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 691 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 692 exp.Count: lambda self, e: self._annotate_with_type( 693 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 694 ), 695 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 696 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 697 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 698 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 699 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 700 exp.Div: lambda self, e: self._annotate_div(e), 701 exp.Dot: lambda self, e: self._annotate_dot(e), 702 exp.Explode: lambda self, e: self._annotate_explode(e), 703 exp.Extract: lambda self, e: self._annotate_extract(e), 704 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 705 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 706 e, exp.DataType.build("ARRAY<DATE>") 707 ), 708 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 709 e, exp.DataType.build("ARRAY<TIMESTAMP>") 710 ), 711 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 712 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 713 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 714 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 715 exp.Literal: lambda self, e: self._annotate_literal(e), 716 exp.Map: lambda self, e: self._annotate_map(e), 717 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 718 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 719 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 720 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 721 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 722 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 723 exp.Struct: lambda self, e: self._annotate_struct(e), 724 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 725 exp.Timestamp: lambda self, e: self._annotate_with_type( 726 e, 727 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 728 ), 729 exp.ToMap: lambda self, e: self._annotate_to_map(e), 730 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 731 exp.Unnest: lambda self, e: self._annotate_unnest(e), 732 exp.VarMap: lambda self, e: self._annotate_map(e), 733 } 734 735 @classmethod 736 def get_or_raise(cls, dialect: DialectType) -> Dialect: 737 """ 738 Look up a dialect in the global dialect registry and return it if it exists. 739 740 Args: 741 dialect: The target dialect. If this is a string, it can be optionally followed by 742 additional key-value pairs that are separated by commas and are used to specify 743 dialect settings, such as whether the dialect's identifiers are case-sensitive. 744 745 Example: 746 >>> dialect = dialect_class = get_or_raise("duckdb") 747 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 748 749 Returns: 750 The corresponding Dialect instance. 751 """ 752 753 if not dialect: 754 return cls() 755 if isinstance(dialect, _Dialect): 756 return dialect() 757 if isinstance(dialect, Dialect): 758 return dialect 759 if isinstance(dialect, str): 760 try: 761 dialect_name, *kv_strings = dialect.split(",") 762 kv_pairs = (kv.split("=") for kv in kv_strings) 763 kwargs = {} 764 for pair in kv_pairs: 765 key = pair[0].strip() 766 value: t.Union[bool | str | None] = None 767 768 if len(pair) == 1: 769 # Default initialize standalone settings to True 770 value = True 771 elif len(pair) == 2: 772 value = pair[1].strip() 773 774 # Coerce the value to boolean if it matches to the truthy/falsy values below 775 value_lower = value.lower() 776 if value_lower in ("true", "1"): 777 value = True 778 elif value_lower in ("false", "0"): 779 value = False 780 781 kwargs[key] = value 782 783 except ValueError: 784 raise ValueError( 785 f"Invalid dialect format: '{dialect}'. " 786 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 787 ) 788 789 result = cls.get(dialect_name.strip()) 790 if not result: 791 from difflib import get_close_matches 792 793 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 794 if similar: 795 similar = f" Did you mean {similar}?" 796 797 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 798 799 return result(**kwargs) 800 801 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 802 803 @classmethod 804 def format_time( 805 cls, expression: t.Optional[str | exp.Expression] 806 ) -> t.Optional[exp.Expression]: 807 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 808 if isinstance(expression, str): 809 return exp.Literal.string( 810 # the time formats are quoted 811 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 812 ) 813 814 if expression and expression.is_string: 815 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 816 817 return expression 818 819 def __init__(self, **kwargs) -> None: 820 normalization_strategy = kwargs.pop("normalization_strategy", None) 821 822 if normalization_strategy is None: 823 self.normalization_strategy = self.NORMALIZATION_STRATEGY 824 else: 825 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 826 827 self.settings = kwargs 828 829 def __eq__(self, other: t.Any) -> bool: 830 # Does not currently take dialect state into account 831 return type(self) == other 832 833 def __hash__(self) -> int: 834 # Does not currently take dialect state into account 835 return hash(type(self)) 836 837 def normalize_identifier(self, expression: E) -> E: 838 """ 839 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 840 841 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 842 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 843 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 844 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 845 846 There are also dialects like Spark, which are case-insensitive even when quotes are 847 present, and dialects like MySQL, whose resolution rules match those employed by the 848 underlying operating system, for example they may always be case-sensitive in Linux. 849 850 Finally, the normalization behavior of some engines can even be controlled through flags, 851 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 852 853 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 854 that it can analyze queries in the optimizer and successfully capture their semantics. 855 """ 856 if ( 857 isinstance(expression, exp.Identifier) 858 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 859 and ( 860 not expression.quoted 861 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 862 ) 863 ): 864 expression.set( 865 "this", 866 ( 867 expression.this.upper() 868 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 869 else expression.this.lower() 870 ), 871 ) 872 873 return expression 874 875 def case_sensitive(self, text: str) -> bool: 876 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 877 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 878 return False 879 880 unsafe = ( 881 str.islower 882 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 883 else str.isupper 884 ) 885 return any(unsafe(char) for char in text) 886 887 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 888 """Checks if text can be identified given an identify option. 889 890 Args: 891 text: The text to check. 892 identify: 893 `"always"` or `True`: Always returns `True`. 894 `"safe"`: Only returns `True` if the identifier is case-insensitive. 895 896 Returns: 897 Whether the given text can be identified. 898 """ 899 if identify is True or identify == "always": 900 return True 901 902 if identify == "safe": 903 return not self.case_sensitive(text) 904 905 return False 906 907 def quote_identifier(self, expression: E, identify: bool = True) -> E: 908 """ 909 Adds quotes to a given identifier. 910 911 Args: 912 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 913 identify: If set to `False`, the quotes will only be added if the identifier is deemed 914 "unsafe", with respect to its characters and this dialect's normalization strategy. 915 """ 916 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 917 name = expression.this 918 expression.set( 919 "quoted", 920 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 921 ) 922 923 return expression 924 925 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 926 if isinstance(path, exp.Literal): 927 path_text = path.name 928 if path.is_number: 929 path_text = f"[{path_text}]" 930 try: 931 return parse_json_path(path_text, self) 932 except ParseError as e: 933 if self.STRICT_JSON_PATH_SYNTAX: 934 logger.warning(f"Invalid JSON path syntax. {str(e)}") 935 936 return path 937 938 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 939 return self.parser(**opts).parse(self.tokenize(sql), sql) 940 941 def parse_into( 942 self, expression_type: exp.IntoType, sql: str, **opts 943 ) -> t.List[t.Optional[exp.Expression]]: 944 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 945 946 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 947 return self.generator(**opts).generate(expression, copy=copy) 948 949 def transpile(self, sql: str, **opts) -> t.List[str]: 950 return [ 951 self.generate(expression, copy=False, **opts) if expression else "" 952 for expression in self.parse(sql) 953 ] 954 955 def tokenize(self, sql: str) -> t.List[Token]: 956 return self.tokenizer.tokenize(sql) 957 958 @property 959 def tokenizer(self) -> Tokenizer: 960 return self.tokenizer_class(dialect=self) 961 962 @property 963 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 964 return self.jsonpath_tokenizer_class(dialect=self) 965 966 def parser(self, **opts) -> Parser: 967 return self.parser_class(dialect=self, **opts) 968 969 def generator(self, **opts) -> Generator: 970 return self.generator_class(dialect=self, **opts)
819 def __init__(self, **kwargs) -> None: 820 normalization_strategy = kwargs.pop("normalization_strategy", None) 821 822 if normalization_strategy is None: 823 self.normalization_strategy = self.NORMALIZATION_STRATEGY 824 else: 825 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 826 827 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the name of the function should be preserved inside the node's metadata, can be useful for roundtripping deprecated vs new functions that share an AST node e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects, "my_id" would refer to "data.my_id" across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator.
Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.
Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).
This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal is cast to x's type to match it instead.
Whether a set operation uses DISTINCT by default. This is None
when either DISTINCT
or ALL
must be explicitly specified.
Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse equivalent of CREATE SCHEMA is CREATE DATABASE.
735 @classmethod 736 def get_or_raise(cls, dialect: DialectType) -> Dialect: 737 """ 738 Look up a dialect in the global dialect registry and return it if it exists. 739 740 Args: 741 dialect: The target dialect. If this is a string, it can be optionally followed by 742 additional key-value pairs that are separated by commas and are used to specify 743 dialect settings, such as whether the dialect's identifiers are case-sensitive. 744 745 Example: 746 >>> dialect = dialect_class = get_or_raise("duckdb") 747 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 748 749 Returns: 750 The corresponding Dialect instance. 751 """ 752 753 if not dialect: 754 return cls() 755 if isinstance(dialect, _Dialect): 756 return dialect() 757 if isinstance(dialect, Dialect): 758 return dialect 759 if isinstance(dialect, str): 760 try: 761 dialect_name, *kv_strings = dialect.split(",") 762 kv_pairs = (kv.split("=") for kv in kv_strings) 763 kwargs = {} 764 for pair in kv_pairs: 765 key = pair[0].strip() 766 value: t.Union[bool | str | None] = None 767 768 if len(pair) == 1: 769 # Default initialize standalone settings to True 770 value = True 771 elif len(pair) == 2: 772 value = pair[1].strip() 773 774 # Coerce the value to boolean if it matches to the truthy/falsy values below 775 value_lower = value.lower() 776 if value_lower in ("true", "1"): 777 value = True 778 elif value_lower in ("false", "0"): 779 value = False 780 781 kwargs[key] = value 782 783 except ValueError: 784 raise ValueError( 785 f"Invalid dialect format: '{dialect}'. " 786 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 787 ) 788 789 result = cls.get(dialect_name.strip()) 790 if not result: 791 from difflib import get_close_matches 792 793 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 794 if similar: 795 similar = f" Did you mean {similar}?" 796 797 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 798 799 return result(**kwargs) 800 801 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
803 @classmethod 804 def format_time( 805 cls, expression: t.Optional[str | exp.Expression] 806 ) -> t.Optional[exp.Expression]: 807 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 808 if isinstance(expression, str): 809 return exp.Literal.string( 810 # the time formats are quoted 811 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 812 ) 813 814 if expression and expression.is_string: 815 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 816 817 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
837 def normalize_identifier(self, expression: E) -> E: 838 """ 839 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 840 841 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 842 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 843 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 844 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 845 846 There are also dialects like Spark, which are case-insensitive even when quotes are 847 present, and dialects like MySQL, whose resolution rules match those employed by the 848 underlying operating system, for example they may always be case-sensitive in Linux. 849 850 Finally, the normalization behavior of some engines can even be controlled through flags, 851 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 852 853 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 854 that it can analyze queries in the optimizer and successfully capture their semantics. 855 """ 856 if ( 857 isinstance(expression, exp.Identifier) 858 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 859 and ( 860 not expression.quoted 861 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 862 ) 863 ): 864 expression.set( 865 "this", 866 ( 867 expression.this.upper() 868 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 869 else expression.this.lower() 870 ), 871 ) 872 873 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
875 def case_sensitive(self, text: str) -> bool: 876 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 877 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 878 return False 879 880 unsafe = ( 881 str.islower 882 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 883 else str.isupper 884 ) 885 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
887 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 888 """Checks if text can be identified given an identify option. 889 890 Args: 891 text: The text to check. 892 identify: 893 `"always"` or `True`: Always returns `True`. 894 `"safe"`: Only returns `True` if the identifier is case-insensitive. 895 896 Returns: 897 Whether the given text can be identified. 898 """ 899 if identify is True or identify == "always": 900 return True 901 902 if identify == "safe": 903 return not self.case_sensitive(text) 904 905 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
907 def quote_identifier(self, expression: E, identify: bool = True) -> E: 908 """ 909 Adds quotes to a given identifier. 910 911 Args: 912 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 913 identify: If set to `False`, the quotes will only be added if the identifier is deemed 914 "unsafe", with respect to its characters and this dialect's normalization strategy. 915 """ 916 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 917 name = expression.this 918 expression.set( 919 "quoted", 920 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 921 ) 922 923 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
925 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 926 if isinstance(path, exp.Literal): 927 path_text = path.name 928 if path.is_number: 929 path_text = f"[{path_text}]" 930 try: 931 return parse_json_path(path_text, self) 932 except ParseError as e: 933 if self.STRICT_JSON_PATH_SYNTAX: 934 logger.warning(f"Invalid JSON path syntax. {str(e)}") 935 936 return path
985def if_sql( 986 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 987) -> t.Callable[[Generator, exp.If], str]: 988 def _if_sql(self: Generator, expression: exp.If) -> str: 989 return self.func( 990 name, 991 expression.this, 992 expression.args.get("true"), 993 expression.args.get("false") or false_value, 994 ) 995 996 return _if_sql
999def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1000 this = expression.this 1001 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 1002 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 1003 1004 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1074def str_position_sql( 1075 self: Generator, 1076 expression: exp.StrPosition, 1077 generate_instance: bool = False, 1078 str_position_func_name: str = "STRPOS", 1079) -> str: 1080 this = self.sql(expression, "this") 1081 substr = self.sql(expression, "substr") 1082 position = self.sql(expression, "position") 1083 instance = expression.args.get("instance") if generate_instance else None 1084 position_offset = "" 1085 1086 if position: 1087 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1088 this = self.func("SUBSTR", this, position) 1089 position_offset = f" + {position} - 1" 1090 1091 return self.func(str_position_func_name, this, substr, instance) + position_offset
1100def var_map_sql( 1101 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1102) -> str: 1103 keys = expression.args["keys"] 1104 values = expression.args["values"] 1105 1106 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1107 self.unsupported("Cannot convert array columns into map.") 1108 return self.func(map_func_name, keys, values) 1109 1110 args = [] 1111 for key, value in zip(keys.expressions, values.expressions): 1112 args.append(self.sql(key)) 1113 args.append(self.sql(value)) 1114 1115 return self.func(map_func_name, *args)
1118def build_formatted_time( 1119 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1120) -> t.Callable[[t.List], E]: 1121 """Helper used for time expressions. 1122 1123 Args: 1124 exp_class: the expression class to instantiate. 1125 dialect: target sql dialect. 1126 default: the default format, True being time. 1127 1128 Returns: 1129 A callable that can be used to return the appropriately formatted time expression. 1130 """ 1131 1132 def _builder(args: t.List): 1133 return exp_class( 1134 this=seq_get(args, 0), 1135 format=Dialect[dialect].format_time( 1136 seq_get(args, 1) 1137 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1138 ), 1139 ) 1140 1141 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1144def time_format( 1145 dialect: DialectType = None, 1146) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1147 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1148 """ 1149 Returns the time format for a given expression, unless it's equivalent 1150 to the default time format of the dialect of interest. 1151 """ 1152 time_format = self.format_time(expression) 1153 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1154 1155 return _time_format
1158def build_date_delta( 1159 exp_class: t.Type[E], 1160 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1161 default_unit: t.Optional[str] = "DAY", 1162) -> t.Callable[[t.List], E]: 1163 def _builder(args: t.List) -> E: 1164 unit_based = len(args) == 3 1165 this = args[2] if unit_based else seq_get(args, 0) 1166 unit = None 1167 if unit_based or default_unit: 1168 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1169 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1170 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1171 1172 return _builder
1175def build_date_delta_with_interval( 1176 expression_class: t.Type[E], 1177) -> t.Callable[[t.List], t.Optional[E]]: 1178 def _builder(args: t.List) -> t.Optional[E]: 1179 if len(args) < 2: 1180 return None 1181 1182 interval = args[1] 1183 1184 if not isinstance(interval, exp.Interval): 1185 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1186 1187 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1188 1189 return _builder
1192def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1193 unit = seq_get(args, 0) 1194 this = seq_get(args, 1) 1195 1196 if isinstance(this, exp.Cast) and this.is_type("date"): 1197 return exp.DateTrunc(unit=unit, this=this) 1198 return exp.TimestampTrunc(this=this, unit=unit)
1201def date_add_interval_sql( 1202 data_type: str, kind: str 1203) -> t.Callable[[Generator, exp.Expression], str]: 1204 def func(self: Generator, expression: exp.Expression) -> str: 1205 this = self.sql(expression, "this") 1206 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1207 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1208 1209 return func
1212def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1213 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1214 args = [unit_to_str(expression), expression.this] 1215 if zone: 1216 args.append(expression.args.get("zone")) 1217 return self.func("DATE_TRUNC", *args) 1218 1219 return _timestamptrunc_sql
1222def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1223 zone = expression.args.get("zone") 1224 if not zone: 1225 from sqlglot.optimizer.annotate_types import annotate_types 1226 1227 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1228 return self.sql(exp.cast(expression.this, target_type)) 1229 if zone.name.lower() in TIMEZONES: 1230 return self.sql( 1231 exp.AtTimeZone( 1232 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1233 zone=zone, 1234 ) 1235 ) 1236 return self.func("TIMESTAMP", expression.this, zone)
1239def no_time_sql(self: Generator, expression: exp.Time) -> str: 1240 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1241 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1242 expr = exp.cast( 1243 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1244 ) 1245 return self.sql(expr)
1248def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1249 this = expression.this 1250 expr = expression.expression 1251 1252 if expr.name.lower() in TIMEZONES: 1253 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1254 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1255 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1256 return self.sql(this) 1257 1258 this = exp.cast(this, exp.DataType.Type.DATE) 1259 expr = exp.cast(expr, exp.DataType.Type.TIME) 1260 1261 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1293def timestrtotime_sql( 1294 self: Generator, 1295 expression: exp.TimeStrToTime, 1296 include_precision: bool = False, 1297) -> str: 1298 datatype = exp.DataType.build( 1299 exp.DataType.Type.TIMESTAMPTZ 1300 if expression.args.get("zone") 1301 else exp.DataType.Type.TIMESTAMP 1302 ) 1303 1304 if isinstance(expression.this, exp.Literal) and include_precision: 1305 precision = subsecond_precision(expression.this.name) 1306 if precision > 0: 1307 datatype = exp.DataType.build( 1308 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1309 ) 1310 1311 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect))
1319def encode_decode_sql( 1320 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1321) -> str: 1322 charset = expression.args.get("charset") 1323 if charset and charset.name.lower() != "utf-8": 1324 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1325 1326 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1339def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1340 cond = expression.this 1341 1342 if isinstance(expression.this, exp.Distinct): 1343 cond = expression.this.expressions[0] 1344 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1345 1346 return self.func("sum", exp.func("if", cond, 1, 0))
1349def trim_sql(self: Generator, expression: exp.Trim) -> str: 1350 target = self.sql(expression, "this") 1351 trim_type = self.sql(expression, "position") 1352 remove_chars = self.sql(expression, "expression") 1353 collation = self.sql(expression, "collation") 1354 1355 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1356 if not remove_chars: 1357 return self.trim_sql(expression) 1358 1359 trim_type = f"{trim_type} " if trim_type else "" 1360 remove_chars = f"{remove_chars} " if remove_chars else "" 1361 from_part = "FROM " if trim_type or remove_chars else "" 1362 collation = f" COLLATE {collation}" if collation else "" 1363 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1384@unsupported_args("position", "occurrence", "parameters") 1385def regexp_extract_sql( 1386 self: Generator, expression: exp.RegexpExtract | exp.RegexpExtractAll 1387) -> str: 1388 group = expression.args.get("group") 1389 1390 # Do not render group if it's the default value for this dialect 1391 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1392 group = None 1393 1394 return self.func(expression.sql_name(), expression.this, expression.expression, group)
1404def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1405 names = [] 1406 for agg in aggregations: 1407 if isinstance(agg, exp.Alias): 1408 names.append(agg.alias) 1409 else: 1410 """ 1411 This case corresponds to aggregations without aliases being used as suffixes 1412 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1413 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1414 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1415 """ 1416 agg_all_unquoted = agg.transform( 1417 lambda node: ( 1418 exp.Identifier(this=node.name, quoted=False) 1419 if isinstance(node, exp.Identifier) 1420 else node 1421 ) 1422 ) 1423 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1424 1425 return names
1465def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1466 @unsupported_args("count") 1467 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1468 return self.func(name, expression.this, expression.expression) 1469 1470 return _arg_max_or_min_sql
1473def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1474 this = expression.this.copy() 1475 1476 return_type = expression.return_type 1477 if return_type.is_type(exp.DataType.Type.DATE): 1478 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1479 # can truncate timestamp strings, because some dialects can't cast them to DATE 1480 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1481 1482 expression.this.replace(exp.cast(this, return_type)) 1483 return expression
1486def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1487 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1488 if cast and isinstance(expression, exp.TsOrDsAdd): 1489 expression = ts_or_ds_add_cast(expression) 1490 1491 return self.func( 1492 name, 1493 unit_to_var(expression), 1494 expression.expression, 1495 expression.this, 1496 ) 1497 1498 return _delta_sql
1501def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1502 unit = expression.args.get("unit") 1503 1504 if isinstance(unit, exp.Placeholder): 1505 return unit 1506 if unit: 1507 return exp.Literal.string(unit.name) 1508 return exp.Literal.string(default) if default else None
1538def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1539 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1540 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1541 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1542 1543 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1546def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1547 """Remove table refs from columns in when statements.""" 1548 alias = expression.this.args.get("alias") 1549 1550 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1551 return self.dialect.normalize_identifier(identifier).name if identifier else None 1552 1553 targets = {normalize(expression.this.this)} 1554 1555 if alias: 1556 targets.add(normalize(alias.this)) 1557 1558 for when in expression.expressions: 1559 # only remove the target names from the THEN clause 1560 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1561 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1562 then = when.args.get("then") 1563 if then: 1564 then.transform( 1565 lambda node: ( 1566 exp.column(node.this) 1567 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1568 else node 1569 ), 1570 copy=False, 1571 ) 1572 1573 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1576def build_json_extract_path( 1577 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1578) -> t.Callable[[t.List], F]: 1579 def _builder(args: t.List) -> F: 1580 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1581 for arg in args[1:]: 1582 if not isinstance(arg, exp.Literal): 1583 # We use the fallback parser because we can't really transpile non-literals safely 1584 return expr_type.from_arg_list(args) 1585 1586 text = arg.name 1587 if is_int(text): 1588 index = int(text) 1589 segments.append( 1590 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1591 ) 1592 else: 1593 segments.append(exp.JSONPathKey(this=text)) 1594 1595 # This is done to avoid failing in the expression validator due to the arg count 1596 del args[2:] 1597 return expr_type( 1598 this=seq_get(args, 0), 1599 expression=exp.JSONPath(expressions=segments), 1600 only_json_types=arrow_req_json_type, 1601 ) 1602 1603 return _builder
1606def json_extract_segments( 1607 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1608) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1609 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1610 path = expression.expression 1611 if not isinstance(path, exp.JSONPath): 1612 return rename_func(name)(self, expression) 1613 1614 escape = path.args.get("escape") 1615 1616 segments = [] 1617 for segment in path.expressions: 1618 path = self.sql(segment) 1619 if path: 1620 if isinstance(segment, exp.JSONPathPart) and ( 1621 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1622 ): 1623 if escape: 1624 path = self.escape_str(path) 1625 1626 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1627 1628 segments.append(path) 1629 1630 if op: 1631 return f" {op} ".join([self.sql(expression.this), *segments]) 1632 return self.func(name, expression.this, *segments) 1633 1634 return _json_extract_segments
1644def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1645 cond = expression.expression 1646 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1647 alias = cond.expressions[0] 1648 cond = cond.this 1649 elif isinstance(cond, exp.Predicate): 1650 alias = "_u" 1651 else: 1652 self.unsupported("Unsupported filter condition") 1653 return "" 1654 1655 unnest = exp.Unnest(expressions=[expression.this]) 1656 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1657 return self.sql(exp.Array(expressions=[filtered]))
1669def build_default_decimal_type( 1670 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1671) -> t.Callable[[exp.DataType], exp.DataType]: 1672 def _builder(dtype: exp.DataType) -> exp.DataType: 1673 if dtype.expressions or precision is None: 1674 return dtype 1675 1676 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1677 return exp.DataType.build(f"DECIMAL({params})") 1678 1679 return _builder
1682def build_timestamp_from_parts(args: t.List) -> exp.Func: 1683 if len(args) == 2: 1684 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1685 # so we parse this into Anonymous for now instead of introducing complexity 1686 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1687 1688 return exp.TimestampFromParts.from_arg_list(args)
1695def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1696 start = expression.args.get("start") 1697 end = expression.args.get("end") 1698 step = expression.args.get("step") 1699 1700 if isinstance(start, exp.Cast): 1701 target_type = start.to 1702 elif isinstance(end, exp.Cast): 1703 target_type = end.to 1704 else: 1705 target_type = None 1706 1707 if start and end and target_type and target_type.is_type("date", "timestamp"): 1708 if isinstance(start, exp.Cast) and target_type is start.to: 1709 end = exp.cast(end, target_type) 1710 else: 1711 start = exp.cast(start, target_type) 1712 1713 return self.func("SEQUENCE", start, end, step)
1716def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: 1717 def _builder(args: t.List, dialect: Dialect) -> E: 1718 return expr_type( 1719 this=seq_get(args, 0), 1720 expression=seq_get(args, 1), 1721 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1722 parameters=seq_get(args, 3), 1723 ) 1724 1725 return _builder
1728def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1729 if isinstance(expression.this, exp.Explode): 1730 return self.sql( 1731 exp.Join( 1732 this=exp.Unnest( 1733 expressions=[expression.this.this], 1734 alias=expression.args.get("alias"), 1735 offset=isinstance(expression.this, exp.Posexplode), 1736 ), 1737 kind="cross", 1738 ) 1739 ) 1740 return self.lateral_sql(expression)
1747def no_make_interval_sql(self: Generator, expression: exp.MakeInterval, sep: str = ", ") -> str: 1748 args = [] 1749 for unit, value in expression.args.items(): 1750 if isinstance(value, exp.Kwarg): 1751 value = value.expression 1752 1753 args.append(f"{value} {unit}") 1754 1755 return f"INTERVAL '{self.format_args(*args, sep=sep)}'"