diff --git a/pokedex/tests/test_movesets.py b/pokedex/tests/test_movesets.py index 2d9eca5..11a7283 100644 --- a/pokedex/tests/test_movesets.py +++ b/pokedex/tests/test_movesets.py @@ -30,12 +30,10 @@ def test_cases(): OK gyarados flail thrash iron-head outrage OK drifblim memento gust thunderbolt pain-split OK crobat nasty-plot brave-bird - OK crobat brave-bird hypnosis NO crobat nasty-plot hypnosis OK garchomp double-edge thrash outrage OK nidoking counter disable amnesia head-smash OK aggron stomp smellingsalt screech fire-punch - NO aggron endeavor body-slam OK tyranitar dragon-dance outrage thunder-wave surf NO butterfree morning-sun harden OK pikachu reversal bide nasty-plot discharge diff --git a/pokedex/util/movesets.py b/pokedex/util/movesets.py index a6533ca..347c9c1 100755 --- a/pokedex/util/movesets.py +++ b/pokedex/util/movesets.py @@ -259,11 +259,11 @@ class MovesetSearch(object): print 'Smeargle families:', sorted(self.smeargle_families) return easy_moves, non_egg_moves - def trade_cost(self, version_group_from, version_group_to, *thing_generations): + def trade_cost(self, version_group_from, version_group_to, max_generation=None): """Return cost of trading between versions, None if impossibble - `thing_generations` should be the generation IDs of the pokemon and - moves being traded. + `max_generation` should be the maximum generation of the moves traded. + (also of pokemon, if those aren't checked another way) """ # XXX: this ignores HM transfer restrictions gen_from = self.generation_id_by_version_group[version_group_from] @@ -271,7 +271,7 @@ class MovesetSearch(object): if gen_from == gen_to: return self.costs['trade'] elif gen_from in (1, 2): - if any(gen > gen_to for gen in thing_generations): + if max_generation and max_generation > gen_to: return None elif gen_to in (1, 2): return self.costs['trade'] @@ -836,10 +836,10 @@ class PokemonNode(Node, Facade, namedtuple('PokemonNode', target_vgs = set(search.pokemon_moves[self.pokemon_]) target_vgs.add(search.goal_version_group) target_vgs.discard(self.version_group_) + max_generation = max(search.move_generations[m] for m in self.moves_) for version_group in target_vgs: cost = search.trade_cost(self.version_group_, version_group, - *(search.move_generations[m] for m in self.moves_) - ) + max_generation) if cost is not None: yield cost, TradeAction(search, version_group), self._replace( version_group_=version_group, new_level=False) @@ -1098,6 +1098,8 @@ def main(argv): moves=''.join('.' if m in node.moves else ' ' for m in moves) + ''.join(m.name[0].lower() for m in node.moves if m not in moves), ) + # XXX: Support more results + break if args.debug: print