From 0ff24b4dc8b80cbe67f7484b32f5ceb318f30099 Mon Sep 17 00:00:00 2001 From: "Eevee (Alex Munroe)" Date: Mon, 5 Oct 2015 16:29:21 -0700 Subject: [PATCH] Fix the CLI in py3 --- pokedex/db/load.py | 22 +++++++++++++++------- pokedex/db/translations.py | 11 ++++++----- pokedex/main.py | 2 +- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/pokedex/db/load.py b/pokedex/db/load.py index da3130f..c87304a 100644 --- a/pokedex/db/load.py +++ b/pokedex/db/load.py @@ -3,9 +3,11 @@ from __future__ import print_function import csv import fnmatch +import io import os.path import sys +import six import sqlalchemy.sql.util import sqlalchemy.types @@ -196,16 +198,20 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s try: csvpath = "%s/%s.csv" % (directory, table_name) - csvfile = open(csvpath, 'rb') + csvfile = open(csvpath, 'r') except IOError: # File doesn't exist; don't load anything! print_done('missing?') continue - csvsize = os.stat(csvpath).st_size + # XXX This is wrong for files with multi-line fields, but Python 3 + # doesn't allow .tell() on a file that's currently being iterated + # (because the result is completely bogus). Oh well. + csvsize = sum(1 for line in csvfile) + csvfile.seek(0) reader = csv.reader(csvfile, lineterminator='\n') - column_names = [unicode(column) for column in reader.next()] + column_names = [six.text_type(column) for column in next(reader)] if not safe and engine.dialect.name == 'postgresql': # Postgres' CSV dialect works with our data, if we mark the not-null @@ -253,10 +259,12 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s session.commit() new_rows[:] = [] - progress = "%d%%" % (100 * csvfile.tell() // csvsize) + progress = "%d%%" % (100 * csvpos // csvsize) print_status(progress) + csvpos = 0 for csvs in reader: + csvpos += 1 row_data = {} for column_name, value in zip(column_names, csvs): @@ -271,7 +279,7 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s value = False else: value = True - else: + elif isinstance(value, bytes): # Otherwise, unflatten from bytes value = value.decode('utf-8') @@ -394,7 +402,7 @@ def dump(session, tables=[], directory=None, verbose=False, langs=None): else: filename = '%s/%s.csv' % (directory, table_name) - writer = csv.writer(open(filename, 'wb'), lineterminator='\n') + writer = csv.writer(io.open(filename, 'w', newline=''), lineterminator='\n') columns = [col.name for col in table.columns] @@ -434,7 +442,7 @@ def dump(session, tables=[], directory=None, verbose=False, langs=None): elif val == False: val = '0' else: - val = unicode(val).encode('utf-8') + val = six.text_type(val).encode('utf-8') csvs.append(val) diff --git a/pokedex/db/translations.py b/pokedex/db/translations.py index 39dc07c..05a04fa 100755 --- a/pokedex/db/translations.py +++ b/pokedex/db/translations.py @@ -25,6 +25,7 @@ from __future__ import print_function import binascii import csv +import io import os import re from collections import defaultdict @@ -257,11 +258,11 @@ class Translations(object): def reader_for_class(self, cls, reader_class=csv.reader): tablename = cls.__table__.name csvpath = os.path.join(self.csv_directory, tablename + '.csv') - return reader_class(open(csvpath, 'rb'), lineterminator='\n') + return reader_class(open(csvpath, 'r'), lineterminator='\n') def writer_for_lang(self, lang): csvpath = os.path.join(self.translation_directory, '%s.csv' % lang) - return csv.writer(open(csvpath, 'wb'), lineterminator='\n') + return csv.writer(io.open(csvpath, 'w', newline=''), lineterminator='\n') def yield_source_messages(self, language_id=None): """Yield all messages from source CSV files @@ -302,7 +303,7 @@ class Translations(object): """ path = os.path.join(self.csv_directory, 'translations', '%s.csv' % lang) try: - file = open(path, 'rb') + file = open(path, 'r') except IOError: return () return yield_translation_csv_messages(file) @@ -353,11 +354,11 @@ class Translations(object): count += 1 if count > 1000: for translation_class, key_data in everything.items(): - yield translation_class, key_data.values() + yield translation_class, list(key_data.values()) count = 0 everything.clear() for translation_class, data_dict in everything.items(): - yield translation_class, data_dict.values() + yield translation_class, list(data_dict.values()) def group_by_object(stream): """Group stream by object diff --git a/pokedex/main.py b/pokedex/main.py index 180248a..a8171a0 100644 --- a/pokedex/main.py +++ b/pokedex/main.py @@ -21,7 +21,7 @@ def main(*argv): # XXX there must be a better way to get Unicode argv # XXX this doesn't work on Windows durp enc = sys.stdin.encoding or 'utf8' - args = [_.decode(enc) for _ in args] + args = [_.decode(enc) if isinstance(_, bytes) else _ for _ in args] # Find the command as a function in this file func = globals().get("command_%s" % command, None)