競技プログラミング学習

色々な問題について自分なりの方針などをまとめていきます。

ビームサーチをライブラリ化する【応用編】

前回の【基礎編】*1の記事では、ビームサーチの中でも一般に広く用いられていると思われる実装例を紹介しました。基本的にボトルネックとなるのは状態のコピーであり、コピーコストを軽くしたりスコアだけ先に計算してコピーを減らすなどして高速化を図りました。
今回紹介するのは、そもそも状態のコピーをやめてしまおうという実装方法です。一般にビームサーチにおけるノードは木構造を成していますから、毎回木を深さ優先探索(DFS)等で走査してスコアやハッシュ値を計算します。C++での実装例を3種類紹介しますが、いずれもポインタを駆使した実装になります。
本記事の内容はrhooさんによる解説記事 高速なビームサーチが欲しい!!! および 爆速ビームサーチライブラリを作る を参考にしていますので、こちらもぜひお読みください。
基礎編および応用編で紹介した実装はGitHub上でも公開しています。自由に使用していただいて構いませんが、結果に責任は負えませんことをご了承ください。

どんな問題で活躍するのか?

例えば前回題材とした「ゲーム実況者Xの挑戦」等のグリッド上を移動し続けるような問題では、グリッドの情報や各マスの訪問履歴などのコピーコストがかなり大きくなる場合があります。訪問履歴をビット列で保持する等コピーコストを小さくする工夫はできますが、それでもコストが大きそうな場合は今回紹介するコピーしない形式の実装も選択肢に入ります。
コピーコストが小さくなるように状態を設定できたら前回のような実装、それが難しいようなら今回のような実装が適しているでしょう。

実装する際の注意点

子ノードが全て不採用となった場合、そのノードはもはや不要となります。しかし不要となったノードを放置しておくと、その後のDFSにおいてもう調査する必要がない部分木を何度も行き来する可能性があります。そうなっては困るので、不採用になった葉から順に親をたどるなどして不要なノードは消去していく必要があります。これを可能にするために、各ノードには親および子ノードへのポインタをもたせることにします。

実装例1:生ポインタを用いる方法

実装方針

基礎編の実装例では各ノードに状態をコピーしていましたが、今回はノードは状態を保持しません。初期状態を一つ用意してDFSによって状態の更新と復元を繰り返しながら各ノードを調査していきます。以下のような形で実装してみます。

  1. 各ノードには親ノードへのポインタと子ノードへのポインタの配列を用意する。
  2. 初期状態を根として、子ノードへのポインタをたどりDFSで木を走査する。
  3. 子ノードへのポインタの配列が空ならばそのノードが葉であるはずなので、全ての操作を試して候補に追加する。
  4. 候補の中から上位beam_width個を採用し、葉から順に不要なノードを検出していく。
  5. 2. から 4. を最後まで繰り返す。

この大まかな方針は以降も変わりません。問題は不要なノードの検出方法とその扱いです。

実装例

ノード構造体に親ノードおよび子ノードへのポインタと、不要になったかどうかを表すexpiredという変数を用意します。

struct Node {
    Operation op;
    Node* parent;
    vector<Node*> children;
    bool expired;

    Node(Operation op, Node* parent);
};
Node::Node(Operation op, Node* parent) : op(op), parent(parent) {
    expired = false;
}

前回と同じように、スコアの比較は一時ノードで行うことにします。今回は本来のスコアraw_scoreと評価スコア(比較で使う値)eval_scoreをそれぞれ用意しました。

using ull = unsigned long long;

struct TemporaryNode {
    int raw_score, eval_score;
    ull hash;
    Operation op;
    Node* parent;

    TemporaryNode(int raw_score, int eval_score, ull hash, Operation op);
};
TemporaryNode::TemporaryNode(int raw_score, int eval_score, ull hash, Operation op) :
raw_score(raw_score), eval_score(eval_score), hash(hash), op(op) {
    parent = nullptr;
}

前回と同様に現在の状態を表すState構造体を用意します。ただし前回と違って状態を戻すという作業が加わるので、状態を更新した際に復元するための情報を返すようにします。

// 状態を復元するのに必要な情報をまとめた構造体
struct Restore {
    // ここに情報を書く
    Restore();
};

struct State {
    // (略)

    State();
    int score() const;
    ull hash() const;
    TemporaryNode try_move(const Operation& op) const;
    Restore apply_move(const Operation& op); // 状態を戻すための情報を返す

    void roll_back(const Restore& res, const Operation& op); // 状態を戻す関数が加わった
};

準備ができたのでDFSを実装します。現在の状態と根ノードへのポインタをもつTree構造体を用意し、その中にdfs関数を書きました。根ノードは操作や親ノードをもたないので、今回はそれぞれ-1とnullptrで初期化しました*2

struct Tree {
    State state;
    Node* root_node;

    Tree(State& state);
    void dfs(Node* node_ptr, vector<TemporaryNode>& temp_nodes, bool single);
};
Tree::Tree(State& state) : state(state) {
    root_node = new Node(Operation{-1}, nullptr);
}
void Tree::dfs(Node* node_ptr, vector<TemporaryNode>& temp_nodes, bool single) {
    // 子ノードが存在しないなら葉なので、候補を追加して終了
    if(node_ptr->children.empty()) {
        for(const auto& op : valid_operations) {
            temp_nodes.emplace_back(state.try_move(op));
            temp_nodes.back().parent = node_ptr;
        }
        node_ptr->expired = true;
        return;
    }

    // 使われない子ノードを削除
    node_ptr->children.erase(remove_if(node_ptr->children.begin(), node_ptr->children.end(),
    [](Node* child_ptr) { return child_ptr->expired; }), node_ptr->children.end());

    // 一本道なら状態を戻さない
    bool next_single = single && ((int)node_ptr->children.size() == 1);

    auto node_backup = node_ptr;
    // 残った子ノードを走査
    for(auto& child_ptr : node_ptr->children) {
        Restore res = state.apply_move(child_ptr->op);
        dfs(child_ptr, temp_nodes, next_single);
        if(!next_single) state.roll_back(res, child_ptr->op);
    }
    if(!next_single) root_node = node_backup;

    // ノードを不要としておき、必要なら後で復活させる
    node_ptr->expired = true;
}

実装について二つ補足します。もし根から途中まで一本道だったとすれば、初めて分岐が発生するノードまで状態を戻せば十分で、毎回初期状態まで戻す必要はありません。それをsingleという変数で判定しています。すなわち、ここまでが一本道でありかつ子が一つしかないならnext_singleもtrueとなるようにしています。
もう一つの補足は、不要になったノードをどのように判定しているかです。本来であれば不要になったノードは葉から順にdelete演算子で削除したいのですが、削除すべきノードは親からも参照されていることに注意してください。言い換えると、削除されなかったノードの一部は子を参照していたダングリングポインタ*3を保持したままになります。これは次にDFSを実行した際に問題を起こしますから、ノードをdeleteするのは得策ではなさそうです。仕方ないのでメモリの効率化は諦めて、まだ必要なノードだけ復活させることにしました。すなわち、DFSをしながら一旦全てのノードを不要としておき、採用されたノードから親をたどって順次ノードを復活させていきます。
やや強引ですが、これでDFSを実装できました。とはいえ不要なノードを残しておく意味はないですし、効率的に削除できればノードを復活させるといった作業もしなくて済みます。後で紹介する2つの実装はこの問題に対する解決策となっています。
最後にビームサーチ本体の実装です。基本的には前回の実装をベースにしています。候補の追加はDFSに任せて、採用すると決定したノードから親をたどってノードを復活させます。

vector<Operation> BeamSearch(const int max_depth, const int beam_width) {
    State init_state;
    Tree tree(init_state);

    vector<TemporaryNode> final_nodes;

    unordered_set<ull> fields;
    vector<TemporaryNode> temp_nodes;

    for(int turn = 1; turn <= max_depth; turn++) {
        fields.clear();
        temp_nodes.clear();

        tree.dfs(tree.root_node, temp_nodes, true);
        // 最後のターンなら一時ノードの情報を保存して終了
        if(turn == max_depth) {
            final_nodes = temp_nodes;
            break;
        }

        int node_size = temp_nodes.size();
        // 候補がビーム幅より多いなら上位beam_width個を選ぶ
        if(node_size > beam_width) {
            nth_element(temp_nodes.begin(), temp_nodes.begin() + beam_width, temp_nodes.end(),
            [](TemporaryNode& n1, TemporaryNode& n2) {
                return n1.eval_score > n2.eval_score;
            });
        }
        // 仮ノードの情報から実際にノードを更新する
        for(int i = 0; i < min(beam_width, node_size); i++) {
            if(fields.count(temp_nodes[i].hash)) continue;
            fields.insert(temp_nodes[i].hash);
            temp_nodes[i].parent->children.emplace_back(
                new Node(temp_nodes[i].op, temp_nodes[i].parent)
            );

            // 採用されたノードから親をたどりノードを復活させる
            Node* node_ptr = temp_nodes[i].parent;
            while(node_ptr && node_ptr->expired) {
                node_ptr->expired = false;
                node_ptr = node_ptr->parent;
            }
        }
    }

    // 最良の状態を選択
    int arg_best = -1;
    int best_score = 0;
    for(int i = 0; i < (int)final_nodes.size(); i++) {
        if(final_nodes[i].raw_score > best_score) {
            arg_best = i;
            best_score = final_nodes[i].raw_score;
        }
    }
    assert(arg_best != -1);
    Operation op = final_nodes[arg_best].op;
    auto ptr = final_nodes[arg_best].parent;

    vector<Operation> result{op};
    // 操作の復元
    while(ptr->parent) {
        result.emplace_back(ptr->op);
        ptr = ptr->parent;
    }
    reverse(result.begin(), result.end());
    return result;
}

実装例2:スマートポインタを用いる方法

スマートポインタを用いる利点

実装例1では、不正なポインタが残ってしまわないようにノードの消去をしませんでした。しかしスマートポインタを用いるとこの問題を巧妙に回避できます。今回は以下の二種類のポインタを利用します。

  • shared_ptr ・・・同じ参照先を複数のポインタで共有できる。最後のshared_ptrが破棄される時に参照先を自動で解放する。
  • weak_ptr ・・・shared_ptrの参照先のみを参照できる。weak_ptrが残っていても最後のshared_ptrは構わず参照先を解放する。

この性質の違いが重要です。そしてweak_ptrには参照先が解放されているかを返すexpiredという関数があるため、これを使えば不正なアクセスを避けることができます。
これをどう用いるのかと言うと、子ノードへのポインタはweak_ptrとし、逆に親ノードへのポインタはshared_ptrとします。各ノードは子からのshared_ptrによる参照があるため生きていますが、子が存在しなくなると親からのweak_ptrしか残らないため自動的に解放されます。挙動としては実装例1で述べたdelete演算子を用いたノードの消去に似ていますが、今回はweak_ptrを用いているため参照先が解放されているか簡単に判定できます。すなわち、次のDFSにおいて参照先が解放されているweak_ptrを削除するだけで済みます。

実装例

ノード構造体は以下のようになります。

struct Node {
    Operation op;
    shared_ptr<Node> parent;
    vector<weak_ptr<Node>> children;

    Node(Operation op, shared_ptr<Node> parent);
};
Node::Node(Operation op, shared_ptr<Node> parent) : op(op), parent(parent) {}

一時ノードにおける親ノードへのポインタも同様にshared_ptrに変更します。DFSの実装も、ノードのexpired変数をtrueに変更していた行を削除するだけで後は同じです。それではビームサーチ本体の実装を見てみます。

vector<Operation> BeamSearch(const int max_depth, const int beam_width) {
    State init_state;
    Tree tree(init_state);

    vector<shared_ptr<Node>> current_nodes;
    vector<TemporaryNode> final_nodes;

    unordered_set<ull> fields;
    vector<TemporaryNode> temp_nodes;

    for(int turn = 1; turn <= max_depth; turn++) {
        fields.clear();
        temp_nodes.clear();

        tree.dfs(tree.root_node, temp_nodes, true);
        // 最後のターンなら一時ノードの情報を保存して終了
        if(turn == max_depth) {
            final_nodes = temp_nodes;
            break;
        }

        int node_size = temp_nodes.size();
        // 候補がビーム幅より多いなら上位beam_width個を選ぶ
        if(node_size > beam_width) {
            nth_element(temp_nodes.begin(), temp_nodes.begin() + beam_width, temp_nodes.end(),
            [](TemporaryNode& n1, TemporaryNode& n2) {
                return n1.eval_score > n2.eval_score;
            });
        }
        // 仮ノードの情報から実際にノードを更新する
        current_nodes.clear();
        for(int i = 0; i < min(beam_width, node_size); i++) {
            if(fields.count(temp_nodes[i].hash)) continue;
            fields.insert(temp_nodes[i].hash);
            current_nodes.emplace_back(make_shared<Node>(temp_nodes[i].op, temp_nodes[i].parent));
            temp_nodes[i].parent->children.emplace_back(current_nodes.back());
        }
    }
    // (略)
}

実装例1とほとんど同じです。採用された葉のノードが、shared_ptrの配列current_nodesに保管されています。こうすることで一時ノードを保管しているtemp_nodesの中身が消去された時点で、不採用のノードはshared_ptrからの参照がなくなり自動的に解放されます。そして子が全て解放されたノードも自動的に解放される、といった要領で連鎖的に不要なノードが削除されていきます。

実装例3:二重連鎖木による高速化

これまでの実装の欠点と二重連鎖木を用いる利点

実装例1では不要なノードが残るため、メモリ使用量が大きくなったり必要なノードを復活させなければならないという欠点がありました。スマートポインタを用いた実装例2はこの問題点を解消できていますが、スマートポインタは丁寧にメモリを管理する分どうしても実行速度が下がってしまいます。生のポインタを使いながら不要なノードを消去する方法を考えたくなります。
実装例1では各ノードが全ての子ノードへのポインタを保持していたため、既に解放された子ノードへアクセスする危険がありました。これを回避するために、代表となる子ノードを設定しその子ノードへのポインタのみを保持することにします。そして各ノードには親ノードと子ノードに加えて前後の兄弟ノードへのポインタも用意します。言い換えると、今まで子へのポインタを並列に保持していたのを直列に変更します。このデータ構造を二重連鎖木と呼びます。
こうすると、不正なポインタが残ることを防げます。例えば削除するノードに前の兄弟がいないならそれは子の代表ですから、次の兄弟に代表を譲ってから削除します。もし前の兄弟も次の兄弟もいないなら自身が最後の子ですから、削除する時に親ノードも一緒に削除します。

実装例

ノード構造体は以下のようになります。高速化のためビームサーチの実装方法を少し変えたので、それに合わせてinitializeという初期化関数を用意しました。

struct Node {
    Operation op;
    Node* parent;
    Node* child;
    Node* prev_sibling;
    Node* next_sibling;

    Node();
    void initialize(Operation op, Node* parent);
};
Node::Node() {
    op = -1;
    parent = nullptr;
    child = nullptr;
    prev_sibling = nullptr;
    next_sibling = nullptr;
}
void Node::initialize(Operation op, Node* parent) {
    this->op = op;
    this->parent = parent;
    child = nullptr;
    prev_sibling = nullptr;
    next_sibling = nullptr;
}

一時ノードやDFSは今までと同様に実装できます。ビームサーチ本体は以下のようにしました。

vector<Operation> BeamSearch(const int max_depth, const int beam_width) {
    State init_state;
    Tree tree(init_state);

    vector<TemporaryNode> final_nodes, temp_nodes;
    vector<Node*> current_nodes, next_nodes;

    constexpr int max_nodes = 15000; // 要調整
    vector<Node> valid_nodes(max_nodes);
    tree.root_node = &valid_nodes[0];
    vector<Node*> vacant;
    for(int i = max_nodes - 1; i > 0; i--) {
        vacant.emplace_back(&valid_nodes[i]);
    }

    unordered_set<uint> fields;

    for(int turn = 1; turn <= max_depth; turn++) {
        // (略)

        // 仮ノードの情報から実際にノードを更新する
        for(int i = 0; i < min(beam_width, node_size); i++) {
            if(fields.count(temp_nodes[i].hash)) continue;
            fields.insert(temp_nodes[i].hash);

            Node* parent = temp_nodes[i].parent;
            /*
            // 必要ならばノード数が足りているか確認する
            if(vacant.empty()) {
                std::cerr << "max_nodes を大きくしてください" << std::endl;
            }
            assert(!vacant.empty());
            */
            // 既に子がいるなら新たに代表の子として挿入する
            if(parent->child) {
                parent->child->prev_sibling = vacant.back();
                vacant.back()->initialize(temp_nodes[i].op, parent);

                parent->child->prev_sibling->next_sibling = parent->child;
                parent->child = parent->child->prev_sibling;
            }
            // 子がいないなら代表の子とする
            else {
                parent->child = vacant.back();
                vacant.back()->initialize(temp_nodes[i].op, parent);
            }
            next_nodes.emplace_back(vacant.back());
            vacant.pop_back();
        }
        // 子がいないノードを再帰的に削除する
        for(auto ptr : current_nodes) {
            while(!ptr->child) {
                // 前も後ろも兄弟がいる場合
                if(ptr->prev_sibling && ptr->next_sibling) {
                    ptr->prev_sibling->next_sibling = ptr->next_sibling;
                    ptr->next_sibling->prev_sibling = ptr->prev_sibling;
                    vacant.emplace_back(ptr);
                    break;
                }
                // 前の兄弟だけいる場合
                else if(ptr->prev_sibling && !ptr->next_sibling) {
                    ptr->prev_sibling->next_sibling = nullptr;
                    vacant.emplace_back(ptr);
                    break;
                }
                // 後ろの兄弟だけいる場合
                else if(!ptr->prev_sibling && ptr->next_sibling) {
                    ptr->next_sibling->prev_sibling = nullptr;
                    ptr->parent->child = ptr->next_sibling;
                    vacant.emplace_back(ptr);
                    break;
                }
                // 両方いない場合はさらに親も削除する
                else {
                    if(!ptr->parent) {
                        vacant.emplace_back(ptr);
                        break;
                    }
                    else {
                        vacant.emplace_back(ptr);
                        ptr = ptr->parent;
                        ptr->child = nullptr;
                    }
                }
            }
        }

        swap(current_nodes, next_nodes);
        next_nodes.clear();
    }
    // (略)
}

必要になるたびにnew演算子でノードを生成するのではなく、あらかじめ一定数のノードを用意しておき、不要になったノードは後で上書きするという実装にしました*4

まとめ

今回はノードに状態をコピーしないビームサーチの実装方法を紹介しました。状態のコピーコストが大きくなってしまう場合はこちらの方が良い場合も多いかもしれません。
現在のところ、今回紹介した実装例の中では最後の二重連鎖木を用いた実装が優秀で一番ビーム幅を確保できるようです。ですが他にも良い実装方法があるかもしれませんので、皆さんもぜひ考えてみてください。実装例もあくまで一例ですので、内容を理解できた方は自分が使いやすいように一から実装してみるのも面白いと思います。

*1:本記事の内容は基礎編の内容がベースになっているために応用編としており、基礎編が簡単で応用編が難しいといった意図ではありません。

*2:より丁寧に実装するなら、操作の無効値を-1とするのではなく、std::optionalを用いてnulloptを代入することもできます。

*3:参照先のメモリが解放される等によって無効なメモリ領域を指しているポインタのことです。本来こうしたポインタにはnullptrを代入すべきですが、親ノードが保持しているポインタのうちどれが自身への参照なのかを判断する情報を子ノード側は保持していません。かといってその度に親ノードがもつ配列を走査していては時間がかかってしまいます。

*4:高速化のために色々と試している中、毎回new演算子でノードを生成すると少し遅い可能性があるとrhooさんから指摘をいただきました。ありがとうございます。