Commit 08091942 authored by ARCHER's avatar ARCHER

on conflict do update + sql transaction

parent d1dd8d95
......@@ -4,7 +4,9 @@ import shapely as shp
import shapely.geometry as shpg
from sqlalchemy.orm import sessionmaker
from sqlalchemy.dialects import postgresql
from sqlalchemy import text;
from sqlalchemy import text, MetaData;
from geoalchemy2 import select
import csv
from io import StringIO
import pandas as pd
......@@ -32,9 +34,13 @@ class TemporaryPandasTable(pd.io.sql.SQLTable):
self.table = self.table.tometadata(self.pd_sql.meta)
# allow creation of temporary tables
self.table._prefixes.append('TEMPORARY')
if 'TEMPORARY' not in self.table._prefixes:
self.table._prefixes.append('TEMPORARY')
self.table.create()
# methods that use a temporary table wrapper if the table allreaydy exists
temporary_table_methods=['copy']
......@@ -106,30 +112,38 @@ def to_sql(df,*args,**kwargs):
table=args[0]
con=args[1]
con_or_engine=args[1]
if hasattr(con_or_engine, 'connect'):
con=con_or_engine.connect()
else:
con=con_or_engine
# get method
if 'method' not in kwargs:
logger.debug('setting "copy" as default method for to_sql')
kwargs['method'] = 'copy'
Session = sessionmaker(bind=con)
# use a session connection
session = Session()
con1=session.connection().connect()
metadata = MetaData(con,reflect=True)
tmp_table = None
new_table = False
if 'if_exists' in kwargs and con1.dialect.has_table(con1, table):
if 'if_exists' in kwargs and table in metadata.tables:
if kwargs['method'] in temporary_table_methods:
# will use a temporary table as a buffer ( for upsert copy )
logger.info("Temporary table used for fast update")
pandas_engine = pd.io.sql.pandasSQL_builder(con1)
pandas_engine = pd.io.sql.pandasSQL_builder(con,meta=metadata)
tmp_table = "_%s" % table
tmp_table_omr = TemporaryPandasTable("'_%s'" % table, pandas_engine, frame=df, if_exists="replace")
tmp_table_omr.create()
if 'dtype' in kwargs:
dtype=kwargs['dtype']
else:
dtype=None
tmp_table_pd = TemporaryPandasTable("_%s" % table, pandas_engine, frame=df, if_exists="replace", dtype=dtype)
#tmp_table_pd.create()
else:
new_table=True
......@@ -137,31 +151,43 @@ def to_sql(df,*args,**kwargs):
if kwargs['method'] in to_sql_methods:
kwargs['method']=to_sql_methods[kwargs['method']]
try:
# use connection session
args=list(args)
if tmp_table is not None:
args[0]=tmp_table
args[1]=con1
args=tuple(args)
ret=df.to_sql_legacy(*args,**kwargs)
if new_table:
# this was probably allready done if it's an append ...
logger.debug('defining primary key from pandas index')
con1.execute('''alter table "%s" add constraint "%s_pk" primary key( %s )''' % (table,table,",".join([ '"%s"' %q for q in df.index.names])))
# use connection session
args=list(args)
if tmp_table is not None:
args[0]=tmp_table
#args[1]=con
args=tuple(args)
ret=df.to_sql_legacy(*args,**kwargs)
if new_table:
logger.debug('defining primary key from pandas index')
con.execute('''alter table "%s" add constraint "%s_pk" primary key( %s )''' % (table,table,",".join([ '"%s"' %q for q in df.index.names])))
if tmp_table is not None:
logger.debug("merging tmp table")
# https://stackoverflow.com/questions/41724658/how-to-do-a-proper-upsert-using-sqlalchemy-on-postgresql
# get existing table orm
table_orm=metadata.tables[table]
tmp_table_orm = tmp_table_pd.table
stmt = postgresql.insert(table_orm).from_select(tmp_table_orm.columns,select([tmp_table_orm]))
#stmt = postgresql.insert(table_orm).from_select(tmp_table_orm.columns,select=session.query(tmp_table_orm))
on_conflict_stmt = stmt.on_conflict_do_update(
index_elements=table_orm.primary_key.columns,
set_={
k: getattr(stmt.excluded, k) for k in
[ c.name for c in table_orm.columns
if c not in list(table_orm.primary_key.columns)
]
}
)
# on_conflict_stmt = stmt.on_conflict_do_nothing()
con.execute(on_conflict_stmt)
tmp_table_orm.drop(con)
#con1.execute('''insert into "%s" select * from "%s" on conflict do nothing''' % (table,tmp_table))
#con1.execute('''drop table "%s"''' % tmp_table)
if tmp_table is not None:
logger.debug("merging tmp table")
con1.execute('''insert into "%s" select * from "%s" on conflict do nothing''' % (table,tmp_table))
con1.execute('''drop table "%s"''' % tmp_table)
session.commit()
except:
session.rollback()
raise
finally:
session.close()
return ret
......@@ -182,8 +208,15 @@ def to_postgis(gdf,*args,**kwargs):
logger.debug("using to_postgis extension")
table=args[0]
con=args[1]
con_or_engine=args[1]
if hasattr(con_or_engine, 'connect'):
con=con_or_engine.connect()
else:
con=con_or_engine
metadata = MetaData(con_or_engine,reflect=True)
srid=get_srid(gdf.crs)
# if no srid => Geography
......@@ -202,46 +235,42 @@ def to_postgis(gdf,*args,**kwargs):
geom_type=geom_types[0].upper()
kwargs['dtype'][gdf.geometry.name] = Geo(geom_type) #, srid=srid)
Session = sessionmaker(bind=con)
# use a session connection
session = Session()
con1=session.connection().connect()
if 'if_exists' in kwargs and con1.dialect.has_table(con1, table):
if kwargs['if_exists'] not in ['fail','replace']:
# table will be modified
# get actual column srid
column_srid=con1.execute('''SELECT Find_SRID('', '%s', '%s');''' % (table , gdf.geometry.name)).fetchone()[0]
with con.begin():
if 'if_exists' in kwargs and table in metadata.tables:
if kwargs['if_exists'] not in ['fail','replace']:
# table will be modified
# get actual column srid
with con.begin_nested() as trans:
try:
column_srid=int(con.execute('''SELECT Find_SRID('', '%s', '%s');''' % (table , gdf.geometry.name)).fetchone()[0])
except:
logger.warning('no column srid on existing table %s. assuming 0' % table)
column_srid=0
trans.rollback()
if column_srid != srid:
raise ValueError("geopandas srid %s doesn't match postgis Find_SRID %s" % (srid , column_srid))
if column_srid != srid:
raise ValueError("geopandas srid %s doesn't match postgis Find_SRID %s" % (srid , column_srid))
if column_srid != 0:
# overwrite column srid so we don't need to give a srid for each wkt (we be done at the end)
logger.debug('temporary removing SRID')
con.execute('''SELECT UpdateGeometrySRID('%s','%s',0);''' % (table , gdf.geometry.name))
if column_srid != 0:
# overwrite column srid so we don't need to give a srid for each wkt (we be done at the end)
con1.execute('''SELECT UpdateGeometrySRID('%s','%s',0);''' % (table , gdf.geometry.name))
ret=None
try:
# use connection session
args=list(args)
args[1]=con1
args[1]=con
args=tuple(args)
ret=gdf.to_sql(*args,**kwargs)
if srid != 0:
logger.debug('restoring SRID %s' % srid)
con.execute('''SELECT UpdateGeometrySRID('%s','%s',%s);''' % (table , gdf.geometry.name , srid))
con1.execute('''SELECT UpdateGeometrySRID('%s','%s',%s);''' % (table , gdf.geometry.name , srid))
session.commit()
except:
session.rollback()
raise
finally:
session.close()
return ret
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment