Commit c702cd72 authored by PIOLLE's avatar PIOLLE
Browse files

better support for fill values in non float dataarrays

parent c5fa3062
...@@ -24,7 +24,59 @@ CF_AUTHORITY = 'CF-1.7' ...@@ -24,7 +24,59 @@ CF_AUTHORITY = 'CF-1.7'
def default_fill_value(obj): def default_fill_value(obj):
"""Returns the default fill value for a specific type""" """Returns the default fill value for a specific type"""
return numpy.ma.default_fill_value(obj) if isinstance(obj, numpy.dtype):
dtype = obj
elif isinstance(obj, (str, type)):
dtype = numpy.dtype(obj)
elif isinstance(obj, numpy.ndarray):
dtype = obj.dtype
else:
raise TypeError("Unexpected object type: ", type(obj), obj)
if dtype.name == 'int16':
return -32768
elif dtype.name == 'uint16':
return 65535
elif dtype.name == 'int8':
return -128
elif dtype.name == 'uint8':
return 255
else:
return numpy.ma.default_fill_value(dtype)
def get_masked_values(data, fill_value, silent=False):
"""fix masked values. Required as xarray data can't store masked values
or nan for non-float types"""
if (fill_value is not None
and data.dtype.name not in [
'float16', 'float32', 'float64', 'complex64', 'complex128',
' datetime64']
and not numpy.issubdtype(data.dtype, numpy.datetime64)):
if (isinstance(data, (numpy.ma.core.MaskedArray, numpy.ndarray))
or data is None):
# mask fill values for int types
data = numpy.ma.masked_equal(data, fill_value, copy=False)
elif not silent:
logging.warning(
'values equal to {} are marked as missing values'
.format(fill_value))
return data
def set_masked_values(data, fill_value):
"""replace masked values with fill value. Required as xarray data can't
store masked values or nan for non-float types"""
if (fill_value is not None
and data.dtype.name not in [
'float16', 'float32', 'float64', 'complex64', 'complex128',
' datetime64']
and not numpy.issubdtype(data.dtype, numpy.datetime64)):
if isinstance(data, (numpy.ma.core.MaskedArray, numpy.ndarray)):
# mask fill values for int types
return data.filled(fill_value)
return data
def default_profile( def default_profile(
......
...@@ -27,7 +27,8 @@ import shapely.geometry ...@@ -27,7 +27,8 @@ import shapely.geometry
import xarray as xr import xarray as xr
from ..cfconvention import ( from ..cfconvention import (
default_profile, default_fill_value, CF_AUTHORITY, DEFAULT_TIME_UNITS default_profile, default_fill_value, CF_AUTHORITY, DEFAULT_TIME_UNITS,
get_masked_values
) )
from .field import Field from .field import Field
...@@ -997,12 +998,16 @@ class Dataset(ABC): ...@@ -997,12 +998,16 @@ class Dataset(ABC):
return Field(self._std_dataset[fieldname], fieldname, dataset=self) return Field(self._std_dataset[fieldname], fieldname, dataset=self)
def add_field(self, field: 'Field') -> None: def add_field(self, field: 'Field', force_index: bool = True) -> None:
"""Add a field to the feature. """Add a field to the feature.
Args: Args:
field: the field is provided as a field: the field is provided as a
:class:`~cerbere.dataset.field.Field` object :class:`~cerbere.dataset.field.Field` object
force_index: if the added field contains an index coordinate with
the same name as the dataset, replace the values with those of
the dataset (otherwise only the field values for which the index
values of the field and the dataset are the same will be added).
""" """
if field.name in self.fieldnames: if field.name in self.fieldnames:
raise Exception( raise Exception(
...@@ -1011,8 +1016,15 @@ class Dataset(ABC): ...@@ -1011,8 +1016,15 @@ class Dataset(ABC):
) )
try: try:
dataarr = field.to_dataarray(silent=True)
# if some indexes are existing in the dataset, ensure the values are
# the same
for idx in dataarr.indexes:
if idx in self._std_dataset.indexes and force_index:
dataarr = dataarr.reset_index(idx, drop=True)
self._std_dataset = self._std_dataset.assign( self._std_dataset = self._std_dataset.assign(
{field.name: field.to_dataarray()} {field.name: dataarr}
) )
except ValueError: except ValueError:
# an error cas when for instance an index (like time) has masked # an error cas when for instance an index (like time) has masked
...@@ -1095,7 +1107,10 @@ class Dataset(ABC): ...@@ -1095,7 +1107,10 @@ class Dataset(ABC):
def get_field_fillvalue(self, fieldname: str) -> Any: def get_field_fillvalue(self, fieldname: str) -> Any:
"""Returns the missing value of a field""" """Returns the missing value of a field"""
return self.get_field(fieldname).fill_value if fieldname in self.coordnames:
return self.get_coord(fieldname).fill_value
else:
return self.get_field(fieldname).fill_value
def get_values(self, def get_values(self,
fieldname: Hashable, fieldname: Hashable,
...@@ -1183,9 +1198,13 @@ class Dataset(ABC): ...@@ -1183,9 +1198,13 @@ class Dataset(ABC):
*(list(rearranged_dims)), transpose_coords=True) *(list(rearranged_dims)), transpose_coords=True)
if not as_masked_array: if not as_masked_array:
return values return get_masked_values(
values,
self.get_field_fillvalue(fieldname))
else: else:
return values.to_masked_array(copy=False) return get_masked_values(
values.to_masked_array(copy=False),
self.get_field_fillvalue(fieldname))
def set_values( def set_values(
self, self,
...@@ -1484,13 +1503,17 @@ class Dataset(ABC): ...@@ -1484,13 +1503,17 @@ class Dataset(ABC):
subset. subset.
""" """
if isinstance(self.dataset, Dataset): if isinstance(self.dataset, Dataset):
return self.dataset.extract( return self.dataset.extract(
index=index, fields=fields, padding=padding, index=index, fields=fields, padding=padding,
prefix=prefix, deep=deep, **kwargs prefix=prefix, deep=deep, **kwargs)
)
if fields is None: if fields is None:
fields = self._varnames fields = self._varnames
else:
fields.extend(self.geocoordnames)
# remove possible duplicates
fields = list(set(fields))
if index is None: if index is None:
subset = self.dataset[fields] subset = self.dataset[fields]
......
...@@ -120,8 +120,22 @@ class Field(object): ...@@ -120,8 +120,22 @@ class Field(object):
if name is not None and not isinstance(name, str): if name is not None and not isinstance(name, str):
raise TypeError('name must be a string') raise TypeError('name must be a string')
# dtype
if data is None:
if dtype is None:
raise ValueError(
"If you don't provide any data, you must at least "
"provide a datatype"
)
if dtype is None:
dtype = data.dtype
# fill value
if fillvalue is None:
fillvalue = cf.default_fill_value(dtype)
if isinstance(data, xr.DataArray): if isinstance(data, xr.DataArray):
self._array = data self._array = cf.set_masked_values(data, fillvalue)
else: else:
# create the DataArray from the provided information # create the DataArray from the provided information
...@@ -133,25 +147,24 @@ class Field(object): ...@@ -133,25 +147,24 @@ class Field(object):
if data is None: if data is None:
# create default array # create default array
if dtype is None:
raise ValueError(
"If you don't provide any data, you must at least "
"provide a datatype"
)
if not isinstance(dims, OrderedDict): if not isinstance(dims, OrderedDict):
raise TypeError( raise TypeError(
"dimensions should be provided with their size in a " "dimensions should be provided with their size in a "
"OrderedDict" "OrderedDict"
) )
data = numpy.ma.masked_all(
tuple(dims.values()), dtype) data = numpy.ma.masked_all(tuple(dims.values()), dtype)
else: data.set_fill_value(fillvalue)
data = data
# instantiate the xarray representation # instantiate the xarray representation
kwargs['dims'] = list(dims) kwargs['dims'] = list(dims)
kwargs['attrs'] = attrs kwargs['attrs'] = attrs
self._array = xr.DataArray(data, name=name, **kwargs)
# fix for xarray to keep the data type : replace masked values
# with fill values
self._array = xr.DataArray(
cf.set_masked_values(data, fillvalue),
name=name, **kwargs)
# Overrides DataArray object when conflicts with the superceding # Overrides DataArray object when conflicts with the superceding
# arguments # arguments
...@@ -201,10 +214,11 @@ class Field(object): ...@@ -201,10 +214,11 @@ class Field(object):
""" """
return Field(data=data) return Field(data=data)
def to_dataarray(self) -> 'xr.DataArray': def to_dataarray(self, silent=False) -> 'xr.DataArray':
"""Return the field values a xarray DataArray""" """Return the field values a xarray DataArray"""
if self.dataset is None: if self.dataset is None:
return self._array return cf.get_masked_values(
self._array, self.fill_value, silent=silent)
else: else:
return self.dataset.get_values( return self.dataset.get_values(
self._array.name, self._array.name,
...@@ -543,11 +557,12 @@ class Field(object): ...@@ -543,11 +557,12 @@ class Field(object):
**kwargs **kwargs
} }
if self.dataset is None: if self.dataset is None:
return numpy.ma.array( data = numpy.ma.array(
self._read_dataarray(self._array, **allkwargs) self._read_dataarray(self._array, **allkwargs))
)
else: else:
return self.dataset.get_values(self.name, **allkwargs) data = self.dataset.get_values(self.name, **allkwargs)
return cf.get_masked_values(data, self.fill_value)
@classmethod @classmethod
def _read_dataarray( def _read_dataarray(
......
...@@ -31,7 +31,7 @@ class NCDataset(Dataset): ...@@ -31,7 +31,7 @@ class NCDataset(Dataset):
*args, *args,
format=NETCDF4, format=NETCDF4,
**kwargs): **kwargs):
return super().__init__( super().__init__(
*args, *args,
format=format, format=format,
**kwargs **kwargs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment