Skip to content

Can't create flow from ColumnTransformer #825

@amueller

Description

@amueller

Trying to write a general ColumnTransformer to run a decision stump.
I tried passing the feature masks as boolean arrays, which broke. Converting to lists of integers works a bit further but also breaks.

import openml
import numpy as np
import pandas as pd
from sklearn.pipeline import make_pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.tree import DecisionTreeClassifier

def get_frequent_task(did):
    tasks = openml.tasks.list_tasks(task_type_id=1, data_id=did)
    most_frequent_task = pd.Series({task: len(openml.runs.list_runs(task=[task])) for task in tasks}).index[0]
    return (most_frequent_task, tasks[most_frequent_task]['NumberOfInstances'])

dataset_id = 61
dataset = openml.datasets.get_dataset(dataset_id)

X, y, categorical_ind, feature_names = dataset.get_data(target=dataset.default_target_attribute, dataset_format='array')
categorical_ind = np.array(categorical_ind)
cat_idx, = np.where(categorical_ind)
cont_idx, = np.where(~categorical_ind)

clf = make_pipeline(ColumnTransformer([('cat', make_pipeline(SimpleImputer(strategy='most_frequent'), OneHotEncoder()), cat_idx.tolist()),
                                       ('cont', make_pipeline(SimpleImputer(strategy='median'), StandardScaler()), cont_idx.tolist())])          
                , DecisionTreeClassifier(max_depth=1))
if not categorical_ind.any():
    clf[0].set_params(cat='drop')
if not (~categorical_ind).any():
    clf[0].set_params(cont='drop')
hotencoded = clf[:-1].fit_transform(X)

bla = openml.runs.run_model_on_task(model=clf, task=get_frequent_task(dataset_id), return_flow=True)

TypeError: Second item of tuple does not match assumptions. Expected OpenMLFlow, got <class 'str'>

I guess having a "drop" transformer in ColumnTransformer doesn't work?

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions