Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
model_validator,
)

from pyiceberg.exceptions import ValidationError
from pyiceberg.schema import Schema
from pyiceberg.transforms import (
BucketTransform,
Expand Down Expand Up @@ -249,6 +250,39 @@ def partition_to_path(self, data: Record, schema: Schema) -> str:
path = "/".join([field_str + "=" + value_str for field_str, value_str in zip(field_strs, value_strs, strict=True)])
return path

def check_compatible(self, schema: Schema, allow_missing_fields: bool = False) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: doesnt look like allow_missing_fields is used anywhere, should we still keep it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_compatible in Java its only used with allowMissingField=true (ctx) when reading metadata tables. Since here we only use this check (for now) at write path only, I think we could remove this field.

# if the underlying field is dropped, we cannot check they are compatible -- continue
schema_fields = schema._lazy_id_to_field
parents = schema._lazy_id_to_parent

for field in self.fields:
source_field = schema_fields.get(field.source_id)

if allow_missing_fields and source_field is None:
continue

if isinstance(field.transform, VoidTransform):
continue

if not source_field:
raise ValidationError(f"Cannot find source column for partition field: {field}")

source_type = source_field.field_type
if not source_type.is_primitive:
raise ValidationError(f"Cannot partition by non-primitive source field: {source_field}")
if not field.transform.can_transform(source_type):
raise ValidationError(
f"Invalid source field {source_field.name} with type {source_type} " + f"for transform: {field.transform}"
)

# The only valid parent types for a PartitionField are StructTypes. This must be checked recursively
parent_id = parents.get(field.source_id)
while parent_id:
parent_type = schema.find_type(parent_id)
if not parent_type.is_struct:
raise ValidationError(f"Invalid partition field parent: {parent_type}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a bit more info here too

parent_id = parents.get(parent_id)


UNPARTITIONED_PARTITION_SPEC = PartitionSpec(spec_id=0)

Expand Down
4 changes: 4 additions & 0 deletions pyiceberg/table/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ def current_snapshot(self) -> Snapshot | None:
def next_sequence_number(self) -> int:
return self.last_sequence_number + 1 if self.format_version > 1 else INITIAL_SEQUENCE_NUMBER

def sort_order(self) -> SortOrder:
"""Get the current sort order for this table, or UNSORTED_SORT_ORDER if there is no sort order."""
return self.sort_order_by_id(self.default_sort_order_id) or UNSORTED_SORT_ORDER

def sort_order_by_id(self, sort_order_id: int) -> SortOrder | None:
"""Get the sort order by sort_order_id."""
return next((sort_order for sort_order in self.sort_orders if sort_order.order_id == sort_order_id), None)
Expand Down
14 changes: 14 additions & 0 deletions pyiceberg/table/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
model_validator,
)

from pyiceberg.exceptions import ValidationError
from pyiceberg.schema import Schema
from pyiceberg.transforms import IdentityTransform, Transform, parse_transform
from pyiceberg.typedef import IcebergBaseModel
Expand Down Expand Up @@ -170,6 +171,19 @@ def __repr__(self) -> str:
fields = f"{', '.join(repr(column) for column in self.fields)}, " if self.fields else ""
return f"SortOrder({fields}order_id={self.order_id})"

def check_compatible(self, schema: Schema) -> None:
for field in self.fields:
source_field = schema._lazy_id_to_field.get(field.source_id)
if source_field is None:
raise ValidationError(f"Cannot find source column for sort field: {field}")
if not source_field.field_type.is_primitive:
raise ValidationError(f"Cannot sort by non-primitive source field: {source_field}")
if not field.transform.can_transform(source_field.field_type):
raise ValidationError(
f"Invalid source field {source_field.name} with type {source_field.field_type} "
+ f"for transform: {field.transform}"
)


UNSORTED_SORT_ORDER_ID = 0
UNSORTED_SORT_ORDER = SortOrder(order_id=UNSORTED_SORT_ORDER_ID)
Expand Down
4 changes: 4 additions & 0 deletions pyiceberg/table/update/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,10 @@ def update_table_metadata(
if base_metadata.last_updated_ms == new_metadata.last_updated_ms:
new_metadata = new_metadata.model_copy(update={"last_updated_ms": datetime_to_millis(datetime.now().astimezone())})

# Check correctness of partition spec, and sort order
new_metadata.spec().check_compatible(new_metadata.schema())
new_metadata.sort_order().check_compatible(new_metadata.schema())

if enforce_validation:
return TableMetadataUtil.parse_obj(new_metadata.model_dump())
else:
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from pyiceberg.serializers import ToOutputFile
from pyiceberg.table import FileScanTask, Table
from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2, TableMetadataV3
from pyiceberg.table.sorting import NullOrder, SortField, SortOrder
from pyiceberg.transforms import DayTransform, IdentityTransform
from pyiceberg.typedef import Identifier
from pyiceberg.types import (
Expand Down Expand Up @@ -1893,6 +1894,11 @@ def test_partition_spec() -> PartitionSpec:
)


@pytest.fixture(scope="session")
def test_sort_order() -> SortOrder:
return SortOrder(SortField(source_id=1, transform=IdentityTransform(), null_order=NullOrder.NULLS_FIRST))


@pytest.fixture(scope="session")
def generated_manifest_entry_file(
avro_schema_manifest_entry: dict[str, Any], test_schema: Schema, test_partition_spec: PartitionSpec
Expand Down
54 changes: 54 additions & 0 deletions tests/integration/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
NoSuchNamespaceError,
NoSuchTableError,
TableAlreadyExistsError,
ValidationError,
)
from pyiceberg.io import WAREHOUSE
from pyiceberg.partitioning import PartitionField, PartitionSpec
Expand Down Expand Up @@ -635,3 +636,56 @@ def test_rest_custom_namespace_separator(rest_catalog: RestCatalog, table_schema

loaded_table = rest_catalog.load_table(identifier=full_table_identifier_tuple)
assert loaded_table.name() == full_table_identifier_tuple


@pytest.mark.integration
@pytest.mark.parametrize("test_catalog", CATALOGS)
def test_incompatible_partitioned_schema_evolution(
test_catalog: Catalog, test_schema: Schema, test_partition_spec: PartitionSpec, database_name: str, table_name: str
) -> None:
if isinstance(test_catalog, HiveCatalog):
pytest.skip("HiveCatalog does not support schema evolution")

identifier = (database_name, table_name)
test_catalog.create_namespace(database_name)
table = test_catalog.create_table(identifier, test_schema, partition_spec=test_partition_spec)
assert test_catalog.table_exists(identifier)

with pytest.raises(ValidationError):
with table.update_schema() as update:
update.delete_column("VendorID")

# Assert column was not dropped
assert "VendorID" in table.schema().column_names

with table.transaction() as transaction:
with transaction.update_spec() as spec_update:
spec_update.remove_field("VendorID")

with transaction.update_schema() as schema_update:
schema_update.delete_column("VendorID")

assert table.spec() == PartitionSpec(PartitionField(2, 1001, DayTransform(), "tpep_pickup_day"), spec_id=1)
assert table.schema() == Schema(NestedField(2, "tpep_pickup_datetime", TimestampType(), False))


@pytest.mark.integration
@pytest.mark.parametrize("test_catalog", CATALOGS)
def test_incompatible_sorted_schema_evolution(
test_catalog: Catalog, test_schema: Schema, test_sort_order: SortOrder, database_name: str, table_name: str
) -> None:
if isinstance(test_catalog, HiveCatalog):
pytest.skip("HiveCatalog does not support schema evolution")

identifier = (database_name, table_name)
test_catalog.create_namespace(database_name)
table = test_catalog.create_table(identifier, test_schema, sort_order=test_sort_order)
assert test_catalog.table_exists(identifier)

with pytest.raises(ValidationError):
with table.update_schema() as update:
update.delete_column("VendorID")

assert table.schema() == Schema(
NestedField(1, "VendorID", IntegerType(), False), NestedField(2, "tpep_pickup_datetime", TimestampType(), False)
)
34 changes: 34 additions & 0 deletions tests/table/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import pytest

from pyiceberg.exceptions import ValidationError
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.transforms import (
Expand Down Expand Up @@ -264,3 +265,36 @@ def test_deserialize_partition_field_v3() -> None:

field = PartitionField.model_validate_json(json_partition_spec)
assert field == PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(width=19), name="str_truncate")


def test_incompatible_source_column_not_found() -> None:
schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType()))

spec = PartitionSpec(PartitionField(3, 1000, IdentityTransform(), "some_partition"))

with pytest.raises(ValidationError) as exc:
spec.check_compatible(schema)

assert "Cannot find source column for partition field: 1000: some_partition: identity(3)" in str(exc.value)


def test_incompatible_non_primitive_type() -> None:
schema = Schema(NestedField(1, "foo", StructType()), NestedField(2, "bar", IntegerType()))

spec = PartitionSpec(PartitionField(1, 1000, IdentityTransform(), "some_partition"))

with pytest.raises(ValidationError) as exc:
spec.check_compatible(schema)

assert "Cannot partition by non-primitive source field: 1: foo: optional struct<>" in str(exc.value)


def test_incompatible_transform_source_type() -> None:
schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType()))

spec = PartitionSpec(PartitionField(1, 1000, YearTransform(), "some_partition"))

with pytest.raises(ValidationError) as exc:
spec.check_compatible(schema)

assert "Invalid source field foo with type int for transform: year" in str(exc.value)
38 changes: 37 additions & 1 deletion tests/table/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import pytest

from pyiceberg.exceptions import ValidationError
from pyiceberg.schema import Schema
from pyiceberg.table.metadata import TableMetadataUtil
from pyiceberg.table.sorting import (
UNSORTED_SORT_ORDER,
Expand All @@ -28,7 +30,8 @@
SortField,
SortOrder,
)
from pyiceberg.transforms import BucketTransform, IdentityTransform, VoidTransform
from pyiceberg.transforms import BucketTransform, IdentityTransform, VoidTransform, YearTransform
from pyiceberg.types import IntegerType, NestedField, StructType


@pytest.fixture
Expand Down Expand Up @@ -133,3 +136,36 @@ def test_serialize_sort_field_v3() -> None:
expected = SortField(source_id=19, transform=IdentityTransform(), null_order=NullOrder.NULLS_FIRST)
payload = '{"source-ids":[19],"transform":"identity","direction":"asc","null-order":"nulls-first"}'
assert SortField.model_validate_json(payload) == expected


def test_incompatible_source_column_not_found(sort_order: SortOrder) -> None:
schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType()))

with pytest.raises(ValidationError) as exc:
sort_order.check_compatible(schema)

assert "Cannot find source column for sort field: 19 ASC NULLS FIRST" in str(exc.value)


def test_incompatible_non_primitive_type() -> None:
schema = Schema(NestedField(1, "foo", StructType()), NestedField(2, "bar", IntegerType()))

sort_order = SortOrder(SortField(source_id=1, transform=IdentityTransform(), null_order=NullOrder.NULLS_FIRST))

with pytest.raises(ValidationError) as exc:
sort_order.check_compatible(schema)

assert "Cannot sort by non-primitive source field: 1: foo: optional struct<>" in str(exc.value)


def test_incompatible_transform_source_type() -> None:
schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType()))

sort_order = SortOrder(
SortField(source_id=1, transform=YearTransform(), null_order=NullOrder.NULLS_FIRST),
)

with pytest.raises(ValidationError) as exc:
sort_order.check_compatible(schema)

assert "Invalid source field foo with type int for transform: year" in str(exc.value)