diff --git a/pokedex/tests/test_movesets.py b/pokedex/tests/test_movesets.py index 31b6139..af71cde 100644 --- a/pokedex/tests/test_movesets.py +++ b/pokedex/tests/test_movesets.py @@ -1,9 +1,11 @@ +from pokedex.db import connect from pokedex.util.movesets import main result_map = {'OK': True, 'NO': False} def test_cases(): + session = connect() for argstring in u""" NO muk NO beedrill rage pursuit agility endeavor toxic @@ -43,7 +45,7 @@ def test_cases(): """.strip().splitlines(): def run_test(argstring): args = argstring.split() - assert main(args[1:]) == result_map[args[0]] + assert main(args[1:], session=session) == result_map[args[0]] run_test.description = 'Moveset checker test: ' + argstring.strip() yield run_test, argstring.strip() @@ -55,6 +57,6 @@ if __name__ == '__main__': def header(str): print print str - cProfile.runctx("[(header(arg), f(arg)) for f, arg in test_cases()]", + cProfile.runctx("[(header(argv), f(argv)) for f, argv in test_cases()]", globals(), locals(), filename=filename) print 'Profile stats saved to', filename diff --git a/pokedex/util/movesets.py b/pokedex/util/movesets.py index b6ee2ce..d2a191f 100755 --- a/pokedex/util/movesets.py +++ b/pokedex/util/movesets.py @@ -1289,7 +1289,7 @@ def print_result(result, moves=()): ''.join(m.name[0].lower() for m in node.moves if m not in moves), ) -def main(argv): +def main(argv, session=None): parser = argparse.ArgumentParser(description= 'Find out if the specified moveset is valid, and provide a suggestion ' 'on how to obtain it.') @@ -1328,7 +1328,8 @@ def main(argv): if args.debug: print 'Connecting' - session = connect(engine_args={'echo': args.debug > 2}) + if session is None: + session = connect(engine_args={'echo': args.debug > 2}) if args.debug: print 'Parsing arguments'