Loading pokemon

Just stuffing them all in memory works fine.
This commit is contained in:
Petr Viktorin 2011-04-26 02:50:09 +03:00
parent bca84867c8
commit 02deb8c06a

View file

@ -5,6 +5,7 @@ import sys
import argparse import argparse
from collections import defaultdict from collections import defaultdict
from sqlalchemy.orm import aliased
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.sql.expression import not_, and_, or_ from sqlalchemy.sql.expression import not_, and_, or_
@ -37,6 +38,12 @@ class MovesetSearch(object):
self.session = session self.session = session
self.sketch = util.get(session, tables.Move, identifier=u'sketch').id
self.no_eggs_group = util.get(session, tables.EggGroup,
identifier=u'no-eggs').id
self.ditto_group = util.get(session, tables.EggGroup,
identifier=u'ditto').id
if costs is None: if costs is None:
self.costs = default_costs self.costs = default_costs
else: else:
@ -59,6 +66,8 @@ class MovesetSearch(object):
self.goal_moves = frozenset(move.id for move in moves) self.goal_moves = frozenset(move.id for move in moves)
self.goal_version_group = version.version_group_id self.goal_version_group = version.version_group_id
self.load_pokemon()
# Fill self.generation_id_by_version_group # Fill self.generation_id_by_version_group
self.load_version_groups(version.version_group_id, self.load_version_groups(version.version_group_id,
[v.version_group_id for v in exclude_versions]) [v.version_group_id for v in exclude_versions])
@ -69,7 +78,7 @@ class MovesetSearch(object):
lambda: defaultdict( # key: method lambda: defaultdict( # key: method
list)))) # list of (level, cost) list)))) # list of (level, cost)
self.movepools = defaultdict(dict) # evo chain -> move -> best cost self.movepools = defaultdict(dict) # evo chain -> move -> best cost
self.learnpools = defaultdict(set) # as above, but not egg moves self.learnpools = defaultdict(set) # evo chain -> move, w/o egg moves
easy_moves, non_egg_moves = self.load_pokemon_moves( easy_moves, non_egg_moves = self.load_pokemon_moves(
self.goal_evolution_chain, 'family') self.goal_evolution_chain, 'family')
@ -81,6 +90,8 @@ class MovesetSearch(object):
self.load_pokemon_moves(self.goal_evolution_chain, 'others') self.load_pokemon_moves(self.goal_evolution_chain, 'others')
def load_version_groups(self, version, excluded): def load_version_groups(self, version, excluded):
"""Load generation_id_by_version_group
"""
query = self.session.query(tables.VersionGroup.id, query = self.session.query(tables.VersionGroup.id,
tables.VersionGroup.generation_id) tables.VersionGroup.generation_id)
query = query.join(tables.Version.version_group) query = query.join(tables.Version.version_group)
@ -145,6 +156,9 @@ class MovesetSearch(object):
query = query.filter(or_( query = query.filter(or_(
tables.PokemonMove.level > 100, # XXX: Chaff? tables.PokemonMove.level > 100, # XXX: Chaff?
tables.PokemonMove.move_id.in_(self.goal_moves), tables.PokemonMove.move_id.in_(self.goal_moves),
tables.PokemonMove.move_id == self.sketch,
tables.PokemonMove.move_id.in_(
self.evolution_moves.values()),
)) ))
if self.excluded_families: if self.excluded_families:
query = query.filter(not_(tables.Pokemon.evolution_chain_id.in_( query = query.filter(not_(tables.Pokemon.evolution_chain_id.in_(
@ -160,7 +174,7 @@ class MovesetSearch(object):
easy_moves = set() easy_moves = set()
non_egg_moves = set() non_egg_moves = set()
for pokemon, move, vg, method, level, chain in query: for pokemon, move, vg, method, level, chain in query:
if move in self.goal_moves: if move in self.goal_moves or move == self.sketch:
cost = self.learn_cost(method, vg) cost = self.learn_cost(method, vg)
self.movepools[chain][move] = min( self.movepools[chain][move] = min(
self.movepools[chain].get(move, cost), cost) self.movepools[chain].get(move, cost), cost)
@ -170,7 +184,7 @@ class MovesetSearch(object):
if cost < self.costs['breed']: if cost < self.costs['breed']:
easy_moves.add(move) easy_moves.add(move)
else: else:
cost = 0 cost = -1
self.pokemon_moves[pokemon][move][vg][method].append((level, cost)) self.pokemon_moves[pokemon][move][vg][method].append((level, cost))
if self.debug and selection == 'family': if self.debug and selection == 'family':
print 'Easy moves:', sorted(easy_moves) print 'Easy moves:', sorted(easy_moves)
@ -219,6 +233,74 @@ class MovesetSearch(object):
else: else:
return self.costs['trade'] + self.costs['transfer'] return self.costs['trade'] + self.costs['transfer']
def load_pokemon(self):
"""Load pokemon breed groups and evolutions
self.egg_groups: maps evolution chains to their egg groups
self.evolution_chains: maps pokemon to their evolution chains
self.pokemon_by_evolution_chain: maps evolution chains to their pokemon
self.unbreedable: set of unbreedable pokemon
self.evolution_parents: maps pokemon to their pre-evolved form
self.evolutions: maps pokemon to lists of (trigger, move, level, child)
self.evolution_moves: maps evolution_chains to their evolution moves
"""
eg1 = tables.PokemonEggGroup
eg2 = aliased(tables.PokemonEggGroup)
query = self.session.query(
tables.Pokemon.id,
tables.Pokemon.evolution_chain_id,
tables.Pokemon.evolves_from_pokemon_id,
eg1.egg_group_id,
eg2.egg_group_id,
)
query = query.join((eg1, eg1.pokemon_id == tables.Pokemon.id))
query = query.outerjoin((eg2, and_(
eg2.pokemon_id == tables.Pokemon.id,
eg1.egg_group_id < eg2.egg_group_id,
)))
query = query.order_by(eg1.egg_group_id != None)
bad_groups = (self.no_eggs_group, self.ditto_group)
unbreedable = set()
self.evolution_parents = dict()
self.egg_groups = defaultdict(tuple)
self.evolution_chains = dict()
self.pokemon_by_evolution_chain = defaultdict(set)
for pokemon, evolution_chain, parent, g1, g2 in query:
if g1 in bad_groups:
unbreedable.add(pokemon)
else:
self.egg_groups[evolution_chain] = (g1, g2) if g2 else (g1, )
self.evolution_chains[pokemon] = evolution_chain
self.pokemon_by_evolution_chain[evolution_chain].add(pokemon)
if parent:
self.evolution_parents[pokemon] = parent
self.unbreedable = frozenset(unbreedable)
self.evolutions = defaultdict(set)
self.evolution_moves = dict()
query = self.session.query(
tables.PokemonEvolution.evolved_pokemon_id,
tables.EvolutionTrigger.identifier,
tables.PokemonEvolution.known_move_id,
tables.PokemonEvolution.minimum_level,
)
query = query.join(tables.PokemonEvolution.trigger)
for child, trigger, move, level in query:
self.evolutions[self.evolution_parents[child]].add(
(trigger, move, level, child))
if move:
self.evolution_moves[self.evolution_chains[child]] = move
if self.debug:
print 'Loaded %s pokemon: %s evo; %s families: %s breedable' % (
len(self.evolution_chains),
len(self.pokemon_by_evolution_chain),
len(self.egg_groups),
len(self.evolutions),
)
print 'Evolution moves: %s' % self.evolution_moves
default_costs = { default_costs = {
# Costs for learning a move in verious ways # Costs for learning a move in verious ways
'level-up': 20, # The normal way 'level-up': 20, # The normal way
@ -323,5 +405,8 @@ def main(argv):
print 'Error:', search.error print 'Error:', search.error
return 1 return 1
if args.debug:
print 'Done'
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(main(sys.argv[1:])) sys.exit(main(sys.argv[1:]))