Commit beba88a2 authored by ARCHER's avatar ARCHER

mutliple geometry columns support

parent 08091942
......@@ -22,6 +22,7 @@ shapely_handled_types=[
shpg.Polygon,
shpg.Point,
shpg.LinearRing,
shpg.LineString,
shpg.MultiPoint,
shpg.MultiLineString,
shpg.MultiPolygon]
......@@ -41,9 +42,6 @@ class TemporaryPandasTable(pd.io.sql.SQLTable):
# methods that use a temporary table wrapper if the table allreaydy exists
temporary_table_methods=['copy']
def get_srid(crs):
"""
......@@ -124,24 +122,29 @@ def to_sql(df,*args,**kwargs):
if 'method' not in kwargs:
logger.debug('setting "copy" as default method for to_sql')
kwargs['method'] = 'copy'
if 'if_exists' not in kwargs:
kwargs['if_exists'] = 'fail'
metadata = MetaData(con,reflect=True)
tmp_table = None
new_table = False
if 'if_exists' in kwargs and table in metadata.tables:
if (kwargs['if_exists'] == 'append') and (table in metadata.tables):
if kwargs['method'] in temporary_table_methods:
# will use a temporary table as a buffer ( for upsert copy )
logger.info("Temporary table used for fast update")
pandas_engine = pd.io.sql.pandasSQL_builder(con,meta=metadata)
tmp_table = "_%s" % table
if 'dtype' in kwargs:
dtype=kwargs['dtype']
else:
dtype=None
tmp_table_pd = TemporaryPandasTable("_%s" % table, pandas_engine, frame=df, if_exists="replace", dtype=dtype)
#tmp_table_pd.create()
# will use a temporary table as a buffer ( for upsert copy )
logger.info("Temporary table used for fast update")
pandas_engine = pd.io.sql.pandasSQL_builder(con,meta=metadata)
tmp_table = "_%s" % table
if 'dtype' in kwargs:
dtype=kwargs['dtype']
else:
dtype=None
tmp_table_pd = TemporaryPandasTable("_%s" % table, pandas_engine, frame=df, if_exists="replace", dtype=dtype)
#tmp_table_pd.create()
else:
......@@ -156,7 +159,6 @@ def to_sql(df,*args,**kwargs):
args=list(args)
if tmp_table is not None:
args[0]=tmp_table
#args[1]=con
args=tuple(args)
ret=df.to_sql_legacy(*args,**kwargs)
......@@ -182,11 +184,13 @@ def to_sql(df,*args,**kwargs):
]
}
)
# on_conflict_stmt = stmt.on_conflict_do_nothing()
con.execute(on_conflict_stmt)
try:
con.execute(on_conflict_stmt)
except:
logger.error("couln't merge tables. Did they have same columns ?")
raise
tmp_table_orm.drop(con)
#con1.execute('''insert into "%s" select * from "%s" on conflict do nothing''' % (table,tmp_table))
#con1.execute('''drop table "%s"''' % tmp_table)
return ret
......@@ -194,8 +198,14 @@ def to_sql(df,*args,**kwargs):
def get_all_geometry_names(gdf):
""" return all geoemetry columns names (ie those not only the one defined by gdf.geometry """
all=[]
for c in gdf.keys():
if hasattr(gdf[c][0], '__geom__'):
all.append(c)
return all
def to_postgis(gdf,*args,**kwargs):
"""
......@@ -229,11 +239,15 @@ def to_postgis(gdf,*args,**kwargs):
# set dtype converter for geometry columun
if 'dtype' not in kwargs:
kwargs['dtype']={}
geom_types=gdf.geometry.geom_type.unique().tolist()
if len(geom_types) != 1:
raise TypeError("mixed types not implemented : %s" % ",".join(geom_types))
geom_type=geom_types[0].upper()
kwargs['dtype'][gdf.geometry.name] = Geo(geom_type) #, srid=srid)
geom_cols = gdf.get_all_geometry_names()
for geom_col in geom_cols :
geom_types=gdf[geom_col].apply(lambda x: x.__geo_interface__['type']).unique().tolist()
if len(geom_types) != 1:
raise TypeError("mixed types not implemented : %s" % ",".join(geom_types))
geom_type=geom_types[0].upper()
logger.debug("%s type : %s" % (geom_col , geom_type))
kwargs['dtype'][geom_col] = Geo(geom_type) #, srid=srid)
with con.begin():
......@@ -243,12 +257,13 @@ def to_postgis(gdf,*args,**kwargs):
# table will be modified
# get actual column srid
with con.begin_nested() as trans:
try:
column_srid=int(con.execute('''SELECT Find_SRID('', '%s', '%s');''' % (table , gdf.geometry.name)).fetchone()[0])
except:
logger.warning('no column srid on existing table %s. assuming 0' % table)
column_srid=0
trans.rollback()
for geom_col in geom_cols :
try:
column_srid=int(con.execute('''SELECT Find_SRID('', '%s', '%s');''' % (table , geom_col)).fetchone()[0])
except:
logger.warning('no column srid on existing %s.%s assuming 0' % (table, geom_col))
column_srid=0
trans.rollback()
if column_srid != srid:
raise ValueError("geopandas srid %s doesn't match postgis Find_SRID %s" % (srid , column_srid))
......@@ -256,7 +271,8 @@ def to_postgis(gdf,*args,**kwargs):
if column_srid != 0:
# overwrite column srid so we don't need to give a srid for each wkt (we be done at the end)
logger.debug('temporary removing SRID')
con.execute('''SELECT UpdateGeometrySRID('%s','%s',0);''' % (table , gdf.geometry.name))
for geom_col in geom_cols :
con.execute('''SELECT UpdateGeometrySRID('%s','%s',0);''' % (table , geom_col))
args=list(args)
......@@ -266,7 +282,8 @@ def to_postgis(gdf,*args,**kwargs):
if srid != 0:
logger.debug('restoring SRID %s' % srid)
con.execute('''SELECT UpdateGeometrySRID('%s','%s',%s);''' % (table , gdf.geometry.name , srid))
for geom_col in geom_cols :
con.execute('''SELECT UpdateGeometrySRID('%s','%s',%s);''' % (table , geom_col , srid))
......@@ -289,6 +306,7 @@ def add_postgis(gpd):
gpd.GeoDataFrame.to_postgis=to_postgis
gpd.GeoDataFrame.to_sql_legacy = gpd.GeoDataFrame.to_sql
gpd.GeoDataFrame.to_sql = to_sql
gpd.GeoDataFrame.get_all_geometry_names = get_all_geometry_names
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment