Skip to content

Commit 78f32c3

Browse files
sbucaillequbvel
andauthored
[pipeline] Add Keypoint Matching pipeline (#39970)
* feat: keypoint-matcher pipeline * docs: added keypoint-matcher pipeline in docs * fix: added missing statements for repo consistency * docs: updated SuperGlue, LightGlue and EfficientLoFTR docs * Apply suggestions from code review Co-authored-by: Pavel Iakubovskii <[email protected]> * test: fixed run_pipeline_test * update pipeline typing and docs * update tests * update docs snippets * Fix import error * fix: pipeline init * pt framework --------- Co-authored-by: Pavel Iakubovskii <[email protected]>
1 parent 6451294 commit 78f32c3

File tree

9 files changed

+432
-3
lines changed

9 files changed

+432
-3
lines changed

docs/source/en/main_classes/pipelines.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,12 @@ Pipelines available for computer vision tasks include the following.
363363
- __call__
364364
- all
365365

366+
### KeypointMatchingPipeline
367+
368+
[[autodoc]] KeypointMatchingPipeline
369+
- __call__
370+
- all
371+
366372
### ObjectDetectionPipeline
367373

368374
[[autodoc]] ObjectDetectionPipeline

docs/source/en/model_doc/efficientloftr.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,24 @@ rendered properly in your Markdown viewer.
2828
>
2929
> Click on the EfficientLoFTR models in the right sidebar for more examples of how to apply EfficientLoFTR to different computer vision tasks.
3030
31-
The example below demonstrates how to match keypoints between two images with the [`AutoModel`] class.
31+
The example below demonstrates how to match keypoints between two images with [`Pipeline`] or the [`AutoModel`] class.
3232

3333
<hfoptions id="usage">
34+
<hfoption id="Pipeline">
35+
36+
```py
37+
from transformers import pipeline
38+
39+
keypoint_matcher = pipeline(task="keypoint-matching", model="zju-community/efficientloftr")
40+
41+
url_0 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"
42+
url_1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg"
43+
44+
results = keypoint_matcher([url_0, url_1], threshold=0.9)
45+
print(results[0])
46+
# {'keypoint_image_0': {'x': ..., 'y': ...}, 'keypoint_image_1': {'x': ..., 'y': ...}, 'score': ...}
47+
```
48+
<hfoption id="AutoModel">
3449
<hfoption id="AutoModel">
3550

3651
```py

docs/source/en/model_doc/lightglue.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,23 @@ You can find all the original LightGlue checkpoints under the [ETH-CVG](https://
3030
>
3131
> Click on the LightGlue models in the right sidebar for more examples of how to apply LightGlue to different computer vision tasks.
3232
33-
The example below demonstrates how to match keypoints between two images with the [`AutoModel`] class.
33+
The example below demonstrates how to match keypoints between two images with [`Pipeline`] or the [`AutoModel`] class.
3434

3535
<hfoptions id="usage">
36+
<hfoption id="Pipeline">
37+
38+
```py
39+
from transformers import pipeline
40+
41+
keypoint_matcher = pipeline(task="keypoint-matching", model="ETH-CVG/lightglue_superpoint")
42+
43+
url_0 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"
44+
url_1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg"
45+
46+
results = keypoint_matcher([url_0, url_1], threshold=0.9)
47+
print(results[0])
48+
# {'keypoint_image_0': {'x': ..., 'y': ...}, 'keypoint_image_1': {'x': ..., 'y': ...}, 'score': ...}
49+
```
3650
<hfoption id="AutoModel">
3751

3852
```py

docs/source/en/model_doc/superglue.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,25 @@ You can find all the original SuperGlue checkpoints under the [Magic Leap Commun
3030
>
3131
> Click on the SuperGlue models in the right sidebar for more examples of how to apply SuperGlue to different computer vision tasks.
3232
33-
The example below demonstrates how to match keypoints between two images with the [`AutoModel`] class.
33+
The example below demonstrates how to match keypoints between two images with [`Pipeline`] or the [`AutoModel`] class.
3434

3535
<hfoptions id="usage">
36+
<hfoption id="Pipeline">
37+
38+
```py
39+
from transformers import pipeline
40+
41+
keypoint_matcher = pipeline(task="keypoint-matching", model="magic-leap-community/superglue_outdoor")
42+
43+
url_0 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"
44+
url_1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg"
45+
46+
results = keypoint_matcher([url_0, url_1], threshold=0.9)
47+
print(results[0])
48+
# {'keypoint_image_0': {'x': ..., 'y': ...}, 'keypoint_image_1': {'x': ..., 'y': ...}, 'score': ...}
49+
```
50+
51+
</hfoption>
3652
<hfoption id="AutoModel">
3753

3854
```py

src/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@
163163
"ImageToImagePipeline",
164164
"ImageToTextPipeline",
165165
"JsonPipelineDataFormat",
166+
"KeypointMatchingPipeline",
166167
"MaskGenerationPipeline",
167168
"NerPipeline",
168169
"ObjectDetectionPipeline",
@@ -826,6 +827,7 @@
826827
from .pipelines import ImageToImagePipeline as ImageToImagePipeline
827828
from .pipelines import ImageToTextPipeline as ImageToTextPipeline
828829
from .pipelines import JsonPipelineDataFormat as JsonPipelineDataFormat
830+
from .pipelines import KeypointMatchingPipeline as KeypointMatchingPipeline
829831
from .pipelines import MaskGenerationPipeline as MaskGenerationPipeline
830832
from .pipelines import NerPipeline as NerPipeline
831833
from .pipelines import ObjectDetectionPipeline as ObjectDetectionPipeline

src/transformers/pipelines/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
from .image_text_to_text import ImageTextToTextPipeline
7171
from .image_to_image import ImageToImagePipeline
7272
from .image_to_text import ImageToTextPipeline
73+
from .keypoint_matching import KeypointMatchingPipeline
7374
from .mask_generation import MaskGenerationPipeline
7475
from .object_detection import ObjectDetectionPipeline
7576
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
@@ -121,6 +122,7 @@
121122
AutoModelForImageClassification,
122123
AutoModelForImageSegmentation,
123124
AutoModelForImageTextToText,
125+
AutoModelForKeypointMatching,
124126
AutoModelForMaskedLM,
125127
AutoModelForMaskGeneration,
126128
AutoModelForObjectDetection,
@@ -439,6 +441,13 @@
439441
"default": {"model": {"pt": ("caidas/swin2SR-classical-sr-x2-64", "cee1c92")}},
440442
"type": "image",
441443
},
444+
"keypoint-matching": {
445+
"impl": KeypointMatchingPipeline,
446+
"tf": (),
447+
"pt": (AutoModelForKeypointMatching,) if is_torch_available() else (),
448+
"default": {"model": {"pt": ("magic-leap-community/superglue_outdoor", "f4041f8")}},
449+
"type": "image",
450+
},
442451
}
443452

444453
PIPELINE_REGISTRY = PipelineRegistry(supported_tasks=SUPPORTED_TASKS, task_aliases=TASK_ALIASES)
@@ -499,6 +508,7 @@ def check_task(task: str) -> tuple[str, dict, Any]:
499508
- `"image-segmentation"`
500509
- `"image-to-text"`
501510
- `"image-to-image"`
511+
- `"keypoint-matching"`
502512
- `"object-detection"`
503513
- `"question-answering"`
504514
- `"summarization"`
@@ -581,6 +591,8 @@ def pipeline(task: Literal["image-to-image"], model: Optional[Union[str, "PreTra
581591
@overload
582592
def pipeline(task: Literal["image-to-text"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ImageToTextPipeline: ...
583593
@overload
594+
def pipeline(task: Literal["keypoint-matching"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> KeypointMatchingPipeline: ...
595+
@overload
584596
def pipeline(task: Literal["mask-generation"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> MaskGenerationPipeline: ...
585597
@overload
586598
def pipeline(task: Literal["object-detection"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ObjectDetectionPipeline: ...
@@ -675,6 +687,7 @@ def pipeline(
675687
- `"image-text-to-text"`: will return a [`ImageTextToTextPipeline`].
676688
- `"image-to-image"`: will return a [`ImageToImagePipeline`].
677689
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
690+
- `"keypoint-matching"`: will return a [`KeypointMatchingPipeline`].
678691
- `"mask-generation"`: will return a [`MaskGenerationPipeline`].
679692
- `"object-detection"`: will return a [`ObjectDetectionPipeline`].
680693
- `"question-answering"`: will return a [`QuestionAnsweringPipeline`].
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Sequence, TypedDict, Union
16+
17+
from typing_extensions import TypeAlias, overload
18+
19+
from ..image_utils import is_pil_image
20+
from ..utils import is_vision_available, requires_backends
21+
from .base import Pipeline
22+
23+
24+
if is_vision_available():
25+
from PIL import Image
26+
27+
from ..image_utils import load_image
28+
29+
30+
ImagePair: TypeAlias = Sequence[Union["Image.Image", str]]
31+
32+
Keypoint = TypedDict("Keypoint", {"x": float, "y": float})
33+
Match = TypedDict("Match", {"keypoint_image_0": Keypoint, "keypoint_image_1": Keypoint, "score": float})
34+
35+
36+
def validate_image_pairs(images: Any) -> Sequence[Sequence[ImagePair]]:
37+
error_message = (
38+
"Input images must be a one of the following :",
39+
" - A pair of images.",
40+
" - A list of pairs of images.",
41+
)
42+
43+
def _is_valid_image(image):
44+
"""images is a PIL Image or a string."""
45+
return is_pil_image(image) or isinstance(image, str)
46+
47+
if isinstance(images, Sequence):
48+
if len(images) == 2 and all((_is_valid_image(image)) for image in images):
49+
return [images]
50+
if all(
51+
isinstance(image_pair, Sequence)
52+
and len(image_pair) == 2
53+
and all(_is_valid_image(image) for image in image_pair)
54+
for image_pair in images
55+
):
56+
return images
57+
raise ValueError(error_message)
58+
59+
60+
class KeypointMatchingPipeline(Pipeline):
61+
"""
62+
Keypoint matching pipeline using any `AutoModelForKeypointMatching`. This pipeline matches keypoints between two images.
63+
"""
64+
65+
_load_processor = False
66+
_load_image_processor = True
67+
_load_feature_extractor = False
68+
_load_tokenizer = False
69+
70+
def __init__(self, *args, **kwargs):
71+
super().__init__(*args, **kwargs)
72+
requires_backends(self, "vision")
73+
if self.framework != "pt":
74+
raise ValueError("Keypoint matching pipeline only supports PyTorch (framework='pt').")
75+
76+
def _sanitize_parameters(self, threshold=None, timeout=None):
77+
preprocess_params = {}
78+
if timeout is not None:
79+
preprocess_params["timeout"] = timeout
80+
postprocess_params = {}
81+
if threshold is not None:
82+
postprocess_params["threshold"] = threshold
83+
return preprocess_params, {}, postprocess_params
84+
85+
@overload
86+
def __call__(self, inputs: ImagePair, threshold: float = 0.0, **kwargs: Any) -> list[Match]: ...
87+
88+
@overload
89+
def __call__(self, inputs: list[ImagePair], threshold: float = 0.0, **kwargs: Any) -> list[list[Match]]: ...
90+
91+
def __call__(
92+
self,
93+
inputs: Union[list[ImagePair], ImagePair],
94+
threshold: float = 0.0,
95+
**kwargs: Any,
96+
) -> Union[list[Match], list[list[Match]]]:
97+
"""
98+
Find matches between keypoints in two images.
99+
100+
Args:
101+
inputs (`str`, `list[str]`, `PIL.Image` or `list[PIL.Image]`):
102+
The pipeline handles three types of images:
103+
104+
- A string containing a http link pointing to an image
105+
- A string containing a local path to an image
106+
- An image loaded in PIL directly
107+
108+
The pipeline accepts either a single pair of images or a batch of image pairs, which must then be passed as a string.
109+
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
110+
images.
111+
112+
threshold (`float`, *optional*, defaults to 0.0):
113+
The threshold to use for keypoint matching. Keypoints matched with a lower matching score will be filtered out.
114+
A value of 0 means that all matched keypoints will be returned.
115+
116+
kwargs:
117+
`timeout (`float`, *optional*, defaults to None)`
118+
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
119+
the call may block forever.
120+
121+
Return:
122+
Union[list[Match], list[list[Match]]]:
123+
A list of matches or a list if a single image pair is provided, or of lists of matches if a batch
124+
of image pairs is provided. Each match is a dictionary containing the following keys:
125+
126+
- **keypoint_image_0** (`Keypoint`): The keypoint in the first image (x, y coordinates).
127+
- **keypoint_image_1** (`Keypoint`): The keypoint in the second image (x, y coordinates).
128+
- **score** (`float`): The matching score between the two keypoints.
129+
"""
130+
if inputs is None:
131+
raise ValueError("Cannot call the keypoint-matching pipeline without an inputs argument!")
132+
formatted_inputs = validate_image_pairs(inputs)
133+
outputs = super().__call__(formatted_inputs, threshold=threshold, **kwargs)
134+
if len(formatted_inputs) == 1:
135+
return outputs[0]
136+
return outputs
137+
138+
def preprocess(self, images, timeout=None):
139+
images = [load_image(image, timeout=timeout) for image in images]
140+
model_inputs = self.image_processor(images=images, return_tensors=self.framework)
141+
model_inputs = model_inputs.to(self.torch_dtype)
142+
target_sizes = [image.size for image in images]
143+
preprocess_outputs = {"model_inputs": model_inputs, "target_sizes": target_sizes}
144+
return preprocess_outputs
145+
146+
def _forward(self, preprocess_outputs):
147+
model_inputs = preprocess_outputs["model_inputs"]
148+
model_outputs = self.model(**model_inputs)
149+
forward_outputs = {"model_outputs": model_outputs, "target_sizes": [preprocess_outputs["target_sizes"]]}
150+
return forward_outputs
151+
152+
def postprocess(self, forward_outputs, threshold=0.0) -> list[Match]:
153+
model_outputs = forward_outputs["model_outputs"]
154+
target_sizes = forward_outputs["target_sizes"]
155+
postprocess_outputs = self.image_processor.post_process_keypoint_matching(
156+
model_outputs, target_sizes=target_sizes, threshold=threshold
157+
)
158+
postprocess_outputs = postprocess_outputs[0]
159+
pair_result = []
160+
for kp_0, kp_1, score in zip(
161+
postprocess_outputs["keypoints0"],
162+
postprocess_outputs["keypoints1"],
163+
postprocess_outputs["matching_scores"],
164+
):
165+
kp_0 = Keypoint(x=kp_0[0].item(), y=kp_0[1].item())
166+
kp_1 = Keypoint(x=kp_1[0].item(), y=kp_1[1].item())
167+
pair_result.append(Match(keypoint_image_0=kp_0, keypoint_image_1=kp_1, score=score.item()))
168+
pair_result = sorted(pair_result, key=lambda x: x["score"], reverse=True)
169+
return pair_result

0 commit comments

Comments
 (0)