1515import argparse
1616import glob
1717import importlib
18+ import multiprocessing as mp
1819import os
1920import re
2021import subprocess
@@ -1226,7 +1227,8 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
12261227 if m .matches (node .module , m .Attribute ()):
12271228 for imported_ in node .names :
12281229 _import = re .search (
1229- rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({ self .match_patterns } )_.*" , import_statement
1230+ rf"(?:transformers\.models\.)|(?:\.\.\.models\.)|(?:\.\.)\w+\.({ self .match_patterns } )_.*" ,
1231+ import_statement ,
12301232 )
12311233 if _import :
12321234 source = _import .group (1 )
@@ -1688,7 +1690,8 @@ def run_ruff(code, check=False):
16881690 return stdout .decode ()
16891691
16901692
1691- def convert_modular_file (modular_file ):
1693+ def convert_modular_file (modular_file : str ) -> dict [str , str ]:
1694+ """Convert a `modular_file` into all the different model-specific files it depicts."""
16921695 pattern = re .search (r"modular_(.*)(?=\.py$)" , modular_file )
16931696 output = {}
16941697 if pattern is not None :
@@ -1712,34 +1715,30 @@ def convert_modular_file(modular_file):
17121715 )
17131716 ruffed_code = run_ruff (header + module .code , True )
17141717 formatted_code = run_ruff (ruffed_code , False )
1715- output [file ] = [ formatted_code , ruffed_code ]
1718+ output [file ] = formatted_code
17161719 return output
17171720 else :
17181721 print (f"modular pattern not found in { modular_file } , exiting" )
17191722 return {}
17201723
17211724
1722- def save_modeling_file (modular_file , converted_file ):
1723- for file_type in converted_file :
1725+ def save_modeling_files (modular_file : str , converted_files : dict [str , str ]):
1726+ """Save all the `converted_files` from the `modular_file`."""
1727+ for file_type in converted_files :
17241728 file_name_prefix = file_type .split ("*" )[0 ]
17251729 file_name_suffix = file_type .split ("*" )[- 1 ] if "*" in file_type else ""
17261730 new_file_name = modular_file .replace ("modular_" , f"{ file_name_prefix } _" ).replace (
17271731 ".py" , f"{ file_name_suffix } .py"
17281732 )
1729- non_comment_lines = len (
1730- [line for line in converted_file [file_type ][0 ].strip ().split ("\n " ) if not line .strip ().startswith ("#" )]
1731- )
1732- if len (converted_file [file_type ][0 ].strip ()) > 0 and non_comment_lines > 0 :
1733- with open (new_file_name , "w" , encoding = "utf-8" ) as f :
1734- f .write (converted_file [file_type ][0 ])
1735- else :
1736- non_comment_lines = len (
1737- [line for line in converted_file [file_type ][0 ].strip ().split ("\n " ) if not line .strip ().startswith ("#" )]
1738- )
1739- if len (converted_file [file_type ][1 ].strip ()) > 0 and non_comment_lines > 0 :
1740- logger .warning ("The modeling code contains errors, it's written without formatting" )
1741- with open (new_file_name , "w" , encoding = "utf-8" ) as f :
1742- f .write (converted_file [file_type ][1 ])
1733+ with open (new_file_name , "w" , encoding = "utf-8" ) as f :
1734+ f .write (converted_files [file_type ])
1735+
1736+
1737+ def run_converter (modular_file : str ):
1738+ """Convert a modular file, and save resulting files."""
1739+ print (f"Converting { modular_file } to a single model single file format" )
1740+ converted_files = convert_modular_file (modular_file )
1741+ save_modeling_files (modular_file , converted_files )
17431742
17441743
17451744if __name__ == "__main__" :
@@ -1759,9 +1758,17 @@ def save_modeling_file(modular_file, converted_file):
17591758 nargs = "+" ,
17601759 help = "A list of `modular_xxxx` files that should be converted to single model file" ,
17611760 )
1761+ parser .add_argument (
1762+ "--num_workers" ,
1763+ "-w" ,
1764+ default = - 1 ,
1765+ type = int ,
1766+ help = "The number of workers to use. Default is -1, which means the number of CPU cores." ,
1767+ )
17621768 args = parser .parse_args ()
17631769 # Both arg represent the same data, but as positional and optional
17641770 files_to_parse = args .files if len (args .files ) > 0 else args .files_to_parse
1771+ num_workers = mp .cpu_count () if args .num_workers == - 1 else args .num_workers
17651772
17661773 if files_to_parse == ["all" ]:
17671774 files_to_parse = glob .glob ("src/transformers/models/**/modular_*.py" , recursive = True )
@@ -1779,12 +1786,17 @@ def save_modeling_file(modular_file, converted_file):
17791786 raise ValueError (f"Cannot find a modular file for { model_name } . Please provide the full path." )
17801787 files_to_parse [i ] = full_path
17811788
1782- priority_list , _ = find_priority_list (files_to_parse )
1783- priority_list = [item for sublist in priority_list for item in sublist ] # flatten the list of lists
1784- assert len (priority_list ) == len (files_to_parse ), "Some files will not be converted"
1789+ # This finds the correct order in which we should convert the modular files, so that a model relying on another one
1790+ # is necessarily converted after its dependencies
1791+ ordered_files , _ = find_priority_list (files_to_parse )
1792+ if sum (len (level_files ) for level_files in ordered_files ) != len (files_to_parse ):
1793+ raise ValueError (
1794+ "Some files will not be converted because they do not appear in the dependency graph."
1795+ "This usually means that at least one modular file does not import any model-specific class"
1796+ )
17851797
1786- for file_name in priority_list :
1787- print ( f"Converting { file_name } to a single model single file format" )
1788- module_path = file_name . replace ( "/" , "." ). replace ( ".py" , "" ). replace ( "src." , "" )
1789- converted_files = convert_modular_file ( file_name )
1790- converter = save_modeling_file ( file_name , converted_files )
1798+ for dependency_level_files in ordered_files :
1799+ # Process files with diff
1800+ workers = min ( num_workers , len ( dependency_level_files ) )
1801+ with mp . Pool ( workers ) as pool :
1802+ pool . map ( run_converter , dependency_level_files )
0 commit comments