load: Use COPY FROM STDIN on PostgreSQL.

COPY FROM FILE requires database superuser permissions,
because of the obvious security implications.

COPY FROM STDIN has no such restriction.

Also do some cleanup while we're here.
This commit is contained in:
Andrew Ekstedt 2015-05-30 22:48:20 -07:00
parent 33fab44d0d
commit 93988d966c

View file

@ -144,22 +144,23 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
table_objs = sqlalchemy.sql.util.sort_tables(table_objs) table_objs = sqlalchemy.sql.util.sort_tables(table_objs)
engine = session.get_bind()
# Limit table names to 30 characters for Oracle # Limit table names to 30 characters for Oracle
oracle = (session.connection().dialect.name == 'oracle') oracle = (engine.dialect.name == 'oracle')
if oracle: if oracle:
rewrite_long_table_names() rewrite_long_table_names()
# SQLite speed tweaks # SQLite speed tweaks
if not safe and session.connection().dialect.name == 'sqlite': if not safe and engine.dialect.name == 'sqlite':
session.connection().execute("PRAGMA synchronous=OFF") session.execute("PRAGMA synchronous=OFF")
session.connection().execute("PRAGMA journal_mode=OFF") session.execute("PRAGMA journal_mode=OFF")
# Drop all tables if requested # Drop all tables if requested
if drop_tables: if drop_tables:
bind = session.get_bind()
print_start('Dropping tables') print_start('Dropping tables')
for n, table in enumerate(reversed(table_objs)): for n, table in enumerate(reversed(table_objs)):
table.drop(checkfirst=True) table.drop(bind=engine, checkfirst=True)
# Drop columns' types if appropriate; needed for enums in # Drop columns' types if appropriate; needed for enums in
# postgresql # postgresql
@ -169,7 +170,7 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
except AttributeError: except AttributeError:
pass pass
else: else:
drop(bind=bind, checkfirst=True) drop(bind=engine, checkfirst=True)
print_status('%s/%s' % (n, len(table_objs))) print_status('%s/%s' % (n, len(table_objs)))
print_done() print_done()
@ -179,7 +180,6 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
table.create() table.create()
print_status('%s/%s' % (n, len(table_objs))) print_status('%s/%s' % (n, len(table_objs)))
print_done() print_done()
connection = session.connection()
# Okay, run through the tables and actually load the data now # Okay, run through the tables and actually load the data now
for table_obj in table_objs: for table_obj in table_objs:
@ -205,35 +205,34 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
reader = csv.reader(csvfile, lineterminator='\n') reader = csv.reader(csvfile, lineterminator='\n')
column_names = [unicode(column) for column in reader.next()] column_names = [unicode(column) for column in reader.next()]
if not safe and session.connection().dialect.name == 'postgresql': if not safe and engine.dialect.name == 'postgresql':
""" # Postgres' CSV dialect works with our data, if we mark the not-null
Postgres' CSV dialect works with our data, if we mark the not-null # columns with FORCE NOT NULL.
columns with FORCE NOT NULL.
COPY is only allowed for DB superusers. If you're not one, use safe
loading (pokedex load -S).
"""
session.commit()
not_null_cols = [c for c in column_names if not table_obj.c[c].nullable] not_null_cols = [c for c in column_names if not table_obj.c[c].nullable]
if not_null_cols: if not_null_cols:
force_not_null = 'FORCE NOT NULL ' + ','.join('"%s"' % c for c in not_null_cols) force_not_null = 'FORCE NOT NULL ' + ','.join('"%s"' % c for c in not_null_cols)
else: else:
force_not_null = '' force_not_null = ''
command = "COPY %(table_name)s (%(columns)s) FROM '%(csvpath)s' CSV HEADER %(force_not_null)s"
session.connection().execute( # Grab the underlying psycopg2 cursor so we can use COPY FROM STDIN
raw_conn = engine.raw_connection()
command = "COPY %(table_name)s (%(columns)s) FROM STDIN CSV HEADER %(force_not_null)s"
csvfile.seek(0)
raw_conn.cursor().copy_expert(
command % dict( command % dict(
table_name=table_name, table_name=table_name,
csvpath=csvpath,
columns=','.join('"%s"' % c for c in column_names), columns=','.join('"%s"' % c for c in column_names),
force_not_null=force_not_null, force_not_null=force_not_null,
) ),
csvfile,
) )
session.commit() raw_conn.commit()
print_done() print_done()
continue continue
# Self-referential tables may contain rows with foreign keys of other # Self-referential tables may contain rows with foreign keys of other
# rows in the same table that do not yet exist. Pull these out and add # rows in the same table that do not yet exist. Pull these out and
# them to the session last # insert them last
# ASSUMPTION: Self-referential tables have a single PK called "id" # ASSUMPTION: Self-referential tables have a single PK called "id"
deferred_rows = [] # ( row referring to id, [foreign ids we need] ) deferred_rows = [] # ( row referring to id, [foreign ids we need] )
seen_ids = set() # primary keys we've seen seen_ids = set() # primary keys we've seen
@ -248,7 +247,7 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
def insert_and_commit(): def insert_and_commit():
if not new_rows: if not new_rows:
return return
session.connection().execute(insert_stmt, new_rows) session.execute(insert_stmt, new_rows)
session.commit() session.commit()
new_rows[:] = [] new_rows[:] = []
@ -316,12 +315,12 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
raise ValueError("Too many levels of self-reference! " raise ValueError("Too many levels of self-reference! "
"Row was: " + str(row)) "Row was: " + str(row))
session.connection().execute( session.execute(
insert_stmt.values(**row_data) insert_stmt.values(**row_data)
) )
seen_ids.add(row_data['id']) seen_ids.add(row_data['id'])
session.commit()
session.commit()
print_done() print_done()
@ -333,18 +332,17 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
table_obj = translation_class.__table__ table_obj = translation_class.__table__
if table_obj in table_objs: if table_obj in table_objs:
insert_stmt = table_obj.insert() insert_stmt = table_obj.insert()
session.connection().execute(insert_stmt, rows) session.execute(insert_stmt, rows)
session.commit() session.commit()
# We don't have a total, but at least show some increasing number # We don't have a total, but at least show some increasing number
new_row_count += len(rows) new_row_count += len(rows)
print_status(str(new_row_count)) print_status(str(new_row_count))
print_done()
# SQLite check # SQLite check
if session.connection().dialect.name == 'sqlite': if engine.dialect.name == 'sqlite':
session.connection().execute("PRAGMA integrity_check") session.execute("PRAGMA integrity_check")
print_done()
def dump(session, tables=[], directory=None, verbose=False, langs=None): def dump(session, tables=[], directory=None, verbose=False, langs=None):