[m-rev.] for review: Implement simple_tc algorithm.

Peter Wang novalazy at gmail.com
Tue Jan 17 17:34:19 AEDT 2023


Implement transitive closure using the simple_tc algorithm from
Esko Nuutila's doctoral thesis.

Based on some runs on randomly generated graphs of 100 to 3000 vertices
(see tests/hard_coded/digraph_tc.m), the simple_tc implementation was
about 1.75 to 2.8 times as fast as the old implementation on my machine.
(It would be many times faster if we did not have to maintain the
predecessor maps required by the digraph representation.)

library/digraph.m:
    Rename digraph.tc and digraph.rtc to digraph.old_tc and
    digraph.old_rtc. They are kept around for benchmarking,
    and will be deleted soon.

    Use the simple_tc algorithm to implement digraph.tc.

    Use digraph.tc to implement digraph.rtc.

    Let key_set_map_add call sparse_bitset.insert_new instead of
    sparse_bitset.contains followed by sparse_bitset.insert.

tests/hard_coded/digraph_tc.m:
    Add code to benchmark the new and old TC implementations.
---
 library/digraph.m             | 291 +++++++++++++++++++++++++++++++---
 tests/hard_coded/digraph_tc.m |  52 +++++-
 2 files changed, 320 insertions(+), 23 deletions(-)

diff --git a/library/digraph.m b/library/digraph.m
index cf2e8e896..fb831e84e 100644
--- a/library/digraph.m
+++ b/library/digraph.m
@@ -2,7 +2,7 @@
 % vim: ft=mercury ts=4 sw=4 et
 %---------------------------------------------------------------------------%
 % Copyright (C) 1995-1999,2002-2007,2010-2012 The University of Melbourne.
-% Copyright (C) 2014-2018, 2022 The Mercury team.
+% Copyright (C) 2014-2018, 2022-2023 The Mercury team.
 % This file is distributed under the terms specified in COPYING.LIB.
 %---------------------------------------------------------------------------%
 %
@@ -339,6 +339,10 @@
 :- func tc(digraph(T)) = digraph(T).
 :- pred tc(digraph(T)::in, digraph(T)::out) is det.
 
+    % This will be deleted soon.
+    %
+:- pred old_tc(digraph(T)::in, digraph(T)::out) is det.
+
     % rtc(G, RTC) is true if RTC is the reflexive transitive closure of G.
     %
     % RTC is the reflexive closure of the transitive closure of G,
@@ -347,6 +351,10 @@
 :- func rtc(digraph(T)) = digraph(T).
 :- pred rtc(digraph(T)::in, digraph(T)::out) is det.
 
+    % This will be deleted soon.
+    %
+:- pred old_rtc(digraph(T)::in, digraph(T)::out) is det.
+
     % traverse(G, ProcessVertex, ProcessEdge, !Acc) will traverse the digraph G
     % - calling ProcessVertex for each vertex in the digraph, and
     % - calling ProcessEdge for each edge in the digraph.
@@ -375,6 +383,7 @@
 :- pred slow_tc(digraph(T)::in, digraph(T)::out) is det.
 
     % Straightforward implementation of rtc for debugging.
+    % This will be deleted soon.
     %
 :- pred slow_rtc(digraph(T)::in, digraph(T)::out) is det.
 
@@ -384,8 +393,8 @@
 :- implementation.
 
 :- import_module bimap.
-:- import_module uint.
 :- import_module require.
+:- import_module uint.
 
 %---------------------------------------------------------------------------%
 
@@ -425,11 +434,10 @@
 
 key_set_map_add(XI, Y, Map0, Map) :-
     ( if map.search(Map0, XI, SuccXs0) then
-        ( if sparse_bitset.contains(SuccXs0, Y) then
-            Map = Map0
-        else
-            sparse_bitset.insert(Y, SuccXs0, SuccXs),
+        ( if sparse_bitset.insert_new(Y, SuccXs0, SuccXs) then
             map.det_update(XI, SuccXs, Map0, Map)
+        else
+            Map = Map0
         )
     else
         SuccXs = sparse_bitset.make_singleton_set(Y),
@@ -1030,6 +1038,10 @@ digraph.return_sccs_in_from_to_order(G) = ATsort :-
 digraph.return_sccs_in_to_from_order(G) = ATsort :-
     % The algorithm used is described in R.E. Tarjan, "Depth-first search
     % and linear graph algorithms", SIAM Journal on Computing, 1, 2 (1972).
+    %
+    % Strictly speaking, this is Kosaraju's algorithm. Tarjan's algorithm
+    % improves upon it by performing one traversal of the input graph
+    % instead of two.
     digraph.dfsrev(G, DfsRev),
     digraph.inverse(G, GInv),
     sparse_bitset.init(Vis),
@@ -1066,6 +1078,235 @@ tc(G) = Tc :-
     digraph.tc(G, Tc).
 
 tc(G, Tc) :-
+    simple_tc_main(G, Tc).
+
+%---------------------------------------------------------------------------%
+
+% This implements the simple_tc algorithm from Esko Nuutila's thesis
+% "Efficient Transitive Closure Computation in Large Digraphs", p 49.
+% <http://www.cs.hut.fi/~enu/thesis.html>
+
+:- type simple_tc_visit(T)
+    --->    simple_tc_visit(
+                visit_counter   :: uint,
+                visit_map       :: map(digraph_key(T), uint)
+            ).
+
+:- type simple_tc_state(T)
+    --->    simple_tc_state(
+                % A map from a vertex to the candidate root of the component
+                % that will include the vertex.
+                root_map        :: map(digraph_key(T), digraph_key(T)),
+
+                % A vertex is included in comp once the component containing
+                % the vertex has been determined.
+                comp            :: digraph_key_set(T),
+
+                % Stack of vertices being visited.
+                stack           :: list(digraph_key(T)),
+
+                % The successors and precessors of each vertex in the graph
+                % we are building.
+                succ_map        :: key_set_map(T),
+                pred_map        :: key_set_map(T)
+            ).
+
+:- pred simple_tc_main(digraph(T)::in, digraph(T)::out) is det.
+
+simple_tc_main(G, Tc) :-
+    G = digraph(NextKey, VMap, FwdMap0, BwdMap0),
+    Visit0 = simple_tc_visit(0u, map.init),
+    State0 = simple_tc_state(map.init, sparse_bitset.init, [],
+        FwdMap0, BwdMap0),
+
+    bimap.foldl2(simple_tc_main_loop(FwdMap0), VMap,
+        Visit0, _Visit, State0, State),
+
+    State = simple_tc_state(_Stack, _Root, _Comp, FwdMap, BwdMap),
+    Tc = digraph(NextKey, VMap, FwdMap, BwdMap).
+
+:- pred simple_tc_main_loop(key_set_map(T)::in, T::in, digraph_key(T)::in,
+    simple_tc_visit(T)::in, simple_tc_visit(T)::out,
+    simple_tc_state(T)::in, simple_tc_state(T)::out) is det.
+
+simple_tc_main_loop(OrigEdges, _V, KeyV, !Visit, !State) :-
+    ( if simple_tc_new_visit(KeyV, !Visit) then
+        simple_tc(OrigEdges, KeyV, !Visit, !State)
+    else
+        true
+    ).
+
+:- pred simple_tc_new_visit(digraph_key(T)::in,
+    simple_tc_visit(T)::in, simple_tc_visit(T)::out) is semidet.
+
+simple_tc_new_visit(V, !Visit) :-
+    Counter0 = !.Visit ^ visit_counter,
+    Map0 = !.Visit ^ visit_map,
+    map.insert(V, Counter0, Map0, Map),
+    Counter = Counter0 + 1u,
+    !Visit ^ visit_counter := Counter,
+    !Visit ^ visit_map := Map.
+
+:- pred simple_tc(key_set_map(T)::in, digraph_key(T)::in,
+    simple_tc_visit(T)::in, simple_tc_visit(T)::out,
+    simple_tc_state(T)::in, simple_tc_state(T)::out) is det.
+
+simple_tc(OrigEdges, V, !Visit, !State) :-
+    some [!RootMap, !Stack] (
+        !:RootMap = !.State ^ root_map,
+        !:Stack = !.State ^ stack,
+
+        map.det_insert(V, V, !RootMap),
+        !:Stack = [V | !.Stack],
+
+        !State ^ root_map := !.RootMap,
+        !State ^ stack := !.Stack
+    ),
+
+    get_successors(OrigEdges, V, OrigSuccV),
+    sparse_bitset.foldl2(simple_tc_for_v_w(OrigEdges, V), OrigSuccV,
+        !Visit, !State),
+
+    RootMap = !.State ^ root_map,
+    ( if map.search(RootMap, V, V) then
+        some [!Stack, !Comp, !SuccMap, !PredMap] (
+            !:Stack = !.State ^ stack,
+            !:Comp = !.State ^ comp,
+            !:SuccMap = !.State ^ succ_map,
+            !:PredMap = !.State ^ pred_map,
+
+            % V is the root of a component that also contains Ws.
+            pop_component(V, Ws, !Stack),
+            sparse_bitset.insert(V, !Comp),
+            sparse_bitset.insert_list(Ws, !Comp),
+
+            % Distribute successors from the root V to other vertices in the
+            % component.
+            get_successors(!.SuccMap, V, SuccV),
+            list.foldl(add_successors(SuccV), Ws, !SuccMap),
+
+            % Maintain the predecessor map from the (new) successors back to
+            % each vertex in the component. This ends up dominating the time
+            % spent computing the transitive closure, even though the user may
+            % not make use of the precessor map at all.
+            add_predecessors(SuccV, V, !PredMap),
+            list.foldl(add_predecessors(SuccV), Ws, !PredMap),
+
+            !State ^ stack := !.Stack,
+            !State ^ comp := !.Comp,
+            !State ^ succ_map := !.SuccMap,
+            !State ^ pred_map := !.PredMap
+        )
+    else
+        % V is not the root of a component so it remains on the stack.
+        true
+    ).
+
+:- pred simple_tc_for_v_w(key_set_map(T)::in,
+    digraph_key(T)::in, digraph_key(T)::in,
+    simple_tc_visit(T)::in, simple_tc_visit(T)::out,
+    simple_tc_state(T)::in, simple_tc_state(T)::out) is det.
+
+simple_tc_for_v_w(OrigEdges, V, W, !Visit, !State) :-
+    ( if simple_tc_new_visit(W, !Visit) then
+        simple_tc(OrigEdges, W, !Visit, !State)
+    else
+        true
+    ),
+
+    Comp = !.State ^ comp,
+    ( if sparse_bitset.contains(Comp, W) then
+        % We already determined the component that contains W.
+        true
+    else
+        % Otherwise, update the candidate that will become the root of the
+        % component that contains W.
+        RootMap0 = !.State ^ root_map,
+        map.lookup(RootMap0, V, RootV),
+        map.lookup(RootMap0, W, RootW),
+        MinRoot = min_by_visit_order(!.Visit, RootV, RootW),
+        map.det_update(V, MinRoot, RootMap0, RootMap),
+        !State ^ root_map := RootMap
+    ),
+
+    SuccMap0 = !.State ^ succ_map,
+    get_successors(SuccMap0, V, SuccV),
+    get_successors(SuccMap0, W, SuccW),
+    sparse_bitset.union(SuccV, SuccW, Union),
+    V = digraph_key(VI),
+    map.set(VI, Union, SuccMap0, SuccMap),
+    !State ^ succ_map := SuccMap.
+
+:- func min_by_visit_order(simple_tc_visit(T), digraph_key(T), digraph_key(T))
+    = digraph_key(T).
+
+min_by_visit_order(Visit, X, Y) = Min :-
+    VisitMap = Visit ^ visit_map,
+    map.lookup(VisitMap, X, OrderX),
+    map.lookup(VisitMap, Y, OrderY),
+    ( if OrderX =< OrderY then
+        Min = X
+    else
+        Min = Y
+    ).
+
+:- pred pop_component(digraph_key(T)::in, list(digraph_key(T))::out,
+    list(digraph_key(T))::in, list(digraph_key(T))::out) is det.
+
+pop_component(Root, NonRoots, !Stack) :-
+    (
+        !.Stack = [V | !:Stack],
+        ( if V = Root then
+            NonRoots = []
+        else
+            pop_component(Root, TailNonRoots, !Stack),
+            NonRoots = [V | TailNonRoots]
+        )
+    ;
+        !.Stack = [],
+        unexpected($pred, "empty stack")
+    ).
+
+:- pred get_successors(key_set_map(T)::in, digraph_key(T)::in,
+    digraph_key_set(T)::out) is det.
+
+get_successors(SuccMap, V, SuccV) :-
+    V = digraph_key(VI),
+    ( if map.search(SuccMap, VI, SuccV0) then
+        SuccV = SuccV0
+    else
+        SuccV = sparse_bitset.init
+    ).
+
+:- pred add_successors(digraph_key_set(T)::in, digraph_key(T)::in,
+    key_set_map(T)::in, key_set_map(T)::out) is det.
+
+add_successors(Successors, V, !SuccMap) :-
+    V = digraph_key(VI),
+    ( if map.search(!.SuccMap, VI, SuccV0) then
+        sparse_bitset.union(Successors, SuccV0, SuccV),
+        map.det_update(VI, SuccV, !SuccMap)
+    else
+        SuccV = Successors,
+        map.det_insert(VI, SuccV, !SuccMap)
+    ).
+
+:- pred add_predecessors(digraph_key_set(T)::in, digraph_key(T)::in,
+    key_set_map(T)::in, key_set_map(T)::out) is det.
+
+add_predecessors(Successors, V, !PredMap) :-
+    sparse_bitset.foldl(add_predecessor(V), Successors, !PredMap).
+
+:- pred add_predecessor(digraph_key(T)::in, digraph_key(T)::in,
+    key_set_map(T)::in, key_set_map(T)::out) is det.
+
+add_predecessor(V, Successor, !PredMap) :-
+    Successor = digraph_key(SuccessorI),
+    key_set_map_add(SuccessorI, V, !PredMap).
+
+%---------------------------------------------------------------------------%
+
+old_tc(G, Tc) :-
     % digraph.tc returns the transitive closure of a digraph.
     % We use this procedure:
     %
@@ -1086,7 +1327,7 @@ tc(G, Tc) :-
     % benefit. We should implement TC using a known efficient algorithm,
     % then RTC can be implemented trivially on top of TC.
     %
-    digraph.rtc(G, Rtc),
+    digraph.old_rtc(G, Rtc),
 
     % Find the fake reflexives.
     digraph.keys(G, Keys),
@@ -1116,7 +1357,27 @@ detect_fake_reflexives(G, Rtc, [X | Xs], !Fakes) :-
 rtc(G) = Rtc :-
     digraph.rtc(G, Rtc).
 
-rtc(G, !:Rtc) :-
+rtc(G, Rtc) :-
+    tc(G, Tc),
+    rc(Tc, Rtc).
+
+    % Reflexive closure.
+    %
+:- pred rc(digraph(T)::in, digraph(T)::out) is det.
+
+rc(G, RC) :-
+    digraph.keys(G, Keys),
+    list.foldl(add_reflexive, Keys, G, RC).
+
+:- pred add_reflexive(digraph_key(T)::in,
+    digraph(T)::in, digraph(T)::out) is det.
+
+add_reflexive(X, !G) :-
+    add_edge(X, X, !G).
+
+%---------------------------------------------------------------------------%
+
+old_rtc(G, !:Rtc) :-
     % digraph.rtc returns the reflexive transitive closure of a digraph.
     %
     % Note: This is not the most efficient algorithm (in the sense of minimal
@@ -1250,20 +1511,6 @@ slow_rtc(G, RTC) :-
     slow_tc(G, TC),
     rc(TC, RTC).
 
-    % Reflexive closure.
-    %
-:- pred rc(digraph(T)::in, digraph(T)::out) is det.
-
-rc(G, RC) :-
-    digraph.keys(G, Keys),
-    list.foldl(add_reflexive, Keys, G, RC).
-
-:- pred add_reflexive(digraph_key(T)::in,
-    digraph(T)::in, digraph(T)::out) is det.
-
-add_reflexive(X, !G) :-
-    add_edge(X, X, !G).
-
 %---------------------------------------------------------------------------%
 :- end_module digraph.
 %---------------------------------------------------------------------------%
diff --git a/tests/hard_coded/digraph_tc.m b/tests/hard_coded/digraph_tc.m
index 6c2d395a7..2f0a1087b 100644
--- a/tests/hard_coded/digraph_tc.m
+++ b/tests/hard_coded/digraph_tc.m
@@ -11,7 +11,7 @@
 
 :- import_module io.
 
-:- pred main(io::di, io::uo) is det.
+:- pred main(io::di, io::uo) is cc_multi.
 
 %---------------------------------------------------------------------------%
 %---------------------------------------------------------------------------%
@@ -19,8 +19,10 @@
 :- implementation.
 
 :- import_module array.
+:- import_module benchmarking.
 :- import_module bool.
 :- import_module digraph.
+:- import_module float.
 :- import_module int.
 :- import_module list.
 :- import_module maybe.
@@ -34,6 +36,31 @@
 
 main(!IO) :-
     io.command_line_arguments(Args, !IO),
+    ( if
+        Args = ["benchmark", SizeStr, RepeatStr],
+        string.to_int(SizeStr, Size),
+        Size > 1,
+        string.to_int(RepeatStr, Repeat),
+        Repeat > 0
+    then
+        init_random(MaybeRNG, !IO),
+        (
+            MaybeRNG = ok(R0),
+            generate_graph(Size, G, R0, _R),
+            run_benchmark(Size, G, Repeat, !IO)
+        ;
+            MaybeRNG = error(Error),
+            io.write_string(Error, !IO),
+            io.nl(!IO),
+            io.set_exit_status(1, !IO)
+        )
+    else
+        main_2(Args, !IO)
+    ).
+
+:- pred main_2(list(string)::in, io::di, io::uo) is det.
+
+main_2(Args, !IO) :-
     ( if Args = [] then
         load_graph("digraph_tc.inp", LoadRes, !IO),
         Verbose = yes
@@ -223,6 +250,8 @@ same_graph(A, B) :-
     sort(PairsB, SortedPairsB),
     SortedPairsA = SortedPairsB.
 
+%---------------------------------------------------------------------------%
+
 :- pred write_graph(digraph(string)::in, io::di, io::uo) is det.
 
 write_graph(G, !IO) :-
@@ -238,3 +267,24 @@ write_edge(A - B, !IO) :-
     io.format("  %s -> %s;\n", [s(A), s(B)], !IO).
 
 %---------------------------------------------------------------------------%
+
+:- pred run_benchmark(int::in, digraph(string)::in, int::in, io::di, io::uo)
+    is cc_multi.
+
+run_benchmark(Size, G, Repeat, !IO) :-
+    NumEdges = length(to_assoc_list(G)),
+    io.format("vertices:   %d\n", [i(Size)], !IO),
+    io.format("edges:      %d\n", [i(NumEdges)], !IO),
+
+    benchmark_det(tc, G, _TC, Repeat, TimeTC),
+    AvgTimeTC = float(TimeTC) / float(Repeat),
+    io.format("tc avg:     %f ms\n", [f(AvgTimeTC)], !IO),
+
+    benchmark_det(old_tc, G, _OldTC, Repeat, OldTimeTC),
+    AvgOldTimeTC = float(OldTimeTC) / float(Repeat),
+    io.format("old_tc avg: %f ms\n", [f(AvgOldTimeTC)], !IO),
+
+    F = float(OldTimeTC) / float(TimeTC),
+    io.format("%f times as fast\n\n", [f(F)], !IO).
+
+%---------------------------------------------------------------------------%
-- 
2.39.0



More information about the reviews mailing list