Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion openml/datasets/data_feature.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# License: BSD 3-Clause

from typing import List


class OpenMLDataFeature(object):
"""
Expand All @@ -20,7 +22,14 @@ class OpenMLDataFeature(object):

LEGAL_DATA_TYPES = ["nominal", "numeric", "string", "date"]

def __init__(self, index, name, data_type, nominal_values, number_missing_values):
def __init__(
self,
index: int,
name: str,
data_type: str,
nominal_values: List[str],
number_missing_values: int,
):
if type(index) != int:
raise ValueError("Index is of wrong datatype")
if data_type not in self.LEGAL_DATA_TYPES:
Expand Down
143 changes: 96 additions & 47 deletions openml/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections import OrderedDict
import re
import gzip
import io
import logging
import os
import pickle
Expand All @@ -13,6 +12,7 @@
import numpy as np
import pandas as pd
import scipy.sparse
import xmltodict

from openml.base import OpenMLBase
from .data_feature import OpenMLDataFeature
Expand Down Expand Up @@ -125,8 +125,8 @@ def __init__(
update_comment=None,
md5_checksum=None,
data_file=None,
features=None,
qualities=None,
features_file: Optional[str] = None,
qualities_file: Optional[str] = None,
dataset=None,
):
def find_invalid_characters(string, pattern):
Expand Down Expand Up @@ -188,7 +188,7 @@ def find_invalid_characters(string, pattern):
self.default_target_attribute = default_target_attribute
self.row_id_attribute = row_id_attribute
if isinstance(ignore_attribute, str):
self.ignore_attribute = [ignore_attribute]
self.ignore_attribute = [ignore_attribute] # type: Optional[List[str]]
elif isinstance(ignore_attribute, list) or ignore_attribute is None:
self.ignore_attribute = ignore_attribute
else:
Expand All @@ -202,33 +202,25 @@ def find_invalid_characters(string, pattern):
self.update_comment = update_comment
self.md5_checksum = md5_checksum
self.data_file = data_file
self.features = None
self.qualities = None
self._dataset = dataset

if features is not None:
self.features = {}
for idx, xmlfeature in enumerate(features["oml:feature"]):
nr_missing = xmlfeature.get("oml:number_of_missing_values", 0)
feature = OpenMLDataFeature(
int(xmlfeature["oml:index"]),
xmlfeature["oml:name"],
xmlfeature["oml:data_type"],
xmlfeature.get("oml:nominal_value"),
int(nr_missing),
)
if idx != feature.index:
raise ValueError("Data features not provided " "in right order")
self.features[feature.index] = feature
if features_file is not None:
self.features = _read_features(
features_file
) # type: Optional[Dict[int, OpenMLDataFeature]]
else:
self.features = None

self.qualities = _check_qualities(qualities)
if qualities_file:
self.qualities = _read_qualities(qualities_file) # type: Optional[Dict[str, float]]
else:
self.qualities = None

if data_file is not None:
(
self.data_pickle_file,
self.data_feather_file,
self.feather_attribute_file,
) = self._create_pickle_in_cache(data_file)
rval = self._create_pickle_in_cache(data_file)
self.data_pickle_file = rval[0] # type: Optional[str]
self.data_feather_file = rval[1] # type: Optional[str]
self.feather_attribute_file = rval[2] # type: Optional[str]
else:
self.data_pickle_file, self.data_feather_file, self.feather_attribute_file = (
None,
Expand Down Expand Up @@ -357,7 +349,7 @@ def decode_arff(fh):
with gzip.open(filename) as fh:
return decode_arff(fh)
else:
with io.open(filename, encoding="utf8") as fh:
with open(filename, encoding="utf8") as fh:
return decode_arff(fh)

def _parse_data_from_arff(
Expand Down Expand Up @@ -405,12 +397,10 @@ def _parse_data_from_arff(
# can be encoded into integers
pd.factorize(type_)[0]
except ValueError:
raise ValueError(
"Categorical data needs to be numeric when " "using sparse ARFF."
)
raise ValueError("Categorical data needs to be numeric when using sparse ARFF.")
# string can only be supported with pandas DataFrame
elif type_ == "STRING" and self.format.lower() == "sparse_arff":
raise ValueError("Dataset containing strings is not supported " "with sparse ARFF.")
raise ValueError("Dataset containing strings is not supported with sparse ARFF.")

# infer the dtype from the ARFF header
if isinstance(type_, list):
Expand Down Expand Up @@ -743,7 +733,7 @@ def get_data(
to_exclude.extend(self.ignore_attribute)

if len(to_exclude) > 0:
logger.info("Going to remove the following attributes:" " %s" % to_exclude)
logger.info("Going to remove the following attributes: %s" % to_exclude)
keep = np.array(
[True if column not in to_exclude else False for column in attribute_names]
)
Expand Down Expand Up @@ -810,6 +800,10 @@ def retrieve_class_labels(self, target_name: str = "class") -> Union[None, List[
-------
list
"""
if self.features is None:
raise ValueError(
"retrieve_class_labels can only be called if feature information is available."
)
for feature in self.features.values():
if (feature.name == target_name) and (feature.data_type == "nominal"):
return feature.nominal_values
Expand Down Expand Up @@ -938,18 +932,73 @@ def _to_dict(self) -> "OrderedDict[str, OrderedDict]":
return data_container


def _check_qualities(qualities):
if qualities is not None:
qualities_ = {}
for xmlquality in qualities:
name = xmlquality["oml:name"]
if xmlquality.get("oml:value", None) is None:
value = float("NaN")
elif xmlquality["oml:value"] == "null":
value = float("NaN")
else:
value = float(xmlquality["oml:value"])
qualities_[name] = value
return qualities_
else:
return None
def _read_features(features_file: str) -> Dict[int, OpenMLDataFeature]:
features_pickle_file = _get_features_pickle_file(features_file)
try:
with open(features_pickle_file, "rb") as fh_binary:
features = pickle.load(fh_binary)
except: # noqa E722
with open(features_file, encoding="utf8") as fh:
features_xml_string = fh.read()
xml_dict = xmltodict.parse(
features_xml_string, force_list=("oml:feature", "oml:nominal_value")
)
features_xml = xml_dict["oml:data_features"]

features = {}
for idx, xmlfeature in enumerate(features_xml["oml:feature"]):
nr_missing = xmlfeature.get("oml:number_of_missing_values", 0)
feature = OpenMLDataFeature(
int(xmlfeature["oml:index"]),
xmlfeature["oml:name"],
xmlfeature["oml:data_type"],
xmlfeature.get("oml:nominal_value"),
int(nr_missing),
)
if idx != feature.index:
raise ValueError("Data features not provided in right order")
features[feature.index] = feature

with open(features_pickle_file, "wb") as fh_binary:
pickle.dump(features, fh_binary)
return features


def _get_features_pickle_file(features_file: str) -> str:
"""This function only exists so it can be mocked during unit testing"""
return features_file + ".pkl"


def _read_qualities(qualities_file: str) -> Dict[str, float]:
qualities_pickle_file = _get_qualities_pickle_file(qualities_file)
try:
with open(qualities_pickle_file, "rb") as fh_binary:
qualities = pickle.load(fh_binary)
except: # noqa E722
with open(qualities_file, encoding="utf8") as fh:
qualities_xml = fh.read()
xml_as_dict = xmltodict.parse(qualities_xml, force_list=("oml:quality",))
qualities = xml_as_dict["oml:data_qualities"]["oml:quality"]
qualities = _check_qualities(qualities)
with open(qualities_pickle_file, "wb") as fh_binary:
pickle.dump(qualities, fh_binary)
return qualities


def _get_qualities_pickle_file(qualities_file: str) -> str:
"""This function only exists so it can be mocked during unit testing"""
return qualities_file + ".pkl"


def _check_qualities(qualities: List[Dict[str, str]]) -> Dict[str, float]:
qualities_ = {}
for xmlquality in qualities:
name = xmlquality["oml:name"]
if xmlquality.get("oml:value", None) is None:
value = float("NaN")
elif xmlquality["oml:value"] == "null":
value = float("NaN")
else:
value = float(xmlquality["oml:value"])
qualities_[name] = value
return qualities_
Loading