-
Notifications
You must be signed in to change notification settings - Fork 4k
[python-package] highlight the path a sample takes through a tree in plot_tree and create_tree_digraph (fixes #4784)
#5119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
plot_tree and create_tree_digraphplot_tree and create_tree_digraph (fixes #4784)
jameslamb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this! I think this is a really powerful feature, and I support adding it by adding new arguments to create_tree_digraph() and plot_tree().
I agree with @StrikerRUS 's suggestion to use color on the arrows and edges, I think it looks great.
A few other recommendations:
- Instead of
x, would you consider calling this argumentexample_caseorexample_observation?- I think the word
sampleshould be avoided, since in other parts of LightGBM that's used to mean "randomly choose".
- I think the word
- could you please add some tests on this new behavior? See the existing tests in
test_plotting.pyfor reference, for example:def test_create_tree_digraph(breast_cancer_split):
- what will the behavior of this code be if
xhas more than one row in it? Would you consider adding some validation that raises an exception with an informative error ifxhas more than one row?
And some questions (some of which might be about these plotting functions generally and outside the scope of this PR, I'm not sure)
- will this work with categorical features? If I remember correctly,
decision_typecan include "or" rules like|| - will this work with categorical features stored as
pandascategorical types?- I think it might not, since the proposed code does a direct comparison to the values in
x, so if the data is apd.Seriesit isn't passed through the_data_from_pandas()logic LightGBM/python-package/lightgbm/basic.py
Line 785 in 7820746
data = _data_from_pandas(data, None, None, self.pandas_categorical)[0] LightGBM/python-package/lightgbm/basic.py
Lines 549 to 554 in 7820746
for col, category in zip(cat_cols, pandas_categorical): if list(data[col].cat.categories) != list(category): data[col] = data[col].cat.set_categories(category) if len(cat_cols): # cat_cols is list data = data.copy() # not alter origin DataFrame data[cat_cols] = data[cat_cols].apply(lambda x: x.cat.codes).replace({-1: np.nan})
- I think it might not, since the proposed code does a direct comparison to the values in
jameslamb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding additional tests and categorical support! Really nice work!
I left one small suggestion. I'd still like to test this locally a little bit more, will do that tomorrow. But overall I'm really excited about this 🤩
jameslamb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VERY nice work @jmoralez ! I tested tonight in more depth and found everything worked really well. These plots are awesome 🤩 .
I was looking for an example real-world dataset that had informative categorical features, and found that you can filter the UCI Machine Learning Repository by "contains categorical features": https://archive.ics.uci.edu/ml/datasets.php?format=&task=&att=cat&area=&numAtt=&numIns=&type=&sort=nameUp&view=table.
Found that this one worked well for my investigation: https://archive.ics.uci.edu/ml/machine-learning-databases/solar-flare/.
That dataset has ONLY categorical features, so it's useful for testing categorical-specific stuff. I think I'll turn to it in the future when experimenting with lightgbm and other ML libraries.
example code (click me)
import lightgbm as lgb
import pandas as pd
data_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/solar-flare/flare.data2"
df = pd.read_csv(
filepath_or_buffer=data_url,
sep=" ",
header=1,
skip_blank_lines=True,
names=[
"class_code",
"largest_spot_size_code",
"spot_distribution_code",
"activity",
"evolution",
"prev_flare_24h",
"historically_complex",
"became_complex_on_this_pass",
"area",
"largest_spot_area",
"c_class_flare_count",
"m_class_flare_count",
"x_class_flare_count"
]
)
for col in df.columns:
if pd.api.types.is_object_dtype(df[col]):
df[col] = df[col].astype("category")
y = (df[["c_class_flare_count"]] > 0).astype("int").values.ravel()
feature_names = [c for c in df.columns if not c.endswith("flare_count")]
X = df[feature_names]
dtrain = lgb.Dataset(
X,
y,
feature_name=feature_names,
categorical_feature="auto",
params={
"min_data_in_bin": 5,
"min_data_per_group": 1
}
)
bst = lgb.train(
params={
'num_leaves': 7,
'objective': 'binary',
'min_data_in_leaf': 1,
'verbose': 1
},
train_set=dtrain,
num_boost_round=3
)
example_case = X[:1]
pd.set_option('display.width', 1000)
print(example_case)
print("---")
print(X["class_code"].cat.categories)
print("---")
lgb.create_tree_digraph(bst, example_case=example_case, tree_index=1)Tried with a few records, and can see that the records go the correct way in the split, including the handling of || for categoricals!
In a future PR, could you please add an example using this new function to https://github.com/microsoft/LightGBM/blob/1cc9f9dcee15586aefdd8271a1c04ad010d3c53a/examples/python-guide/plot_example.py?
Sure. Thanks for the great suggestions on your review, as always! |
StrikerRUS
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
But I'm afraid that this PR doesn't cover of possible scenarios of params and values in example_case (see inline comment below).
python-package/lightgbm/plotting.py
Outdated
| if root['decision_type'] == '==': | ||
| thresholds = {int(x) for x in root['threshold'].split('||')} | ||
| if example_case[split_feature] in thresholds: | ||
| direction = 'left' | ||
| else: | ||
| direction = 'right' | ||
| else: | ||
| direction = 'left' if example_case[split_feature] <= root['threshold'] else 'right' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this code cover all combinations of use_missing, zero_as_missing params and NaN values in example_case?
Lines 520 to 560 in 11110c5
| std::string Tree::NumericalDecisionIfElse(int node) const { | |
| std::stringstream str_buf; | |
| Common::C_stringstream(str_buf); | |
| str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2); | |
| uint8_t missing_type = GetMissingType(decision_type_[node]); | |
| bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask); | |
| if (missing_type == MissingType::None | |
| || (missing_type == MissingType::Zero && default_left && kZeroThreshold < threshold_[node])) { | |
| str_buf << "if (fval <= " << threshold_[node] << ") {"; | |
| } else if (missing_type == MissingType::Zero) { | |
| if (default_left) { | |
| str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {"; | |
| } else { | |
| str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {"; | |
| } | |
| } else { | |
| if (default_left) { | |
| str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {"; | |
| } else { | |
| str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {"; | |
| } | |
| } | |
| return str_buf.str(); | |
| } | |
| std::string Tree::CategoricalDecisionIfElse(int node) const { | |
| uint8_t missing_type = GetMissingType(decision_type_[node]); | |
| std::stringstream str_buf; | |
| Common::C_stringstream(str_buf); | |
| if (missing_type == MissingType::NaN) { | |
| str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }"; | |
| } else { | |
| str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }"; | |
| } | |
| int cat_idx = static_cast<int>(threshold_[node]); | |
| str_buf << "if (int_fval >= 0 && int_fval < 32 * ("; | |
| str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx]; | |
| str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx]; | |
| str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {"; | |
| return str_buf.str(); | |
| } |
Refer to #2921.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StrikerRUS
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome contribution!
Just some minor comments below:
python-package/lightgbm/plotting.py
Outdated
| def _determine_direction_for_categorical_split(fval: float, thresholds: str, missing_type: str) -> str: | ||
| if missing_type == 'None': | ||
| int_fval = -1 if math.isnan(fval) else int(fval) | ||
| else: | ||
| int_fval = 0 if math.isnan(fval) else int(fval) | ||
| int_thresholds = {int(t) for t in thresholds.split('||')} | ||
| return 'left' if int_fval in int_thresholds else 'right' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like NaN values are treaded here properly, right?
Refer to #4468.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated this in 1e6f95d to do the same as #4468. -1 are replaced by nans in
LightGBM/python-package/lightgbm/basic.py
Line 559 in df14e60
| data[cat_cols] = data[cat_cols].apply(lambda x: x.cat.codes).replace({-1: np.nan}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this!
Do we need to update the following code as well (in a separate PR)?
Lines 549 to 553 in fb37e50
| if (missing_type == MissingType::NaN) { | |
| str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }"; | |
| } else { | |
| str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }"; | |
| } |
python-package/lightgbm/plotting.py
Outdated
|
|
||
| def _determine_direction_for_numeric_split(fval: float, threshold: float, missing_type: str, default_left: bool) -> str: | ||
| le_threshold = fval <= threshold | ||
| if missing_type == _MissingType.NONE or (missing_type == _MissingType.ZERO and default_left and ZERO_THRESHOLD < threshold): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You cannot compare string with Enum this way. Such comparison will always return False.
missing_type = 'Zero'
missing_type == _MissingType.ZERO # Falsemissing_type should either be converted to Enum first or you should compare in the following way: missing_type == _MissingType.ZERO.value.
Also, this incorrect comparison shows that added tests are unreliable unfortunately 😢 They should fail, but all CI is green right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the numerical split definition to the one here:
LightGBM/include/LightGBM/tree.h
Lines 335 to 353 in f94050a
| inline int NumericalDecision(double fval, int node) const { | |
| uint8_t missing_type = GetMissingType(decision_type_[node]); | |
| if (std::isnan(fval) && missing_type != MissingType::NaN) { | |
| fval = 0.0f; | |
| } | |
| if ((missing_type == MissingType::Zero && IsZero(fval)) | |
| || (missing_type == MissingType::NaN && std::isnan(fval))) { | |
| if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) { | |
| return left_child_[node]; | |
| } else { | |
| return right_child_[node]; | |
| } | |
| } | |
| if (fval <= threshold_[node]) { | |
| return left_child_[node]; | |
| } else { | |
| return right_child_[node]; | |
| } | |
| } |
Now the test fails if I comment out the enum conversion. Let me know what you think, it looks different than the previous one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this!
Do we need to sync the code in if/else dump as well (in a separate PR)?
Lines 520 to 560 in 11110c5
| std::string Tree::NumericalDecisionIfElse(int node) const { | |
| std::stringstream str_buf; | |
| Common::C_stringstream(str_buf); | |
| str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2); | |
| uint8_t missing_type = GetMissingType(decision_type_[node]); | |
| bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask); | |
| if (missing_type == MissingType::None | |
| || (missing_type == MissingType::Zero && default_left && kZeroThreshold < threshold_[node])) { | |
| str_buf << "if (fval <= " << threshold_[node] << ") {"; | |
| } else if (missing_type == MissingType::Zero) { | |
| if (default_left) { | |
| str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {"; | |
| } else { | |
| str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {"; | |
| } | |
| } else { | |
| if (default_left) { | |
| str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {"; | |
| } else { | |
| str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {"; | |
| } | |
| } | |
| return str_buf.str(); | |
| } | |
| std::string Tree::CategoricalDecisionIfElse(int node) const { | |
| uint8_t missing_type = GetMissingType(decision_type_[node]); | |
| std::stringstream str_buf; | |
| Common::C_stringstream(str_buf); | |
| if (missing_type == MissingType::NaN) { | |
| str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }"; | |
| } else { | |
| str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }"; | |
| } | |
| int cat_idx = static_cast<int>(threshold_[node]); | |
| str_buf << "if (int_fval >= 0 && int_fval < 32 * ("; | |
| str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx]; | |
| str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx]; | |
| str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {"; | |
| return str_buf.str(); | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think that code may need to be updated. I'll take a look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kindly ping @jmoralez for possible following-up PR.
|
It seems azure pipeline is broken 😢, but I don't have permission to fix it now. @shiyu1994 , can you take a look? |
Co-authored-by: Nikita Titov <[email protected]>
|
@guolinke I'm working on this and will fix this soon. Sorry for the delay. |
StrikerRUS
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Thank you very much!
|
This pull request has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this. |

This adds an argument
x(or we can maybe call itsample) that takes a single sample and highlights the path that sample takes through a tree in the tree plotting functions. The path is highlighted by making the edges of the nodes as well as the edges blue and bold. Here's an example for different tree sizes.Sample script
Closes #4784.