ビームサーチをライブラリ化する【基礎編】
7/30 追記:一部の実装を変更して高速化しました。(ビームサーチ内で毎回配列を定義するコストが無視できなかったため、一度だけ定義して毎回中身を全て消去する形にしました)
ヒューリスティックコンテストでよく用いられるビームサーチという手法があります。山登り/焼きなまし法と並んで必ずと言っていいほど紹介されるビームサーチですが、実装が大変と感じている方も多いと思います。
この記事ではビームサーチを既に知っている方向けに、ライブラリ化の指針を提供したいと考えています。ビームサーチ自体については簡単に説明する程度ですのでご了承ください*1。
今回は基礎編と題して、一般に広く用いられていると思われる実装方法を紹介します。応用編*2ではrhooさんの記事などで解説されている他の実装方法を紹介したいと考えています。本記事中のソースコードは自由に利用していただいて構いませんが、万一何らかの不利益が生じても筆者は責任を負えません。
説明にあたり具体的な例があるとわかりやすいと思うので、ビームサーチの練習問題としてよく取り上げられている「ゲーム実況者Xの挑戦」を題材にしてみたいと思います。
atcoder.jp
問題の概要としては、$N$ 個あるグリッド状のマップから $K$ 個を選び、全てのマップ上で同じ動き方(上下左右のいずれか)をしてなるべく多くのコインを取得することを目指します。時系列が重要なターン制の問題なのでビームサーチが有効そうな気がしますね。
ビームサーチの挙動を考える
全ての行動パターンを探索できれば厳密な最適解が得られるわけですが、毎ターン4つの選択肢があって2500ターンも続くので当然それは不可能です。そこでビームサーチでは、一部の評価が高い状態だけを残して次に進むことを繰り返すのでした*3。この挙動はどのような問題にも共通していますから、うまくライブラリ化できないか考えてみます。
実装例
実装の方針
ビームサーチで用いる各状態のことをノードと呼ぶことにします。上述の挙動を素直に実装するなら以下のようになるでしょうか。
- ノード構造体を定義する。残すノードを決められるように評価関数を用意する。
- 初期ノードを優先度付きキューに入れる。
- 優先度付きキューから評価が高い一定個数を取り出し、各ノードについて全ての遷移(今回は4通り)を試して次の優先度付きキューに全て入れる。
- 3. を最後まで繰り返す。
既に通過したマスで再度コインを取得してしまわないように、各マップについてどのマスを訪問済みかの情報をビット列で保持します。ではC++での実装例を見てみましょう。
ゲームの進行状況を表す構造体を用意する
まずは現在のゲームの進行状況を表すState構造体を定義し、それを各ノードにもたせることにします。その際、操作を表すOperation構造体やプレイヤーの情報をもつPlayer構造体を用意しています。
State構造体はビームサーチでコピーする情報を保持するので、なるべくコピーコストが小さくなるようにデータの持ち方を工夫したいです。ここが問題に固有の部分で、他はほとんど使い回せるような実装にします。
using ull = unsigned long long; struct State { int now_turn; int coins; int penalty; // 罠を踏んだ時にペナルティを加算する ull zobrist_hash; // 盤面の重複を削除するためのハッシュ値 vector<Player> players; vector<bitset<2500>> visited; string move_history; State(); int score() const; ull hash() const; void apply_move(const Operation& op); }; State::State() { // 初期化(初期盤面のハッシュ値の計算、プレイヤーの位置情報の取得など) } int State::score() const { // 評価関数を設定する } ull State::hash() const { // ハッシュ値、今回はzobrist hash } void State::apply_move(const Operation& op) { // ここにターンを進める処理を書く // (スコアとハッシュ値の計算、プレイヤーや訪問済み情報の更新など) }
これまでの移動の履歴をmove_historyという文字列で保持して、最後に操作を復元できるようにします。
ノード構造体を用意する
続いてビームサーチで用いるノード構造体を定義します。
struct Node { State state; Node(State& state); int get_score() const; ull get_hash() const; void advance(const Operation& op); bool operator< (const Node& node) const; }; Node::Node(State& state) : state(state) {} int Node::get_score() const { return state.score(); } ull Node::get_hash() const { return state.hash(); } void Node::advance(const Operation& op) { state.apply_move(op); } bool Node::operator< (const Node& node) const { return get_score() < node.get_score(); }
ノードをコピーしてadvance関数を呼ぶことで状態が更新されます。ノードの実装はどの問題でも基本的に変更する必要がないようにしています。
ビームサーチを実装する
準備ができたので、いよいよビームサーチの実装です。各ノードからの遷移を書きやすくするために、ありうる操作を列挙したvalid_operationsというOperation型の配列を用意しています。
Node BeamSearch(State& init_state, const int max_depth, const int beam_width) { priority_queue<Node> nodes; nodes.emplace(init_state); // 初期状態を優先度付きキューに入れる unordered_set<ull> fields; // 重複除去用 for(int turn = 1; turn <= max_depth; turn++) { priority_queue<Node> next_nodes; // 上位beam_width個だけ取り出す for(int i = 0; i < beam_width; i++) { if(nodes.empty()) break; Node node = nodes.top(); nodes.pop(); // 可能な全ての遷移を試す for(auto& op : valid_operations) { Node next_node = node; next_node.advance(op); if(fields.count(next_node.get_hash())) continue; // 過去との重複をチェック next_nodes.emplace(next_node); } } while(!nodes.empty()) nodes.pop(); // 重複を除去しながら上位beam_width個を採用する while(!next_nodes.empty() && (int)nodes.size() < beam_width) { if(fields.count(next_nodes.top().get_hash())) { // このターン内の重複をチェック next_nodes.pop(); continue; } fields.insert(next_nodes.top().get_hash()); nodes.emplace(next_nodes.top()); next_nodes.pop(); } } return nodes.top(); }
ビームサーチ自体もほとんど変更を加えず使い回せると思います。前述のように、問題に固有なのはコピーすべき状態をもつState構造体や各操作を表すOperation構造体になります。
この問題は実行制限時間が4秒で、ここまでの実装でビーム幅を200程度は確保できると思います。ですが実は無駄な処理をしていて高速化の余地が残っています。
高速化するには
スコアだけ先に計算して上位を決定する
ビームサーチで実行時間のボトルネックになるのは、多くの場合ノードのコピーでしょう。上の実装では毎ターン最大で beam_width * 操作の種類数 だけコピーが発生しますが、よく考えると結局採用されないノードが多数あり無駄なコピーが発生しています。
これを回避するには、ノードをコピーせずにスコア(とハッシュ値)だけ先に計算して上位を決定すれば良いです。今までState構造体には実際に操作をして状態を更新するapply_moveという関数しかありませんでしたが、状態を更新せずにスコアだけ計算するtry_moveという関数を用意します。
struct State { // ...他の部分は変更なし pair<int,ull> try_move(const Operation& op) const; // 状態を更新しないのでconstをつけておく } pair<int,ull> State::try_move(const Operation& op) const { // 一手進めた場合のスコアとハッシュ値を返す、ただし状態を更新はしない }
ノード構造体にも対応するcalculateという関数を用意します。
struct Node { // ...他の部分は変更なし pair<int,ull> calculate(const Operation& op) const; } pair<int,ull> Node::calculate(const Operation& op) const { return state.try_move(op); }
そしてビームサーチではノードで比較するのではなく、スコアなどの情報を保持した一時ノードを用意して比較します。盤面の情報をもたないのでコピーコストが軽くなっています。
random_device rnd; mt19937 engine(rnd()); uniform_real_distribution<> randR(0.0, 1.0); // スコアだけ計算して上位を選ぶために用いる仮ノード struct TemporaryNode { int score; ull hash; int node_index; Operation op; double rand; // タイブレーク用 TemporaryNode(int score, ull hash, int node_index, Operation& op); }; TemporaryNode::TemporaryNode(int score, ull hash, int node_index, Operation& op) : score(score), hash(hash), node_index(node_index), op(op) { rand = randR(engine); }
スコアが同じ状態に差をつけるために乱数を入れておきました。
続いてビームサーチ本体です。先ほどは優先度付きキューで上位を取り出しましたが、C++には指定した個数だけ配列の上位を並び替えてくれるnth_elementという便利な関数があるのでそれを活用してみます。
Node BeamSearch(State& init_state, const int max_depth, const int beam_width) { vector<Node> nodes, next_nodes; nodes.emplace_back(init_state); nodes.back().move_history = Stack{nullptr}; vector<TemporaryNode> temp_nodes; // スコア比較用の仮ノードを保管 unordered_set<ull> fields; // 重複除去用 for(int turn = 1; turn <= max_depth; turn++) { temp_nodes.clear(); fields.clear(); for(int i = 0; i < (int)nodes.size(); i++) { // 可能な全ての遷移を試す for(auto& op : valid_operations) { auto [next_score, next_hash] = nodes[i].calculate(op); temp_nodes.emplace_back(next_score, next_hash, i, op); // 必要なら重複除去 if(fields.count(temp_nodes.back().hash)) { temp_nodes.pop_back(); } else { fields.insert(temp_nodes.back().hash); } } } int node_size = temp_nodes.size(); // 候補がビーム幅より多いなら上位beam_width個を選ぶ if(node_size > beam_width) { nth_element(temp_nodes.begin(), temp_nodes.begin() + beam_width, temp_nodes.end(), [](TemporaryNode& n1, TemporaryNode& n2) { if(n1.score == n2.score) { return n1.rand > n2.rand; } return n1.score > n2.score; }); } // 仮ノードの情報から実際にノードを更新する for(int i = 0; i < min(beam_width, node_size); i++) { int index = temp_nodes[i].node_index; next_nodes.emplace_back(nodes[index]); next_nodes.back().advance(temp_nodes[i].op); // 必要ならスコアとハッシュ値を確認 // assert(next_nodes.back().score == temp_nodes[i].score); // assert(next_nodes.back().hash == temp_nodes[i].hash); } swap(nodes, next_nodes); next_nodes.clear(); } int arg_best = -1, best_score = 0; for(int i = 0; i < (int)nodes.size(); i++) { if(nodes[i].get_score() > best_score) { arg_best = i; best_score = nodes[i].get_score(); } } return nodes[arg_best]; }
これもほとんど変更せずに使い回せると思います。一時ノードには比較関数を定義していませんが、ラムダ式でnth_elementに直接渡しています。これで毎ターン最大beam_width個のノードのコピーで済むようになりました。
発展:永続スタックを用いた行動履歴の復元
記事のタイトルに基礎編と書きましたが、最後に一つ発展的な内容を紹介します。
今までは移動の履歴を文字列で保持していましたが、終盤になると文字列のサイズが大きくなりコピーコストも無視できません。そこで文字列を毎回コピーして直接保持するのではなく効率的に復元する方法を考えます。
各ノードに親ノードへのポインタをもたせておきたくなりますが、過去のノードは破棄されてしまいますし、かといって全て保存しておいたらメモリが大変なことになります。そこで永続データ構造と呼ばれるものを利用して履歴を保存してみます。詳しくはnoshi91さんの記事をご覧ください*4。ここではスマートポインタを用いてメモリを管理する実装方法を選択しました。もう参照されることがない履歴は自動的にメモリを解放してくれるので便利です。
// 行動を復元する永続stack struct History { Operation op; shared_ptr<History> parent; History(const Operation& op, shared_ptr<History>& parent) : op(op), parent(parent) {} }; struct Stack { shared_ptr<History> head; Operation top(); Stack push(const Operation& op); Stack pop(); }; Operation Stack::top() { return head->op; } Stack Stack::push(const Operation& op) { return Stack({make_shared<History>(op, head)}); } Stack Stack::pop() { return Stack({head->parent}); }
ノードには新たにスタックをもたせます。
struct Node { State state; Stack move_history; // これが追加された Node(State& state); int get_score() const; ull get_hash() const; pair<int,ull> calculate(const Operation& op) const; void advance(const Operation& op); };
ビームサーチでは各ノードでスタックに操作を追加します。
Node BeamSearch(State& init_state, const int max_depth, const int beam_width) { vector<Node> nodes, next_nodes; nodes.emplace_back(init_state); nodes.back().move_history = Stack{nullptr}; vector<TemporaryNode> temp_nodes; // スコア比較用の仮ノードを保管 unordered_set<ull> fields; // 重複除去用 for(int turn = 1; turn <= max_depth; turn++) { // (略) // 仮ノードの情報から実際にノードを更新する vector<Node> next_nodes; for(int i = 0; i < min(beam_width, node_size); i++) { int index = temp_nodes[i].node_index; next_nodes.emplace_back(nodes[index]); next_nodes.back().advance(temp_nodes[i].op); // 必要ならスコアとハッシュ値を確認 // assert(next_nodes.back().score == temp_nodes[i].score); // assert(next_nodes.back().hash == temp_nodes[i].hash); fields.insert(next_nodes.back().get_hash()); // 親ノードのスタックに操作を追加して新しいノードのスタックを作成する next_nodes.back().move_history = nodes[index].move_history.push(temp_nodes[i].op); } swap(nodes, next_nodes); next_nodes.clear(); } // (略) }
実際に操作を復元する際は以下のように用います。
State init_state; constexpr int max_depth = 2500, beam_width = 1000; Node result = BeamSearch(init_state, max_depth, beam_width); vector<Operation> moves; Stack move_history = result.move_history; while(move_history.head) { Operation op = move_history.top(); moves.emplace_back(op); move_history = move_history.pop(); } reverse(moves.begin(), moves.end());
最終的な実装例
以上をまとめて次のようなライブラリになりました。この実装で今回の問題ではビーム幅1000以上を達成できています。まだまだ改良の余地はあると思いますので、皆さんも最強のビームサーチライブラリを自作しちゃいましょう。長くなってしまいましたが、最後までお読みいただきありがとうございました。
random_device rnd; mt19937 engine(rnd()); uniform_real_distribution<> randR(0.0, 1.0); using ull = unsigned long long; struct State { // コピーすべき情報をここに書く State(); int score() const; ull hash() const; pair<int,ull> try_move(const Operation& op) const; void apply_move(const Operation& op); }; State::State() { } int State::score() const { } ull State::hash() const { } // 一手進めた場合のスコアとハッシュ値を返す、更新はしない pair<int,ull> State::try_move(const Operation& op) const { } // 更新する void State::apply_move(const Operation& op) { } // 行動を復元する永続stack struct History { Operation op; shared_ptr<History> parent; History(const Operation& op, shared_ptr<History>& parent) : op(op), parent(parent) {} }; struct Stack { shared_ptr<History> head; Operation top(); Stack push(const Operation& op); Stack pop(); }; Operation Stack::top() { return head->op; } Stack Stack::push(const Operation& op) { return Stack({make_shared<History>(op, head)}); } Stack Stack::pop() { return Stack({head->parent}); } struct Node { State state; Stack move_history; Node(State& state); int get_score() const; ull get_hash() const; pair<int,ull> calculate(const Operation& op) const; void advance(const Operation& op); }; Node::Node(State& state) : state(state) {} int Node::get_score() const { return state.score(); } ull Node::get_hash() const { return state.hash(); } pair<int,ull> Node::calculate(const Operation& op) const { return state.try_move(op); } void Node::advance(const Operation& op) { state.apply_move(op); } // スコアだけ計算して上位を選ぶために用いる仮ノード struct TemporaryNode { int score; ull hash; int node_index; Operation op; double rand; // タイブレーク用 TemporaryNode(int score, ull hash, int node_index, Operation& op); }; TemporaryNode::TemporaryNode(int score, ull hash, int node_index, Operation& op) : score(score), hash(hash), node_index(node_index), op(op) { rand = randR(engine); } Node BeamSearch(State& init_state, const int max_depth, const int beam_width) { vector<Node> nodes, next_nodes; nodes.emplace_back(init_state); nodes.back().move_history = Stack{nullptr}; vector<TemporaryNode> temp_nodes; // スコア比較用の仮ノードを保管 unordered_set<ull> fields; // 重複除去用 for(int turn = 1; turn <= max_depth; turn++) { temp_nodes.clear(); fields.clear(); for(int i = 0; i < (int)nodes.size(); i++) { // 可能な全ての遷移を試す for(auto& op : valid_operations) { auto [next_score, next_hash] = nodes[i].calculate(op); temp_nodes.emplace_back(next_score, next_hash, i, op); // 必要なら重複除去 if(fields.count(temp_nodes.back().hash)) { temp_nodes.pop_back(); } else { fields.insert(temp_nodes.back().hash); } } } int node_size = temp_nodes.size(); // 候補がビーム幅より多いなら上位beam_width個を選ぶ if(node_size > beam_width) { nth_element(temp_nodes.begin(), temp_nodes.begin() + beam_width, temp_nodes.end(), [](TemporaryNode& n1, TemporaryNode& n2) { if(n1.score == n2.score) { return n1.rand > n2.rand; } return n1.score > n2.score; }); } // 仮ノードの情報から実際にノードを更新する for(int i = 0; i < min(beam_width, node_size); i++) { int index = temp_nodes[i].node_index; next_nodes.emplace_back(nodes[index]); next_nodes.back().advance(temp_nodes[i].op); // 必要ならスコアとハッシュ値を確認 // assert(next_nodes.back().score == temp_nodes[i].score); // assert(next_nodes.back().hash == temp_nodes[i].hash); // 親ノードのスタックに操作を追加して新しいノードのスタックを作成する next_nodes.back().move_history = nodes[index].move_history.push(temp_nodes[i].op); } swap(nodes, next_nodes); next_nodes.clear(); } int arg_best = -1, best_score = 0; for(int i = 0; i < (int)nodes.size(); i++) { if(nodes[i].get_score() > best_score) { arg_best = i; best_score = nodes[i].get_score(); } } return nodes[arg_best]; }