Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions openml/tasks/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _get_estimation_procedure_list():
procs_dict = xmltodict.parse(xml_string)
# Minimalistic check if the XML is useful
if "oml:estimationprocedures" not in procs_dict:
raise ValueError("Error in return XML, does not contain tag " "oml:estimationprocedures.")
raise ValueError("Error in return XML, does not contain tag oml:estimationprocedures.")
elif "@xmlns:oml" not in procs_dict["oml:estimationprocedures"]:
raise ValueError(
"Error in return XML, does not contain tag "
Expand All @@ -106,10 +106,19 @@ def _get_estimation_procedure_list():

procs = []
for proc_ in procs_dict["oml:estimationprocedures"]["oml:estimationprocedure"]:
task_type_int = int(proc_["oml:ttid"])
try:
task_type_id = TaskType(task_type_int)
except ValueError as e:
warnings.warn(
f"Could not create task type id for {task_type_int} due to error {e}",
RuntimeWarning,
)
continue
procs.append(
{
"id": int(proc_["oml:id"]),
"task_type_id": TaskType(int(proc_["oml:ttid"])),
"task_type_id": task_type_id,
"name": proc_["oml:name"],
"type": proc_["oml:type"],
}
Expand All @@ -124,7 +133,7 @@ def list_tasks(
size: Optional[int] = None,
tag: Optional[str] = None,
output_format: str = "dict",
**kwargs
**kwargs,
) -> Union[Dict, pd.DataFrame]:
"""
Return a number of tasks having the given tag and task_type
Expand Down Expand Up @@ -175,7 +184,7 @@ def list_tasks(
offset=offset,
size=size,
tag=tag,
**kwargs
**kwargs,
)


Expand Down Expand Up @@ -240,9 +249,18 @@ def __list_tasks(api_call, output_format="dict"):
tid = None
try:
tid = int(task_["oml:task_id"])
task_type_int = int(task_["oml:task_type_id"])
try:
task_type_id = TaskType(task_type_int)
except ValueError as e:
warnings.warn(
f"Could not create task type id for {task_type_int} due to error {e}",
RuntimeWarning,
)
continue
task = {
"tid": tid,
"ttid": TaskType(int(task_["oml:task_type_id"])),
"ttid": task_type_id,
"did": int(task_["oml:did"]),
"name": task_["oml:name"],
"task_type": task_["oml:task_type"],
Expand Down Expand Up @@ -330,7 +348,10 @@ def get_task(
task
"""
if not isinstance(task_id, int):
warnings.warn("Task id must be specified as `int` from 0.14.0 onwards.", DeprecationWarning)
warnings.warn(
"Task id must be specified as `int` from 0.14.0 onwards.",
DeprecationWarning,
)

try:
task_id = int(task_id)
Expand Down Expand Up @@ -466,9 +487,12 @@ def create_task(
estimation_procedure_id: int,
target_name: Optional[str] = None,
evaluation_measure: Optional[str] = None,
**kwargs
**kwargs,
) -> Union[
OpenMLClassificationTask, OpenMLRegressionTask, OpenMLLearningCurveTask, OpenMLClusteringTask
OpenMLClassificationTask,
OpenMLRegressionTask,
OpenMLLearningCurveTask,
OpenMLClusteringTask,
]:
"""Create a task based on different given attributes.

Expand Down Expand Up @@ -519,5 +543,5 @@ def create_task(
target_name=target_name,
estimation_procedure_id=estimation_procedure_id,
evaluation_measure=evaluation_measure,
**kwargs
**kwargs,
)