General code tidying for this Oracle fix.

This commit is contained in:
Lynn "Zhorken" Vaughan 2014-02-21 16:45:47 -05:00
parent 5b759feaa2
commit f54a4caaca
3 changed files with 35 additions and 52 deletions

View file

@ -25,7 +25,7 @@ def connect(uri=None, session_args={}, engine_args={}, engine_prefix=''):
if uri is None: if uri is None:
uri = get_default_db_uri() uri = get_default_db_uri()
### Do some fixery for MySQL ### Do some fixery for specific RDBMSes
if uri.startswith('mysql:'): if uri.startswith('mysql:'):
# MySQL uses latin1 for connections by default even if the server is # MySQL uses latin1 for connections by default even if the server is
# otherwise oozing with utf8; charset fixes this # otherwise oozing with utf8; charset fixes this
@ -37,9 +37,7 @@ def connect(uri=None, session_args={}, engine_args={}, engine_prefix=''):
for table in metadata.tables.values(): for table in metadata.tables.values():
table.kwargs['mysql_engine'] = 'InnoDB' table.kwargs['mysql_engine'] = 'InnoDB'
table.kwargs['mysql_charset'] = 'utf8' table.kwargs['mysql_charset'] = 'utf8'
elif uri.startswith(('oracle:', 'oracle+cx_oracle:')):
### Do some fixery for Oracle
if uri.startswith('oracle:') or uri.startswith('oracle+cx_oracle:'):
# Oracle requires auto_setinputsizes=False (or at least a special # Oracle requires auto_setinputsizes=False (or at least a special
# set of exclusions from it, which I don't know) # set of exclusions from it, which I don't know)
if 'auto_setinputsizes' not in uri: if 'auto_setinputsizes' not in uri:

View file

@ -139,19 +139,16 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
table_names = _get_table_names(metadata, tables) table_names = _get_table_names(metadata, tables)
table_objs = [metadata.tables[name] for name in table_names] table_objs = [metadata.tables[name] for name in table_names]
# Oracle fixery, load needs short names
# flag for oracle stuff
oranames = (session.connection().dialect.name == 'oracle')
if oranames:
# Shorten table names, Oracle limits table and column names to 30 chars
# Make a dictionary to match old<->new names
oradict = rewrite_long_table_names()
if recursive: if recursive:
table_objs.extend(find_dependent_tables(table_objs)) table_objs.extend(find_dependent_tables(table_objs))
table_objs = sqlalchemy.sql.util.sort_tables(table_objs) table_objs = sqlalchemy.sql.util.sort_tables(table_objs)
# Limit table names to 30 characters for Oracle
oracle = (session.connection().dialect.name == 'oracle')
if oracle:
rewrite_long_table_names()
# SQLite speed tweaks # SQLite speed tweaks
if not safe and session.connection().dialect.name == 'sqlite': if not safe and session.connection().dialect.name == 'sqlite':
session.connection().execute("PRAGMA synchronous=OFF") session.connection().execute("PRAGMA synchronous=OFF")
@ -186,15 +183,17 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
# Okay, run through the tables and actually load the data now # Okay, run through the tables and actually load the data now
for table_obj in table_objs: for table_obj in table_objs:
table_name = table_obj.name if oracle:
table_name = table_obj._original_name
else:
table_name = table_obj.name
insert_stmt = table_obj.insert() insert_stmt = table_obj.insert()
print_start(table_name) print_start(table_name)
try: try:
csvpath = "%s/%s.csv" % (directory, table_name) csvpath = "%s/%s.csv" % (directory, table_name)
# In oracle mode, use the original names instead of current
if oranames:
csvpath = "%s/%s.csv" % (directory, oradict[table_name])
csvfile = open(csvpath, 'rb') csvfile = open(csvpath, 'rb')
except IOError: except IOError:
# File doesn't exist; don't load anything! # File doesn't exist; don't load anything!
@ -380,22 +379,23 @@ def dump(session, tables=[], directory=None, verbose=False, langs=['en']):
table_names = _get_table_names(metadata, tables) table_names = _get_table_names(metadata, tables)
table_names.sort() table_names.sort()
# Oracle fixery : read from short table names, dump long names # Oracle needs to dump from tables with shortened names to csvs with the
oranames = (session.connection().dialect.name == 'oracle') # usual names
if oranames: oracle = (session.connection().dialect.name == 'oracle')
# Make a dictionary to match old<->new names if oracle:
oradict = rewrite_long_table_names() rewrite_long_table_names()
for table_name in table_names: for table_name in table_names:
print_start(table_name) print_start(table_name)
table = metadata.tables[table_name] table = metadata.tables[table_name]
writer = csv.writer(open("%s/%s.csv" % (directory, table_name), 'wb'), if oracle:
lineterminator='\n') filename = '%s/%s.csv' % (directory, table._original_name)
# In oracle mode, use the original names instead of current else:
if oranames: filename = '%s/%s.csv' % (directory, table_name)
writer = csv.writer(open("%s/%s.csv" % (directory, oradict[table_name]), 'wb'),
lineterminator='\n') writer = csv.writer(open(filename, 'wb'), lineterminator='\n')
columns = [col.name for col in table.columns] columns = [col.name for col in table.columns]
# For name tables, dump rows for official languages, as well as # For name tables, dump rows for official languages, as well as

View file

@ -8,34 +8,19 @@ def rewrite_long_table_names():
Returns a dictionary matching short names -> long names. Returns a dictionary matching short names -> long names.
""" """
# Load table names from metadata # Load tables from metadata
t_names = metadata.tables.keys() table_objs = metadata.tables.values()
table_names = list(t_names)
table_objs = [metadata.tables[name] for name in table_names]
# Prepare a dictionary to match old<->new names
dictionary = {}
# Shorten table names, Oracle limits table and column names to 30 chars # Shorten table names, Oracle limits table and column names to 30 chars
for table in table_objs: for table in table_objs:
table._orginal_name = table.name[:] table._original_name = table.name
dictionary[table.name]=table._orginal_name
if len(table._orginal_name) > 30:
for letter in ['a', 'e', 'i', 'o', 'u', 'y']:
table.name=table.name.replace(letter,'')
dictionary[table.name]=table._orginal_name
return dictionary
if len(table.name) > 30:
for letter in 'aeiouy':
table.name = table.name.replace(letter, '')
def restore_long_table_names(metadata,dictionary): def restore_long_table_names():
"""Modifies the table names to restore the long-naming. """Modifies the table names to restore the long-naming."""
`metadata`
The metadata to restore.
`dictionary`
The dictionary matching short name -> long name.
"""
for table in metadata.tables.values(): for table in metadata.tables.values():
table.name = dictionary[table.name] table.name = table._original_name
del table._original_name