|
54 | 54 | from collections import defaultdict |
55 | 55 |
|
56 | 56 | from typing import ( |
57 | | - List, Dict, Tuple, Iterable, Mapping, Optional, Set, cast, |
| 57 | + List, Dict, Tuple, Iterable, Mapping, Optional, Set, Union, cast, |
58 | 58 | ) |
59 | 59 | from typing_extensions import Final |
60 | 60 |
|
|
84 | 84 | from mypy.options import Options as MypyOptions |
85 | 85 | from mypy.types import ( |
86 | 86 | Type, TypeStrVisitor, CallableType, UnboundType, NoneType, TupleType, TypeList, Instance, |
87 | | - AnyType, get_proper_type |
| 87 | + AnyType, get_proper_type, OVERLOAD_NAMES |
88 | 88 | ) |
89 | 89 | from mypy.visitor import NodeVisitor |
90 | 90 | from mypy.find_sources import create_source_list, InvalidSourceList |
|
93 | 93 | from mypy.traverser import all_yield_expressions, has_return_statement, has_yield_expression |
94 | 94 | from mypy.moduleinspect import ModuleInspect |
95 | 95 |
|
| 96 | +TYPING_MODULE_NAMES: Final = ( |
| 97 | + 'typing', |
| 98 | + 'typing_extensions', |
| 99 | +) |
96 | 100 |
|
97 | 101 | # Common ways of naming package containing vendored modules. |
98 | 102 | VENDOR_PACKAGES: Final = [ |
@@ -768,13 +772,15 @@ def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> Tup |
768 | 772 | self.add_decorator('property') |
769 | 773 | self.add_decorator('abc.abstractmethod') |
770 | 774 | is_abstract = True |
771 | | - elif self.refers_to_fullname(name, 'typing.overload'): |
| 775 | + elif self.refers_to_fullname(name, OVERLOAD_NAMES): |
772 | 776 | self.add_decorator(name) |
773 | 777 | self.add_typing_import('overload') |
774 | 778 | is_overload = True |
775 | 779 | return is_abstract, is_overload |
776 | 780 |
|
777 | | - def refers_to_fullname(self, name: str, fullname: str) -> bool: |
| 781 | + def refers_to_fullname(self, name: str, fullname: Union[str, Tuple[str, ...]]) -> bool: |
| 782 | + if isinstance(fullname, tuple): |
| 783 | + return any(self.refers_to_fullname(name, fname) for fname in fullname) |
778 | 784 | module, short = fullname.rsplit('.', 1) |
779 | 785 | return (self.import_tracker.module_for.get(name) == module and |
780 | 786 | (name == short or |
@@ -825,8 +831,8 @@ def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) -> |
825 | 831 | expr.expr.name + '.coroutine', |
826 | 832 | expr.expr.name) |
827 | 833 | elif (isinstance(expr.expr, NameExpr) and |
828 | | - (expr.expr.name == 'typing' or |
829 | | - self.import_tracker.reverse_alias.get(expr.expr.name) == 'typing') and |
| 834 | + (expr.expr.name in TYPING_MODULE_NAMES or |
| 835 | + self.import_tracker.reverse_alias.get(expr.expr.name) in TYPING_MODULE_NAMES) and |
830 | 836 | expr.name == 'overload'): |
831 | 837 | self.import_tracker.require_name(expr.expr.name) |
832 | 838 | self.add_decorator('%s.%s' % (expr.expr.name, 'overload')) |
@@ -1060,7 +1066,7 @@ def visit_import_from(self, o: ImportFrom) -> None: |
1060 | 1066 | and name not in self.referenced_names |
1061 | 1067 | and (not self._all_ or name in IGNORED_DUNDERS) |
1062 | 1068 | and not is_private |
1063 | | - and module not in ('abc', 'typing', 'asyncio')): |
| 1069 | + and module not in ('abc', *TYPING_MODULE_NAMES, 'asyncio')): |
1064 | 1070 | # An imported name that is never referenced in the module is assumed to be |
1065 | 1071 | # exported, unless there is an explicit __all__. Note that we need to special |
1066 | 1072 | # case 'abc' since some references are deleted during semantic analysis. |
@@ -1118,8 +1124,7 @@ def get_init(self, lvalue: str, rvalue: Expression, |
1118 | 1124 | typename = self.print_annotation(annotation) |
1119 | 1125 | if (isinstance(annotation, UnboundType) and not annotation.args and |
1120 | 1126 | annotation.name == 'Final' and |
1121 | | - self.import_tracker.module_for.get('Final') in ('typing', |
1122 | | - 'typing_extensions')): |
| 1127 | + self.import_tracker.module_for.get('Final') in TYPING_MODULE_NAMES): |
1123 | 1128 | # Final without type argument is invalid in stubs. |
1124 | 1129 | final_arg = self.get_str_type_of_node(rvalue) |
1125 | 1130 | typename += '[{}]'.format(final_arg) |
|
0 commit comments