purify
C++ Purify implementation with native circuit and BPP support
Loading...
Searching...
No Matches
expr.cpp
Go to the documentation of this file.
1// Copyright (c) 2026 Judica, Inc.
2// Distributed under the MIT software license, see the accompanying
3// file COPYING or https://opensource.org/license/mit/.
4
10#include "purify/expr.hpp"
11
12#include <algorithm>
13#include <array>
14#include <cassert>
15#include <charconv>
16#include <limits>
17#include <sstream>
18#include <string_view>
19#include <utility>
20
21namespace {
22
23int compare_field_elements(const purify::FieldElement& lhs, const purify::FieldElement& rhs) {
24 return lhs.to_uint256().compare(rhs.to_uint256());
25}
26
27int compare_symbols(const purify::Symbol& lhs, const purify::Symbol& rhs) {
28 const purify::SymbolLess less;
29 if (less(lhs, rhs)) {
30 return -1;
31 }
32 if (less(rhs, lhs)) {
33 return 1;
34 }
35 return 0;
36}
37
38std::string make_symbol_string(std::string_view prefix, std::uint32_t index, std::string_view suffix = {}) {
39 std::array<char, std::numeric_limits<std::uint32_t>::digits10 + 1> digits{};
40 const auto [ptr, ec] = std::to_chars(digits.data(), digits.data() + digits.size(), index);
41 assert(ec == std::errc() && "uint32_t index must format into a fixed-size decimal buffer");
42
43 std::string out;
44 out.reserve(prefix.size() + static_cast<std::size_t>(ptr - digits.data()) + suffix.size());
45 out.append(prefix);
46 out.append(digits.data(), ptr);
47 out.append(suffix);
48 return out;
49}
50
51} // namespace
52
53namespace purify {
54
55Symbol Symbol::witness(std::uint32_t index) {
56 return {SymbolKind::Witness, index};
57}
58
59Symbol Symbol::left(std::uint32_t index) {
60 return {SymbolKind::Left, index};
61}
62
63Symbol Symbol::right(std::uint32_t index) {
64 return {SymbolKind::Right, index};
65}
66
67Symbol Symbol::output(std::uint32_t index) {
68 return {SymbolKind::Output, index};
69}
70
71Symbol Symbol::commitment(std::uint32_t index) {
73}
74
75std::string Symbol::to_string() const {
76 switch (kind) {
78 return make_symbol_string("v[", index, "]");
80 return make_symbol_string("L", index);
82 return make_symbol_string("R", index);
84 return make_symbol_string("O", index);
86 return make_symbol_string("V", index);
87 }
88 assert(false && "unknown SymbolKind");
89 return "?";
90}
91
92Expr::Expr() : constant_(FieldElement::zero()) {}
93
94Expr::Expr(const FieldElement& value) : constant_(value) {}
95
96Expr::Expr(std::int64_t value) : constant_(FieldElement::from_int(value)) {}
97
99 Expr out;
100 out.linear_.push_back({symbol, FieldElement::one()});
101 return out;
102}
103
104std::string Expr::to_string() const {
105 std::ostringstream out;
106 bool first = true;
107 if (!constant_.is_zero() || linear_.empty()) {
108 out << constant_.to_decimal();
109 first = false;
110 }
111 for (const auto& term : linear_) {
112 if (!first) {
113 out << " + ";
114 }
115 if (term.second == FieldElement::one()) {
116 out << term.first.to_string();
117 } else {
118 out << term.second.to_decimal() << " * " << term.first.to_string();
119 }
120 first = false;
121 }
122 return out.str();
123}
124
125std::optional<FieldElement> Expr::evaluate(const WitnessAssignments& values) const {
126 FieldElement out = constant_;
127 for (const auto& term : linear_) {
128 if (term.first.kind != SymbolKind::Witness) {
129 return std::nullopt;
130 }
131 std::size_t index = term.first.index;
132 if (index >= values.size() || !values[index].has_value()) {
133 return std::nullopt;
134 }
135 out = out + (*values[index] * term.second);
136 }
137 return out;
138}
139
140std::pair<Expr, Expr> Expr::split() const {
141 Expr linear_expr(0);
142 linear_expr.linear_ = linear_;
143 return {Expr(constant_), linear_expr};
144}
145
146void Expr::push_term(const Term& term) {
147 if (term.second.is_zero()) {
148 return;
149 }
150 if (!linear_.empty() && linear_.back().first == term.first) {
151 linear_.back().second = linear_.back().second + term.second;
152 if (linear_.back().second.is_zero()) {
153 linear_.pop_back();
154 }
155 return;
156 }
157 linear_.push_back(term);
158}
159
161 ExprBuilder builder;
162 builder.terms_.reserve(terms);
163 return builder;
164}
165
167 terms_.reserve(terms);
168 return *this;
169}
170
172 constant_ = constant_ + value;
173 return *this;
174}
175
176ExprBuilder& ExprBuilder::add(std::int64_t value) {
177 return add(FieldElement::from_int(value));
178}
179
181 if (!scale.is_zero()) {
182 terms_.push_back({symbol, scale});
183 }
184 return *this;
185}
186
188 constant_ = constant_ + expr.constant();
189 if (expr.linear().empty()) {
190 return *this;
191 }
192 terms_.reserve(terms_.size() + expr.linear().size());
193 for (const auto& term : expr.linear()) {
194 terms_.push_back(term);
195 }
196 return *this;
197}
198
200 constant_ = constant_ - expr.constant();
201 if (expr.linear().empty()) {
202 return *this;
203 }
204 terms_.reserve(terms_.size() + expr.linear().size());
205 for (const auto& term : expr.linear()) {
206 FieldElement coeff = term.second.negate();
207 if (!coeff.is_zero()) {
208 terms_.push_back({term.first, coeff});
209 }
210 }
211 return *this;
212}
213
215 if (scale.is_zero()) {
216 return *this;
217 }
218 if (scale.is_one()) {
219 return add(expr);
220 }
221 if (scale == FieldElement::from_int(-1)) {
222 return subtract(expr);
223 }
224 constant_ = constant_ + expr.constant() * scale;
225 if (expr.linear().empty()) {
226 return *this;
227 }
228 terms_.reserve(terms_.size() + expr.linear().size());
229 for (const auto& term : expr.linear()) {
230 FieldElement coeff = term.second * scale;
231 if (!coeff.is_zero()) {
232 terms_.push_back({term.first, coeff});
233 }
234 }
235 return *this;
236}
237
238ExprBuilder& ExprBuilder::add_scaled(const Expr& expr, std::int64_t scale) {
239 return add_scaled(expr, FieldElement::from_int(scale));
240}
241
243 Expr out(constant_);
244 if (terms_.empty()) {
245 return out;
246 }
247 std::sort(terms_.begin(), terms_.end(), [](const Expr::Term& lhs, const Expr::Term& rhs) {
248 return SymbolLess{}(lhs.first, rhs.first);
249 });
250 auto& linear = out.linear();
251 linear.reserve(terms_.size());
252 for (const auto& term : terms_) {
253 if (!linear.empty() && linear.back().first == term.first) {
254 linear.back().second = linear.back().second + term.second;
255 if (linear.back().second.is_zero()) {
256 linear.pop_back();
257 }
258 } else if (!term.second.is_zero()) {
259 linear.push_back(term);
260 }
261 }
262 return out;
263}
264
265Expr operator+(const Expr& lhs, const Expr& rhs) {
266 Expr out(lhs.constant_ + rhs.constant_);
267 out.linear_.reserve(lhs.linear_.size() + rhs.linear_.size());
268 const SymbolLess less;
269 std::size_t i = 0;
270 std::size_t j = 0;
271 while (i < lhs.linear_.size() || j < rhs.linear_.size()) {
272 if (j == rhs.linear_.size() || (i < lhs.linear_.size() && less(lhs.linear_[i].first, rhs.linear_[j].first))) {
273 out.push_term(lhs.linear_[i]);
274 ++i;
275 } else if (i == lhs.linear_.size() || less(rhs.linear_[j].first, lhs.linear_[i].first)) {
276 out.push_term(rhs.linear_[j]);
277 ++j;
278 } else {
279 out.push_term({lhs.linear_[i].first, lhs.linear_[i].second + rhs.linear_[j].second});
280 ++i;
281 ++j;
282 }
283 }
284 return out;
285}
286
287Expr operator+(const Expr& lhs, std::int64_t rhs) {
288 return lhs + Expr(rhs);
289}
290
291Expr operator+(std::int64_t lhs, const Expr& rhs) {
292 return Expr(lhs) + rhs;
293}
294
295Expr operator-(const Expr& lhs, const Expr& rhs) {
296 return lhs + (-rhs);
297}
298
299Expr operator-(const Expr& lhs, std::int64_t rhs) {
300 return lhs - Expr(rhs);
301}
302
303Expr operator-(std::int64_t lhs, const Expr& rhs) {
304 return Expr(lhs) - rhs;
305}
306
307Expr operator-(const Expr& value) {
308 return value * FieldElement::from_int(-1);
309}
310
311Expr operator*(const Expr& expr, const FieldElement& scalar) {
312 if (scalar.is_zero()) {
313 return Expr(0);
314 }
315 Expr out(expr.constant_ * scalar);
316 out.linear_.reserve(expr.linear_.size());
317 for (const auto& term : expr.linear_) {
318 out.linear_.push_back({term.first, term.second * scalar});
319 }
320 return out;
321}
322
323Expr operator*(const FieldElement& scalar, const Expr& expr) {
324 return expr * scalar;
325}
326
327Expr operator*(const Expr& expr, std::int64_t scalar) {
328 return expr * FieldElement::from_int(scalar);
329}
330
331Expr operator*(std::int64_t scalar, const Expr& expr) {
332 return expr * scalar;
333}
334
335bool operator==(const Expr& lhs, const Expr& rhs) {
336 return lhs.constant_ == rhs.constant_ && lhs.linear_ == rhs.linear_;
337}
338
339bool ExprLess::operator()(const Expr& lhs, const Expr& rhs) const {
340 int constant_cmp = compare_field_elements(lhs.constant(), rhs.constant());
341 if (constant_cmp != 0) {
342 return constant_cmp < 0;
343 }
344 std::size_t common = std::min(lhs.linear().size(), rhs.linear().size());
345 for (std::size_t i = 0; i < common; ++i) {
346 int symbol_cmp = compare_symbols(lhs.linear()[i].first, rhs.linear()[i].first);
347 if (symbol_cmp != 0) {
348 return symbol_cmp < 0;
349 }
350 int coeff_cmp = compare_field_elements(lhs.linear()[i].second, rhs.linear()[i].second);
351 if (coeff_cmp != 0) {
352 return coeff_cmp < 0;
353 }
354 }
355 return lhs.linear().size() < rhs.linear().size();
356}
357
358bool ExprPairLess::operator()(const std::pair<Expr, Expr>& lhs, const std::pair<Expr, Expr>& rhs) const {
359 const ExprLess less;
360 if (less(lhs.first, rhs.first)) {
361 return true;
362 }
363 if (less(rhs.first, lhs.first)) {
364 return false;
365 }
366 return less(lhs.second, rhs.second);
367}
368
369bool operator<(const Expr& lhs, const Expr& rhs) {
370 return ExprLess{}(lhs, rhs);
371}
372
373std::ostream& operator<<(std::ostream& out, const Expr& expr) {
374 out << expr.to_string();
375 return out;
376}
377
378Expr Transcript::secret(const std::optional<FieldElement>& value) {
379 std::size_t index = varmap_.size();
380 assert(index <= static_cast<std::size_t>(std::numeric_limits<std::uint32_t>::max())
381 && "Transcript::secret() witness index must fit in uint32_t");
382 varmap_.push_back(value);
383 return Expr::variable(Symbol::witness(static_cast<std::uint32_t>(index)));
384}
385
386Expr Transcript::mul(const Expr& lhs, const Expr& rhs) {
387 auto direct = std::make_pair(lhs, rhs);
388 auto reverse = std::make_pair(rhs, lhs);
389 auto it = mul_cache_.find(direct);
390 if (it != mul_cache_.end()) {
391 return it->second;
392 }
393 it = mul_cache_.find(reverse);
394 if (it != mul_cache_.end()) {
395 return it->second;
396 }
397 std::optional<FieldElement> lhs_val = lhs.evaluate(varmap_);
398 std::optional<FieldElement> rhs_val = rhs.evaluate(varmap_);
399 std::optional<FieldElement> value;
400 if (lhs_val.has_value() && rhs_val.has_value()) {
401 value = *lhs_val * *rhs_val;
402 }
403 Expr out = secret(value);
404 mul_cache_[direct] = out;
405 muls_.push_back({lhs, rhs, out});
406 return out;
407}
408
409Expr Transcript::div(const Expr& lhs, const Expr& rhs) {
410 auto direct = std::make_pair(lhs, rhs);
411 auto it = div_cache_.find(direct);
412 if (it != div_cache_.end()) {
413 return it->second;
414 }
415 std::optional<FieldElement> lhs_val = lhs.evaluate(varmap_);
416 std::optional<FieldElement> rhs_val = rhs.evaluate(varmap_);
417 assert((!rhs_val.has_value() || !rhs_val->is_zero()) && "Transcript::div() requires a non-zero known divisor");
418 std::optional<FieldElement> value;
419 if (lhs_val.has_value() && rhs_val.has_value()) {
420 value = *lhs_val * rhs_val->inverse();
421 }
422 Expr out = secret(value);
423 div_cache_[direct] = out;
424 muls_.push_back({out, rhs, lhs});
425 return out;
426}
427
428Expr Transcript::boolean(const Expr& expr) {
429 if (bool_cache_.count(expr) != 0) {
430 return expr;
431 }
432#ifndef NDEBUG
433 std::optional<FieldElement> value = expr.evaluate(varmap_);
434 assert((!value.has_value() || *value == FieldElement::zero() || *value == FieldElement::one())
435 && "Transcript::boolean() requires a known value to be 0 or 1");
436#endif
437 bool_cache_.insert(expr);
438 muls_.push_back({expr, expr - 1, Expr(0)});
439 return expr;
440}
441
442void Transcript::equal(const Expr& lhs, const Expr& rhs) {
443 Expr diff = lhs - rhs;
444#ifndef NDEBUG
445 std::optional<FieldElement> value = diff.evaluate(varmap_);
446 assert((!value.has_value() || value->is_zero()) && "Transcript::equal() requires known values to match");
447#endif
448 eqs_.push_back(diff);
449}
450
451std::optional<FieldElement> Transcript::evaluate(const Expr& expr) const {
452 return expr.evaluate(varmap_);
453}
454
455} // namespace purify
Small runtime builder that flattens affine combinations into one expression.
Definition expr.hpp:170
ExprBuilder & add(const FieldElement &value)
Adds a constant field term to the pending affine expression.
Definition expr.cpp:171
ExprBuilder & add_term(Symbol symbol, const FieldElement &scale)
Adds one scaled symbolic variable term.
Definition expr.cpp:180
static ExprBuilder reserved(std::size_t terms)
Returns a builder with storage reserved for approximately terms linear slots.
Definition expr.cpp:160
Expr build()
Materializes the flattened affine expression.
Definition expr.cpp:242
ExprBuilder & add_scaled(const Expr &expr, const FieldElement &scale)
Adds an existing expression scaled by a field element.
Definition expr.cpp:214
ExprBuilder & subtract(const Expr &expr)
Subtracts an existing expression with implicit coefficient minus one.
Definition expr.cpp:199
ExprBuilder & reserve(std::size_t terms)
Reserves storage for approximately terms linear slots.
Definition expr.cpp:166
Symbolic affine expression over indexed variables and field coefficients.
Definition expr.hpp:71
const FieldElement & constant() const
Returns the constant term of the affine expression.
Definition expr.hpp:86
std::string to_string() const
Formats the expression in a stable human-readable form used for debugging and serialization.
Definition expr.cpp:104
std::vector< Term > & linear()
Returns mutable access to the sorted linear term list.
Definition expr.hpp:91
Expr()
Constructs the zero expression.
Definition expr.cpp:92
std::pair< Symbol, FieldElement > Term
Definition expr.hpp:73
static Expr variable(Symbol symbol)
Returns a single-variable expression with coefficient one.
Definition expr.cpp:98
std::optional< FieldElement > evaluate(const WitnessAssignments &values) const
Evaluates the expression against a possibly partial transcript witness assignment.
Definition expr.cpp:125
std::pair< Expr, Expr > split() const
Splits the expression into a pure constant and a pure linear component.
Definition expr.cpp:140
Field element modulo the backend scalar field used by this implementation.
Definition numeric.hpp:815
bool is_one() const
Returns true when the element is one.
Definition numeric.cpp:108
std::string to_decimal() const
Formats the field element as an unsigned decimal string.
Definition numeric.cpp:100
static FieldElement one()
Returns the multiplicative identity of the scalar field.
Definition numeric.cpp:36
FieldElement negate() const
Returns the additive inverse modulo the field prime.
Definition numeric.cpp:121
static FieldElement from_int(std::int64_t value)
Constructs a field element from a signed integer, reducing negatives modulo the field.
Definition numeric.cpp:46
UInt256 to_uint256() const
Exports the field element as a canonical 256-bit unsigned integer.
Definition numeric.cpp:79
bool is_zero() const
Returns true when the element is zero.
Definition numeric.cpp:104
Symbolic expression and transcript machinery used to derive Purify circuits.
Definition api.hpp:21
bool operator<(const Symbol &lhs, const Symbol &rhs) noexcept
Definition expr.hpp:58
Expr operator*(const Expr &expr, const FieldElement &scalar)
Definition expr.cpp:311
std::ostream & operator<<(std::ostream &out, const Expr &expr)
Streams the human-readable expression form to an output stream.
Definition expr.cpp:373
bool operator==(const Expr &lhs, const Expr &rhs)
Definition expr.cpp:335
Expr operator-(const Expr &lhs, const Expr &rhs)
Definition expr.cpp:295
std::vector< std::optional< FieldElement > > WitnessAssignments
Partial witness assignment vector indexed by transcript witness id.
Definition expr.hpp:63
Bytes operator+(Bytes lhs, const Bytes &rhs)
Concatenates two byte vectors.
Definition curve.cpp:167
Scalar32 scalar
Definition bppp.cpp:119
int compare(const BigUInt &other) const
Compares two fixed-width integers using unsigned ordering.
Definition numeric.hpp:301
Compact symbolic variable identifier used inside expressions and transcripts.
Definition expr.hpp:35
static Symbol witness(std::uint32_t index)
Definition expr.cpp:55
static Symbol left(std::uint32_t index)
Definition expr.cpp:59
static Symbol output(std::uint32_t index)
Definition expr.cpp:67
std::uint32_t index
Definition expr.hpp:37
std::string to_string() const
Definition expr.cpp:75
static Symbol commitment(std::uint32_t index)
Definition expr.cpp:71
static Symbol right(std::uint32_t index)
Definition expr.cpp:63
SymbolKind kind
Definition expr.hpp:36