veekun_pokedex/pokedex/util/astar.py

300 lines
10 KiB
Python
Raw Normal View History

2011-04-25 22:58:38 +00:00
"""A pure-Python implementation of the A* search algorithm
"""
import heapq
class Node(object):
"""Node for the A* search algorithm.
To get started, implement the `expand` method and call `search`.
N.B. Node object must be hashable.
"""
def expand(self):
"""Return a list of (costs, transition, next_node) for next states
"Next states" are those reachable from this node.
May return any finite iterable.
"""
raise NotImplementedError
def estimate(self, goal):
"""Return an *optimistic* estimate of the cost to the given goal node.
If there are multiple goal states, return the lowest estimate among all
of them.
"""
return 0
def is_goal(self, goal):
"""Return true iff this is a goal node.
"""
return self == goal
def find_path(self, goal=None, **kwargs):
"""Return the best path to the goal
Returns an iterator of (cost, transition, node) triples, in reverse
order (i.e. the first element will have the total cost and goal node).
If `goal` will be passed to the `estimate` and `is_goal` methods.
See a_star for the advanced keyword arguments, `notify` and
`estimate_error_callback`.
"""
paths = self.find_all_paths(goal=goal, **kwargs)
try:
return paths.next()
except StopIteration:
return None
def find_all_paths(self, goal=None, **kwargs):
"""Yield the best path to each goal
Returns an iterator of paths. See the `search` method for how paths
look.
Giving the `goal` argument will cause it to search for that goal,
instead of consulting the `is_goal` method.
This means that if you wish to find more than one path, you must not
pass a `goal` to this method, and instead reimplament `is_goal`.
See a_star for the advanced keyword arguments, `notify` and
`estimate_error_callback`.
"""
return a_star(
initial=self,
expand=lambda s: s.expand(),
estimate=lambda s: s.estimate(goal),
is_goal=lambda s: s.is_goal(goal),
**kwargs)
### The main algorithm
def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None,
estimate_error_callback=None):
"""A* search algorithm for a consistent heuristic
General background: http://en.wikipedia.org/wiki/A*_search_algorithm
This algorithm will work in large or infinite search spaces.
This version of the algorithm is modified for multiple possible goals:
it does not end when it reaches a goal. Rather, it yields the best path
for each goal.
(Exhausting the iterator is of course not recommended for large search
spaces.)
Returns an iterable of paths, where each path is an iterable of
(cummulative cost, transition, node) triples representing the path to
the goal. The transition is the one leading to the corresponding node.
The path is in reverse order, thus its first element will contain the
total cost and the goal node.
The initial node is not included in the returned path.
Arguments:
`initial`: the initial node
`expand`: function yielding a (cost of transition, transition, next node)
triple for each node reachable from its argument.
The `transition` element is application data; it is not touched, only
returned as part of the best path.
`estimate`: function(x) returning optimistic estimate of cost from node x
to a goal. If not given, 0 will be used for estimates.
`is_goal`: function(x) returning true iff x is a goal node
`notify`: If given, if is called at each step with three arguments:
- current cost (with estimate). The cost to the next goal will not be
smaller than this.
- current node
- open set cardinality: roughly, an estimate of the size of the
boundary between "explored" and "unexplored" parts of node space
- debug: stats that be useful for debugging or tuning (in this
implementation, this is the open heap size)
The number of calls to notify or the current cost can be useful as
stopping criteria; the other values may help in tuning estimators.
`estimate_error_callback`: function handling cases where an estimate was
detected not to be optimistic (as A* requires). The function is given a
path (as would be returned by a_star, except it does not lead to a goal
node). By default, nothing is done (indeed, an estimate that's not
strictly optimistic can be useful, esp. if the optimal path is not
required)
"""
# g: best cummulative cost (from initial node) found so far
# h: optimistic estimate of cost to goal
# f: g + h
closed = set() # nodes we don't want to visit again
est = estimate(initial) # estimate total cost
opened = _HeapDict() # node -> (f, g, h)
opened[initial] = (est, 0, est)
came_from = {initial: None} # node -> (prev_node, came_from[prev_node])
while True: # _HeapDict will raise StopIteration for us
x, (f, g, h) = opened.pop()
closed.add(x)
if notify is not None:
notify(f, x, len(opened.dict), len(opened.heap))
if is_goal(x):
yield _trace_path(came_from[x])
for cost, transition, y in expand(x):
if y in closed:
continue
tentative_g = g + cost
old_f, old_g, h = opened.get(y, (None, None, None))
if old_f is None:
h = estimate(y)
elif tentative_g > old_g:
continue
came_from[y] = ((tentative_g, transition, y), came_from[x])
new_f = tentative_g + h
opened[y] = new_f, tentative_g, h
if estimate_error_callback is not None and new_f < f:
estimate_error_callback(_trace_path(came_from[y]))
def _trace_path(cdr):
"""Backtrace an A* result"""
# Convert a lispy list to a pythony iterator
while cdr:
car, cdr = cdr
yield car
class _HeapDict(object):
"""A custom parallel heap/dict structure -- the best of both worlds.
This is NOT a general-purpose class; it only supports what a_star needs.
"""
# The dict has the definitive contents
# The heap has (value, key) pairs. It may have some extra elements.
def __init__(self):
self.dict = {}
self.heap = []
def __setitem__(self, key, value):
self.dict[key] = value
heapq.heappush(self.heap, (value, key))
def __delitem__(self, key):
del self.dict[key]
def get(self, key, default):
"""Return value for key, or default if not found
"""
return self.dict.get(key, default)
def pop(self):
"""Return (key, value) with the smallest value.
Raise StopIteration (!!) if empty
"""
while True:
try:
value, key = heapq.heappop(self.heap)
if value is self.dict[key]:
del self.dict[key]
return key, value
except KeyError:
# deleted from dict = not here
pass
except IndexError:
# nothing more to pop
raise StopIteration
### Example/test
def test_example_knights():
"""Test/example: the "knights" problem
Definition and another solution may be found at:
http://brandon.sternefamily.net/posts/2005/02/a-star-algorithm-in-python/
"""
# Legal moves
moves = { 1: [4, 7],
2: [8, 10],
3: [9],
4: [1, 6, 10],
5: [7],
6: [4],
7: [1, 5],
8: [2, 9],
9: [8, 3],
10: [2, 4] }
class Positions(dict, Node):
"""Node class representing positions as a dictionary.
Keys are unique piece names, values are (color, position) where color
is True for white, False for black.
"""
def expand(self):
for piece, (color, position) in self.items():
for new_position in moves[position]:
if new_position not in (p for c, p in self.values()):
new_node = Positions(self)
new_node.update({piece: (color, new_position)})
yield 1, None, new_node
def estimate(self, goal):
# Number of misplaced figures
misplaced = 0
for piece, (color, position) in self.items():
if (color, position) not in goal.values():
misplaced += 1
return misplaced
def is_goal(self, goal):
return self.estimate(goal) == 0
def __hash__(self):
return hash(tuple(sorted(self.items())))
initial = Positions({
'White 1': (True, 1),
'white 2': (True, 6),
'Black 1': (False, 5),
'black 2': (False, 7),
})
# Goal: colors should be switched
goal = Positions((piece, (not color, position))
for piece, (color, position) in initial.items())
def print_board(positions, linebreak='\n', extra=''):
board = dict((position, piece)
for piece, (color, position) in positions.items())
for i in range(1, 11):
# line breaks
if i in (2, 6, 9):
print linebreak,
print board.get(i, '_')[0],
print extra
def notify(cost, state, b, c):
print 'Looking at state with cost %s:' % cost,
print_board(state, '|', '(%s; %s; %s)' % (state.estimate(goal), b, c))
solution_path = list(initial.search(goal, notify=notify))
print 'Step', 0
print_board(initial)
for i, (cost, transition, positions) in enumerate(reversed(solution_path)):
print 'Step', i + 1
print_board(positions)
# Check solution is correct
cost, transition, positions = solution_path[0]
assert set(positions.values()) == set(goal.values())
assert cost == 40