Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aggregator optimization: update .gitignore, aggregator/aggregator.go,… #19

Draft
wants to merge 3 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 188 additions & 1 deletion aggregator/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
aggLayerClient: aggLayerClient,
sequencerPrivateKey: sequencerPrivateKey,
witnessRetrievalChan: make(chan state.DBBatch),
rpcClient: rpc.NewBatchEndpoints(cfg.RPCURL, cfg.RPCTimeout.Duration),

Check failure on line 161 in aggregator/aggregator.go

View workflow job for this annotation

GitHub Actions / lint

cannot use rpc.NewBatchEndpoints(cfg.RPCURL, cfg.RPCTimeout.Duration) (value of type *"github.com/0xPolygon/cdk/rpc".BatchEndpoints) as RPCInterface value in struct literal: *"github.com/0xPolygon/cdk/rpc".BatchEndpoints does not implement RPCInterface (missing method GetLatestSequence)) (typecheck)

Check failure on line 161 in aggregator/aggregator.go

View workflow job for this annotation

GitHub Actions / lint

cannot use rpc.NewBatchEndpoints(cfg.RPCURL, cfg.RPCTimeout.Duration) (value of type *"github.com/0xPolygon/cdk/rpc".BatchEndpoints) as RPCInterface value in struct literal: *"github.com/0xPolygon/cdk/rpc".BatchEndpoints does not implement RPCInterface (missing method GetLatestSequence)

Check failure on line 161 in aggregator/aggregator.go

View workflow job for this annotation

GitHub Actions / lint

cannot use rpc.NewBatchEndpoints(cfg.RPCURL, cfg.RPCTimeout.Duration) (value of type *"github.com/0xPolygon/cdk/rpc".BatchEndpoints) as RPCInterface value in struct literal: *"github.com/0xPolygon/cdk/rpc".BatchEndpoints does not implement RPCInterface (missing method GetLatestSequence)) (typecheck)
}

if a.ctx == nil {
Expand Down Expand Up @@ -440,7 +440,8 @@
}

if !proofGenerated {
proofGenerated, err = a.tryGenerateBatchProof(ctx, prover)
// proofGenerated, err = a.tryGenerateBatchProof(ctx, prover)
proofGenerated, err = a.tryGenerateBatchProof_FromRpcBatch(ctx, prover)
if err != nil {
tmpLogger.Errorf("Error trying to generate proof: %v", err)
}
Expand Down Expand Up @@ -1262,6 +1263,101 @@
return stateBatch, witness, proof, nil
}

func (a *Aggregator) getAndProveBatchFromRPC(
ctx context.Context,
prover ProverInterface,
) (*state.Batch, []byte, *state.Proof, error) {
proverID := prover.ID()
proverName := prover.Name()

tmpLogger := a.logger.WithFields(
"prover", proverName,
"proverId", proverID,
"proverAddr", prover.Addr(),
)

a.storageMutex.Lock()
defer a.storageMutex.Unlock()

// Step 1: Get latest sequence from RPC
sequence, err := a.rpcClient.GetLatestSequence()
if err != nil {
tmpLogger.Errorf("Error getting latest sequence from RPC: %v", err)
return nil, nil, nil, fmt.Errorf("failed to get latest sequence: %w", err)
}

// Step 2: Fetch all batches in the sequence
batches := make([]*state.Batch, 0)
for bn := sequence.FromBatchNumber; bn <= sequence.ToBatchNumber; bn++ {
batch, err := a.rpcClient.GetBatch(bn)
if err != nil {
tmpLogger.Errorf("Error fetching batch %d in sequence: %v", bn, err)
return nil, nil, nil, fmt.Errorf("failed to get batch %d: %w", bn, err)
}
batches = append(batches, batch.ToStateBatch(a.cfg.ChainID, a.cfg.ForkId))
}

// Step 3: Merge batch data and calculate aggregated parameters
mergedL2Data := mergeBatchL2Data(batches)
l1InfoRoot, err := calculateSequenceL1InfoRoot(batches)
if err != nil {
tmpLogger.Errorf("Error calculating L1InfoRoot: %v", err)
return nil, nil, nil, fmt.Errorf("L1InfoRoot calculation failed: %w", err)
}

// Step 4: Build virtual batch with merged data
stateBatch := &state.Batch{
BatchNumber: sequence.ToBatchNumber, // Use last batch number as sequence identifier
Coinbase: batches[0].Coinbase,
BatchL2Data: mergedL2Data,
StateRoot: batches[0].StateRoot,
LocalExitRoot: batches[0].LocalExitRoot,
AccInputHash: batches[0].AccInputHash,
L1InfoTreeIndex: batches[0].L1InfoTreeIndex,
L1InfoRoot: l1InfoRoot,
Timestamp: batches[0].Timestamp,
GlobalExitRoot: batches[0].GlobalExitRoot,
ChainID: a.cfg.ChainID,
ForkID: a.cfg.ForkId,
}

// Step 5: Generate witness and proof
witness, err := a.rpcClient.GetWitness(stateBatch.BatchNumber, a.cfg.UseFullWitness)
if err != nil {
tmpLogger.Errorf("Error getting witness for sequence: %v", err)
return nil, nil, nil, fmt.Errorf("witness generation failed: %w", err)
}

// Step 6: Create proof object
now := time.Now().Round(time.Microsecond)
proof := &state.Proof{
BatchNumber: stateBatch.BatchNumber,
BatchNumberFinal: stateBatch.BatchNumber,
Prover: &proverName,
ProverID: &proverID,
GeneratingSince: &now,
}

return stateBatch, witness, proof, nil
}

// Helper function to merge L2 data from multiple batches
func mergeBatchL2Data(batches []*state.Batch) []byte {
var merged []byte
for _, b := range batches {
merged = append(merged, b.BatchL2Data...)
}
return merged
}

// Helper function to compute L1InfoRoot for sequence
func calculateSequenceL1InfoRoot(batches []*state.Batch) (common.Hash, error) {
if len(batches) == 0 {
return common.Hash{}, errors.New("empty batch list")
}
return batches[0].L1InfoRoot, nil
}

func (a *Aggregator) tryGenerateBatchProof(ctx context.Context, prover ProverInterface) (bool, error) {
tmpLogger := a.logger.WithFields(
"prover", prover.Name(),
Expand Down Expand Up @@ -1360,6 +1456,97 @@
return true, nil
}

// tryGenerateBatchProof_FromRpcBatch attempts to generate a proof for the latest batch from RPC
func (a *Aggregator) tryGenerateBatchProof_FromRpcBatch(ctx context.Context, prover ProverInterface) (bool, error) {
tmpLogger := a.logger.WithFields(
"prover", prover.Name(),
"proverId", prover.ID(),
"proverAddr", prover.Addr(),
)
tmpLogger.Debug("tryGenerateBatchProof_FromRpcBatch start")

// Get batch data from RPC
batchToProve, witness, proof, err := a.getAndProveBatchFromRPC(ctx, prover)
if err != nil {
return false, err
}

tmpLogger = tmpLogger.WithFields("batch", batchToProve.BatchNumber)

var genProofID *string

// Clean up on error
defer func() {
if err != nil {
tmpLogger.Debug("Deleting proof in progress")
err2 := a.storage.DeleteGeneratedProofs(ctx, proof.BatchNumber, proof.BatchNumberFinal, nil)
if err2 != nil {
tmpLogger.Errorf("Failed to delete proof in progress, err: %v", err2)
}
}
tmpLogger.Debug("tryGenerateBatchProof_FromRpcBatch end")
}()

// Build input for prover
tmpLogger.Infof("Sending zki + batch to the prover, batchNumber [%d]", batchToProve.BatchNumber)
inputProver, err := a.buildInputProver(ctx, batchToProve, witness)
if err != nil {
err = fmt.Errorf("failed to build input prover, %w", err)
tmpLogger.Error(FirstToUpper(err.Error()))
return false, err
}

// Send batch to prover
tmpLogger.Infof("Sending a batch to the prover. OldAccInputHash [%#x], L1InfoRoot [%#x]",
inputProver.PublicInputs.OldAccInputHash, inputProver.PublicInputs.L1InfoRoot)

genProofID, err = prover.BatchProof(inputProver)
if err != nil {
err = fmt.Errorf("failed to get batch proof id, %w", err)
tmpLogger.Error(FirstToUpper(err.Error()))
return false, err
}

proof.ProofID = genProofID

// Wait for the proof to be generated
tmpLogger = tmpLogger.WithFields("proofId", *proof.ProofID)

resGetProof, stateRoot, accInputHash, err := prover.WaitRecursiveProof(ctx, *proof.ProofID)
if err != nil {
err = fmt.Errorf("failed to get proof from prover, %w", err)
tmpLogger.Error(FirstToUpper(err.Error()))
return false, err
}

tmpLogger.Info("Batch proof generated")

if a.cfg.BatchProofSanityCheckEnabled {
a.performSanityChecks(tmpLogger, stateRoot, accInputHash, batchToProve)
}
proof.Proof = resGetProof

// Attempt to build the final proof
finalProofBuilt, finalProofErr := a.tryBuildFinalProof(ctx, prover, proof)
if finalProofErr != nil {
tmpLogger.Errorf("Error trying to build final proof: %v", finalProofErr)
}

if !finalProofBuilt {
proof.GeneratingSince = nil

// Update the batch proof if the final proof has not been generated
err := a.storage.UpdateGeneratedProof(a.ctx, proof, nil)
if err != nil {
err = fmt.Errorf("failed to store batch proof result, %w", err)
tmpLogger.Error(FirstToUpper(err.Error()))
return false, err
}
}

return true, nil
}

func (a *Aggregator) performSanityChecks(tmpLogger *log.Logger, stateRoot, accInputHash common.Hash,
batchToProve *state.Batch) {
// Sanity Check: state root from the proof must match the one from the batch
Expand Down
1 change: 1 addition & 0 deletions aggregator/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
type RPCInterface interface {
GetBatch(batchNumber uint64) (*types.RPCBatch, error)
GetWitness(batchNumber uint64, fullWitness bool) ([]byte, error)
GetLatestSequence() (*state.Sequence, error)
}

type ProverInterface interface {
Expand Down
17 changes: 17 additions & 0 deletions aggregator/mocks/mock_rpc.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 45 additions & 0 deletions rpc/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"errors"
"fmt"
"math/big"
"strconv"
"strings"
"time"

"github.com/0xPolygon/cdk-rpc/rpc"
Expand Down Expand Up @@ -159,3 +161,46 @@ func (b *BatchEndpoints) GetWitness(batchNumber uint64, fullWitness bool) ([]byt

return common.FromHex(witness), nil
}

// GetLatestBatch retrieves the most recent batch from the RPC endpoint
func (b *BatchEndpoints) GetLatestBatch() (*types.RPCBatch, error) {
log.Infof("GetLatestBatch called")

// Create context with timeout
ctx, cancel := context.WithTimeout(context.Background(), b.readTimeout)
defer cancel()

// Call RPC method to get the latest batch number
log.Infof("Calling RPC method to get the latest batch number")
response, err := rpc.JSONRPCCallWithContext(ctx, b.url, "zkevm_getLatestBatchNumber")
if err != nil {
log.Infof("Error getting latest batch number: %v", err)
return nil, fmt.Errorf("error getting latest batch number: %w", err)
}

// Check if the response is nil
if response.Result == nil {
log.Infof("Response result is nil, returning ErrNotFound")
return nil, state.ErrNotFound
}

// Unmarshal the response.Result into a hex string (e.g., "0x71")
var batchNumberHex string
err = json.Unmarshal(response.Result, &batchNumberHex)
if err != nil {
log.Infof("Error unmarshalling latest batch number hex string: %v", err)
return nil, fmt.Errorf("error unmarshalling latest batch number hex string: %w", err)
}

// Convert the hex string to uint64 by trimming the "0x" prefix and parsing in base 16
batchNumber, err := strconv.ParseUint(strings.TrimPrefix(batchNumberHex, "0x"), 16, 64)
if err != nil {
log.Infof("Error parsing hex batch number: %v", err)
return nil, fmt.Errorf("error parsing hex batch number: %w", err)
}

log.Infof("Latest batch number obtained: %d", batchNumber)

// Get the batch data using the latest batch number
return b.GetBatch(batchNumber)
}
19 changes: 19 additions & 0 deletions rpc/types/rpcbatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package types

import (
"fmt"
"time"

"github.com/0xPolygon/cdk/sequencesender/seqsendertypes"
"github.com/0xPolygon/cdk/state"
"github.com/ethereum/go-ethereum/common"
)

Expand Down Expand Up @@ -155,3 +157,20 @@ func (b *RPCBatch) String() string {
func (b *RPCBatch) IsClosed() bool {
return b.closed
}

func (b *RPCBatch) ToStateBatch(chainID uint64, forkID uint64) *state.Batch {
return &state.Batch{
BatchNumber: b.BatchNumber(),
GlobalExitRoot: b.GlobalExitRoot(),
BatchL2Data: b.L2Data(),
Timestamp: time.Unix(int64(b.LastL2BLockTimestamp()), 0),
Coinbase: b.LastCoinbase(),
StateRoot: b.StateRoot(),
LocalExitRoot: b.LocalExitRoot(),
AccInputHash: b.AccInputHash(),
L1InfoTreeIndex: b.L1InfoTreeIndex(),
L1InfoRoot: common.Hash{},
ChainID: chainID,
ForkID: forkID,
}
}
Loading