Commit 9e9f3639 authored by PIOLLE's avatar PIOLLE
Browse files

improved fill values in float fields + unit test

parent 8d2d06bb
......@@ -48,19 +48,25 @@ def default_fill_value(obj):
def get_masked_values(data, fill_value, silent=False):
"""fix masked values. Required as xarray data can't store masked values
or nan for non-float types"""
if (fill_value is not None
and data.dtype.name not in [
'float16', 'float32', 'float64', 'complex64', 'complex128',
' datetime64']
and not numpy.issubdtype(data.dtype, numpy.datetime64)):
if (isinstance(data, (numpy.ma.core.MaskedArray, numpy.ndarray))
or data is None):
if fill_value is None and not isinstance(data, numpy.ma.core.MaskedArray):
# no masked data
return data
if isinstance(data, (numpy.ma.core.MaskedArray, numpy.ndarray)):
if data.dtype.name in [
'float16', 'float32', 'float64', 'complex64', 'complex128']:
if fill_value is not None:
data.set_fill_value(fill_value)
return data
elif numpy.issubdtype(data.dtype, numpy.datetime64):
return data
else:
# mask fill values for int types
data = numpy.ma.masked_equal(data, fill_value, copy=False)
elif not silent:
logging.warning(
'values equal to {} are marked as missing values'
.format(fill_value))
elif not silent:
logging.warning(
'values equal to {} are marked as missing values'
.format(fill_value))
return data
......@@ -68,14 +74,22 @@ def get_masked_values(data, fill_value, silent=False):
def set_masked_values(data, fill_value):
"""replace masked values with fill value. Required as xarray data can't
store masked values or nan for non-float types"""
if (fill_value is not None
and data.dtype.name not in [
'float16', 'float32', 'float64', 'complex64', 'complex128',
' datetime64']
and not numpy.issubdtype(data.dtype, numpy.datetime64)):
if isinstance(data, numpy.ma.core.MaskedArray):
# mask fill values for int types
return data.filled(fill_value)
if fill_value is None and not isinstance(data, numpy.ma.core.MaskedArray):
# no masked data
return data
if isinstance(data, numpy.ma.core.MaskedArray):
if data.dtype.name in [
'float16', 'float32', 'float64', 'complex64', 'complex128']:
fill_value = numpy.nan
elif numpy.issubdtype(data.dtype, numpy.datetime64):
fill_value = numpy.datetime64('NaT')
else:
if fill_value is None:
fill_value = data.fill_value
return data.filled(fill_value)
return data
......
......@@ -698,13 +698,14 @@ class Field(object):
else:
self.dataset.set_values(self.name, values, index=index)
@classmethod
def _set_xrvalues(
cls,
self,
xrdata,
values,
index=None
):
values = cf.set_masked_values(values, self.fill_value)
if index is None:
xrdata.values[:] = values
......
from datetime import datetime
import os
import unittest
import numpy as np
import xarray as xr
from cerbere.dataset.ncdataset import NCDataset
from cerbere.dataset.field import Field
N_VALUES = 10
class TestMaskedArray(unittest.TestCase):
"""Test class for checking fill values"""
def test_field_ma_float32_auto_fillvalue(self):
data = np.ma.ones((N_VALUES,), dtype=np.float32)
field = Field(
data=data,
dims=['time'],
name='myvar')
self.assertEqual(field.get_values().count(), N_VALUES)
data = np.ma.masked_where(
(np.arange(N_VALUES) >= N_VALUES/2), data, copy=True)
field.set_values(data)
self.assertEqual(field.get_values().count(), N_VALUES/2)
self.assertEqual(field.get_values().dtype.name, 'float32')
def test_field_ma_float32(self):
data = np.ma.ones((N_VALUES,), dtype=np.float32)
field = Field(
data=data,
dims=['time'],
fillvalue=1e5,
name='myvar')
self.assertEqual(field.get_values().count(), N_VALUES)
data = np.ma.masked_where(
(np.arange(N_VALUES) >= N_VALUES/2), data, copy=True)
field.set_values(data)
self.assertEqual(field.get_values().count(), N_VALUES/2)
self.assertEqual(field.get_values().fill_value, 1e5)
self.assertEqual(field.fill_value, 1e5)
self.assertEqual(field.get_values().dtype.name, 'float32')
def test_field_ma_int32_auto_fillvalue(self):
data = np.ma.ones((N_VALUES,), dtype=np.int32)
field = Field(
data=data,
dims=['time'],
name='myvar')
self.assertEqual(field.get_values().count(), N_VALUES)
data = np.ma.masked_where(
(np.arange(N_VALUES) >= N_VALUES / 2), data, copy=True)
field.set_values(data)
self.assertEqual(field.get_values().count(), N_VALUES / 2)
self.assertEqual(field.get_values().dtype.name, 'int32')
def test_field_ma_int32(self):
data = np.ma.ones((N_VALUES,), dtype=np.int32)
field = Field(
data=data,
fillvalue=-1,
dims=['time'],
name='myvar')
self.assertEqual(field.get_values().count(), N_VALUES)
data = np.ma.masked_where(
(np.arange(N_VALUES) >= N_VALUES / 2), data, copy=True)
field.set_values(data)
self.assertEqual(field.get_values().count(), N_VALUES / 2)
self.assertEqual(field.get_values().fill_value, -1)
self.assertEqual(field.fill_value, -1)
self.assertEqual(field.get_values().dtype.name, 'int32')
def test_field_ma_datetime64_auto_fillvalue(self):
data = np.ma.array(
[np.datetime64("2010-01-01")] * N_VALUES, dtype=np.datetime64)
field = Field(
data=data,
dims=['time'],
name='myvar')
self.assertEqual(field.get_values().count(), N_VALUES)
data = np.ma.masked_where(
(np.arange(N_VALUES) >= N_VALUES / 2), data, copy=True)
field.set_values(data)
self.assertEqual(field.get_values().count(), N_VALUES / 2)
self.assertEqual(field.get_values().dtype.name, 'datetime64[ns]')
def test_field_da_float32_auto_fillvalue(self):
data = np.ma.ones((N_VALUES,), dtype=np.float32)
da = xr.DataArray(
data=data,
dims=['time'],
name='myvar'
)
field = Field(
data=da,
dims=['time'],
name='myvar')
self.assertEqual(field.get_values().count(), N_VALUES)
data = np.ma.masked_where(
(np.arange(N_VALUES) >= N_VALUES/2), data, copy=True)
field.set_values(data)
print(field.get_values())
self.assertEqual(field.get_values().count(), N_VALUES/2)
self.assertEqual(field.get_values().dtype.name, 'float32')
self.assertEqual(field.get_values().fill_value, np.float32(1e20))
def test_field_da_float32(self):
data = np.ma.ones((N_VALUES,), dtype=np.float32)
da = xr.DataArray(
data=data,
dims=['time'],
name='myvar'
)
field = Field(
data=da,
dims=['time'],
fillvalue=1e5,
name='myvar')
self.assertEqual(field.get_values().count(), N_VALUES)
data = np.ma.masked_where(
(np.arange(N_VALUES) >= N_VALUES/2), data, copy=True)
field.set_values(data)
print(field.get_values())
self.assertEqual(field.get_values().count(), N_VALUES/2)
self.assertEqual(field.get_values().fill_value, 1e5)
self.assertEqual(field.fill_value, 1e5)
self.assertEqual(field.get_values().dtype.name, 'float32')
def test_field_da_int32_auto_fillvalue(self):
data = np.ma.ones((N_VALUES,), dtype=np.int32)
da = xr.DataArray(
data=data,
dims=['time'],
name='myvar'
)
field = Field(
data=da,
dims=['time'],
name='myvar')
self.assertEqual(field.get_values().count(), N_VALUES)
data = np.ma.masked_where(
(np.arange(N_VALUES) >= N_VALUES/2), data, copy=True)
field.set_values(data)
print(field.get_values())
self.assertEqual(field.get_values().count(), N_VALUES/2)
self.assertEqual(field.get_values().dtype.name, 'int32')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment