Cache common stuff. Remove goal from a_star.

This commit is contained in:
Petr Viktorin 2011-04-27 14:21:24 +03:00
parent cb78144e8d
commit d746716575

View file

@ -6,6 +6,7 @@ import argparse
import itertools
import heapq
from collections import defaultdict, namedtuple
from weakref import WeakKeyDictionary
from sqlalchemy.orm import aliased
from sqlalchemy.orm.exc import NoResultFound
@ -41,38 +42,50 @@ def powerset(iterable):
###
class MovesetSearch(object):
_cache = WeakKeyDictionary()
def __init__(self, session, pokemon, version, moves, level=100, costs=None,
exclude_versions=(), exclude_pokemon=(), debug_level=False):
self.generator = None
self.session = session
self.debug_level = debug_level
try:
# Cache all the common stuff.
# Not that it takes a lot of time to load, but it can add up
# if repeated.
self.__dict__ = MovesetSearch._cache[session]
except KeyError:
self.generator = None
self.sketch = util.get(session, tables.Move, identifier=u'sketch').id
self.unsketchable = set([
util.get(session, tables.Move, identifier=u'struggle').id,
util.get(session, tables.Move, identifier=u'chatter').id,
])
self.no_eggs_group = util.get(session, tables.EggGroup,
identifier=u'no-eggs').id
self.ditto_group = util.get(session, tables.EggGroup,
identifier=u'ditto').id
self.load_pokemon()
self.load_moves()
MovesetSearch._cache[session] = self.__dict__
self.debug_level = debug_level
if not moves:
raise NoMoves('No moves specified.')
elif len(moves) > 4:
raise NoMoves('Too many moves specified.')
self.debug_level = debug_level
self.session = session
self.sketch = util.get(session, tables.Move, identifier=u'sketch').id
self.unsketchable = set([
util.get(session, tables.Move, identifier=u'struggle').id,
util.get(session, tables.Move, identifier=u'chatter').id,
])
self.no_eggs_group = util.get(session, tables.EggGroup,
identifier=u'no-eggs').id
self.ditto_group = util.get(session, tables.EggGroup,
identifier=u'ditto').id
if costs is None:
self.costs = default_costs
else:
self.costs = costs
self.load_pokemon()
self.load_moves()
self.excluded_families = frozenset(p.evolution_chain_id
for p in exclude_pokemon)
@ -638,7 +651,7 @@ default_costs = {
class Node(object):
"""Node for the A* search algorithm.
To get started, implement the `expand` method and call `search`.
To get started, implement `expand` & `is_goal` and call `find_path`.
N.B. Node objects must be hashable.
"""
@ -652,7 +665,12 @@ class Node(object):
"""
raise NotImplementedError
def estimate(self, goal):
def is_goal(self):
"""Return true iff this is a goal node.
"""
raise NotImplementedError
def estimate(self):
"""Return an *optimistic* estimate of the cost to the given goal node.
If there are multiple goal states, return the lowest estimate among all
@ -660,11 +678,6 @@ class Node(object):
"""
return 0
def is_goal(self, goal):
"""Return true iff this is a goal node.
"""
return self == goal
def find_path(self, goal=None, **kwargs):
"""Return the best path to the goal
@ -676,13 +689,13 @@ class Node(object):
See a_star for the advanced keyword arguments, `notify` and
`estimate_error_callback`.
"""
paths = self.find_all_paths(goal=goal, **kwargs)
paths = self.find_all_paths(**kwargs)
try:
return paths.next()
except StopIteration:
return None
def find_all_paths(self, goal=None, **kwargs):
def find_all_paths(self, **kwargs):
"""Yield the best path to each goal
Returns an iterator of paths. See the `search` method for how paths
@ -699,8 +712,8 @@ class Node(object):
return a_star(
initial=self,
expand=lambda s: s.expand(),
estimate=lambda s: s.estimate(goal),
is_goal=lambda s: s.is_goal(goal),
estimate=lambda s: s.estimate(),
is_goal=lambda s: s.is_goal(),
**kwargs)
def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None,
@ -748,11 +761,10 @@ def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None,
stopping criteria; the other values may help in tuning estimators.
`estimate_error_callback`: function handling cases where an estimate was
detected not to be optimistic (as A* requires). The function is given a
path (as would be returned by a_star, except it does not lead to a goal
node). By default, nothing is done (indeed, an estimate that's not
strictly optimistic can be useful, esp. if the optimal path is not
required)
detected not to be consistent. The function is given a path (as would
be returned by a_star, except it does not lead to a goal node). By
default, nothing is done (as an estimate that's not even optimistic can
still be useful).
"""
# g: best cummulative cost (from initial node) found so far
# h: optimistic estimate of cost to goal
@ -971,6 +983,9 @@ class InitialNode(Node, namedtuple('InitialNode', 'search')):
)
yield 0, action, node
def is_goal(self):
return False
class PokemonNode(Node, Facade, namedtuple('PokemonNode',
'search pokemon_ level version_group_ new_level moves_')):
@ -1176,7 +1191,7 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode',
new_level=False, moves_=frozenset(moves))
return
def estimate(self, g):
def estimate(self):
# Given good estimates, A* finds solutions much faster.
# However, here it seems we either have easy movesets, which
# get found pretty easily by themselves, or hard ones, where
@ -1191,6 +1206,9 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode',
search.goal_version_group, search.costs['trade'] * 2)
return trade_cost
def is_goal(self):
return False
class BaseBreedNode(Node):
"""Breed node
This serves to prevent duplicate breeds, by storing only the needed info
@ -1217,12 +1235,8 @@ class BaseBreedNode(Node):
search=self.search, pokemon_=baby, level=hatch_level,
version_group_=vg, moves_=bred_moves, new_level=True)
@property
def pokemon(self):
return None
def estimate(self, g):
return 0
def is_goal(self):
return False
class BreedNode(BaseBreedNode, namedtuple('BreedNode',
'search dummy group_ version_group_ moves_')):
@ -1248,8 +1262,9 @@ class GoalNode(PokemonNode):
def expand(self):
return ()
def is_goal(self, g):
def is_goal(self):
return True
###
### CLI interface
###
@ -1265,7 +1280,7 @@ def print_result(result, moves=()):
cost=cost,
action=action,
long='>' if (len(unicode(action)) > 45) else '',
est=node.estimate(None),
est=node.estimate(),
pokemon=node.pokemon.name,
nl='.' if node.new_level else ' ',
level=node.level,