"""Metadata and data loading model classes."""
import datetime as dt
import math
from copy import copy
from dataclasses import astuple, dataclass, field, replace
from typing import (
Any,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
)
from odc.geo import CRS, Geometry, MaybeCRS
from odc.geo.geobox import GeoBox
from odc.geo.types import Unset
from .loader.types import RasterBandMetadata, RasterSource, norm_band_metadata
BandKey = Tuple[str, int]
"""Asset Name, band index within an asset (1 based)."""
BandQuery = Optional[Union[str, Sequence[str]]]
"""One|All|Some bands"""
[docs]@dataclass(eq=True, frozen=True)
class ParsedItem(Mapping[Union[BandKey, str], RasterSource]):
"""
Captures essentials parts for data loading from a STAC Item.
Only includes raster bands of interest.
"""
id: str
"""Item id copied from STAC."""
collection: RasterCollectionMetadata
"""Collection this Item is part of."""
bands: Dict[BandKey, RasterSource]
"""Raster bands."""
geometry: Optional[Geometry] = None
"""Footprint of the dataset."""
datetime: Optional[dt.datetime] = None
"""Nominal timestamp."""
datetime_range: Tuple[Optional[dt.datetime], Optional[dt.datetime]] = None, None
"""Time period covered."""
href: Optional[str] = None
"""Self link from stac item."""
[docs] def geoboxes(self, bands: BandQuery = None) -> Tuple[GeoBox, ...]:
"""
Unique ``GeoBox`` s, highest resolution first.
:param bands: which bands to consider, default is all
"""
bands = self.collection.normalize_band_query(bands)
def _resolution(g: GeoBox) -> float:
return min(g.resolution.map(abs).xy) # type: ignore
gbx: Set[GeoBox] = set()
for name in bands:
b = self.bands.get(self.collection.band_key(name), None)
if b is not None:
if b.geobox is not None:
gbx.add(b.geobox)
return tuple(sorted(gbx, key=_resolution))
[docs] def crs(self, bands: BandQuery = None) -> Optional[CRS]:
"""
First non-null CRS across assets.
"""
for gbox in self.geoboxes(bands):
if gbox.crs is not None:
return gbox.crs
return None
[docs] def image_geometry(
self,
crs: MaybeCRS = Unset(),
bands: BandQuery = None,
) -> Optional[Geometry]:
"""
Extract footprint of a given band(s) from proj metadata in a given projection.
"""
if isinstance(crs, Unset):
crs = None
for gbox in self.geoboxes(bands):
if gbox.crs is not None:
if crs is None or crs == gbox.crs:
return gbox.extent
return gbox.footprint(crs)
return None
[docs] def safe_geometry(
self,
crs: MaybeCRS = Unset(),
bands: BandQuery = None,
) -> Optional[Geometry]:
"""
Get item geometry footprint in desired projection or native.
1. Use full-image footprint if proj data is available
2. Fallback to item geometry if not
"""
img_geom = self.image_geometry(crs, bands=bands)
if img_geom is not None:
return img_geom
if self.geometry is None:
return None
if crs is None or isinstance(crs, Unset):
return self.geometry
N = 100 # minimum number of points along perimiter we desire
min_sample_distance = math.sqrt(self.geometry.area) * 4 / N
return self.geometry.to_crs(
crs,
min_sample_distance,
check_and_fix=True,
).dropna()
[docs] def resolve_bands(
self, bands: BandQuery = None
) -> Dict[str, Optional[RasterSource]]:
"""
Query bands taking care of aliases.
"""
bands = self.collection.normalize_band_query(bands)
canon = self.collection.band_key
return {
k: self.bands.get(_actual, None)
for k, _actual in ((k, canon(k)) for k in bands)
}
def __getitem__(self, band: Union[str, BandKey]) -> RasterSource:
"""
Query band taking care of aliases.
:raises: :py:class:`KeyError`
"""
if isinstance(band, str):
band = self.collection.band_key(band)
return self.bands[band]
def __len__(self) -> int:
return len(self.bands)
def __iter__(self) -> Iterator[BandKey]:
yield from self.bands
def __contains__(self, k: object) -> bool:
if isinstance(k, str):
try:
return self.collection.band_key(k) in self.bands
except ValueError:
return False
if isinstance(k, tuple):
return k in self.bands
return False
@property
def nominal_datetime(self) -> dt.datetime:
"""
Resolve timestamp to a single value.
- datetime if set
- start_datetime if set
- end_datetime if set
- ``raise ValueError`` otherwise
"""
for ts in [self.datetime, *self.datetime_range]:
if ts is not None:
return ts
raise ValueError("Timestamp was not populated.")
@property
def mid_longitude(self) -> Optional[float]:
"""
Return longitude of the center point.
used for "solar day" computation.
"""
if self.geometry is None:
return None
((lon, _),) = self.geometry.centroid.to_crs("epsg:4326").points
return lon
@property
def solar_date(self) -> dt.datetime:
"""
Nominal datetime adjusted by longitude.
"""
lon = self.mid_longitude
if lon is None:
return self.nominal_datetime
return _convert_to_solar_time(self.nominal_datetime, lon)
[docs] def solar_date_at(self, lon: float) -> dt.datetime:
"""
Nominal datetime adjusted by longitude.
"""
return _convert_to_solar_time(self.nominal_datetime, lon)
[docs] def strip(self) -> "ParsedItem":
"""
Copy of self but with stripped bands.
"""
return replace(self, bands={k: band.strip() for k, band in self.bands.items()})
[docs] def assets(self) -> Dict[str, List[RasterSource]]:
"""
Extract bands grouped by asset they belong to.
"""
assets: Dict[str, List[Tuple[int, RasterSource]]] = {}
for (asset, idx), src in self.bands.items():
assets.setdefault(asset, []).append((idx, src))
return {
k: [src for _, src in sorted(srcs, key=lambda x: x[0])]
for k, srcs in assets.items()
}
def __hash__(self) -> int:
return hash((self.id, self.collection.name))
def __dask_tokenize__(self):
return (
self.id,
self.collection,
self.bands,
self.href,
self.datetime,
self.datetime_range,
)
@dataclass(frozen=True)
class MDParseConfig:
"""Item parsing config."""
band_defaults: RasterBandMetadata = field(default_factory=RasterBandMetadata)
band_cfg: Dict[str, RasterBandMetadata] = field(default_factory=dict)
aliases: Dict[str, BandKey] = field(default_factory=dict)
ignore_proj: bool = False
@staticmethod
def from_dict(collection_id: str, cfg=Dict[str, Any]) -> "MDParseConfig":
_cfg = copy(cfg.get("*", {}))
_cfg.update(cfg.get(collection_id, {}))
band_defaults, band_cfg = _norm_band_cfg(_cfg.get("assets", {}))
aliases = {
alias: ((band, 1) if isinstance(band, str) else band)
for alias, band in _cfg.get("aliases", {}).items()
}
ignore_proj: bool = _cfg.get("ignore_proj", False)
return MDParseConfig(
band_defaults=band_defaults,
band_cfg=band_cfg,
ignore_proj=ignore_proj,
aliases=aliases,
)
def _norm_band_cfg(
cfg: Dict[str, Any]
) -> Tuple[RasterBandMetadata, Dict[str, RasterBandMetadata]]:
fallback = norm_band_metadata(cfg.get("*", {}))
return fallback, {
k: norm_band_metadata(v, fallback) for k, v in cfg.items() if k != "*"
}
def _convert_to_solar_time(utc: dt.datetime, longitude: float) -> dt.datetime:
# offset_seconds snapped to 1 hour increments
# 1/15 == 24/360 (hours per degree of longitude)
offset_seconds = int(longitude / 15) * 3600
return utc + dt.timedelta(seconds=offset_seconds)
def norm_key(k: Union[str, BandKey]) -> BandKey:
"""
("band", i) -> ("band", i)
"band" -> ("band", 1)
"band.3" -> ("band", 3)
"""
if isinstance(k, str):
parts = k.rsplit(".", 1)
if len(parts) == 2:
return parts[0], int(parts[1])
return (k, 1)
return k