Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 140 additions & 60 deletions commitizen/providers/cargo_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,41 @@
import fnmatch
import glob
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, cast

from tomlkit import TOMLDocument, dumps, parse
from tomlkit.exceptions import NonExistentKey

from commitizen.providers.base_provider import TomlProvider

if TYPE_CHECKING:
from collections.abc import Iterable

from tomlkit.items import AoT


class CargoProvider(TomlProvider):
"""
Cargo version management
DictLike = dict[str, Any]

With support for `workspaces`
"""

class CargoProvider(TomlProvider):
"""Cargo version management for virtual workspace manifests + version.workspace=true members."""

filename = "Cargo.toml"
lock_filename = "Cargo.lock"

@property
def lock_file(self) -> Path:
return Path() / self.lock_filename
return Path(self.lock_filename)

def get(self, document: TOMLDocument) -> str:
out = _try_get_workspace(document)["package"]["version"]
if TYPE_CHECKING:
assert isinstance(out, str)
return out
t = _root_version_table(document)
v = t.get("version")
if not isinstance(v, str):
raise TypeError("expected root version to be a string")
return v

def set(self, document: TOMLDocument, version: str) -> None:
_try_get_workspace(document)["package"]["version"] = version
_root_version_table(document)["version"] = version

def set_version(self, version: str) -> None:
super().set_version(version)
Expand All @@ -50,56 +52,134 @@ def set_lock_version(self, version: str) -> None:
if TYPE_CHECKING:
assert isinstance(packages, AoT)

try:
cargo_package_name = cargo_toml_content["package"]["name"] # type: ignore[index]
if TYPE_CHECKING:
assert isinstance(cargo_package_name, str)
for i, package in enumerate(packages):
if package["name"] == cargo_package_name:
cargo_lock_content["package"][i]["version"] = version # type: ignore[index]
break
except NonExistentKey:
workspace = cargo_toml_content.get("workspace", {})
if TYPE_CHECKING:
assert isinstance(workspace, dict)
workspace_members = workspace.get("members", [])
excluded_workspace_members = workspace.get("exclude", [])
members_inheriting: list[str] = []

for member in workspace_members:
for path in glob.glob(member, recursive=True):
if any(
fnmatch.fnmatch(path, pattern)
for pattern in excluded_workspace_members
):
continue

cargo_file = Path(path) / "Cargo.toml"
package_content = parse(cargo_file.read_text()).get("package", {})
if TYPE_CHECKING:
assert isinstance(package_content, dict)
try:
version_workspace = package_content["version"]["workspace"]
if version_workspace is True:
package_name = package_content["name"]
if TYPE_CHECKING:
assert isinstance(package_name, str)
members_inheriting.append(package_name)
except NonExistentKey:
pass

for i, package in enumerate(packages):
if package["name"] in members_inheriting:
cargo_lock_content["package"][i]["version"] = version # type: ignore[index]

root_pkg = _table_get(cargo_toml_content, "package")
if root_pkg is not None:
name = root_pkg.get("name")
if isinstance(name, str):
_lock_set_versions(packages, {name}, version)
self.lock_file.write_text(dumps(cargo_lock_content))
return

ws = _table_get(cargo_toml_content, "workspace") or {}
member_globs = cast("list[str]", ws.get("members", []) or [])
exclude_globs = cast("list[str]", ws.get("exclude", []) or [])
inheriting = _workspace_inheriting_member_names(member_globs, exclude_globs)
_lock_set_versions(packages, inheriting, version)
self.lock_file.write_text(dumps(cargo_lock_content))


def _try_get_workspace(document: TOMLDocument) -> dict:
def _table_get(doc: TOMLDocument, key: str) -> DictLike | None:
"""Get a TOML table by key as a dict-like object.

Returns:
The value at `doc[key]` cast to a dict-like table (supports `.get`) if it
exists and is table/container-like; otherwise returns None.

Rationale:
tomlkit returns loosely-typed Container/Table objects; using a small
helper keeps call sites readable and makes type-checkers happier.
"""
try:
workspace = document["workspace"]
if TYPE_CHECKING:
assert isinstance(workspace, dict)
return workspace
value = doc[key]
except NonExistentKey:
return document
return None
return cast("DictLike", value) if hasattr(value, "get") else None


def _root_version_table(doc: TOMLDocument) -> DictLike:
"""Return the table that owns the "root" version field.

This provider supports two layouts:

1) Workspace virtual manifests:
[workspace.package]
version = "x.y.z"

2) Regular crate(non-workspace root manifest):
[package]
version = "x.y.z"

The selected table is where `get()` reads from and `set()` writes to.
"""
workspace_table = _table_get(doc, "workspace")
if workspace_table is not None:
workspace_package_table = workspace_table.get("package")
if hasattr(workspace_package_table, "get"):
return cast("DictLike", workspace_package_table)

package_table = _table_get(doc, "package")
if package_table is None:
raise NonExistentKey("expected either [workspace.package] or [package]")
return package_table


def _is_workspace_inherited_version(v: Any) -> bool:
return hasattr(v, "get") and cast("DictLike", v).get("workspace") is True


def _iter_member_dirs(
member_globs: Iterable[str], exclude_globs: Iterable[str]
) -> Iterable[Path]:
"""Yield workspace member directories matched by `member_globs`, excluding `exclude_globs`.

Cargo workspaces define members/exclude as glob patterns (e.g. "crates/*").
This helper expands those patterns and yields the corresponding directories
as `Path` objects, skipping any matches that satisfy an exclude glob.

Kept as a helper to make call sites read as domain logic ("iterate member dirs")
rather than glob/filter plumbing.
"""
for member_glob in member_globs:
for match in glob.glob(member_glob, recursive=True):
if any(fnmatch.fnmatch(match, ex) for ex in exclude_globs):
continue
yield Path(match)


def _workspace_inheriting_member_names(
members: Iterable[str], excludes: Iterable[str]
) -> set[str]:
"""Return workspace member crate names that inherit the workspace version.

A member is considered "inheriting" when its Cargo.toml has:
[package]
version.workspace = true

This scans `members` globs (respecting `excludes`) and returns the set of
`[package].name` values for matching crates. Missing/invalid Cargo.toml files
are ignored.
"""
inheriting_member_names: set[str] = set()
for d in _iter_member_dirs(members, excludes):
cargo_file = d / "Cargo.toml"
if not cargo_file.exists():
continue
pkg = parse(cargo_file.read_text()).get("package")
if not hasattr(pkg, "get"):
continue
pkgd = cast("DictLike", pkg)
if _is_workspace_inherited_version(pkgd.get("version")):
name = pkgd.get("name")
if isinstance(name, str):
inheriting_member_names.add(name)
return inheriting_member_names


def _lock_set_versions(packages: Any, package_names: set[str], version: str) -> None:
"""Update Cargo.lock package entries in-place.

Args:
packages: `Cargo.lock` parsed TOML "package" array (AoT-like). Mutated in-place.
package_names: Set of package names whose `version` field should be updated.
version: New version string to write.

Notes:
We use `enumerate` + index assignment because tomlkit AoT entries may be
Container-like and direct mutation patterns vary; indexed assignment is
reliable for updating the underlying document.
"""
if not package_names:
return
for i, pkg_entry in enumerate(packages):
if getattr(pkg_entry, "get", None) and pkg_entry.get("name") in package_names:
packages[i]["version"] = version
143 changes: 143 additions & 0 deletions tests/providers/test_cargo_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING

import pytest
from tomlkit.exceptions import NonExistentKey

from commitizen.providers import get_provider
from commitizen.providers.cargo_provider import CargoProvider
Expand Down Expand Up @@ -451,3 +452,145 @@ def test_cargo_provider_workspace_member_without_workspace_key(
assert file.read_text() == dedent(expected_workspace_toml)
# The lock file should remain unchanged since the member doesn't inherit workspace version
assert lock_file.read_text() == dedent(expected_lock_content)


def test_cargo_provider_get_raises_when_version_is_not_string(
config: BaseConfig,
chdir: Path,
) -> None:
"""CargoProvider.get should raise when root version is not a string."""
(chdir / CargoProvider.filename).write_text(
dedent(
"""\
[package]
name = "whatever"
version = 1
"""
)
)
config.settings["version_provider"] = "cargo"

provider = get_provider(config)
assert isinstance(provider, CargoProvider)

with pytest.raises(TypeError, match=r"expected root version to be a string"):
provider.get_version()


def test_cargo_provider_get_raises_when_no_package_tables(
config: BaseConfig,
chdir: Path,
) -> None:
"""_root_version_table should raise when neither [workspace.package] nor [package] exists."""
(chdir / CargoProvider.filename).write_text(
dedent(
"""\
[workspace]
members = []
"""
)
)
config.settings["version_provider"] = "cargo"

provider = get_provider(config)
assert isinstance(provider, CargoProvider)

with pytest.raises(
NonExistentKey, match=r"expected either \[workspace\.package\] or \[package\]"
):
provider.get_version()


def test_workspace_member_dir_without_cargo_toml_is_ignored(
config: BaseConfig,
chdir: Path,
) -> None:
"""Cover: if not cargo_file.exists(): continue"""
workspace_toml = """\
[workspace]
members = ["missing_manifest"]

[workspace.package]
version = "0.1.0"
"""
lock_content = """\
[[package]]
name = "missing_manifest"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "123abc"
"""
expected_workspace_toml = """\
[workspace]
members = ["missing_manifest"]

[workspace.package]
version = "42.1"
"""

(chdir / CargoProvider.filename).write_text(dedent(workspace_toml))
os.mkdir(chdir / "missing_manifest") # directory exists, but Cargo.toml does NOT

(chdir / CargoProvider.lock_filename).write_text(dedent(lock_content))

config.settings["version_provider"] = "cargo"
provider = get_provider(config)
assert isinstance(provider, CargoProvider)

provider.set_version("42.1")
assert (chdir / CargoProvider.filename).read_text() == dedent(
expected_workspace_toml
)
# lock should remain unchanged since member cannot be inspected => not inheriting
assert (chdir / CargoProvider.lock_filename).read_text() == dedent(lock_content)


def test_workspace_member_with_non_table_package_is_ignored(
config: BaseConfig,
chdir: Path,
) -> None:
"""Cover: if not hasattr(pkg, "get"): continue"""
workspace_toml = """\
[workspace]
members = ["bad_package"]

[workspace.package]
version = "0.1.0"
"""
member_toml = """\
package = "oops"
"""
lock_content = """\
[[package]]
name = "bad_package"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "123abc"
"""
expected_workspace_toml = """\
[workspace]
members = ["bad_package"]

[workspace.package]
version = "42.1"
"""

(chdir / CargoProvider.filename).write_text(dedent(workspace_toml))

os.mkdir(chdir / "bad_package")
(chdir / "bad_package" / "Cargo.toml").write_text(
dedent(member_toml)
) # package is str, not table

(chdir / CargoProvider.lock_filename).write_text(dedent(lock_content))

config.settings["version_provider"] = "cargo"
provider = get_provider(config)
assert isinstance(provider, CargoProvider)

provider.set_version("42.1")
assert (chdir / CargoProvider.filename).read_text() == dedent(
expected_workspace_toml
)
# lock should remain unchanged since package is not a table => not inheriting
assert (chdir / CargoProvider.lock_filename).read_text() == dedent(lock_content)