diff --git a/pokedex/util/astar.py b/pokedex/util/astar.py deleted file mode 100644 index 0e54080..0000000 --- a/pokedex/util/astar.py +++ /dev/null @@ -1,299 +0,0 @@ -"""A pure-Python implementation of the A* search algorithm -""" - -import heapq - -class Node(object): - """Node for the A* search algorithm. - - To get started, implement the `expand` method and call `search`. - - N.B. Node object must be hashable. - """ - - def expand(self): - """Return a list of (costs, transition, next_node) for next states - - "Next states" are those reachable from this node. - - May return any finite iterable. - """ - raise NotImplementedError - - def estimate(self, goal): - """Return an *optimistic* estimate of the cost to the given goal node. - - If there are multiple goal states, return the lowest estimate among all - of them. - """ - return 0 - - def is_goal(self, goal): - """Return true iff this is a goal node. - """ - return self == goal - - def find_path(self, goal=None, **kwargs): - """Return the best path to the goal - - Returns an iterator of (cost, transition, node) triples, in reverse - order (i.e. the first element will have the total cost and goal node). - - If `goal` will be passed to the `estimate` and `is_goal` methods. - - See a_star for the advanced keyword arguments, `notify` and - `estimate_error_callback`. - """ - paths = self.find_all_paths(goal=goal, **kwargs) - try: - return paths.next() - except StopIteration: - return None - - def find_all_paths(self, goal=None, **kwargs): - """Yield the best path to each goal - - Returns an iterator of paths. See the `search` method for how paths - look. - - Giving the `goal` argument will cause it to search for that goal, - instead of consulting the `is_goal` method. - This means that if you wish to find more than one path, you must not - pass a `goal` to this method, and instead reimplament `is_goal`. - - See a_star for the advanced keyword arguments, `notify` and - `estimate_error_callback`. - """ - return a_star( - initial=self, - expand=lambda s: s.expand(), - estimate=lambda s: s.estimate(goal), - is_goal=lambda s: s.is_goal(goal), - **kwargs) - -### The main algorithm - -def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None, - estimate_error_callback=None): - """A* search algorithm for a consistent heuristic - - General background: http://en.wikipedia.org/wiki/A*_search_algorithm - - This algorithm will work in large or infinite search spaces. - - This version of the algorithm is modified for multiple possible goals: - it does not end when it reaches a goal. Rather, it yields the best path - for each goal. - (Exhausting the iterator is of course not recommended for large search - spaces.) - - Returns an iterable of paths, where each path is an iterable of - (cummulative cost, transition, node) triples representing the path to - the goal. The transition is the one leading to the corresponding node. - The path is in reverse order, thus its first element will contain the - total cost and the goal node. - The initial node is not included in the returned path. - - Arguments: - - `initial`: the initial node - - `expand`: function yielding a (cost of transition, transition, next node) - triple for each node reachable from its argument. - The `transition` element is application data; it is not touched, only - returned as part of the best path. - `estimate`: function(x) returning optimistic estimate of cost from node x - to a goal. If not given, 0 will be used for estimates. - `is_goal`: function(x) returning true iff x is a goal node - - `notify`: If given, if is called at each step with three arguments: - - current cost (with estimate). The cost to the next goal will not be - smaller than this. - - current node - - open set cardinality: roughly, an estimate of the size of the - boundary between "explored" and "unexplored" parts of node space - - debug: stats that be useful for debugging or tuning (in this - implementation, this is the open heap size) - The number of calls to notify or the current cost can be useful as - stopping criteria; the other values may help in tuning estimators. - - `estimate_error_callback`: function handling cases where an estimate was - detected not to be optimistic (as A* requires). The function is given a - path (as would be returned by a_star, except it does not lead to a goal - node). By default, nothing is done (indeed, an estimate that's not - strictly optimistic can be useful, esp. if the optimal path is not - required) - """ - # g: best cummulative cost (from initial node) found so far - # h: optimistic estimate of cost to goal - # f: g + h - closed = set() # nodes we don't want to visit again - est = estimate(initial) # estimate total cost - opened = _HeapDict() # node -> (f, g, h) - opened[initial] = (est, 0, est) - came_from = {initial: None} # node -> (prev_node, came_from[prev_node]) - while True: # _HeapDict will raise StopIteration for us - x, (f, g, h) = opened.pop() - closed.add(x) - - if notify is not None: - notify(f, x, len(opened.dict), len(opened.heap)) - - if is_goal(x): - yield _trace_path(came_from[x]) - - for cost, transition, y in expand(x): - if y in closed: - continue - tentative_g = g + cost - - old_f, old_g, h = opened.get(y, (None, None, None)) - - if old_f is None: - h = estimate(y) - elif tentative_g > old_g: - continue - - came_from[y] = ((tentative_g, transition, y), came_from[x]) - new_f = tentative_g + h - - opened[y] = new_f, tentative_g, h - - if estimate_error_callback is not None and new_f < f: - estimate_error_callback(_trace_path(came_from[y])) - -def _trace_path(cdr): - """Backtrace an A* result""" - # Convert a lispy list to a pythony iterator - while cdr: - car, cdr = cdr - yield car - -class _HeapDict(object): - """A custom parallel heap/dict structure -- the best of both worlds. - - This is NOT a general-purpose class; it only supports what a_star needs. - """ - # The dict has the definitive contents - # The heap has (value, key) pairs. It may have some extra elements. - def __init__(self): - self.dict = {} - self.heap = [] - - def __setitem__(self, key, value): - self.dict[key] = value - heapq.heappush(self.heap, (value, key)) - - def __delitem__(self, key): - del self.dict[key] - - def get(self, key, default): - """Return value for key, or default if not found - """ - return self.dict.get(key, default) - - def pop(self): - """Return (key, value) with the smallest value. - - Raise StopIteration (!!) if empty - """ - while True: - try: - value, key = heapq.heappop(self.heap) - if value is self.dict[key]: - del self.dict[key] - return key, value - except KeyError: - # deleted from dict = not here - pass - except IndexError: - # nothing more to pop - raise StopIteration - - -### Example/test - - -def test_example_knights(): - """Test/example: the "knights" problem - - Definition and another solution may be found at: - http://brandon.sternefamily.net/posts/2005/02/a-star-algorithm-in-python/ - """ - # Legal moves - moves = { 1: [4, 7], - 2: [8, 10], - 3: [9], - 4: [1, 6, 10], - 5: [7], - 6: [4], - 7: [1, 5], - 8: [2, 9], - 9: [8, 3], - 10: [2, 4] } - - class Positions(dict, Node): - """Node class representing positions as a dictionary. - - Keys are unique piece names, values are (color, position) where color - is True for white, False for black. - """ - def expand(self): - for piece, (color, position) in self.items(): - for new_position in moves[position]: - if new_position not in (p for c, p in self.values()): - new_node = Positions(self) - new_node.update({piece: (color, new_position)}) - yield 1, None, new_node - - def estimate(self, goal): - # Number of misplaced figures - misplaced = 0 - for piece, (color, position) in self.items(): - if (color, position) not in goal.values(): - misplaced += 1 - return misplaced - - def is_goal(self, goal): - return self.estimate(goal) == 0 - - def __hash__(self): - return hash(tuple(sorted(self.items()))) - - initial = Positions({ - 'White 1': (True, 1), - 'white 2': (True, 6), - 'Black 1': (False, 5), - 'black 2': (False, 7), - }) - - # Goal: colors should be switched - goal = Positions((piece, (not color, position)) - for piece, (color, position) in initial.items()) - - def print_board(positions, linebreak='\n', extra=''): - board = dict((position, piece) - for piece, (color, position) in positions.items()) - for i in range(1, 11): - # line breaks - if i in (2, 6, 9): - print linebreak, - print board.get(i, '_')[0], - print extra - - def notify(cost, state, b, c): - print 'Looking at state with cost %s:' % cost, - print_board(state, '|', '(%s; %s; %s)' % (state.estimate(goal), b, c)) - - solution_path = list(initial.search(goal, notify=notify)) - - print 'Step', 0 - print_board(initial) - for i, (cost, transition, positions) in enumerate(reversed(solution_path)): - print 'Step', i + 1 - print_board(positions) - - # Check solution is correct - cost, transition, positions = solution_path[0] - assert set(positions.values()) == set(goal.values()) - assert cost == 40 diff --git a/pokedex/util/movesets.py b/pokedex/util/movesets.py index 347c9c1..acc516d 100755 --- a/pokedex/util/movesets.py +++ b/pokedex/util/movesets.py @@ -4,6 +4,7 @@ import sys import argparse import itertools +import heapq from collections import defaultdict, namedtuple from sqlalchemy.orm import aliased @@ -12,9 +13,6 @@ from sqlalchemy.sql.expression import not_, and_, or_ from pokedex.db import connect, tables, util -from pokedex.util import querytimer -from pokedex.util.astar import a_star, Node - ### ### Illegal Moveset exceptions ### @@ -44,7 +42,7 @@ def powerset(iterable): class MovesetSearch(object): def __init__(self, session, pokemon, version, moves, level=100, costs=None, - exclude_versions=(), exclude_pokemon=(), debug=False): + exclude_versions=(), exclude_pokemon=(), debug_level=False): self.generator = None @@ -53,7 +51,7 @@ class MovesetSearch(object): elif len(moves) > 4: raise NoMoves('Too many moves specified.') - self.debug = debug + self.debug_level = debug_level self.session = session @@ -78,7 +76,7 @@ class MovesetSearch(object): self.excluded_families = frozenset(p.evolution_chain_id for p in exclude_pokemon) - if debug: + if debug_level > 1: print 'Specified moves:', [move.id for move in moves] self.goal_pokemon = pokemon.id @@ -124,9 +122,10 @@ class MovesetSearch(object): self.output_objects = dict() kwargs = dict() - if debug: + if debug_level: self._astar_debug_notify_counter = 0 kwargs['notify'] = self.astar_debug_notify + kwargs['estimate_error_callback'] = self.astar_estimate_error self.generator = InitialNode(self).find_all_paths(**kwargs) def load_version_groups(self, version, excluded): @@ -151,7 +150,7 @@ class MovesetSearch(object): filtered_map[version] = ( self.generation_id_by_version_group[version]) self.generation_id_by_version_group = filtered_map - if self.debug: + if self.debug_level > 1: print 'Excluded version groups:', excluded print 'Trade cost table:' print '%03s' % '', @@ -178,7 +177,7 @@ class MovesetSearch(object): non_egg_moves is a set of moves that don't require breeding Otherwise, these are empty sets. """ - if self.debug: + if self.debug_level > 1: print 'Loading pokemon moves, %s %s' % (evolution_chain, selection) query = self.session.query( tables.PokemonMove.pokemon_id, @@ -252,10 +251,10 @@ class MovesetSearch(object): continue cost = 1 self.pokemon_moves[pokemon][vg][move][method].append((level, cost)) - if self.debug and selection == 'family': + if self.debug_level > 1 and selection == 'family': print 'Easy moves:', sorted(easy_moves) print 'Non-egg moves:', sorted(non_egg_moves) - if self.debug: + if self.debug_level > 1: print 'Smeargle families:', sorted(self.smeargle_families) return easy_moves, non_egg_moves @@ -364,7 +363,7 @@ class MovesetSearch(object): if move: self.evolution_moves[self.evolution_chains[child]] = move - if self.debug: + if self.debug_level > 1: print 'Loaded %s pokemon: %s evo; %s families: %s breedable' % ( len(self.evolution_chains), len(self.pokemon_by_evolution_chain), @@ -396,7 +395,7 @@ class MovesetSearch(object): ) self.move_generations = dict(query) - if self.debug: + if self.debug_level > 1: print 'Loaded %s moves' % len(self.move_generations) def construct_breed_graph(self): @@ -443,7 +442,7 @@ class MovesetSearch(object): if len(groups) >= 2: eg2_movepools[groups].update(pool) - if self.debug: + if self.debug_level > 1: print 'Egg group summary:' for group in sorted(all_groups): print "%2s can pass: %s" % (group, sorted(eg1_movepools[group])) @@ -496,7 +495,7 @@ class MovesetSearch(object): breeds_required[group][frozenset(moves)] = 1 self.breeds_required = breeds_required - if self.debug: + if self.debug_level > 1: for group, movesetlist in breeds_required.items(): print 'From egg group', group for moveset, cost in movesetlist.items(): @@ -530,7 +529,7 @@ class MovesetSearch(object): last_moves = moves last_gen = gen - if self.debug: + if self.debug_level > 1: print 'Deduplicated %s version groups' % counter def astar_debug_notify(self, cost, node, setsize, heapsize): @@ -538,8 +537,13 @@ class MovesetSearch(object): if counter % 100 == 0: print 'A* iteration %s, cost %s; remaining: %s (%s) \r' % ( counter, cost, setsize, heapsize), + sys.stdout.flush() self._astar_debug_notify_counter += 1 + def astar_estimate_error(self, result): + print '**warning: bad A* estimate**' + print_result(result) + def __iter__(self): return self.generator @@ -585,7 +589,7 @@ default_costs = { # For technical reasons, 'sketch' is also used for learning Sketch and # by normal means, if it isn't included in the target moveset. # So the actual cost of a sketched move will be double this number. - 'sketch': 100, # Cheap. Exclude Smeargle if you think it's too cheap. + 'sketch': 1, # Cheap. Exclude Smeargle if you think it's too cheap. # Gimmick moves – we need to use this method to learn the move anyway, # so make a big-ish dent in the score if missing @@ -604,7 +608,7 @@ default_costs = { 'evolution-delayed': 50, # *in addition* to evolution. Who wants to mash B on every level. 'breed': 400, # Breeding's a pain. 'trade': 200, # Trading's a pain, but not as much as breeding. - 'transfer': 200, # *in addition* to trade. For one-way cross-generation transfers + 'transfer': 150, # *in addition* to trade. Keep it below 'trade'. 'forget': 300, # Deleting a move. (Not needed unless deleting an evolution move.) 'relearn': 150, # Also a pain, though not as big as breeding. 'per-level': 1, # Prefer less grinding. This is for all lv-ups but the final “grow” @@ -620,6 +624,215 @@ default_costs = { 'breed-penalty': 100, } +### +### A* +### + +class Node(object): + """Node for the A* search algorithm. + + To get started, implement the `expand` method and call `search`. + + N.B. Node objects must be hashable. + """ + + def expand(self): + """Return a list of (costs, transition, next_node) for next states + + "Next states" are those reachable from this node. + + May return any finite iterable. + """ + raise NotImplementedError + + def estimate(self, goal): + """Return an *optimistic* estimate of the cost to the given goal node. + + If there are multiple goal states, return the lowest estimate among all + of them. + """ + return 0 + + def is_goal(self, goal): + """Return true iff this is a goal node. + """ + return self == goal + + def find_path(self, goal=None, **kwargs): + """Return the best path to the goal + + Returns an iterator of (cost, transition, node) triples, in reverse + order (i.e. the first element will have the total cost and goal node). + + If `goal` will be passed to the `estimate` and `is_goal` methods. + + See a_star for the advanced keyword arguments, `notify` and + `estimate_error_callback`. + """ + paths = self.find_all_paths(goal=goal, **kwargs) + try: + return paths.next() + except StopIteration: + return None + + def find_all_paths(self, goal=None, **kwargs): + """Yield the best path to each goal + + Returns an iterator of paths. See the `search` method for how paths + look. + + Giving the `goal` argument will cause it to search for that goal, + instead of consulting the `is_goal` method. + This means that if you wish to find more than one path, you must not + pass a `goal` to this method, and instead reimplament `is_goal`. + + See a_star for the advanced keyword arguments, `notify` and + `estimate_error_callback`. + """ + return a_star( + initial=self, + expand=lambda s: s.expand(), + estimate=lambda s: s.estimate(goal), + is_goal=lambda s: s.is_goal(goal), + **kwargs) + +def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None, + estimate_error_callback=None): + """A* search algorithm for a consistent heuristic + + General background: http://en.wikipedia.org/wiki/A*_search_algorithm + + This algorithm will work in large or infinite search spaces. + + This version of the algorithm is modified for multiple possible goals: + it does not end when it reaches a goal. Rather, it yields the best path + for each goal. + (Exhausting the iterator is of course not recommended for large search + spaces.) + + Returns an iterable of paths, where each path is an iterable of + (cummulative cost, transition, node) triples representing the path to + the goal. The transition is the one leading to the corresponding node. + The path is in reverse order, thus its first element will contain the + total cost and the goal node. + The initial node is not included in the returned path. + + Arguments: + + `initial`: the initial node + + `expand`: function yielding a (cost of transition, transition, next node) + triple for each node reachable from its argument. + The `transition` element is application data; it is not touched, only + returned as part of the best path. + `estimate`: function(x) returning optimistic estimate of cost from node x + to a goal. If not given, 0 will be used for estimates. + `is_goal`: function(x) returning true iff x is a goal node + + `notify`: If given, if is called at each step with three arguments: + - current cost (with estimate). The cost to the next goal will not be + smaller than this. + - current node + - open set cardinality: roughly, an estimate of the size of the + boundary between "explored" and "unexplored" parts of node space + - debug: stats that be useful for debugging or tuning (in this + implementation, this is the open heap size) + The number of calls to notify or the current cost can be useful as + stopping criteria; the other values may help in tuning estimators. + + `estimate_error_callback`: function handling cases where an estimate was + detected not to be optimistic (as A* requires). The function is given a + path (as would be returned by a_star, except it does not lead to a goal + node). By default, nothing is done (indeed, an estimate that's not + strictly optimistic can be useful, esp. if the optimal path is not + required) + """ + # g: best cummulative cost (from initial node) found so far + # h: optimistic estimate of cost to goal + # f: g + h + closed = set() # nodes we don't want to visit again + est = estimate(initial) # estimate total cost + opened = _HeapDict() # node -> (f, g, h) + opened[initial] = (est, 0, est) + came_from = {initial: None} # node -> (prev_node, came_from[prev_node]) + while True: # _HeapDict will raise StopIteration for us + x, (f, g, h) = opened.pop() + closed.add(x) + + if notify is not None: + notify(f, x, len(opened.dict), len(opened.heap)) + + if is_goal(x): + yield _trace_path(came_from[x]) + + for cost, transition, y in expand(x): + if y in closed: + continue + tentative_g = g + cost + + old_f, old_g, h = opened.get(y, (None, None, None)) + + if old_f is None: + h = estimate(y) + elif tentative_g > old_g: + continue + + came_from[y] = ((tentative_g, transition, y), came_from[x]) + new_f = tentative_g + h + + opened[y] = new_f, tentative_g, h + + if estimate_error_callback is not None and new_f < f: + estimate_error_callback(_trace_path(came_from[y])) + +def _trace_path(cdr): + """Backtrace an A* result""" + # Convert a lispy list to a pythony iterator + while cdr: + car, cdr = cdr + yield car + +class _HeapDict(object): + """A custom parallel heap/dict structure -- the best of both worlds. + + This is NOT a general-purpose class; it only supports what a_star needs. + """ + # The dict has the definitive contents + # The heap has (value, key) pairs. It may have some extra elements. + def __init__(self): + self.dict = {} + self.heap = [] + + def __setitem__(self, key, value): + self.dict[key] = value + heapq.heappush(self.heap, (value, key)) + + def __delitem__(self, key): + del self.dict[key] + + def get(self, key, default): + """Return value for key, or default if not found + """ + return self.dict.get(key, default) + + def pop(self): + """Return (key, value) with the smallest value. + + Raise StopIteration (!!) if empty + """ + while True: + try: + value, key = heapq.heappop(self.heap) + if value is self.dict[key]: + del self.dict[key] + return key, value + except KeyError: + # deleted from dict = not here + pass + except IndexError: + # nothing more to pop + raise StopIteration + ### ### Result objects ### @@ -828,8 +1041,9 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode', def expand_forget(self): cost = self.search.costs['forget'] for move in self.moves_: - yield cost, ForgetAction(self.search, move), self._replace( - moves_=self.moves_.difference([move]), new_level=False) + if move not in self.search.goal_moves: + yield cost, ForgetAction(self.search, move), self._replace( + moves_=self.moves_.difference([move]), new_level=False) def expand_trade(self): search = self.search @@ -923,7 +1137,7 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode', moves = self.moves_ for sketch in moves: if sketch == self.search.sketch: - for sketched in self.search.goal_moves: + for sketched in sorted(self.search.goal_moves): if sketched in self.search.unsketchable: continue if sketched not in moves: @@ -936,6 +1150,26 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode', new_level=False, moves_=frozenset(moves)) return + def estimate(self, g): + # Given good estimates, A* finds solutions much faster. + # However, here it seems we either have easy movesets, which + # get found pretty easily by themselves, or hard ones, where + # heuristics don't help too much, or impossible ones where they + # don't matter at all. + # So, keep the computations here to a minimum. + search = self.search + trade_cost = search.trade_cost(self.version_group_, + search.goal_version_group) + if trade_cost is None: + trade_cost = search.costs['trade'] * 2 + return trade_cost + evo_chain = search.evolution_chains[self.pokemon_] + if evo_chain == search.goal_evolution_chain: + breed_cost = 0 + else: + breed_cost = search.costs['breed'] + return trade_cost + breed_cost + class BaseBreedNode(Node): """Breed node This serves to prevent duplicate breeds, by storing only the needed info @@ -955,8 +1189,8 @@ class BaseBreedNode(Node): continue if len(bred_moves) < 4: for move, methods in moves.items(): - if 'light-ball-pichu' in methods: - bred_moves.add(move) + if 'light-ball-egg' in methods: + bred_moves = bred_moves.union([move]) cost = search.costs['per-hatch-counter'] * search.hatch_counters[baby] yield 0, BreedAction(self.search, baby, bred_moves), PokemonNode( search=self.search, pokemon_=baby, level=hatch_level, @@ -995,11 +1229,30 @@ class GoalNode(PokemonNode): def is_goal(self, g): return True - ### ### CLI interface ### +def print_result(result, moves=()): + template = u"{cost:4} {est:4} {action:45.45}{long:1} {pokemon:10}{level:>3}{nl:1}{versions:2} {moves}" + print template.format(cost='Cost', est='Est.', action='Action', pokemon='Pokemon', + long='', level='Lv.', nl='V', versions='er', + moves=''.join(m.name[0].lower() for m in moves)) + for cost, action, node in reversed(list(result)): + if action: + print template.format( + cost=cost, + action=action, + long='>' if (len(unicode(action)) > 45) else '', + est=node.estimate(None), + pokemon=node.pokemon.name, + nl='.' if node.new_level else ' ', + level=node.level, + versions=''.join(v.name[0] for v in node.versions), + moves=''.join('.' if m in node.moves else ' ' for m in moves) + + ''.join(m.name[0].lower() for m in node.moves if m not in moves), + ) + def main(argv): parser = argparse.ArgumentParser(description= 'Find out if the specified moveset is valid, and provide a suggestion ' @@ -1039,7 +1292,7 @@ def main(argv): if args.debug: print 'Connecting' - session = connect(engine_args={'echo': args.debug > 1}) + session = connect(engine_args={'echo': args.debug > 2}) if args.debug: print 'Parsing arguments' @@ -1068,37 +1321,20 @@ def main(argv): try: search = MovesetSearch(session, pokemon, version, moves, args.level, exclude_versions=excl_versions, exclude_pokemon=excl_pokemon, - debug=args.debug) + debug_level=args.debug) except IllegalMoveCombination, e: print 'Error:', e else: if args.debug: print 'Setup done' - template = u"{cost:4} {action:50.50}{long:1} {pokemon:10}{level:>3}{nl:1}{versions:2} {moves}" for result in search: + if args.debug and search.output_objects: + print '**warning: search looked up output objects**' + no_results = False print '-' * 79 - if no_results: - if search.output_objects: - print '**warning: search looked up output objects**' - no_results = False - print template.format(cost='Cost', action='Action', pokemon='Pokemon', - long='',level='Lv.', nl='V', versions='er', - moves=''.join(m.name[0].lower() for m in moves)) - for cost, action, node in reversed(list(result)): - if action: - print template.format( - cost=cost, - action=action, - long='>' if len(str(action)) > 50 else '', - pokemon=node.pokemon.name, - nl='.' if node.new_level else ' ', - level=node.level, - versions=''.join(v.name[0] for v in node.versions), - moves=''.join('.' if m in node.moves else ' ' for m in moves) + - ''.join(m.name[0].lower() for m in node.moves if m not in moves), - ) - # XXX: Support more results + print_result(result, moves=moves) + # XXX: Support more than one result break if args.debug: