source: trunk/misc/simulators/hashbasedsig.py

Last change on this file was 4711c9f, checked in by Itamar Turner-Trauring <itamar@…>, at 2021-05-07T13:44:44Z

More flake fixes.

  • Property mode set to 100644
File size: 13.9 KB
Line 
1#!python
2
3from __future__ import print_function
4
5# range of hash output lengths
6range_L_hash = [128]
7
8lg_M = 53                   # lg(required number of signatures before losing security)
9
10limit_bytes = 480000        # limit on signature length
11limit_cost = 500            # limit on Mcycles_Sig + weight_ver*Mcycles_ver
12weight_ver = 1              # how important verification cost is relative to signature cost
13                            # (note: setting this too high will just exclude useful candidates)
14
15L_block = 512               # bitlength of hash input blocks
16L_pad   = 64                # bitlength of hash padding overhead (for M-D hashes)
17L_label = 80                # bitlength of hash position label
18L_prf   = 256               # bitlength of hash output when used as a PRF
19cycles_per_byte = 15.8      # cost of hash
20
21Mcycles_per_block = cycles_per_byte * L_block / (8 * 1000000.0)
22
23
24from math import floor, ceil, log, log1p, pow, e
25from sys import stderr
26from gc import collect
27
28def lg(x):
29    return log(x, 2)
30def ln(x):
31    return log(x, e)
32def ceil_log(x, B):
33    return int(ceil(log(x, B)))
34def ceil_div(x, y):
35    return int(ceil(float(x) / float(y)))
36def floor_div(x, y):
37    return int(floor(float(x) / float(y)))
38
39# number of compression function evaluations to hash k bits
40# we assume that there is a label in each block
41def compressions(k):
42    return ceil_div(k + L_pad, L_block - L_label)
43
44# sum of power series sum([pow(p, i) for i in range(n)])
45def sum_powers(p, n):
46    if p == 1: return n
47    return (pow(p, n) - 1)/(p - 1)
48
49
50def make_candidate(B, K, K1, K2, q, T, T_min, L_hash, lg_N, sig_bytes, c_sign, c_ver, c_ver_pm):
51    Mcycles_sign   = c_sign   * Mcycles_per_block
52    Mcycles_ver    = c_ver    * Mcycles_per_block
53    Mcycles_ver_pm = c_ver_pm * Mcycles_per_block
54    cost = Mcycles_sign + weight_ver*Mcycles_ver
55
56    if sig_bytes >= limit_bytes or cost > limit_cost:
57        return []
58
59    return [{
60        'B': B, 'K': K, 'K1': K1, 'K2': K2, 'q': q, 'T': T,
61        'T_min': T_min,
62        'L_hash': L_hash,
63        'lg_N': lg_N,
64        'sig_bytes': sig_bytes,
65        'c_sign': c_sign,
66        'Mcycles_sign': Mcycles_sign,
67        'c_ver': c_ver,
68        'c_ver_pm': c_ver_pm,
69        'Mcycles_ver': Mcycles_ver,
70        'Mcycles_ver_pm': Mcycles_ver_pm,
71        'cost': cost,
72    }]
73
74
75# K1 = size of root Merkle tree
76# K  = size of middle Merkle trees
77# K2 = size of leaf Merkle trees
78# q  = number of revealed private keys per signed message
79
80# Winternitz with B < 4 is never optimal. For example, going from B=4 to B=2 halves the
81# chain depth, but that is cancelled out by doubling (roughly) the number of digits.
82range_B = range(4, 33)
83
84M = pow(2, lg_M)
85
86def calculate(K, K1, K2, q_max, L_hash, trees):
87    candidates = []
88    lg_K  = lg(K)
89    lg_K1 = lg(K1)
90    lg_K2 = lg(K2)
91
92    # We want the optimal combination of q and T. That takes too much time and memory
93    # to search for directly, so we start by calculating the lowest possible value of T
94    # for any q. Then for potential values of T, we calculate the smallest q such that we
95    # will have at least L_hash bits of security against forgery using revealed private keys
96    # (i.e. this method of forgery is no easier than finding a hash preimage), provided
97    # that fewer than 2^lg_S_min messages are signed.
98
99    # min height of certification tree (excluding root and bottom layer)
100    T_min = ceil_div(lg_M - lg_K1, lg_K)
101
102    last_q = None
103    for T in range(T_min, T_min+21):
104        # lg(total number of leaf private keys)
105        lg_S = lg_K1 + lg_K*T
106        lg_N = lg_S + lg_K2
107
108        # Suppose that m signatures have been made. The number of times X that a given bucket has
109        # been chosen follows a binomial distribution B(m, p) where p = 1/S and S is the number of
110        # buckets. I.e. Pr(X = x) = C(m, x) * p^x * (1-p)^(m-x).
111        #
112        # If an attacker picks a random seed and message that falls into a bucket that has been
113        # chosen x times, then at most q*x private values in that bucket have been revealed, so
114        # (ignoring the possibility of guessing private keys, which is negligable) the attacker's
115        # success probability for a forgery using the revealed values is at most min(1, q*x / K2)^q.
116        #
117        # Let j = floor(K2/q). Conditioning on x, we have
118        #
119        # Pr(forgery) = sum_{x = 0..j}(Pr(X = x) * (q*x / K2)^q) + Pr(x > j)
120        #             = sum_{x = 1..j}(Pr(X = x) * (q*x / K2)^q) + Pr(x > j)
121        #
122        # We lose nothing by approximating (q*x / K2)^q as 1 for x > 4, i.e. ignoring the resistance
123        # of the HORS scheme to forgery when a bucket has been chosen 5 or more times.
124        #
125        # Pr(forgery) < sum_{x = 1..4}(Pr(X = x) * (q*x / K2)^q) + Pr(x > 4)
126        #
127        # where Pr(x > 4) = 1 - sum_{x = 0..4}(Pr(X = x))
128        #
129        # We use log arithmetic here because values very close to 1 cannot be represented accurately
130        # in floating point, but their logarithms can (provided we use appropriate functions such as
131        # log1p).
132
133        lg_p = -lg_S
134        lg_1_p = log1p(-pow(2, lg_p))/ln(2)        # lg(1-p), computed accurately
135        j = 5
136        lg_px = [lg_1_p * M]*j
137
138        # We approximate lg(M-x) as lg(M)
139        lg_px_step = lg_M + lg_p - lg_1_p
140        for x in range(1, j):
141            lg_px[x] = lg_px[x-1] - lg(x) + lg_px_step
142
143        q = None
144        # Find the minimum acceptable value of q.
145        for q_cand in range(1, q_max+1):
146            lg_q = lg(q_cand)
147            lg_pforge = [lg_px[x] + (lg_q*x - lg_K2)*q_cand for x in range(1, j)]
148            if max(lg_pforge) < -L_hash + lg(j) and lg_px[j-1] + 1.0 < -L_hash:
149                #print("K = %d, K1 = %d, K2 = %d, L_hash = %d, lg_K2 = %.3f, q = %d, lg_pforge_1 = %.3f, lg_pforge_2 = %.3f, lg_pforge_3 = %.3f"
150                #      % (K, K1, K2, L_hash, lg_K2, q, lg_pforge_1, lg_pforge_2, lg_pforge_3))
151                q = q_cand
152                break
153
154        if q is None or q == last_q:
155            # if q hasn't decreased, this will be strictly worse than the previous candidate
156            continue
157        last_q = q
158
159        # number of compressions to compute the Merkle hashes
160        (h_M,  c_M,  _) = trees[K]
161        (h_M1, c_M1, _) = trees[K1]
162        (h_M2, c_M2, (dau, tri)) = trees[K2]
163
164        # B = generalized Winternitz base
165        for B in range_B:
166            # n is the number of digits needed to sign the message representative and checksum.
167            # The representation is base-B, except that we allow the most significant digit
168            # to be up to 2B-1.
169            n_L = ceil_div(L_hash-1, lg(B))
170            firstL_max = floor_div(pow(2, L_hash)-1, pow(B, n_L-1))
171            C_max = firstL_max + (n_L-1)*(B-1)
172            n_C = ceil_log(ceil_div(C_max, 2), B)
173            n = n_L + n_C
174            firstC_max = floor_div(C_max, pow(B, n_C-1))
175
176            # Total depth of Winternitz hash chains. The chains for the most significant
177            # digit of the message representative and of the checksum may be a different
178            # length to those for the other digits.
179            c_D = (n-2)*(B-1) + firstL_max + firstC_max
180
181            # number of compressions to hash a Winternitz public key
182            c_W = compressions(n*L_hash)
183
184            # bitlength of a single Winternitz signature and authentication path
185            L_MW  = (n + h_M ) * L_hash
186            L_MW1 = (n + h_M1) * L_hash
187
188            # bitlength of the HORS signature and authentication paths
189            # For all but one of the q authentication paths, one of the sibling elements in
190            # another path is made redundant where they intersect. This cancels out the hash
191            # that would otherwise be needed at the bottom of the path, making the total
192            # length of the signature q*h_M2 + 1 hashes, rather than q*(h_M2 + 1).
193            L_leaf = (q*h_M2 + 1) * L_hash
194
195            # length of the overall GMSS+HORS signature and seeds
196            sig_bytes = ceil_div(L_MW1 + T*L_MW + L_leaf + L_prf + ceil(lg_N), 8)
197
198            c_MW  = K *(c_D + c_W) + c_M  + ceil_div(K *n*L_hash, L_prf)
199            c_MW1 = K1*(c_D + c_W) + c_M1 + ceil_div(K1*n*L_hash, L_prf)
200
201            # For simplicity, c_sign and c_ver don't take into account compressions saved
202            # as a result of intersecting authentication paths in the HORS signature, so
203            # are slight overestimates.
204
205            c_sign = c_MW1 + T*c_MW + q*(c_M2 + 1) + ceil_div(K2*L_hash, L_prf)
206
207            # *expected* number of compressions to verify a signature
208            c_ver = c_D/2.0 + c_W + c_M1 + T*(c_D/2.0 + c_W + c_M) + q*(c_M2 + 1)
209            c_ver_pm = (1 + T)*c_D/2.0
210
211            candidates += make_candidate(B, K, K1, K2, q, T, T_min, L_hash, lg_N, sig_bytes, c_sign, c_ver, c_ver_pm)
212
213    return candidates
214
215def search():
216    for L_hash in range_L_hash:
217        print("collecting...   \r", end=' ', file=stderr)
218        collect()
219
220        print("precomputing... \r", end=' ', file=stderr)
221
222        """
223        # d/dq (lg(q+1) + L_hash/q) = 1/(ln(2)*(q+1)) - L_hash/q^2
224        # Therefore lg(q+1) + L_hash/q is at a minimum when 1/(ln(2)*(q+1)) = L_hash/q^2.
225        # Let alpha = L_hash*ln(2), then from the quadratic formula, the integer q that
226        # minimizes lg(q+1) + L_hash/q is the floor or ceiling of (alpha + sqrt(alpha^2 - 4*alpha))/2.
227        # (We don't want the other solution near 0.)
228
229        alpha = floor(L_hash*ln(2))  # float
230        q = floor((alpha + sqrt(alpha*(alpha-4)))/2)
231        if lg(q+2) + L_hash/(q+1) < lg(q+1) + L_hash/q:
232            q += 1
233        lg_S_margin = lg(q+1) + L_hash/q
234        q_max = int(q)
235
236        q = floor(L_hash*ln(2))  # float
237        if lg(q+1) + L_hash/(q+1) < lg(q) + L_hash/q:
238            q += 1
239        lg_S_margin = lg(q) + L_hash/q
240        q_max = int(q)
241        """
242        q_max = 4000
243
244        # find optimal Merkle tree shapes for this L_hash and each K
245        trees = {}
246        K_max = 50
247        c2 = compressions(2*L_hash)
248        c3 = compressions(3*L_hash)
249        for dau in range(0, 10):
250            a = pow(2, dau)
251            for tri in range(0, ceil_log(30-dau, 3)):
252                x = int(a*pow(3, tri))
253                h = dau + 2*tri
254                c_x = int(sum_powers(2, dau)*c2 + a*sum_powers(3, tri)*c3)
255                for y in range(1, x+1):
256                    if tri > 0:
257                        # If the bottom level has arity 3, then for every 2 nodes by which the tree is
258                        # imperfect, we can save c3 compressions by pruning 3 leaves back to their parent.
259                        # If the tree is imperfect by an odd number of nodes, we can prune one extra leaf,
260                        # possibly saving a compression if c2 < c3.
261                        c_y = c_x - floor_div(x-y, 2)*c3 - ((x-y) % 2)*(c3-c2)
262                    else:
263                        # If the bottom level has arity 2, then for each node by which the tree is
264                        # imperfect, we can save c2 compressions by pruning 2 leaves back to their parent.
265                        c_y = c_x - (x-y)*c2
266
267                    if y not in trees or (h, c_y, (dau, tri)) < trees[y]:
268                        trees[y] = (h, c_y, (dau, tri))
269
270        #for x in range(1, K_max+1):
271        #    print(x, trees[x])
272
273        candidates = []
274        progress = 0
275        fuzz = 0
276        complete = (K_max-1)*(2200-200)/100
277        for K in range(2, K_max+1):
278            for K2 in range(200, 2200, 100):
279                for K1 in range(max(2, K-fuzz), min(K_max, K+fuzz)+1):
280                    candidates += calculate(K, K1, K2, q_max, L_hash, trees)
281                progress += 1
282                print("searching: %3d %% \r" % (100.0 * progress / complete,), end=' ', file=stderr)
283
284        print("filtering...    \r", end=' ', file=stderr)
285        step = 2.0
286        bins = {}
287        limit = floor_div(limit_cost, step)
288        for bin in range(0, limit+2):
289            bins[bin] = []
290
291        for c in candidates:
292            bin = floor_div(c['cost'], step)
293            bins[bin] += [c]
294
295        del candidates
296
297        # For each in a range of signing times, find the best candidate.
298        best = []
299        for bin in range(0, limit):
300            candidates = bins[bin] + bins[bin+1] + bins[bin+2]
301            if len(candidates) > 0:
302                best += [min(candidates, key=lambda c: c['sig_bytes'])]
303
304        def format_candidate(candidate):
305            return ("%(B)3d  %(K)3d  %(K1)3d  %(K2)5d %(q)4d %(T)4d  "
306                    "%(L_hash)4d   %(lg_N)5.1f  %(sig_bytes)7d   "
307                    "%(c_sign)7d (%(Mcycles_sign)7.2f) "
308                    "%(c_ver)7d +/-%(c_ver_pm)5d (%(Mcycles_ver)5.2f +/-%(Mcycles_ver_pm)5.2f)   "
309                   ) % candidate
310
311        print("                \r", end=' ', file=stderr)
312        if len(best) > 0:
313            print("  B    K   K1     K2    q    T  L_hash  lg_N  sig_bytes  c_sign (Mcycles)        c_ver     (    Mcycles   )")
314            print("---- ---- ---- ------ ---- ---- ------ ------ --------- ------------------ --------------------------------")
315
316            best.sort(key=lambda c: (c['sig_bytes'], c['cost']))
317            last_sign = None
318            last_ver = None
319            for c in best:
320                if last_sign is None or c['c_sign'] < last_sign or c['c_ver'] < last_ver:
321                    print(format_candidate(c))
322                    last_sign = c['c_sign']
323                    last_ver = c['c_ver']
324
325            print()
326        else:
327            print("No candidates found for L_hash = %d or higher." % (L_hash))
328            return
329
330        del bins
331        del best
332
333print("Maximum signature size: %d bytes" % (limit_bytes,))
334print("Maximum (signing + %d*verification) cost: %.1f Mcycles" % (weight_ver, limit_cost))
335print("Hash parameters: %d-bit blocks with %d-bit padding and %d-bit labels, %.2f cycles per byte" \
336      % (L_block, L_pad, L_label, cycles_per_byte))
337print("PRF output size: %d bits" % (L_prf,))
338print("Security level given by L_hash is maintained for up to 2^%d signatures.\n" % (lg_M,))
339
340search()
Note: See TracBrowser for help on using the repository browser.