import os
import io
import logging
import requests
import hashlib
import tempfile
import zipfile
import dateutil.parser
from pathlib import Path
from multiformats import multihash
import binascii
from lxml import etree
from datetime import datetime, timedelta, timezone
from osgeo import ogr, osr
from pyproj import Geod
import numpy as np
import spatialist
from spatialist.raster import Raster, rasterize
from spatialist.vector import bbox, intersect, boundary, vectorize, Vector, crsConvert
import pyroSAR
from pyroSAR.ancillary import Lock, LockCollection
from pyroSAR import identify_many
from collections import defaultdict
from typing import Callable, List, TypeVar
log = logging.getLogger('cesard')
T = TypeVar('T') # any type
K = TypeVar('K') # key
[docs]
def buffer_min_overlap(
geom1: Vector,
geom2: Vector,
percent: int | float = 1,
step: int | float | None = None
) -> Vector:
"""
Buffer a rectangular geometry to a minimum overlap with a second geometry.
The geometry is iteratively buffered until the minimum overlap is reached.
If the overlap of the input geometries is already larger than the defined
threshold, a copy of the original geometry is returned.
Parameters
----------
geom1:
the geometry to be buffered
geom2:
the reference geometry to intersect with
percent:
the minimum overlap in percent of `geom1`
step:
the buffering step size. If None, the step size is 0.1 % of the
average rectangle corner length.
"""
geom1_crs = geom1.getProjection(type='epsg')
geom2_crs = geom2.getProjection(type='epsg')
if geom1_crs != geom2_crs:
raise ValueError('both geometries must have the same CRS')
geom2_area = geom2.getArea()
ext = geom1.extent
ext2 = ext.copy()
if step is None:
xdist = ext['xmax'] - ext['xmin']
ydist = ext['ymax'] - ext['ymin']
step = (xdist + ydist) / 2 / 1000
buffer = 0
overlap = 0
while overlap <= percent:
xbuf = buffer * step
ybuf = buffer * step
ext2['xmin'] = ext['xmin'] - xbuf
ext2['xmax'] = ext['xmax'] + xbuf
ext2['ymin'] = ext['ymin'] - ybuf
ext2['ymax'] = ext['ymax'] + ybuf
with bbox(coordinates=ext2, crs=geom1_crs) as geom3:
ext3 = geom3.extent
inter = intersect(obj1=geom2, obj2=geom3)
if inter is not None:
inter_area = inter.getArea()
overlap = inter_area / geom2_area * 100
inter.close()
else:
overlap = 0
buffer += 1
return bbox(coordinates=ext3, crs=geom1_crs)
[docs]
def buffer_time(
start: str,
stop: str,
as_datetime: bool = False,
str_format: str = '%Y%m%dT%H%M%S',
**kwargs
) -> tuple[str | datetime, str | datetime]:
"""
Time range buffering
Parameters
----------
start:
the start time date object to convert; timezone-unaware dates are interpreted as UTC.
stop:
the stop time date object to convert; timezone-unaware dates are interpreted as UTC.
as_datetime:
return datetime objects instead of strings?
str_format:
the output string format (ignored if `as_datetime` is True)
kwargs
time arguments passed to :func:`datetime.timedelta`
Returns
-------
the buffered start and stop time as string or datetime object
"""
td = timedelta(**kwargs)
start = date_to_utc(start, as_datetime=True) - td
stop = date_to_utc(stop, as_datetime=True) + td
if not as_datetime:
start = start.strftime(str_format)
stop = stop.strftime(str_format)
return start, stop
[docs]
def check_scene_consistency(
scenes: list[str | pyroSAR.drivers.ID]
) -> None:
"""
Check the consistency of a scene selection.
The following pyroSAR object attributes must be the same:
- sensor
- acquisition_mode
- product
- frameNumber (data take ID for Sentinel-1)
Parameters
----------
scenes:
the scene selection
Raises
------
RuntimeError
"""
scenes = identify_many(scenes)
for attr in ['sensor', 'acquisition_mode', 'product', 'frameNumber']:
values = set([getattr(x, attr) for x in scenes])
if not len(values) == 1:
msg = f"scene selection differs in attribute '{attr}': {values}"
raise RuntimeError(msg)
[docs]
def check_spacing(
spacing: int | float
) -> None:
"""
Check whether the spacing fits into the MGRS tile boundaries.
Parameters
----------
spacing:
the target pixel spacing in meters
"""
# 109800 m is the edge length of one tile.
# the overlap between tiles is either 9780 or 9840 m.
overlap_edges = [9780, 9840, 109800]
options = []
for i in range(1, 61 * 10): # 60 is the largest spacing
if all([x % (i / 10) == 0 for x in overlap_edges]):
options.append(i / 10)
if spacing not in options:
raise RuntimeError(f'target spacing of {spacing} m does not align '
f'with the S2-MGRS tile size and overlaps.\n'
f'Options: {options}')
[docs]
def combine_polygons(
vector: Vector | list[Vector],
crs: int | str = 4326,
multipolygon: bool = False,
layer_name: str = 'combined'
) -> Vector:
"""
Combine polygon vector objects into one.
The output is a single vector object with the polygons either stored in
separate features or combined into a single multipolygon geometry.
Parameters
----------
vector:
the input vector object(s). Providing only one object only makes sense when `multipolygon=True`.
crs:
the target CRS. Default: EPSG:4326
multipolygon:
combine all polygons into one multipolygon?
Default False: write each polygon into a separate feature.
layer_name:
the layer name of the output vector object.
Returns
-------
the combined vector object
"""
if not isinstance(vector, list):
vector = [vector]
##############################################################################
# check geometry types
geometry_names = []
field_defs = []
for item in vector:
field_defs.extend(item.fieldDefs)
for feature in item.layer:
geom = feature.GetGeometryRef()
geometry_names.append(geom.GetGeometryName())
item.layer.ResetReading()
geom = None
geometry_names = list(set(geometry_names))
if not all(x == 'POLYGON' for x in geometry_names):
raise RuntimeError('All geometries must be of type POLYGON')
##############################################################################
vec = Vector(driver='Memory')
srs_out = crsConvert(crs, 'osr')
if multipolygon:
geom_type = ogr.wkbMultiPolygon
geom_out = [ogr.Geometry(geom_type)]
else:
geom_type = ogr.wkbPolygon
geom_out = []
fields = []
vec.addlayer(name=layer_name, srs=srs_out, geomType=geom_type)
for item in vector:
fieldnames = item.fieldnames
if item.srs.IsSame(srs_out):
coord_trans = None
else:
coord_trans = osr.CoordinateTransformation(item.srs, srs_out)
for feature in item.layer:
geom = feature.GetGeometryRef()
if coord_trans is not None:
geom.Transform(coord_trans)
if multipolygon:
geom_out[0].AddGeometry(geom.Clone())
else:
fields.append({x: feature.GetField(x) for x in fieldnames})
geom_out.append(geom.Clone())
item.layer.ResetReading()
geom = None
if multipolygon:
geom_out = geom_out[0].UnionCascaded()
vec.addfeature(geom_out)
else:
for field_def in field_defs:
if field_def.GetName() not in vec.fieldnames:
vec.layer.CreateField(field_def)
for i, geom in enumerate(geom_out):
vec.addfeature(geometry=geom, fields=fields[i])
geom_out = None
return vec
[docs]
def compute_hash(
file_path: str,
algorithm: str = 'sha256',
chunk_size: int = 8192,
multihash_encode: bool = True
) -> str:
"""
Compute the (multi)hash of a file using the specified algorithm.
Parameters
----------
file_path:
Path to the file.
algorithm:
Hash algorithm to use (default is 'sha256').
chunk_size:
Size of chunks to read from the file in bytes (default is 8192).
multihash_encode:
Encode the hash according to the
`multihash specification <https://github.com/multiformats/multihash>`_
(default is True)?
The hash generated by `hashlib` will be wrapped using
:func:`multiformats.multihash.wrap`.
Returns
-------
the hexadecimal hash string of the file.
See Also
--------
:mod:`hashlib`
:mod:`multiformats.multihash`
"""
# lookup between hashlib and multihash algorithm names; to be extended if necessary
algo_lookup = {'sha1': 'sha1',
'sha256': 'sha2-256',
'sha512': 'sha2-512'}
if algorithm not in algo_lookup.keys():
raise ValueError(f'Hash algorithm must be one of {algo_lookup.keys()}')
hash_func = getattr(hashlib, algorithm)()
with open(file_path, 'rb') as f:
while chunk := f.read(chunk_size):
hash_func.update(chunk)
if multihash_encode:
digest = hash_func.digest()
mh = multihash.wrap(digest, algo_lookup[algorithm])
return mh.hex()
else:
return hash_func.hexdigest()
[docs]
def datamask(
measurement: str,
dm_ras: str,
dm_vec: str
) -> str | None:
"""
Create data masks for a given image file.
The created raster data mask does not contain a simple mask of nodata values.
Rather, a boundary vector geometry containing all valid pixels is created and
then rasterized. This boundary geometry (single polygon) is saved as `dm_vec`.
In this case `dm_vec` is returned.
If the input image only contains nodata values, no raster data mask is created,
and an empty dummy vector mask is created. In this case the function will return
`None`.
Parameters
----------
measurement:
the binary image file
dm_ras:
the name of the raster data mask
dm_vec:
the name of the vector data mask
Returns
-------
`dm_vec` if the vector data mask contains a geometry or None otherwise
"""
def mask_from_array(arr, dm_vec, dm_ras, ref):
"""
Parameters
----------
arr: np.ndarray
dm_vec: str
dm_ras: str
ref: spatialist.raster.Raster
Returns
-------
str or None
"""
# create a dummy vector mask if the mask only contains 0 values
if len(arr[arr == 1]) == 0:
Path(dm_vec).touch(exist_ok=False)
return None
# vectorize the raster data mask
with vectorize(target=arr, reference=ref) as vec:
# compute a valid data boundary geometry (vector data mask)
with boundary(vec, expression="value=1") as bounds:
# rasterize the vector data mask
if not os.path.isfile(dm_ras):
rasterize(vectorobject=bounds, reference=ref,
outname=dm_ras)
# write the vector data mask
bounds.write(outfile=dm_vec)
return dm_vec
if os.path.isfile(dm_vec) and os.path.isfile(dm_ras):
return None if os.path.getsize(dm_vec) == 0 else dm_vec
with LockCollection([dm_vec, dm_ras]):
if not os.path.isfile(dm_vec):
if not os.path.isfile(dm_ras):
with Raster(measurement) as ras:
arr = ras.array()
# create a nodata mask
mask = ~np.isnan(arr)
del arr
out = mask_from_array(arr=mask, dm_vec=dm_vec,
dm_ras=dm_ras, ref=ras)
else:
# read the raster data mask
with Raster(dm_ras) as ras:
mask = ras.array()
out = mask_from_array(arr=mask, dm_vec=dm_vec,
dm_ras=dm_ras, ref=ras)
del mask
else:
if os.path.getsize(dm_vec) == 0:
out = None
else:
out = dm_vec
return out
[docs]
def date_to_utc(
date: str | datetime | None,
as_datetime: bool = False,
str_format: str = '%Y%m%dT%H%M%S'
) -> str | datetime | None:
"""
convert a date object to a UTC date string or datetime object.
Parameters
----------
date:
the date object to convert; timezone-unaware dates are interpreted as UTC.
as_datetime:
return a datetime object instead of a string?
str_format:
the output string format (ignored if `as_datetime` is True)
Returns
-------
the date string or datetime object in UTC time zone
"""
if date is None:
return date
elif isinstance(date, str):
out = dateutil.parser.parse(date)
elif isinstance(date, datetime):
out = date
else:
raise TypeError('date must be a string, datetime object or None')
if out.tzinfo is None:
out = out.replace(tzinfo=timezone.utc)
else:
out = out.astimezone(timezone.utc)
if not as_datetime:
out = out.strftime(str_format)
return out
[docs]
def defaultdict_to_dict(
d: defaultdict,
) -> dict:
"""
Convert a (nested) defaultdict to a regular dictionary.
Parameters
----------
d:
the defaultdict to convert
Returns
-------
the converted dictionary
"""
if isinstance(d, (dict, defaultdict)):
return {k: defaultdict_to_dict(v) for k, v in d.items()}
return d
[docs]
def generate_unique_id(
encoded_str: bytes,
length: int = 4
) -> str:
"""
Returns a unique product identifier as a hexadecimal string.
The CRC-16 algorithm used to compute the unique identifier is
CRC-CCITT (0xFFFF). The resulting CRC value is truncated to
the number of hexadecimal characters specified by the `length`
argument.
Parameters
----------
encoded_str:
A string that should be used to generate a unique id from.
The string needs to be encoded; e.g.: `'abc'.encode()`.
length:
The desired length of the output string in hexadecimal
characters (max: 4). Values higher than 4 will be capped
at 4, since CRC-16 only produces 16 bits.
Returns
-------
The unique product identifier (upper-case hexadecimal string).
"""
crc = binascii.crc_hqx(encoded_str, 0xffff)
max_length = 4 # Max characters for 16-bit CRC
length = max(1, min(length, max_length)) # Clamp between 1 and 4
mask = (1 << (length * 4)) - 1 # Each hex digit = 4 bits
p_id = f'{crc & mask:0{length}X}'
return p_id
[docs]
def get_kml() -> str:
"""
Download the Sentinel-2 MGRS grid KML file. The target folder is ~/cesard.
Returns
-------
the path to the KML file
"""
remote = ('https://sentiwiki.copernicus.eu/__attachments/1692737/'
'S2A_OPER_GIP_TILPAR_MPC__20151209T095117_V20150622T000000_21000101T000000_B00.zip')
local_path = os.path.join(os.path.expanduser('~'), '.cesard')
os.makedirs(local_path, exist_ok=True)
local = os.path.join(local_path, os.path.basename(remote).replace('.zip', '.kml'))
if os.path.isfile(local):
with Lock(local, soft=True):
return local
with Lock(local):
log.info(f'downloading MGRS grid KML file to {local_path}')
r = requests.get(remote)
with zipfile.ZipFile(io.BytesIO(r.content)) as zf:
zf.extractall(local_path)
return local
[docs]
def get_max_ext(
geometries: list[Vector],
buffer: float | None = None,
crs: str | int | None = None
) -> dict[str, float]:
"""
Gets the maximum extent from a list of geometries.
Parameters
----------
geometries:
List of :class:`~spatialist.vector.Vector` geometries.
buffer:
The buffer in units of the geometries' CRS to add to the extent.
crs:
The target CRS of the extent. If None (default) the extent is
expressed in the CRS of the input geometries.
Returns
-------
The maximum extent of the selected :class:`~spatialist.vector.Vector`
geometries including the chosen buffer.
"""
max_ext = {}
crs_list = []
for geo in geometries:
crs_list.append(f"EPSG:{geo.getProjection('epsg')}")
if len(max_ext.keys()) == 0:
max_ext = geo.extent
else:
ext = geo.extent
for key in ['xmin', 'ymin']:
if ext[key] < max_ext[key]:
max_ext[key] = ext[key]
for key in ['xmax', 'ymax']:
if ext[key] > max_ext[key]:
max_ext[key] = ext[key]
crs_list = list(set(crs_list))
if len(crs_list) > 1:
raise RuntimeError(f'The input geometries are in different CRSs: {crs_list}')
max_ext = dict(max_ext)
if buffer is not None:
max_ext['xmin'] -= buffer
max_ext['xmax'] += buffer
max_ext['ymin'] -= buffer
max_ext['ymax'] += buffer
if crs is not None:
with bbox(coordinates=max_ext, crs=crs_list[0]) as geo:
geo.reproject(projection=crs)
max_ext = geo.extent
return max_ext
[docs]
def get_tmp_name(suffix: str) -> str:
"""
Get the name of a temporary file with defined suffix.
Files are placed in a subdirectory 'cesard' of the regular
temporary directory so the latter is not flooded with too
many files in case they are not properly deleted.
Parameters
----------
suffix: str
the file suffix/extension, e.g. '.tif'
Returns
-------
the temporary file name
"""
tmpdir = os.path.join(tempfile.gettempdir(), 'cesard')
os.makedirs(tmpdir, exist_ok=True)
return tempfile.NamedTemporaryFile(suffix=suffix, dir=tmpdir).name
[docs]
def group_by_attr(
items: List[T],
key_fn: Callable[[T], K]
) -> List[List[T]]:
"""
Group items based on a key function.
Parameters
----------
items:
The list of arbitrary items to group.
key_fn:
A function that extracts a key from each item.
Returns
-------
A list of groups, where each group is a list of items with the same key.
Example
-------
>>> list_in = ['abc', 'axy', 'brt', 'btk']
>>> print(group_by_attr(list_in, lambda x: x[0]))
[['abc', 'axy'], ['brt', 'btk']]
>>> list_in = [{'a': 1}, {'a': 2}, {'a': 1}, {'a': 2}]
>>> print(group_by_attr(list_in, lambda x: x['a']))
[[{'a': 1}, {'a': 1}], [{'a': 2}, {'a': 2}]]
"""
grouped = defaultdict(list)
for item in items:
key = key_fn(item)
grouped[key].append(item)
return list(grouped.values())
[docs]
def group_by_time(
scenes: list[pyroSAR.drivers.ID | str],
time: int | float = 3
) -> list[list[pyroSAR.drivers.ID]]:
"""
Group scenes by their acquisition time difference.
Parameters
----------
scenes:
a list of image names
time:
a time difference in seconds by which to group the scenes.
The default of 3 seconds incorporates the overlap between SLCs.
Returns
-------
a list of sub-lists containing the file names of the grouped scenes
"""
# sort images by time stamp
scenes = identify_many(scenes, sortkey='start')
if len(scenes) < 2:
return [scenes]
groups = [[scenes[0]]]
group = groups[0]
for i in range(1, len(scenes)):
start = datetime.strptime(scenes[i].start, '%Y%m%dT%H%M%S')
stop_pred = datetime.strptime(scenes[i - 1].stop, '%Y%m%dT%H%M%S')
diff = abs((stop_pred - start).total_seconds())
if diff <= time:
group.append(scenes[i])
else:
groups.append([scenes[i]])
group = groups[-1]
return groups
[docs]
def pixel_size_degrees(
lon: float, lat: float,
xres: float, yres: float
) -> tuple[float, float]:
"""
Convert a pixel size from meters to degrees.
Parameters
----------
lon:
longitude in degrees
lat:
latitude in degrees
xres:
x resolution in meters
yres:
y resolution in meters
Returns
-------
the x and y resolution in degrees
See Also
--------
pyproj.Geod.fwd
"""
geod = Geod(ellps="WGS84")
lon2, lat2, _ = geod.fwd(lon, lat, az=0, dist=yres)
yres_deg = lat2 - lat
lon3, lat3, _ = geod.fwd(lon, lat, az=90, dist=xres)
xres_deg = lon3 - lon
return xres_deg, yres_deg
[docs]
def vrt_add_overviews(
vrt: str,
overviews: list[int],
resampling: str = 'AVERAGE'
) -> None:
"""
Add overviews to an existing VRT file.
Existing overviews will be overwritten.
Parameters
----------
vrt:
the VRT file
overviews:
the overview levels
resampling:
the overview resampling method
"""
tree = etree.parse(vrt)
root = tree.getroot()
ovr = root.find('OverviewList')
if ovr is None:
ovr = etree.SubElement(root, 'OverviewList')
ovr.text = ' '.join([str(x) for x in overviews])
ovr.attrib['resampling'] = resampling.lower()
etree.indent(root)
tree.write(vrt, pretty_print=True, xml_declaration=False, encoding='utf-8')