Share session across tests

This commit is contained in:
Petr Viktorin 2011-04-27 14:30:11 +03:00
parent d746716575
commit e4c0d0b16b
2 changed files with 7 additions and 4 deletions

View file

@ -1,9 +1,11 @@
from pokedex.db import connect
from pokedex.util.movesets import main from pokedex.util.movesets import main
result_map = {'OK': True, 'NO': False} result_map = {'OK': True, 'NO': False}
def test_cases(): def test_cases():
session = connect()
for argstring in u""" for argstring in u"""
NO muk NO muk
NO beedrill rage pursuit agility endeavor toxic NO beedrill rage pursuit agility endeavor toxic
@ -43,7 +45,7 @@ def test_cases():
""".strip().splitlines(): """.strip().splitlines():
def run_test(argstring): def run_test(argstring):
args = argstring.split() args = argstring.split()
assert main(args[1:]) == result_map[args[0]] assert main(args[1:], session=session) == result_map[args[0]]
run_test.description = 'Moveset checker test: ' + argstring.strip() run_test.description = 'Moveset checker test: ' + argstring.strip()
yield run_test, argstring.strip() yield run_test, argstring.strip()
@ -55,6 +57,6 @@ if __name__ == '__main__':
def header(str): def header(str):
print print
print str print str
cProfile.runctx("[(header(arg), f(arg)) for f, arg in test_cases()]", cProfile.runctx("[(header(argv), f(argv)) for f, argv in test_cases()]",
globals(), locals(), filename=filename) globals(), locals(), filename=filename)
print 'Profile stats saved to', filename print 'Profile stats saved to', filename

View file

@ -1289,7 +1289,7 @@ def print_result(result, moves=()):
''.join(m.name[0].lower() for m in node.moves if m not in moves), ''.join(m.name[0].lower() for m in node.moves if m not in moves),
) )
def main(argv): def main(argv, session=None):
parser = argparse.ArgumentParser(description= parser = argparse.ArgumentParser(description=
'Find out if the specified moveset is valid, and provide a suggestion ' 'Find out if the specified moveset is valid, and provide a suggestion '
'on how to obtain it.') 'on how to obtain it.')
@ -1328,7 +1328,8 @@ def main(argv):
if args.debug: if args.debug:
print 'Connecting' print 'Connecting'
session = connect(engine_args={'echo': args.debug > 2}) if session is None:
session = connect(engine_args={'echo': args.debug > 2})
if args.debug: if args.debug:
print 'Parsing arguments' print 'Parsing arguments'