Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
9d2048e
exclude gnome for full downloads if needed
tschaume Mar 5, 2025
505ddfe
query s3 for trajectories
tsmathis Oct 23, 2025
aee0f8c
add deltalake query support
tsmathis Oct 23, 2025
d5a25b1
linting + mistaken sed replace on 'where'
tsmathis Oct 23, 2025
2de051d
return trajectory as pmg dict
tsmathis Oct 23, 2025
7d0b8b7
update trajectory test
tsmathis Oct 23, 2025
7195adf
correct docstrs
tsmathis Oct 23, 2025
33b787f
Merge branch 'main' into deltalake
tschaume Oct 24, 2025
2664fcd
get access controlled batch ids from heartbeat
tsmathis Nov 3, 2025
b498a76
refactor
tsmathis Nov 4, 2025
7da6984
Merge branch 'main' into deltalake
tschaume Nov 4, 2025
948c108
auto dependency upgrades
invalid-email-address Nov 5, 2025
b0aed4f
Update testing.yml
tschaume Nov 5, 2025
a35bcb7
rm overlooked access of removed settings param
tsmathis Nov 5, 2025
9460601
refactor: consolidate requests to heartbeat for meta info
tsmathis Nov 5, 2025
05f1d0e
lint
tsmathis Nov 5, 2025
e685445
fix incomplete docstr
tsmathis Nov 5, 2025
bb0b238
typo
tsmathis Nov 5, 2025
dc0c949
Merge branch 'main' into deltalake
tsmathis Nov 10, 2025
fb84d73
revert testing endpoint
tsmathis Nov 10, 2025
5bdacf5
no parallel on batch_id_neq_any
tsmathis Nov 10, 2025
7ee5515
more resilient dataset path expansion
tsmathis Nov 12, 2025
ae7674d
missed field annotation update
tsmathis Nov 12, 2025
5538c74
coerce Path to str for deltalake lib
tsmathis Nov 12, 2025
f39c0d3
flush based on bytes
tsmathis Nov 14, 2025
a965255
iterate over individual rows for local dataset
tsmathis Nov 14, 2025
03b38e7
missed bounds check for updated iteration behavior
tsmathis Nov 14, 2025
3a44b4f
opt for module level logging over warnings lib
tsmathis Nov 14, 2025
b2a832f
lint
tsmathis Nov 14, 2025
4b4af48
Merge branch 'main' into deltalake
tsmathis Feb 9, 2026
9cf0713
missed during merge-conflict resolution
tsmathis Feb 9, 2026
ff17bea
bump deltalake
tsmathis Feb 9, 2026
cd6e4a4
explicit casts for arrow types for data read from delta
tsmathis Feb 9, 2026
0cf6a40
auto dependency upgrades
invalid-email-address Feb 9, 2026
7284d74
raise warnings for pythonic usage of MPDatasets
tsmathis Feb 9, 2026
961e21c
Automated dependency upgrades (#1058)
tsmathis Feb 9, 2026
e09fd48
incomplete docstr for MPDataset
tsmathis Feb 9, 2026
92f88ac
fix get_trajectory helper func + test
tsmathis Feb 9, 2026
551e448
missed passing mpdataset kwargs to lazy subresters on init
tsmathis Feb 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
- name: Test with pytest
env:
MP_API_KEY: ${{ secrets[env.API_KEY_NAME] }}
#MP_API_ENDPOINT: https://api-preview.materialsproject.org/
# MP_API_ENDPOINT: https://api-preview.materialsproject.org/
run: |
pytest -n auto -x --cov=mp_api --cov-report=xml
- uses: codecov/codecov-action@v1
Expand Down
249 changes: 231 additions & 18 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import gzip
import inspect
import itertools
import logging
import os
import platform
import shutil
import sys
import warnings
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
Expand All @@ -21,13 +23,17 @@
from json import JSONDecodeError
from math import ceil
from typing import TYPE_CHECKING, ForwardRef, Optional, get_args
from urllib.parse import quote
from urllib.parse import quote, urljoin

import boto3
import pyarrow as pa
import pyarrow.dataset as ds
import requests
from botocore import UNSIGNED
from botocore.config import Config
from botocore.exceptions import ClientError
from deltalake import DeltaTable, QueryBuilder, convert_to_deltalake
from emmet.core.arrow import arrowize
from emmet.core.utils import jsanitize
from pydantic import BaseModel, create_model
from requests.adapters import HTTPAdapter
Expand All @@ -38,6 +44,7 @@
from mp_api.client.core.exceptions import MPRestError
from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS
from mp_api.client.core.utils import (
MPDataset,
load_json,
validate_api_key,
validate_endpoint,
Expand All @@ -62,6 +69,15 @@
__version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION")


hdlr = logging.StreamHandler()
fmt = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
hdlr.setFormatter(fmt)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(hdlr)


class _DictLikeAccess(BaseModel):
"""Define a pydantic mix-in which permits dict-like access to model fields."""

Expand All @@ -85,6 +101,7 @@ class BaseRester:
suffix: str = ""
document_model: type[BaseModel] | None = None
primary_key: str = "material_id"
delta_backed: bool = False

def __init__(
self,
Expand All @@ -98,6 +115,10 @@ def __init__(
timeout: int = 20,
headers: dict | None = None,
mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS,
local_dataset_cache: (
str | os.PathLike
) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE,
force_renew: bool = False,
**kwargs,
):
"""Initialize the REST API helper class.
Expand Down Expand Up @@ -129,6 +150,9 @@ def __init__(
timeout: Time in seconds to wait until a request timeout error is thrown
headers: Custom headers for localhost connections.
mute_progress_bars: Whether to disable progress bars.
local_dataset_cache: Target directory for downloading full datasets. Defaults
to 'mp_datasets' in the user's home directory
force_renew: Option to overwrite existing local dataset
**kwargs: access to legacy kwargs that may be in the process of being deprecated
"""
self.api_key = validate_api_key(api_key)
Expand All @@ -141,7 +165,14 @@ def __init__(
self.timeout = timeout
self.headers = headers or {}
self.mute_progress_bars = mute_progress_bars
self.db_version = BaseRester._get_database_version(self.base_endpoint)

(
self.db_version,
self.access_controlled_batch_ids,
) = BaseRester._get_heartbeat_info(self.base_endpoint)

self.local_dataset_cache = local_dataset_cache
self.force_renew = force_renew

self._session = session
self._s3_client = s3_client
Expand Down Expand Up @@ -209,8 +240,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): # pragma: no cover

@staticmethod
@cache
def _get_database_version(endpoint):
"""The Materials Project database is periodically updated and has a
def _get_heartbeat_info(endpoint) -> tuple[str, str]:
"""DB version:
The Materials Project database is periodically updated and has a
database version associated with it. When the database is updated,
consolidated data (information about "a material") may and does
change, while calculation data about a specific calculation task
Expand All @@ -220,9 +252,24 @@ def _get_database_version(endpoint):
where "_DD" may be optional. An additional numerical or `postN` suffix
might be added if multiple releases happen on the same day.

Returns: database version as a string
Access Controlled Datasets:
Certain contributions to the Materials Project have access
control restrictions that require explicit agreement to the
Terms of Use for the respective datasets prior to access being
granted.

A full list of the Terms of Use for all contributions in the
Materials Project are available at:

https://next-gen.materialsproject.org/about/terms

Returns:
tuple with database version as a string and a comma separated
string with all calculation batch identifiers that have access
restrictions
"""
return requests.get(url=endpoint + "heartbeat").json()["db_version"]
response = requests.get(url=endpoint + "heartbeat").json()
return response["db_version"], response["access_controlled_batch_ids"]

def _post_resource(
self,
Expand Down Expand Up @@ -353,10 +400,7 @@ def _patch_resource(
raise MPRestError(str(ex))

def _query_open_data(
self,
bucket: str,
key: str,
decoder: Callable | None = None,
self, bucket: str, key: str, decoder: Callable | None = None
) -> tuple[list[dict] | list[bytes], int]:
"""Query and deserialize Materials Project AWS open data s3 buckets.

Expand Down Expand Up @@ -460,6 +504,12 @@ def _query_resource(
url = validate_endpoint(self.endpoint, suffix=suburl)

if query_s3:
pbar_message = ( # type: ignore
f"Retrieving {self.document_model.__name__} documents" # type: ignore
if self.document_model is not None
else "Retrieving documents"
)

if "/" not in self.suffix:
suffix = self.suffix
elif self.suffix == "molecules/summary":
Expand All @@ -469,15 +519,181 @@ def _query_resource(
suffix = infix if suffix == "core" else suffix
suffix = suffix.replace("_", "-")

# Paginate over all entries in the bucket.
# TODO: change when a subset of entries needed from DB
# Check if user has access to GNoMe
# temp suppress tqdm
re_enable = not self.mute_progress_bars
self.mute_progress_bars = True
has_gnome_access = bool(
self._submit_requests(
url=urljoin(self.base_endpoint, "materials/summary/"),
criteria={
"batch_id": "gnome_r2scan_statics",
"_fields": "material_id",
},
use_document_model=False,
num_chunks=1,
chunk_size=1,
timeout=timeout,
)
.get("meta", {})
.get("total_doc", 0)
)
self.mute_progress_bars = not re_enable

if "tasks" in suffix:
bucket_suffix, prefix = "parsed", "tasks_atomate2"
bucket_suffix, prefix = ("parsed", "core/tasks/")
else:
bucket_suffix = "build"
prefix = f"collections/{self.db_version.replace('.', '-')}/{suffix}"

bucket = f"materialsproject-{bucket_suffix}"

if self.delta_backed:
target_path = str(
self.local_dataset_cache.joinpath(f"{bucket_suffix}/{prefix}")
)
os.makedirs(target_path, exist_ok=True)

if DeltaTable.is_deltatable(target_path):
if self.force_renew:
shutil.rmtree(target_path)
logger.warning(
f"Regenerating {suffix} dataset at {target_path}..."
)
os.makedirs(target_path, exist_ok=True)
else:
logger.warning(
f"Dataset for {suffix} already exists at {target_path}, returning existing dataset."
)
logger.info(
"Delete or move existing dataset or re-run search query with MPRester(force_renew=True) "
"to refresh local dataset.",
)

return {
"data": MPDataset(
path=target_path,
document_model=self.document_model,
use_document_model=self.use_document_model,
)
}

tbl = DeltaTable(
f"s3a://{bucket}/{prefix}",
storage_options={
"AWS_SKIP_SIGNATURE": "true",
"AWS_REGION": "us-east-1",
},
)

controlled_batch_str = ",".join(
[f"'{tag}'" for tag in self.access_controlled_batch_ids]
)

predicate = (
" WHERE batch_id NOT IN (" # don't delete leading space
+ controlled_batch_str
+ ")"
if not has_gnome_access
else ""
)

builder = QueryBuilder().register("tbl", tbl)

# Setup progress bar
num_docs_needed = pa.table(
builder.execute("SELECT COUNT(*) FROM tbl").read_all()
)[0][0].as_py()

if not has_gnome_access:
num_docs_needed = self.count(
{"batch_id_neq_any": self.access_controlled_batch_ids}
)

pbar = (
tqdm(
desc=pbar_message,
total=num_docs_needed,
)
if not self.mute_progress_bars
else None
)

iterator = builder.execute("SELECT * FROM tbl" + predicate)

file_options = ds.ParquetFileFormat().make_write_options(
compression="zstd"
)

def _flush(
accumulator: list[pa.RecordBatch], group: int, schema: pa.Schema
):
# somewhere post datafusion 51.0.0 and arrow-rs 57.0.0
# casts to *View types began, need to cast back to base schema
# -> pyarrow is behind on implementation support for *View types
tbl = (
pa.Table.from_batches(accumulator)
.select(schema.names)
.cast(target_schema=schema)
)

ds.write_dataset(
tbl,
base_dir=target_path,
format="parquet",
basename_template=f"group-{group}-"
+ "part-{i}.zstd.parquet",
existing_data_behavior="overwrite_or_ignore",
max_rows_per_group=1024,
file_options=file_options,
)

group = 1
size = 0
accumulator = []
schema = pa.schema(arrowize(self.document_model))
for page in iterator:
# arro3 rb to pyarrow rb for compat w/ pyarrow ds writer
rg = pa.record_batch(page)
accumulator.append(rg)
page_size = page.num_rows
size += rg.get_total_buffer_size()

if pbar is not None:
pbar.update(page_size)

if size >= MAPI_CLIENT_SETTINGS.DATASET_FLUSH_THRESHOLD:
_flush(accumulator, group, schema)
group += 1
size = 0
accumulator.clear()

if accumulator:
_flush(accumulator, group + 1, schema)

if pbar is not None:
pbar.close()

logger.info(f"Dataset for {suffix} written to {target_path}")
logger.info("Converting to DeltaTable...")

convert_to_deltalake(target_path)

logger.info(
"Consult the delta-rs and pyarrow documentation for advanced usage: "
"delta-io.github.io/delta-rs, arrow.apache.org/docs/python"
)

return {
"data": MPDataset(
path=target_path,
document_model=self.document_model,
use_document_model=self.use_document_model,
)
}

# Paginate over all entries in the bucket.
# TODO: change when a subset of entries needed from DB
paginator = self.s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)

Expand Down Expand Up @@ -514,11 +730,6 @@ def _query_resource(
}

# Setup progress bar
pbar_message = ( # type: ignore
f"Retrieving {self.document_model.__name__} documents" # type: ignore
if self.document_model is not None
else "Retrieving documents"
)
num_docs_needed = int(self.count())
pbar = (
tqdm(
Expand Down Expand Up @@ -1365,6 +1576,8 @@ def __getattr__(self, v: str):
use_document_model=self.use_document_model,
headers=self.headers,
mute_progress_bars=self.mute_progress_bars,
local_dataset_cache=self.local_dataset_cache,
force_renew=self.force_renew,
)
return self.sub_resters[v]

Expand Down
13 changes: 13 additions & 0 deletions mp_api/client/core/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Define custom exceptions and warnings for the client."""

from __future__ import annotations


Expand All @@ -8,3 +9,15 @@ class MPRestError(Exception):

class MPRestWarning(Warning):
"""Raised when a query is malformed but interpretable."""


class MPDatasetIndexingWarning(Warning):
"""Raised during sub-optimal indexing of MPDatasets."""


class MPDatasetSlicingWarning(Warning):
"""Raised during sub-optimal slicing of MPDatasets."""


class MPDatasetIterationWarning(Warning):
"""Raised during sub-optimal iteration of MPDatasets."""
Loading