Reference: [[Make JETLS multithreaded]] (especially for `LWContainer` and `CASContainer` implementations)
Goal: Parallelize `report_package(::Module)`.
`analyze_and_report_package!` itself can be easily parallelized with the following diff:
> [!diff]- `git diff src/JETBase.jl`
> ```diff
> diff --git a/src/JETBase.jl b/src/JETBase.jl
> index 9ab03847..3a708c93 100644
> --- a/src/JETBase.jl
> +++ b/src/JETBase.jl
> @@ -327,6 +327,20 @@ Prints a report of the top-level error `report` to the given `io`.
> """
> function print_report end
>
> +mutable struct PackageAnalysisProgress
> + const reports::Vector{InferenceErrorReport}
> + const reports_lock::ReentrantLock
> + @atomic done::Int
> + @atomic analyzed::Int
> + @atomic cached::Int
> + const interval::Int
> + @atomic next_interval::Int
> + function PackageAnalysisProgress(n_sigs::Int)
> + interval = max(n_sigs ÷ 25, 1)
> + new(InferenceErrorReport[], ReentrantLock(), 0, 0, 0, interval, interval)
> + end
> +end
> +
> include("toplevel/virtualprocess.jl")
>
> # results
> @@ -968,6 +982,11 @@ struct SigAnalysisResult
> codeinst::CodeInstance
> end
>
> +struct SigWorkItem
> + siginfos::Vector{Revise.SigInfo}
> + index::Int
> +end
> +
> """
> analyze_and_report_package!(analyzer::AbstractAnalyzer, package::Module; jetconfigs...) -> JETToplevelResult
>
> @@ -991,7 +1010,6 @@ function analyze_and_report_package!(analyzer::AbstractAnalyzer, pkgmod::Module;
> end
>
> start = time()
> - counter, analyzed, cached = Ref(0), Ref(0), Ref(0)
> res = VirtualProcessResult(nothing)
> jetconfigs = set_if_missing(jetconfigs, :toplevel_logger, IOContext(stdout, JET_LOGGER_LEVEL => DEFAULT_LOGGER_LEVEL))
> config = ToplevelConfig(; jetconfigs...)
> @@ -1001,66 +1019,86 @@ function analyze_and_report_package!(analyzer::AbstractAnalyzer, pkgmod::Module;
> newstate = AnalyzerState(AnalyzerState(analyzer); world=Base.get_world_counter())
> analyzer = AbstractAnalyzer(analyzer, newstate)
>
> - n_sigs = 0
> - for fi in pkgdata.fileinfos, (_, exsigs) in fi.modexsigs, (_, siginfos) in exsigs
> - isnothing(siginfos) && continue
> - n_sigs += length(siginfos)
> - end
> + workitems = SigWorkItem[]
> for fi in pkgdata.fileinfos, (_, exsigs) in fi.modexsigs, (_, siginfos) in exsigs
> isnothing(siginfos) && continue
> for (i, siginfo) in enumerate(siginfos)
> - toplevel_logger(config) do @nospecialize(io::IO)
> - clearline(io)
> - end
> - counter[] += 1
> - inf_world = CC.get_inference_world(analyzer)
> + push!(workitems, SigWorkItem(siginfos, i))
> + end
> + end
> +
> + n_sigs = length(workitems)
> + progress = PackageAnalysisProgress(n_sigs)
> + inf_world = CC.get_inference_world(analyzer)
> +
> + toplevel_logger(config) do @nospecialize(io::IO)
> + print(io, "Analyzing top-level definitions (progress: 0/$n_sigs | interval: $(progress.interval))")
> + end
> +
> + tasks = map(workitems) do workitem
> + (; siginfos, index) = workitem
> + siginfo = siginfos[index]
> + Threads.@spawn :default try
> ext = Revise.get_extended_data(siginfo, :JET)
> + local reports::Vector{InferenceErrorReport}
> if ext !== nothing && ext.data isa SigAnalysisResult
> prev_result = ext.data::SigAnalysisResult
> if (CC.cache_owner(analyzer) === prev_result.codeinst.owner &&
> prev_result.codeinst.max_world ≥ inf_world ≥ prev_result.codeinst.min_world)
> - toplevel_logger(config) do @nospecialize(io::IO)
> - (counter[] == n_sigs ? println : print)(io, "Skipped analysis for cached definition ($(counter[])/$n_sigs)")
> - end
> - cached[] += 1
> + @atomic progress.cached += 1
> reports = prev_result.reports
> @goto gotreports
> end
> end
> + task_analyzer = AbstractAnalyzer(analyzer, AnalyzerState(AnalyzerState(analyzer)))
> match = Base._which(siginfo.sig;
> - method_table = CC.method_table(analyzer),
> + method_table = CC.method_table(task_analyzer),
> world = inf_world,
> raise = false)
> if match !== nothing
> - toplevel_logger(config; pre=clearline) do @nospecialize(io::IO)
> - if jet_logger_level(io) ≥ JET_LOGGER_LEVEL_DEBUG
> - print(io, "Analyzing top-level definition `")
> - Base.show_tuple_as_call(io, Symbol(""), siginfo.sig)
> - print(io, "` (progress: $(counter[])/$n_sigs)")
> - else
> - print(io, "Analyzing top-level definition (progress: $(counter[])/$n_sigs)")
> - end
> - end
> - result = analyze_method_signature!(analyzer,
> + result = analyze_method_signature!(task_analyzer,
> match.method, match.spec_types, match.sparams)
> - analyzed[] += 1
> - reports = get_reports(analyzer, result)
> - siginfos[i] = Revise.replace_extended_data(siginfo, :JET, SigAnalysisResult(reports, result.ci))
> - @label gotreports
> - append!(res.inference_error_reports, reports)
> + @atomic progress.analyzed += 1
> + reports = get_reports(task_analyzer, result)
> + siginfos[index] = Revise.replace_extended_data(siginfo, :JET, SigAnalysisResult(reports, result.ci))
> else
> toplevel_logger(config) do @nospecialize(io::IO)
> print(io, "Couldn't find a single matching method for the signature `")
> Base.show_tuple_as_call(io, Symbol(""), siginfo.sig)
> - println(io, "` (progress: $(counter[])/$n_sigs)")
> + println(io, "`")
> + end
> + reports = InferenceErrorReport[]
> + end
> + @label gotreports
> + isempty(reports) || @lock progress.reports_lock append!(progress.reports, reports)
> + catch err
> + @error "Error analyzing method signature" siginfo.sig
> + Base.showerror(stderr, err, catch_backtrace())
> + finally
> + done = (@atomic progress.done += 1)
> + current_next = @atomic progress.next_interval
> + if done >= current_next
> + @atomicreplace progress.next_interval current_next => current_next + progress.interval
> + toplevel_logger(config; pre=clearline) do @nospecialize(io::IO)
> + print(io, "Analyzing top-level definitions (progress: $done/$n_sigs)")
> end
> end
> end
> end
>
> + waitall(tasks)
> +
> + append!(res.inference_error_reports, progress.reports)
> +
> + toplevel_logger(config; pre=clearline) do @nospecialize(io::IO)
> + done = @atomic progress.done
> + print(io, "Analyzing top-level definitions (progress: $done/$n_sigs)")
> + end
> toplevel_logger(config; pre=println) do @nospecialize(io::IO)
> sec = round(time() - start; digits = 3)
> - println(io, "Analyzed all top-level definitions (all: $(counter[]) | analyzed: $(analyzed[]) | cached: $(cached[]) | took: $sec sec)")
> + analyzed = @atomic progress.analyzed
> + cached = @atomic progress.cached
> + println(io, "Analyzed all top-level definitions (all: $n_sigs | analyzed: $analyzed | cached: $cached | took: $sec sec)")
> end
>
> unique!(aggregation_policy(analyzer), res.inference_error_reports)
> ```
However, this alone is not safe. Without making `AbstractAnalyzer`'s state management (especially cache-related) thread-safe, data races can occur in `analyze_method_signature!(::AbstractAnalyzer, ...)`.
The state managed by `AbstractAnalyzer` can be categorized as follows:
1. State local to a single call graph inference (i.e., `analyze_method_signature!`):
- `report_stash`, `cache_target`: Initialized in `AnalyzerState` constructor
- `concretized`, `binding_states`: Only updated during top-level analysis, but top-level analysis with code loading should always be sequential, so no data race concerns here
- `entry`: Can be updated via `set_entry!`, but practically only updated once in the entry frame `InferenceState` constructor (linear-type)
2. State that can be shared across call graph inferences but is local to `AbstractAnalyzer`:
- `inf_cache::Vector{InferenceResult}`: Call graph inference local cache. Can be read/written at various points in inference routines
- `analysis_results::IdDict{InferenceResult,AnalysisResult}`: Maps `inf_cache` to `AnalysisResult`. Used to retrieve final JET analysis results
3. Globally managed state:
- `CodeInstance` cache: Globally managed by inference. Each `AbstractAnalyzer` provides a unique `AnalysisToken` via `CC.cache_owner` interface, and `CodeInstance` cache is separated by it. Can be read/written at various points in call graph inference, but is implemented thread-safely on the `CC` side.
- Analyzer cache (dict of `AnalysisToken`): Global cache managed by each `AbstractAnalyzer`. Caches the token used by each analyzer's `CodeInstance` cache
In summary, thread-safety needs to be introduced for the following components:
- `inf_cache::Vector{InferenceResult}`
- `analysis_results::IdDict{InferenceResult,AnalysisResult}`
- Analyzer cache
First, analyzer cache is typically implemented as a simple `Dict{UInt,AnalysisToken}`, so it can be made thread-safe by using `CASContainer` or `LWContainer` on each `AbstractAnalyzer` side.
For `inf_cache` and `analysis_results`, there are several approaches:
1. Manage `inf_cache` and `analysis_results` as call graph inference local, like `report_stash`/`cache_target`
- a. Manage them purely local to signature inference, with no sharing between tasks
- b. (Extension) Merge them in a thread-safe manner on the caller side (`analyze_method_signature!`) and reuse across tasks
2. Make `inf_cache` and `analysis_results` themselves thread-safe, allowing all call graph inferences to share them
Approach 1 is the most straightforward thread-safe implementation.
Approach 2 is more complex, but could theoretically benefit from cache hits by immediately using results from other threads.
# TODO
- [x] Make analyzer cache thread-safe
- [x] Make `analyze_method_signature!` multithreaded
- [x] Try approach 1
- [x] a. Make `inf_cache`/`analysis_results` purely local to signature analysis: 17.75s
- [x] b. Merge `inf_cache`/`analysis_results` concurrently and reuse the caches
- 309s (w/o inference local refactor)
- 24.52s (w/ inference local refactor[^1])
- [-] Try approach 2:
- Cancelled. Would face the same cache bloat issues as 1b, plus the incomplete `InferenceResult` problem. Since 1b with IdDict-based
cache (24.52s) was still slower than 1a (17.75s), approach 2 with its
additional thread-safe data structure overhead would likely be even worse.
- [x] Make `analyze_from_definitions!` multithreaded
[^1]: This implementation has a bug: when `take_cache_snapshot!` copies the shared `inf_cache`, it may include `InferenceResult` objects that are still being inferred by other tasks. These incomplete results don't have `ci_as_edge` defined yet, causing `AssertionError: InferenceResult without ci_as_edge` in `const_prop_result`. To fix this, we would need to filter results by `isdefined(result, :ci_as_edge)` during merge, but this adds complexity and the performance benefit is marginal.
## Performance experiment
> `julia --startup-file-no --threads=4,2 -e 'using JET; report_package(JET; target_modules=(JET,), sourceinfo=:compact);'`
| Approach | Description | Time |
| -------- | ----------------------------------------------------- | ------ |
| `master` | The current sequential implementation | 52.07s |
| 1a | Local caches, no sharing | 17.75s |
| 1b | Local caches + merge/reuse (Vector-based) | 309s |
| 1b | Local caches + merge/reuse (IdDict-based `inf_cache`) | 24.52s |
| 2 | Thread-safe shared caches (cancelled) | N/A |
> `julia --startup-file=no --threads=4,2 -e 'using JET; using Pkg; Pkg.activate(; temp=true); Pkg.add("CSV"); using CSV; report_package(CSV; target_modules=(CSV,), sourceinfo=:compact);'`
| Approach | Description | Time |
| -------- | ------------------------------------- | ------ |
| `master` | The current sequential implementation | 44.23s |
| 1a | Local caches, no sharing | 19.57s |
Notes:
- 1a is the simplest approach where each signature inference task has its own independent caches
- 1b with Vector-based `inf_cache` suffered from O(n²) merge complexity
- 1b with IdDict-based `inf_cache` (O(1) lookup) significantly improved performance but still slower than 1a
- The overhead of copying and merging caches outweighs the benefit of cross-task cache hits