diff --git a/pokedex/util/movesets.py b/pokedex/util/movesets.py index 6f90a06..fdba349 100755 --- a/pokedex/util/movesets.py +++ b/pokedex/util/movesets.py @@ -5,6 +5,7 @@ import sys import argparse from collections import defaultdict +from sqlalchemy.orm import aliased from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.sql.expression import not_, and_, or_ @@ -37,6 +38,12 @@ class MovesetSearch(object): self.session = session + self.sketch = util.get(session, tables.Move, identifier=u'sketch').id + self.no_eggs_group = util.get(session, tables.EggGroup, + identifier=u'no-eggs').id + self.ditto_group = util.get(session, tables.EggGroup, + identifier=u'ditto').id + if costs is None: self.costs = default_costs else: @@ -59,6 +66,8 @@ class MovesetSearch(object): self.goal_moves = frozenset(move.id for move in moves) self.goal_version_group = version.version_group_id + self.load_pokemon() + # Fill self.generation_id_by_version_group self.load_version_groups(version.version_group_id, [v.version_group_id for v in exclude_versions]) @@ -69,7 +78,7 @@ class MovesetSearch(object): lambda: defaultdict( # key: method list)))) # list of (level, cost) self.movepools = defaultdict(dict) # evo chain -> move -> best cost - self.learnpools = defaultdict(set) # as above, but not egg moves + self.learnpools = defaultdict(set) # evo chain -> move, w/o egg moves easy_moves, non_egg_moves = self.load_pokemon_moves( self.goal_evolution_chain, 'family') @@ -81,6 +90,8 @@ class MovesetSearch(object): self.load_pokemon_moves(self.goal_evolution_chain, 'others') def load_version_groups(self, version, excluded): + """Load generation_id_by_version_group + """ query = self.session.query(tables.VersionGroup.id, tables.VersionGroup.generation_id) query = query.join(tables.Version.version_group) @@ -145,6 +156,9 @@ class MovesetSearch(object): query = query.filter(or_( tables.PokemonMove.level > 100, # XXX: Chaff? tables.PokemonMove.move_id.in_(self.goal_moves), + tables.PokemonMove.move_id == self.sketch, + tables.PokemonMove.move_id.in_( + self.evolution_moves.values()), )) if self.excluded_families: query = query.filter(not_(tables.Pokemon.evolution_chain_id.in_( @@ -160,7 +174,7 @@ class MovesetSearch(object): easy_moves = set() non_egg_moves = set() for pokemon, move, vg, method, level, chain in query: - if move in self.goal_moves: + if move in self.goal_moves or move == self.sketch: cost = self.learn_cost(method, vg) self.movepools[chain][move] = min( self.movepools[chain].get(move, cost), cost) @@ -170,7 +184,7 @@ class MovesetSearch(object): if cost < self.costs['breed']: easy_moves.add(move) else: - cost = 0 + cost = -1 self.pokemon_moves[pokemon][move][vg][method].append((level, cost)) if self.debug and selection == 'family': print 'Easy moves:', sorted(easy_moves) @@ -219,6 +233,74 @@ class MovesetSearch(object): else: return self.costs['trade'] + self.costs['transfer'] + def load_pokemon(self): + """Load pokemon breed groups and evolutions + + self.egg_groups: maps evolution chains to their egg groups + self.evolution_chains: maps pokemon to their evolution chains + self.pokemon_by_evolution_chain: maps evolution chains to their pokemon + self.unbreedable: set of unbreedable pokemon + + self.evolution_parents: maps pokemon to their pre-evolved form + self.evolutions: maps pokemon to lists of (trigger, move, level, child) + self.evolution_moves: maps evolution_chains to their evolution moves + """ + eg1 = tables.PokemonEggGroup + eg2 = aliased(tables.PokemonEggGroup) + query = self.session.query( + tables.Pokemon.id, + tables.Pokemon.evolution_chain_id, + tables.Pokemon.evolves_from_pokemon_id, + eg1.egg_group_id, + eg2.egg_group_id, + ) + query = query.join((eg1, eg1.pokemon_id == tables.Pokemon.id)) + query = query.outerjoin((eg2, and_( + eg2.pokemon_id == tables.Pokemon.id, + eg1.egg_group_id < eg2.egg_group_id, + ))) + query = query.order_by(eg1.egg_group_id != None) + bad_groups = (self.no_eggs_group, self.ditto_group) + unbreedable = set() + self.evolution_parents = dict() + self.egg_groups = defaultdict(tuple) + self.evolution_chains = dict() + self.pokemon_by_evolution_chain = defaultdict(set) + for pokemon, evolution_chain, parent, g1, g2 in query: + if g1 in bad_groups: + unbreedable.add(pokemon) + else: + self.egg_groups[evolution_chain] = (g1, g2) if g2 else (g1, ) + self.evolution_chains[pokemon] = evolution_chain + self.pokemon_by_evolution_chain[evolution_chain].add(pokemon) + if parent: + self.evolution_parents[pokemon] = parent + self.unbreedable = frozenset(unbreedable) + + self.evolutions = defaultdict(set) + self.evolution_moves = dict() + query = self.session.query( + tables.PokemonEvolution.evolved_pokemon_id, + tables.EvolutionTrigger.identifier, + tables.PokemonEvolution.known_move_id, + tables.PokemonEvolution.minimum_level, + ) + query = query.join(tables.PokemonEvolution.trigger) + for child, trigger, move, level in query: + self.evolutions[self.evolution_parents[child]].add( + (trigger, move, level, child)) + if move: + self.evolution_moves[self.evolution_chains[child]] = move + + if self.debug: + print 'Loaded %s pokemon: %s evo; %s families: %s breedable' % ( + len(self.evolution_chains), + len(self.pokemon_by_evolution_chain), + len(self.egg_groups), + len(self.evolutions), + ) + print 'Evolution moves: %s' % self.evolution_moves + default_costs = { # Costs for learning a move in verious ways 'level-up': 20, # The normal way @@ -323,5 +405,8 @@ def main(argv): print 'Error:', search.error return 1 + if args.debug: + print 'Done' + if __name__ == '__main__': sys.exit(main(sys.argv[1:]))