Merge A* in. General improvements.

This commit is contained in:
Petr Viktorin 2011-04-27 04:13:54 +03:00
parent 88b9a216c7
commit b2e54ca1ea
2 changed files with 284 additions and 347 deletions

View file

@ -1,299 +0,0 @@
"""A pure-Python implementation of the A* search algorithm
"""
import heapq
class Node(object):
"""Node for the A* search algorithm.
To get started, implement the `expand` method and call `search`.
N.B. Node object must be hashable.
"""
def expand(self):
"""Return a list of (costs, transition, next_node) for next states
"Next states" are those reachable from this node.
May return any finite iterable.
"""
raise NotImplementedError
def estimate(self, goal):
"""Return an *optimistic* estimate of the cost to the given goal node.
If there are multiple goal states, return the lowest estimate among all
of them.
"""
return 0
def is_goal(self, goal):
"""Return true iff this is a goal node.
"""
return self == goal
def find_path(self, goal=None, **kwargs):
"""Return the best path to the goal
Returns an iterator of (cost, transition, node) triples, in reverse
order (i.e. the first element will have the total cost and goal node).
If `goal` will be passed to the `estimate` and `is_goal` methods.
See a_star for the advanced keyword arguments, `notify` and
`estimate_error_callback`.
"""
paths = self.find_all_paths(goal=goal, **kwargs)
try:
return paths.next()
except StopIteration:
return None
def find_all_paths(self, goal=None, **kwargs):
"""Yield the best path to each goal
Returns an iterator of paths. See the `search` method for how paths
look.
Giving the `goal` argument will cause it to search for that goal,
instead of consulting the `is_goal` method.
This means that if you wish to find more than one path, you must not
pass a `goal` to this method, and instead reimplament `is_goal`.
See a_star for the advanced keyword arguments, `notify` and
`estimate_error_callback`.
"""
return a_star(
initial=self,
expand=lambda s: s.expand(),
estimate=lambda s: s.estimate(goal),
is_goal=lambda s: s.is_goal(goal),
**kwargs)
### The main algorithm
def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None,
estimate_error_callback=None):
"""A* search algorithm for a consistent heuristic
General background: http://en.wikipedia.org/wiki/A*_search_algorithm
This algorithm will work in large or infinite search spaces.
This version of the algorithm is modified for multiple possible goals:
it does not end when it reaches a goal. Rather, it yields the best path
for each goal.
(Exhausting the iterator is of course not recommended for large search
spaces.)
Returns an iterable of paths, where each path is an iterable of
(cummulative cost, transition, node) triples representing the path to
the goal. The transition is the one leading to the corresponding node.
The path is in reverse order, thus its first element will contain the
total cost and the goal node.
The initial node is not included in the returned path.
Arguments:
`initial`: the initial node
`expand`: function yielding a (cost of transition, transition, next node)
triple for each node reachable from its argument.
The `transition` element is application data; it is not touched, only
returned as part of the best path.
`estimate`: function(x) returning optimistic estimate of cost from node x
to a goal. If not given, 0 will be used for estimates.
`is_goal`: function(x) returning true iff x is a goal node
`notify`: If given, if is called at each step with three arguments:
- current cost (with estimate). The cost to the next goal will not be
smaller than this.
- current node
- open set cardinality: roughly, an estimate of the size of the
boundary between "explored" and "unexplored" parts of node space
- debug: stats that be useful for debugging or tuning (in this
implementation, this is the open heap size)
The number of calls to notify or the current cost can be useful as
stopping criteria; the other values may help in tuning estimators.
`estimate_error_callback`: function handling cases where an estimate was
detected not to be optimistic (as A* requires). The function is given a
path (as would be returned by a_star, except it does not lead to a goal
node). By default, nothing is done (indeed, an estimate that's not
strictly optimistic can be useful, esp. if the optimal path is not
required)
"""
# g: best cummulative cost (from initial node) found so far
# h: optimistic estimate of cost to goal
# f: g + h
closed = set() # nodes we don't want to visit again
est = estimate(initial) # estimate total cost
opened = _HeapDict() # node -> (f, g, h)
opened[initial] = (est, 0, est)
came_from = {initial: None} # node -> (prev_node, came_from[prev_node])
while True: # _HeapDict will raise StopIteration for us
x, (f, g, h) = opened.pop()
closed.add(x)
if notify is not None:
notify(f, x, len(opened.dict), len(opened.heap))
if is_goal(x):
yield _trace_path(came_from[x])
for cost, transition, y in expand(x):
if y in closed:
continue
tentative_g = g + cost
old_f, old_g, h = opened.get(y, (None, None, None))
if old_f is None:
h = estimate(y)
elif tentative_g > old_g:
continue
came_from[y] = ((tentative_g, transition, y), came_from[x])
new_f = tentative_g + h
opened[y] = new_f, tentative_g, h
if estimate_error_callback is not None and new_f < f:
estimate_error_callback(_trace_path(came_from[y]))
def _trace_path(cdr):
"""Backtrace an A* result"""
# Convert a lispy list to a pythony iterator
while cdr:
car, cdr = cdr
yield car
class _HeapDict(object):
"""A custom parallel heap/dict structure -- the best of both worlds.
This is NOT a general-purpose class; it only supports what a_star needs.
"""
# The dict has the definitive contents
# The heap has (value, key) pairs. It may have some extra elements.
def __init__(self):
self.dict = {}
self.heap = []
def __setitem__(self, key, value):
self.dict[key] = value
heapq.heappush(self.heap, (value, key))
def __delitem__(self, key):
del self.dict[key]
def get(self, key, default):
"""Return value for key, or default if not found
"""
return self.dict.get(key, default)
def pop(self):
"""Return (key, value) with the smallest value.
Raise StopIteration (!!) if empty
"""
while True:
try:
value, key = heapq.heappop(self.heap)
if value is self.dict[key]:
del self.dict[key]
return key, value
except KeyError:
# deleted from dict = not here
pass
except IndexError:
# nothing more to pop
raise StopIteration
### Example/test
def test_example_knights():
"""Test/example: the "knights" problem
Definition and another solution may be found at:
http://brandon.sternefamily.net/posts/2005/02/a-star-algorithm-in-python/
"""
# Legal moves
moves = { 1: [4, 7],
2: [8, 10],
3: [9],
4: [1, 6, 10],
5: [7],
6: [4],
7: [1, 5],
8: [2, 9],
9: [8, 3],
10: [2, 4] }
class Positions(dict, Node):
"""Node class representing positions as a dictionary.
Keys are unique piece names, values are (color, position) where color
is True for white, False for black.
"""
def expand(self):
for piece, (color, position) in self.items():
for new_position in moves[position]:
if new_position not in (p for c, p in self.values()):
new_node = Positions(self)
new_node.update({piece: (color, new_position)})
yield 1, None, new_node
def estimate(self, goal):
# Number of misplaced figures
misplaced = 0
for piece, (color, position) in self.items():
if (color, position) not in goal.values():
misplaced += 1
return misplaced
def is_goal(self, goal):
return self.estimate(goal) == 0
def __hash__(self):
return hash(tuple(sorted(self.items())))
initial = Positions({
'White 1': (True, 1),
'white 2': (True, 6),
'Black 1': (False, 5),
'black 2': (False, 7),
})
# Goal: colors should be switched
goal = Positions((piece, (not color, position))
for piece, (color, position) in initial.items())
def print_board(positions, linebreak='\n', extra=''):
board = dict((position, piece)
for piece, (color, position) in positions.items())
for i in range(1, 11):
# line breaks
if i in (2, 6, 9):
print linebreak,
print board.get(i, '_')[0],
print extra
def notify(cost, state, b, c):
print 'Looking at state with cost %s:' % cost,
print_board(state, '|', '(%s; %s; %s)' % (state.estimate(goal), b, c))
solution_path = list(initial.search(goal, notify=notify))
print 'Step', 0
print_board(initial)
for i, (cost, transition, positions) in enumerate(reversed(solution_path)):
print 'Step', i + 1
print_board(positions)
# Check solution is correct
cost, transition, positions = solution_path[0]
assert set(positions.values()) == set(goal.values())
assert cost == 40

View file

@ -4,6 +4,7 @@
import sys import sys
import argparse import argparse
import itertools import itertools
import heapq
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
from sqlalchemy.orm import aliased from sqlalchemy.orm import aliased
@ -12,9 +13,6 @@ from sqlalchemy.sql.expression import not_, and_, or_
from pokedex.db import connect, tables, util from pokedex.db import connect, tables, util
from pokedex.util import querytimer
from pokedex.util.astar import a_star, Node
### ###
### Illegal Moveset exceptions ### Illegal Moveset exceptions
### ###
@ -44,7 +42,7 @@ def powerset(iterable):
class MovesetSearch(object): class MovesetSearch(object):
def __init__(self, session, pokemon, version, moves, level=100, costs=None, def __init__(self, session, pokemon, version, moves, level=100, costs=None,
exclude_versions=(), exclude_pokemon=(), debug=False): exclude_versions=(), exclude_pokemon=(), debug_level=False):
self.generator = None self.generator = None
@ -53,7 +51,7 @@ class MovesetSearch(object):
elif len(moves) > 4: elif len(moves) > 4:
raise NoMoves('Too many moves specified.') raise NoMoves('Too many moves specified.')
self.debug = debug self.debug_level = debug_level
self.session = session self.session = session
@ -78,7 +76,7 @@ class MovesetSearch(object):
self.excluded_families = frozenset(p.evolution_chain_id self.excluded_families = frozenset(p.evolution_chain_id
for p in exclude_pokemon) for p in exclude_pokemon)
if debug: if debug_level > 1:
print 'Specified moves:', [move.id for move in moves] print 'Specified moves:', [move.id for move in moves]
self.goal_pokemon = pokemon.id self.goal_pokemon = pokemon.id
@ -124,9 +122,10 @@ class MovesetSearch(object):
self.output_objects = dict() self.output_objects = dict()
kwargs = dict() kwargs = dict()
if debug: if debug_level:
self._astar_debug_notify_counter = 0 self._astar_debug_notify_counter = 0
kwargs['notify'] = self.astar_debug_notify kwargs['notify'] = self.astar_debug_notify
kwargs['estimate_error_callback'] = self.astar_estimate_error
self.generator = InitialNode(self).find_all_paths(**kwargs) self.generator = InitialNode(self).find_all_paths(**kwargs)
def load_version_groups(self, version, excluded): def load_version_groups(self, version, excluded):
@ -151,7 +150,7 @@ class MovesetSearch(object):
filtered_map[version] = ( filtered_map[version] = (
self.generation_id_by_version_group[version]) self.generation_id_by_version_group[version])
self.generation_id_by_version_group = filtered_map self.generation_id_by_version_group = filtered_map
if self.debug: if self.debug_level > 1:
print 'Excluded version groups:', excluded print 'Excluded version groups:', excluded
print 'Trade cost table:' print 'Trade cost table:'
print '%03s' % '', print '%03s' % '',
@ -178,7 +177,7 @@ class MovesetSearch(object):
non_egg_moves is a set of moves that don't require breeding non_egg_moves is a set of moves that don't require breeding
Otherwise, these are empty sets. Otherwise, these are empty sets.
""" """
if self.debug: if self.debug_level > 1:
print 'Loading pokemon moves, %s %s' % (evolution_chain, selection) print 'Loading pokemon moves, %s %s' % (evolution_chain, selection)
query = self.session.query( query = self.session.query(
tables.PokemonMove.pokemon_id, tables.PokemonMove.pokemon_id,
@ -252,10 +251,10 @@ class MovesetSearch(object):
continue continue
cost = 1 cost = 1
self.pokemon_moves[pokemon][vg][move][method].append((level, cost)) self.pokemon_moves[pokemon][vg][move][method].append((level, cost))
if self.debug and selection == 'family': if self.debug_level > 1 and selection == 'family':
print 'Easy moves:', sorted(easy_moves) print 'Easy moves:', sorted(easy_moves)
print 'Non-egg moves:', sorted(non_egg_moves) print 'Non-egg moves:', sorted(non_egg_moves)
if self.debug: if self.debug_level > 1:
print 'Smeargle families:', sorted(self.smeargle_families) print 'Smeargle families:', sorted(self.smeargle_families)
return easy_moves, non_egg_moves return easy_moves, non_egg_moves
@ -364,7 +363,7 @@ class MovesetSearch(object):
if move: if move:
self.evolution_moves[self.evolution_chains[child]] = move self.evolution_moves[self.evolution_chains[child]] = move
if self.debug: if self.debug_level > 1:
print 'Loaded %s pokemon: %s evo; %s families: %s breedable' % ( print 'Loaded %s pokemon: %s evo; %s families: %s breedable' % (
len(self.evolution_chains), len(self.evolution_chains),
len(self.pokemon_by_evolution_chain), len(self.pokemon_by_evolution_chain),
@ -396,7 +395,7 @@ class MovesetSearch(object):
) )
self.move_generations = dict(query) self.move_generations = dict(query)
if self.debug: if self.debug_level > 1:
print 'Loaded %s moves' % len(self.move_generations) print 'Loaded %s moves' % len(self.move_generations)
def construct_breed_graph(self): def construct_breed_graph(self):
@ -443,7 +442,7 @@ class MovesetSearch(object):
if len(groups) >= 2: if len(groups) >= 2:
eg2_movepools[groups].update(pool) eg2_movepools[groups].update(pool)
if self.debug: if self.debug_level > 1:
print 'Egg group summary:' print 'Egg group summary:'
for group in sorted(all_groups): for group in sorted(all_groups):
print "%2s can pass: %s" % (group, sorted(eg1_movepools[group])) print "%2s can pass: %s" % (group, sorted(eg1_movepools[group]))
@ -496,7 +495,7 @@ class MovesetSearch(object):
breeds_required[group][frozenset(moves)] = 1 breeds_required[group][frozenset(moves)] = 1
self.breeds_required = breeds_required self.breeds_required = breeds_required
if self.debug: if self.debug_level > 1:
for group, movesetlist in breeds_required.items(): for group, movesetlist in breeds_required.items():
print 'From egg group', group print 'From egg group', group
for moveset, cost in movesetlist.items(): for moveset, cost in movesetlist.items():
@ -530,7 +529,7 @@ class MovesetSearch(object):
last_moves = moves last_moves = moves
last_gen = gen last_gen = gen
if self.debug: if self.debug_level > 1:
print 'Deduplicated %s version groups' % counter print 'Deduplicated %s version groups' % counter
def astar_debug_notify(self, cost, node, setsize, heapsize): def astar_debug_notify(self, cost, node, setsize, heapsize):
@ -538,8 +537,13 @@ class MovesetSearch(object):
if counter % 100 == 0: if counter % 100 == 0:
print 'A* iteration %s, cost %s; remaining: %s (%s) \r' % ( print 'A* iteration %s, cost %s; remaining: %s (%s) \r' % (
counter, cost, setsize, heapsize), counter, cost, setsize, heapsize),
sys.stdout.flush()
self._astar_debug_notify_counter += 1 self._astar_debug_notify_counter += 1
def astar_estimate_error(self, result):
print '**warning: bad A* estimate**'
print_result(result)
def __iter__(self): def __iter__(self):
return self.generator return self.generator
@ -585,7 +589,7 @@ default_costs = {
# For technical reasons, 'sketch' is also used for learning Sketch and # For technical reasons, 'sketch' is also used for learning Sketch and
# by normal means, if it isn't included in the target moveset. # by normal means, if it isn't included in the target moveset.
# So the actual cost of a sketched move will be double this number. # So the actual cost of a sketched move will be double this number.
'sketch': 100, # Cheap. Exclude Smeargle if you think it's too cheap. 'sketch': 1, # Cheap. Exclude Smeargle if you think it's too cheap.
# Gimmick moves we need to use this method to learn the move anyway, # Gimmick moves we need to use this method to learn the move anyway,
# so make a big-ish dent in the score if missing # so make a big-ish dent in the score if missing
@ -604,7 +608,7 @@ default_costs = {
'evolution-delayed': 50, # *in addition* to evolution. Who wants to mash B on every level. 'evolution-delayed': 50, # *in addition* to evolution. Who wants to mash B on every level.
'breed': 400, # Breeding's a pain. 'breed': 400, # Breeding's a pain.
'trade': 200, # Trading's a pain, but not as much as breeding. 'trade': 200, # Trading's a pain, but not as much as breeding.
'transfer': 200, # *in addition* to trade. For one-way cross-generation transfers 'transfer': 150, # *in addition* to trade. Keep it below 'trade'.
'forget': 300, # Deleting a move. (Not needed unless deleting an evolution move.) 'forget': 300, # Deleting a move. (Not needed unless deleting an evolution move.)
'relearn': 150, # Also a pain, though not as big as breeding. 'relearn': 150, # Also a pain, though not as big as breeding.
'per-level': 1, # Prefer less grinding. This is for all lv-ups but the final “grow” 'per-level': 1, # Prefer less grinding. This is for all lv-ups but the final “grow”
@ -620,6 +624,215 @@ default_costs = {
'breed-penalty': 100, 'breed-penalty': 100,
} }
###
### A*
###
class Node(object):
"""Node for the A* search algorithm.
To get started, implement the `expand` method and call `search`.
N.B. Node objects must be hashable.
"""
def expand(self):
"""Return a list of (costs, transition, next_node) for next states
"Next states" are those reachable from this node.
May return any finite iterable.
"""
raise NotImplementedError
def estimate(self, goal):
"""Return an *optimistic* estimate of the cost to the given goal node.
If there are multiple goal states, return the lowest estimate among all
of them.
"""
return 0
def is_goal(self, goal):
"""Return true iff this is a goal node.
"""
return self == goal
def find_path(self, goal=None, **kwargs):
"""Return the best path to the goal
Returns an iterator of (cost, transition, node) triples, in reverse
order (i.e. the first element will have the total cost and goal node).
If `goal` will be passed to the `estimate` and `is_goal` methods.
See a_star for the advanced keyword arguments, `notify` and
`estimate_error_callback`.
"""
paths = self.find_all_paths(goal=goal, **kwargs)
try:
return paths.next()
except StopIteration:
return None
def find_all_paths(self, goal=None, **kwargs):
"""Yield the best path to each goal
Returns an iterator of paths. See the `search` method for how paths
look.
Giving the `goal` argument will cause it to search for that goal,
instead of consulting the `is_goal` method.
This means that if you wish to find more than one path, you must not
pass a `goal` to this method, and instead reimplament `is_goal`.
See a_star for the advanced keyword arguments, `notify` and
`estimate_error_callback`.
"""
return a_star(
initial=self,
expand=lambda s: s.expand(),
estimate=lambda s: s.estimate(goal),
is_goal=lambda s: s.is_goal(goal),
**kwargs)
def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None,
estimate_error_callback=None):
"""A* search algorithm for a consistent heuristic
General background: http://en.wikipedia.org/wiki/A*_search_algorithm
This algorithm will work in large or infinite search spaces.
This version of the algorithm is modified for multiple possible goals:
it does not end when it reaches a goal. Rather, it yields the best path
for each goal.
(Exhausting the iterator is of course not recommended for large search
spaces.)
Returns an iterable of paths, where each path is an iterable of
(cummulative cost, transition, node) triples representing the path to
the goal. The transition is the one leading to the corresponding node.
The path is in reverse order, thus its first element will contain the
total cost and the goal node.
The initial node is not included in the returned path.
Arguments:
`initial`: the initial node
`expand`: function yielding a (cost of transition, transition, next node)
triple for each node reachable from its argument.
The `transition` element is application data; it is not touched, only
returned as part of the best path.
`estimate`: function(x) returning optimistic estimate of cost from node x
to a goal. If not given, 0 will be used for estimates.
`is_goal`: function(x) returning true iff x is a goal node
`notify`: If given, if is called at each step with three arguments:
- current cost (with estimate). The cost to the next goal will not be
smaller than this.
- current node
- open set cardinality: roughly, an estimate of the size of the
boundary between "explored" and "unexplored" parts of node space
- debug: stats that be useful for debugging or tuning (in this
implementation, this is the open heap size)
The number of calls to notify or the current cost can be useful as
stopping criteria; the other values may help in tuning estimators.
`estimate_error_callback`: function handling cases where an estimate was
detected not to be optimistic (as A* requires). The function is given a
path (as would be returned by a_star, except it does not lead to a goal
node). By default, nothing is done (indeed, an estimate that's not
strictly optimistic can be useful, esp. if the optimal path is not
required)
"""
# g: best cummulative cost (from initial node) found so far
# h: optimistic estimate of cost to goal
# f: g + h
closed = set() # nodes we don't want to visit again
est = estimate(initial) # estimate total cost
opened = _HeapDict() # node -> (f, g, h)
opened[initial] = (est, 0, est)
came_from = {initial: None} # node -> (prev_node, came_from[prev_node])
while True: # _HeapDict will raise StopIteration for us
x, (f, g, h) = opened.pop()
closed.add(x)
if notify is not None:
notify(f, x, len(opened.dict), len(opened.heap))
if is_goal(x):
yield _trace_path(came_from[x])
for cost, transition, y in expand(x):
if y in closed:
continue
tentative_g = g + cost
old_f, old_g, h = opened.get(y, (None, None, None))
if old_f is None:
h = estimate(y)
elif tentative_g > old_g:
continue
came_from[y] = ((tentative_g, transition, y), came_from[x])
new_f = tentative_g + h
opened[y] = new_f, tentative_g, h
if estimate_error_callback is not None and new_f < f:
estimate_error_callback(_trace_path(came_from[y]))
def _trace_path(cdr):
"""Backtrace an A* result"""
# Convert a lispy list to a pythony iterator
while cdr:
car, cdr = cdr
yield car
class _HeapDict(object):
"""A custom parallel heap/dict structure -- the best of both worlds.
This is NOT a general-purpose class; it only supports what a_star needs.
"""
# The dict has the definitive contents
# The heap has (value, key) pairs. It may have some extra elements.
def __init__(self):
self.dict = {}
self.heap = []
def __setitem__(self, key, value):
self.dict[key] = value
heapq.heappush(self.heap, (value, key))
def __delitem__(self, key):
del self.dict[key]
def get(self, key, default):
"""Return value for key, or default if not found
"""
return self.dict.get(key, default)
def pop(self):
"""Return (key, value) with the smallest value.
Raise StopIteration (!!) if empty
"""
while True:
try:
value, key = heapq.heappop(self.heap)
if value is self.dict[key]:
del self.dict[key]
return key, value
except KeyError:
# deleted from dict = not here
pass
except IndexError:
# nothing more to pop
raise StopIteration
### ###
### Result objects ### Result objects
### ###
@ -828,8 +1041,9 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode',
def expand_forget(self): def expand_forget(self):
cost = self.search.costs['forget'] cost = self.search.costs['forget']
for move in self.moves_: for move in self.moves_:
yield cost, ForgetAction(self.search, move), self._replace( if move not in self.search.goal_moves:
moves_=self.moves_.difference([move]), new_level=False) yield cost, ForgetAction(self.search, move), self._replace(
moves_=self.moves_.difference([move]), new_level=False)
def expand_trade(self): def expand_trade(self):
search = self.search search = self.search
@ -923,7 +1137,7 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode',
moves = self.moves_ moves = self.moves_
for sketch in moves: for sketch in moves:
if sketch == self.search.sketch: if sketch == self.search.sketch:
for sketched in self.search.goal_moves: for sketched in sorted(self.search.goal_moves):
if sketched in self.search.unsketchable: if sketched in self.search.unsketchable:
continue continue
if sketched not in moves: if sketched not in moves:
@ -936,6 +1150,26 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode',
new_level=False, moves_=frozenset(moves)) new_level=False, moves_=frozenset(moves))
return return
def estimate(self, g):
# Given good estimates, A* finds solutions much faster.
# However, here it seems we either have easy movesets, which
# get found pretty easily by themselves, or hard ones, where
# heuristics don't help too much, or impossible ones where they
# don't matter at all.
# So, keep the computations here to a minimum.
search = self.search
trade_cost = search.trade_cost(self.version_group_,
search.goal_version_group)
if trade_cost is None:
trade_cost = search.costs['trade'] * 2
return trade_cost
evo_chain = search.evolution_chains[self.pokemon_]
if evo_chain == search.goal_evolution_chain:
breed_cost = 0
else:
breed_cost = search.costs['breed']
return trade_cost + breed_cost
class BaseBreedNode(Node): class BaseBreedNode(Node):
"""Breed node """Breed node
This serves to prevent duplicate breeds, by storing only the needed info This serves to prevent duplicate breeds, by storing only the needed info
@ -955,8 +1189,8 @@ class BaseBreedNode(Node):
continue continue
if len(bred_moves) < 4: if len(bred_moves) < 4:
for move, methods in moves.items(): for move, methods in moves.items():
if 'light-ball-pichu' in methods: if 'light-ball-egg' in methods:
bred_moves.add(move) bred_moves = bred_moves.union([move])
cost = search.costs['per-hatch-counter'] * search.hatch_counters[baby] cost = search.costs['per-hatch-counter'] * search.hatch_counters[baby]
yield 0, BreedAction(self.search, baby, bred_moves), PokemonNode( yield 0, BreedAction(self.search, baby, bred_moves), PokemonNode(
search=self.search, pokemon_=baby, level=hatch_level, search=self.search, pokemon_=baby, level=hatch_level,
@ -995,11 +1229,30 @@ class GoalNode(PokemonNode):
def is_goal(self, g): def is_goal(self, g):
return True return True
### ###
### CLI interface ### CLI interface
### ###
def print_result(result, moves=()):
template = u"{cost:4} {est:4} {action:45.45}{long:1} {pokemon:10}{level:>3}{nl:1}{versions:2} {moves}"
print template.format(cost='Cost', est='Est.', action='Action', pokemon='Pokemon',
long='', level='Lv.', nl='V', versions='er',
moves=''.join(m.name[0].lower() for m in moves))
for cost, action, node in reversed(list(result)):
if action:
print template.format(
cost=cost,
action=action,
long='>' if (len(unicode(action)) > 45) else '',
est=node.estimate(None),
pokemon=node.pokemon.name,
nl='.' if node.new_level else ' ',
level=node.level,
versions=''.join(v.name[0] for v in node.versions),
moves=''.join('.' if m in node.moves else ' ' for m in moves) +
''.join(m.name[0].lower() for m in node.moves if m not in moves),
)
def main(argv): def main(argv):
parser = argparse.ArgumentParser(description= parser = argparse.ArgumentParser(description=
'Find out if the specified moveset is valid, and provide a suggestion ' 'Find out if the specified moveset is valid, and provide a suggestion '
@ -1039,7 +1292,7 @@ def main(argv):
if args.debug: if args.debug:
print 'Connecting' print 'Connecting'
session = connect(engine_args={'echo': args.debug > 1}) session = connect(engine_args={'echo': args.debug > 2})
if args.debug: if args.debug:
print 'Parsing arguments' print 'Parsing arguments'
@ -1068,37 +1321,20 @@ def main(argv):
try: try:
search = MovesetSearch(session, pokemon, version, moves, args.level, search = MovesetSearch(session, pokemon, version, moves, args.level,
exclude_versions=excl_versions, exclude_pokemon=excl_pokemon, exclude_versions=excl_versions, exclude_pokemon=excl_pokemon,
debug=args.debug) debug_level=args.debug)
except IllegalMoveCombination, e: except IllegalMoveCombination, e:
print 'Error:', e print 'Error:', e
else: else:
if args.debug: if args.debug:
print 'Setup done' print 'Setup done'
template = u"{cost:4} {action:50.50}{long:1} {pokemon:10}{level:>3}{nl:1}{versions:2} {moves}"
for result in search: for result in search:
if args.debug and search.output_objects:
print '**warning: search looked up output objects**'
no_results = False
print '-' * 79 print '-' * 79
if no_results: print_result(result, moves=moves)
if search.output_objects: # XXX: Support more than one result
print '**warning: search looked up output objects**'
no_results = False
print template.format(cost='Cost', action='Action', pokemon='Pokemon',
long='',level='Lv.', nl='V', versions='er',
moves=''.join(m.name[0].lower() for m in moves))
for cost, action, node in reversed(list(result)):
if action:
print template.format(
cost=cost,
action=action,
long='>' if len(str(action)) > 50 else '',
pokemon=node.pokemon.name,
nl='.' if node.new_level else ' ',
level=node.level,
versions=''.join(v.name[0] for v in node.versions),
moves=''.join('.' if m in node.moves else ' ' for m in moves) +
''.join(m.name[0].lower() for m in node.moves if m not in moves),
)
# XXX: Support more results
break break
if args.debug: if args.debug: