# Copyright 2026 Daitum
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
``DerivedTable`` and its supporting enums and helpers.
A derived table is built from another source table with optional grouping,
filtering, sorting and aggregation. The aggregation helpers in this module
validate which ``AggregationMethod`` values are legal for each
``DataType``/``ObjectDataType``/``MapDataType`` and compute the resulting
data type.
"""
from __future__ import annotations
from enum import Enum
from typeguard import typechecked
from ._buildable import Buildable
from .data_types import PRIMITIVE_DATA_TYPES, BaseDataType, DataType, MapDataType, ObjectDataType
from .fields import DataField, Field
from .tables import Table
[docs]
class SortDirection(Enum):
"""
Enumeration representing the possible sorting options for a `SortKey` in a `DerivedTable`.
"""
ASCENDING = "ASCENDING"
"""Ascending sort direction."""
DESCENDING = "DESCENDING"
"""Descending sort direction."""
[docs]
class AggregationMethod(Enum):
"""
Enumeration listing the possible aggregation methods for aggregated fields.
"""
BLANK = "BLANK"
""" Returns a typed null for the target field without aggregating values.
Allowed DataTypes: any type.
Return DataTypes: same as the allowed type. """
COUNT = "COUNT"
""" Counts the number of non-null input items and returns that count. Result type is
INTEGER irrespective of the source item type.
Allowed DataTypes: any type.
Return DataTypes: INTEGER. """
SUM = "SUM"
""" Sums numeric inputs. For array targets, performs element-wise sum across arrays.
Allowed DataTypes: INTEGER, DECIMAL, INTEGER_ARRAY, DECIMAL_ARRAY.
Return DataTypes: same as the allowed type. """
MIN = "MIN"
""" Returns the minimum (by natural ordering) of the inputs.
Allowed DataTypes: any comparable scalar type (INTEGER, DECIMAL, STRING, BOOLEAN,
DATE, TIME, DATETIME) and other comparable values. Not intended for arrays.
Return DataTypes: same as the allowed type. """
MAX = "MAX"
""" Returns the maximum (by natural ordering) of the inputs.
Allowed DataTypes: any comparable scalar type (INTEGER, DECIMAL, STRING, BOOLEAN,
DATE, TIME, DATETIME) and other comparable values. Not intended for arrays.
Return DataTypes: same as the allowed type. """
AVERAGE = "AVERAGE"
""" Calculates the arithmetic mean of numeric inputs and returns a DECIMAL value.
Allowed DataTypes: INTEGER, DECIMAL.
Return DataTypes: DECIMAL. """
FIRST = "FIRST"
""" Returns the first input value encountered (typed), or null if none.
Allowed DataTypes: any type.
Return DataTypes: same as the allowed type. """
LAST = "LAST"
""" Returns the last input value encountered (typed), or null if none.
Allowed DataTypes: any type.
Return DataTypes: same as the allowed type. """
EQUAL = "EQUAL"
""" If all input values are equal, returns that value; otherwise returns typed null.
Allowed DataTypes: any type.
Return DataTypes: same as the allowed type. """
AND = "AND"
""" Logical AND over boolean inputs. Nulls are treated as false by booleanValue() conversion.
Allowed DataTypes: BOOLEAN.
Return DataTypes: same as the allowed type. """
OR = "OR"
""" Logical OR over boolean inputs. Nulls are treated as false by booleanValue() conversion.
Allowed DataTypes: BOOLEAN.
Return DataTypes: same as the allowed type. """
ARRAY = "ARRAY"
""" Collects input items into an array of the target element type, preserving order.
Allowed DataTypes: any non-array types except MapDataType (e.g., INTEGER, DECIMAL,
STRING, BOOLEAN, DATE, TIME, DATETIME, OBJECT)
Return DataTypes: the corresponding array type of the allowed type. (e.g., INTEGER_ARRAY,
DECIMAL_ARRAY, STRING_ARRAY, BOOLEAN_ARRAY, DATE_ARRAY, TIME_ARRAY, DATETIME_ARRAY,
OBJECT_ARRAY) """
REFERENCE = "REFERENCE"
""" Creates an OBJECT_ARRAY of references from row ids provided by the evaluation engine.
Allowed DataTypes: any type.
Return DataTypes: OBJECT_ARRAY for a specific table. """
INTERSECTION = "INTERSECTION"
""" Computes the set intersection across array inputs and returns an array result.
Allowed DataTypes: any scalar or array type.
Return DataTypes: the array of the scalar type when scalar target is provided, or the
same array type when the target is already an array. """
UNION = "UNION"
""" Computes the set union of inputs. If inputs are arrays, unions their elements; if inputs
are scalars, unions individual values..
Allowed DataTypes: any scalar or array type.
Return DataTypes: the array of the scalar type when scalar target is provided, or the same
array type when the target is already an array. """
def _is_valid_aggregation(
data_type: BaseDataType,
aggregation_method: AggregationMethod,
) -> bool:
"""
Return ``True`` if *aggregation_method* is valid for the given *data_type*.
Args:
data_type: The data type of the source field.
aggregation_method: The proposed aggregation method.
Returns:
``True`` if the combination is supported, ``False`` otherwise.
"""
match aggregation_method:
case (
AggregationMethod.BLANK
| AggregationMethod.COUNT
| AggregationMethod.FIRST
| AggregationMethod.LAST
| AggregationMethod.EQUAL
| AggregationMethod.REFERENCE
):
result = True
case AggregationMethod.SUM:
result = data_type in [
DataType.INTEGER,
DataType.DECIMAL,
DataType.INTEGER_ARRAY,
DataType.DECIMAL_ARRAY,
]
case AggregationMethod.MIN | AggregationMethod.MAX:
result = data_type in PRIMITIVE_DATA_TYPES
case AggregationMethod.AVERAGE:
result = data_type in [DataType.INTEGER, DataType.DECIMAL]
case AggregationMethod.AND | AggregationMethod.OR:
result = data_type == DataType.BOOLEAN
case AggregationMethod.ARRAY:
result = data_type in PRIMITIVE_DATA_TYPES or (
isinstance(data_type, ObjectDataType) and not data_type.is_array()
)
case AggregationMethod.INTERSECTION | AggregationMethod.UNION:
result = not isinstance(data_type, MapDataType)
case _:
result = False
return result
def _get_aggregated_data_type(
source_field: Field,
aggregation_method: AggregationMethod,
) -> BaseDataType:
"""
Determine the result data type after applying *aggregation_method* to *source_field*.
Args:
source_field: The field being aggregated.
aggregation_method: The aggregation method to apply.
Returns:
The resulting data type.
Raises:
ValueError: If *aggregation_method* is not valid for the source field's data type.
"""
data_type: BaseDataType = source_field.data_type
if not _is_valid_aggregation(data_type, aggregation_method):
raise ValueError(
f"Aggregation method {aggregation_method.name} is not valid for data type "
f"{data_type}"
)
match aggregation_method:
case AggregationMethod.COUNT:
return DataType.INTEGER
case AggregationMethod.AVERAGE:
return DataType.DECIMAL
case AggregationMethod.ARRAY:
if not isinstance(data_type, MapDataType):
return data_type.to_array()
case AggregationMethod.REFERENCE:
return ObjectDataType(source_field.table, is_array=True)
case AggregationMethod.INTERSECTION | AggregationMethod.UNION:
if not isinstance(data_type, MapDataType):
return data_type if data_type.is_array() else data_type.to_array()
return data_type
[docs]
@typechecked
class DerivedTable(Table):
"""
Represents a derived table that is based on another source table and supports grouping,
sorting, and filtering.
"""
def __init__(
self,
id: str,
source_table: Table,
group_by: list[Field] | None = None,
filter_field: Field | None = None,
):
super().__init__(id)
self._source_table = source_table
self.source_table_id: str = source_table.id
self.grouping_configuration: DerivedTable._GroupingConfiguration | None = None
self.filter_field: str | None = None
self._filter_field_ref: Field | None = None
self.sort_keys: list[DerivedTable._SortKey] = []
if group_by is not None:
for field in group_by:
if field.tracking_group is not None:
raise ValueError("Currently do not support grouping by tracked fields")
self.grouping_configuration = DerivedTable._GroupingConfiguration(group_by)
if filter_field is not None:
if filter_field.data_type != DataType.BOOLEAN:
raise ValueError(f"Cannot filter on field with data type: {filter_field.data_type}")
self._filter_field_ref = filter_field
self.filter_field = filter_field.id
[docs]
def add_source_fields( # noqa: PLR0912
self, source_fields: list[Field] | None = None, include_validators: bool = False
):
"""
Adds fields to the current object from the source table or grouping configuration.
This method checks whether the `source_fields` parameter is provided. If `source_fields` is
`None`, it adds all fields from the source table or grouped fields (depending on whether
the grouping configuration is present). If `source_fields` is provided, it validates that
each field is present in the available fields before adding it.
Args:
source_fields (list[Field]): A list of `Field` objects to be added. Defaults to `None`,
for which case all available fields will be added.
include_validators (bool): Whether to propagate validators from source fields to the
derived table. Defaults to `False`.
Raises:
ValueError: If Any field in `source_fields` is not found in the available fields.
Note:
If `_grouping_configuration` is `None`, fields are taken from
`self.source_table._fields`.
If `_grouping_configuration` is not `None`, fields are taken from
`self.grouping_configuration.group_by_fields`.
"""
fields = (
self.grouping_configuration._raw_fields
if self.grouping_configuration
else list(self._source_table.field_definitions.values())
)
field_list = list(source_fields or fields)
if source_fields is not None:
for field in source_fields:
if field not in fields:
context = "grouped fields." if self.grouping_configuration else "source table."
raise ValueError(f"The field {field.id} does not appear in the {context}")
if field.tracking_group is not None:
tracked_field = self._source_table.get_field(field.tracking_id)
if tracked_field not in field_list:
field_list.append(tracked_field)
if include_validators and source_fields is not None:
self._append_validator_fields(field_list)
for field in field_list:
is_validator_field = "__invalid__" in field.id or "__message__" in field.id
if not include_validators and is_validator_field:
continue
data_field = DataField(field.id, self, field.data_type)
if field.order_index is not None:
data_field.set_order_index(field.order_index)
if field.description is not None:
data_field.set_description(field.description)
if field.tracking_group is not None:
data_field.set_tracking_group(field.tracking_group)
self._add_field(data_field)
if include_validators:
for validator in field._validators: # pylint: disable=protected-access
data_field._validators.append(validator) # pylint: disable=protected-access
def _append_validator_fields(self, field_list: list) -> None:
existing_ids = {f.id for f in field_list}
base_ids = {
fid for fid in existing_ids if "__invalid__" not in fid and "__message__" not in fid
}
for (
source_field
) in self._source_table.field_definitions.values(): # pylint: disable=protected-access
fid = source_field.id
if fid in existing_ids:
continue
if "__invalid__" not in fid and "__message__" not in fid:
continue
base_id = fid.split("__invalid__")[0].split("__message__")[0]
if base_id in base_ids:
field_list.append(source_field)
[docs]
def add_sort_key(self, field: Field, direction: SortDirection):
"""
Adds a sort key to the derived table.
Args:
field (Field): The field to sort by.
direction (SortDirection): The direction in which to sort.
"""
self.sort_keys.append(DerivedTable._SortKey(field, direction))
[docs]
def add_aggregated_field(
self,
id: str,
source_field: Field,
aggregation_method: AggregationMethod,
tracking_group: str | None = None,
) -> DataField:
"""
Adds an aggregated field to the table.
Args:
id (str): The unique identifier for the aggregated field.
source_field (Field): The source field that will be aggregated.
aggregation_method (AggregationMethod): The method used to aggregate the `source_field`.
tracking_group (str, optional): Group identifier for change tracking.
Raises:
ValueError: If no grouped fields are present in the table.
ValueError: If the aggregation method is not valid for the source field's data type.
"""
if self.grouping_configuration is None:
raise ValueError(
"Cannot add an aggregated field to a DerivedTable with no grouped fields"
)
data_type = _get_aggregated_data_type(source_field, aggregation_method)
self.grouping_configuration.add_aggregated_field(id, source_field, aggregation_method)
data_field = DataField(id, self, data_type)
if tracking_group is not None:
data_field.set_tracking_group(tracking_group)
self._add_field(data_field)
if tracking_group is not None:
self.add_aggregated_field(
data_field.tracking_id,
self._source_table.get_field(source_field.tracking_id),
aggregation_method,
)
return data_field
class _SortKey(Buildable):
def __init__(self, field: Field, direction: SortDirection):
self.field = field.id
self.direction = direction
class _GroupingConfiguration(Buildable):
def __init__(self, group_by_fields: list[Field]):
self.group_by_fields = [f.id for f in group_by_fields]
self.aggregated_fields: list[DerivedTable._GroupingConfiguration._AggregatedField] = []
self._raw_fields = group_by_fields
# pylint: disable=missing-function-docstring
def add_aggregated_field(
self, id: str, source_field: Field, aggregation_method: AggregationMethod
):
self.aggregated_fields.append(
DerivedTable._GroupingConfiguration._AggregatedField(
id, source_field, aggregation_method
)
)
class _AggregatedField(Buildable):
def __init__(self, id: str, source_field: Field, aggregation_method: AggregationMethod):
self.aggregated_field_id = id
self.source_field_id = source_field.id
self.aggregation_method = aggregation_method.name