Skip to content

Conversation

@jmoralez
Copy link
Collaborator

@jmoralez jmoralez commented Apr 2, 2022

This adds an argument x (or we can maybe call it sample) that takes a single sample and highlights the path that sample takes through a tree in the tree plotting functions. The path is highlighted by making the edges of the nodes as well as the edges blue and bold. Here's an example for different tree sizes.

Sample script
import lightgbm as lgb
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression

X, y = make_regression(1_000, n_features=4, n_informative=2, random_state=0)
ds = lgb.Dataset(X, y)

fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(16, 6))
leaves = [7, 15, 31, 63]
for i, (axi, num_leaves) in enumerate(zip(ax.flat, leaves)):
    bst = lgb.train({'num_leaves': num_leaves, 'verbose': -1}, ds, num_boost_round=5)
    lgb.plot_tree(bst, x=X[-1], ax=axi)

image

Closes #4784.

@jmoralez jmoralez changed the title [python-package] highlight the path a sample takes through a tree in plot_tree and create_tree_digraph [python-package] highlight the path a sample takes through a tree in plot_tree and create_tree_digraph (fixes #4784) Apr 2, 2022
Copy link
Collaborator

@jameslamb jameslamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! I think this is a really powerful feature, and I support adding it by adding new arguments to create_tree_digraph() and plot_tree().

I agree with @StrikerRUS 's suggestion to use color on the arrows and edges, I think it looks great.

A few other recommendations:

  • Instead of x, would you consider calling this argument example_case or example_observation?
    • I think the word sample should be avoided, since in other parts of LightGBM that's used to mean "randomly choose".
  • could you please add some tests on this new behavior? See the existing tests in test_plotting.py for reference, for example:
  • what will the behavior of this code be if x has more than one row in it? Would you consider adding some validation that raises an exception with an informative error if x has more than one row?

And some questions (some of which might be about these plotting functions generally and outside the scope of this PR, I'm not sure)

  • will this work with categorical features? If I remember correctly, decision_type can include "or" rules like ||
  • will this work with categorical features stored as pandas categorical types?
    • I think it might not, since the proposed code does a direct comparison to the values in x, so if the data is a pd.Series it isn't passed through the _data_from_pandas() logic
    • data = _data_from_pandas(data, None, None, self.pandas_categorical)[0]
    • for col, category in zip(cat_cols, pandas_categorical):
      if list(data[col].cat.categories) != list(category):
      data[col] = data[col].cat.set_categories(category)
      if len(cat_cols): # cat_cols is list
      data = data.copy() # not alter origin DataFrame
      data[cat_cols] = data[cat_cols].apply(lambda x: x.cat.codes).replace({-1: np.nan})

@jmoralez jmoralez requested a review from jameslamb April 26, 2022 02:48
Copy link
Collaborator

@jameslamb jameslamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding additional tests and categorical support! Really nice work!

I left one small suggestion. I'd still like to test this locally a little bit more, will do that tomorrow. But overall I'm really excited about this 🤩

Copy link
Collaborator

@jameslamb jameslamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VERY nice work @jmoralez ! I tested tonight in more depth and found everything worked really well. These plots are awesome 🤩 .

I was looking for an example real-world dataset that had informative categorical features, and found that you can filter the UCI Machine Learning Repository by "contains categorical features": https://archive.ics.uci.edu/ml/datasets.php?format=&task=&att=cat&area=&numAtt=&numIns=&type=&sort=nameUp&view=table.

Found that this one worked well for my investigation: https://archive.ics.uci.edu/ml/machine-learning-databases/solar-flare/.

That dataset has ONLY categorical features, so it's useful for testing categorical-specific stuff. I think I'll turn to it in the future when experimenting with lightgbm and other ML libraries.

example code (click me)
import lightgbm as lgb
import pandas as pd
data_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/solar-flare/flare.data2"
df = pd.read_csv(
    filepath_or_buffer=data_url,
    sep=" ",
    header=1,
    skip_blank_lines=True,
    names=[
        "class_code",
        "largest_spot_size_code",
        "spot_distribution_code",
        "activity",
        "evolution",
        "prev_flare_24h",
        "historically_complex",
        "became_complex_on_this_pass",
        "area",
        "largest_spot_area",
        "c_class_flare_count",
        "m_class_flare_count",
        "x_class_flare_count"
    ]
)

for col in df.columns:
    if pd.api.types.is_object_dtype(df[col]):
        df[col] = df[col].astype("category")

y = (df[["c_class_flare_count"]] > 0).astype("int").values.ravel()
feature_names = [c for c in df.columns if not c.endswith("flare_count")]
X = df[feature_names]

dtrain = lgb.Dataset(
    X,
    y,
    feature_name=feature_names,
    categorical_feature="auto",
    params={
        "min_data_in_bin": 5,
        "min_data_per_group": 1
    }
)
bst = lgb.train(
    params={
        'num_leaves': 7,
        'objective': 'binary',
        'min_data_in_leaf': 1,
        'verbose': 1
    },
    train_set=dtrain,
    num_boost_round=3
)

example_case = X[:1]
pd.set_option('display.width', 1000)
print(example_case)
print("---")
print(X["class_code"].cat.categories)
print("---")
lgb.create_tree_digraph(bst, example_case=example_case, tree_index=1)

Tried with a few records, and can see that the records go the correct way in the split, including the handling of || for categoricals!

image


In a future PR, could you please add an example using this new function to https://github.com/microsoft/LightGBM/blob/1cc9f9dcee15586aefdd8271a1c04ad010d3c53a/examples/python-guide/plot_example.py?

@jameslamb jameslamb removed request for hzy46 and tongwu-sh June 5, 2022 05:45
@jmoralez
Copy link
Collaborator Author

jmoralez commented Jun 6, 2022

In a future PR, could you please add an example using this new function

Sure.

Thanks for the great suggestions on your review, as always!

Copy link
Collaborator

@StrikerRUS StrikerRUS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!
But I'm afraid that this PR doesn't cover of possible scenarios of params and values in example_case (see inline comment below).

Comment on lines 465 to 472
if root['decision_type'] == '==':
thresholds = {int(x) for x in root['threshold'].split('||')}
if example_case[split_feature] in thresholds:
direction = 'left'
else:
direction = 'right'
else:
direction = 'left' if example_case[split_feature] <= root['threshold'] else 'right'
Copy link
Collaborator

@StrikerRUS StrikerRUS Jun 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this code cover all combinations of use_missing, zero_as_missing params and NaN values in example_case?

LightGBM/src/io/tree.cpp

Lines 520 to 560 in 11110c5

std::string Tree::NumericalDecisionIfElse(int node) const {
std::stringstream str_buf;
Common::C_stringstream(str_buf);
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
uint8_t missing_type = GetMissingType(decision_type_[node]);
bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
if (missing_type == MissingType::None
|| (missing_type == MissingType::Zero && default_left && kZeroThreshold < threshold_[node])) {
str_buf << "if (fval <= " << threshold_[node] << ") {";
} else if (missing_type == MissingType::Zero) {
if (default_left) {
str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
} else {
str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
}
} else {
if (default_left) {
str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
} else {
str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
}
}
return str_buf.str();
}
std::string Tree::CategoricalDecisionIfElse(int node) const {
uint8_t missing_type = GetMissingType(decision_type_[node]);
std::stringstream str_buf;
Common::C_stringstream(str_buf);
if (missing_type == MissingType::NaN) {
str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
} else {
str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
}
int cat_idx = static_cast<int>(threshold_[node]);
str_buf << "if (int_fval >= 0 && int_fval < 32 * (";
str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx];
str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];
str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {";
return str_buf.str();
}

Refer to #2921.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the decisions for numerical splits in 4c19af5 and for categorical in 9f65d63

Copy link
Collaborator

@StrikerRUS StrikerRUS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome contribution!
Just some minor comments below:

Comment on lines 435 to 441
def _determine_direction_for_categorical_split(fval: float, thresholds: str, missing_type: str) -> str:
if missing_type == 'None':
int_fval = -1 if math.isnan(fval) else int(fval)
else:
int_fval = 0 if math.isnan(fval) else int(fval)
int_thresholds = {int(t) for t in thresholds.split('||')}
return 'left' if int_fval in int_thresholds else 'right'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like NaN values are treaded here properly, right?
Refer to #4468.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this in 1e6f95d to do the same as #4468. -1 are replaced by nans in

data[cat_cols] = data[cat_cols].apply(lambda x: x.cat.codes).replace({-1: np.nan})
so we only get nan or a non-negative integer.

Copy link
Collaborator

@StrikerRUS StrikerRUS Jul 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this!

Do we need to update the following code as well (in a separate PR)?

LightGBM/src/io/tree.cpp

Lines 549 to 553 in fb37e50

if (missing_type == MissingType::NaN) {
str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
} else {
str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
}


def _determine_direction_for_numeric_split(fval: float, threshold: float, missing_type: str, default_left: bool) -> str:
le_threshold = fval <= threshold
if missing_type == _MissingType.NONE or (missing_type == _MissingType.ZERO and default_left and ZERO_THRESHOLD < threshold):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You cannot compare string with Enum this way. Such comparison will always return False.

missing_type = 'Zero'
missing_type == _MissingType.ZERO  # False

missing_type should either be converted to Enum first or you should compare in the following way: missing_type == _MissingType.ZERO.value.

Also, this incorrect comparison shows that added tests are unreliable unfortunately 😢 They should fail, but all CI is green right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the numerical split definition to the one here:

inline int NumericalDecision(double fval, int node) const {
uint8_t missing_type = GetMissingType(decision_type_[node]);
if (std::isnan(fval) && missing_type != MissingType::NaN) {
fval = 0.0f;
}
if ((missing_type == MissingType::Zero && IsZero(fval))
|| (missing_type == MissingType::NaN && std::isnan(fval))) {
if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) {
return left_child_[node];
} else {
return right_child_[node];
}
}
if (fval <= threshold_[node]) {
return left_child_[node];
} else {
return right_child_[node];
}
}
.
Now the test fails if I comment out the enum conversion. Let me know what you think, it looks different than the previous one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this!

Do we need to sync the code in if/else dump as well (in a separate PR)?

LightGBM/src/io/tree.cpp

Lines 520 to 560 in 11110c5

std::string Tree::NumericalDecisionIfElse(int node) const {
std::stringstream str_buf;
Common::C_stringstream(str_buf);
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
uint8_t missing_type = GetMissingType(decision_type_[node]);
bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
if (missing_type == MissingType::None
|| (missing_type == MissingType::Zero && default_left && kZeroThreshold < threshold_[node])) {
str_buf << "if (fval <= " << threshold_[node] << ") {";
} else if (missing_type == MissingType::Zero) {
if (default_left) {
str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
} else {
str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
}
} else {
if (default_left) {
str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
} else {
str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
}
}
return str_buf.str();
}
std::string Tree::CategoricalDecisionIfElse(int node) const {
uint8_t missing_type = GetMissingType(decision_type_[node]);
std::stringstream str_buf;
Common::C_stringstream(str_buf);
if (missing_type == MissingType::NaN) {
str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
} else {
str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
}
int cat_idx = static_cast<int>(threshold_[node]);
str_buf << "if (int_fval >= 0 && int_fval < 32 * (";
str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx];
str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];
str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {";
return str_buf.str();
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that code may need to be updated. I'll take a look.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kindly ping @jmoralez for possible following-up PR.

@guolinke
Copy link
Collaborator

It seems azure pipeline is broken 😢, but I don't have permission to fix it now. @shiyu1994 , can you take a look?

@shiyu1994
Copy link
Collaborator

@guolinke I'm working on this and will fix this soon. Sorry for the delay.

Copy link
Collaborator

@StrikerRUS StrikerRUS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Thank you very much!

@github-actions
Copy link
Contributor

This pull request has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Aug 19, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add a function to plot tree with a case

6 participants