diff --git a/pokedex/db/load.py b/pokedex/db/load.py index d968fdb..37d7529 100644 --- a/pokedex/db/load.py +++ b/pokedex/db/load.py @@ -144,22 +144,23 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s table_objs = sqlalchemy.sql.util.sort_tables(table_objs) + engine = session.get_bind() + # Limit table names to 30 characters for Oracle - oracle = (session.connection().dialect.name == 'oracle') + oracle = (engine.dialect.name == 'oracle') if oracle: rewrite_long_table_names() # SQLite speed tweaks - if not safe and session.connection().dialect.name == 'sqlite': - session.connection().execute("PRAGMA synchronous=OFF") - session.connection().execute("PRAGMA journal_mode=OFF") + if not safe and engine.dialect.name == 'sqlite': + session.execute("PRAGMA synchronous=OFF") + session.execute("PRAGMA journal_mode=OFF") # Drop all tables if requested if drop_tables: - bind = session.get_bind() print_start('Dropping tables') for n, table in enumerate(reversed(table_objs)): - table.drop(checkfirst=True) + table.drop(bind=engine, checkfirst=True) # Drop columns' types if appropriate; needed for enums in # postgresql @@ -169,7 +170,7 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s except AttributeError: pass else: - drop(bind=bind, checkfirst=True) + drop(bind=engine, checkfirst=True) print_status('%s/%s' % (n, len(table_objs))) print_done() @@ -179,7 +180,6 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s table.create() print_status('%s/%s' % (n, len(table_objs))) print_done() - connection = session.connection() # Okay, run through the tables and actually load the data now for table_obj in table_objs: @@ -205,35 +205,34 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s reader = csv.reader(csvfile, lineterminator='\n') column_names = [unicode(column) for column in reader.next()] - if not safe and session.connection().dialect.name == 'postgresql': - """ - Postgres' CSV dialect works with our data, if we mark the not-null - columns with FORCE NOT NULL. - COPY is only allowed for DB superusers. If you're not one, use safe - loading (pokedex load -S). - """ - session.commit() + if not safe and engine.dialect.name == 'postgresql': + # Postgres' CSV dialect works with our data, if we mark the not-null + # columns with FORCE NOT NULL. not_null_cols = [c for c in column_names if not table_obj.c[c].nullable] if not_null_cols: force_not_null = 'FORCE NOT NULL ' + ','.join('"%s"' % c for c in not_null_cols) else: force_not_null = '' - command = "COPY %(table_name)s (%(columns)s) FROM '%(csvpath)s' CSV HEADER %(force_not_null)s" - session.connection().execute( + + # Grab the underlying psycopg2 cursor so we can use COPY FROM STDIN + raw_conn = engine.raw_connection() + command = "COPY %(table_name)s (%(columns)s) FROM STDIN CSV HEADER %(force_not_null)s" + csvfile.seek(0) + raw_conn.cursor().copy_expert( command % dict( table_name=table_name, - csvpath=csvpath, columns=','.join('"%s"' % c for c in column_names), force_not_null=force_not_null, - ) + ), + csvfile, ) - session.commit() + raw_conn.commit() print_done() continue # Self-referential tables may contain rows with foreign keys of other - # rows in the same table that do not yet exist. Pull these out and add - # them to the session last + # rows in the same table that do not yet exist. Pull these out and + # insert them last # ASSUMPTION: Self-referential tables have a single PK called "id" deferred_rows = [] # ( row referring to id, [foreign ids we need] ) seen_ids = set() # primary keys we've seen @@ -248,7 +247,7 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s def insert_and_commit(): if not new_rows: return - session.connection().execute(insert_stmt, new_rows) + session.execute(insert_stmt, new_rows) session.commit() new_rows[:] = [] @@ -316,12 +315,12 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s raise ValueError("Too many levels of self-reference! " "Row was: " + str(row)) - session.connection().execute( + session.execute( insert_stmt.values(**row_data) ) seen_ids.add(row_data['id']) - session.commit() + session.commit() print_done() @@ -333,18 +332,17 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s table_obj = translation_class.__table__ if table_obj in table_objs: insert_stmt = table_obj.insert() - session.connection().execute(insert_stmt, rows) + session.execute(insert_stmt, rows) session.commit() # We don't have a total, but at least show some increasing number new_row_count += len(rows) print_status(str(new_row_count)) - print_done() - # SQLite check - if session.connection().dialect.name == 'sqlite': - session.connection().execute("PRAGMA integrity_check") + if engine.dialect.name == 'sqlite': + session.execute("PRAGMA integrity_check") + print_done() def dump(session, tables=[], directory=None, verbose=False, langs=None):