mirror of
https://github.com/veekun/pokedex.git
synced 2024-08-20 18:16:34 +00:00
Fixed csvimport to load in table dependency order.
This commit is contained in:
parent
0af6b1c8ab
commit
15ee3fcccf
1 changed files with 8 additions and 16 deletions
|
@ -29,17 +29,11 @@ def csvimport(engine_uri, directory='.'):
|
||||||
|
|
||||||
metadata.create_all()
|
metadata.create_all()
|
||||||
|
|
||||||
# Oh, mysql-chan.
|
|
||||||
# TODO try to insert data in preorder so we don't need this hack and won't
|
|
||||||
# break similarly on other engines
|
|
||||||
if 'mysql' in engine_uri:
|
|
||||||
session.execute('SET FOREIGN_KEY_CHECKS = 0')
|
|
||||||
|
|
||||||
# SQLAlchemy is retarded and there is no way for me to get a list of ORM
|
# SQLAlchemy is retarded and there is no way for me to get a list of ORM
|
||||||
# classes besides to inspect the module they all happen to live in for
|
# classes besides to inspect the module they all happen to live in for
|
||||||
# things that look right.
|
# things that look right.
|
||||||
table_base = tables_module.TableBase
|
table_base = tables_module.TableBase
|
||||||
orm_classes = {}
|
orm_classes = {} # table object => table class
|
||||||
|
|
||||||
for name in dir(tables_module):
|
for name in dir(tables_module):
|
||||||
# dir() returns strings! How /convenient/.
|
# dir() returns strings! How /convenient/.
|
||||||
|
@ -56,10 +50,13 @@ def csvimport(engine_uri, directory='.'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# thingy is definitely a table class! Hallelujah.
|
# thingy is definitely a table class! Hallelujah.
|
||||||
orm_classes[thingy.__table__.name] = thingy
|
orm_classes[thingy.__table__] = thingy
|
||||||
|
|
||||||
# Okay, run through the tables and actually load the data now
|
# Okay, run through the tables and actually load the data now
|
||||||
for table_name, table in sorted(orm_classes.items()):
|
for table_obj in metadata.sorted_tables:
|
||||||
|
table_class = orm_classes[table_obj]
|
||||||
|
table_name = table_obj.name
|
||||||
|
|
||||||
# Print the table name but leave the cursor in a fixed column
|
# Print the table name but leave the cursor in a fixed column
|
||||||
print table_name + '...', ' ' * (40 - len(table_name)),
|
print table_name + '...', ' ' * (40 - len(table_name)),
|
||||||
|
|
||||||
|
@ -74,10 +71,10 @@ def csvimport(engine_uri, directory='.'):
|
||||||
column_names = [unicode(column) for column in reader.next()]
|
column_names = [unicode(column) for column in reader.next()]
|
||||||
|
|
||||||
for csvs in reader:
|
for csvs in reader:
|
||||||
row = table()
|
row = table_class()
|
||||||
|
|
||||||
for column_name, value in zip(column_names, csvs):
|
for column_name, value in zip(column_names, csvs):
|
||||||
column = table.__table__.c[column_name]
|
column = table_obj.c[column_name]
|
||||||
if column.nullable and value == '':
|
if column.nullable and value == '':
|
||||||
# Empty string in a nullable column really means NULL
|
# Empty string in a nullable column really means NULL
|
||||||
value = None
|
value = None
|
||||||
|
@ -99,11 +96,6 @@ def csvimport(engine_uri, directory='.'):
|
||||||
session.commit()
|
session.commit()
|
||||||
print 'loaded'
|
print 'loaded'
|
||||||
|
|
||||||
# Shouldn't matter since this is usually the end of the program and thus
|
|
||||||
# the connection too, but let's change this back just in case
|
|
||||||
if 'mysql' in engine_uri:
|
|
||||||
session.execute('SET FOREIGN_KEY_CHECKS = 1')
|
|
||||||
|
|
||||||
|
|
||||||
def csvexport(engine_uri, directory='.'):
|
def csvexport(engine_uri, directory='.'):
|
||||||
import csv
|
import csv
|
||||||
|
|
Loading…
Reference in a new issue