Commit d67a5ba0 authored by PIOLLE's avatar PIOLLE
Browse files

improved internal encoding + more unit tests

parent a2cea61a
...@@ -22,77 +22,6 @@ DEFAULT_TIME_UNITS = 'seconds since 1981-01-01 00:00:00' ...@@ -22,77 +22,6 @@ DEFAULT_TIME_UNITS = 'seconds since 1981-01-01 00:00:00'
CF_AUTHORITY = 'CF-1.7' CF_AUTHORITY = 'CF-1.7'
def default_fill_value(obj):
"""Returns the default fill value for a specific type"""
if isinstance(obj, numpy.dtype):
dtype = obj
elif isinstance(obj, (str, type)):
dtype = numpy.dtype(obj)
elif isinstance(obj, numpy.ndarray):
dtype = obj.dtype
else:
raise TypeError("Unexpected object type: ", type(obj), obj)
if dtype.name == 'int16':
return numpy.int16(-32768)
elif dtype.name == 'uint16':
return numpy.uint16(65535)
elif dtype.name == 'int8':
return numpy.int8(-128)
elif dtype.name == 'uint8':
return numpy.uint8(255)
else:
return numpy.ma.default_fill_value(dtype)
def get_masked_values(fieldname, data, fill_value, silent=False):
"""fix masked values. Required as xarray data can't store masked values
or nan for non-float types"""
if fill_value is None and not isinstance(data, numpy.ma.core.MaskedArray):
# no masked data
return data
if isinstance(data, (numpy.ma.core.MaskedArray, numpy.ndarray)):
if data.dtype.name in [
'float16', 'float32', 'float64', 'complex64', 'complex128']:
if fill_value is not None:
data.set_fill_value(fill_value)
return data
elif numpy.issubdtype(data.dtype, numpy.datetime64):
return data
else:
# mask fill values for int types
data = numpy.ma.masked_equal(data, fill_value, copy=False)
elif not silent:
logging.warning(
'values equal to {} are marked as missing values in {}'
.format(fill_value, fieldname))
return data
def set_masked_values(data, fill_value):
"""replace masked values with fill value. Required as xarray data can't
store masked values or nan for non-float types"""
if fill_value is None and not isinstance(data, numpy.ma.core.MaskedArray):
# no masked data
return data
if isinstance(data, numpy.ma.core.MaskedArray):
if data.dtype.name in [
'float16', 'float32', 'float64', 'complex64', 'complex128']:
fill_value = numpy.nan
elif numpy.issubdtype(data.dtype, numpy.datetime64):
fill_value = numpy.datetime64('NaT')
else:
if fill_value is None:
fill_value = data.fill_value
return data.filled(fill_value)
return data
def default_profile( def default_profile(
profile: str='default_saving_profile.yaml') -> Mapping[str, Any]: profile: str='default_saving_profile.yaml') -> Mapping[str, Any]:
"""Returns a list of default settings for storing data and metadata """Returns a list of default settings for storing data and metadata
......
...@@ -10,3 +10,249 @@ file formats and conventions. ...@@ -10,3 +10,249 @@ file formats and conventions.
ncdataset ncdataset
ghrsstncdataset ghrsstncdataset
""" """
from enum import Enum
import logging
from typing import Union
import numpy as np
import xarray as xr
# cerbere internals for encoding/decoding data to file or memory
ENCODING = 'cerbere'
class Encoding(Enum):
"""attributes for saving the encoding of a source file"""
# attribute for marking variables with no fill value (like masks)
UNMASKED: str = 'no_fillvalue'
# parent dataset the field belongs to
DATASET: str = '_attached_dataset'
# source file's saved encoding attributes
IO_FILLVALUE: str = 'io__FillValue'
IO_SCALE: str = 'io_scale_factor'
IO_OFFSET: str = 'io_add_offset'
IO_DTYPE: str = 'io_dtype'
# in memory numpy encoding
M_DTYPE: str = 'm_dype'
M_MASK: Union[np.ndarray, bool] = 'mask'
# individual components for vectorial fields
COMPONENTS: str = 'components'
# field status wrt to its copy on file
STATUS: str = 'status'
def default_fill_value(obj):
"""Returns the default fill value for a specific type"""
if isinstance(obj, np.dtype):
dtype = obj
elif isinstance(obj, (str, type)):
dtype = np.dtype(obj)
elif isinstance(obj, np.ndarray):
dtype = obj.dtype
else:
raise TypeError("Unexpected object type: ", type(obj), obj)
if dtype.name == 'int16':
return np.int16(-32768)
elif dtype.name == 'uint16':
return np.uint16(65535)
elif dtype.name == 'int8':
return np.int8(-128)
elif dtype.name == 'uint8':
return np.uint8(255)
else:
return np.ma.default_fill_value(dtype)
def infer_fillvalue(fill_value, data, dtype):
if fill_value is not None:
return fill_value
# use valid xarray missing values if possible (float or datetime64)
if np.issubdtype(data.dtype, np.floating) or \
np.issubdtype(data.dtype, np.complexfloating):
fill_value = np.nan
elif np.issubdtype(dtype, np.datetime64):
fill_value = np.datetime64('NaT')
else:
try:
fill_value = data.fill_value
except:
fill_value = default_fill_value(dtype)
return fill_value
def to_cerbere_dataarray(
data: Union[xr.DataArray, np.ndarray, np.ma.MaskedArray],
fill_value: float = None,
dtype: np.dtype = None,
no_missing_value: bool = False,
**kwargs) -> xr.DataArray:
"""Converts to a cerbere DataArray.
A cerbere DataArray adds specific internal attributes into the
`encoding` property of a DataArray, such as:
* the mask of missing values (like for a numpy MaskedArray),
which is not supported by DataArray (which replaces masked values
with NaN)
* the scientific dtype of the data (which may be different for the in
memory encoding since xarray will transform the dtype of MaskedArray
into floats.
Used for internal data encoding as xarray data can't store masked
values or nan for non-float types and transforms masked arrays to float
arrays, which leads to a changed dtype of an array. When using cerbere
`get_values` method, the data returned to the users will be in the
internal DataArray dtype (with NaN for fill values when any), unless
`decoding` keyword is set to True (NaN are then replaced be the fill_value
attribute and the dtype of the returned values is the scientific one). The
mask can be retrieved with the `mask` property of a Field object.
Args:
data: Array to be converted to masked DataArray
fill_value: Fill value. Must be of the same type as dtype
dtype: array's scientific data type. It can be different
from `dtype` of `data` array. Data returned with `get_values`
will be forced to this type.
**kwargs: other DataArray creation arguments
"""
if dtype is None:
dtype = data.dtype
if dtype != data.dtype:
logging.debug('Data internally stored as {} but tht scientific dtype '
'is {}'.format(data.dtype, dtype))
if no_missing_value and fill_value is not None:
raise ValueError('no_missing_value is set to True but a fill_value was '
'also provided')
if fill_value is None and data.dtype != dtype:
# There is no reason here for xarray to have a different data type
# than the scientific one.
logging.warning('internal data type {} should be the same than '
'scientific data type {}'.format(data.dtype, dtype))
if isinstance(data, np.ma.core.MaskedArray):
if no_missing_value:
if data.count() != data.size:
raise ValueError('no_missing_value set to True but data '
'contains masked values')
logging.warning("data was provided as a masked array with "
"no_missing_value set to True. data will be "
"forced to non masked array.")
mask = None
else:
mask = data.mask
else:
# mask - case ndarray, Dask array DataArray
try:
# numpy array type
mask = np.isnan(data) | (data == fill_value)
except:
# xarray array type
mask = data.isnull() | (data == fill_value)
if np.all(mask):
mask = True
elif np.all(~mask):
mask = False
if fill_value is None:
# no explicit fill values, no nan in data : this is a unmasked
# variable
mask = None
# wrap into xarray DataArray
data = xr.DataArray(data=data, **kwargs)
if ENCODING not in data.encoding:
data.encoding[ENCODING] = {}
data.encoding[ENCODING][Encoding.M_DTYPE] = dtype
# no mask in the variable
data.encoding[ENCODING][Encoding.M_MASK] = mask
# ensure there is a defined fill value
if mask is not None:
data.encoding['_FillValue'] = infer_fillvalue(
fill_value, data, dtype)
return data
def io_encoding_dtype(data: xr.DataArray) -> np.dtype:
"""guess the scientific dtype from data read on file"""
dtype = data.encoding.get(
'scale_factor', data.encoding.get('add_offset', None))
fillv = data.encoding.get(
'_FillValue', data.attrs.get('_FillValue', None))
if dtype is None:
# no scaling of data -> dtype should be unchanged
if fillv is None:
return data.dtype
if np.dtype(fillv) != data.dtype:
# no scaling was applied but xarray may have changed the dtype for
# instance if fill values were replaced with NaN. Returns the
# intended scientific dtype indicated by the _FillValue attr.
return np.dtype(fillv)
else:
if dtype != data.dtype:
# xarray may have changed the dtype for instance if fill values
# were replaced with NaN. Returns the intended scientific dtype
# indicated by the scaling attributes
return dtype
return dtype
def from_cerbere_dataarray(
data,
as_masked_array: bool = True,
decoding: bool = False) -> Union[xr.DataArray, np.ma.MaskedArray]:
""""""
if not isinstance(data, xr.DataArray):
raise TypeError('unexpected data type: {}'.format(type(data)))
if ENCODING not in data.encoding or \
Encoding.M_DTYPE not in data.encoding[ENCODING]:
# the data were not yet internally encoded or were read from an
# externally defined DataArray or from a file => us the native encoding
fillv = data.encoding.get('_FillValue', data.attrs.get(
'_FillValue', None))
if as_masked_array:
dtype = io_encoding_dtype(data)
data = data.astype(dtype, copy=False).to_masked_array(copy=False)
data.set_fill_value(fillv)
elif decoding:
dtype = io_encoding_dtype(data)
data = data.astype(dtype, copy=False)
return data
if decoding and not as_masked_array:
# return the data in their scientific dtype, replacing NaN with fill
# values
return data.fillna(data.encoding['_FillValue']).astype(
data.encoding[ENCODING][Encoding.M_DTYPE], copy=False)
if as_masked_array:
mdata = data.astype(
data.encoding[ENCODING][Encoding.M_DTYPE],
copy=False).to_masked_array(copy=False)
mdata.set_fill_value(data.encoding['_FillValue'])
mdata.mask = data.encoding[ENCODING][Encoding.M_MASK]
return mdata
# return the data in their internal dtype
return data
...@@ -27,10 +27,8 @@ import shapely.geometry ...@@ -27,10 +27,8 @@ import shapely.geometry
import xarray as xr import xarray as xr
import cerbere.cfconvention import cerbere.cfconvention
from ..cfconvention import ( from ..cfconvention import default_profile, CF_AUTHORITY, DEFAULT_TIME_UNITS
default_profile, default_fill_value, CF_AUTHORITY, DEFAULT_TIME_UNITS, import cerbere.dataset as internals
get_masked_values
)
from .field import Field from .field import Field
...@@ -43,24 +41,6 @@ class OpenMode(Enum): ...@@ -43,24 +41,6 @@ class OpenMode(Enum):
WRITE_NEW: str = 'w' WRITE_NEW: str = 'w'
READ_WRITE: str = 'r+' READ_WRITE: str = 'r+'
# dict name for saving source file's encoding
S_ENCODING = 'cerbere_src_encoding'
class Encoding(Enum):
"""attributes for saving the encoding of a source file"""
# attribute for marking variables with no fill value (like masks)
UNMASKED: str = 'no_fillvalue'
# source file's saved encoding attributes
FILLVALUE = '_FillValue'
SCALE = 'scale_factor'
OFFSET = 'add_offset'
DTYPE = 'dtype'
# attribute for override dtype # attribute for override dtype
C_DTYPE = 'cerbere_dtype' C_DTYPE = 'cerbere_dtype'
...@@ -484,16 +464,16 @@ class Dataset(ABC): ...@@ -484,16 +464,16 @@ class Dataset(ABC):
for var in data.keys(): for var in data.keys():
if isinstance(data[var], Field): if isinstance(data[var], Field):
data[var] = to_dict(data[var].to_dataarray()) data[var] = to_dict(data[var].to_xarray())
if 'coords' in data.keys(): if 'coords' in data.keys():
for var, value in data['coords'].items(): for var, value in data['coords'].items():
if isinstance(value, Field): if isinstance(value, Field):
data[var] = to_dict(value.to_dataarray()) data[var] = to_dict(value.to_xarray())
if 'data_vars' in data.keys(): if 'data_vars' in data.keys():
for var, value in data['data_vars'].items(): for var, value in data['data_vars'].items():
if isinstance(value, Field): if isinstance(value, Field):
data[var] = to_dict(value.to_dataarray()) data[var] = to_dict(value.to_xarray())
# create a dataset # create a dataset
self.dataset = xr.Dataset.from_dict(data) self.dataset = xr.Dataset.from_dict(data)
...@@ -684,20 +664,27 @@ class Dataset(ABC): ...@@ -684,20 +664,27 @@ class Dataset(ABC):
return datetime.datetime.fromtimestamp(os.path.getctime(self.url)) return datetime.datetime.fromtimestamp(os.path.getctime(self.url))
def _save_encoding(self): def _save_encoding(self):
"""store in cerbere encoding the I/O encoding of a variable"""
# save original encoding # save original encoding
for v in self._std_dataset.variables.values(): for v in self._std_dataset.variables.values():
encoding = {} encoding = {}
# mark variables with no fill value # mark variables with no fill value
if '_FillValue' not in v.encoding: if '_FillValue' not in v.encoding:
encoding[Encoding.UNMASKED] = True encoding[internals.Encoding.UNMASKED] = True
encoding[Encoding.FILLVALUE] = self._xr_fillvalue(v) encoding[internals.Encoding.IO_FILLVALUE] = self._xr_fillvalue(v)
encoding[Encoding.DTYPE] = v.encoding.get('dtype', None) encoding[internals.Encoding.IO_DTYPE] = v.encoding.get(
encoding[Encoding.OFFSET] = v.encoding.get('add_offset', None) 'dtype', None)
encoding[Encoding.SCALE] = v.encoding.get('scale_factor', None) encoding[internals.Encoding.IO_OFFSET] = v.encoding.get(
'add_offset', None)
encoding[internals.Encoding.IO_SCALE] = v.encoding.get(
'scale_factor', None)
if internals.ENCODING not in v.encoding:
logging.warning("Cerbere internal attribute should be there")
v.encoding[internals.ENCODING] = {}
v.encoding[S_ENCODING] = encoding v.encoding[internals.ENCODING].update(encoding)
def _open_dataset(self, **kwargs) -> 'xr.Dataset': def _open_dataset(self, **kwargs) -> 'xr.Dataset':
""" """
...@@ -835,7 +822,7 @@ class Dataset(ABC): ...@@ -835,7 +822,7 @@ class Dataset(ABC):
* a :class:`xarray.DataArray` * a :class:`xarray.DataArray`
""" """
if isinstance(values, Field): if isinstance(values, Field):
xrdata = values.to_dataarray() xrdata = values.to_xarray()
elif isinstance(values, xr.DataArray): elif isinstance(values, xr.DataArray):
xrdata = values xrdata = values
...@@ -1047,7 +1034,8 @@ class Dataset(ABC): ...@@ -1047,7 +1034,8 @@ class Dataset(ABC):
) )
try: try:
dataarr = field.to_dataarray(silent=True) dataarr = field.to_xarray(silent=True, decoding=False)
# if some indexes are existing in the dataset, ensure the values are # if some indexes are existing in the dataset, ensure the values are
# the same # the same
for idx in dataarr.indexes: for idx in dataarr.indexes:
...@@ -1229,13 +1217,10 @@ class Dataset(ABC): ...@@ -1229,13 +1217,10 @@ class Dataset(ABC):
values = values.transpose( values = values.transpose(
*(list(rearranged_dims)), transpose_coords=True) *(list(rearranged_dims)), transpose_coords=True)
if as_masked_array: values = internals.from_cerbere_dataarray(
values = values.to_masked_array(copy=False) values, as_masked_array=as_masked_array)
return get_masked_values( return values
fieldname,
values,
self.get_field_fillvalue(fieldname))
def set_values( def set_values(
self, self,
...@@ -1977,10 +1962,10 @@ class Dataset(ABC): ...@@ -1977,10 +1962,10 @@ class Dataset(ABC):
# xarray enforces _FillValue for floats # xarray enforces _FillValue for floats
if dtype in [np.dtype(np.float32), np.dtype(np.float64)]: if dtype in [np.dtype(np.float32), np.dtype(np.float64)]:
return True return True
if S_ENCODING not in encoding: if internals.ENCODING not in encoding:
return '_FillValue' in encoding return '_FillValue' in encoding
return not encoding[S_ENCODING].get( return not encoding[internals.ENCODING].get(
Encoding.UNMASKED, False) internals.Encoding.UNMASKED, False)
# ensure original or overriding encoding # ensure original or overriding encoding
for v in saved_dataset.variables: for v in saved_dataset.variables:
...@@ -1998,22 +1983,23 @@ class Dataset(ABC): ...@@ -1998,22 +1983,23 @@ class Dataset(ABC):
# save in original data type if not overriden by output # save in original data type if not overriden by output
# format profile # format profile
if keep_src_encoding: if keep_src_encoding:
if S_ENCODING not in encoding: if internals.ENCODING not in encoding:
raise ValueError( raise ValueError(
'No original encoding for {}. Where these data ' 'No original encoding for {}. Where these data '
'read from a file?'.format(v)) 'read from a file?'.format(v))
for att in Encoding: for att in internals.Encoding:
if encoding[S_ENCODING][att] is not None: if encoding[internals.ENCODING][att] is not None:
encoding[att] = encoding[S_ENCODING][att] encoding[att] = encoding[internals.ENCODING][att]
if has_scaling(encoding): if has_scaling(encoding):
dtype = svar.dtype dtype = svar.dtype
else: else:
# profile > source file (S_ENCODING) > data array dtype # profile > source file (S_ENCODING) > data array dtype
if S_ENCODING in encoding: if internals.ENCODING in encoding:
dtype = np.dtype(encoding.get( dtype = np.dtype(encoding.get(
C_DTYPE, C_DTYPE,
encoding[S_ENCODING].get(Encoding.DTYPE, svar.dtype))) encoding[internals.ENCODING].get(
internals.Encoding.M_DTYPE, svar.dtype)))
else: else:
dtype = np.dtype(encoding.get(C_DTYPE, svar.dtype)) dtype = np.dtype(encoding.get(C_DTYPE, svar.dtype))
...@@ -2026,10 +2012,10 @@ class Dataset(ABC): ...@@ -2026,10 +2012,10 @@ class Dataset(ABC):
if dtype != np.object and has_fillvalue(encoding, dtype): if dtype != np.object and has_fillvalue(encoding, dtype):
# profile > source file > data array fill value # profile > source file > data array fill value
default_fv = default_fill_value(dtype) default_fv = internals.default_fill_value(dtype)
if S_ENCODING in encoding: if internals.ENCODING in encoding:
default_fv = encoding[S_ENCODING].get( default_fv = encoding[internals.ENCODING].get(
Encoding.FILLVALUE, default_fv) internals.Encoding.IO_FILLVALUE, default_fv)
fillv = self._xr_fillvalue(svar, default_fv) fillv = self._xr_fillvalue(svar, default_fv)
if np.issubdtype(type(fillv), np.datetime64): if np.issubdtype(type(fillv), np.datetime64):
...@@ -2040,11 +2026,11 @@ class Dataset(ABC): ...@@ -2040,11 +2026,11 @@ class Dataset(ABC):
elif np.dtype(type(fillv)) != dtype: elif np.dtype(type(fillv)) != dtype:
logging.debug( logging.debug(
'_FillValue changed from {}({}) to {}({}) when ' '_FillValue changed from {}({}) to {}({}) when '
'saving {}' 'saving {}'.format(
.format(fillv, np.dtype(type(fillv)), fillv, np.dtype(type(fillv)),
default_fill_value(dtype), dtype, internals.default_fill_value(dtype), dtype,
svar.name)) svar.name))
fillv = default_fill_value(dtype) fillv = internals.default_fill_value(dtype)
encoding['_FillValue'] = fillv encoding['_FillValue'] = fillv
...@@ -2060,8 +2046,8 @@ class Dataset(ABC): ...@@ -2060,8 +2046,8 @@ class Dataset(ABC):
dtype=dtype) dtype=dtype)
# remove cerbere specifics # remove cerbere specifics
if S_ENCODING in encoding: if internals.ENCODING in encoding:
encoding.pop(S_ENCODING) encoding.pop(internals.