Commit a2cea61a authored by PIOLLE's avatar PIOLLE
Browse files

added method closest_spatial_location for Grid feature

parent 19d66e31
...@@ -9,6 +9,7 @@ from __future__ import print_function ...@@ -9,6 +9,7 @@ from __future__ import print_function
from typing import Tuple from typing import Tuple
import xarray as xr import xarray as xr
import numpy as np
from .feature import Feature from .feature import Feature
...@@ -118,119 +119,6 @@ class Grid(Feature): ...@@ -118,119 +119,6 @@ class Grid(Feature):
else: else:
return 'y', 'x' return 'y', 'x'
# def save(self, output, attrs=None, infer_cf_attrs=False, **kwargs):
# """
# Save the grid to a storage (file,...)
#
# Args:
# output (:class:`~cerbere.mapper.abstractmapper.AbstractMapper`):
# storage object which to save the feature data to.
#
# attrs (dict): the global metadata (attributes) of the feature, as
# a dictionary where keys are the attributes names.
# See STANDARD_ATTRIBUTES_VALUES in abstractmapper class to see
# a list of standard attributes
#
# infer_cf_attrs (bool, optional)
# """
# output.save(self)
# if output is None:
# mapper = self.get_mapper()
# else:
# mapper = output
# if mapper.is_writable():
# # creating dimensions
# mapper.create_dim('time', None)
# if self.projection.is_cylindrical():
# mapper.create_dim(
# 'lat',
# self.get_geolocation_field('lat').get_dimsize('y')
# )
# mapper.create_dim(
# 'lon',
# self.get_geolocation_field('lon').get_dimsize('x')
# )
# dim_translation = {'y': 'lat', 'x': 'lon'}
# else:
# mapper.create_dim(
# 'y',
# self.get_geolocation_field('lat').get_dimsize('y')
# )
# mapper.create_dim(
# 'x',
# self.get_geolocation_field('lon').get_dimsize('x')
# )
# dim_translation = None
# # create additional dimensions
# dims = ['y', 'x', 'lat', 'lon', 'time']
# for v in self._fields.keys():
# if v not in self._geocoordinates:
# for d in self._fields[v].dimensions:
# if d not in dims:
# mapper.create_dim(
# d, self._fields[v].get_dimsize(d))
# dims.append(d)
# # creating metadata
# if attrs:
# globalattr = attrs
# else:
# globalattr = {}
# if 'title' not in globalattr and self.title is not None:
# globalattr['title'] = self.title
# if 'summary' not in globalattr and self.description is not None:
# globalattr['summary'] = self.description
# if 'id' not in globalattr and self.identifier is not None:
# globalattr['id'] = self.identifier
# lonmin, latmin, lonmax, latmax = self.get_bbox()
# globalattr['geospatial_lat_min'] = latmin
# globalattr['geospatial_lon_min'] = lonmin
# globalattr['geospatial_lat_max'] = latmax
# globalattr['geospatial_lon_max'] = lonmax
# if self.projection.identifier == 'regular':
# if (len(self.get_lat()) > 1) and (len(self.get_lon()) > 1):
# globalattr['geospatial_lat_resolution'] \
# = self.get_lat()[1] - self.get_lat()[0]
# globalattr['geospatial_lon_resolution'] \
# = self.get_lon()[1] - self.get_lon()[0]
# tmptime = self.get_start_time()
# if tmptime is not None:
# globalattr['time_coverage_start'] = tmptime
# tmptime = self.get_end_time()
# if tmptime is not None:
# globalattr['time_coverage_end'] = tmptime
# globalattr['cdm_data_type'] = 'grid'
# mapper.write_global_attributes(globalattr)
# # creating records
# if self._geocoordinates['time'] is None:
# raise Exception('No time information defined')
# for geof in self._geocoordinates:
# if self._geocoordinates[geof] is None:
# logging.warning('Missing geolocation variable : %s', geof)
# else:
# mapper.create_field(self._geocoordinates[geof],
# dim_translation,
# feature='Grid'
# )
# for dataf in self._fields.keys():
# if dataf not in self._geocoordinates:
# mapper.create_field(self._fields[dataf],
# dim_translation,
# feature='Grid')
# else:
# raise Exception('Mapper object is not writable')
# # saving records
# for geof in self._geocoordinates:
# field = self._geocoordinates[geof]
# if field is not None and not field.is_saved():
# mapper.write_field(self._geocoordinates[geof])
# for dataf in self._fields.keys():
# if not self._fields[dataf].is_saved():
# mapper.write_field(self._fields[dataf])
# mapper.sync()
# return
def get_spatial_resolution(self): def get_spatial_resolution(self):
"""Return the spatial resolution of the feature, in degrees""" """Return the spatial resolution of the feature, in degrees"""
if self.spatial_resolution is None: if self.spatial_resolution is None:
...@@ -240,107 +128,6 @@ class Grid(Feature): ...@@ -240,107 +128,6 @@ class Grid(Feature):
else: else:
return self.spatial_resolution return self.spatial_resolution
# def extract_subset(
# self, boundaries=None, slices=None, fields=None, padding=False,
# prefix=None):
# """Extract a subset feature from the grid.
#
# The created subset is a new Grid object without any reference to
# the source.
#
# Args:
# boundaries (tuple): area of the subset to extract, defined as
# llcrnrlon, llcrnrlat, urcrnrlon, urcrnrlat.
#
# slices (dict): indices for /time/ dimension of the subset to
# extract from the source data. If None, the complete feature
# is extracted.
#
# fields (list): list of field names to extract. If None, all fields
# are extracted.
#
# padding (bool): Passed to extract_field method to ensure padding
# with _FillValues for points outside of the bounds of the this
# feature (used only in conjuncture with slices.
#
# prefix (str): add a prefix string to the field names of the
# extracted subset.
# """
# if boundaries and slices:
# raise Exception("Boundaries and slices can not be both provided.")
# if boundaries:
# # get corresponding slices
# llcrnrlon, llcrnrlat, urcrnrlon, urcrnrlat = boundaries
# slice1 = self.latlon2slice(llcrnrlat, llcrnrlon)
# slice2 = self.latlon2slice(urcrnrlat, urcrnrlon)
# lslices = collections.OrderedDict([])
# for dim in slice1.keys():
# lslices[dim] = slice(min(slice1[dim].start, slice2[dim].start),
# max(slice1[dim].stop, slice2[dim].stop),
# 1)
# else:
# lslices = slices
#
# if self.is_unique_grid_time():
# timefield = self.extract_field('time', prefix=prefix)
# else:
# timefield = self.extract_field('time',
# slices=lslices,
# padding=padding)
# subgrid = Grid(
# latitudes=self.extract_field('lat',
# slices=lslices,
# padding=padding),
# longitudes=self.extract_field('lon',
# slices=lslices,
# padding=padding),
# times=timefield,
# projection=self.projection,
# metadata=self.metadata,
# )
# if fields is None:
# fields = self.get_fieldnames()
# elif not type(fields) is list:
# raise Exception("fields must be a list")
# for field in fields:
# subgrid.add_field(
# self.extract_field(field, slices=lslices,
# padding=padding,
# prefix=prefix))
#
# return subgrid
#
# def extract_spatialsection(self, lat1, lon1, lat2, lon2, fieldnames=None):
# """
# """
# # TBD
# pass
#
# def latlon2slice(self, lat, lon):
# """Returns the slice corresponding to the provided lat/lon locations
#
# Args:
# lat (float) : latitude
# lon (float): longitude
#
# Returns:
# slice
# """
# lats = self.get_lat()
# lons = self.get_lon()
# if self.projection.is_cylindrical():
# # y = lats[numpy.abs(lats - lat).argmin()]
# # x = lons[numpy.abs(lons - lon).argmin()]
# x = numpy.abs(lons - lon).argmin()
# y = numpy.abs(lats - lat).argmin()
# # logging.debug('nearest grid lon %s lat %s',lons[x],lats[y])
# #ydim_name_in_attached_storage = self.get_mapper().get_matching_dimname('y')
# #xdim_name_in_attached_storage = self.get_mapper().get_matching_dimname('x')
# # logging.debug('latlon2slice dimension name found for y : %s',ydim_name_in_attached_storage)
# return collections.OrderedDict([('y', slice(y, y + 1, 1)), ('x', slice(x, x + 1, 1))])
# else:
# raise NotImplementedError
class CylindricalGrid(Grid): class CylindricalGrid(Grid):
def __init__(self, def __init__(self,
...@@ -376,3 +163,21 @@ class CylindricalGrid(Grid): ...@@ -376,3 +163,21 @@ class CylindricalGrid(Grid):
) )
else: else:
return ('lat', 'lon',) return ('lat', 'lon',)
def closest_spatial_location(self, lon, lat):
"""Get closest dataset lat/lon location to given coordinates.
Use pythagorian differences on lat/lon values so take the result with
caution.
Returns:
A tuple with the indices and lon/lat of the closest point found.
"""
idx_lon = np.abs(self.get_lon() - lon).argmin()
idx_lat = np.abs(self.get_lat() - lat).argmin()
loc = {'lat': idx_lat, 'lon': idx_lon}
geoloc = (self.get_lon().flat[idx_lon], self.get_lat().flat[idx_lat])
return loc, geoloc,
...@@ -17,6 +17,8 @@ from cerbere.dataset.field import Field ...@@ -17,6 +17,8 @@ from cerbere.dataset.field import Field
from cerbere.dataset.dataset import Dataset from cerbere.dataset.dataset import Dataset
from cerbere.dataset.ncdataset import NCDataset from cerbere.dataset.ncdataset import NCDataset
TEST_SUBSET = "test_subset.nc"
class Checker(): class Checker():
"""Checker for dataset classes """Checker for dataset classes
...@@ -236,7 +238,7 @@ class Checker(): ...@@ -236,7 +238,7 @@ class Checker():
attr) attr)
self.assertIsInstance( self.assertIsInstance(
datasetobj.get_attr(attr), datasetobj.get_attr(attr),
(str, int, datetime.datetime, numpy.int32, (str, int, datetime.datetime, numpy.int32, numpy.uint32,
numpy.int16, numpy.float32, numpy.float64, list), numpy.int16, numpy.float32, numpy.float64, list),
msg) msg)
datasetobj.close() datasetobj.close()
...@@ -276,11 +278,10 @@ class Checker(): ...@@ -276,11 +278,10 @@ class Checker():
width = min(min(rows // 2, cells // 2), 5) width = min(min(rows // 2, cells // 2), 5)
r0, r1 = rows // 2 - width, rows // 2 + width r0, r1 = rows // 2 - width, rows // 2 + width
c0, c1 = cells // 2 - width, cells // 2 + width c0, c1 = cells // 2 - width, cells // 2 + width
print("Subset ") slices = {'row': slice(r0, r1, 1), 'cell': slice(c0, c1, 1)}
print("row : ", r0, r1) print("Subset: ", slices)
print("cell: ", c0, c1) subset = featureobj.extract(index=slices)
subset = featureobj.extract(index={'row': slice(r0, r1, 1), subset.attrs['slices'] = str(slices)
'cell': slice(c0, c1, 1)})
elif featureobj.__class__.__name__ in ['Grid', 'GridTimeSeries']: elif featureobj.__class__.__name__ in ['Grid', 'GridTimeSeries']:
ni = featureobj.geodims['x'] ni = featureobj.geodims['x']
...@@ -319,9 +320,10 @@ class Checker(): ...@@ -319,9 +320,10 @@ class Checker():
subset = self.__extract_subset() subset = self.__extract_subset()
# test saving the subset in netCDF format # test saving the subset in netCDF format
fname = "test_subset.nc" fname = TEST_SUBSET
if os.path.exists(fname): if os.path.exists(fname):
os.remove(fname) os.remove(fname)
subsetfile = NCDataset(fname, mode='w') subsetfile = NCDataset(fname, mode='w')
subset.save(subsetfile) subset.save(subsetfile)
subsetfile.close() subsetfile.close()
......
...@@ -44,6 +44,8 @@ class TestFeature(unittest.TestCase): ...@@ -44,6 +44,8 @@ class TestFeature(unittest.TestCase):
def tearDown(self): def tearDown(self):
"""Cleaning up after the test""" """Cleaning up after the test"""
os.remove(TEST_FILE) os.remove(TEST_FILE)
if os.path.exists(TEST_SAVE):
os.remove(TEST_SAVE)
def test_create_feature_from_dict(self): def test_create_feature_from_dict(self):
basefeat = self.get_feature(self.define_base_feature()) basefeat = self.get_feature(self.define_base_feature())
...@@ -79,11 +81,10 @@ class TestFeature(unittest.TestCase): ...@@ -79,11 +81,10 @@ class TestFeature(unittest.TestCase):
feat = self.get_feature(ncf) feat = self.get_feature(ncf)
self.assertIsInstance(feat, self.get_feature_class()) self.assertIsInstance(feat, self.get_feature_class())
def test_save_grid(self): def test_save_feature(self):
ncf = NCDataset(dataset=TEST_SAVE, mode='w') ncf = NCDataset(dataset=TEST_SAVE, mode='w')
feat = self.get_feature(self.define_base_feature()) feat = self.get_feature(self.define_base_feature())
feat.save(ncf) feat.save(ncf)
os.remove(TEST_SAVE)
def test_get_lat(self): def test_get_lat(self):
feat = self.get_feature(self.define_base_feature()) feat = self.get_feature(self.define_base_feature())
...@@ -181,6 +182,12 @@ class TestFeature(unittest.TestCase): ...@@ -181,6 +182,12 @@ class TestFeature(unittest.TestCase):
# test append with unshared dimensions # test append with unshared dimensions
feat = self.get_feature(self.define_base_feature()) feat = self.get_feature(self.define_base_feature())
feat.append(feat2, prefix="v2_", as_new_dims=True) feat.append(feat2, prefix="v2_", add_coords=True, as_new_dims=True)
self.assertIn("v2_time", feat.dims) for fieldname in feat2.fieldnames:
self.assertIn("v2_time", feat.coordnames) self.assertIn("v2_"+fieldname, feat.fieldnames)
for d in feat2.get_field(fieldname).dims:
self.assertIn("v2_"+d, feat.dimnames)
self.assertIn("v2_"+d, feat.coordnames)
# test the appended feature can be save
feat.save(TEST_SAVE)
from datetime import datetime from datetime import datetime
import numpy as np import numpy as np
import shapely.geometry
import xarray as xr import xarray as xr
import netCDF4 as netcdf
from cerbere.feature.grid import CylindricalGrid from cerbere.feature.grid import CylindricalGrid
from .test_feature import TestFeature from .test_feature import TestFeature, TEST_SAVE
class TestCylindricalGridFeature(TestFeature): class TestCylindricalGridFeature(TestFeature):
...@@ -16,8 +18,8 @@ class TestCylindricalGridFeature(TestFeature): ...@@ -16,8 +18,8 @@ class TestCylindricalGridFeature(TestFeature):
def define_base_feature(self): def define_base_feature(self):
# creates a test xarray object # creates a test xarray object
lon = xr.DataArray(data=np.arange(-180, 180, 1), dims=['lon']) lon = xr.DataArray(data=np.arange(-180., 180., 1), dims=['lon'])
lat = xr.DataArray(data=np.arange(-80, 80, 1), dims=['lat']) lat = xr.DataArray(data=np.arange(-80., 80., 1), dims=['lat'])
time = xr.DataArray([datetime(2018, 1, 1)], dims='time') time = xr.DataArray([datetime(2018, 1, 1)], dims='time')
var = xr.DataArray( var = xr.DataArray(
data=np.ones(shape=(160, 360)), data=np.ones(shape=(160, 360)),
...@@ -29,6 +31,7 @@ class TestCylindricalGridFeature(TestFeature): ...@@ -29,6 +31,7 @@ class TestCylindricalGridFeature(TestFeature):
data_vars={'myvar': var}, data_vars={'myvar': var},
attrs={'gattr1': 'gattr1_val', 'gattr2': 'gattr2_val'} attrs={'gattr1': 'gattr1_val', 'gattr2': 'gattr2_val'}
) )
return CylindricalGrid(xrdataset) return CylindricalGrid(xrdataset)
def get_feature_dimnames(self): def get_feature_dimnames(self):
...@@ -241,4 +244,19 @@ class TestCylindricalGridFeature(TestFeature): ...@@ -241,4 +244,19 @@ class TestCylindricalGridFeature(TestFeature):
'time', 'time',
expand=True expand=True
) )
print("result: test_expanded_time ", res) print("result: test_expanded_time ", res)
\ No newline at end of file
def test_clip(self):
basefeat = self.define_base_feature()
subfeature = basefeat.clip(shapely.geometry.box(-180, 0, 0, 70))[0]
self.assertEqual(
dict(subfeature.dims),
dict([('lat', 71), ('lon', 181), ('time', 1)]))
def test_save_feature(self):
super(TestCylindricalGridFeature, self).test_save_feature()
dst = netcdf.Dataset(TEST_SAVE)
self.assertEqual(dst.variables['lat'].getncattr('_FillValue'),
np.float32(1e20))
self.assertEqual(dst.variables['lat'].dtype, np.dtype(np.float32))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment