Source code for odc.stac._stac_load

"""stac.load - dc.load from STAC Items."""

from __future__ import annotations

import dataclasses
import functools
import itertools
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    Iterator,
    List,
    Literal,
    Optional,
    Sequence,
    Tuple,
    Union,
    cast,
)

import numpy as np
import pystac
import pystac.item
import xarray as xr
from dask.utils import ndeepmap
from numpy.typing import DTypeLike
from odc.geo import CRS, MaybeCRS, SomeResolution
from odc.geo.geobox import GeoBox, GeoboxAnchor, GeoboxTiles
from odc.geo.types import Unset

from ._mdtools import ConversionConfig, output_geobox, parse_items
from .loader import (
    DaskGraphBuilder,
    direct_chunked_load,
    mk_dataset,
    reader_driver,
    resolve_chunk_shape,
    resolve_load_cfg,
)
from .model import BandQuery, ParsedItem, RasterCollectionMetadata

DEFAULT_CHUNK_FOR_LOAD = 2048
"""Used to partition load when not using Dask."""

GroupbyCallback = Callable[[pystac.item.Item, ParsedItem, int], Any]

Groupby = Union[str, GroupbyCallback]


def _collection(items: Iterable[ParsedItem]) -> RasterCollectionMetadata:
    for item in items:
        return item.collection
    raise ValueError("Can't load empty sequence")


def patch_urls(
    item: ParsedItem, edit: Callable[[str], str], bands: BandQuery = None
) -> ParsedItem:
    """
    Map function over dataset measurement urls.

    :param ds: Dataset to edit in place
    :param edit: Function that returns modified url from input url
    :param bands: Only edit specified bands, default is to edit all
    :return: Input dataset
    """

    if bands is None:
        _bands = {
            k: dataclasses.replace(src, uri=edit(src.uri))
            for k, src in item.bands.items()
        }
    else:
        _to_edit = set(map(item.collection.band_key, bands))
        _bands = {
            k: dataclasses.replace(src, uri=edit(src.uri) if k in _to_edit else src.uri)
            for k, src in item.bands.items()
        }

    return dataclasses.replace(item, bands=_bands)


# pylint: disable=too-many-arguments,too-many-locals,too-many-statements
[docs]def load( items: Iterable[pystac.item.Item], bands: Optional[Union[str, Sequence[str]]] = None, *, groupby: Optional[Groupby] = "time", resampling: Optional[Union[str, Dict[str, str]]] = None, dtype: Union[DTypeLike, Dict[str, DTypeLike], None] = None, chunks: Optional[Dict[str, int | Literal["auto"]]] = None, pool: Union[ThreadPoolExecutor, int, None] = None, # Geo selection crs: MaybeCRS = Unset(), resolution: Optional[SomeResolution] = None, anchor: Optional[GeoboxAnchor] = None, geobox: Optional[GeoBox] = None, bbox: Optional[Tuple[float, float, float, float]] = None, lon: Optional[Tuple[float, float]] = None, lat: Optional[Tuple[float, float]] = None, x: Optional[Tuple[float, float]] = None, y: Optional[Tuple[float, float]] = None, like: Optional[Any] = None, geopolygon: Optional[Any] = None, # UI progress: Optional[Any] = None, fail_on_error: bool = True, # stac related stac_cfg: Optional[ConversionConfig] = None, patch_url: Optional[Callable[[str], str]] = None, preserve_original_order: bool = False, **kw, ) -> xr.Dataset: """ STAC :class:`~pystac.item.Item` to :class:`xarray.Dataset`. Load several STAC :class:`~pystac.item.Item` objects (from the same or similar collections) as an :class:`xarray.Dataset`. This method can load pixel data directly on a local machine or construct a Dask graph that can be processed on a remote cluster. .. code-block:: python catalog = pystac.Client.open(...) query = catalog.search(...) xx = odc.stac.load( query.get_items(), bands=["red", "green", "blue"], ) xx.red.plot.imshow(col="time") :param items: Iterable of STAC :class:`~pystac.item.Item` to load :param bands: List of band names to load, defaults to All. Also accepts single band name as input .. rubric:: Common Options :param groupby: Controls what items get placed in to the same pixel plane. Following have special meaning: * "time" items with exactly the same timestamp are grouped together * "solar_day" items captured on the same day adjusted for solar time * "id" every item is loaded separately Any other string is assumed to be a key in Item's properties dictionary. You can also supply custom key function, it should take 3 arguments ``(pystac.Item, ParsedItem, index:int) -> Any`` :param preserve_original_order: By default items are sorted by `time, id` within each group to make pixel fusing order deterministic. Setting this flag to ``True`` will instead keep items within each group in the same order as supplied, so that one can implement arbitrary priority for pixel overlap cases. :param resampling: Controls resampling strategy, can be specified per band :param dtype: Force output dtype, can be specified per band :param chunks: Rather than loading pixel data directly, construct Dask backed arrays. ``chunks={'x': 2048, 'y': 2048}`` :param progress: Pass in ``tqdm`` progress bar or similar, only used in non-Dask load. :param fail_on_error: Set this to ``False`` to skip over load failures. :param pool: Use thread pool to perform load locally, only used in non-Dask load. .. rubric:: Control Pixel Grid of Output There are many ways to control footprint and resolution of returned data. The most precise way is to use :py:class:`~odc.geo.geobox.GeoBox`, ``geobox=GeoBox(..)``. Similarly one can use ``like=xx`` to match pixel grid to previously loaded data (``xx = odc.stac.load(...)``). Other common way is to configure crs and resolution only .. code-block:: python xx = odc.stac.load(... crs="EPSG:3857", resolution=10) # resolution units must match CRS # here we assume 1 degree == 111km to load at roughly # the same 10m resolution as statement above. yy = odc.stac.load(... crs="EPSG:4326", resolution=0.00009009) By default :py:func:`odc.stac.load` loads all available pixels in the requested projection and resolution. To limit extent of loaded data you have to supply bounds via either ``geobox=`` or ``like=`` parameters (these also select projection and resolution). Alternatively use a pair of ``x, y`` or ``lon, lat`` parameters. ``x, y`` allows you to specify bounds in the output projection, while ``lon, lat`` operate in degrees. You can also use ``bbox`` which is equivalent to ``lon, lat``. It should be noted that returned data is likely to reach outside of the requested bounds by fraction of a pixel when using ``bbox``, ``x, y`` or ``lon, lat`` mechanisms. This is due to pixel grid "snapping". Pixel edges will still start at ``N*pixel_size`` where ``N is int`` regardless of the requested bounding box. :param crs: Load data in a given CRS. Special name of ``"utm"`` is also understood, in which case an appropriate UTM projection will be picked based on the output bounding box. :param resolution: Set resolution of output in units of the output CRS. This can be a single float, in which case pixels are assumed to be square with ``Y`` axis flipped. To specify non-square or non-flipped pixels use :py:func:`odc.geo.resxy_` or :py:func:`odc.geo.resyx_`. ``resolution=10`` is equivalent to ``resolution=odc.geo.resxy_(10, -10)``. Resolution must be supplied in the units of the output CRS. Units are commonly meters for *Projected* and degrees for *Geographic* CRSs. :param bbox: Specify bounding box in Lon/Lat. ``[min(lon), min(lat), max(lon), max(lat)]`` :param lon: Define output bounds in Lon/Lat :param lat: Define output bounds in Lon/Lat :param x: Define output bounds in output projection coordinate units :param y: Define output bounds in output projection coordinate units :param anchor: Controls pixel snapping, default is to align pixel grid to ``X``/``Y`` axis such that pixel edges align with ``x=0, y=0``. Other common option is to align pixel centers to ``0,0`` rather than edges. :param geobox: Allows to specify exact region/resolution/projection using :class:`~odc.geo.geobox.GeoBox` object :param like: Match output grid to the data loaded previously. :param geopolygon: Limit returned result to a bounding box of a given geometry. This could be an instance of :class:`~odc.geo.geom.Geometry`, GeoJSON dictionary, GeoPandas DataFrame, or any object implementing ``__geo_interface__``. We assume ``EPSG:4326`` projection for dictionary and Shapely inputs. CRS information available on GeoPandas inputs should be understood correctly. .. rubric:: STAC Related Options :param stac_cfg: Controls interpretation of :py:class:`pystac.Item`. Mostly used to specify "missing" metadata like pixel data types. :param patch_url: Optionally transform url of every band before loading :return: :py:class:`xarray.Dataset` with requested bands populated .. rubric:: Complete Example Code .. code-block:: python import planetary_computer as pc from pystac_client import Client from odc import stac catalog = Client.open("https://planetarycomputer.microsoft.com/api/stac/v1") query = catalog.search( collections=["sentinel-2-l2a"], datetime="2019-06-06", query={"s2:mgrs_tile": dict(eq="06VVN")}, ) xx = stac.load( query.get_items(), bands=["red", "green", "blue"], resolution=100, # 1/10 of the native 10m resolution patch_url=pc.sign, ) xx.red.plot.imshow(col="time", size=8, aspect=1) .. rubric:: Example Optional Configuration Sometimes data source might be missing some optional STAC extensions. With ``stac_cfg=`` parameter one can supply that information at load time. Configuration is per collection per asset. You can provide information like pixel data type, ``nodata`` value used, ``unit`` attribute and band aliases you would like to use. Sample ``stac_cfg={..}`` parameter: .. code-block:: yaml sentinel-2-l2a: # < name of the collection, i.e. ``.collection_id`` assets: "*": # Band named "*" contains band info for "most" bands data_type: uint16 nodata: 0 unit: "1" SCL: # Those bands that are different than "most" data_type: uint8 nodata: 0 unit: "1" aliases: #< unique alias -> canonical map rededge: B05 rededge1: B05 rededge2: B06 rededge3: B07 some-other-collection: assets: #... "*": # Applies to all collections if not defined on a collection warnings: ignore # ignore|all (default all) """ # pylint: disable=unused-argument,too-many-branches if bands is None: # dc.load name for bands is measurements bands = kw.pop("measurements", None) # normalize args # dc.load compatible name for crs is `output_crs` if isinstance(crs, Unset) or crs is None: crs = cast(MaybeCRS, kw.pop("output_crs", None)) if groupby is None: groupby = "id" items = list(items) _parsed = list(parse_items(items, cfg=stac_cfg)) gbox = output_geobox( _parsed, bands=bands, crs=crs, resolution=resolution, anchor=anchor, align=kw.get("align", None), geobox=geobox, like=like, geopolygon=geopolygon, bbox=bbox, lon=lon, lat=lat, x=x, y=y, ) if gbox is None: raise ValueError("Failed to auto-guess CRS/resolution.") debug = kw.get("debug", False) # Check we have all the bands of interest # will raise ValueError if no such band/alias collection = _collection(_parsed) bands_to_load = collection.resolve_bands(bands) bands = list(bands_to_load) load_cfg = resolve_load_cfg( bands_to_load, resampling, dtype=dtype, use_overviews=kw.get("use_overviews", True), nodata=kw.get("nodata", None), fail_on_error=fail_on_error, ) if dtype is None: _dtypes = sorted( set(cfg.dtype for cfg in load_cfg.values() if cfg.dtype is not None), key=lambda x: np.dtype(x).itemsize, reverse=True, ) dtype = "uint16" if len(_dtypes) == 0 else _dtypes[0] if patch_url is not None: _parsed = [patch_urls(item, edit=patch_url, bands=bands) for item in _parsed] # Time dimension ((mid_lon, _),) = gbox.extent.centroid.to_crs("epsg:4326").points _grouped_idx = _group_items( items, _parsed, groupby, mid_lon, preserve_original_order=preserve_original_order, ) tss = _extract_timestamps(ndeepmap(2, lambda idx: _parsed[idx], _grouped_idx)) if chunks is not None: chunk_shape = resolve_chunk_shape(len(tss), gbox, chunks, dtype) else: chunk_shape = resolve_chunk_shape( len(tss), gbox, {dim: DEFAULT_CHUNK_FOR_LOAD for dim in gbox.dimensions}, dtype, ) # Spatio-temporal binning assert isinstance(gbox.crs, CRS) gbt = GeoboxTiles(gbox, chunk_shape[1:]) tyx_bins = dict(_tyx_bins(_grouped_idx, _parsed, gbt)) _parsed = [item.strip() for item in _parsed] def _with_debug_info(ds: xr.Dataset, **kw) -> xr.Dataset: # expose data for debugging if not debug: return ds from types import SimpleNamespace # pylint: disable=import-outside-toplevel ds.encoding.update( debug=SimpleNamespace( gbt=gbt, mid_lon=mid_lon, parsed=_parsed, grouped_idx=_grouped_idx, tyx_bins=tyx_bins, bands_to_load=bands_to_load, load_cfg=load_cfg, **kw, ) ) return ds rdr = reader_driver(load_cfg) rdr_env = rdr.capture_env() if chunks is not None: _loader = DaskGraphBuilder( load_cfg, _parsed, tyx_bins, gbt, rdr_env, rdr, time_chunks=chunk_shape[0], ) return _with_debug_info(mk_dataset(gbox, tss, load_cfg, _loader)) return _with_debug_info( direct_chunked_load( load_cfg, _parsed, tyx_bins, gbt, tss, rdr_env, rdr, pool=pool, progress=progress, ) )
def _extract_timestamps(grouped: List[List[ParsedItem]]) -> List[datetime]: def _ts(group: List[ParsedItem]) -> datetime: assert len(group) > 0 return group[0].nominal_datetime.replace(tzinfo=None) return list(map(_ts, grouped)) # pylint: disable=unused-argument def _groupby_solar_day( item: pystac.item.Item, parsed: ParsedItem, idx: int, lon: Optional[float] = None, ): if lon is None: return parsed.solar_date.date() return parsed.solar_date_at(lon).date() def _groupby_time( item: pystac.item.Item, parsed: ParsedItem, idx: int, ): return parsed.nominal_datetime def _groupby_id( item: pystac.item.Item, parsed: ParsedItem, idx: int, ): return idx def _groupby_property( item: pystac.item.Item, parsed: ParsedItem, idx: int, key: str = "", ): return item.properties.get(key, None) def _resolve_groupby(groupby: Groupby, lon: Optional[float] = None) -> GroupbyCallback: if not isinstance(groupby, str): return groupby if groupby == "time": return _groupby_time if groupby == "solar_day": return functools.partial(_groupby_solar_day, lon=lon) if groupby == "id": return _groupby_id return functools.partial(_groupby_property, key=groupby) def _group_items( items: List[pystac.item.Item], parsed: List[ParsedItem], groupby: Groupby, lon: Optional[float] = None, preserve_original_order=False, ) -> List[List[int]]: assert len(items) == len(parsed) group_key = _resolve_groupby(groupby, lon=lon) def _sorter(idx: int): _group = group_key(items[idx], parsed[idx], idx) if preserve_original_order: # Sort by group_key but keeping original item order within each group return (_group, idx) # Sort by group_key, but then time,id within each group return (_group, parsed[idx].nominal_datetime, parsed[idx].id) ii = sorted(range(len(parsed)), key=_sorter) return [ list(group) for _, group in itertools.groupby( ii, lambda idx: group_key(items[idx], parsed[idx], idx) ) ] def _tiles(item: ParsedItem, gbt: GeoboxTiles) -> Iterator[Tuple[int, int]]: geom = item.safe_geometry(gbt.base.crs) if geom is None: raise ValueError("Can not process items without defined footprint") yield from gbt.tiles(geom) def _tyx_bins( grouped: List[List[int]], items: List[ParsedItem], gbt: GeoboxTiles, ) -> Iterator[Tuple[Tuple[int, int, int], List[int]]]: for t_idx, group in enumerate(grouped): _yx: Dict[Tuple[int, int], List[int]] = {} for item_idx in group: for idx in _tiles(items[item_idx], gbt): _yx.setdefault(idx, []).append(item_idx) yield from (((t_idx, *idx), ii_item) for idx, ii_item in _yx.items())