Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
model_validator,
)

from pyiceberg.exceptions import ValidationError
from pyiceberg.schema import Schema
from pyiceberg.transforms import (
BucketTransform,
Expand Down Expand Up @@ -249,6 +250,37 @@ def partition_to_path(self, data: Record, schema: Schema) -> str:
path = "/".join([field_str + "=" + value_str for field_str, value_str in zip(field_strs, value_strs, strict=True)])
return path

def check_compatible(self, schema: Schema, allow_missing_fields: bool = False) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: doesnt look like allow_missing_fields is used anywhere, should we still keep it?

# if the underlying field is dropped, we cannot check they are compatible -- continue
schema_fields = schema._lazy_id_to_field
parents = schema._lazy_id_to_parent

for field in self.fields:
source_field = schema_fields.get(field.source_id)

if allow_missing_fields and source_field is None:
continue

if isinstance(field.transform, VoidTransform):
continue

if not source_field:
raise ValidationError(f"Cannot find source column for partition field: {field}")

source_type = source_field.field_type
if not source_type.is_primitive:
raise ValidationError(f"Cannot partition by non-primitive source field: {source_type}")
if not field.transform.can_transform(source_type):
raise ValidationError(f"Invalid source type {source_type} for transform: {field.transform}")

# The only valid parent types for a PartitionField are StructTypes. This must be checked recursively
parent_id = parents.get(field.source_id)
while parent_id:
parent_type = schema.find_type(parent_id)
if not parent_type.is_struct:
raise ValidationError(f"Invalid partition field parent: {parent_type}")
parent_id = parents.get(parent_id)


UNPARTITIONED_PARTITION_SPEC = PartitionSpec(spec_id=0)

Expand Down
11 changes: 11 additions & 0 deletions pyiceberg/table/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
model_validator,
)

from pyiceberg.exceptions import ValidationError
from pyiceberg.schema import Schema
from pyiceberg.transforms import IdentityTransform, Transform, parse_transform
from pyiceberg.typedef import IcebergBaseModel
Expand Down Expand Up @@ -169,6 +170,16 @@ def __repr__(self) -> str:
fields = f"{', '.join(repr(column) for column in self.fields)}, " if self.fields else ""
return f"SortOrder({fields}order_id={self.order_id})"

def check_compatible(self, schema: Schema) -> None:
for field in self.fields:
source_field = schema._lazy_id_to_field.get(field.source_id)
if source_field is None:
raise ValidationError(f"Cannot find source column for sort field: {field}")
if not source_field.field_type.is_primitive:
raise ValidationError(f"Cannot sort by non-primitive source field: {source_field}")
if not field.transform.can_transform(source_field.field_type):
raise ValidationError(f"Invalid source type {source_field.field_type} for transform: {field.transform}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we include source_field in the msg here? or do you think theres enough context already?



UNSORTED_SORT_ORDER_ID = 0
UNSORTED_SORT_ORDER = SortOrder(order_id=UNSORTED_SORT_ORDER_ID)
Expand Down
6 changes: 6 additions & 0 deletions pyiceberg/table/update/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,12 @@ def update_table_metadata(
if base_metadata.last_updated_ms == new_metadata.last_updated_ms:
new_metadata = new_metadata.model_copy(update={"last_updated_ms": datetime_to_millis(datetime.now().astimezone())})

# Check correctness of partition spec, and sort order
new_metadata.spec().check_compatible(new_metadata.schema())

if sort_order := new_metadata.sort_order_by_id(new_metadata.default_sort_order_id):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i think we can add a sort_order() function to TableMetadata class, i notice its only available in Table. but we can refactor later

then we can do

new_metadata.sort_order().check_compatible(new_metadata.schema())

sort_order.check_compatible(new_metadata.schema())

if enforce_validation:
return TableMetadataUtil.parse_obj(new_metadata.model_dump())
else:
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from pyiceberg.serializers import ToOutputFile
from pyiceberg.table import FileScanTask, Table
from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2, TableMetadataV3
from pyiceberg.table.sorting import NullOrder, SortField, SortOrder
from pyiceberg.transforms import DayTransform, IdentityTransform
from pyiceberg.typedef import Identifier
from pyiceberg.types import (
Expand Down Expand Up @@ -1887,6 +1888,11 @@ def test_partition_spec() -> PartitionSpec:
)


@pytest.fixture(scope="session")
def test_sort_order() -> SortOrder:
return SortOrder(SortField(source_id=1, transform=IdentityTransform(), null_order=NullOrder.NULLS_FIRST))


@pytest.fixture(scope="session")
def generated_manifest_entry_file(
avro_schema_manifest_entry: dict[str, Any], test_schema: Schema, test_partition_spec: PartitionSpec
Expand Down
54 changes: 54 additions & 0 deletions tests/integration/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
NoSuchNamespaceError,
NoSuchTableError,
TableAlreadyExistsError,
ValidationError,
)
from pyiceberg.io import WAREHOUSE
from pyiceberg.partitioning import PartitionField, PartitionSpec
Expand Down Expand Up @@ -635,3 +636,56 @@ def test_rest_custom_namespace_separator(rest_catalog: RestCatalog, table_schema

loaded_table = rest_catalog.load_table(identifier=full_table_identifier_tuple)
assert loaded_table.name() == full_table_identifier_tuple


@pytest.mark.integration
@pytest.mark.parametrize("test_catalog", CATALOGS)
def test_incompatible_partitioned_schema_evolution(
test_catalog: Catalog, test_schema: Schema, test_partition_spec: PartitionSpec, database_name: str, table_name: str
) -> None:
if isinstance(test_catalog, HiveCatalog):
pytest.skip("HiveCatalog does not support schema evolution")

identifier = (database_name, table_name)
test_catalog.create_namespace(database_name)
table = test_catalog.create_table(identifier, test_schema, partition_spec=test_partition_spec)
assert test_catalog.table_exists(identifier)

with pytest.raises(ValidationError):
with table.update_schema() as update:
update.delete_column("VendorID")

# Assert column was not dropped
assert "VendorID" in table.schema().column_names

with table.transaction() as transaction:
with transaction.update_spec() as spec_update:
spec_update.remove_field("VendorID")

with transaction.update_schema() as schema_update:
schema_update.delete_column("VendorID")

assert table.spec() == PartitionSpec(PartitionField(2, 1001, DayTransform(), "tpep_pickup_day"), spec_id=1)
assert table.schema() == Schema(NestedField(2, "tpep_pickup_datetime", TimestampType(), False))


@pytest.mark.integration
@pytest.mark.parametrize("test_catalog", CATALOGS)
def test_incompatible_sorted_schema_evolution(
test_catalog: Catalog, test_schema: Schema, test_sort_order: SortOrder, database_name: str, table_name: str
) -> None:
if isinstance(test_catalog, HiveCatalog):
pytest.skip("HiveCatalog does not support schema evolution")

identifier = (database_name, table_name)
test_catalog.create_namespace(database_name)
table = test_catalog.create_table(identifier, test_schema, sort_order=test_sort_order)
assert test_catalog.table_exists(identifier)

with pytest.raises(ValidationError):
with table.update_schema() as update:
update.delete_column("VendorID")

assert table.schema() == Schema(
NestedField(1, "VendorID", IntegerType(), False), NestedField(2, "tpep_pickup_datetime", TimestampType(), False)
)
34 changes: 34 additions & 0 deletions tests/table/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import pytest

from pyiceberg.exceptions import ValidationError
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.transforms import (
Expand Down Expand Up @@ -259,3 +260,36 @@ def test_deserialize_partition_field_v3() -> None:

field = PartitionField.model_validate_json(json_partition_spec)
assert field == PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(width=19), name="str_truncate")


def test_incompatible_source_column_not_found() -> None:
schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType()))

spec = PartitionSpec(PartitionField(3, 1000, IdentityTransform(), "some_partition"))

with pytest.raises(ValidationError) as exc:
spec.check_compatible(schema)

assert "Cannot find source column for partition field: 1000: some_partition: identity(3)" in str(exc.value)


def test_incompatible_non_primitive_type() -> None:
schema = Schema(NestedField(1, "foo", StructType()), NestedField(2, "bar", IntegerType()))

spec = PartitionSpec(PartitionField(1, 1000, IdentityTransform(), "some_partition"))

with pytest.raises(ValidationError) as exc:
spec.check_compatible(schema)

assert "Cannot partition by non-primitive source field: struct<>" in str(exc.value)


def test_incompatible_transform_source_type() -> None:
schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType()))

spec = PartitionSpec(PartitionField(1, 1000, YearTransform(), "some_partition"))

with pytest.raises(ValidationError) as exc:
spec.check_compatible(schema)

assert "Invalid source type int for transform: year" in str(exc.value)
38 changes: 37 additions & 1 deletion tests/table/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import pytest

from pyiceberg.exceptions import ValidationError
from pyiceberg.schema import Schema
from pyiceberg.table.metadata import TableMetadataUtil
from pyiceberg.table.sorting import (
UNSORTED_SORT_ORDER,
Expand All @@ -28,7 +30,8 @@
SortField,
SortOrder,
)
from pyiceberg.transforms import BucketTransform, IdentityTransform, VoidTransform
from pyiceberg.transforms import BucketTransform, IdentityTransform, VoidTransform, YearTransform
from pyiceberg.types import IntegerType, NestedField, StructType


@pytest.fixture
Expand Down Expand Up @@ -114,3 +117,36 @@ def test_serialize_sort_field_v3() -> None:
expected = SortField(source_id=19, transform=IdentityTransform(), null_order=NullOrder.NULLS_FIRST)
payload = '{"source-ids":[19],"transform":"identity","direction":"asc","null-order":"nulls-first"}'
assert SortField.model_validate_json(payload) == expected


def test_incompatible_source_column_not_found(sort_order: SortOrder) -> None:
schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType()))

with pytest.raises(ValidationError) as exc:
sort_order.check_compatible(schema)

assert "Cannot find source column for sort field: 19 ASC NULLS FIRST" in str(exc.value)


def test_incompatible_non_primitive_type() -> None:
schema = Schema(NestedField(1, "foo", StructType()), NestedField(2, "bar", IntegerType()))

sort_order = SortOrder(SortField(source_id=1, transform=IdentityTransform(), null_order=NullOrder.NULLS_FIRST))

with pytest.raises(ValidationError) as exc:
sort_order.check_compatible(schema)

assert "Cannot sort by non-primitive source field: 1: foo: optional struct<>" in str(exc.value)


def test_incompatible_transform_source_type() -> None:
schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType()))

sort_order = SortOrder(
SortField(source_id=1, transform=YearTransform(), null_order=NullOrder.NULLS_FIRST),
)

with pytest.raises(ValidationError) as exc:
sort_order.check_compatible(schema)

assert "Invalid source type int for transform: year" in str(exc.value)