mirror of
https://github.com/veekun/pokedex.git
synced 2024-08-20 18:16:34 +00:00
300 lines
10 KiB
Python
300 lines
10 KiB
Python
|
"""A pure-Python implementation of the A* search algorithm
|
||
|
"""
|
||
|
|
||
|
import heapq
|
||
|
|
||
|
class Node(object):
|
||
|
"""Node for the A* search algorithm.
|
||
|
|
||
|
To get started, implement the `expand` method and call `search`.
|
||
|
|
||
|
N.B. Node object must be hashable.
|
||
|
"""
|
||
|
|
||
|
def expand(self):
|
||
|
"""Return a list of (costs, transition, next_node) for next states
|
||
|
|
||
|
"Next states" are those reachable from this node.
|
||
|
|
||
|
May return any finite iterable.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def estimate(self, goal):
|
||
|
"""Return an *optimistic* estimate of the cost to the given goal node.
|
||
|
|
||
|
If there are multiple goal states, return the lowest estimate among all
|
||
|
of them.
|
||
|
"""
|
||
|
return 0
|
||
|
|
||
|
def is_goal(self, goal):
|
||
|
"""Return true iff this is a goal node.
|
||
|
"""
|
||
|
return self == goal
|
||
|
|
||
|
def find_path(self, goal=None, **kwargs):
|
||
|
"""Return the best path to the goal
|
||
|
|
||
|
Returns an iterator of (cost, transition, node) triples, in reverse
|
||
|
order (i.e. the first element will have the total cost and goal node).
|
||
|
|
||
|
If `goal` will be passed to the `estimate` and `is_goal` methods.
|
||
|
|
||
|
See a_star for the advanced keyword arguments, `notify` and
|
||
|
`estimate_error_callback`.
|
||
|
"""
|
||
|
paths = self.find_all_paths(goal=goal, **kwargs)
|
||
|
try:
|
||
|
return paths.next()
|
||
|
except StopIteration:
|
||
|
return None
|
||
|
|
||
|
def find_all_paths(self, goal=None, **kwargs):
|
||
|
"""Yield the best path to each goal
|
||
|
|
||
|
Returns an iterator of paths. See the `search` method for how paths
|
||
|
look.
|
||
|
|
||
|
Giving the `goal` argument will cause it to search for that goal,
|
||
|
instead of consulting the `is_goal` method.
|
||
|
This means that if you wish to find more than one path, you must not
|
||
|
pass a `goal` to this method, and instead reimplament `is_goal`.
|
||
|
|
||
|
See a_star for the advanced keyword arguments, `notify` and
|
||
|
`estimate_error_callback`.
|
||
|
"""
|
||
|
return a_star(
|
||
|
initial=self,
|
||
|
expand=lambda s: s.expand(),
|
||
|
estimate=lambda s: s.estimate(goal),
|
||
|
is_goal=lambda s: s.is_goal(goal),
|
||
|
**kwargs)
|
||
|
|
||
|
### The main algorithm
|
||
|
|
||
|
def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None,
|
||
|
estimate_error_callback=None):
|
||
|
"""A* search algorithm for a consistent heuristic
|
||
|
|
||
|
General background: http://en.wikipedia.org/wiki/A*_search_algorithm
|
||
|
|
||
|
This algorithm will work in large or infinite search spaces.
|
||
|
|
||
|
This version of the algorithm is modified for multiple possible goals:
|
||
|
it does not end when it reaches a goal. Rather, it yields the best path
|
||
|
for each goal.
|
||
|
(Exhausting the iterator is of course not recommended for large search
|
||
|
spaces.)
|
||
|
|
||
|
Returns an iterable of paths, where each path is an iterable of
|
||
|
(cummulative cost, transition, node) triples representing the path to
|
||
|
the goal. The transition is the one leading to the corresponding node.
|
||
|
The path is in reverse order, thus its first element will contain the
|
||
|
total cost and the goal node.
|
||
|
The initial node is not included in the returned path.
|
||
|
|
||
|
Arguments:
|
||
|
|
||
|
`initial`: the initial node
|
||
|
|
||
|
`expand`: function yielding a (cost of transition, transition, next node)
|
||
|
triple for each node reachable from its argument.
|
||
|
The `transition` element is application data; it is not touched, only
|
||
|
returned as part of the best path.
|
||
|
`estimate`: function(x) returning optimistic estimate of cost from node x
|
||
|
to a goal. If not given, 0 will be used for estimates.
|
||
|
`is_goal`: function(x) returning true iff x is a goal node
|
||
|
|
||
|
`notify`: If given, if is called at each step with three arguments:
|
||
|
- current cost (with estimate). The cost to the next goal will not be
|
||
|
smaller than this.
|
||
|
- current node
|
||
|
- open set cardinality: roughly, an estimate of the size of the
|
||
|
boundary between "explored" and "unexplored" parts of node space
|
||
|
- debug: stats that be useful for debugging or tuning (in this
|
||
|
implementation, this is the open heap size)
|
||
|
The number of calls to notify or the current cost can be useful as
|
||
|
stopping criteria; the other values may help in tuning estimators.
|
||
|
|
||
|
`estimate_error_callback`: function handling cases where an estimate was
|
||
|
detected not to be optimistic (as A* requires). The function is given a
|
||
|
path (as would be returned by a_star, except it does not lead to a goal
|
||
|
node). By default, nothing is done (indeed, an estimate that's not
|
||
|
strictly optimistic can be useful, esp. if the optimal path is not
|
||
|
required)
|
||
|
"""
|
||
|
# g: best cummulative cost (from initial node) found so far
|
||
|
# h: optimistic estimate of cost to goal
|
||
|
# f: g + h
|
||
|
closed = set() # nodes we don't want to visit again
|
||
|
est = estimate(initial) # estimate total cost
|
||
|
opened = _HeapDict() # node -> (f, g, h)
|
||
|
opened[initial] = (est, 0, est)
|
||
|
came_from = {initial: None} # node -> (prev_node, came_from[prev_node])
|
||
|
while True: # _HeapDict will raise StopIteration for us
|
||
|
x, (f, g, h) = opened.pop()
|
||
|
closed.add(x)
|
||
|
|
||
|
if notify is not None:
|
||
|
notify(f, x, len(opened.dict), len(opened.heap))
|
||
|
|
||
|
if is_goal(x):
|
||
|
yield _trace_path(came_from[x])
|
||
|
|
||
|
for cost, transition, y in expand(x):
|
||
|
if y in closed:
|
||
|
continue
|
||
|
tentative_g = g + cost
|
||
|
|
||
|
old_f, old_g, h = opened.get(y, (None, None, None))
|
||
|
|
||
|
if old_f is None:
|
||
|
h = estimate(y)
|
||
|
elif tentative_g > old_g:
|
||
|
continue
|
||
|
|
||
|
came_from[y] = ((tentative_g, transition, y), came_from[x])
|
||
|
new_f = tentative_g + h
|
||
|
|
||
|
opened[y] = new_f, tentative_g, h
|
||
|
|
||
|
if estimate_error_callback is not None and new_f < f:
|
||
|
estimate_error_callback(_trace_path(came_from[y]))
|
||
|
|
||
|
def _trace_path(cdr):
|
||
|
"""Backtrace an A* result"""
|
||
|
# Convert a lispy list to a pythony iterator
|
||
|
while cdr:
|
||
|
car, cdr = cdr
|
||
|
yield car
|
||
|
|
||
|
class _HeapDict(object):
|
||
|
"""A custom parallel heap/dict structure -- the best of both worlds.
|
||
|
|
||
|
This is NOT a general-purpose class; it only supports what a_star needs.
|
||
|
"""
|
||
|
# The dict has the definitive contents
|
||
|
# The heap has (value, key) pairs. It may have some extra elements.
|
||
|
def __init__(self):
|
||
|
self.dict = {}
|
||
|
self.heap = []
|
||
|
|
||
|
def __setitem__(self, key, value):
|
||
|
self.dict[key] = value
|
||
|
heapq.heappush(self.heap, (value, key))
|
||
|
|
||
|
def __delitem__(self, key):
|
||
|
del self.dict[key]
|
||
|
|
||
|
def get(self, key, default):
|
||
|
"""Return value for key, or default if not found
|
||
|
"""
|
||
|
return self.dict.get(key, default)
|
||
|
|
||
|
def pop(self):
|
||
|
"""Return (key, value) with the smallest value.
|
||
|
|
||
|
Raise StopIteration (!!) if empty
|
||
|
"""
|
||
|
while True:
|
||
|
try:
|
||
|
value, key = heapq.heappop(self.heap)
|
||
|
if value is self.dict[key]:
|
||
|
del self.dict[key]
|
||
|
return key, value
|
||
|
except KeyError:
|
||
|
# deleted from dict = not here
|
||
|
pass
|
||
|
except IndexError:
|
||
|
# nothing more to pop
|
||
|
raise StopIteration
|
||
|
|
||
|
|
||
|
### Example/test
|
||
|
|
||
|
|
||
|
def test_example_knights():
|
||
|
"""Test/example: the "knights" problem
|
||
|
|
||
|
Definition and another solution may be found at:
|
||
|
http://brandon.sternefamily.net/posts/2005/02/a-star-algorithm-in-python/
|
||
|
"""
|
||
|
# Legal moves
|
||
|
moves = { 1: [4, 7],
|
||
|
2: [8, 10],
|
||
|
3: [9],
|
||
|
4: [1, 6, 10],
|
||
|
5: [7],
|
||
|
6: [4],
|
||
|
7: [1, 5],
|
||
|
8: [2, 9],
|
||
|
9: [8, 3],
|
||
|
10: [2, 4] }
|
||
|
|
||
|
class Positions(dict, Node):
|
||
|
"""Node class representing positions as a dictionary.
|
||
|
|
||
|
Keys are unique piece names, values are (color, position) where color
|
||
|
is True for white, False for black.
|
||
|
"""
|
||
|
def expand(self):
|
||
|
for piece, (color, position) in self.items():
|
||
|
for new_position in moves[position]:
|
||
|
if new_position not in (p for c, p in self.values()):
|
||
|
new_node = Positions(self)
|
||
|
new_node.update({piece: (color, new_position)})
|
||
|
yield 1, None, new_node
|
||
|
|
||
|
def estimate(self, goal):
|
||
|
# Number of misplaced figures
|
||
|
misplaced = 0
|
||
|
for piece, (color, position) in self.items():
|
||
|
if (color, position) not in goal.values():
|
||
|
misplaced += 1
|
||
|
return misplaced
|
||
|
|
||
|
def is_goal(self, goal):
|
||
|
return self.estimate(goal) == 0
|
||
|
|
||
|
def __hash__(self):
|
||
|
return hash(tuple(sorted(self.items())))
|
||
|
|
||
|
initial = Positions({
|
||
|
'White 1': (True, 1),
|
||
|
'white 2': (True, 6),
|
||
|
'Black 1': (False, 5),
|
||
|
'black 2': (False, 7),
|
||
|
})
|
||
|
|
||
|
# Goal: colors should be switched
|
||
|
goal = Positions((piece, (not color, position))
|
||
|
for piece, (color, position) in initial.items())
|
||
|
|
||
|
def print_board(positions, linebreak='\n', extra=''):
|
||
|
board = dict((position, piece)
|
||
|
for piece, (color, position) in positions.items())
|
||
|
for i in range(1, 11):
|
||
|
# line breaks
|
||
|
if i in (2, 6, 9):
|
||
|
print linebreak,
|
||
|
print board.get(i, '_')[0],
|
||
|
print extra
|
||
|
|
||
|
def notify(cost, state, b, c):
|
||
|
print 'Looking at state with cost %s:' % cost,
|
||
|
print_board(state, '|', '(%s; %s; %s)' % (state.estimate(goal), b, c))
|
||
|
|
||
|
solution_path = list(initial.search(goal, notify=notify))
|
||
|
|
||
|
print 'Step', 0
|
||
|
print_board(initial)
|
||
|
for i, (cost, transition, positions) in enumerate(reversed(solution_path)):
|
||
|
print 'Step', i + 1
|
||
|
print_board(positions)
|
||
|
|
||
|
# Check solution is correct
|
||
|
cost, transition, positions = solution_path[0]
|
||
|
assert set(positions.values()) == set(goal.values())
|
||
|
assert cost == 40
|