Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 36 additions & 104 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6252,84 +6252,6 @@ def get_expr_name() -> str:
else:
self.fail(message_registry.TYPE_ALWAYS_TRUE.format(format_expr_type()), expr)

def find_type_equals_check(
self, node: ComparisonExpr, expr_indices: list[int]
) -> tuple[TypeMap, TypeMap]:
"""Narrow types based on any checks of the type ``type(x) == T``

Args:
node: The node that might contain the comparison
expr_indices: The list of indices of expressions in ``node`` that are being
compared
"""
# exprs that are being passed into type
exprs_in_type_calls: list[Expression] = []

for index in expr_indices:
expr = node.operands[index]
if isinstance(expr, CallExpr) and is_type_call(expr):
exprs_in_type_calls.append(expr.args[0])

if not exprs_in_type_calls:
return {}, {}

# type that is being compared to type(expr)
type_being_compared: list[TypeRange] | None = None
# whether the type being compared to is final
is_final = False

for index in expr_indices:
expr = node.operands[index]
if isinstance(expr, CallExpr) and is_type_call(expr):
continue
current_type = self.get_isinstance_type(expr)
if current_type is None:
continue
if type_being_compared is not None:
# It doesn't really make sense to have several types being
# compared to the output of type (like type(x) == int == str)
# because whether that's true is solely dependent on what the
# types being compared are, so we don't try to narrow types any
# further because we can't really get any information about the
# type of x from that check
return {}, {}
if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo):
is_final = expr.node.is_final
type_being_compared = current_type

if_maps: list[TypeMap] = []
else_maps: list[TypeMap] = []
for expr in exprs_in_type_calls:
current_if_type, current_else_type = self.conditional_types_with_intersection(
self.lookup_type(expr), type_being_compared, expr
)
current_if_map, current_else_map = conditional_types_to_typemaps(
expr, current_if_type, current_else_type
)
if_maps.append(current_if_map)
else_maps.append(current_else_map)

def combine_maps(list_maps: list[TypeMap]) -> TypeMap:
"""Combine all typemaps in list_maps into one typemap"""
if all(m is None for m in list_maps):
return None
result_map = {}
for d in list_maps:
if d is not None:
result_map.update(d)
return result_map

if_map = combine_maps(if_maps)
# type(x) == T is only true when x has the same type as T, meaning
# that it can be false if x is an instance of a subclass of T. That means
# we can't do any narrowing in the else case unless T is final, in which
# case T can't be subclassed
if is_final:
else_map = combine_maps(else_maps)
else:
else_map = {}
return if_map, else_map

def find_isinstance_check(
self, node: Expression, *, in_boolean_context: bool = True
) -> tuple[TypeMap, TypeMap]:
Expand Down Expand Up @@ -6603,8 +6525,7 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
else_map: TypeMap

if operator in {"is", "is not", "==", "!="}:
if_map, else_map = self.equality_type_narrowing_helper(
node,
if_map, else_map = self.narrow_type_by_identity_equality(
operator,
operands,
operand_types,
Expand All @@ -6623,7 +6544,7 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
if left_index in narrowable_operand_index_to_hash:
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
if collection_item_type is not None:
if_map, else_map = self.narrow_type_by_equality(
if_map, else_map = self.narrow_type_by_identity_equality(
"==",
operands=[operands[left_index], operands[right_index]],
operand_types=[item_type, collection_item_type],
Expand Down Expand Up @@ -6677,28 +6598,7 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
# like `if 1 < len(x) < 4: ...`
return reduce_conditional_maps(self.find_tuple_len_narrowing(node), use_meet=True)

def equality_type_narrowing_helper(
self,
node: ComparisonExpr,
operator: str,
operands: list[Expression],
operand_types: list[Type],
expr_indices: list[int],
narrowable_indices: AbstractSet[int],
) -> tuple[TypeMap, TypeMap]:
"""Calculate type maps for '==', '!=', 'is' or 'is not' expression."""
# If we haven't been able to narrow types yet, we might be dealing with a
# explicit type(x) == some_type check
if_map, else_map = self.narrow_type_by_equality(
operator, operands, operand_types, expr_indices, narrowable_indices=narrowable_indices
)
if node is not None:
type_if_map, type_else_map = self.find_type_equals_check(node, expr_indices)
if_map = and_conditional_maps(if_map, type_if_map)
else_map = and_conditional_maps(else_map, type_else_map)
return if_map, else_map

def narrow_type_by_equality(
def narrow_type_by_identity_equality(
self,
operator: str,
operands: list[Expression],
Expand Down Expand Up @@ -6826,6 +6726,38 @@ def narrow_type_by_equality(
else_map = {} # this is the big difference compared to the above
partial_type_maps.append((if_map, else_map))

exprs_in_type_calls = []
for i in expr_indices:
expr = operands[i]
if isinstance(expr, CallExpr) and is_type_call(expr):
exprs_in_type_calls.append(expr.args[0])

if exprs_in_type_calls:
for expr_in_type_call in exprs_in_type_calls:
for i in expr_indices:
expr = operands[i]
if isinstance(expr, CallExpr) and is_type_call(expr):
continue

current_type_range = self.get_isinstance_type(expr)
if_map, else_map = conditional_types_to_typemaps(
expr_in_type_call,
*self.conditional_types_with_intersection(
self.lookup_type(expr_in_type_call),
current_type_range,
expr_in_type_call,
),
)

is_final = (
expr.node.is_final
if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo)
else False
)
if not is_final:
else_map = {}
partial_type_maps.append((if_map, else_map))

# We will not have duplicate entries in our type maps if we only have two operands,
# so we can skip running meets on the intersections
return reduce_conditional_maps(partial_type_maps, use_meet=len(operands) > 2)
Expand Down Expand Up @@ -7021,7 +6953,7 @@ def refine_away_none_in_comparison(
"""Produces conditional type maps refining away None in an identity/equality chain.

For more details about what the different arguments mean, see the
docstring of 'narrow_type_by_equality' up above.
docstring of 'narrow_type_by_identity_equality' up above.
"""

non_optional_types = []
Expand Down
2 changes: 1 addition & 1 deletion mypy/checker_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def conditional_types_with_intersection(
raise NotImplementedError

@abstractmethod
def narrow_type_by_equality(
def narrow_type_by_identity_equality(
self,
operator: str,
operands: list[Expression],
Expand Down
2 changes: 1 addition & 1 deletion mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def visit_value_pattern(self, o: ValuePattern) -> PatternType:
node = TempNode(current_type)
# Value patterns are essentially a syntactic sugar on top of `if x == Value`.
# They should be treated equivalently.
ok_map, rest_map = self.chk.narrow_type_by_equality(
ok_map, rest_map = self.chk.narrow_type_by_identity_equality(
"==", [node, TempNode(typ)], [current_type, typ], [0, 1], {0}
)
ok_type = ok_map.get(node, current_type) if ok_map is not None else UninhabitedType()
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-isinstance.test
Original file line number Diff line number Diff line change
Expand Up @@ -2767,7 +2767,7 @@ from typing import Union

x: Union[int, str]
if type(x) == int == str:
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str"
reveal_type(x) # E: Statement is unreachable
else:
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str"

Expand Down