load: Use COPY FROM STDIN on PostgreSQL.

COPY FROM FILE requires database superuser permissions,
because of the obvious security implications.

COPY FROM STDIN has no such restriction.

Also do some cleanup while we're here.
This commit is contained in:
Andrew Ekstedt 2015-05-30 22:48:20 -07:00
parent 33fab44d0d
commit 93988d966c

View file

@ -144,22 +144,23 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
table_objs = sqlalchemy.sql.util.sort_tables(table_objs)
engine = session.get_bind()
# Limit table names to 30 characters for Oracle
oracle = (session.connection().dialect.name == 'oracle')
oracle = (engine.dialect.name == 'oracle')
if oracle:
rewrite_long_table_names()
# SQLite speed tweaks
if not safe and session.connection().dialect.name == 'sqlite':
session.connection().execute("PRAGMA synchronous=OFF")
session.connection().execute("PRAGMA journal_mode=OFF")
if not safe and engine.dialect.name == 'sqlite':
session.execute("PRAGMA synchronous=OFF")
session.execute("PRAGMA journal_mode=OFF")
# Drop all tables if requested
if drop_tables:
bind = session.get_bind()
print_start('Dropping tables')
for n, table in enumerate(reversed(table_objs)):
table.drop(checkfirst=True)
table.drop(bind=engine, checkfirst=True)
# Drop columns' types if appropriate; needed for enums in
# postgresql
@ -169,7 +170,7 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
except AttributeError:
pass
else:
drop(bind=bind, checkfirst=True)
drop(bind=engine, checkfirst=True)
print_status('%s/%s' % (n, len(table_objs)))
print_done()
@ -179,7 +180,6 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
table.create()
print_status('%s/%s' % (n, len(table_objs)))
print_done()
connection = session.connection()
# Okay, run through the tables and actually load the data now
for table_obj in table_objs:
@ -205,35 +205,34 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
reader = csv.reader(csvfile, lineterminator='\n')
column_names = [unicode(column) for column in reader.next()]
if not safe and session.connection().dialect.name == 'postgresql':
"""
Postgres' CSV dialect works with our data, if we mark the not-null
columns with FORCE NOT NULL.
COPY is only allowed for DB superusers. If you're not one, use safe
loading (pokedex load -S).
"""
session.commit()
if not safe and engine.dialect.name == 'postgresql':
# Postgres' CSV dialect works with our data, if we mark the not-null
# columns with FORCE NOT NULL.
not_null_cols = [c for c in column_names if not table_obj.c[c].nullable]
if not_null_cols:
force_not_null = 'FORCE NOT NULL ' + ','.join('"%s"' % c for c in not_null_cols)
else:
force_not_null = ''
command = "COPY %(table_name)s (%(columns)s) FROM '%(csvpath)s' CSV HEADER %(force_not_null)s"
session.connection().execute(
# Grab the underlying psycopg2 cursor so we can use COPY FROM STDIN
raw_conn = engine.raw_connection()
command = "COPY %(table_name)s (%(columns)s) FROM STDIN CSV HEADER %(force_not_null)s"
csvfile.seek(0)
raw_conn.cursor().copy_expert(
command % dict(
table_name=table_name,
csvpath=csvpath,
columns=','.join('"%s"' % c for c in column_names),
force_not_null=force_not_null,
),
csvfile,
)
)
session.commit()
raw_conn.commit()
print_done()
continue
# Self-referential tables may contain rows with foreign keys of other
# rows in the same table that do not yet exist. Pull these out and add
# them to the session last
# rows in the same table that do not yet exist. Pull these out and
# insert them last
# ASSUMPTION: Self-referential tables have a single PK called "id"
deferred_rows = [] # ( row referring to id, [foreign ids we need] )
seen_ids = set() # primary keys we've seen
@ -248,7 +247,7 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
def insert_and_commit():
if not new_rows:
return
session.connection().execute(insert_stmt, new_rows)
session.execute(insert_stmt, new_rows)
session.commit()
new_rows[:] = []
@ -316,12 +315,12 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
raise ValueError("Too many levels of self-reference! "
"Row was: " + str(row))
session.connection().execute(
session.execute(
insert_stmt.values(**row_data)
)
seen_ids.add(row_data['id'])
session.commit()
session.commit()
print_done()
@ -333,18 +332,17 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
table_obj = translation_class.__table__
if table_obj in table_objs:
insert_stmt = table_obj.insert()
session.connection().execute(insert_stmt, rows)
session.execute(insert_stmt, rows)
session.commit()
# We don't have a total, but at least show some increasing number
new_row_count += len(rows)
print_status(str(new_row_count))
print_done()
# SQLite check
if session.connection().dialect.name == 'sqlite':
session.connection().execute("PRAGMA integrity_check")
if engine.dialect.name == 'sqlite':
session.execute("PRAGMA integrity_check")
print_done()
def dump(session, tables=[], directory=None, verbose=False, langs=None):