From 93988d966cb3855bf44136eecce5b50d70b24366 Mon Sep 17 00:00:00 2001 From: Andrew Ekstedt Date: Sat, 30 May 2015 22:48:20 -0700 Subject: [PATCH] load: Use COPY FROM STDIN on PostgreSQL. COPY FROM FILE requires database superuser permissions, because of the obvious security implications. COPY FROM STDIN has no such restriction. Also do some cleanup while we're here. --- pokedex/db/load.py | 60 ++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/pokedex/db/load.py b/pokedex/db/load.py index d968fdb..37d7529 100644 --- a/pokedex/db/load.py +++ b/pokedex/db/load.py @@ -144,22 +144,23 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s table_objs = sqlalchemy.sql.util.sort_tables(table_objs) + engine = session.get_bind() + # Limit table names to 30 characters for Oracle - oracle = (session.connection().dialect.name == 'oracle') + oracle = (engine.dialect.name == 'oracle') if oracle: rewrite_long_table_names() # SQLite speed tweaks - if not safe and session.connection().dialect.name == 'sqlite': - session.connection().execute("PRAGMA synchronous=OFF") - session.connection().execute("PRAGMA journal_mode=OFF") + if not safe and engine.dialect.name == 'sqlite': + session.execute("PRAGMA synchronous=OFF") + session.execute("PRAGMA journal_mode=OFF") # Drop all tables if requested if drop_tables: - bind = session.get_bind() print_start('Dropping tables') for n, table in enumerate(reversed(table_objs)): - table.drop(checkfirst=True) + table.drop(bind=engine, checkfirst=True) # Drop columns' types if appropriate; needed for enums in # postgresql @@ -169,7 +170,7 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s except AttributeError: pass else: - drop(bind=bind, checkfirst=True) + drop(bind=engine, checkfirst=True) print_status('%s/%s' % (n, len(table_objs))) print_done() @@ -179,7 +180,6 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s table.create() print_status('%s/%s' % (n, len(table_objs))) print_done() - connection = session.connection() # Okay, run through the tables and actually load the data now for table_obj in table_objs: @@ -205,35 +205,34 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s reader = csv.reader(csvfile, lineterminator='\n') column_names = [unicode(column) for column in reader.next()] - if not safe and session.connection().dialect.name == 'postgresql': - """ - Postgres' CSV dialect works with our data, if we mark the not-null - columns with FORCE NOT NULL. - COPY is only allowed for DB superusers. If you're not one, use safe - loading (pokedex load -S). - """ - session.commit() + if not safe and engine.dialect.name == 'postgresql': + # Postgres' CSV dialect works with our data, if we mark the not-null + # columns with FORCE NOT NULL. not_null_cols = [c for c in column_names if not table_obj.c[c].nullable] if not_null_cols: force_not_null = 'FORCE NOT NULL ' + ','.join('"%s"' % c for c in not_null_cols) else: force_not_null = '' - command = "COPY %(table_name)s (%(columns)s) FROM '%(csvpath)s' CSV HEADER %(force_not_null)s" - session.connection().execute( + + # Grab the underlying psycopg2 cursor so we can use COPY FROM STDIN + raw_conn = engine.raw_connection() + command = "COPY %(table_name)s (%(columns)s) FROM STDIN CSV HEADER %(force_not_null)s" + csvfile.seek(0) + raw_conn.cursor().copy_expert( command % dict( table_name=table_name, - csvpath=csvpath, columns=','.join('"%s"' % c for c in column_names), force_not_null=force_not_null, - ) + ), + csvfile, ) - session.commit() + raw_conn.commit() print_done() continue # Self-referential tables may contain rows with foreign keys of other - # rows in the same table that do not yet exist. Pull these out and add - # them to the session last + # rows in the same table that do not yet exist. Pull these out and + # insert them last # ASSUMPTION: Self-referential tables have a single PK called "id" deferred_rows = [] # ( row referring to id, [foreign ids we need] ) seen_ids = set() # primary keys we've seen @@ -248,7 +247,7 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s def insert_and_commit(): if not new_rows: return - session.connection().execute(insert_stmt, new_rows) + session.execute(insert_stmt, new_rows) session.commit() new_rows[:] = [] @@ -316,12 +315,12 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s raise ValueError("Too many levels of self-reference! " "Row was: " + str(row)) - session.connection().execute( + session.execute( insert_stmt.values(**row_data) ) seen_ids.add(row_data['id']) - session.commit() + session.commit() print_done() @@ -333,18 +332,17 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s table_obj = translation_class.__table__ if table_obj in table_objs: insert_stmt = table_obj.insert() - session.connection().execute(insert_stmt, rows) + session.execute(insert_stmt, rows) session.commit() # We don't have a total, but at least show some increasing number new_row_count += len(rows) print_status(str(new_row_count)) - print_done() - # SQLite check - if session.connection().dialect.name == 'sqlite': - session.connection().execute("PRAGMA integrity_check") + if engine.dialect.name == 'sqlite': + session.execute("PRAGMA integrity_check") + print_done() def dump(session, tables=[], directory=None, verbose=False, langs=None):