Skip to content

Commit a3afebb

Browse files
authored
[modular] Use multi-processing + fix model import issue (#40481)
* add mp and simplify a bit * improve * fix * fix imports * nit
1 parent 75d6f17 commit a3afebb

File tree

2 files changed

+41
-29
lines changed

2 files changed

+41
-29
lines changed

utils/check_modular_conversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def process_file(
3636
# Read the actual modeling file
3737
with open(file_path, "r", encoding="utf-8") as modeling_file:
3838
content = modeling_file.read()
39-
output_buffer = StringIO(generated_modeling_content[file_type][0])
39+
output_buffer = StringIO(generated_modeling_content[file_type])
4040
output_buffer.seek(0)
4141
output_content = output_buffer.read()
4242
diff = difflib.unified_diff(
@@ -54,7 +54,7 @@ def process_file(
5454
shutil.copy(file_path, file_path + BACKUP_EXT)
5555
# we always save the generated content, to be able to update dependant files
5656
with open(file_path, "w", encoding="utf-8", newline="\n") as modeling_file:
57-
modeling_file.write(generated_modeling_content[file_type][0])
57+
modeling_file.write(generated_modeling_content[file_type])
5858
console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]")
5959
if show_diff:
6060
console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n")

utils/modular_model_converter.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import argparse
1616
import glob
1717
import importlib
18+
import multiprocessing as mp
1819
import os
1920
import re
2021
import subprocess
@@ -1226,7 +1227,8 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
12261227
if m.matches(node.module, m.Attribute()):
12271228
for imported_ in node.names:
12281229
_import = re.search(
1229-
rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement
1230+
rf"(?:transformers\.models\.)|(?:\.\.\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*",
1231+
import_statement,
12301232
)
12311233
if _import:
12321234
source = _import.group(1)
@@ -1688,7 +1690,8 @@ def run_ruff(code, check=False):
16881690
return stdout.decode()
16891691

16901692

1691-
def convert_modular_file(modular_file):
1693+
def convert_modular_file(modular_file: str) -> dict[str, str]:
1694+
"""Convert a `modular_file` into all the different model-specific files it depicts."""
16921695
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
16931696
output = {}
16941697
if pattern is not None:
@@ -1712,34 +1715,30 @@ def convert_modular_file(modular_file):
17121715
)
17131716
ruffed_code = run_ruff(header + module.code, True)
17141717
formatted_code = run_ruff(ruffed_code, False)
1715-
output[file] = [formatted_code, ruffed_code]
1718+
output[file] = formatted_code
17161719
return output
17171720
else:
17181721
print(f"modular pattern not found in {modular_file}, exiting")
17191722
return {}
17201723

17211724

1722-
def save_modeling_file(modular_file, converted_file):
1723-
for file_type in converted_file:
1725+
def save_modeling_files(modular_file: str, converted_files: dict[str, str]):
1726+
"""Save all the `converted_files` from the `modular_file`."""
1727+
for file_type in converted_files:
17241728
file_name_prefix = file_type.split("*")[0]
17251729
file_name_suffix = file_type.split("*")[-1] if "*" in file_type else ""
17261730
new_file_name = modular_file.replace("modular_", f"{file_name_prefix}_").replace(
17271731
".py", f"{file_name_suffix}.py"
17281732
)
1729-
non_comment_lines = len(
1730-
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
1731-
)
1732-
if len(converted_file[file_type][0].strip()) > 0 and non_comment_lines > 0:
1733-
with open(new_file_name, "w", encoding="utf-8") as f:
1734-
f.write(converted_file[file_type][0])
1735-
else:
1736-
non_comment_lines = len(
1737-
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
1738-
)
1739-
if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0:
1740-
logger.warning("The modeling code contains errors, it's written without formatting")
1741-
with open(new_file_name, "w", encoding="utf-8") as f:
1742-
f.write(converted_file[file_type][1])
1733+
with open(new_file_name, "w", encoding="utf-8") as f:
1734+
f.write(converted_files[file_type])
1735+
1736+
1737+
def run_converter(modular_file: str):
1738+
"""Convert a modular file, and save resulting files."""
1739+
print(f"Converting {modular_file} to a single model single file format")
1740+
converted_files = convert_modular_file(modular_file)
1741+
save_modeling_files(modular_file, converted_files)
17431742

17441743

17451744
if __name__ == "__main__":
@@ -1759,9 +1758,17 @@ def save_modeling_file(modular_file, converted_file):
17591758
nargs="+",
17601759
help="A list of `modular_xxxx` files that should be converted to single model file",
17611760
)
1761+
parser.add_argument(
1762+
"--num_workers",
1763+
"-w",
1764+
default=-1,
1765+
type=int,
1766+
help="The number of workers to use. Default is -1, which means the number of CPU cores.",
1767+
)
17621768
args = parser.parse_args()
17631769
# Both arg represent the same data, but as positional and optional
17641770
files_to_parse = args.files if len(args.files) > 0 else args.files_to_parse
1771+
num_workers = mp.cpu_count() if args.num_workers == -1 else args.num_workers
17651772

17661773
if files_to_parse == ["all"]:
17671774
files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
@@ -1779,12 +1786,17 @@ def save_modeling_file(modular_file, converted_file):
17791786
raise ValueError(f"Cannot find a modular file for {model_name}. Please provide the full path.")
17801787
files_to_parse[i] = full_path
17811788

1782-
priority_list, _ = find_priority_list(files_to_parse)
1783-
priority_list = [item for sublist in priority_list for item in sublist] # flatten the list of lists
1784-
assert len(priority_list) == len(files_to_parse), "Some files will not be converted"
1789+
# This finds the correct order in which we should convert the modular files, so that a model relying on another one
1790+
# is necessarily converted after its dependencies
1791+
ordered_files, _ = find_priority_list(files_to_parse)
1792+
if sum(len(level_files) for level_files in ordered_files) != len(files_to_parse):
1793+
raise ValueError(
1794+
"Some files will not be converted because they do not appear in the dependency graph."
1795+
"This usually means that at least one modular file does not import any model-specific class"
1796+
)
17851797

1786-
for file_name in priority_list:
1787-
print(f"Converting {file_name} to a single model single file format")
1788-
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
1789-
converted_files = convert_modular_file(file_name)
1790-
converter = save_modeling_file(file_name, converted_files)
1798+
for dependency_level_files in ordered_files:
1799+
# Process files with diff
1800+
workers = min(num_workers, len(dependency_level_files))
1801+
with mp.Pool(workers) as pool:
1802+
pool.map(run_converter, dependency_level_files)

0 commit comments

Comments
 (0)