Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,12 @@ Pipelines available for computer vision tasks include the following.
- __call__
- all

### KeypointMatchingPipeline

[[autodoc]] KeypointMatchingPipeline
- __call__
- all

### ObjectDetectionPipeline

[[autodoc]] ObjectDetectionPipeline
Expand Down
17 changes: 16 additions & 1 deletion docs/source/en/model_doc/efficientloftr.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,24 @@ rendered properly in your Markdown viewer.
>
> Click on the EfficientLoFTR models in the right sidebar for more examples of how to apply EfficientLoFTR to different computer vision tasks.

The example below demonstrates how to match keypoints between two images with the [`AutoModel`] class.
The example below demonstrates how to match keypoints between two images with [`Pipeline`] or the [`AutoModel`] class.

<hfoptions id="usage">
<hfoption id="Pipeline">

```py
from transformers import pipeline

keypoint_matcher = pipeline(task="keypoint-matching", model="zju-community/efficientloftr")

url_0 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"
url_1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg"

results = keypoint_matcher([url_0, url_1], threshold=0.9)
print(results[0])
# {'keypoint_image_0': {'x': ..., 'y': ...}, 'keypoint_image_1': {'x': ..., 'y': ...}, 'score': ...}
```
<hfoption id="AutoModel">
<hfoption id="AutoModel">

```py
Expand Down
16 changes: 15 additions & 1 deletion docs/source/en/model_doc/lightglue.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,23 @@ You can find all the original LightGlue checkpoints under the [ETH-CVG](https://
>
> Click on the LightGlue models in the right sidebar for more examples of how to apply LightGlue to different computer vision tasks.

The example below demonstrates how to match keypoints between two images with the [`AutoModel`] class.
The example below demonstrates how to match keypoints between two images with [`Pipeline`] or the [`AutoModel`] class.

<hfoptions id="usage">
<hfoption id="Pipeline">

```py
from transformers import pipeline

keypoint_matcher = pipeline(task="keypoint-matching", model="ETH-CVG/lightglue_superpoint")

url_0 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"
url_1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg"

results = keypoint_matcher([url_0, url_1], threshold=0.9)
print(results[0])
# {'keypoint_image_0': {'x': ..., 'y': ...}, 'keypoint_image_1': {'x': ..., 'y': ...}, 'score': ...}
```
<hfoption id="AutoModel">

```py
Expand Down
18 changes: 17 additions & 1 deletion docs/source/en/model_doc/superglue.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,25 @@ You can find all the original SuperGlue checkpoints under the [Magic Leap Commun
>
> Click on the SuperGlue models in the right sidebar for more examples of how to apply SuperGlue to different computer vision tasks.

The example below demonstrates how to match keypoints between two images with the [`AutoModel`] class.
The example below demonstrates how to match keypoints between two images with [`Pipeline`] or the [`AutoModel`] class.

<hfoptions id="usage">
<hfoption id="Pipeline">

```py
from transformers import pipeline

keypoint_matcher = pipeline(task="keypoint-matching", model="magic-leap-community/superglue_outdoor")

url_0 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"
url_1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg"

results = keypoint_matcher([url_0, url_1], threshold=0.9)
print(results[0])
# {'keypoint_image_0': {'x': ..., 'y': ...}, 'keypoint_image_1': {'x': ..., 'y': ...}, 'score': ...}
```

</hfoption>
<hfoption id="AutoModel">

```py
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
"ImageToImagePipeline",
"ImageToTextPipeline",
"JsonPipelineDataFormat",
"KeypointMatchingPipeline",
"MaskGenerationPipeline",
"NerPipeline",
"ObjectDetectionPipeline",
Expand Down Expand Up @@ -826,6 +827,7 @@
from .pipelines import ImageToImagePipeline as ImageToImagePipeline
from .pipelines import ImageToTextPipeline as ImageToTextPipeline
from .pipelines import JsonPipelineDataFormat as JsonPipelineDataFormat
from .pipelines import KeypointMatchingPipeline as KeypointMatchingPipeline
from .pipelines import MaskGenerationPipeline as MaskGenerationPipeline
from .pipelines import NerPipeline as NerPipeline
from .pipelines import ObjectDetectionPipeline as ObjectDetectionPipeline
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from .image_text_to_text import ImageTextToTextPipeline
from .image_to_image import ImageToImagePipeline
from .image_to_text import ImageToTextPipeline
from .keypoint_matching import KeypointMatchingPipeline
from .mask_generation import MaskGenerationPipeline
from .object_detection import ObjectDetectionPipeline
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
Expand Down Expand Up @@ -121,6 +122,7 @@
AutoModelForImageClassification,
AutoModelForImageSegmentation,
AutoModelForImageTextToText,
AutoModelForKeypointMatching,
AutoModelForMaskedLM,
AutoModelForMaskGeneration,
AutoModelForObjectDetection,
Expand Down Expand Up @@ -439,6 +441,13 @@
"default": {"model": {"pt": ("caidas/swin2SR-classical-sr-x2-64", "cee1c92")}},
"type": "image",
},
"keypoint-matching": {
"impl": KeypointMatchingPipeline,
"tf": (),
"pt": (AutoModelForKeypointMatching,) if is_torch_available() else (),
"default": {"model": {"pt": ("magic-leap-community/superglue_outdoor", "f4041f8")}},
"type": "image",
},
}

PIPELINE_REGISTRY = PipelineRegistry(supported_tasks=SUPPORTED_TASKS, task_aliases=TASK_ALIASES)
Expand Down Expand Up @@ -499,6 +508,7 @@ def check_task(task: str) -> tuple[str, dict, Any]:
- `"image-segmentation"`
- `"image-to-text"`
- `"image-to-image"`
- `"keypoint-matching"`
- `"object-detection"`
- `"question-answering"`
- `"summarization"`
Expand Down Expand Up @@ -581,6 +591,8 @@ def pipeline(task: Literal["image-to-image"], model: Optional[Union[str, "PreTra
@overload
def pipeline(task: Literal["image-to-text"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ImageToTextPipeline: ...
@overload
def pipeline(task: Literal["keypoint-matching"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> KeypointMatchingPipeline: ...
@overload
def pipeline(task: Literal["mask-generation"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> MaskGenerationPipeline: ...
@overload
def pipeline(task: Literal["object-detection"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ObjectDetectionPipeline: ...
Expand Down Expand Up @@ -675,6 +687,7 @@ def pipeline(
- `"image-text-to-text"`: will return a [`ImageTextToTextPipeline`].
- `"image-to-image"`: will return a [`ImageToImagePipeline`].
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
- `"keypoint-matching"`: will return a [`KeypointMatchingPipeline`].
- `"mask-generation"`: will return a [`MaskGenerationPipeline`].
- `"object-detection"`: will return a [`ObjectDetectionPipeline`].
- `"question-answering"`: will return a [`QuestionAnsweringPipeline`].
Expand Down
169 changes: 169 additions & 0 deletions src/transformers/pipelines/keypoint_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Sequence, TypedDict, Union

from typing_extensions import TypeAlias, overload

from ..image_utils import is_pil_image
from ..utils import is_vision_available, requires_backends
from .base import Pipeline


if is_vision_available():
from PIL import Image

from ..image_utils import load_image


ImagePair: TypeAlias = Sequence[Union["Image.Image", str]]

Keypoint = TypedDict("Keypoint", {"x": float, "y": float})
Match = TypedDict("Match", {"keypoint_image_0": Keypoint, "keypoint_image_1": Keypoint, "score": float})


def validate_image_pairs(images: Any) -> Sequence[Sequence[ImagePair]]:
error_message = (
"Input images must be a one of the following :",
" - A pair of images.",
" - A list of pairs of images.",
)

def _is_valid_image(image):
"""images is a PIL Image or a string."""
return is_pil_image(image) or isinstance(image, str)

if isinstance(images, Sequence):
if len(images) == 2 and all((_is_valid_image(image)) for image in images):
return [images]
if all(
isinstance(image_pair, Sequence)
and len(image_pair) == 2
and all(_is_valid_image(image) for image in image_pair)
for image_pair in images
):
return images
raise ValueError(error_message)


class KeypointMatchingPipeline(Pipeline):
"""
Keypoint matching pipeline using any `AutoModelForKeypointMatching`. This pipeline matches keypoints between two images.
"""

_load_processor = False
_load_image_processor = True
_load_feature_extractor = False
_load_tokenizer = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
requires_backends(self, "vision")
if self.framework != "pt":
raise ValueError("Keypoint matching pipeline only supports PyTorch (framework='pt').")

def _sanitize_parameters(self, threshold=None, timeout=None):
preprocess_params = {}
if timeout is not None:
preprocess_params["timeout"] = timeout
postprocess_params = {}
if threshold is not None:
postprocess_params["threshold"] = threshold
return preprocess_params, {}, postprocess_params

@overload
def __call__(self, inputs: ImagePair, threshold: float = 0.0, **kwargs: Any) -> list[Match]: ...

@overload
def __call__(self, inputs: list[ImagePair], threshold: float = 0.0, **kwargs: Any) -> list[list[Match]]: ...

def __call__(
self,
inputs: Union[list[ImagePair], ImagePair],
threshold: float = 0.0,
**kwargs: Any,
) -> Union[list[Match], list[list[Match]]]:
"""
Find matches between keypoints in two images.

Args:
inputs (`str`, `list[str]`, `PIL.Image` or `list[PIL.Image]`):
The pipeline handles three types of images:

- A string containing a http link pointing to an image
- A string containing a local path to an image
- An image loaded in PIL directly

The pipeline accepts either a single pair of images or a batch of image pairs, which must then be passed as a string.
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
images.

threshold (`float`, *optional*, defaults to 0.0):
The threshold to use for keypoint matching. Keypoints matched with a lower matching score will be filtered out.
A value of 0 means that all matched keypoints will be returned.

kwargs:
`timeout (`float`, *optional*, defaults to None)`
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.

Return:
Union[list[Match], list[list[Match]]]:
A list of matches or a list if a single image pair is provided, or of lists of matches if a batch
of image pairs is provided. Each match is a dictionary containing the following keys:

- **keypoint_image_0** (`Keypoint`): The keypoint in the first image (x, y coordinates).
- **keypoint_image_1** (`Keypoint`): The keypoint in the second image (x, y coordinates).
- **score** (`float`): The matching score between the two keypoints.
"""
if inputs is None:
raise ValueError("Cannot call the keypoint-matching pipeline without an inputs argument!")
formatted_inputs = validate_image_pairs(inputs)
outputs = super().__call__(formatted_inputs, threshold=threshold, **kwargs)
if len(formatted_inputs) == 1:
return outputs[0]
return outputs

def preprocess(self, images, timeout=None):
images = [load_image(image, timeout=timeout) for image in images]
model_inputs = self.image_processor(images=images, return_tensors=self.framework)
model_inputs = model_inputs.to(self.torch_dtype)
target_sizes = [image.size for image in images]
preprocess_outputs = {"model_inputs": model_inputs, "target_sizes": target_sizes}
return preprocess_outputs

def _forward(self, preprocess_outputs):
model_inputs = preprocess_outputs["model_inputs"]
model_outputs = self.model(**model_inputs)
forward_outputs = {"model_outputs": model_outputs, "target_sizes": [preprocess_outputs["target_sizes"]]}
return forward_outputs

def postprocess(self, forward_outputs, threshold=0.0) -> list[Match]:
model_outputs = forward_outputs["model_outputs"]
target_sizes = forward_outputs["target_sizes"]
postprocess_outputs = self.image_processor.post_process_keypoint_matching(
model_outputs, target_sizes=target_sizes, threshold=threshold
)
postprocess_outputs = postprocess_outputs[0]
pair_result = []
for kp_0, kp_1, score in zip(
postprocess_outputs["keypoints0"],
postprocess_outputs["keypoints1"],
postprocess_outputs["matching_scores"],
):
kp_0 = Keypoint(x=kp_0[0].item(), y=kp_0[1].item())
kp_1 = Keypoint(x=kp_1[0].item(), y=kp_1[1].item())
pair_result.append(Match(keypoint_image_0=kp_0, keypoint_image_1=kp_1, score=score.item()))
pair_result = sorted(pair_result, key=lambda x: x["score"], reverse=True)
return pair_result
Loading