mirror of
https://github.com/veekun/pokedex.git
synced 2024-08-20 18:16:34 +00:00
Cache common stuff. Remove goal from a_star.
This commit is contained in:
parent
cb78144e8d
commit
d746716575
1 changed files with 58 additions and 43 deletions
|
@ -6,6 +6,7 @@ import argparse
|
|||
import itertools
|
||||
import heapq
|
||||
from collections import defaultdict, namedtuple
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
|
@ -41,19 +42,22 @@ def powerset(iterable):
|
|||
###
|
||||
|
||||
class MovesetSearch(object):
|
||||
_cache = WeakKeyDictionary()
|
||||
|
||||
def __init__(self, session, pokemon, version, moves, level=100, costs=None,
|
||||
exclude_versions=(), exclude_pokemon=(), debug_level=False):
|
||||
|
||||
self.generator = None
|
||||
|
||||
if not moves:
|
||||
raise NoMoves('No moves specified.')
|
||||
elif len(moves) > 4:
|
||||
raise NoMoves('Too many moves specified.')
|
||||
self.session = session
|
||||
|
||||
self.debug_level = debug_level
|
||||
|
||||
self.session = session
|
||||
try:
|
||||
# Cache all the common stuff.
|
||||
# Not that it takes a lot of time to load, but it can add up
|
||||
# if repeated.
|
||||
self.__dict__ = MovesetSearch._cache[session]
|
||||
except KeyError:
|
||||
self.generator = None
|
||||
|
||||
self.sketch = util.get(session, tables.Move, identifier=u'sketch').id
|
||||
self.unsketchable = set([
|
||||
|
@ -65,14 +69,23 @@ class MovesetSearch(object):
|
|||
self.ditto_group = util.get(session, tables.EggGroup,
|
||||
identifier=u'ditto').id
|
||||
|
||||
self.load_pokemon()
|
||||
self.load_moves()
|
||||
|
||||
MovesetSearch._cache[session] = self.__dict__
|
||||
|
||||
self.debug_level = debug_level
|
||||
|
||||
if not moves:
|
||||
raise NoMoves('No moves specified.')
|
||||
elif len(moves) > 4:
|
||||
raise NoMoves('Too many moves specified.')
|
||||
|
||||
if costs is None:
|
||||
self.costs = default_costs
|
||||
else:
|
||||
self.costs = costs
|
||||
|
||||
self.load_pokemon()
|
||||
self.load_moves()
|
||||
|
||||
self.excluded_families = frozenset(p.evolution_chain_id
|
||||
for p in exclude_pokemon)
|
||||
|
||||
|
@ -638,7 +651,7 @@ default_costs = {
|
|||
class Node(object):
|
||||
"""Node for the A* search algorithm.
|
||||
|
||||
To get started, implement the `expand` method and call `search`.
|
||||
To get started, implement `expand` & `is_goal` and call `find_path`.
|
||||
|
||||
N.B. Node objects must be hashable.
|
||||
"""
|
||||
|
@ -652,7 +665,12 @@ class Node(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def estimate(self, goal):
|
||||
def is_goal(self):
|
||||
"""Return true iff this is a goal node.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def estimate(self):
|
||||
"""Return an *optimistic* estimate of the cost to the given goal node.
|
||||
|
||||
If there are multiple goal states, return the lowest estimate among all
|
||||
|
@ -660,11 +678,6 @@ class Node(object):
|
|||
"""
|
||||
return 0
|
||||
|
||||
def is_goal(self, goal):
|
||||
"""Return true iff this is a goal node.
|
||||
"""
|
||||
return self == goal
|
||||
|
||||
def find_path(self, goal=None, **kwargs):
|
||||
"""Return the best path to the goal
|
||||
|
||||
|
@ -676,13 +689,13 @@ class Node(object):
|
|||
See a_star for the advanced keyword arguments, `notify` and
|
||||
`estimate_error_callback`.
|
||||
"""
|
||||
paths = self.find_all_paths(goal=goal, **kwargs)
|
||||
paths = self.find_all_paths(**kwargs)
|
||||
try:
|
||||
return paths.next()
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def find_all_paths(self, goal=None, **kwargs):
|
||||
def find_all_paths(self, **kwargs):
|
||||
"""Yield the best path to each goal
|
||||
|
||||
Returns an iterator of paths. See the `search` method for how paths
|
||||
|
@ -699,8 +712,8 @@ class Node(object):
|
|||
return a_star(
|
||||
initial=self,
|
||||
expand=lambda s: s.expand(),
|
||||
estimate=lambda s: s.estimate(goal),
|
||||
is_goal=lambda s: s.is_goal(goal),
|
||||
estimate=lambda s: s.estimate(),
|
||||
is_goal=lambda s: s.is_goal(),
|
||||
**kwargs)
|
||||
|
||||
def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None,
|
||||
|
@ -748,11 +761,10 @@ def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None,
|
|||
stopping criteria; the other values may help in tuning estimators.
|
||||
|
||||
`estimate_error_callback`: function handling cases where an estimate was
|
||||
detected not to be optimistic (as A* requires). The function is given a
|
||||
path (as would be returned by a_star, except it does not lead to a goal
|
||||
node). By default, nothing is done (indeed, an estimate that's not
|
||||
strictly optimistic can be useful, esp. if the optimal path is not
|
||||
required)
|
||||
detected not to be consistent. The function is given a path (as would
|
||||
be returned by a_star, except it does not lead to a goal node). By
|
||||
default, nothing is done (as an estimate that's not even optimistic can
|
||||
still be useful).
|
||||
"""
|
||||
# g: best cummulative cost (from initial node) found so far
|
||||
# h: optimistic estimate of cost to goal
|
||||
|
@ -971,6 +983,9 @@ class InitialNode(Node, namedtuple('InitialNode', 'search')):
|
|||
)
|
||||
yield 0, action, node
|
||||
|
||||
def is_goal(self):
|
||||
return False
|
||||
|
||||
class PokemonNode(Node, Facade, namedtuple('PokemonNode',
|
||||
'search pokemon_ level version_group_ new_level moves_')):
|
||||
|
||||
|
@ -1176,7 +1191,7 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode',
|
|||
new_level=False, moves_=frozenset(moves))
|
||||
return
|
||||
|
||||
def estimate(self, g):
|
||||
def estimate(self):
|
||||
# Given good estimates, A* finds solutions much faster.
|
||||
# However, here it seems we either have easy movesets, which
|
||||
# get found pretty easily by themselves, or hard ones, where
|
||||
|
@ -1191,6 +1206,9 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode',
|
|||
search.goal_version_group, search.costs['trade'] * 2)
|
||||
return trade_cost
|
||||
|
||||
def is_goal(self):
|
||||
return False
|
||||
|
||||
class BaseBreedNode(Node):
|
||||
"""Breed node
|
||||
This serves to prevent duplicate breeds, by storing only the needed info
|
||||
|
@ -1217,12 +1235,8 @@ class BaseBreedNode(Node):
|
|||
search=self.search, pokemon_=baby, level=hatch_level,
|
||||
version_group_=vg, moves_=bred_moves, new_level=True)
|
||||
|
||||
@property
|
||||
def pokemon(self):
|
||||
return None
|
||||
|
||||
def estimate(self, g):
|
||||
return 0
|
||||
def is_goal(self):
|
||||
return False
|
||||
|
||||
class BreedNode(BaseBreedNode, namedtuple('BreedNode',
|
||||
'search dummy group_ version_group_ moves_')):
|
||||
|
@ -1248,8 +1262,9 @@ class GoalNode(PokemonNode):
|
|||
def expand(self):
|
||||
return ()
|
||||
|
||||
def is_goal(self, g):
|
||||
def is_goal(self):
|
||||
return True
|
||||
|
||||
###
|
||||
### CLI interface
|
||||
###
|
||||
|
@ -1265,7 +1280,7 @@ def print_result(result, moves=()):
|
|||
cost=cost,
|
||||
action=action,
|
||||
long='>' if (len(unicode(action)) > 45) else '',
|
||||
est=node.estimate(None),
|
||||
est=node.estimate(),
|
||||
pokemon=node.pokemon.name,
|
||||
nl='.' if node.new_level else ' ',
|
||||
level=node.level,
|
||||
|
|
Loading…
Reference in a new issue