Commit 422cf329 authored by PIOLLE's avatar PIOLLE
Browse files

revised internal encoding

parent 71b9f482
...@@ -34,18 +34,18 @@ def default_fill_value(obj): ...@@ -34,18 +34,18 @@ def default_fill_value(obj):
raise TypeError("Unexpected object type: ", type(obj), obj) raise TypeError("Unexpected object type: ", type(obj), obj)
if dtype.name == 'int16': if dtype.name == 'int16':
return -32768 return numpy.int16(-32768)
elif dtype.name == 'uint16': elif dtype.name == 'uint16':
return 65535 return numpy.uint16(65535)
elif dtype.name == 'int8': elif dtype.name == 'int8':
return -128 return numpy.int8(-128)
elif dtype.name == 'uint8': elif dtype.name == 'uint8':
return 255 return numpy.uint8(255)
else: else:
return numpy.ma.default_fill_value(dtype) return numpy.ma.default_fill_value(dtype)
def get_masked_values(data, fill_value, silent=False): def get_masked_values(fieldname, data, fill_value, silent=False):
"""fix masked values. Required as xarray data can't store masked values """fix masked values. Required as xarray data can't store masked values
or nan for non-float types""" or nan for non-float types"""
if fill_value is None and not isinstance(data, numpy.ma.core.MaskedArray): if fill_value is None and not isinstance(data, numpy.ma.core.MaskedArray):
...@@ -65,8 +65,8 @@ def get_masked_values(data, fill_value, silent=False): ...@@ -65,8 +65,8 @@ def get_masked_values(data, fill_value, silent=False):
data = numpy.ma.masked_equal(data, fill_value, copy=False) data = numpy.ma.masked_equal(data, fill_value, copy=False)
elif not silent: elif not silent:
logging.warning( logging.warning(
'values equal to {} are marked as missing values' 'values equal to {} are marked as missing values in {}'
.format(fill_value)) .format(fill_value, fieldname))
return data return data
......
...@@ -26,6 +26,7 @@ from scipy.ndimage.morphology import binary_dilation ...@@ -26,6 +26,7 @@ from scipy.ndimage.morphology import binary_dilation
import shapely.geometry import shapely.geometry
import xarray as xr import xarray as xr
import cerbere.cfconvention
from ..cfconvention import ( from ..cfconvention import (
default_profile, default_fill_value, CF_AUTHORITY, DEFAULT_TIME_UNITS, default_profile, default_fill_value, CF_AUTHORITY, DEFAULT_TIME_UNITS,
get_masked_values get_masked_values
...@@ -43,6 +44,26 @@ class OpenMode(Enum): ...@@ -43,6 +44,26 @@ class OpenMode(Enum):
READ_WRITE: str = 'r+' READ_WRITE: str = 'r+'
# dict name for saving source file's encoding
S_ENCODING = 'cerbere_src_encoding'
class Encoding(Enum):
"""attributes for saving the encoding of a source file"""
# attribute for marking variables with no fill value (like masks)
UNMASKED: str = 'no_fillvalue'
# source file's saved encoding attributes
FILLVALUE = '_FillValue'
SCALE = 'scale_factor'
OFFSET = 'add_offset'
DTYPE = 'dtype'
# attribute for override dtype
C_DTYPE = 'cerbere_dtype'
# standard geolocation coordinates # standard geolocation coordinates
GEOCOORDINATES = [u'time', u'lat', u'lon', u'z', u'depth', u'height'] GEOCOORDINATES = [u'time', u'lat', u'lon', u'z', u'depth', u'height']
REQUIRED_GEOCOORDINATES = [u'time', u'lat', u'lon'] REQUIRED_GEOCOORDINATES = [u'time', u'lat', u'lon']
...@@ -602,7 +623,7 @@ class Dataset(ABC): ...@@ -602,7 +623,7 @@ class Dataset(ABC):
return dimname return dimname
@property @property
def url(self) -> str: def url(self) -> Path:
"""Return the url of the file storing the dataset""" """Return the url of the file storing the dataset"""
if isinstance(self.dataset, Dataset): if isinstance(self.dataset, Dataset):
return self.dataset.url return self.dataset.url
...@@ -662,6 +683,22 @@ class Dataset(ABC): ...@@ -662,6 +683,22 @@ class Dataset(ABC):
"""The date the dataset file was generated""" """The date the dataset file was generated"""
return datetime.datetime.fromtimestamp(os.path.getctime(self.url)) return datetime.datetime.fromtimestamp(os.path.getctime(self.url))
def _save_encoding(self):
# save original encoding
for v in self._std_dataset.variables.values():
encoding = {}
# mark variables with no fill value
if '_FillValue' not in v.encoding:
encoding[Encoding.UNMASKED] = True
encoding[Encoding.FILLVALUE] = self._xr_fillvalue(v)
encoding[Encoding.DTYPE] = v.encoding.get('dtype', None)
encoding[Encoding.OFFSET] = v.encoding.get('add_offset', None)
encoding[Encoding.SCALE] = v.encoding.get('scale_factor', None)
v.encoding[S_ENCODING] = encoding
def _open_dataset(self, **kwargs) -> 'xr.Dataset': def _open_dataset(self, **kwargs) -> 'xr.Dataset':
""" """
Open a file (netCDF, ZArr,...) and returns its content as a xarray_ Open a file (netCDF, ZArr,...) and returns its content as a xarray_
...@@ -715,15 +752,8 @@ class Dataset(ABC): ...@@ -715,15 +752,8 @@ class Dataset(ABC):
**self._native2std_dim} **self._native2std_dim}
) )
# remove original encoding # save original encoding
for v in self.dataset.variables.values(): self._save_encoding()
attrs = []
for attr in v.encoding:
if attr in ['units', '_FillValue']:
continue
attrs.append(attr)
for attr in attrs:
v.encoding.pop(attr)
def _transform(self): def _transform(self):
"""apply some transformation to original dataset to make it more """apply some transformation to original dataset to make it more
...@@ -888,8 +918,7 @@ class Dataset(ABC): ...@@ -888,8 +918,7 @@ class Dataset(ABC):
@property @property
def attrs(self) -> MutableMapping[Hashable, Any]: def attrs(self) -> MutableMapping[Hashable, Any]:
"""Mapping from global attribute names to value. """Mapping from global attribute names to value."""
"""
return self._std_dataset.attrs return self._std_dataset.attrs
@attrs.setter @attrs.setter
...@@ -1116,8 +1145,8 @@ class Dataset(ABC): ...@@ -1116,8 +1145,8 @@ class Dataset(ABC):
return self.get_field(fieldname).fill_value return self.get_field(fieldname).fill_value
def get_values(self, def get_values(self,
fieldname: Hashable, fieldname: str,
index: Mapping[Hashable, slice] = None, index: Mapping[str, slice] = None,
as_masked_array: bool = True, as_masked_array: bool = True,
expand: bool = False, expand: bool = False,
expand_dims: List[str] = None, expand_dims: List[str] = None,
...@@ -1200,14 +1229,13 @@ class Dataset(ABC): ...@@ -1200,14 +1229,13 @@ class Dataset(ABC):
values = values.transpose( values = values.transpose(
*(list(rearranged_dims)), transpose_coords=True) *(list(rearranged_dims)), transpose_coords=True)
if not as_masked_array: if as_masked_array:
return get_masked_values( values = values.to_masked_array(copy=False)
values,
self.get_field_fillvalue(fieldname)) return get_masked_values(
else: fieldname,
return get_masked_values( values,
values.to_masked_array(copy=False), self.get_field_fillvalue(fieldname))
self.get_field_fillvalue(fieldname))
def set_values( def set_values(
self, self,
...@@ -1672,7 +1700,7 @@ class Dataset(ABC): ...@@ -1672,7 +1700,7 @@ class Dataset(ABC):
return loc, geoloc, return loc, geoloc,
def get_closest_spatial_location(self, lon, lat): def closest_spatial_location(self, lon, lat):
"""Get closest dataset lat/lon location to given coordinates. """Get closest dataset lat/lon location to given coordinates.
Use pythagorian differences on lat/lon values so take the result with Use pythagorian differences on lat/lon values so take the result with
...@@ -1797,6 +1825,7 @@ class Dataset(ABC): ...@@ -1797,6 +1825,7 @@ class Dataset(ABC):
if times.count() == 0: if times.count() == 0:
logging.warning('No valid time in dataset.') logging.warning('No valid time in dataset.')
return return
self.attrs['time_coverage_start'] = pd.Timestamp( self.attrs['time_coverage_start'] = pd.Timestamp(
times[~np.isnat(times)].min() times[~np.isnat(times)].min()
).to_pydatetime() ).to_pydatetime()
...@@ -1920,12 +1949,138 @@ class Dataset(ABC): ...@@ -1920,12 +1949,138 @@ class Dataset(ABC):
"""Close file""" """Close file"""
self._std_dataset.close() self._std_dataset.close()
@staticmethod
def _xr_fillvalue(xvar, default=None):
"""Return the fill value of an xarray
Which can be either in attrs or encoding attributes
"""
return xvar.encoding.get(
'_FillValue', xvar.attrs.get('_FillValue', default))
@staticmethod
def _xr_units(xvar, default=None):
"""Return the units of an xarray
Which can be either in attrs or encoding attributes
"""
return xvar.encoding.get('units', xvar.attrs.get('units', default))
def _to_netcdf(self, saved_dataset, keep_src_encoding=False):
# ensure proper type in output attributes for the considered format
self._format_nc_attrs(saved_dataset)
def has_scaling(attrs):
return any([_ in attrs for _ in ['add_offset', 'scale_factor']])
def has_fillvalue(encoding, dtype) -> bool:
# xarray enforces _FillValue for floats
if dtype in [np.dtype(np.float32), np.dtype(np.float64)]:
return True
if S_ENCODING not in encoding:
return '_FillValue' in encoding
return not encoding[S_ENCODING].get(
Encoding.UNMASKED, False)
# ensure original or overriding encoding
for v in saved_dataset.variables:
svar = saved_dataset[v]
encoding = svar.encoding
if 'zlib' not in encoding:
encoding['zlib'] = True
if 'complevel' not in encoding:
encoding['complevel'] = 4
if np.issubdtype(svar.dtype, np.datetime64) \
and self._xr_units(svar) is None:
svar.encoding['units'] = cerbere.cfconvention.DEFAULT_TIME_UNITS
# save in original data type if not overriden by output
# format profile
if keep_src_encoding:
if S_ENCODING not in encoding:
raise ValueError(
'No original encoding for {}. Where these data '
'read from a file?'.format(v))
for att in Encoding:
if encoding[S_ENCODING][att] is not None:
encoding[att] = encoding[S_ENCODING][att]
if has_scaling(encoding):
dtype = svar.dtype
else:
# profile > source file (S_ENCODING) > data array dtype
if S_ENCODING in encoding:
dtype = np.dtype(encoding.get(
C_DTYPE,
encoding[S_ENCODING].get(Encoding.DTYPE, svar.dtype)))
else:
dtype = np.dtype(encoding.get(C_DTYPE, svar.dtype))
if np.issubdtype(dtype, np.datetime64):
dtype = np.dtype(np.float64)
encoding['dtype'] = dtype
# save a _FillValue matching the encoding data type
if dtype != np.object and has_fillvalue(encoding, dtype):
# profile > source file > data array fill value
default_fv = default_fill_value(dtype)
if S_ENCODING in encoding:
default_fv = encoding[S_ENCODING].get(
Encoding.FILLVALUE, default_fv)
fillv = self._xr_fillvalue(svar, default_fv)
if np.issubdtype(type(fillv), np.datetime64):
# xarray might change automatically the _FillValue encoding
# to NaT, then overwriting the previously encode value
fillv = default_fv
elif np.dtype(type(fillv)) != dtype:
logging.debug(
'_FillValue changed from {}({}) to {}({}) when '
'saving {}'
.format(fillv, np.dtype(type(fillv)),
default_fill_value(dtype), dtype,
svar.name))
fillv = default_fill_value(dtype)
encoding['_FillValue'] = fillv
elif '_FillValue' in encoding:
# no fill value permitted by xarray for object type
encoding.pop('_FillValue')
# adjust missing value attribute types if packing is applied
for matt in ['valid_min', 'valid_max', 'valid_range']:
if matt in svar.attrs:
svar.attrs[matt] = np.array(
svar.attrs[matt],
dtype=dtype)
# remove cerbere specifics
if S_ENCODING in encoding:
encoding.pop(S_ENCODING)
saved_dataset[v].encoding = encoding
saved_dataset.to_netcdf(
path=self._url,
mode={
OpenMode.READ_WRITE: 'a',
OpenMode.WRITE_NEW: 'w'
}[self._mode],
format=self._format,
engine='netcdf4'
)
def save(self, def save(self,
dest: Union[str, 'Dataset', None] = None, dest: Union[str, 'Dataset', None] = None,
format: str = 'NETCDF4', format: str = 'NETCDF4',
profile: str = 'default_saving_profile.yaml', profile: str = 'default_saving_profile.yaml',
force_profile: bool = False force_profile: bool = False,
): keep_src_encoding: bool = False):
""" """
Args: Args:
dest (str, optional): save to a new file, whose path is provided in dest (str, optional): save to a new file, whose path is provided in
...@@ -1934,6 +2089,9 @@ class Dataset(ABC): ...@@ -1934,6 +2089,9 @@ class Dataset(ABC):
apply before saving (or default formatting profile is used). apply before saving (or default formatting profile is used).
force_profile (bool, optional): force profile attribute values to force_profile (bool, optional): force profile attribute values to
supersede existing ones in dataset attributes. supersede existing ones in dataset attributes.
keep_src_encoding (bool): keep original dtype, _FillValue
and scaling if any (through `add_offset` or `scale_factor`
attributes) as in the source data.
""" """
if isinstance(self.dataset, Dataset): if isinstance(self.dataset, Dataset):
return self.dataset.save( return self.dataset.save(
...@@ -1988,55 +2146,7 @@ class Dataset(ABC): ...@@ -1988,55 +2146,7 @@ class Dataset(ABC):
# save to chosen format # save to chosen format
if 'NETCDF' in self._format: if 'NETCDF' in self._format:
self._to_netcdf(saved_dataset, keep_src_encoding)
# ensure proper type in output attributes for the considered format
self._format_nc_attrs(saved_dataset)
for v in saved_dataset.variables:
encoding = saved_dataset[v].encoding
if 'zlib' not in encoding:
encoding['zlib'] = True
if 'complevel' not in encoding:
encoding['complevel'] = 4
if saved_dataset[v].dtype != np.object:
if '_FillValue' in saved_dataset[v].attrs:
fillv = saved_dataset[v].attrs.pop('_FillValue')
elif '_FillValue' in encoding:
fillv = encoding.pop('_FillValue')
else:
fillv = default_fill_value(saved_dataset[v].dtype)
if (('_no_missing_value' not in encoding)
or not encoding['_no_missing_value']):
encoding['_FillValue'] = fillv
saved_dataset[v].encoding.update(encoding)
else:
# no fill value permitted by xarray for object type
saved_dataset[v].encoding.pop('_FillValue')
# adjust missing value attribute types if packing is applied
for matt in ['valid_min', 'valid_max', 'valid_range']:
if 'dtype' not in saved_dataset[v].encoding:
continue
if matt in saved_dataset[v].attrs:
saved_dataset[v].attrs[matt] = np.array(
saved_dataset[v].attrs[matt],
dtype=saved_dataset[v].encoding['dtype'])
saved_dataset.to_netcdf(
path=self._url,
mode={
OpenMode.READ_WRITE: 'a',
OpenMode.WRITE_NEW: 'w'
}[self._mode],
format=self._format,
engine='netcdf4'
)
else: else:
logging.error('Unknown output format : {}'.format(self._format)) logging.error('Unknown output format : {}'.format(self._format))
...@@ -2112,6 +2222,8 @@ class Dataset(ABC): ...@@ -2112,6 +2222,8 @@ class Dataset(ABC):
for att in attrs[v]: for att in attrs[v]:
if att not in dataset.variables[v].encoding or force_profile: if att not in dataset.variables[v].encoding or force_profile:
value = attrs[v][att] value = attrs[v][att]
if att == 'dtype':
att = C_DTYPE
if value is None: if value is None:
continue continue
dataset.variables[v].encoding[att] = value dataset.variables[v].encoding[att] = value
...@@ -2179,7 +2291,8 @@ class Dataset(ABC): ...@@ -2179,7 +2291,8 @@ class Dataset(ABC):
return attrval.wkt return attrval.wkt
return attrval return attrval
def get_collection_id(self) -> str: @property
def collection_id(self) -> str:
"""return the identifier of the product collection""" """return the identifier of the product collection"""
raise NotImplementedError raise NotImplementedError
......
...@@ -249,7 +249,7 @@ class Field(object): ...@@ -249,7 +249,7 @@ class Field(object):
"""Return the field values a xarray DataArray""" """Return the field values a xarray DataArray"""
if self.dataset is None: if self.dataset is None:
return cf.get_masked_values( return cf.get_masked_values(
self._array, self.fill_value, silent=silent) self._array.name, self._array, self.fill_value, silent=silent)
else: else:
return self.dataset.get_values( return self.dataset.get_values(
self._array.name, self._array.name,
...@@ -356,7 +356,7 @@ class Field(object): ...@@ -356,7 +356,7 @@ class Field(object):
@property @property
def dimnames(self) -> Tuple[str]: def dimnames(self) -> Tuple[str]:
"""Tuple of the field's dimension names""" """Tuple of the field's dimension names"""
return tuple(self.dims.keys()) return self.dims
def get_dimsize(self, dimname) -> int: def get_dimsize(self, dimname) -> int:
"""Return the size of a field dimension""" """Return the size of a field dimension"""
...@@ -874,6 +874,8 @@ class Field(object): ...@@ -874,6 +874,8 @@ class Field(object):
# detach from any dataset # detach from any dataset
new_field._array.encoding['_attached_dataset'] = None new_field._array.encoding['_attached_dataset'] = None
#new_field._array.encoding['cerbere_src_encoding'] =
# self._array.encoding['cerbere_src_encoding']
if prefix is not None: if prefix is not None:
new_field.set_name(prefix + new_field.name) new_field.set_name(prefix + new_field.name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment