source: trunk/src/allmydata/test/test_encode.py

Last change on this file was 168532d, checked in by Alexandre Detiste <alexandre.detiste@…>, at 2024-12-21T18:03:57Z

finish removing "future"

  • Property mode set to 100644
File size: 15.4 KB
Line 
1"""
2Ported to Python 3.
3"""
4
5from zope.interface import implementer
6from twisted.trial import unittest
7from twisted.internet import defer
8from twisted.python.failure import Failure
9from foolscap.api import fireEventually
10from allmydata import uri
11from allmydata.immutable import encode, upload, checker
12from allmydata.util import hashutil
13from allmydata.util.assertutil import _assert
14from allmydata.util.consumer import download_to_data
15from allmydata.interfaces import IStorageBucketWriter, IStorageBucketReader
16from allmydata.test.no_network import GridTestMixin
17
18class LostPeerError(Exception):
19    pass
20
21def byteschr(x):
22    return bytes([x])
23
24def flip_bit(good): # flips the last bit
25    return good[:-1] + byteschr(ord(good[-1]) ^ 0x01)
26
27@implementer(IStorageBucketWriter, IStorageBucketReader)
28class FakeBucketReaderWriterProxy(object):
29    # these are used for both reading and writing
30    def __init__(self, mode="good", peerid="peer"):
31        self.mode = mode
32        self.blocks = {}
33        self.plaintext_hashes = []
34        self.crypttext_hashes = []
35        self.block_hashes = None
36        self.share_hashes = None
37        self.closed = False
38        self.peerid = peerid
39
40    def get_peerid(self):
41        return self.peerid
42
43    def _start(self):
44        if self.mode == "lost-early":
45            f = Failure(LostPeerError("I went away early"))
46            return fireEventually(f)
47        return defer.succeed(self)
48
49    def put_header(self):
50        return self._start()
51
52    def put_block(self, segmentnum, data):
53        if self.mode == "lost-early":
54            f = Failure(LostPeerError("I went away early"))
55            return fireEventually(f)
56        def _try():
57            assert not self.closed
58            assert segmentnum not in self.blocks
59            if self.mode == "lost" and segmentnum >= 1:
60                raise LostPeerError("I'm going away now")
61            self.blocks[segmentnum] = data
62        return defer.maybeDeferred(_try)
63
64    def put_crypttext_hashes(self, hashes):
65        def _try():
66            assert not self.closed
67            assert not self.crypttext_hashes
68            self.crypttext_hashes = hashes
69        return defer.maybeDeferred(_try)
70
71    def put_block_hashes(self, blockhashes):
72        def _try():
73            assert not self.closed
74            assert self.block_hashes is None
75            self.block_hashes = blockhashes
76        return defer.maybeDeferred(_try)
77
78    def put_share_hashes(self, sharehashes):
79        def _try():
80            assert not self.closed
81            assert self.share_hashes is None
82            self.share_hashes = sharehashes
83        return defer.maybeDeferred(_try)
84
85    def put_uri_extension(self, uri_extension):
86        def _try():
87            assert not self.closed
88            self.uri_extension = uri_extension
89        return defer.maybeDeferred(_try)
90
91    def close(self):
92        def _try():
93            assert not self.closed
94            self.closed = True
95        return defer.maybeDeferred(_try)
96
97    def abort(self):
98        return defer.succeed(None)
99
100    def get_block_data(self, blocknum, blocksize, size):
101        d = self._start()
102        def _try(unused=None):
103            assert isinstance(blocknum, int)
104            if self.mode == "bad block":
105                return flip_bit(self.blocks[blocknum])
106            return self.blocks[blocknum]
107        d.addCallback(_try)
108        return d
109
110    def get_plaintext_hashes(self):
111        d = self._start()
112        def _try(unused=None):
113            hashes = self.plaintext_hashes[:]
114            return hashes
115        d.addCallback(_try)
116        return d
117
118    def get_crypttext_hashes(self):
119        d = self._start()
120        def _try(unused=None):
121            hashes = self.crypttext_hashes[:]
122            if self.mode == "bad crypttext hashroot":
123                hashes[0] = flip_bit(hashes[0])
124            if self.mode == "bad crypttext hash":
125                hashes[1] = flip_bit(hashes[1])
126            return hashes
127        d.addCallback(_try)
128        return d
129
130    def get_block_hashes(self, at_least_these=()):
131        d = self._start()
132        def _try(unused=None):
133            if self.mode == "bad blockhash":
134                hashes = self.block_hashes[:]
135                hashes[1] = flip_bit(hashes[1])
136                return hashes
137            return self.block_hashes
138        d.addCallback(_try)
139        return d
140
141    def get_share_hashes(self, at_least_these=()):
142        d = self._start()
143        def _try(unused=None):
144            if self.mode == "bad sharehash":
145                hashes = self.share_hashes[:]
146                hashes[1] = (hashes[1][0], flip_bit(hashes[1][1]))
147                return hashes
148            if self.mode == "missing sharehash":
149                # one sneaky attack would be to pretend we don't know our own
150                # sharehash, which could manage to frame someone else.
151                # download.py is supposed to guard against this case.
152                return []
153            return self.share_hashes
154        d.addCallback(_try)
155        return d
156
157    def get_uri_extension(self):
158        d = self._start()
159        def _try(unused=None):
160            if self.mode == "bad uri_extension":
161                return flip_bit(self.uri_extension)
162            return self.uri_extension
163        d.addCallback(_try)
164        return d
165
166
167def make_data(length):
168    data = b"happy happy joy joy" * 100
169    assert length <= len(data)
170    return data[:length]
171
172class ValidatedExtendedURIProxy(unittest.TestCase):
173    K = 4
174    M = 10
175    SIZE = 200
176    SEGSIZE = 72
177    _TMP = SIZE%SEGSIZE
178    if _TMP == 0:
179        _TMP = SEGSIZE
180    if _TMP % K != 0:
181        _TMP += (K - (_TMP % K))
182    TAIL_SEGSIZE = _TMP
183    _TMP = SIZE // SEGSIZE
184    if SIZE % SEGSIZE != 0:
185        _TMP += 1
186    NUM_SEGMENTS = _TMP
187    mindict = { 'segment_size': SEGSIZE,
188                'crypttext_root_hash': b'0'*hashutil.CRYPTO_VAL_SIZE,
189                'share_root_hash': b'1'*hashutil.CRYPTO_VAL_SIZE }
190    optional_consistent = { 'crypttext_hash': b'2'*hashutil.CRYPTO_VAL_SIZE,
191                            'codec_name': b"crs",
192                            'codec_params': b"%d-%d-%d" % (SEGSIZE, K, M),
193                            'tail_codec_params': b"%d-%d-%d" % (TAIL_SEGSIZE, K, M),
194                            'num_segments': NUM_SEGMENTS,
195                            'size': SIZE,
196                            'needed_shares': K,
197                            'total_shares': M,
198                            'plaintext_hash': b"anything",
199                            'plaintext_root_hash': b"anything", }
200    # optional_inconsistent = { 'crypttext_hash': ('2'*(hashutil.CRYPTO_VAL_SIZE-1), "", 77),
201    optional_inconsistent = { 'crypttext_hash': (77,),
202                              'codec_name': (b"digital fountain", b""),
203                              'codec_params': (b"%d-%d-%d" % (SEGSIZE, K-1, M),
204                                               b"%d-%d-%d" % (SEGSIZE-1, K, M),
205                                               b"%d-%d-%d" % (SEGSIZE, K, M-1)),
206                              'tail_codec_params': (b"%d-%d-%d" % (TAIL_SEGSIZE, K-1, M),
207                                               b"%d-%d-%d" % (TAIL_SEGSIZE-1, K, M),
208                                               b"%d-%d-%d" % (TAIL_SEGSIZE, K, M-1)),
209                              'num_segments': (NUM_SEGMENTS-1,),
210                              'size': (SIZE-1,),
211                              'needed_shares': (K-1,),
212                              'total_shares': (M-1,), }
213
214    def _test(self, uebdict):
215        uebstring = uri.pack_extension(uebdict)
216        uebhash = hashutil.uri_extension_hash(uebstring)
217        fb = FakeBucketReaderWriterProxy()
218        fb.put_uri_extension(uebstring)
219        verifycap = uri.CHKFileVerifierURI(storage_index=b'x'*16, uri_extension_hash=uebhash, needed_shares=self.K, total_shares=self.M, size=self.SIZE)
220        vup = checker.ValidatedExtendedURIProxy(fb, verifycap)
221        return vup.start()
222
223    def _test_accept(self, uebdict):
224        return self._test(uebdict)
225
226    def _should_fail(self, res, expected_failures):
227        if isinstance(res, Failure):
228            res.trap(*expected_failures)
229        else:
230            self.fail("was supposed to raise %s, not get '%s'" % (expected_failures, res))
231
232    def _test_reject(self, uebdict):
233        d = self._test(uebdict)
234        d.addBoth(self._should_fail, (KeyError, checker.BadURIExtension))
235        return d
236
237    def test_accept_minimal(self):
238        return self._test_accept(self.mindict)
239
240    def test_reject_insufficient(self):
241        dl = []
242        for k in self.mindict.keys():
243            insuffdict = self.mindict.copy()
244            del insuffdict[k]
245            d = self._test_reject(insuffdict)
246        dl.append(d)
247        return defer.DeferredList(dl)
248
249    def test_accept_optional(self):
250        dl = []
251        for k in self.optional_consistent.keys():
252            mydict = self.mindict.copy()
253            mydict[k] = self.optional_consistent[k]
254            d = self._test_accept(mydict)
255        dl.append(d)
256        return defer.DeferredList(dl)
257
258    def test_reject_optional(self):
259        dl = []
260        for k in self.optional_inconsistent.keys():
261            for v in self.optional_inconsistent[k]:
262                mydict = self.mindict.copy()
263                mydict[k] = v
264                d = self._test_reject(mydict)
265                dl.append(d)
266        return defer.DeferredList(dl)
267
268class Encode(unittest.TestCase):
269    def do_encode(self, max_segment_size, datalen, NUM_SHARES, NUM_SEGMENTS,
270                  expected_block_hashes, expected_share_hashes):
271        data = make_data(datalen)
272        # force use of multiple segments
273        e = encode.Encoder()
274        u = upload.Data(data, convergence=b"some convergence string")
275        u.set_default_encoding_parameters({'max_segment_size': max_segment_size,
276                                           'k': 25, 'happy': 75, 'n': 100})
277        eu = upload.EncryptAnUploadable(u)
278        d = e.set_encrypted_uploadable(eu)
279
280        all_shareholders = []
281        def _ready(res):
282            k,happy,n = e.get_param("share_counts")
283            _assert(n == NUM_SHARES) # else we'll be completely confused
284            numsegs = e.get_param("num_segments")
285            _assert(numsegs == NUM_SEGMENTS, numsegs, NUM_SEGMENTS)
286            segsize = e.get_param("segment_size")
287            _assert( (NUM_SEGMENTS-1)*segsize < len(data) <= NUM_SEGMENTS*segsize,
288                     NUM_SEGMENTS, segsize,
289                     (NUM_SEGMENTS-1)*segsize, len(data), NUM_SEGMENTS*segsize)
290
291            shareholders = {}
292            servermap = {}
293            for shnum in range(NUM_SHARES):
294                peer = FakeBucketReaderWriterProxy()
295                shareholders[shnum] = peer
296                servermap.setdefault(shnum, set()).add(peer.get_peerid())
297                all_shareholders.append(peer)
298            e.set_shareholders(shareholders, servermap)
299            return e.start()
300        d.addCallback(_ready)
301
302        def _check(res):
303            verifycap = res
304            self.failUnless(isinstance(verifycap.uri_extension_hash, bytes))
305            self.failUnlessEqual(len(verifycap.uri_extension_hash), 32)
306            for i,peer in enumerate(all_shareholders):
307                self.failUnless(peer.closed)
308                self.failUnlessEqual(len(peer.blocks), NUM_SEGMENTS)
309                # each peer gets a full tree of block hashes. For 3 or 4
310                # segments, that's 7 hashes. For 5 segments it's 15 hashes.
311                self.failUnlessEqual(len(peer.block_hashes),
312                                     expected_block_hashes)
313                for h in peer.block_hashes:
314                    self.failUnlessEqual(len(h), 32)
315                # each peer also gets their necessary chain of share hashes.
316                # For 100 shares (rounded up to 128 leaves), that's 8 hashes
317                self.failUnlessEqual(len(peer.share_hashes),
318                                     expected_share_hashes)
319                for (hashnum, h) in peer.share_hashes:
320                    self.failUnless(isinstance(hashnum, int))
321                    self.failUnlessEqual(len(h), 32)
322        d.addCallback(_check)
323
324        return d
325
326    def test_send_74(self):
327        # 3 segments (25, 25, 24)
328        return self.do_encode(25, 74, 100, 3, 7, 8)
329    def test_send_75(self):
330        # 3 segments (25, 25, 25)
331        return self.do_encode(25, 75, 100, 3, 7, 8)
332    def test_send_51(self):
333        # 3 segments (25, 25, 1)
334        return self.do_encode(25, 51, 100, 3, 7, 8)
335
336    def test_send_76(self):
337        # encode a 76 byte file (in 4 segments: 25,25,25,1) to 100 shares
338        return self.do_encode(25, 76, 100, 4, 7, 8)
339    def test_send_99(self):
340        # 4 segments: 25,25,25,24
341        return self.do_encode(25, 99, 100, 4, 7, 8)
342    def test_send_100(self):
343        # 4 segments: 25,25,25,25
344        return self.do_encode(25, 100, 100, 4, 7, 8)
345
346    def test_send_124(self):
347        # 5 segments: 25, 25, 25, 25, 24
348        return self.do_encode(25, 124, 100, 5, 15, 8)
349    def test_send_125(self):
350        # 5 segments: 25, 25, 25, 25, 25
351        return self.do_encode(25, 125, 100, 5, 15, 8)
352    def test_send_101(self):
353        # 5 segments: 25, 25, 25, 25, 1
354        return self.do_encode(25, 101, 100, 5, 15, 8)
355
356
357class Roundtrip(GridTestMixin, unittest.TestCase):
358
359    # a series of 3*3 tests to check out edge conditions. One axis is how the
360    # plaintext is divided into segments: kn+(-1,0,1). Another way to express
361    # this is n%k == -1 or 0 or 1. For example, for 25-byte segments, we
362    # might test 74 bytes, 75 bytes, and 76 bytes.
363
364    # on the other axis is how many leaves in the block hash tree we wind up
365    # with, relative to a power of 2, so 2^a+(-1,0,1). Each segment turns
366    # into a single leaf. So we'd like to check out, e.g., 3 segments, 4
367    # segments, and 5 segments.
368
369    # that results in the following series of data lengths:
370    #  3 segs: 74, 75, 51
371    #  4 segs: 99, 100, 76
372    #  5 segs: 124, 125, 101
373
374    # all tests encode to 100 shares, which means the share hash tree will
375    # have 128 leaves, which means that buckets will be given an 8-long share
376    # hash chain
377
378    # all 3-segment files will have a 4-leaf blockhashtree, and thus expect
379    # to get 7 blockhashes. 4-segment files will also get 4-leaf block hash
380    # trees and 7 blockhashes. 5-segment files will get 8-leaf block hash
381    # trees, which gets 15 blockhashes.
382
383    def test_74(self): return self.do_test_size(74)
384    def test_75(self): return self.do_test_size(75)
385    def test_51(self): return self.do_test_size(51)
386    def test_99(self): return self.do_test_size(99)
387    def test_100(self): return self.do_test_size(100)
388    def test_76(self): return self.do_test_size(76)
389    def test_124(self): return self.do_test_size(124)
390    def test_125(self): return self.do_test_size(125)
391    def test_101(self): return self.do_test_size(101)
392
393    def upload(self, data):
394        u = upload.Data(data, None)
395        u.max_segment_size = 25
396        u.encoding_param_k = 25
397        u.encoding_param_happy = 1
398        u.encoding_param_n = 100
399        d = self.c0.upload(u)
400        d.addCallback(lambda ur: self.c0.create_node_from_uri(ur.get_uri()))
401        # returns a FileNode
402        return d
403
404    def do_test_size(self, size):
405        self.basedir = self.mktemp()
406        self.set_up_grid()
407        self.c0 = self.g.clients[0]
408        DATA = b"p"*size
409        d = self.upload(DATA)
410        d.addCallback(lambda n: download_to_data(n))
411        def _downloaded(newdata):
412            self.failUnlessEqual(newdata, DATA)
413        d.addCallback(_downloaded)
414        return d
Note: See TracBrowser for help on using the repository browser.