Fixed csvimport to load in table dependency order.

This commit is contained in:
Eevee 2009-05-28 21:16:18 -07:00
parent 0af6b1c8ab
commit 15ee3fcccf
1 changed files with 8 additions and 16 deletions

View File

@ -29,17 +29,11 @@ def csvimport(engine_uri, directory='.'):
metadata.create_all() metadata.create_all()
# Oh, mysql-chan.
# TODO try to insert data in preorder so we don't need this hack and won't
# break similarly on other engines
if 'mysql' in engine_uri:
session.execute('SET FOREIGN_KEY_CHECKS = 0')
# SQLAlchemy is retarded and there is no way for me to get a list of ORM # SQLAlchemy is retarded and there is no way for me to get a list of ORM
# classes besides to inspect the module they all happen to live in for # classes besides to inspect the module they all happen to live in for
# things that look right. # things that look right.
table_base = tables_module.TableBase table_base = tables_module.TableBase
orm_classes = {} orm_classes = {} # table object => table class
for name in dir(tables_module): for name in dir(tables_module):
# dir() returns strings! How /convenient/. # dir() returns strings! How /convenient/.
@ -56,10 +50,13 @@ def csvimport(engine_uri, directory='.'):
continue continue
# thingy is definitely a table class! Hallelujah. # thingy is definitely a table class! Hallelujah.
orm_classes[thingy.__table__.name] = thingy orm_classes[thingy.__table__] = thingy
# Okay, run through the tables and actually load the data now # Okay, run through the tables and actually load the data now
for table_name, table in sorted(orm_classes.items()): for table_obj in metadata.sorted_tables:
table_class = orm_classes[table_obj]
table_name = table_obj.name
# Print the table name but leave the cursor in a fixed column # Print the table name but leave the cursor in a fixed column
print table_name + '...', ' ' * (40 - len(table_name)), print table_name + '...', ' ' * (40 - len(table_name)),
@ -74,10 +71,10 @@ def csvimport(engine_uri, directory='.'):
column_names = [unicode(column) for column in reader.next()] column_names = [unicode(column) for column in reader.next()]
for csvs in reader: for csvs in reader:
row = table() row = table_class()
for column_name, value in zip(column_names, csvs): for column_name, value in zip(column_names, csvs):
column = table.__table__.c[column_name] column = table_obj.c[column_name]
if column.nullable and value == '': if column.nullable and value == '':
# Empty string in a nullable column really means NULL # Empty string in a nullable column really means NULL
value = None value = None
@ -99,11 +96,6 @@ def csvimport(engine_uri, directory='.'):
session.commit() session.commit()
print 'loaded' print 'loaded'
# Shouldn't matter since this is usually the end of the program and thus
# the connection too, but let's change this back just in case
if 'mysql' in engine_uri:
session.execute('SET FOREIGN_KEY_CHECKS = 1')
def csvexport(engine_uri, directory='.'): def csvexport(engine_uri, directory='.'):
import csv import csv