diff --git a/pokedex/util/movesets.py b/pokedex/util/movesets.py index 1f7252f..b6ee2ce 100755 --- a/pokedex/util/movesets.py +++ b/pokedex/util/movesets.py @@ -6,6 +6,7 @@ import argparse import itertools import heapq from collections import defaultdict, namedtuple +from weakref import WeakKeyDictionary from sqlalchemy.orm import aliased from sqlalchemy.orm.exc import NoResultFound @@ -41,38 +42,50 @@ def powerset(iterable): ### class MovesetSearch(object): + _cache = WeakKeyDictionary() + def __init__(self, session, pokemon, version, moves, level=100, costs=None, exclude_versions=(), exclude_pokemon=(), debug_level=False): - self.generator = None + self.session = session + + self.debug_level = debug_level + + try: + # Cache all the common stuff. + # Not that it takes a lot of time to load, but it can add up + # if repeated. + self.__dict__ = MovesetSearch._cache[session] + except KeyError: + self.generator = None + + self.sketch = util.get(session, tables.Move, identifier=u'sketch').id + self.unsketchable = set([ + util.get(session, tables.Move, identifier=u'struggle').id, + util.get(session, tables.Move, identifier=u'chatter').id, + ]) + self.no_eggs_group = util.get(session, tables.EggGroup, + identifier=u'no-eggs').id + self.ditto_group = util.get(session, tables.EggGroup, + identifier=u'ditto').id + + self.load_pokemon() + self.load_moves() + + MovesetSearch._cache[session] = self.__dict__ + + self.debug_level = debug_level if not moves: raise NoMoves('No moves specified.') elif len(moves) > 4: raise NoMoves('Too many moves specified.') - self.debug_level = debug_level - - self.session = session - - self.sketch = util.get(session, tables.Move, identifier=u'sketch').id - self.unsketchable = set([ - util.get(session, tables.Move, identifier=u'struggle').id, - util.get(session, tables.Move, identifier=u'chatter').id, - ]) - self.no_eggs_group = util.get(session, tables.EggGroup, - identifier=u'no-eggs').id - self.ditto_group = util.get(session, tables.EggGroup, - identifier=u'ditto').id - if costs is None: self.costs = default_costs else: self.costs = costs - self.load_pokemon() - self.load_moves() - self.excluded_families = frozenset(p.evolution_chain_id for p in exclude_pokemon) @@ -638,7 +651,7 @@ default_costs = { class Node(object): """Node for the A* search algorithm. - To get started, implement the `expand` method and call `search`. + To get started, implement `expand` & `is_goal` and call `find_path`. N.B. Node objects must be hashable. """ @@ -652,7 +665,12 @@ class Node(object): """ raise NotImplementedError - def estimate(self, goal): + def is_goal(self): + """Return true iff this is a goal node. + """ + raise NotImplementedError + + def estimate(self): """Return an *optimistic* estimate of the cost to the given goal node. If there are multiple goal states, return the lowest estimate among all @@ -660,11 +678,6 @@ class Node(object): """ return 0 - def is_goal(self, goal): - """Return true iff this is a goal node. - """ - return self == goal - def find_path(self, goal=None, **kwargs): """Return the best path to the goal @@ -676,13 +689,13 @@ class Node(object): See a_star for the advanced keyword arguments, `notify` and `estimate_error_callback`. """ - paths = self.find_all_paths(goal=goal, **kwargs) + paths = self.find_all_paths(**kwargs) try: return paths.next() except StopIteration: return None - def find_all_paths(self, goal=None, **kwargs): + def find_all_paths(self, **kwargs): """Yield the best path to each goal Returns an iterator of paths. See the `search` method for how paths @@ -699,8 +712,8 @@ class Node(object): return a_star( initial=self, expand=lambda s: s.expand(), - estimate=lambda s: s.estimate(goal), - is_goal=lambda s: s.is_goal(goal), + estimate=lambda s: s.estimate(), + is_goal=lambda s: s.is_goal(), **kwargs) def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None, @@ -748,11 +761,10 @@ def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None, stopping criteria; the other values may help in tuning estimators. `estimate_error_callback`: function handling cases where an estimate was - detected not to be optimistic (as A* requires). The function is given a - path (as would be returned by a_star, except it does not lead to a goal - node). By default, nothing is done (indeed, an estimate that's not - strictly optimistic can be useful, esp. if the optimal path is not - required) + detected not to be consistent. The function is given a path (as would + be returned by a_star, except it does not lead to a goal node). By + default, nothing is done (as an estimate that's not even optimistic can + still be useful). """ # g: best cummulative cost (from initial node) found so far # h: optimistic estimate of cost to goal @@ -971,6 +983,9 @@ class InitialNode(Node, namedtuple('InitialNode', 'search')): ) yield 0, action, node + def is_goal(self): + return False + class PokemonNode(Node, Facade, namedtuple('PokemonNode', 'search pokemon_ level version_group_ new_level moves_')): @@ -1176,7 +1191,7 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode', new_level=False, moves_=frozenset(moves)) return - def estimate(self, g): + def estimate(self): # Given good estimates, A* finds solutions much faster. # However, here it seems we either have easy movesets, which # get found pretty easily by themselves, or hard ones, where @@ -1191,6 +1206,9 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode', search.goal_version_group, search.costs['trade'] * 2) return trade_cost + def is_goal(self): + return False + class BaseBreedNode(Node): """Breed node This serves to prevent duplicate breeds, by storing only the needed info @@ -1217,12 +1235,8 @@ class BaseBreedNode(Node): search=self.search, pokemon_=baby, level=hatch_level, version_group_=vg, moves_=bred_moves, new_level=True) - @property - def pokemon(self): - return None - - def estimate(self, g): - return 0 + def is_goal(self): + return False class BreedNode(BaseBreedNode, namedtuple('BreedNode', 'search dummy group_ version_group_ moves_')): @@ -1248,8 +1262,9 @@ class GoalNode(PokemonNode): def expand(self): return () - def is_goal(self, g): + def is_goal(self): return True + ### ### CLI interface ### @@ -1265,7 +1280,7 @@ def print_result(result, moves=()): cost=cost, action=action, long='>' if (len(unicode(action)) > 45) else '', - est=node.estimate(None), + est=node.estimate(), pokemon=node.pokemon.name, nl='.' if node.new_level else ' ', level=node.level,