Cache common stuff. Remove goal from a_star.

This commit is contained in:
Petr Viktorin 2011-04-27 14:21:24 +03:00
parent cb78144e8d
commit d746716575

View file

@ -6,6 +6,7 @@ import argparse
import itertools import itertools
import heapq import heapq
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
from weakref import WeakKeyDictionary
from sqlalchemy.orm import aliased from sqlalchemy.orm import aliased
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
@ -41,19 +42,22 @@ def powerset(iterable):
### ###
class MovesetSearch(object): class MovesetSearch(object):
_cache = WeakKeyDictionary()
def __init__(self, session, pokemon, version, moves, level=100, costs=None, def __init__(self, session, pokemon, version, moves, level=100, costs=None,
exclude_versions=(), exclude_pokemon=(), debug_level=False): exclude_versions=(), exclude_pokemon=(), debug_level=False):
self.generator = None self.session = session
if not moves:
raise NoMoves('No moves specified.')
elif len(moves) > 4:
raise NoMoves('Too many moves specified.')
self.debug_level = debug_level self.debug_level = debug_level
self.session = session try:
# Cache all the common stuff.
# Not that it takes a lot of time to load, but it can add up
# if repeated.
self.__dict__ = MovesetSearch._cache[session]
except KeyError:
self.generator = None
self.sketch = util.get(session, tables.Move, identifier=u'sketch').id self.sketch = util.get(session, tables.Move, identifier=u'sketch').id
self.unsketchable = set([ self.unsketchable = set([
@ -65,14 +69,23 @@ class MovesetSearch(object):
self.ditto_group = util.get(session, tables.EggGroup, self.ditto_group = util.get(session, tables.EggGroup,
identifier=u'ditto').id identifier=u'ditto').id
self.load_pokemon()
self.load_moves()
MovesetSearch._cache[session] = self.__dict__
self.debug_level = debug_level
if not moves:
raise NoMoves('No moves specified.')
elif len(moves) > 4:
raise NoMoves('Too many moves specified.')
if costs is None: if costs is None:
self.costs = default_costs self.costs = default_costs
else: else:
self.costs = costs self.costs = costs
self.load_pokemon()
self.load_moves()
self.excluded_families = frozenset(p.evolution_chain_id self.excluded_families = frozenset(p.evolution_chain_id
for p in exclude_pokemon) for p in exclude_pokemon)
@ -638,7 +651,7 @@ default_costs = {
class Node(object): class Node(object):
"""Node for the A* search algorithm. """Node for the A* search algorithm.
To get started, implement the `expand` method and call `search`. To get started, implement `expand` & `is_goal` and call `find_path`.
N.B. Node objects must be hashable. N.B. Node objects must be hashable.
""" """
@ -652,7 +665,12 @@ class Node(object):
""" """
raise NotImplementedError raise NotImplementedError
def estimate(self, goal): def is_goal(self):
"""Return true iff this is a goal node.
"""
raise NotImplementedError
def estimate(self):
"""Return an *optimistic* estimate of the cost to the given goal node. """Return an *optimistic* estimate of the cost to the given goal node.
If there are multiple goal states, return the lowest estimate among all If there are multiple goal states, return the lowest estimate among all
@ -660,11 +678,6 @@ class Node(object):
""" """
return 0 return 0
def is_goal(self, goal):
"""Return true iff this is a goal node.
"""
return self == goal
def find_path(self, goal=None, **kwargs): def find_path(self, goal=None, **kwargs):
"""Return the best path to the goal """Return the best path to the goal
@ -676,13 +689,13 @@ class Node(object):
See a_star for the advanced keyword arguments, `notify` and See a_star for the advanced keyword arguments, `notify` and
`estimate_error_callback`. `estimate_error_callback`.
""" """
paths = self.find_all_paths(goal=goal, **kwargs) paths = self.find_all_paths(**kwargs)
try: try:
return paths.next() return paths.next()
except StopIteration: except StopIteration:
return None return None
def find_all_paths(self, goal=None, **kwargs): def find_all_paths(self, **kwargs):
"""Yield the best path to each goal """Yield the best path to each goal
Returns an iterator of paths. See the `search` method for how paths Returns an iterator of paths. See the `search` method for how paths
@ -699,8 +712,8 @@ class Node(object):
return a_star( return a_star(
initial=self, initial=self,
expand=lambda s: s.expand(), expand=lambda s: s.expand(),
estimate=lambda s: s.estimate(goal), estimate=lambda s: s.estimate(),
is_goal=lambda s: s.is_goal(goal), is_goal=lambda s: s.is_goal(),
**kwargs) **kwargs)
def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None, def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None,
@ -748,11 +761,10 @@ def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None,
stopping criteria; the other values may help in tuning estimators. stopping criteria; the other values may help in tuning estimators.
`estimate_error_callback`: function handling cases where an estimate was `estimate_error_callback`: function handling cases where an estimate was
detected not to be optimistic (as A* requires). The function is given a detected not to be consistent. The function is given a path (as would
path (as would be returned by a_star, except it does not lead to a goal be returned by a_star, except it does not lead to a goal node). By
node). By default, nothing is done (indeed, an estimate that's not default, nothing is done (as an estimate that's not even optimistic can
strictly optimistic can be useful, esp. if the optimal path is not still be useful).
required)
""" """
# g: best cummulative cost (from initial node) found so far # g: best cummulative cost (from initial node) found so far
# h: optimistic estimate of cost to goal # h: optimistic estimate of cost to goal
@ -971,6 +983,9 @@ class InitialNode(Node, namedtuple('InitialNode', 'search')):
) )
yield 0, action, node yield 0, action, node
def is_goal(self):
return False
class PokemonNode(Node, Facade, namedtuple('PokemonNode', class PokemonNode(Node, Facade, namedtuple('PokemonNode',
'search pokemon_ level version_group_ new_level moves_')): 'search pokemon_ level version_group_ new_level moves_')):
@ -1176,7 +1191,7 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode',
new_level=False, moves_=frozenset(moves)) new_level=False, moves_=frozenset(moves))
return return
def estimate(self, g): def estimate(self):
# Given good estimates, A* finds solutions much faster. # Given good estimates, A* finds solutions much faster.
# However, here it seems we either have easy movesets, which # However, here it seems we either have easy movesets, which
# get found pretty easily by themselves, or hard ones, where # get found pretty easily by themselves, or hard ones, where
@ -1191,6 +1206,9 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode',
search.goal_version_group, search.costs['trade'] * 2) search.goal_version_group, search.costs['trade'] * 2)
return trade_cost return trade_cost
def is_goal(self):
return False
class BaseBreedNode(Node): class BaseBreedNode(Node):
"""Breed node """Breed node
This serves to prevent duplicate breeds, by storing only the needed info This serves to prevent duplicate breeds, by storing only the needed info
@ -1217,12 +1235,8 @@ class BaseBreedNode(Node):
search=self.search, pokemon_=baby, level=hatch_level, search=self.search, pokemon_=baby, level=hatch_level,
version_group_=vg, moves_=bred_moves, new_level=True) version_group_=vg, moves_=bred_moves, new_level=True)
@property def is_goal(self):
def pokemon(self): return False
return None
def estimate(self, g):
return 0
class BreedNode(BaseBreedNode, namedtuple('BreedNode', class BreedNode(BaseBreedNode, namedtuple('BreedNode',
'search dummy group_ version_group_ moves_')): 'search dummy group_ version_group_ moves_')):
@ -1248,8 +1262,9 @@ class GoalNode(PokemonNode):
def expand(self): def expand(self):
return () return ()
def is_goal(self, g): def is_goal(self):
return True return True
### ###
### CLI interface ### CLI interface
### ###
@ -1265,7 +1280,7 @@ def print_result(result, moves=()):
cost=cost, cost=cost,
action=action, action=action,
long='>' if (len(unicode(action)) > 45) else '', long='>' if (len(unicode(action)) > 45) else '',
est=node.estimate(None), est=node.estimate(),
pokemon=node.pokemon.name, pokemon=node.pokemon.name,
nl='.' if node.new_level else ' ', nl='.' if node.new_level else ' ',
level=node.level, level=node.level,