Make the trade table more sparse

This commit is contained in:
Petr Viktorin 2011-04-27 13:16:39 +03:00
parent 15b92851ed
commit 0f28e1ebda

View file

@ -129,7 +129,7 @@ class MovesetSearch(object):
self.generator = InitialNode(self).find_all_paths(**kwargs) self.generator = InitialNode(self).find_all_paths(**kwargs)
def load_version_groups(self, version, excluded): def load_version_groups(self, version, excluded):
"""Load generation_id_by_version_group """Load generation_id_by_version_group & trade_costs
""" """
query = self.session.query(tables.VersionGroup.id, query = self.session.query(tables.VersionGroup.id,
tables.VersionGroup.generation_id) tables.VersionGroup.generation_id)
@ -139,7 +139,7 @@ class MovesetSearch(object):
self.generation_id_by_version_group = dict(query) self.generation_id_by_version_group = dict(query)
def expand(v2): def expand(v2):
for v1 in self.generation_id_by_version_group: for v1 in self.generation_id_by_version_group:
if self.trade_cost(v1, v2): if self.get_trade_cost(v1, v2):
yield 0, None, v1 yield 0, None, v1
def is_goal(v): def is_goal(v):
return True return True
@ -150,6 +150,12 @@ class MovesetSearch(object):
filtered_map[version] = ( filtered_map[version] = (
self.generation_id_by_version_group[version]) self.generation_id_by_version_group[version])
self.generation_id_by_version_group = filtered_map self.generation_id_by_version_group = filtered_map
self.trade_costs = defaultdict(dict)
for g1 in self.generation_id_by_version_group:
for g2 in self.generation_id_by_version_group:
cost = self.get_trade_cost(g1, g2)
if cost:
self.trade_costs[g1][g2] = cost
if self.debug_level > 1: if self.debug_level > 1:
print 'Excluded version groups:', excluded print 'Excluded version groups:', excluded
print 'Trade cost table:' print 'Trade cost table:'
@ -160,9 +166,37 @@ class MovesetSearch(object):
for g1 in sorted(self.generation_id_by_version_group): for g1 in sorted(self.generation_id_by_version_group):
print '%03s' % g1, print '%03s' % g1,
for g2 in sorted(self.generation_id_by_version_group): for g2 in sorted(self.generation_id_by_version_group):
print '%03s' % (self.trade_cost(g1, g2) or '---'), print '%03s' % self.trade_costs[g1].get(g2, '---'),
print print
def get_trade_cost(self, version_group_from, version_group_to):
"""Return cost of trading between versions, None if impossibble
The generation of traded moves/items should also be checked, if
trading to gen. 1.
"""
# XXX: this ignores HM transfer restrictions
if version_group_from == version_group_to:
# No reason to trade
return None
gen_from = self.generation_id_by_version_group[version_group_from]
gen_to = self.generation_id_by_version_group[version_group_to]
if gen_from == gen_to:
return self.costs['trade']
elif gen_from in (1, 2):
if gen_to in (1, 2):
return self.costs['trade']
else:
return None
elif gen_to in (1, 2):
return None
elif gen_from > gen_to:
return None
elif gen_from < gen_to - 1:
return None
else:
return self.costs['trade'] + self.costs['transfer']
def load_pokemon_moves(self, evolution_chain, selection): def load_pokemon_moves(self, evolution_chain, selection):
"""Load pokemon_moves, movepools, learnpools, smeargle_families """Load pokemon_moves, movepools, learnpools, smeargle_families
@ -258,33 +292,6 @@ class MovesetSearch(object):
print 'Smeargle families:', sorted(self.smeargle_families) print 'Smeargle families:', sorted(self.smeargle_families)
return easy_moves, non_egg_moves return easy_moves, non_egg_moves
def trade_cost(self, version_group_from, version_group_to, max_generation=None):
"""Return cost of trading between versions, None if impossibble
`max_generation` should be the maximum generation of the moves traded.
(also of pokemon, if those aren't checked another way)
"""
# XXX: this ignores HM transfer restrictions
gen_from = self.generation_id_by_version_group[version_group_from]
gen_to = self.generation_id_by_version_group[version_group_to]
if gen_from == gen_to:
return self.costs['trade']
elif gen_from in (1, 2):
if max_generation and max_generation > gen_to:
return None
elif gen_to in (1, 2):
return self.costs['trade']
else:
return None
elif gen_to in (1, 2):
return None
elif gen_from > gen_to:
return None
elif gen_from < gen_to - 1:
return None
else:
return self.costs['trade'] + self.costs['transfer']
def load_pokemon(self): def load_pokemon(self):
"""Load pokemon breed groups and evolutions """Load pokemon breed groups and evolutions
@ -1064,14 +1071,16 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode',
def expand_trade(self): def expand_trade(self):
search = self.search search = self.search
target_vgs = set(search.pokemon_moves[self.pokemon_]) target_vgs = search.trade_costs[self.version_group_]
target_vgs.add(search.goal_version_group) if not target_vgs:
target_vgs.discard(self.version_group_) return
max_generation = max(search.move_generations[m] for m in self.moves_) gen_from = search.generation_id_by_version_group[self.version_group_]
for version_group in target_vgs: if gen_from == 2:
cost = search.trade_cost(self.version_group_, version_group, max_gen = max(search.move_generations[m] for m in self.moves_)
max_generation) for version_group, cost in target_vgs.items():
if cost is not None: if (gen_from == 2 and max_gen == 2 and
search.generation_id_by_version_group[version_group] == 1):
continue
yield cost, TradeAction(search, version_group), self._replace( yield cost, TradeAction(search, version_group), self._replace(
version_group_=version_group, new_level=False) version_group_=version_group, new_level=False)
@ -1175,10 +1184,11 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode',
# don't matter at all. # don't matter at all.
# So, keep the computations here to a minimum. # So, keep the computations here to a minimum.
search = self.search search = self.search
trade_cost = search.trade_cost(self.version_group_, if self.version_group_ == search.goal_version_group:
search.goal_version_group) trade_cost = 0
if trade_cost is None: else:
trade_cost = search.costs['trade'] * 2 trade_cost = search.trade_costs[self.version_group_].get(
search.goal_version_group, search.costs['trade'] * 2)
return trade_cost return trade_cost
class BaseBreedNode(Node): class BaseBreedNode(Node):