mirror of
https://github.com/veekun/pokedex.git
synced 2024-08-20 18:16:34 +00:00
Merge A* in. General improvements.
This commit is contained in:
parent
88b9a216c7
commit
b2e54ca1ea
2 changed files with 284 additions and 347 deletions
|
@ -1,299 +0,0 @@
|
||||||
"""A pure-Python implementation of the A* search algorithm
|
|
||||||
"""
|
|
||||||
|
|
||||||
import heapq
|
|
||||||
|
|
||||||
class Node(object):
|
|
||||||
"""Node for the A* search algorithm.
|
|
||||||
|
|
||||||
To get started, implement the `expand` method and call `search`.
|
|
||||||
|
|
||||||
N.B. Node object must be hashable.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def expand(self):
|
|
||||||
"""Return a list of (costs, transition, next_node) for next states
|
|
||||||
|
|
||||||
"Next states" are those reachable from this node.
|
|
||||||
|
|
||||||
May return any finite iterable.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def estimate(self, goal):
|
|
||||||
"""Return an *optimistic* estimate of the cost to the given goal node.
|
|
||||||
|
|
||||||
If there are multiple goal states, return the lowest estimate among all
|
|
||||||
of them.
|
|
||||||
"""
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def is_goal(self, goal):
|
|
||||||
"""Return true iff this is a goal node.
|
|
||||||
"""
|
|
||||||
return self == goal
|
|
||||||
|
|
||||||
def find_path(self, goal=None, **kwargs):
|
|
||||||
"""Return the best path to the goal
|
|
||||||
|
|
||||||
Returns an iterator of (cost, transition, node) triples, in reverse
|
|
||||||
order (i.e. the first element will have the total cost and goal node).
|
|
||||||
|
|
||||||
If `goal` will be passed to the `estimate` and `is_goal` methods.
|
|
||||||
|
|
||||||
See a_star for the advanced keyword arguments, `notify` and
|
|
||||||
`estimate_error_callback`.
|
|
||||||
"""
|
|
||||||
paths = self.find_all_paths(goal=goal, **kwargs)
|
|
||||||
try:
|
|
||||||
return paths.next()
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def find_all_paths(self, goal=None, **kwargs):
|
|
||||||
"""Yield the best path to each goal
|
|
||||||
|
|
||||||
Returns an iterator of paths. See the `search` method for how paths
|
|
||||||
look.
|
|
||||||
|
|
||||||
Giving the `goal` argument will cause it to search for that goal,
|
|
||||||
instead of consulting the `is_goal` method.
|
|
||||||
This means that if you wish to find more than one path, you must not
|
|
||||||
pass a `goal` to this method, and instead reimplament `is_goal`.
|
|
||||||
|
|
||||||
See a_star for the advanced keyword arguments, `notify` and
|
|
||||||
`estimate_error_callback`.
|
|
||||||
"""
|
|
||||||
return a_star(
|
|
||||||
initial=self,
|
|
||||||
expand=lambda s: s.expand(),
|
|
||||||
estimate=lambda s: s.estimate(goal),
|
|
||||||
is_goal=lambda s: s.is_goal(goal),
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
### The main algorithm
|
|
||||||
|
|
||||||
def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None,
|
|
||||||
estimate_error_callback=None):
|
|
||||||
"""A* search algorithm for a consistent heuristic
|
|
||||||
|
|
||||||
General background: http://en.wikipedia.org/wiki/A*_search_algorithm
|
|
||||||
|
|
||||||
This algorithm will work in large or infinite search spaces.
|
|
||||||
|
|
||||||
This version of the algorithm is modified for multiple possible goals:
|
|
||||||
it does not end when it reaches a goal. Rather, it yields the best path
|
|
||||||
for each goal.
|
|
||||||
(Exhausting the iterator is of course not recommended for large search
|
|
||||||
spaces.)
|
|
||||||
|
|
||||||
Returns an iterable of paths, where each path is an iterable of
|
|
||||||
(cummulative cost, transition, node) triples representing the path to
|
|
||||||
the goal. The transition is the one leading to the corresponding node.
|
|
||||||
The path is in reverse order, thus its first element will contain the
|
|
||||||
total cost and the goal node.
|
|
||||||
The initial node is not included in the returned path.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
|
|
||||||
`initial`: the initial node
|
|
||||||
|
|
||||||
`expand`: function yielding a (cost of transition, transition, next node)
|
|
||||||
triple for each node reachable from its argument.
|
|
||||||
The `transition` element is application data; it is not touched, only
|
|
||||||
returned as part of the best path.
|
|
||||||
`estimate`: function(x) returning optimistic estimate of cost from node x
|
|
||||||
to a goal. If not given, 0 will be used for estimates.
|
|
||||||
`is_goal`: function(x) returning true iff x is a goal node
|
|
||||||
|
|
||||||
`notify`: If given, if is called at each step with three arguments:
|
|
||||||
- current cost (with estimate). The cost to the next goal will not be
|
|
||||||
smaller than this.
|
|
||||||
- current node
|
|
||||||
- open set cardinality: roughly, an estimate of the size of the
|
|
||||||
boundary between "explored" and "unexplored" parts of node space
|
|
||||||
- debug: stats that be useful for debugging or tuning (in this
|
|
||||||
implementation, this is the open heap size)
|
|
||||||
The number of calls to notify or the current cost can be useful as
|
|
||||||
stopping criteria; the other values may help in tuning estimators.
|
|
||||||
|
|
||||||
`estimate_error_callback`: function handling cases where an estimate was
|
|
||||||
detected not to be optimistic (as A* requires). The function is given a
|
|
||||||
path (as would be returned by a_star, except it does not lead to a goal
|
|
||||||
node). By default, nothing is done (indeed, an estimate that's not
|
|
||||||
strictly optimistic can be useful, esp. if the optimal path is not
|
|
||||||
required)
|
|
||||||
"""
|
|
||||||
# g: best cummulative cost (from initial node) found so far
|
|
||||||
# h: optimistic estimate of cost to goal
|
|
||||||
# f: g + h
|
|
||||||
closed = set() # nodes we don't want to visit again
|
|
||||||
est = estimate(initial) # estimate total cost
|
|
||||||
opened = _HeapDict() # node -> (f, g, h)
|
|
||||||
opened[initial] = (est, 0, est)
|
|
||||||
came_from = {initial: None} # node -> (prev_node, came_from[prev_node])
|
|
||||||
while True: # _HeapDict will raise StopIteration for us
|
|
||||||
x, (f, g, h) = opened.pop()
|
|
||||||
closed.add(x)
|
|
||||||
|
|
||||||
if notify is not None:
|
|
||||||
notify(f, x, len(opened.dict), len(opened.heap))
|
|
||||||
|
|
||||||
if is_goal(x):
|
|
||||||
yield _trace_path(came_from[x])
|
|
||||||
|
|
||||||
for cost, transition, y in expand(x):
|
|
||||||
if y in closed:
|
|
||||||
continue
|
|
||||||
tentative_g = g + cost
|
|
||||||
|
|
||||||
old_f, old_g, h = opened.get(y, (None, None, None))
|
|
||||||
|
|
||||||
if old_f is None:
|
|
||||||
h = estimate(y)
|
|
||||||
elif tentative_g > old_g:
|
|
||||||
continue
|
|
||||||
|
|
||||||
came_from[y] = ((tentative_g, transition, y), came_from[x])
|
|
||||||
new_f = tentative_g + h
|
|
||||||
|
|
||||||
opened[y] = new_f, tentative_g, h
|
|
||||||
|
|
||||||
if estimate_error_callback is not None and new_f < f:
|
|
||||||
estimate_error_callback(_trace_path(came_from[y]))
|
|
||||||
|
|
||||||
def _trace_path(cdr):
|
|
||||||
"""Backtrace an A* result"""
|
|
||||||
# Convert a lispy list to a pythony iterator
|
|
||||||
while cdr:
|
|
||||||
car, cdr = cdr
|
|
||||||
yield car
|
|
||||||
|
|
||||||
class _HeapDict(object):
|
|
||||||
"""A custom parallel heap/dict structure -- the best of both worlds.
|
|
||||||
|
|
||||||
This is NOT a general-purpose class; it only supports what a_star needs.
|
|
||||||
"""
|
|
||||||
# The dict has the definitive contents
|
|
||||||
# The heap has (value, key) pairs. It may have some extra elements.
|
|
||||||
def __init__(self):
|
|
||||||
self.dict = {}
|
|
||||||
self.heap = []
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
self.dict[key] = value
|
|
||||||
heapq.heappush(self.heap, (value, key))
|
|
||||||
|
|
||||||
def __delitem__(self, key):
|
|
||||||
del self.dict[key]
|
|
||||||
|
|
||||||
def get(self, key, default):
|
|
||||||
"""Return value for key, or default if not found
|
|
||||||
"""
|
|
||||||
return self.dict.get(key, default)
|
|
||||||
|
|
||||||
def pop(self):
|
|
||||||
"""Return (key, value) with the smallest value.
|
|
||||||
|
|
||||||
Raise StopIteration (!!) if empty
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
value, key = heapq.heappop(self.heap)
|
|
||||||
if value is self.dict[key]:
|
|
||||||
del self.dict[key]
|
|
||||||
return key, value
|
|
||||||
except KeyError:
|
|
||||||
# deleted from dict = not here
|
|
||||||
pass
|
|
||||||
except IndexError:
|
|
||||||
# nothing more to pop
|
|
||||||
raise StopIteration
|
|
||||||
|
|
||||||
|
|
||||||
### Example/test
|
|
||||||
|
|
||||||
|
|
||||||
def test_example_knights():
|
|
||||||
"""Test/example: the "knights" problem
|
|
||||||
|
|
||||||
Definition and another solution may be found at:
|
|
||||||
http://brandon.sternefamily.net/posts/2005/02/a-star-algorithm-in-python/
|
|
||||||
"""
|
|
||||||
# Legal moves
|
|
||||||
moves = { 1: [4, 7],
|
|
||||||
2: [8, 10],
|
|
||||||
3: [9],
|
|
||||||
4: [1, 6, 10],
|
|
||||||
5: [7],
|
|
||||||
6: [4],
|
|
||||||
7: [1, 5],
|
|
||||||
8: [2, 9],
|
|
||||||
9: [8, 3],
|
|
||||||
10: [2, 4] }
|
|
||||||
|
|
||||||
class Positions(dict, Node):
|
|
||||||
"""Node class representing positions as a dictionary.
|
|
||||||
|
|
||||||
Keys are unique piece names, values are (color, position) where color
|
|
||||||
is True for white, False for black.
|
|
||||||
"""
|
|
||||||
def expand(self):
|
|
||||||
for piece, (color, position) in self.items():
|
|
||||||
for new_position in moves[position]:
|
|
||||||
if new_position not in (p for c, p in self.values()):
|
|
||||||
new_node = Positions(self)
|
|
||||||
new_node.update({piece: (color, new_position)})
|
|
||||||
yield 1, None, new_node
|
|
||||||
|
|
||||||
def estimate(self, goal):
|
|
||||||
# Number of misplaced figures
|
|
||||||
misplaced = 0
|
|
||||||
for piece, (color, position) in self.items():
|
|
||||||
if (color, position) not in goal.values():
|
|
||||||
misplaced += 1
|
|
||||||
return misplaced
|
|
||||||
|
|
||||||
def is_goal(self, goal):
|
|
||||||
return self.estimate(goal) == 0
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash(tuple(sorted(self.items())))
|
|
||||||
|
|
||||||
initial = Positions({
|
|
||||||
'White 1': (True, 1),
|
|
||||||
'white 2': (True, 6),
|
|
||||||
'Black 1': (False, 5),
|
|
||||||
'black 2': (False, 7),
|
|
||||||
})
|
|
||||||
|
|
||||||
# Goal: colors should be switched
|
|
||||||
goal = Positions((piece, (not color, position))
|
|
||||||
for piece, (color, position) in initial.items())
|
|
||||||
|
|
||||||
def print_board(positions, linebreak='\n', extra=''):
|
|
||||||
board = dict((position, piece)
|
|
||||||
for piece, (color, position) in positions.items())
|
|
||||||
for i in range(1, 11):
|
|
||||||
# line breaks
|
|
||||||
if i in (2, 6, 9):
|
|
||||||
print linebreak,
|
|
||||||
print board.get(i, '_')[0],
|
|
||||||
print extra
|
|
||||||
|
|
||||||
def notify(cost, state, b, c):
|
|
||||||
print 'Looking at state with cost %s:' % cost,
|
|
||||||
print_board(state, '|', '(%s; %s; %s)' % (state.estimate(goal), b, c))
|
|
||||||
|
|
||||||
solution_path = list(initial.search(goal, notify=notify))
|
|
||||||
|
|
||||||
print 'Step', 0
|
|
||||||
print_board(initial)
|
|
||||||
for i, (cost, transition, positions) in enumerate(reversed(solution_path)):
|
|
||||||
print 'Step', i + 1
|
|
||||||
print_board(positions)
|
|
||||||
|
|
||||||
# Check solution is correct
|
|
||||||
cost, transition, positions = solution_path[0]
|
|
||||||
assert set(positions.values()) == set(goal.values())
|
|
||||||
assert cost == 40
|
|
|
@ -4,6 +4,7 @@
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
import itertools
|
import itertools
|
||||||
|
import heapq
|
||||||
from collections import defaultdict, namedtuple
|
from collections import defaultdict, namedtuple
|
||||||
|
|
||||||
from sqlalchemy.orm import aliased
|
from sqlalchemy.orm import aliased
|
||||||
|
@ -12,9 +13,6 @@ from sqlalchemy.sql.expression import not_, and_, or_
|
||||||
|
|
||||||
from pokedex.db import connect, tables, util
|
from pokedex.db import connect, tables, util
|
||||||
|
|
||||||
from pokedex.util import querytimer
|
|
||||||
from pokedex.util.astar import a_star, Node
|
|
||||||
|
|
||||||
###
|
###
|
||||||
### Illegal Moveset exceptions
|
### Illegal Moveset exceptions
|
||||||
###
|
###
|
||||||
|
@ -44,7 +42,7 @@ def powerset(iterable):
|
||||||
|
|
||||||
class MovesetSearch(object):
|
class MovesetSearch(object):
|
||||||
def __init__(self, session, pokemon, version, moves, level=100, costs=None,
|
def __init__(self, session, pokemon, version, moves, level=100, costs=None,
|
||||||
exclude_versions=(), exclude_pokemon=(), debug=False):
|
exclude_versions=(), exclude_pokemon=(), debug_level=False):
|
||||||
|
|
||||||
self.generator = None
|
self.generator = None
|
||||||
|
|
||||||
|
@ -53,7 +51,7 @@ class MovesetSearch(object):
|
||||||
elif len(moves) > 4:
|
elif len(moves) > 4:
|
||||||
raise NoMoves('Too many moves specified.')
|
raise NoMoves('Too many moves specified.')
|
||||||
|
|
||||||
self.debug = debug
|
self.debug_level = debug_level
|
||||||
|
|
||||||
self.session = session
|
self.session = session
|
||||||
|
|
||||||
|
@ -78,7 +76,7 @@ class MovesetSearch(object):
|
||||||
self.excluded_families = frozenset(p.evolution_chain_id
|
self.excluded_families = frozenset(p.evolution_chain_id
|
||||||
for p in exclude_pokemon)
|
for p in exclude_pokemon)
|
||||||
|
|
||||||
if debug:
|
if debug_level > 1:
|
||||||
print 'Specified moves:', [move.id for move in moves]
|
print 'Specified moves:', [move.id for move in moves]
|
||||||
|
|
||||||
self.goal_pokemon = pokemon.id
|
self.goal_pokemon = pokemon.id
|
||||||
|
@ -124,9 +122,10 @@ class MovesetSearch(object):
|
||||||
self.output_objects = dict()
|
self.output_objects = dict()
|
||||||
|
|
||||||
kwargs = dict()
|
kwargs = dict()
|
||||||
if debug:
|
if debug_level:
|
||||||
self._astar_debug_notify_counter = 0
|
self._astar_debug_notify_counter = 0
|
||||||
kwargs['notify'] = self.astar_debug_notify
|
kwargs['notify'] = self.astar_debug_notify
|
||||||
|
kwargs['estimate_error_callback'] = self.astar_estimate_error
|
||||||
self.generator = InitialNode(self).find_all_paths(**kwargs)
|
self.generator = InitialNode(self).find_all_paths(**kwargs)
|
||||||
|
|
||||||
def load_version_groups(self, version, excluded):
|
def load_version_groups(self, version, excluded):
|
||||||
|
@ -151,7 +150,7 @@ class MovesetSearch(object):
|
||||||
filtered_map[version] = (
|
filtered_map[version] = (
|
||||||
self.generation_id_by_version_group[version])
|
self.generation_id_by_version_group[version])
|
||||||
self.generation_id_by_version_group = filtered_map
|
self.generation_id_by_version_group = filtered_map
|
||||||
if self.debug:
|
if self.debug_level > 1:
|
||||||
print 'Excluded version groups:', excluded
|
print 'Excluded version groups:', excluded
|
||||||
print 'Trade cost table:'
|
print 'Trade cost table:'
|
||||||
print '%03s' % '',
|
print '%03s' % '',
|
||||||
|
@ -178,7 +177,7 @@ class MovesetSearch(object):
|
||||||
non_egg_moves is a set of moves that don't require breeding
|
non_egg_moves is a set of moves that don't require breeding
|
||||||
Otherwise, these are empty sets.
|
Otherwise, these are empty sets.
|
||||||
"""
|
"""
|
||||||
if self.debug:
|
if self.debug_level > 1:
|
||||||
print 'Loading pokemon moves, %s %s' % (evolution_chain, selection)
|
print 'Loading pokemon moves, %s %s' % (evolution_chain, selection)
|
||||||
query = self.session.query(
|
query = self.session.query(
|
||||||
tables.PokemonMove.pokemon_id,
|
tables.PokemonMove.pokemon_id,
|
||||||
|
@ -252,10 +251,10 @@ class MovesetSearch(object):
|
||||||
continue
|
continue
|
||||||
cost = 1
|
cost = 1
|
||||||
self.pokemon_moves[pokemon][vg][move][method].append((level, cost))
|
self.pokemon_moves[pokemon][vg][move][method].append((level, cost))
|
||||||
if self.debug and selection == 'family':
|
if self.debug_level > 1 and selection == 'family':
|
||||||
print 'Easy moves:', sorted(easy_moves)
|
print 'Easy moves:', sorted(easy_moves)
|
||||||
print 'Non-egg moves:', sorted(non_egg_moves)
|
print 'Non-egg moves:', sorted(non_egg_moves)
|
||||||
if self.debug:
|
if self.debug_level > 1:
|
||||||
print 'Smeargle families:', sorted(self.smeargle_families)
|
print 'Smeargle families:', sorted(self.smeargle_families)
|
||||||
return easy_moves, non_egg_moves
|
return easy_moves, non_egg_moves
|
||||||
|
|
||||||
|
@ -364,7 +363,7 @@ class MovesetSearch(object):
|
||||||
if move:
|
if move:
|
||||||
self.evolution_moves[self.evolution_chains[child]] = move
|
self.evolution_moves[self.evolution_chains[child]] = move
|
||||||
|
|
||||||
if self.debug:
|
if self.debug_level > 1:
|
||||||
print 'Loaded %s pokemon: %s evo; %s families: %s breedable' % (
|
print 'Loaded %s pokemon: %s evo; %s families: %s breedable' % (
|
||||||
len(self.evolution_chains),
|
len(self.evolution_chains),
|
||||||
len(self.pokemon_by_evolution_chain),
|
len(self.pokemon_by_evolution_chain),
|
||||||
|
@ -396,7 +395,7 @@ class MovesetSearch(object):
|
||||||
)
|
)
|
||||||
self.move_generations = dict(query)
|
self.move_generations = dict(query)
|
||||||
|
|
||||||
if self.debug:
|
if self.debug_level > 1:
|
||||||
print 'Loaded %s moves' % len(self.move_generations)
|
print 'Loaded %s moves' % len(self.move_generations)
|
||||||
|
|
||||||
def construct_breed_graph(self):
|
def construct_breed_graph(self):
|
||||||
|
@ -443,7 +442,7 @@ class MovesetSearch(object):
|
||||||
if len(groups) >= 2:
|
if len(groups) >= 2:
|
||||||
eg2_movepools[groups].update(pool)
|
eg2_movepools[groups].update(pool)
|
||||||
|
|
||||||
if self.debug:
|
if self.debug_level > 1:
|
||||||
print 'Egg group summary:'
|
print 'Egg group summary:'
|
||||||
for group in sorted(all_groups):
|
for group in sorted(all_groups):
|
||||||
print "%2s can pass: %s" % (group, sorted(eg1_movepools[group]))
|
print "%2s can pass: %s" % (group, sorted(eg1_movepools[group]))
|
||||||
|
@ -496,7 +495,7 @@ class MovesetSearch(object):
|
||||||
breeds_required[group][frozenset(moves)] = 1
|
breeds_required[group][frozenset(moves)] = 1
|
||||||
self.breeds_required = breeds_required
|
self.breeds_required = breeds_required
|
||||||
|
|
||||||
if self.debug:
|
if self.debug_level > 1:
|
||||||
for group, movesetlist in breeds_required.items():
|
for group, movesetlist in breeds_required.items():
|
||||||
print 'From egg group', group
|
print 'From egg group', group
|
||||||
for moveset, cost in movesetlist.items():
|
for moveset, cost in movesetlist.items():
|
||||||
|
@ -530,7 +529,7 @@ class MovesetSearch(object):
|
||||||
last_moves = moves
|
last_moves = moves
|
||||||
last_gen = gen
|
last_gen = gen
|
||||||
|
|
||||||
if self.debug:
|
if self.debug_level > 1:
|
||||||
print 'Deduplicated %s version groups' % counter
|
print 'Deduplicated %s version groups' % counter
|
||||||
|
|
||||||
def astar_debug_notify(self, cost, node, setsize, heapsize):
|
def astar_debug_notify(self, cost, node, setsize, heapsize):
|
||||||
|
@ -538,8 +537,13 @@ class MovesetSearch(object):
|
||||||
if counter % 100 == 0:
|
if counter % 100 == 0:
|
||||||
print 'A* iteration %s, cost %s; remaining: %s (%s) \r' % (
|
print 'A* iteration %s, cost %s; remaining: %s (%s) \r' % (
|
||||||
counter, cost, setsize, heapsize),
|
counter, cost, setsize, heapsize),
|
||||||
|
sys.stdout.flush()
|
||||||
self._astar_debug_notify_counter += 1
|
self._astar_debug_notify_counter += 1
|
||||||
|
|
||||||
|
def astar_estimate_error(self, result):
|
||||||
|
print '**warning: bad A* estimate**'
|
||||||
|
print_result(result)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self.generator
|
return self.generator
|
||||||
|
|
||||||
|
@ -585,7 +589,7 @@ default_costs = {
|
||||||
# For technical reasons, 'sketch' is also used for learning Sketch and
|
# For technical reasons, 'sketch' is also used for learning Sketch and
|
||||||
# by normal means, if it isn't included in the target moveset.
|
# by normal means, if it isn't included in the target moveset.
|
||||||
# So the actual cost of a sketched move will be double this number.
|
# So the actual cost of a sketched move will be double this number.
|
||||||
'sketch': 100, # Cheap. Exclude Smeargle if you think it's too cheap.
|
'sketch': 1, # Cheap. Exclude Smeargle if you think it's too cheap.
|
||||||
|
|
||||||
# Gimmick moves – we need to use this method to learn the move anyway,
|
# Gimmick moves – we need to use this method to learn the move anyway,
|
||||||
# so make a big-ish dent in the score if missing
|
# so make a big-ish dent in the score if missing
|
||||||
|
@ -604,7 +608,7 @@ default_costs = {
|
||||||
'evolution-delayed': 50, # *in addition* to evolution. Who wants to mash B on every level.
|
'evolution-delayed': 50, # *in addition* to evolution. Who wants to mash B on every level.
|
||||||
'breed': 400, # Breeding's a pain.
|
'breed': 400, # Breeding's a pain.
|
||||||
'trade': 200, # Trading's a pain, but not as much as breeding.
|
'trade': 200, # Trading's a pain, but not as much as breeding.
|
||||||
'transfer': 200, # *in addition* to trade. For one-way cross-generation transfers
|
'transfer': 150, # *in addition* to trade. Keep it below 'trade'.
|
||||||
'forget': 300, # Deleting a move. (Not needed unless deleting an evolution move.)
|
'forget': 300, # Deleting a move. (Not needed unless deleting an evolution move.)
|
||||||
'relearn': 150, # Also a pain, though not as big as breeding.
|
'relearn': 150, # Also a pain, though not as big as breeding.
|
||||||
'per-level': 1, # Prefer less grinding. This is for all lv-ups but the final “grow”
|
'per-level': 1, # Prefer less grinding. This is for all lv-ups but the final “grow”
|
||||||
|
@ -620,6 +624,215 @@ default_costs = {
|
||||||
'breed-penalty': 100,
|
'breed-penalty': 100,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
### A*
|
||||||
|
###
|
||||||
|
|
||||||
|
class Node(object):
|
||||||
|
"""Node for the A* search algorithm.
|
||||||
|
|
||||||
|
To get started, implement the `expand` method and call `search`.
|
||||||
|
|
||||||
|
N.B. Node objects must be hashable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def expand(self):
|
||||||
|
"""Return a list of (costs, transition, next_node) for next states
|
||||||
|
|
||||||
|
"Next states" are those reachable from this node.
|
||||||
|
|
||||||
|
May return any finite iterable.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def estimate(self, goal):
|
||||||
|
"""Return an *optimistic* estimate of the cost to the given goal node.
|
||||||
|
|
||||||
|
If there are multiple goal states, return the lowest estimate among all
|
||||||
|
of them.
|
||||||
|
"""
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def is_goal(self, goal):
|
||||||
|
"""Return true iff this is a goal node.
|
||||||
|
"""
|
||||||
|
return self == goal
|
||||||
|
|
||||||
|
def find_path(self, goal=None, **kwargs):
|
||||||
|
"""Return the best path to the goal
|
||||||
|
|
||||||
|
Returns an iterator of (cost, transition, node) triples, in reverse
|
||||||
|
order (i.e. the first element will have the total cost and goal node).
|
||||||
|
|
||||||
|
If `goal` will be passed to the `estimate` and `is_goal` methods.
|
||||||
|
|
||||||
|
See a_star for the advanced keyword arguments, `notify` and
|
||||||
|
`estimate_error_callback`.
|
||||||
|
"""
|
||||||
|
paths = self.find_all_paths(goal=goal, **kwargs)
|
||||||
|
try:
|
||||||
|
return paths.next()
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_all_paths(self, goal=None, **kwargs):
|
||||||
|
"""Yield the best path to each goal
|
||||||
|
|
||||||
|
Returns an iterator of paths. See the `search` method for how paths
|
||||||
|
look.
|
||||||
|
|
||||||
|
Giving the `goal` argument will cause it to search for that goal,
|
||||||
|
instead of consulting the `is_goal` method.
|
||||||
|
This means that if you wish to find more than one path, you must not
|
||||||
|
pass a `goal` to this method, and instead reimplament `is_goal`.
|
||||||
|
|
||||||
|
See a_star for the advanced keyword arguments, `notify` and
|
||||||
|
`estimate_error_callback`.
|
||||||
|
"""
|
||||||
|
return a_star(
|
||||||
|
initial=self,
|
||||||
|
expand=lambda s: s.expand(),
|
||||||
|
estimate=lambda s: s.estimate(goal),
|
||||||
|
is_goal=lambda s: s.is_goal(goal),
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None,
|
||||||
|
estimate_error_callback=None):
|
||||||
|
"""A* search algorithm for a consistent heuristic
|
||||||
|
|
||||||
|
General background: http://en.wikipedia.org/wiki/A*_search_algorithm
|
||||||
|
|
||||||
|
This algorithm will work in large or infinite search spaces.
|
||||||
|
|
||||||
|
This version of the algorithm is modified for multiple possible goals:
|
||||||
|
it does not end when it reaches a goal. Rather, it yields the best path
|
||||||
|
for each goal.
|
||||||
|
(Exhausting the iterator is of course not recommended for large search
|
||||||
|
spaces.)
|
||||||
|
|
||||||
|
Returns an iterable of paths, where each path is an iterable of
|
||||||
|
(cummulative cost, transition, node) triples representing the path to
|
||||||
|
the goal. The transition is the one leading to the corresponding node.
|
||||||
|
The path is in reverse order, thus its first element will contain the
|
||||||
|
total cost and the goal node.
|
||||||
|
The initial node is not included in the returned path.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
`initial`: the initial node
|
||||||
|
|
||||||
|
`expand`: function yielding a (cost of transition, transition, next node)
|
||||||
|
triple for each node reachable from its argument.
|
||||||
|
The `transition` element is application data; it is not touched, only
|
||||||
|
returned as part of the best path.
|
||||||
|
`estimate`: function(x) returning optimistic estimate of cost from node x
|
||||||
|
to a goal. If not given, 0 will be used for estimates.
|
||||||
|
`is_goal`: function(x) returning true iff x is a goal node
|
||||||
|
|
||||||
|
`notify`: If given, if is called at each step with three arguments:
|
||||||
|
- current cost (with estimate). The cost to the next goal will not be
|
||||||
|
smaller than this.
|
||||||
|
- current node
|
||||||
|
- open set cardinality: roughly, an estimate of the size of the
|
||||||
|
boundary between "explored" and "unexplored" parts of node space
|
||||||
|
- debug: stats that be useful for debugging or tuning (in this
|
||||||
|
implementation, this is the open heap size)
|
||||||
|
The number of calls to notify or the current cost can be useful as
|
||||||
|
stopping criteria; the other values may help in tuning estimators.
|
||||||
|
|
||||||
|
`estimate_error_callback`: function handling cases where an estimate was
|
||||||
|
detected not to be optimistic (as A* requires). The function is given a
|
||||||
|
path (as would be returned by a_star, except it does not lead to a goal
|
||||||
|
node). By default, nothing is done (indeed, an estimate that's not
|
||||||
|
strictly optimistic can be useful, esp. if the optimal path is not
|
||||||
|
required)
|
||||||
|
"""
|
||||||
|
# g: best cummulative cost (from initial node) found so far
|
||||||
|
# h: optimistic estimate of cost to goal
|
||||||
|
# f: g + h
|
||||||
|
closed = set() # nodes we don't want to visit again
|
||||||
|
est = estimate(initial) # estimate total cost
|
||||||
|
opened = _HeapDict() # node -> (f, g, h)
|
||||||
|
opened[initial] = (est, 0, est)
|
||||||
|
came_from = {initial: None} # node -> (prev_node, came_from[prev_node])
|
||||||
|
while True: # _HeapDict will raise StopIteration for us
|
||||||
|
x, (f, g, h) = opened.pop()
|
||||||
|
closed.add(x)
|
||||||
|
|
||||||
|
if notify is not None:
|
||||||
|
notify(f, x, len(opened.dict), len(opened.heap))
|
||||||
|
|
||||||
|
if is_goal(x):
|
||||||
|
yield _trace_path(came_from[x])
|
||||||
|
|
||||||
|
for cost, transition, y in expand(x):
|
||||||
|
if y in closed:
|
||||||
|
continue
|
||||||
|
tentative_g = g + cost
|
||||||
|
|
||||||
|
old_f, old_g, h = opened.get(y, (None, None, None))
|
||||||
|
|
||||||
|
if old_f is None:
|
||||||
|
h = estimate(y)
|
||||||
|
elif tentative_g > old_g:
|
||||||
|
continue
|
||||||
|
|
||||||
|
came_from[y] = ((tentative_g, transition, y), came_from[x])
|
||||||
|
new_f = tentative_g + h
|
||||||
|
|
||||||
|
opened[y] = new_f, tentative_g, h
|
||||||
|
|
||||||
|
if estimate_error_callback is not None and new_f < f:
|
||||||
|
estimate_error_callback(_trace_path(came_from[y]))
|
||||||
|
|
||||||
|
def _trace_path(cdr):
|
||||||
|
"""Backtrace an A* result"""
|
||||||
|
# Convert a lispy list to a pythony iterator
|
||||||
|
while cdr:
|
||||||
|
car, cdr = cdr
|
||||||
|
yield car
|
||||||
|
|
||||||
|
class _HeapDict(object):
|
||||||
|
"""A custom parallel heap/dict structure -- the best of both worlds.
|
||||||
|
|
||||||
|
This is NOT a general-purpose class; it only supports what a_star needs.
|
||||||
|
"""
|
||||||
|
# The dict has the definitive contents
|
||||||
|
# The heap has (value, key) pairs. It may have some extra elements.
|
||||||
|
def __init__(self):
|
||||||
|
self.dict = {}
|
||||||
|
self.heap = []
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self.dict[key] = value
|
||||||
|
heapq.heappush(self.heap, (value, key))
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
del self.dict[key]
|
||||||
|
|
||||||
|
def get(self, key, default):
|
||||||
|
"""Return value for key, or default if not found
|
||||||
|
"""
|
||||||
|
return self.dict.get(key, default)
|
||||||
|
|
||||||
|
def pop(self):
|
||||||
|
"""Return (key, value) with the smallest value.
|
||||||
|
|
||||||
|
Raise StopIteration (!!) if empty
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
value, key = heapq.heappop(self.heap)
|
||||||
|
if value is self.dict[key]:
|
||||||
|
del self.dict[key]
|
||||||
|
return key, value
|
||||||
|
except KeyError:
|
||||||
|
# deleted from dict = not here
|
||||||
|
pass
|
||||||
|
except IndexError:
|
||||||
|
# nothing more to pop
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
###
|
###
|
||||||
### Result objects
|
### Result objects
|
||||||
###
|
###
|
||||||
|
@ -828,6 +1041,7 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode',
|
||||||
def expand_forget(self):
|
def expand_forget(self):
|
||||||
cost = self.search.costs['forget']
|
cost = self.search.costs['forget']
|
||||||
for move in self.moves_:
|
for move in self.moves_:
|
||||||
|
if move not in self.search.goal_moves:
|
||||||
yield cost, ForgetAction(self.search, move), self._replace(
|
yield cost, ForgetAction(self.search, move), self._replace(
|
||||||
moves_=self.moves_.difference([move]), new_level=False)
|
moves_=self.moves_.difference([move]), new_level=False)
|
||||||
|
|
||||||
|
@ -923,7 +1137,7 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode',
|
||||||
moves = self.moves_
|
moves = self.moves_
|
||||||
for sketch in moves:
|
for sketch in moves:
|
||||||
if sketch == self.search.sketch:
|
if sketch == self.search.sketch:
|
||||||
for sketched in self.search.goal_moves:
|
for sketched in sorted(self.search.goal_moves):
|
||||||
if sketched in self.search.unsketchable:
|
if sketched in self.search.unsketchable:
|
||||||
continue
|
continue
|
||||||
if sketched not in moves:
|
if sketched not in moves:
|
||||||
|
@ -936,6 +1150,26 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode',
|
||||||
new_level=False, moves_=frozenset(moves))
|
new_level=False, moves_=frozenset(moves))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def estimate(self, g):
|
||||||
|
# Given good estimates, A* finds solutions much faster.
|
||||||
|
# However, here it seems we either have easy movesets, which
|
||||||
|
# get found pretty easily by themselves, or hard ones, where
|
||||||
|
# heuristics don't help too much, or impossible ones where they
|
||||||
|
# don't matter at all.
|
||||||
|
# So, keep the computations here to a minimum.
|
||||||
|
search = self.search
|
||||||
|
trade_cost = search.trade_cost(self.version_group_,
|
||||||
|
search.goal_version_group)
|
||||||
|
if trade_cost is None:
|
||||||
|
trade_cost = search.costs['trade'] * 2
|
||||||
|
return trade_cost
|
||||||
|
evo_chain = search.evolution_chains[self.pokemon_]
|
||||||
|
if evo_chain == search.goal_evolution_chain:
|
||||||
|
breed_cost = 0
|
||||||
|
else:
|
||||||
|
breed_cost = search.costs['breed']
|
||||||
|
return trade_cost + breed_cost
|
||||||
|
|
||||||
class BaseBreedNode(Node):
|
class BaseBreedNode(Node):
|
||||||
"""Breed node
|
"""Breed node
|
||||||
This serves to prevent duplicate breeds, by storing only the needed info
|
This serves to prevent duplicate breeds, by storing only the needed info
|
||||||
|
@ -955,8 +1189,8 @@ class BaseBreedNode(Node):
|
||||||
continue
|
continue
|
||||||
if len(bred_moves) < 4:
|
if len(bred_moves) < 4:
|
||||||
for move, methods in moves.items():
|
for move, methods in moves.items():
|
||||||
if 'light-ball-pichu' in methods:
|
if 'light-ball-egg' in methods:
|
||||||
bred_moves.add(move)
|
bred_moves = bred_moves.union([move])
|
||||||
cost = search.costs['per-hatch-counter'] * search.hatch_counters[baby]
|
cost = search.costs['per-hatch-counter'] * search.hatch_counters[baby]
|
||||||
yield 0, BreedAction(self.search, baby, bred_moves), PokemonNode(
|
yield 0, BreedAction(self.search, baby, bred_moves), PokemonNode(
|
||||||
search=self.search, pokemon_=baby, level=hatch_level,
|
search=self.search, pokemon_=baby, level=hatch_level,
|
||||||
|
@ -995,11 +1229,30 @@ class GoalNode(PokemonNode):
|
||||||
|
|
||||||
def is_goal(self, g):
|
def is_goal(self, g):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
###
|
###
|
||||||
### CLI interface
|
### CLI interface
|
||||||
###
|
###
|
||||||
|
|
||||||
|
def print_result(result, moves=()):
|
||||||
|
template = u"{cost:4} {est:4} {action:45.45}{long:1} {pokemon:10}{level:>3}{nl:1}{versions:2} {moves}"
|
||||||
|
print template.format(cost='Cost', est='Est.', action='Action', pokemon='Pokemon',
|
||||||
|
long='', level='Lv.', nl='V', versions='er',
|
||||||
|
moves=''.join(m.name[0].lower() for m in moves))
|
||||||
|
for cost, action, node in reversed(list(result)):
|
||||||
|
if action:
|
||||||
|
print template.format(
|
||||||
|
cost=cost,
|
||||||
|
action=action,
|
||||||
|
long='>' if (len(unicode(action)) > 45) else '',
|
||||||
|
est=node.estimate(None),
|
||||||
|
pokemon=node.pokemon.name,
|
||||||
|
nl='.' if node.new_level else ' ',
|
||||||
|
level=node.level,
|
||||||
|
versions=''.join(v.name[0] for v in node.versions),
|
||||||
|
moves=''.join('.' if m in node.moves else ' ' for m in moves) +
|
||||||
|
''.join(m.name[0].lower() for m in node.moves if m not in moves),
|
||||||
|
)
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
parser = argparse.ArgumentParser(description=
|
parser = argparse.ArgumentParser(description=
|
||||||
'Find out if the specified moveset is valid, and provide a suggestion '
|
'Find out if the specified moveset is valid, and provide a suggestion '
|
||||||
|
@ -1039,7 +1292,7 @@ def main(argv):
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print 'Connecting'
|
print 'Connecting'
|
||||||
|
|
||||||
session = connect(engine_args={'echo': args.debug > 1})
|
session = connect(engine_args={'echo': args.debug > 2})
|
||||||
|
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print 'Parsing arguments'
|
print 'Parsing arguments'
|
||||||
|
@ -1068,37 +1321,20 @@ def main(argv):
|
||||||
try:
|
try:
|
||||||
search = MovesetSearch(session, pokemon, version, moves, args.level,
|
search = MovesetSearch(session, pokemon, version, moves, args.level,
|
||||||
exclude_versions=excl_versions, exclude_pokemon=excl_pokemon,
|
exclude_versions=excl_versions, exclude_pokemon=excl_pokemon,
|
||||||
debug=args.debug)
|
debug_level=args.debug)
|
||||||
except IllegalMoveCombination, e:
|
except IllegalMoveCombination, e:
|
||||||
print 'Error:', e
|
print 'Error:', e
|
||||||
else:
|
else:
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print 'Setup done'
|
print 'Setup done'
|
||||||
|
|
||||||
template = u"{cost:4} {action:50.50}{long:1} {pokemon:10}{level:>3}{nl:1}{versions:2} {moves}"
|
|
||||||
for result in search:
|
for result in search:
|
||||||
print '-' * 79
|
if args.debug and search.output_objects:
|
||||||
if no_results:
|
|
||||||
if search.output_objects:
|
|
||||||
print '**warning: search looked up output objects**'
|
print '**warning: search looked up output objects**'
|
||||||
no_results = False
|
no_results = False
|
||||||
print template.format(cost='Cost', action='Action', pokemon='Pokemon',
|
print '-' * 79
|
||||||
long='',level='Lv.', nl='V', versions='er',
|
print_result(result, moves=moves)
|
||||||
moves=''.join(m.name[0].lower() for m in moves))
|
# XXX: Support more than one result
|
||||||
for cost, action, node in reversed(list(result)):
|
|
||||||
if action:
|
|
||||||
print template.format(
|
|
||||||
cost=cost,
|
|
||||||
action=action,
|
|
||||||
long='>' if len(str(action)) > 50 else '',
|
|
||||||
pokemon=node.pokemon.name,
|
|
||||||
nl='.' if node.new_level else ' ',
|
|
||||||
level=node.level,
|
|
||||||
versions=''.join(v.name[0] for v in node.versions),
|
|
||||||
moves=''.join('.' if m in node.moves else ' ' for m in moves) +
|
|
||||||
''.join(m.name[0].lower() for m in node.moves if m not in moves),
|
|
||||||
)
|
|
||||||
# XXX: Support more results
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if args.debug:
|
if args.debug:
|
||||||
|
|
Loading…
Reference in a new issue