Skip to content

perf: optimize instruction_lookups::generate_witness() with improved parallelism #634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 79 additions & 47 deletions jolt-core/src/jolt/vm/instruction_lookups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ pub struct InstructionLookupStuff<T: CanonicalSerialize + CanonicalDeserialize>
v_init_final: VerifierComputedOpening<Vec<T>>,
}

/// Note –– F: JoltField bound is not enforced.
/// Note –– F: JoltField bound is not enforced.
///
/// See issue #112792 <https://github.com/rust-lang/rust/issues/112792>.
/// Adding #![feature(lazy_type_alias)] to the crate attributes seem to break
/// `alloy_sol_types`.
pub type InstructionLookupPolynomials<F: JoltField> =
InstructionLookupStuff<MultilinearPolynomial<F>>;
/// Note –– F: JoltField bound is not enforced.
/// Note –– F: JoltField bound is not enforced.
///
/// See issue #112792 <https://github.com/rust-lang/rust/issues/112792>.
/// Adding #![feature(lazy_type_alias)] to the crate attributes seem to break
/// `alloy_sol_types`.
pub type InstructionLookupOpenings<F: JoltField> = InstructionLookupStuff<F>;
/// Note –– PCS: CommitmentScheme bound is not enforced.
/// Note –– PCS: CommitmentScheme bound is not enforced.
///
/// See issue #112792 <https://github.com/rust-lang/rust/issues/112792>.
/// Adding #![feature(lazy_type_alias)] to the crate attributes seem to break
Expand Down Expand Up @@ -868,8 +868,10 @@ where
) -> InstructionLookupPolynomials<F> {
let m: usize = ops.len().next_power_of_two();

// Calculate lookup indices in parallel mode
let subtable_lookup_indices: Vec<Vec<u16>> = Self::subtable_lookup_indices(ops);

// Compute polynomials for each memory in parallel
let polys: Vec<(
MultilinearPolynomial<F>,
MultilinearPolynomial<F>,
Expand Down Expand Up @@ -910,40 +912,55 @@ where
})
.collect();

// Vec<(MultilinearPolynomial<F>, MultilinearPolynomial<F>, MultilinearPolynomial<F>)> ->
// (Vec<MultilinearPolynomial<F>>, Vec<MultilinearPolynomial<F>>, Vec<MultilinearPolynomial<F>>)
// Unpack results into three separate vectors of polynomials
// Use pre-allocation for greater efficiency
let (read_cts, final_cts, E_polys): (
Vec<MultilinearPolynomial<F>>,
Vec<MultilinearPolynomial<F>>,
Vec<MultilinearPolynomial<F>>,
) = polys.into_iter().fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut read_acc, mut final_acc, mut E_acc), (read, f, E)| {
) = {
let num_memories = preprocessing.num_memories;
let mut read_acc = Vec::with_capacity(num_memories);
let mut final_acc = Vec::with_capacity(num_memories);
let mut E_acc = Vec::with_capacity(num_memories);

for (read, f, E) in polys {
read_acc.push(read);
final_acc.push(f);
E_acc.push(E);
(read_acc, final_acc, E_acc)
},
);
}

(read_acc, final_acc, E_acc)
};

// Convert indices to polynomials in parallel
let dim: Vec<MultilinearPolynomial<F>> = subtable_lookup_indices
.into_par_iter()
.map(MultilinearPolynomial::from)
.collect();

let mut instruction_flag_bitvectors: Vec<Vec<u8>> =
vec![vec![0; m]; Self::NUM_INSTRUCTIONS];
for (j, op) in ops.iter().enumerate() {
if let Some(instr) = &op.instruction_lookup {
instruction_flag_bitvectors[InstructionSet::enum_index(instr)][j] = 1;
// Create flag vectors for each instruction in parallel
let instruction_flag_bitvectors: Vec<Vec<u8>> = {
let mut bitvectors = vec![vec![0; m]; Self::NUM_INSTRUCTIONS];

// Process operations sequentially to avoid mutable borrow issues
for (j, op) in ops.iter().enumerate() {
if let Some(instr) = &op.instruction_lookup {
let instr_idx = InstructionSet::enum_index(instr);
bitvectors[instr_idx][j] = 1;
}
}
}

bitvectors
};

// Convert flag vectors to polynomials in parallel
let instruction_flag_polys: Vec<MultilinearPolynomial<F>> = instruction_flag_bitvectors
.into_par_iter()
.map(MultilinearPolynomial::from)
.collect();

// Compute lookup outputs in parallel
let mut lookup_outputs = Self::compute_lookup_outputs(ops);
lookup_outputs.resize(m, 0);
let lookup_outputs = MultilinearPolynomial::from(lookup_outputs);
Expand All @@ -960,6 +977,51 @@ where
}
}

/// Converts each instruction in `ops` into its corresponding subtable lookup indices.
/// The output is `C` vectors, each of length `m`.
fn subtable_lookup_indices(ops: &[JoltTraceStep<InstructionSet>]) -> Vec<Vec<u16>> {
let m = ops.len().next_power_of_two();
let log_M = M.log_2();

// Process instructions in parallel to create indices
let chunked_indices: Vec<Vec<u16>> = ops
.par_iter()
.map(|op| {
if let Some(instr) = &op.instruction_lookup {
instr
.to_indices(C, log_M)
.iter()
.map(|i| *i as u16)
.collect()
} else {
vec![0; C]
}
})
.collect();

// Create lookup indices for each dimension C in parallel
(0..C)
.into_par_iter()
.map(|i| {
let mut access_sequence: Vec<u16> = chunked_indices
.iter()
.map(|chunks| chunks[i])
.collect();
access_sequence.resize(m, 0);
access_sequence
})
.collect()
}

/// Computes the lookup output value for a single instruction
fn compute_lookup_output(
_preprocessing: &InstructionLookupsPreprocessing<C, F>,
instruction: &InstructionSet,
_indices: &[u16],
) -> u32 {
instruction.lookup_entry() as u32
}

/// Prove Jolt primary sumcheck including instruction collation.
///
/// Computes \sum{ eq(r,x) * [ flags_0(x) * g_0(E(x)) + flags_1(x) * g_1(E(x)) + ... + flags_{NUM_INSTRUCTIONS}(E(x)) * g_{NUM_INSTRUCTIONS}(E(x)) ]}
Expand Down Expand Up @@ -1222,36 +1284,6 @@ where
+ 2 // eq and flag
}

/// Converts each instruction in `ops` into its corresponding subtable lookup indices.
/// The output is `C` vectors, each of length `m`.
fn subtable_lookup_indices(ops: &[JoltTraceStep<InstructionSet>]) -> Vec<Vec<u16>> {
let m = ops.len().next_power_of_two();
let log_M = M.log_2();
let chunked_indices: Vec<Vec<u16>> = ops
.iter()
.map(|op| {
if let Some(instr) = &op.instruction_lookup {
instr
.to_indices(C, log_M)
.iter()
.map(|i| *i as u16)
.collect()
} else {
vec![0; C]
}
})
.collect();

let mut subtable_lookup_indices: Vec<Vec<u16>> = Vec::with_capacity(C);
for i in 0..C {
let mut access_sequence: Vec<u16> =
chunked_indices.iter().map(|chunks| chunks[i]).collect();
access_sequence.resize(m, 0);
subtable_lookup_indices.push(access_sequence);
}
subtable_lookup_indices
}

#[tracing::instrument(skip_all, name = "InstructionLookupsProof::compute_lookup_outputs")]
fn compute_lookup_outputs(instructions: &Vec<JoltTraceStep<InstructionSet>>) -> Vec<u32> {
instructions
Expand Down