Ticket #798: new-downloader-v1.diff

File new-downloader-v1.diff, 121.3 KB (added by warner, at 2010-03-11T19:35:25Z)

work-in-progress of new downloader, maybe 80% complete

  • new file src/allmydata/immutable/download2.py

    diff --git a/src/allmydata/immutable/download2.py b/src/allmydata/immutable/download2.py
    new file mode 100644
    index 0000000..440459c
    - +  
     1
     2import binascii
     3from allmydata.util.hashtree import IncompleteHashTree, BadHashError, \
     4     NotEnoughHashesError
     5
     6(UNUSED, PENDING, OVERDUE, COMPLETE, CORRUPT, DEAD, BADSEGNUM) = \
     7 ("UNUSED", "PENDING", "OVERDUE", "COMPLETE", "CORRUPT", "DEAD", "BADSEGNUM")
     8
     9class BadSegmentNumberError(Exception):
     10    pass
     11
     12class Share:
     13    # this is a specific implementation of IShare for tahoe's native storage
     14    # servers. A different backend would use a different class.
     15    """I represent a single instance of a single share (e.g. I reference the
     16    shnum2 for share SI=abcde on server xy12t, not the one on server ab45q).
     17    I am associated with a CommonShare that remembers data that is held in
     18    common among e.g. SI=abcde/shnum2 across all servers. I am also
     19    associated with a CiphertextFileNode for e.g. SI=abcde (all shares, all
     20    servers).
     21    """
     22
     23    def __init__(self, rref, verifycap, commonshare, node, peerid, shnum):
     24        self._rref = rref
     25        self._guess_offsets(verifycap, node.guessed_segment_size)
     26        self.actual_offsets = None
     27        self.actual_segment_size = None
     28        self._UEB_length = None
     29        self._commonshare = commonshare # holds block_hash_tree
     30        self._node = node # holds share_hash_tree and UEB
     31        self._peerid = peerid
     32        self._shnum = shnum
     33
     34        self._wanted = Spans() # desired metadata
     35        self._wanted_blocks = Spans() # desired block data
     36        self._requested = Spans() # we've sent a request for this
     37        self._received = Spans() # we've received a response for this
     38        self._received_data = DataSpans() # the response contents, still unused
     39        self._requested_blocks = [] # (segnum, set(observer2..))
     40        ver = rref.version["http://allmydata.org/tahoe/protocols/storage/v1"]
     41        self._overrun_ok = ver["tolerates-immutable-read-overrun"]
     42        # If _overrun_ok and we guess the offsets correctly, we can get
     43        # everything in one RTT. If _overrun_ok and we guess wrong, we might
     44        # need two RTT (but we could get lucky and do it in one). If overrun
     45        # is *not* ok (tahoe-1.3.0 or earlier), we need four RTT: 1=version,
     46        # 2=offset table, 3=UEB_length and everything else (hashes, block),
     47        # 4=UEB.
     48
     49    def _guess_offsets(self, verifycap, guessed_segment_size):
     50        self.guessed_segment_size = guessed_segment_size
     51        size = verifycap.size
     52        k = verifycap.needed_shares
     53        N = verifycap.total_shares
     54        offsets = {}
     55        for i,field in enumerate('data',
     56                                 'plaintext_hash_tree', # UNUSED
     57                                 'crypttext_hash_tree',
     58                                 'block_hashes',
     59                                 'share_hashes',
     60                                 'uri_extension',
     61                                 ):
     62            offsets[field] = i # bad guesses are easy :) # XXX stub
     63        self.guessed_offsets = offsets
     64        self._fieldsize = 4
     65        self._fieldstruct = ">L"
     66
     67    # called by our client, the SegmentFetcher
     68    def get_block(self, segnum):
     69        """Add a block number to the list of requests. This will eventually
     70        result in a fetch of the data necessary to validate the block, then
     71        the block itself. The fetch order is generally
     72        first-come-first-served, but requests may be answered out-of-order if
     73        data becomes available sooner.
     74
     75        I return an Observer2, which has two uses. The first is to call
     76        o.subscribe(), which gives me a place to send state changes and
     77        eventually the data block. The second is o.cancel(), which removes
     78        the request (if it is still active).
     79        """
     80        o = Observer2()
     81        o.set_canceler(self._cancel_block_request)
     82        for i,(segnum0,observers) in enumerate(self._requested_blocks):
     83            if segnum0 == segnum:
     84                observers.add(o)
     85                break
     86        else:
     87            self._requested_blocks.append(segnum, set([o]))
     88        eventually(self.loop)
     89        return o
     90
     91    def _cancel_block_request(self, o):
     92        new_requests = []
     93        for e in self._requested_blocks:
     94            (segnum0, observers) = e
     95            observers.discard(o)
     96            if observers:
     97                new_requests.append(e)
     98        self._requested_blocks = new_requests
     99
     100    # internal methods
     101    def _active_segnum(self):
     102        if self._requested_blocks:
     103            return self._requested_blocks[0]
     104        return None
     105
     106    def _active_segnum_and_observers(self):
     107        if self._requested_blocks:
     108            # we only retrieve information for one segment at a time, to
     109            # minimize alacrity (first come, first served)
     110            return self._requested_blocks[0]
     111        return None, []
     112
     113    def loop(self):
     114        # TODO: if any exceptions occur here, kill the download
     115
     116        # we are (eventually) called after all state transitions:
     117        #  new segments added to self._requested_blocks
     118        #  new data received from servers (responses to our read() calls)
     119        #  impatience timer fires (server appears slow)
     120
     121        # First, consume all of the information that we currently have, for
     122        # all the segments people currently want.
     123        while self._get_satisfaction():
     124            pass
     125
     126        # When we get no satisfaction (from the data we've received so far),
     127        # we determine what data we desire (to satisfy more requests). The
     128        # number of segments is finite, so I can't get no satisfaction
     129        # forever.
     130        self._desire()
     131
     132        # finally send out requests for whatever we need (desire minus have).
     133        # You can't always get what you want, but, sometimes, you get what
     134        # you need.
     135        self._request_needed() # express desire
     136
     137    def _get_satisfaction(self):
     138        # return True if we retired a data block, and should therefore be
     139        # called again. Return False if we don't retire a data block (even if
     140        # we do retire some other data, like hash chains).
     141
     142        if self.actual_offsets is None:
     143            if not self._satisfy_offsets():
     144                # can't even look at anything without the offset table
     145                return False
     146
     147        if self._node.UEB is None:
     148            if not self._satisfy_UEB():
     149                # can't check any hashes without the UEB
     150                return False
     151
     152        segnum, observers = self._active_segnum_and_observers()
     153        if segnum >= self._node.UEB.num_segments:
     154            for o in observers:
     155                o.notify(state=BADSEGNUM)
     156            self._requested_blocks.pop(0)
     157            return True
     158
     159        if self._node.share_hash_tree.needed_hashes(self.shnum):
     160            if not self._satisfy_share_hash_tree():
     161                # can't check block_hash_tree without a root
     162                return False
     163
     164        if segnum is None:
     165            return False # we don't want any particular segment right now
     166
     167        # block_hash_tree
     168        needed_hashes = self._commonshare.block_hash_tree.needed_hashes(segnum)
     169        if needed_hashes:
     170            if not self._satisfy_block_hash_tree(needed_hashes):
     171                # can't check block without block_hash_tree
     172                return False
     173
     174        # data blocks
     175        return self._satisfy_data_block(segnum, observers)
     176
     177    def _satisfy_offsets(self):
     178        version_s = self._received_data.get(0, 4)
     179        if version_s is None:
     180            return False
     181        (version,) = struct.unpack(">L", version_s)
     182        if version == 1:
     183            table_start = 0x0c
     184            self._fieldsize = 0x4
     185            self._fieldstruct = ">L"
     186        else:
     187            table_start = 0x14
     188            self._fieldsize = 0x8
     189            self._fieldstruct = ">Q"
     190        offset_table_size = 6 * self._fieldsize
     191        table_s = self._received_data.pop(table_start, offset_table_size)
     192        if table_s is None:
     193            return False
     194        fields = struct.unpack(6*self._fieldstruct, table_s)
     195        offsets = {}
     196        for i,field in enumerate('data',
     197                                 'plaintext_hash_tree', # UNUSED
     198                                 'crypttext_hash_tree',
     199                                 'block_hashes',
     200                                 'share_hashes',
     201                                 'uri_extension',
     202                                 ):
     203            offsets[field] = fields[i]
     204        self.actual_offsets = offsets
     205        self._received_data.remove(0, 4) # don't need this anymore
     206        return True
     207
     208    def _satisfy_UEB(self):
     209        o = self.actual_offsets
     210        fsize = self._fieldsize
     211        rdata = self._received_data
     212        UEB_length_s = rdata.get(o["uri_extension"], fsize)
     213        if not UEB_length_s:
     214            return False
     215        UEB_length = struct.unpack(UEB_length_s, self._fieldstruct)
     216        UEB_s = rdata.pop(o["uri_extension"]+fsize, UEB_length)
     217        if not UEB_s:
     218            return False
     219        rdata.remove(o["uri_extension"], fsize)
     220        try:
     221            self._node.validate_UEB(UEB_s) # stores in self._node.UEB # XXX
     222            self.actual_segment_size = self._node.segment_size
     223            assert self.actual_segment_size is not None
     224            return True
     225        except hashtree.BadHashError:
     226            # TODO: if this UEB was bad, we'll keep trying to validate it
     227            # over and over again. Only log.err on the first one, or better
     228            # yet skip all but the first
     229            f = Failure()
     230            self._signal_corruption(f, o["uri_extension"], fsize+UEB_length)
     231            return False
     232
     233    def _satisfy_share_hash_tree(self):
     234        # the share hash chain is stored as (hashnum,hash) tuples, so you
     235        # can't fetch just the pieces you need, because you don't know
     236        # exactly where they are. So fetch everything, and parse the results
     237        # later.
     238        o = self.actual_offsets
     239        rdata = self._received_data
     240        hashlen = o["uri_extension"] - o["share_hashes"]
     241        assert hashlen % (2+HASH_SIZE) == 0
     242        hashdata = rdata.get(o["share_hashes"], hashlen)
     243        if not hashdata:
     244            return False
     245        share_hashes = {}
     246        for i in range(0, hashlen, 2+HASH_SIZE):
     247            hashnum = struct.unpack(">H", hashdata[i:i+2])[0]
     248            hashvalue = hashdata[i+2:i+2+HASH_SIZE]
     249            share_hashes[hashnum] = hashvalue
     250        try:
     251            self._node.process_share_hashes(share_hashes)
     252            # adds to self._node.share_hash_tree
     253            rdata.remove(o["share_hashes"], hashlen)
     254            return True
     255        except IndexError, hashtree.BadHashError, hashtree.NotEnoughHashesError:
     256            f = Failure()
     257            self._signal_corruption(f, o["share_hashes"], hashlen)
     258            return False
     259
     260    def _signal_corruption(self, f, start, offset):
     261        # there was corruption somewhere in the given range
     262        print f # XXX
     263        pass
     264
     265    def _satisfy_block_hash_tree(self, needed_hashes):
     266        o = self.actual_offsets
     267        rdata = self._received_data
     268        block_hashes = {}
     269        for hashnum in needed_hashes:
     270            hashdata = rdata.get(o["block_hashes"]+hashnum*HASH_SIZE, HASH_SIZE)
     271            if hashdata:
     272                block_hashes[hashnum] = hashdata
     273            else:
     274                return False # missing some hashes
     275        # note that we don't submit any hashes to the block_hash_tree until
     276        # we've gotten them all, because the hash tree will throw an
     277        # exception if we only give it a partial set (which it therefore
     278        # cannot validate)
     279        ok = commonshare.process_block_hashes(block_hashes) # XXX
     280        if not ok:
     281            return False
     282        for hashnum in needed_hashes:
     283            rdata.remove(o["block_hashes"]+hashnum*HASH_SIZE, HASH_SIZE)
     284        return True
     285
     286    def _satisfy_data_block(self, segnum, observers):
     287        o = self.actual_offsets
     288        segsize = self._node.UEB["segment_size"]
     289        needed_shares = self._node.UEB["needed_shares"]
     290        sharesize = mathutil.div_ceil(self._node.UEB["size"],
     291                                      needed_shares)
     292        blocksize = mathutil.div_ceil(segsize, needed_shares) # XXX
     293        blockstart = o["data"] + segnum * blocksize
     294        if blocknum < NUM_BLOCKS-1:
     295            blocklen = blocksize
     296        else:
     297            blocklen = sharesize % blocksize
     298            if blocklen == 0:
     299                blocklen = blocksize
     300        block = rdata.pop(blockstart, blocklen)
     301        if not block:
     302            return False
     303        # this block is being retired, either as COMPLETE or CORRUPT, since
     304        # no further data reads will help
     305        assert self._requested_blocks[0][0] == segnum
     306        ok = commonshare.check_block(segnum, block)
     307        if ok:
     308            state = COMPLETE
     309        else:
     310            state = CORRUPT
     311        for o in observers:
     312            # goes to SegmentFetcher._block_request_activity
     313            o.notify(state=state, block=block)
     314        self._requested_blocks.pop(0) # retired
     315        return True # got satisfaction
     316
     317    def _desire(self):
     318        segnum, observers = self._active_segnum_and_observers()
     319        fsize = self._fieldsize
     320        rdata = self._received_data
     321        commonshare = self._commonshare
     322
     323        if not self.actual_offsets:
     324            self._desire_offsets()
     325
     326        # we can use guessed offsets as long as this server tolerates overrun
     327        if not self.actual_offsets and not self._overrun_ok:
     328            return # must wait for the offsets to arrive
     329
     330        o = self.actual_offsets or self.guessed_offsets
     331        segsize = self.actual_segment_size or self.guessed_segment_size
     332        if self._node.UEB is None:
     333            self._desire_UEB(o)
     334
     335        if self._node.share_hash_tree.needed_hashes(self.shnum):
     336            hashlen = o["uri_extension"] - o["share_hashes"]
     337            self._wanted.add(o["share_hashes"], hashlen)
     338
     339        if segnum is None:
     340            return # only need block hashes or blocks for active segments
     341
     342        # block hash chain
     343        for hashnum in commonshare.block_hash_tree.needed_hashes(segnum):
     344            self._wanted.add(o["block_hashes"]+hashnum*HASH_SIZE, HASH_SIZE)
     345
     346        # data
     347        blockstart, blocklen = COMPUTE(segnum, segsize, etc) # XXX
     348        self._wanted_blocks.add(blockstart, blocklen)
     349
     350
     351    def _desire_offsets(self):
     352        if self._overrun_ok:
     353            # easy! this includes version number, sizes, and offsets
     354            self._wanted.add(0,1024)
     355            return
     356
     357        # v1 has an offset table that lives [0x0,0x24). v2 lives [0x0,0x44).
     358        # To be conservative, only request the data that we know lives there,
     359        # even if that means more roundtrips.
     360
     361        self._wanted.add(0,4)  # version number, always safe
     362        version_s = self._received_data.get(0, 4)
     363        if not version_s:
     364            return
     365        (version,) = struct.unpack(">L", version_s)
     366        if version == 1:
     367            table_start = 0x0c
     368            fieldsize = 0x4
     369        else:
     370            table_start = 0x14
     371            fieldsize = 0x8
     372        offset_table_size = 6 * fieldsize
     373        self._wanted.add(table_start, offset_table_size)
     374
     375    def _desire_UEB(self, o):
     376        # UEB data is stored as (length,data).
     377        if self._overrun_ok:
     378            # We can pre-fetch 2kb, which should probably cover it. If it
     379            # turns out to be larger, we'll come back here later with a known
     380            # length and fetch the rest.
     381            self._wanted.add(o["uri_extension"], 2048)
     382            # now, while that is probably enough to fetch the whole UEB, it
     383            # might not be, so we need to do the next few steps as well. In
     384            # most cases, the following steps will not actually add anything
     385            # to self._wanted
     386
     387        self._wanted.add(o["uri_extension"], self._fieldsize)
     388        # only use a length if we're sure it's correct, otherwise we'll
     389        # probably fetch a huge number
     390        if not self.actual_offsets:
     391            return
     392        UEB_length_s = rdata.get(o["uri_extension"], self._fieldsize)
     393        if UEB_length_s:
     394            UEB_length = struct.unpack(UEB_length_s, self._fieldstruct)
     395            # we know the length, so make sure we grab everything
     396            self._wanted.add(o["uri_extension"]+self._fieldsize, UEB_length)
     397
     398    def _request_needed(self):
     399        # send requests for metadata first, to avoid hanging on to large data
     400        # blocks any longer than necessary.
     401        self._send_requests(self._wanted - self._received - self._requested)
     402        # then send requests for data blocks. All the hashes should arrive
     403        # before the blocks, so the blocks can be consumed and released in a
     404        # single turn.
     405        self._send_requests(self._wanted_blocks - self._received - self._requested
     406
     407    def _send_requests(self, needed):
     408        for (start, length) in needed:
     409            # TODO: quantize to reasonably-large blocks
     410            self._requested.add(start, length)
     411            d = self._send_request(start, length)
     412            d.addCallback(self._got_data, start, length)
     413            d.addErrback(self._got_error)
     414            d.addErrback(log.err, ...) # XXX
     415
     416    def _send_request(self, start, length):
     417        return self._rref.callRemote("read", start, length)
     418
     419    def _got_data(self, data, start, length):
     420        span = (start, length)
     421        assert span in self._requested
     422        self._requested.remove(start, length)
     423        self._received.add(start, length)
     424        self._received_data.add(start, data)
     425        eventually(self.loop)
     426
     427    def _got_error(self, f): # XXX
     428        ...
     429
     430
     431class CommonShare:
     432    """I hold data that is common across all instances of a single share,
     433    like sh2 on both servers A and B. This is just the block hash tree.
     434    """
     435    def __init__(self, numsegs):
     436        if numsegs is not None:
     437            self._block_hash_tree = IncompleteHashTree(numsegs)
     438
     439    def got_numsegs(self, numsegs):
     440        self._block_hash_tree = IncompleteHashTree(numsegs)
     441
     442    def process_block_hashes(self, block_hashes):
     443        self._block_hash_tree.add_hashes(block_hashes)
     444        return True
     445    def check_block(self, segnum, block):
     446        h = hashutil.block_hash(block)
     447        try:
     448            self._block_hash_tree.set_hashes(leaves={segnum: h})
     449        except (hashtree.BadHashError, hashtree.NotEnoughHashesError), le:
     450            LOG(...)
     451            return False
     452        return True
     453
     454# all classes are also Services, and the rule is that you don't initiate more
     455# work unless self.running
     456
     457# GC: decide whether each service is restartable or not. For non-restartable
     458# services, stopService() should delete a lot of attributes to kill reference
     459# cycles. The primary goal is to decref remote storage BucketReaders when a
     460# download is complete.
     461
     462class SegmentFetcher:
     463    """I am responsible for acquiring blocks for a single segment. I will use
     464    the Share instances passed to my add_shares() method to locate, retrieve,
     465    and validate those blocks. I expect my parent node to call my
     466    no_more_shares() method when there are no more shares available. I will
     467    call my parent's want_more_shares() method when I want more: I expect to
     468    see at least one call to add_shares or no_more_shares afterwards.
     469
     470    When I have enough validated blocks, I will call my parent's
     471    process_blocks() method with a dictionary that maps shnum to blockdata.
     472    If I am unable to provide enough blocks, I will call my parent's
     473    fetch_failed() method with (self, f). After either of these events, I
     474    will shut down and do no further work. My parent can also call my stop()
     475    method to have me shut down early."""
     476
     477    def __init__(self, node, segnum, k):
     478        self._node = node # CiphertextFileNode
     479        self.segnum = segnum
     480        self._k = k
     481        self._shares = {} # maps non-dead Share instance to a state, one of
     482                          # (UNUSED, PENDING, OVERDUE, COMPLETE, CORRUPT).
     483                          # State transition map is:
     484                          #  UNUSED -(send-read)-> PENDING
     485                          #  PENDING -(timer)-> OVERDUE
     486                          #  PENDING -(rx)-> COMPLETE, CORRUPT, DEAD, BADSEGNUM
     487                          #  OVERDUE -(rx)-> COMPLETE, CORRUPT, DEAD, BADSEGNUM
     488                          # If a share becomes DEAD, it is removed from the
     489                          # dict. If it becomes BADSEGNUM, the whole fetch is
     490                          # terminated.
     491        self._share_observers = {} # maps Share to Observer2 for active ones
     492        self._shnums = DictOfSets() # maps shnum to the shares that provide it
     493        self._blocks = {} # maps shnum to validated block data
     494        self._no_more_shares = False
     495        self._bad_segnum = False
     496        self._running = True
     497
     498    def stop(self):
     499        self._cancel_all_requests()
     500        self._running = False
     501        del self._shares # let GC work # ???
     502
     503
     504    # called by our parent CiphertextFileNode
     505
     506    def add_shares(self, shares):
     507        # called when ShareFinder locates a new share, and when a non-initial
     508        # segment fetch is started and we already know about shares from the
     509        # previous segment
     510        for s in shares:
     511            self._shares[s] = UNUSED
     512            self._shnums[s.shnum].add(s)
     513        eventually(self._loop)
     514
     515    def no_more_shares(self):
     516        # ShareFinder tells us it's reached the end of its list
     517        self._no_more_shares = True
     518
     519    # internal methods
     520
     521    def _count_shnums(self, *states):
     522        """shnums for which at least one state is in the following list"""
     523        shnums = []
     524        for shnum,shares in self._shnums.iteritems():
     525            matches = [s for s in shares if s.state in states]
     526            if matches:
     527                shnums.append(shnum)
     528        return len(shnums)
     529
     530    def _loop(self):
     531        if not self._running:
     532            return
     533        if self._bad_segnum:
     534            # oops, we were asking for a segment number beyond the end of the
     535            # file. This is an error.
     536            self.stop()
     537            e = BadSegmentNumberError("%d > %d" % (self.segnum,
     538                                                   self._node.num_segments))
     539            f = Failure(e)
     540            self._node.fetch_failed(self, f)
     541            return
     542
     543        # are we done?
     544        if self._count_shnums(COMPLETE) >= self._k:
     545            # yay!
     546            self.stop()
     547            self._node.process_blocks(self.segnum, self._blocks)
     548            return
     549
     550        # we may have exhausted everything
     551        if (self._no_more_shares and
     552            self._count_shnums(UNUSED, PENDING, OVERDUE, COMPLETE) < self._k):
     553            # no more new shares are coming, and the remaining hopeful shares
     554            # aren't going to be enough. boo!
     555            self.stop()
     556            e = NotEnoughShares("...") # XXX
     557            f = Failure(e)
     558            self._node.fetch_failed(self, f)
     559            return
     560
     561        # nope, not done. Are we "block-hungry" (i.e. do we want to send out
     562        # more read requests, or do we think we have enough in flight
     563        # already?)
     564        while self._count_shnums(PENDING, COMPLETE) < self._k:
     565            # we're hungry.. are there any unused shares?
     566            sent = self._send_new_request()
     567            if not sent:
     568                break
     569
     570        # ok, now are we "share-hungry" (i.e. do we have enough known shares
     571        # to make us happy, or should we ask the ShareFinder to get us more?)
     572        if self._count_shnums(UNUSED, PENDING, COMPLETE) < self._k:
     573            # we're hungry for more shares
     574            self._node.want_more_shares()
     575            # that will trigger the ShareFinder to keep looking
     576
     577    def _send_new_request(self):
     578        for shnum,shares in self._shnums.iteritems():
     579            states = [self._shares[s] for s in shares]
     580            if COMPLETE in states or PENDING in states:
     581                # don't send redundant requests
     582                continue
     583            if UNUSED not in states:
     584                # no candidates for this shnum, move on
     585                continue
     586            # here's a candidate. Send a request.
     587            s = find_one(shares, UNUSED) # XXX could choose fastest
     588            self._shares[s] = PENDING
     589            self._share_observers[s] = o = s.get_block(segnum)
     590            o.subscribe(self._block_request_activity, share=s, shnum=shnum)
     591            # TODO: build up a list of candidates, then walk through the
     592            # list, sending requests to the most desireable servers,
     593            # re-checking our block-hunger each time. For non-initial segment
     594            # fetches, this would let us stick with faster servers.
     595            return True
     596        # nothing was sent: don't call us again until you have more shares to
     597        # work with, or one of the existing shares has been declared OVERDUE
     598        return False
     599
     600    def _cancel_all_requests(self):
     601        for o in self._share_observers.values():
     602            o.cancel()
     603        self._share_observers = {}
     604
     605    def _block_request_activity(self, share, shnum, state, block=None):
     606        # called by Shares, in response to our s.send_request() calls.
     607        # COMPLETE, CORRUPT, DEAD, BADSEGNUM are terminal.
     608        if state in (COMPLETE, CORRUPT, DEAD, BADSEGNUM):
     609            del self._share_observers[share]
     610        if state is COMPLETE:
     611            # 'block' is fully validated
     612            self._shares[share] = COMPLETE
     613            self._blocks[shnum] = block
     614        elif state is OVERDUE:
     615            self._shares[share] = OVERDUE
     616            # OVERDUE is not terminal: it will eventually transition to
     617            # COMPLETE, CORRUPT, or DEAD.
     618        elif state is CORRUPT:
     619            self._shares[share] = CORRUPT
     620        elif state is DEAD:
     621            del self._shares[share]
     622            self._shnums[shnum].remove(share)
     623        elif state is BADSEGNUM:
     624            self._shares[share] = BADSEGNUM # ???
     625            self._bad_segnum = True
     626        eventually(self._loop)
     627
     628
     629class RequestToken:
     630    def __init__(self, peerid):
     631        self.peerid = peerid
     632
     633class ShareFinder:
     634    def __init__(self, storage_broker, storage_index,
     635                 share_consumer, max_outstanding_requests=10):
     636        self.running = True
     637        s = storage_broker.get_servers_for_index(storage_index)
     638        self._servers = iter(s)
     639        self.share_consumer = share_consumer
     640        self.max_outstanding = max_outstanding_requests
     641
     642        self._hungry = False
     643
     644        self._commonshares = {} # shnum to CommonShare instance
     645        self.undelivered_shares = []
     646        self.pending_requests = set()
     647
     648        self._si_prefix = base32.b2a_l(storage_index[:8], 60)
     649        self._lp = log.msg(format="ShareFinder[si=%(si)s] starting",
     650                           si=self._si_prefix, level=log.NOISY, umid="2xjj2A")
     651
     652        self._num_segments = None
     653        d = share_consumer.get_num_segments()
     654        d.addCallback(self._got_numsegs)
     655        d.addErrback(log.err, ...) # XXX
     656
     657    def log(self, *args, **kwargs):
     658        if "parent" not in kwargs:
     659            kwargs["parent"] = self._lp
     660        return log.msg(*args, **kwargs)
     661
     662    def stop(self):
     663        self.running = False
     664
     665    def _got_numsegs(self, numsegs):
     666        for cs in self._commonshares.values():
     667            cs.got_numsegs(numsegs)
     668        self._num_segments = numsegs
     669
     670    # called by our parent CiphertextDownloader
     671    def hungry(self):
     672        log.msg(format="ShareFinder[si=%(si)s] hungry",
     673                si=self._si_prefix, level=log.NOISY, umid="NywYaQ")
     674        self._hungry = True
     675        eventually(self.loop)
     676
     677    # internal methods
     678    def loop(self):
     679        log.msg(format="ShareFinder[si=%(si)s] loop: running=%(running)s"
     680                " hungry=%(hungry)s, undelivered=%(undelivered)s,"
     681                " pending=%(pending)s",
     682                si=self._si_prefix, running=self._running, hungry=self._hungry,
     683                undelivered=",".join(["sh%d@%s" % (s._shnum,
     684                                                   idlib.shortnodeid_b2a(s._peerid))
     685                                      for s in self.undelivered_shares]),
     686                pending=",".join([idlib.shortnodeid_b2a(rt.peerid)
     687                                  for rt in self.pending_requests]), # sort?
     688                level=log.NOISY, umid="kRtS4Q")
     689        if not self.running:
     690            return
     691        if not self._hungry:
     692            return
     693        if self.undelivered_shares:
     694            sh = self.undelivered_shares.pop(0)
     695            # they will call hungry() again if they want more
     696            self._hungry = False
     697            eventually(self.share_consumer.got_shares, [sh])
     698            return
     699        if len(self.pending_requests) >= self.max_outstanding_requests:
     700            # cannot send more requests, must wait for some to retire
     701            return
     702
     703        server = None
     704        try:
     705            if self._servers:
     706                server = self._servers.next()
     707        except StopIteration:
     708            self._servers = None
     709
     710        if server:
     711            self.send_request(server)
     712            return
     713
     714        if self.pending_requests:
     715            # no server, but there are still requests in flight: maybe one of
     716            # them will make progress
     717            return
     718
     719        # we've run out of servers (so we can't send any more requests), and
     720        # we have nothing in flight. No further progress can be made. They
     721        # are destined to remain hungry.
     722        self.share_consumer.no_more_shares()
     723        self.stop()
     724
     725
     726    def send_request(self, server):
     727        peerid, rref = server
     728        req = RequestToken(peerid)
     729        self.pending_requests.add(req)
     730        lp = self.log(format="sending DYHB to [%(peerid)s]",
     731                      peerid=idlib.shortnodeid_b2a(peerid),
     732                      level=log.NOISY, umid="Io7pyg")
     733        d = rref.callRemote("get_buckets", self._storage_index)
     734        d.addBoth(incidentally, self.pending_requests.discard, req)
     735        d.addCallbacks(self._got_response, self._got_error,
     736                       callbackArgs=(peerid, req, lp))
     737        d.addErrback(log.err, format="error in send_request",
     738                     level=log.WEIRD, parent=lp, umid="rpdV0w")
     739        d.addCallback(incidentally, eventually, self.loop)
     740
     741    def _got_response(self, buckets, peerid, req, lp):
     742        if buckets:
     743            shnums_s = ",".join([str(shnum) for shnum in buckets])
     744            self.log(format="got shnums [%s] from [%(peerid)s]" % shnums_s,
     745                     peerid=idlib.shortnodeid_b2a(peerid),
     746                     level=log.NOISY, parent=lp, umid="0fcEZw")
     747        else:
     748            self.log(format="no shares from [%(peerid)s]",
     749                     peerid=idlib.shortnodeid_b2a(peerid),
     750                     level=log.NOISY, parent=lp, umid="U7d4JA")
     751        for shnum, bucket in buckets.iteritems():
     752            if shnum not in self._commonshares:
     753                self._commonshares[shnum] = CommonShare(self._num_segments)
     754            cs = self._commonshares[shnum]
     755            s = Share(bucket, self.verifycap, cs, self.node,
     756                      peerid, shnum)
     757            self.undelivered_shares.append(s)
     758
     759    def _got_error(self, f, peerid, req):
     760        self.log(format="got error from [%(peerid)s]",
     761                 peerid=idlib.shortnodeid_b2a(peerid), failure=f,
     762                 level=log.UNUSUAL, parent=lp, umid="zUKdCw")
     763
     764
     765
     766class Segmentation:
     767    """I am responsible for a single offset+size read of the file. I handle
     768    segmentation: I figure out which segments are necessary, request them
     769    (from my CiphertextDownloader) in order, and trim the segments down to
     770    match the offset+size span. I use the Producer/Consumer interface to only
     771    request one segment at a time.
     772    """
     773    implements(IPushProducer)
     774    def __init__(self, node, offset, size, consumer):
     775        self._node = node
     776        self._hungry = True
     777        self._active_segnum = None
     778        self._cancel_segment_request = None
     779        # these are updated as we deliver data. At any given time, we still
     780        # want to download file[offset:offset+size]
     781        self._offset = offset
     782        self._size = size
     783        self._consumer = consumer
     784
     785    def start(self):
     786        self._alive = True
     787        self._deferred = defer.Deferred()
     788        self._consumer.registerProducer(self) # XXX???
     789        self._maybe_fetch_next()
     790        return self._deferred
     791
     792    def _maybe_fetch_next(self):
     793        if not self._alive or not self._hungry:
     794            return
     795        if self._active_segnum is not None:
     796            return
     797        self._fetch_next()
     798
     799    def _fetch_next(self):
     800        if self._size == 0:
     801            # done!
     802            self._alive = False
     803            self._hungry = False
     804            self._consumer.unregisterProducer()
     805            self._deferred.callback(self._consumer)
     806            return
     807        n = self._node
     808        have_actual_segment_size = n.actual_segment_size is not None
     809        segment_size = n.actual_segment_size or n.guessed_segment_size
     810        if self._offset == 0:
     811            # great! we want segment0 for sure
     812            wanted_segnum = 0
     813        else:
     814            # this might be a guess
     815            wanted_segnum = self._offset // segment_size
     816        self._active_segnum = wanted_segnum
     817        d,c = self._node.get_segment(wanted_segnum)
     818        self._cancel_segment_request = c
     819        d.addBoth(self._request_retired)
     820        d.addCallback(self._got_segment, have_actual_segment_size)
     821        d.addErrback(self._retry_bad_segment, have_actual_segment_size)
     822        d.addErrback(self._error)
     823
     824    def _request_retired(self, res):
     825        self._active_segnum = None
     826        self._cancel_segment_request = None
     827        return res
     828
     829    def _got_segment(self, (segment_start,segment), had_actual_segment_size):
     830        self._active_segnum = None
     831        self._cancel_segment_request = None
     832        # we got file[segment_start:segment_start+len(segment)]
     833        # we want file[self._offset:self._offset+self._size]
     834        o = overlap(segment_start, len(segment),  self._offset, self._size)
     835        # the overlap is file[o[0]:o[0]+o[1]]
     836        if not o or o[0] != self._offset:
     837            # we didn't get the first byte, so we can't use this segment
     838            if have_actual_segment_size:
     839                # and we should have gotten it right. This is big problem.
     840                raise SOMETHING
     841            # we've wasted some bandwidth, but now we can grab the right one,
     842            # because we should know the segsize by now.
     843            assert self._node.actual_segment_size is not None
     844            self._maybe_fetch_next()
     845            return
     846        offset_in_segment = self._offset - segment_start
     847        desired_data = segment[offset_in_segment:offset_in_segment+o[1]]
     848
     849        self._offset += len(desired_data)
     850        self._size -= len(desired_data)
     851        self._consumer.write(desired_data)
     852        self._maybe_fetch_next()
     853
     854    def _retry_bad_segment(self, f, had_actual_segment_size):
     855        f.trap(BadSegmentNumberError): # guessed way wrong, off the end
     856        if had_actual_segment_size:
     857            # but we should have known better, so this is a real error
     858            return f
     859        # we didn't know better: try again with more information
     860        return self._maybe_fetch_next()
     861
     862    def _error(self, f):
     863        self._alive = False
     864        self._hungry = False
     865        self._consumer.unregisterProducer()
     866        self._deferred.errback(f)
     867
     868    def stopProducing(self):
     869        self._hungry = False
     870        self._alive = False
     871        # cancel any outstanding segment request
     872        if self._cancel_segment_request:
     873            self._cancel_segment_request()
     874            self._cancel_segment_request = None
     875    def pauseProducing(self):
     876        self._hungry = False
     877    def resumeProducing(self):
     878        self._hungry = True
     879        eventually(self._maybe_fetch_next)
     880
     881class Cancel:
     882    def __init__(self, f):
     883        self._f = f
     884        self.cancelled = False
     885    def cancel(self):
     886        if not self.cancelled:
     887            self.cancelled = True
     888            self._f(self)
     889
     890class CiphertextFileNode:
     891    # Share._node points to me
     892    def __init__(self, verifycap, storage_broker, secret_holder,
     893                 terminator, history):
     894        assert isinstance(verifycap, CHKFileVerifierURI)
     895        self.u = verifycap
     896        storage_index = verifycap.storage_index
     897        self._needed_shares = verifycap.needed_shares
     898        self._total_shares = verifycap.total_shares
     899        self.running = True
     900        terminator.register(self) # calls self.stop() at stopService()
     901        # the rule is: only send network requests if you're active
     902        # (self.running is True). You can do eventual-sends any time. This
     903        # rule should mean that once stopService()+flushEventualQueue()
     904        # fires, everything will be done.
     905        self._secret_holder = secret_holder
     906        self._history = history
     907
     908        self.share_hash_tree = IncompleteHashTree(self.u.total_shares)
     909
     910        # we guess the segment size, so Segmentation can pull non-initial
     911        # segments in a single roundtrip
     912        k = verifycap.needed_shares
     913        max_segment_size = 128*KiB # TODO: pull from elsewhere, maybe the
     914                                   # same place as upload.BaseUploadable
     915        s = mathutil.next_multiple(min(verifycap.size, max_segment_size), k)
     916        self.guessed_segment_size = s
     917
     918        # filled in when we parse a valid UEB
     919        self.have_UEB = False
     920        self.num_segments = None
     921        self.segment_size = None
     922        self.tail_data_size = None
     923        self.tail_segment_size = None
     924        self.block_size = None
     925        self.share_size = None
     926        self.ciphertext_hash_tree = None # size depends on num_segments
     927        self.ciphertext_hash = None # flat hash, optional
     928
     929        # things to track callers that want data
     930        self._segsize_observers = OneShotObserverList()
     931        self._numsegs_observers = OneShotObserverList()
     932        # _segment_requests can have duplicates
     933        self._segment_requests = [] # (segnum, d, cancel_handle)
     934        self._active_segment = None # a SegmentFetcher, with .segnum
     935
     936        self._sharefinder = ShareFinder(storage_broker, storage_index, self)
     937        self._shares = set()
     938
     939    def stop(self):
     940        # called by the Terminator at shutdown, mostly for tests
     941        if self._active_segment:
     942            self._active_segment.stop()
     943            self._active_segment = None
     944        self._sharefinder.stop()
     945
     946    # things called by our client, either a filenode user or an
     947    # ImmutableFileNode wrapper
     948
     949    def read(self, consumer, offset=0, size=None):
     950        """I am the main entry point, from which FileNode.read() can get
     951        data. I feed the consumer with the desired range of ciphertext. I
     952        return a Deferred that fires (with the consumer) when the read is
     953        finished."""
     954        # for concurrent operations: each gets its own Segmentation manager
     955        if size is None:
     956            size = self._size - offset
     957        s = Segmentation(self, offset, size, consumer)
     958        # this raises an interesting question: what segments to fetch? if
     959        # offset=0, always fetch the first segment, and then allow
     960        # Segmentation to be responsible for pulling the subsequent ones if
     961        # the first wasn't large enough. If offset>0, we're going to need an
     962        # extra roundtrip to get the UEB (and therefore the segment size)
     963        # before we can figure out which segment to get. TODO: allow the
     964        # offset-table-guessing code (which starts by guessing the segsize)
     965        # to assist the offset>0 process.
     966        d = s.start()
     967        return d
     968
     969    def get_segment(self, segnum):
     970        """Begin downloading a segment. I return a tuple (d, c): 'd' is a
     971        Deferred that fires with (offset,data) when the desired segment is
     972        available, and c is an object on which c.cancel() can be called to
     973        disavow interest in the segment (after which 'd' will never fire).
     974
     975        You probably need to know the segment size before calling this,
     976        unless you want the first few bytes of the file. If you ask for a
     977        segment number which turns out to be too large, the Deferred will
     978        errback with BadSegmentNumberError.
     979
     980        The Deferred fires with the offset of the first byte of the data
     981        segment, so that you can call get_segment() before knowing the
     982        segment size, and still know which data you received.
     983        """
     984        d = defer.Deferred()
     985        c = Cancel(self._cancel_request)
     986        self._segment_requests.append( (segnum, d, c) )
     987        self._start_new_segment()
     988        eventually(self._loop)
     989        return (d, c)
     990
     991    # things called by the Segmentation object used to transform
     992    # arbitrary-sized read() calls into quantized segment fetches
     993
     994    def get_segment_size(self):
     995        """I return a Deferred that fires with the segment_size used by this
     996        file."""
     997        return self._segsize_observers.when_fired()
     998    def get_num_segments(self):
     999        """I return a Deferred that fires with the number of segments used by
     1000        this file."""
     1001        return self._numsegs_observers.when_fired()
     1002
     1003    def _start_new_segment(self):
     1004        if self._active_segment is None and self._segment_requests:
     1005            segnum = self._segment_requests[0][0]
     1006            self._active_segment = fetcher = SegmentFetcher(self, segnum,
     1007                                                            self._needed_shares)
     1008            active_shares = [s for s in self._shares if s.not_dead()]
     1009            fetcher.add_shares(active_shares) # this triggers the loop
     1010
     1011
     1012    # called by our child ShareFinder
     1013    def got_shares(self, shares):
     1014        self._shares.update(shares)
     1015        if self._active_segment
     1016            self._active_segment.add_shares(shares)
     1017    def no_more_shares(self):
     1018        self._no_more_shares = True
     1019        if self._active_segment:
     1020            self._active_segment.no_more_shares()
     1021
     1022    # things called by our Share instances
     1023
     1024    def validate_UEB(self, UEB_s):
     1025        h = hashutil.uri_extension_hash(UEB_s)
     1026        if h != self._verifycap.uri_extension_hash:
     1027            raise hashutil.BadHashError
     1028        UEB_dict = uri.unpack_extension(data)
     1029        self._parse_UEB(self, UEB_dict) # sets self._stuff
     1030        # TODO: a malformed (but authentic) UEB could throw an assertion in
     1031        # _parse_UEB, and we should abandon the download.
     1032        self.have_UEB = True
     1033        self._segsize_observers.fire(self.segment_size)
     1034        self._numsegs_observers.fire(self.num_segments)
     1035
     1036
     1037    def _parse_UEB(self, d):
     1038        # Note: the UEB contains needed_shares and total_shares. These are
     1039        # redundant and inferior (the filecap contains the authoritative
     1040        # values). However, because it is possible to encode the same file in
     1041        # multiple ways, and the encoders might choose (poorly) to use the
     1042        # same key for both (therefore getting the same SI), we might
     1043        # encounter shares for both types. The UEB hashes will be different,
     1044        # however, and we'll disregard the "other" encoding's shares as
     1045        # corrupted.
     1046
     1047        # therefore, we ignore d['total_shares'] and d['needed_shares'].
     1048
     1049        self.share_size = mathutil.div_ceil(self._verifycap.size,
     1050                                            self._needed_shares)
     1051
     1052        self.segment_size = d['segment_size']
     1053        for r in self._readers:
     1054            r.set_segment_size(self.segment_size)
     1055
     1056        self.block_size = mathutil.div_ceil(self._segsize, self._needed_shares)
     1057        self.num_segments = mathutil.div_ceil(self._size, self.segment_size)
     1058
     1059        self.tail_data_size = self._size % self.segment_size
     1060        if self.tail_data_size == 0:
     1061            self.tail_data_size = self.segment_size
     1062        # padding for erasure code
     1063        self.tail_segment_size = mathutil.next_multiple(self.tail_data_size,
     1064                                                        self._needed_shares)
     1065
     1066        # zfec.Decode() instantiation is fast, but still, let's use the same
     1067        # codec for anything we can. 3-of-10 takes 15us on my laptop,
     1068        # 25-of-100 is 900us, 3-of-255 is 97us, 25-of-255 is 2.5ms,
     1069        # worst-case 254-of-255 is 9.3ms
     1070        self._codec = codec.CRSDecoder()
     1071        self._codec.set_params(self.segment_size,
     1072                               self._needed_shares, self._total_shares)
     1073
     1074
     1075        # Ciphertext hash tree root is mandatory, so that there is at most
     1076        # one ciphertext that matches this read-cap or verify-cap. The
     1077        # integrity check on the shares is not sufficient to prevent the
     1078        # original encoder from creating some shares of file A and other
     1079        # shares of file B.
     1080        self.ciphertext_hash_tree = IncompleteHashTree(self.num_segments)
     1081        self.ciphertext_hash_tree.set_hashes({0: d['crypttext_root_hash']})
     1082
     1083        self.share_hash_tree.set_hashes({0: d['share_root_hash']})
     1084
     1085        # crypttext_hash is optional. We only pull this from the first UEB
     1086        # that we see.
     1087        if 'crypttext_hash' in d:
     1088            if len(d["crypttext_hash"]) == hashutil.CRYPTO_VAL_SIZE:
     1089                self.ciphertext_hash = d['crypttext_hash']
     1090            else:
     1091                log.msg("ignoring bad-length UEB[crypttext_hash], "
     1092                        "got %d bytes, want %d" % (len(d['crypttext_hash']),
     1093                                                   hashutil.CRYPTO_VAL_SIZE),
     1094                        umid="oZkGLA", level=log.WEIRD)
     1095
     1096        # Our job is a fast download, not verification, so we ignore any
     1097        # redundant fields. The Verifier uses a different code path which
     1098        # does not ignore them.
     1099
     1100
     1101    def process_share_hashes(self, share_hashes):
     1102        self.share_hash_tree.set_hashes(share_hashes)
     1103
     1104    # called by our child SegmentFetcher
     1105
     1106    def want_more_shares(self):
     1107        self._sharefinder.hungry()
     1108
     1109    def fetch_failed(self, sf, f):
     1110        assert sf is self._active_segment
     1111        sf.disownServiceParent()
     1112        self._active_segment = None
     1113        # deliver error upwards
     1114        for (d,c) in self._extract_requests(sf.segnum):
     1115            eventually(self._deliver_error, d, c, f)
     1116
     1117    def _deliver_error(self, d, c, f):
     1118        # this method exists to handle cancel() that occurs between
     1119        # _got_segment and _deliver_error
     1120        if not c.cancelled:
     1121            d.errback(f)
     1122
     1123    def process_blocks(self, segnum, blocks):
     1124        codec = self._codec
     1125        if segnum == self.num_segments-1:
     1126            codec = codec.CRSDecoder()
     1127            k, N = self._needed_shares, self._total_shares
     1128            codec.set_params(self.tail_segment_size, k, N)
     1129
     1130        shares = []
     1131        shareids = []
     1132        for (shareid, share) in blocks.iteritems():
     1133            shareids.append(shareid)
     1134            shares.append(share)
     1135        del blocks
     1136        segment = codec.decode(shares, shareids)
     1137        del shares
     1138        self._process_segment(segnum, segment)
     1139
     1140    def _process_segment(self, segnum, segment):
     1141        h = hashutil.crypttext_hash(segment)
     1142        try:
     1143            self.ciphertext_hash_tree.set_hashes(leaves={segnum, h})
     1144        except SOMETHING:
     1145            SOMETHING
     1146        assert self._active_segment.segnum == segnum
     1147        assert self.segment_size is not None
     1148        offset = segnum * self.segment_size
     1149        for (d,c) in self._extract_requests(segnum):
     1150            eventually(self._deliver, d, c, offset, segment)
     1151        self._active_segment = None
     1152        self._start_new_segment()
     1153
     1154    def _deliver(self, d, c, offset, segment):
     1155        # this method exists to handle cancel() that occurs between
     1156        # _got_segment and _deliver
     1157        if not c.cancelled:
     1158            d.callback((offset,segment))
     1159
     1160    def _extract_requests(self, segnum):
     1161        """Remove matching requests and return their (d,c) tuples so that the
     1162        caller can retire them."""
     1163        retire = [(d,c) for (segnum0, d, c) in self._segment_requests
     1164                  if segnum0 == segnum]
     1165        self._segment_requests = [t for t in self._segment_requests
     1166                                  if t[0] != segnum]
     1167        return retire
     1168
     1169    def _cancel_request(self, c):
     1170        self._segment_requests = [t for t in self._segment_requests
     1171                                  if t[2] != c]
     1172        segnums = [segnum for (segnum,d,c) in self._segment_requests]
     1173        if self._active_segment.segnum not in segnums:
     1174            self._active_segment.stop()
     1175            self._active_segment = None
     1176            self._start_new_segment()
     1177
     1178class DecryptingConsumer:
     1179    """I sit between a CiphertextDownloader (which acts as a Producer) and
     1180    the real Consumer, decrypting everything that passes by. The real
     1181    Consumer sees the real Producer, but the Producer sees us instead of the
     1182    real consumer."""
     1183    implements(IConsumer)
     1184
     1185    def __init__(self, consumer, readkey, offset):
     1186        self._consumer = consumer
     1187        # TODO: pycryptopp CTR-mode needs random-access operations: I want
     1188        # either a=AES(readkey, offset) or better yet both of:
     1189        #  a=AES(readkey, offset=0)
     1190        #  a.process(ciphertext, offset=xyz)
     1191        # For now, we fake it with the existing iv= argument.
     1192        offset_big = offset // 16
     1193        offset_small = offset % 16
     1194        iv = binascii.unhexlify("%032x" % offset_big)
     1195        self._decryptor = AES(readkey, iv=iv)
     1196        self._decryptor.process("\x00"*offset_small)
     1197
     1198    def registerProducer(self, producer):
     1199        # this passes through, so the real consumer can flow-control the real
     1200        # producer. Therefore we don't need to provide any IPushProducer
     1201        # methods. We implement all the IConsumer methods as pass-throughs,
     1202        # and only intercept write() to perform decryption.
     1203        self._consumer.registerProducer(producer)
     1204    def unregisterProducer(self):
     1205        self._consumer.unregisterProducer()
     1206    def write(self, ciphertext):
     1207        plaintext = self._decryptor.process(ciphertext)
     1208        self._consumer.write(plaintext)
     1209
     1210class ImmutableFileNode:
     1211    # I wrap a CiphertextFileNode with a decryption key
     1212    def __init__(self, filecap, storage_broker, secret_holder, downloader,
     1213                 history):
     1214        assert isinstance(filecap, CHKFileURI)
     1215        verifycap = filecap.get_verify_cap()
     1216        self._cnode = CiphertextFileNode(verifycap, storage_broker,
     1217                                         secret_holder, downloader, history)
     1218        assert isinstance(filecap, CHKFileURI)
     1219        self.u = filecap
     1220
     1221    def read(self, consumer, offset=0, size=None):
     1222        decryptor = DecryptingConsumer(consumer, self._readkey, offset)
     1223        return self._cnode.read(decryptor, offset, size)
     1224
     1225
     1226# TODO: if server1 has all shares, and server2-10 have one each, make the
     1227# loop stall slightly before requesting all shares from the first server, to
     1228# give it a chance to learn about the other shares and get some diversity.
     1229# Or, don't bother, let the first block all come from one server, and take
     1230# comfort in the fact that we'll learn about the other servers by the time we
     1231# fetch the second block.
     1232#
     1233# davidsarah points out that we could use sequential (instead of parallel)
     1234# fetching of multiple block from a single server: by the time the first
     1235# block arrives, we'll hopefully have heard about other shares. This would
     1236# induce some RTT delays (i.e. lose pipelining) in the case that this server
     1237# has the only shares, but that seems tolerable. We could rig it to only use
     1238# sequential requests on the first segment.
     1239
     1240# as a query gets later, we're more willing to duplicate work.
     1241
     1242# should change server read protocol to allow small shares to be fetched in a
     1243# single RTT. Instead of get_buckets-then-read, just use read(shnums, readv),
     1244# where shnums=[] means all shares, and the return value is a dict of
     1245# # shnum->ta (like with mutable files). The DYHB query should also fetch the
     1246# offset table, since everything else can be located once we have that.
     1247
     1248
     1249# ImmutableFileNode
     1250#    DecryptingConsumer
     1251#  CiphertextFileNode
     1252#    Segmentation
     1253#   ShareFinder
     1254#   SegmentFetcher[segnum] (one at a time)
     1255#   CommonShare[shnum]
     1256#   Share[shnum,server]
     1257
     1258# TODO: when we learn numsegs, any get_segment() calls for bad blocknumbers
     1259# should be failed with BadSegmentNumberError. But should this be the
     1260# responsibility of CiphertextFileNode, or SegmentFetcher? The knowledge will
     1261# first appear when a Share receives a valid UEB and calls
     1262# CiphertextFileNode.validate_UEB, then _parse_UEB. The SegmentFetcher is
     1263# expecting to hear from the Share, via the _block_request_activity observer.
     1264
     1265# make it the responsibility of the SegmentFetcher. Each Share that gets a
     1266# valid UEB will tell the SegmentFetcher BADSEGNUM (instead of COMPLETE or
     1267# CORRUPT). The SegmentFetcher it then responsible for shutting down, and
     1268# informing its parent (the CiphertextFileNode) of the BadSegmentNumberError,
     1269# which is then passed to the client of get_segment().
     1270
     1271
     1272# TODO: if offset table is corrupt, attacker could cause us to fetch whole
     1273# (large) share
  • new file src/allmydata/immutable/download2_off.py

    diff --git a/src/allmydata/immutable/download2_off.py b/src/allmydata/immutable/download2_off.py
    new file mode 100755
    index 0000000..d2b8b99
    - +  
     1#! /usr/bin/python
     2
     3# known (shnum,Server) pairs are sorted into a list according to
     4# desireability. This sort is picking a winding path through a matrix of
     5# [shnum][server]. The goal is to get diversity of both shnum and server.
     6
     7# The initial order is:
     8#  find the lowest shnum on the first server, add it
     9#  look at the next server, find the lowest shnum that we don't already have
     10#   if any
     11#  next server, etc, until all known servers are checked
     12#  now look at servers that we skipped (because ...
     13
     14# Keep track of which block requests are outstanding by (shnum,Server). Don't
     15# bother prioritizing "validated" shares: the overhead to pull the share hash
     16# chain is tiny (4 hashes = 128 bytes), and the overhead to pull a new block
     17# hash chain is also tiny (1GB file, 8192 segments of 128KiB each, 13 hashes,
     18# 832 bytes). Each time a block request is sent, also request any necessary
     19# hashes. Don't bother with a "ValidatedShare" class (as distinct from some
     20# other sort of Share). Don't bother avoiding duplicate hash-chain requests.
     21
     22# For each outstanding segread, walk the list and send requests (skipping
     23# outstanding shnums) until requests for k distinct shnums are in flight. If
     24# we can't do that, ask for more. If we get impatient on a request, find the
     25# first non-outstanding
     26
     27# start with the first Share in the list, and send a request. Then look at
     28# the next one. If we already have a pending request for the same shnum or
     29# server, push that Share down onto the fallback list and try the next one,
     30# etc. If we run out of non-fallback shares, use the fallback ones,
     31# preferring shnums that we don't have outstanding requests for (i.e. assume
     32# that all requests will complete). Do this by having a second fallback list.
     33
     34# hell, I'm reviving the Herder. But remember, we're still talking 3 objects
     35# per file, not thousands.
     36
     37# actually, don't bother sorting the initial list. Append Shares as the
     38# responses come back, that will put the fastest servers at the front of the
     39# list, and give a tiny preference to servers that are earlier in the
     40# permuted order.
     41
     42# more ideas:
     43#  sort shares by:
     44#   1: number of roundtrips needed to get some data
     45#   2: share number
     46#   3: ms of RTT delay
     47# maybe measure average time-to-completion of requests, compare completion
     48# time against that, much larger indicates congestion on the server side
     49# or the server's upstream speed is less than our downstream. Minimum
     50# time-to-completion indicates min(our-downstream,their-upstream). Could
     51# fetch shares one-at-a-time to measure that better.
     52
     53# when should we risk duplicate work and send a new request?
     54
     55def walk(self):
     56    shares = sorted(list)
     57    oldshares = copy(shares)
     58    outstanding = list()
     59    fallbacks = list()
     60    second_fallbacks = list()
     61    while len(outstanding.nonlate.shnums) < k: # need more requests
     62        while oldshares:
     63            s = shares.pop(0)
     64            if s.server in outstanding.servers or s.shnum in outstanding.shnums:
     65                fallbacks.append(s)
     66                continue
     67            outstanding.append(s)
     68            send_request(s)
     69            break #'while need_more_requests'
     70        # must use fallback list. Ask for more servers while we're at it.
     71        ask_for_more_servers()
     72        while fallbacks:
     73            s = fallbacks.pop(0)
     74            if s.shnum in outstanding.shnums:
     75                # assume that the outstanding requests will complete, but
     76                # send new requests for other shnums to existing servers
     77                second_fallbacks.append(s)
     78                continue
     79            outstanding.append(s)
     80            send_request(s)
     81            break #'while need_more_requests'
     82        # if we get here, we're being forced to send out multiple queries per
     83        # share. We've already asked for more servers, which might help. If
     84        # there are no late outstanding queries, then duplicate shares won't
     85        # help. Don't send queries for duplicate shares until some of the
     86        # queries are late.
     87        if outstanding.late:
     88            # we're allowed to try any non-outstanding share
     89            while second_fallbacks:
     90                pass
     91    newshares = outstanding + fallbacks + second_fallbacks + oldshares
     92       
     93
     94class Server:
     95    """I represent an abstract Storage Server. One day, the StorageBroker
     96    will return instances of me. For now, the StorageBroker returns (peerid,
     97    RemoteReference) tuples, and this code wraps a Server instance around
     98    them.
     99    """
     100    def __init__(self, peerid, ss):
     101        self.peerid = peerid
     102        self.remote = ss
     103        self._remote_buckets = {} # maps shnum to RIBucketReader
     104        # TODO: release the bucket references on shares that we no longer
     105        # want. OTOH, why would we not want them? Corruption?
     106
     107    def send_query(self, storage_index):
     108        """I return a Deferred that fires with a set of shnums. If the server
     109        had shares available, I will retain the RemoteReferences to its
     110        buckets, so that get_data(shnum, range) can be called later."""
     111        d = self.remote.callRemote("get_buckets", self.storage_index)
     112        d.addCallback(self._got_response)
     113        return d
     114
     115    def _got_response(self, r):
     116        self._remote_buckets = r
     117        return set(r.keys())
     118
     119class ShareOnAServer:
     120    """I represent one instance of a share, known to live on a specific
     121    server. I am created every time a server responds affirmatively to a
     122    do-you-have-block query."""
     123
     124    def __init__(self, shnum, server):
     125        self._shnum = shnum
     126        self._server = server
     127        self._block_hash_tree = None
     128
     129    def cost(self, segnum):
     130        """I return a tuple of (roundtrips, bytes, rtt), indicating how
     131        expensive I think it would be to fetch the given segment. Roundtrips
     132        indicates how many roundtrips it is likely to take (one to get the
     133        data and hashes, plus one to get the offset table and UEB if this is
     134        the first segment we've ever fetched). 'bytes' is how many bytes we
     135        must fetch (estimated). 'rtt' is estimated round-trip time (float) in
     136        seconds for a trivial request. The downloading algorithm will compare
     137        costs to decide which shares should be used."""
     138        # the most significant factor here is roundtrips: a Share for which
     139        # we already have the offset table is better to than a brand new one
     140
     141    def max_bandwidth(self):
     142        """Return a float, indicating the highest plausible bytes-per-second
     143        that I've observed coming from this share. This will be based upon
     144        the minimum (bytes-per-fetch / time-per-fetch) ever observed. This
     145        can we used to estimate the server's upstream bandwidth. Clearly this
     146        is only accurate if a share is retrieved with no contention for
     147        either the upstream, downstream, or middle of the connection, but it
     148        may still serve as a useful metric for deciding which servers to pull
     149        from."""
     150
     151    def get_segment(self, segnum):
     152        """I return a Deferred that will fire with the segment data, or
     153        errback."""
     154
     155class NativeShareOnAServer(ShareOnAServer):
     156    """For tahoe native (foolscap) servers, I contain a RemoteReference to
     157    the RIBucketReader instance."""
     158    def __init__(self, shnum, server, rref):
     159        ShareOnAServer.__init__(self, shnum, server)
     160        self._rref = rref # RIBucketReader
     161
     162class Share:
     163    def __init__(self, shnum):
     164        self._shnum = shnum
     165        # _servers are the Server instances which appear to hold a copy of
     166        # this share. It is populated when the ValidShare is first created,
     167        # or when we receive a get_buckets() response for a shnum that
     168        # already has a ValidShare instance. When we lose the connection to a
     169        # server, we remove it.
     170        self._servers = set()
     171        # offsets, UEB, and share_hash_tree all live in the parent.
     172        # block_hash_tree lives here.
     173        self._block_hash_tree = None
     174
     175        self._want
     176
     177    def get_servers(self):
     178        return self._servers
     179
     180
     181    def get_block(self, segnum):
     182        # read enough data to obtain a single validated block
     183        if not self.have_offsets:
     184            # we get the offsets in their own read, since they tell us where
     185            # everything else lives. We must fetch offsets for each share
     186            # separately, since they aren't directly covered by the UEB.
     187            pass
     188        if not self.parent.have_ueb:
     189            # use _guessed_segsize to make a guess about the layout, so we
     190            # can fetch both the offset table and the UEB in the same read.
     191            # This also requires making a guess about the presence or absence
     192            # of the plaintext_hash_tree. Oh, and also the version number. Oh
     193            # well.
     194            pass
     195
     196class CiphertextDownloader:
     197    """I manage all downloads for a single file. I operate a state machine
     198    with input events that are local read() requests, responses to my remote
     199    'get_bucket' and 'read_bucket' messages, and connection establishment and
     200    loss. My outbound events are connection establishment requests and bucket
     201    read requests messages.
     202    """
     203    # eventually this will merge into the FileNode
     204    ServerClass = Server # for tests to override
     205
     206    def __init__(self, storage_index, ueb_hash, size, k, N, storage_broker,
     207                 shutdowner):
     208        # values we get from the filecap
     209        self._storage_index = si = storage_index
     210        self._ueb_hash = ueb_hash
     211        self._size = size
     212        self._needed_shares = k
     213        self._total_shares = N
     214        self._share_hash_tree = IncompleteHashTree(self._total_shares)
     215        # values we discover when we first fetch the UEB
     216        self._ueb = None # is dict after UEB fetch+validate
     217        self._segsize = None
     218        self._numsegs = None
     219        self._blocksize = None
     220        self._tail_segsize = None
     221        self._ciphertext_hash = None # optional
     222        # structures we create when we fetch the UEB, then continue to fill
     223        # as we download the file
     224        self._share_hash_tree = None # is IncompleteHashTree after UEB fetch
     225        self._ciphertext_hash_tree = None
     226
     227        # values we learn as we download the file
     228        self._offsets = {} # (shnum,Server) to offset table (dict)
     229        self._block_hash_tree = {} # shnum to IncompleteHashTree
     230        # other things which help us
     231        self._guessed_segsize = min(128*1024, size)
     232        self._active_share_readers = {} # maps shnum to Reader instance
     233        self._share_readers = [] # sorted by preference, best first
     234        self._readers = set() # set of Reader instances
     235        self._recent_horizon = 10 # seconds
     236
     237        # 'shutdowner' is a MultiService parent used to cancel all downloads
     238        # when the node is shutting down, to let tests have a clean reactor.
     239
     240        self._init_available_servers()
     241        self._init_find_enough_shares()
     242
     243    # _available_servers is an iterator that provides us with Server
     244    # instances. Each time we pull out a Server, we immediately send it a
     245    # query, so we don't need to keep track of who we've sent queries to.
     246
     247    def _init_available_servers(self):
     248        self._available_servers = self._get_available_servers()
     249        self._no_more_available_servers = False
     250
     251    def _get_available_servers(self):
     252        """I am a generator of servers to use, sorted by the order in which
     253        we should query them. I make sure there are no duplicates in this
     254        list."""
     255        # TODO: make StorageBroker responsible for this non-duplication, and
     256        # replace this method with a simple iter(get_servers_for_index()),
     257        # plus a self._no_more_available_servers=True
     258        seen = set()
     259        sb = self._storage_broker
     260        for (peerid, ss) in sb.get_servers_for_index(self._storage_index):
     261            if peerid not in seen:
     262                yield self.ServerClass(peerid, ss) # Server(peerid, ss)
     263                seen.add(peerid)
     264        self._no_more_available_servers = True
     265
     266    # this block of code is responsible for having enough non-problematic
     267    # distinct shares/servers available and ready for download, and for
     268    # limiting the number of queries that are outstanding. The idea is that
     269    # we'll use the k fastest/best shares, and have the other ones in reserve
     270    # in case those servers stop responding or respond too slowly. We keep
     271    # track of all known shares, but we also keep track of problematic shares
     272    # (ones with hash failures or lost connections), so we can put them at
     273    # the bottom of the list.
     274
     275    def _init_find_enough_shares(self):
     276        # _unvalidated_sharemap maps shnum to set of Servers, and remembers
     277        # where viable (but not yet validated) shares are located. Each
     278        # get_bucket() response adds to this map, each act of validation
     279        # removes from it.
     280        self._sharemap = DictOfSets()
     281
     282        # _sharemap maps shnum to set of Servers, and remembers where viable
     283        # shares are located. Each get_bucket() response adds to this map,
     284        # each hash failure or disconnect removes from it. (TODO: if we
     285        # disconnect but reconnect later, we should be allowed to re-query).
     286        self._sharemap = DictOfSets()
     287
     288        # _problem_shares is a set of (shnum, Server) tuples, and
     289
     290        # _queries_in_flight maps a Server to a timestamp, which remembers
     291        # which servers we've sent queries to (and when) but have not yet
     292        # heard a response. This lets us put a limit on the number of
     293        # outstanding queries, to limit the size of the work window (how much
     294        # extra work we ask servers to do in the hopes of keeping our own
     295        # pipeline filled). We remove a Server from _queries_in_flight when
     296        # we get an answer/error or we finally give up. If we ever switch to
     297        # a non-connection-oriented protocol (like UDP, or forwarded Chord
     298        # queries), we can use this information to retransmit any query that
     299        # has gone unanswered for too long.
     300        self._queries_in_flight = dict()
     301
     302    def _count_recent_queries_in_flight(self):
     303        now = time.time()
     304        recent = now - self._recent_horizon
     305        return len([s for (s,when) in self._queries_in_flight.items()
     306                    if when > recent])
     307
     308    def _find_enough_shares(self):
     309        # goal: have 2*k distinct not-invalid shares available for reading,
     310        # from 2*k distinct servers. Do not have more than 4*k "recent"
     311        # queries in flight at a time.
     312        if (len(self._sharemap) >= 2*self._needed_shares
     313            and len(self._sharemap.values) >= 2*self._needed_shares):
     314            return
     315        num = self._count_recent_queries_in_flight()
     316        while num < 4*self._needed_shares:
     317            try:
     318                s = self._available_servers.next()
     319            except StopIteration:
     320                return # no more progress can be made
     321            self._queries_in_flight[s] = time.time()
     322            d = s.send_query(self._storage_index)
     323            d.addBoth(incidentally, self._queries_in_flight.discard, s)
     324            d.addCallbacks(lambda shnums: [self._sharemap.add(shnum, s)
     325                                           for shnum in shnums],
     326                           lambda f: self._query_error(f, s))
     327            d.addErrback(self._error)
     328            d.addCallback(self._reschedule)
     329            num += 1
     330
     331    def _query_error(self, f, s):
     332        # a server returned an error, log it gently and ignore
     333        level = log.WEIRD
     334        if f.check(DeadReferenceError):
     335            level = log.UNUSUAL
     336        log.msg("Error during get_buckets to server=%(server)s", server=str(s),
     337                failure=f, level=level, umid="3uuBUQ")
     338
     339    # this block is responsible for turning known shares into usable shares,
     340    # by fetching enough data to validate their contents.
     341
     342    # UEB (from any share)
     343    # share hash chain, validated (from any share, for given shnum)
     344    # block hash (any share, given shnum)
     345
     346    def _got_ueb(self, ueb_data, share):
     347        if self._ueb is not None:
     348            return
     349        if hashutil.uri_extension_hash(ueb_data) != self._ueb_hash:
     350            share.error("UEB hash does not match")
     351            return
     352        d = uri.unpack_extension(ueb_data)
     353        self.share_size = mathutil.div_ceil(self._size, self._needed_shares)
     354
     355
     356        # There are several kinds of things that can be found in a UEB.
     357        # First, things that we really need to learn from the UEB in order to
     358        # do this download. Next: things which are optional but not redundant
     359        # -- if they are present in the UEB they will get used. Next, things
     360        # that are optional and redundant. These things are required to be
     361        # consistent: they don't have to be in the UEB, but if they are in
     362        # the UEB then they will be checked for consistency with the
     363        # already-known facts, and if they are inconsistent then an exception
     364        # will be raised. These things aren't actually used -- they are just
     365        # tested for consistency and ignored. Finally: things which are
     366        # deprecated -- they ought not be in the UEB at all, and if they are
     367        # present then a warning will be logged but they are otherwise
     368        # ignored.
     369
     370        # First, things that we really need to learn from the UEB:
     371        # segment_size, crypttext_root_hash, and share_root_hash.
     372        self._segsize = d['segment_size']
     373
     374        self._blocksize = mathutil.div_ceil(self._segsize, self._needed_shares)
     375        self._numsegs = mathutil.div_ceil(self._size, self._segsize)
     376
     377        self._tail_segsize = self._size % self._segsize
     378        if self._tail_segsize == 0:
     379            self._tail_segsize = self._segsize
     380        # padding for erasure code
     381        self._tail_segsize = mathutil.next_multiple(self._tail_segsize,
     382                                                    self._needed_shares)
     383
     384        # Ciphertext hash tree root is mandatory, so that there is at most
     385        # one ciphertext that matches this read-cap or verify-cap. The
     386        # integrity check on the shares is not sufficient to prevent the
     387        # original encoder from creating some shares of file A and other
     388        # shares of file B.
     389        self._ciphertext_hash_tree = IncompleteHashTree(self._numsegs)
     390        self._ciphertext_hash_tree.set_hashes({0: d['crypttext_root_hash']})
     391
     392        self._share_hash_tree.set_hashes({0: d['share_root_hash']})
     393
     394
     395        # Next: things that are optional and not redundant: crypttext_hash
     396        if 'crypttext_hash' in d:
     397            if len(self._ciphertext_hash) == hashutil.CRYPTO_VAL_SIZE:
     398                self._ciphertext_hash = d['crypttext_hash']
     399            else:
     400                log.msg("ignoring bad-length UEB[crypttext_hash], "
     401                        "got %d bytes, want %d" % (len(d['crypttext_hash']),
     402                                                   hashutil.CRYPTO_VAL_SIZE),
     403                        umid="oZkGLA", level=log.WEIRD)
     404
     405        # we ignore all of the redundant fields when downloading. The
     406        # Verifier uses a different code path which does not ignore them.
     407
     408        # finally, set self._ueb as a marker that we don't need to request it
     409        # anymore
     410        self._ueb = d
     411
     412    def _got_share_hashes(self, hashes, share):
     413        assert isinstance(hashes, dict)
     414        try:
     415            self._share_hash_tree.set_hashes(hashes)
     416        except (IndexError, BadHashError, NotEnoughHashesError), le:
     417            share.error("Bad or missing hashes")
     418            return
     419
     420    #def _got_block_hashes(
     421
     422    def _init_validate_enough_shares(self):
     423        # _valid_shares maps shnum to ValidatedShare instances, and is
     424        # populated once the block hash root has been fetched and validated
     425        # (which requires any valid copy of the UEB, and a valid copy of the
     426        # share hash chain for each shnum)
     427        self._valid_shares = {}
     428
     429        # _target_shares is an ordered list of ReadyShare instances, each of
     430        # which is a (shnum, server) tuple. It is sorted in order of
     431        # preference: we expect to get the fastest response from the
     432        # ReadyShares at the front of the list. It is also sorted to
     433        # distribute the shnums, so that fetching shares from
     434        # _target_shares[:k] is likely (but not guaranteed) to give us k
     435        # distinct shares. The rule is that we skip over entries for blocks
     436        # that we've already received, limit the number of recent queries for
     437        # the same block,
     438        self._target_shares = []
     439
     440    def _validate_enough_shares(self):
     441        # my goal is to have at least 2*k distinct validated shares from at
     442        # least 2*k distinct servers
     443        valid_share_servers = set()
     444        for vs in self._valid_shares.values():
     445            valid_share_servers.update(vs.get_servers())
     446        if (len(self._valid_shares) >= 2*self._needed_shares
     447            and len(self._valid_share_servers) >= 2*self._needed_shares):
     448            return
     449        #for
     450
     451    def _reschedule(self, _ign):
     452        # fire the loop again
     453        if not self._scheduled:
     454            self._scheduled = True
     455            eventually(self._loop)
     456
     457    def _loop(self):
     458        self._scheduled = False
     459        # what do we need?
     460
     461        self._find_enough_shares()
     462        self._validate_enough_shares()
     463
     464        if not self._ueb:
     465            # we always need a copy of the UEB
     466            pass
     467
     468    def _error(self, f):
     469        # this is an unexpected error: a coding bug
     470        log.err(f, level=log.UNUSUAL)
     471           
     472
     473
     474# using a single packed string (and an offset table) may be an artifact of
     475# our native storage server: other backends might allow cheap multi-part
     476# files (think S3, several buckets per share, one for each section).
     477
     478# find new names for:
     479#  data_holder
     480#  Share / Share2  (ShareInstance / Share? but the first is more useful)
     481
     482class IShare(Interface):
     483    """I represent a single instance of a single share (e.g. I reference the
     484    shnum2 for share SI=abcde on server xy12t, not the one on server ab45q).
     485    This interface is used by SegmentFetcher to retrieve validated blocks.
     486    """
     487    def get_block(segnum):
     488        """Return an Observer2, which will be notified with the following
     489        events:
     490         state=COMPLETE, block=data (terminal): validated block data
     491         state=OVERDUE (non-terminal): we have reason to believe that the
     492                                       request might have stalled, or we
     493                                       might just be impatient
     494         state=CORRUPT (terminal): the data we received was corrupt
     495         state=DEAD (terminal): the connection has failed
     496        """
     497
     498
     499# it'd be nice if we receive the hashes before the block, or just
     500# afterwards, so we aren't stuck holding on to unvalidated blocks
     501# that we can't process. If we guess the offsets right, we can
     502# accomplish this by sending the block request after the metadata
     503# requests (by keeping two separate requestlists), and have a one RTT
     504# pipeline like:
     505#  1a=metadata, 1b=block
     506#  1b->process+deliver : one RTT
     507
     508# But if we guess wrong, and fetch the wrong part of the block, we'll
     509# have a pipeline that looks like:
     510#  1a=wrong metadata, 1b=wrong block
     511#  1a->2a=right metadata,2b=right block
     512#  2b->process+deliver
     513# which means two RTT and buffering one block (which, since we'll
     514# guess the segsize wrong for everything, means buffering one
     515# segment)
     516
     517# if we start asking for multiple segments, we could get something
     518# worse:
     519#  1a=wrong metadata, 1b=wrong block0, 1c=wrong block1, ..
     520#  1a->2a=right metadata,2b=right block0,2c=right block1, .
     521#  2b->process+deliver
     522
     523# which means two RTT but fetching and buffering the whole file
     524# before delivering anything. However, since we don't know when the
     525# other shares are going to arrive, we need to avoid having more than
     526# one block in the pipeline anyways. So we shouldn't be able to get
     527# into this state.
     528
     529# it also means that, instead of handling all of
     530# self._requested_blocks at once, we should only be handling one
     531# block at a time: one of the requested block should be special
     532# (probably FIFO). But retire all we can.
     533
     534    # this might be better with a Deferred, using COMPLETE as the success
     535    # case and CORRUPT/DEAD in an errback, because that would let us hold the
     536    # 'share' and 'shnum' arguments locally (instead of roundtripping them
     537    # through Share.send_request). But that OVERDUE is not terminal. So I
     538    # want a new sort of callback mechanism, with the extra-argument-passing
     539    # aspects of Deferred, but without being so one-shot. Is this a job for
     540    # Observer? No, it doesn't take extra arguments. So this uses Observer2.
     541
     542
     543class Reader:
     544    """I am responsible for a single offset+size read of the file. I handle
     545    segmentation: I figure out which segments are necessary, request them
     546    (from my CiphertextDownloader) in order, and trim the segments down to
     547    match the offset+size span. I use the Producer/Consumer interface to only
     548    request one segment at a time.
     549    """
     550    implements(IPushProducer)
     551    def __init__(self, consumer, offset, size):
     552        self._needed = []
     553        self._consumer = consumer
     554        self._hungry = False
     555        self._offset = offset
     556        self._size = size
     557        self._segsize = None
     558    def start(self):
     559        self._alive = True
     560        self._deferred = defer.Deferred()
     561        # the process doesn't actually start until set_segment_size()
     562        return self._deferred
     563
     564    def set_segment_size(self, segsize):
     565        if self._segsize is not None:
     566            return
     567        self._segsize = segsize
     568        self._compute_segnums()
     569
     570    def _compute_segnums(self, segsize):
     571        # now that we know the file's segsize, what segments (and which
     572        # ranges of each) will we need?
     573        size = self._size
     574        offset = self._offset
     575        while size:
     576            assert size >= 0
     577            this_seg_num = int(offset / self._segsize)
     578            this_seg_offset = offset - (seg_num*self._segsize)
     579            this_seg_size = min(size, self._segsize-seg_offset)
     580            size -= this_seg_size
     581            if size:
     582                offset += this_seg_size
     583            yield (this_seg_num, this_seg_offset, this_seg_size)
     584
     585    def get_needed_segments(self):
     586        return set([segnum for (segnum, off, size) in self._needed])
     587
     588
     589    def stopProducing(self):
     590        self._hungry = False
     591        self._alive = False
     592        # TODO: cancel the segment requests
     593    def pauseProducing(self):
     594        self._hungry = False
     595    def resumeProducing(self):
     596        self._hungry = True
     597    def add_segment(self, segnum, offset, size):
     598        self._needed.append( (segnum, offset, size) )
     599    def got_segment(self, segnum, segdata):
     600        """Return True if this schedule has more to go, or False if it is
     601        done."""
     602        assert self._needed[0][segnum] == segnum
     603        (_ign, offset, size) = self._needed.pop(0)
     604        data = segdata[offset:offset+size]
     605        self._consumer.write(data)
     606        if not self._needed:
     607            # we're done
     608            self._alive = False
     609            self._hungry = False
     610            self._consumer.unregisterProducer()
     611            self._deferred.callback(self._consumer)
     612    def error(self, f):
     613        self._alive = False
     614        self._hungry = False
     615        self._consumer.unregisterProducer()
     616        self._deferred.errback(f)
     617
     618
     619
     620class x:
     621    def OFFread(self, consumer, offset=0, size=None):
     622        """I am the main entry point, from which FileNode.read() can get
     623        data."""
     624        # tolerate concurrent operations: each gets its own Reader
     625        if size is None:
     626            size = self._size - offset
     627        r = Reader(consumer, offset, size)
     628        self._readers.add(r)
     629        d = r.start()
     630        if self.segment_size is not None:
     631            r.set_segment_size(self.segment_size)
     632            # TODO: if we can't find any segments, and thus never get a
     633            # segsize, tell the Readers to give up
     634        return d
  • new file src/allmydata/immutable/download2_util.py

    diff --git a/src/allmydata/immutable/download2_util.py b/src/allmydata/immutable/download2_util.py
    new file mode 100755
    index 0000000..48f2f0a
    - +  
     1
     2import weakref
     3
     4class Observer2:
     5    """A simple class to distribute multiple events to a single subscriber.
     6    It accepts arbitrary kwargs, but no posargs."""
     7    def __init__(self):
     8        self._watcher = None
     9        self._undelivered_results = []
     10        self._canceler = None
     11
     12    def set_canceler(self, f):
     13        # we use a weakref to avoid creating a cycle between us and the thing
     14        # we're observing: they'll be holding a reference to us to compare
     15        # against the value we pass to their canceler function.
     16        self._canceler = weakref(f)
     17
     18    def subscribe(self, observer, **watcher_kwargs):
     19        self._watcher = (observer, watcher_kwargs)
     20        while self._undelivered_results:
     21            self._notify(self._undelivered_results.pop(0))
     22
     23    def notify(self, **result_kwargs):
     24        if self._watcher:
     25            self._notify(result_kwargs)
     26        else:
     27            self._undelivered_results.append(result_kwargs)
     28
     29    def _notify(self, result_kwargs):
     30        o, watcher_kwargs = self._watcher
     31        kwargs = dict(result_kwargs)
     32        kwargs.update(watcher_kwargs)
     33        eventually(o, **kwargs)
     34
     35    def cancel(self):
     36        f = self._canceler()
     37        if f:
     38            f(self)
     39
     40class DictOfSets:
     41    def add(self, key, value): pass
     42    def values(self): # return set that merges all value sets
     43        r = set()
     44        for key in self:
     45            r.update(self[key])
     46        return r
     47
     48
     49def incidentally(res, f, *args, **kwargs):
     50    """Add me to a Deferred chain like this:
     51     d.addBoth(incidentally, func, arg)
     52    and I'll behave as if you'd added the following function:
     53     def _(res):
     54         func(arg)
     55         return res
     56    This is useful if you want to execute an expression when the Deferred
     57    fires, but don't care about its value.
     58    """
     59    f(*args, **kwargs)
     60    return res
     61
     62
     63import weakref
     64class Terminator(service.Service):
     65    def __init__(self):
     66        service.Service.__init__(self)
     67        self._clients = weakref.WeakKeyDictionary()
     68    def register(self, c):
     69        self._clients[c] = None
     70    def stopService(self):
     71        for c in self._clients:
     72            c.stop()
     73        return service.Service.stopService(self)
  • src/allmydata/test/test_util.py

    diff --git a/src/allmydata/test/test_util.py b/src/allmydata/test/test_util.py
    index 6874655..b7537d7 100644
    a b from twisted.trial import unittest 
    77from twisted.internet import defer, reactor
    88from twisted.python.failure import Failure
    99from twisted.python import log
     10from hashlib import md5
    1011
    1112from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil
    1213from allmydata.util import assertutil, fileutil, deferredutil, abbreviate
    1314from allmydata.util import limiter, time_format, pollmixin, cachedir
    1415from allmydata.util import statistics, dictutil, pipeline
    1516from allmydata.util import log as tahoe_log
     17from allmydata.util.spans import Spans, overlap, DataSpans
    1618
    1719class Base32(unittest.TestCase):
    1820    def test_b2a_matches_Pythons(self):
    class Log(unittest.TestCase): 
    15111513        tahoe_log.err(format="intentional sample error",
    15121514                      failure=f, level=tahoe_log.OPERATIONAL, umid="wO9UoQ")
    15131515        self.flushLoggedErrors(SampleError)
     1516
     1517
     1518class SimpleSpans:
     1519    # this is a simple+inefficient form of util.spans.Spans . We compare the
     1520    # behavior of this reference model against the real (efficient) form.
     1521
     1522    def __init__(self, _span_or_start=None, length=None):
     1523        self._have = set()
     1524        if length is not None:
     1525            for i in range(_span_or_start, _span_or_start+length):
     1526                self._have.add(i)
     1527        elif _span_or_start:
     1528            for (start,length) in _span_or_start:
     1529                self.add(start, length)
     1530
     1531    def add(self, start, length):
     1532        for i in range(start, start+length):
     1533            self._have.add(i)
     1534        return self
     1535
     1536    def remove(self, start, length):
     1537        for i in range(start, start+length):
     1538            self._have.discard(i)
     1539        return self
     1540
     1541    def each(self):
     1542        return sorted(self._have)
     1543
     1544    def __iter__(self):
     1545        items = sorted(self._have)
     1546        prevstart = None
     1547        prevend = None
     1548        for i in items:
     1549            if prevstart is None:
     1550                prevstart = prevend = i
     1551                continue
     1552            if i == prevend+1:
     1553                prevend = i
     1554                continue
     1555            yield (prevstart, prevend-prevstart+1)
     1556            prevstart = prevend = i
     1557        if prevstart is not None:
     1558            yield (prevstart, prevend-prevstart+1)
     1559
     1560    def __len__(self):
     1561        # this also gets us bool(s)
     1562        return len(self._have)
     1563
     1564    def __add__(self, other):
     1565        s = self.__class__(self)
     1566        for (start, length) in other:
     1567            s.add(start, length)
     1568        return s
     1569
     1570    def __sub__(self, other):
     1571        s = self.__class__(self)
     1572        for (start, length) in other:
     1573            s.remove(start, length)
     1574        return s
     1575
     1576    def __iadd__(self, other):
     1577        for (start, length) in other:
     1578            self.add(start, length)
     1579        return self
     1580
     1581    def __isub__(self, other):
     1582        for (start, length) in other:
     1583            self.remove(start, length)
     1584        return self
     1585
     1586    def __contains__(self, (start,length)):
     1587        for i in range(start, start+length):
     1588            if i not in self._have:
     1589                return False
     1590        return True
     1591
     1592class ByteSpans(unittest.TestCase):
     1593    def test_basic(self):
     1594        s = Spans()
     1595        self.failUnlessEqual(list(s), [])
     1596        self.failIf(s)
     1597        self.failIf((0,1) in s)
     1598        self.failUnlessEqual(len(s), 0)
     1599
     1600        s1 = Spans(3, 4) # 3,4,5,6
     1601        self._check1(s1)
     1602
     1603        s2 = Spans(s1)
     1604        self._check1(s2)
     1605
     1606        s2.add(10,2) # 10,11
     1607        self._check1(s1)
     1608        self.failUnless((10,1) in s2)
     1609        self.failIf((10,1) in s1)
     1610        self.failUnlessEqual(list(s2.each()), [3,4,5,6,10,11])
     1611        self.failUnlessEqual(len(s2), 6)
     1612
     1613        s2.add(15,2).add(20,2)
     1614        self.failUnlessEqual(list(s2.each()), [3,4,5,6,10,11,15,16,20,21])
     1615        self.failUnlessEqual(len(s2), 10)
     1616
     1617        s2.remove(4,3).remove(15,1)
     1618        self.failUnlessEqual(list(s2.each()), [3,10,11,16,20,21])
     1619        self.failUnlessEqual(len(s2), 6)
     1620
     1621    def _check1(self, s):
     1622        self.failUnlessEqual(list(s), [(3,4)])
     1623        self.failUnless(s)
     1624        self.failUnlessEqual(len(s), 4)
     1625        self.failIf((0,1) in s)
     1626        self.failUnless((3,4) in s)
     1627        self.failUnless((3,1) in s)
     1628        self.failUnless((5,2) in s)
     1629        self.failUnless((6,1) in s)
     1630        self.failIf((6,2) in s)
     1631        self.failIf((7,1) in s)
     1632        self.failUnlessEqual(list(s.each()), [3,4,5,6])
     1633
     1634    def test_math(self):
     1635        s1 = Spans(0, 10) # 0,1,2,3,4,5,6,7,8,9
     1636        s2 = Spans(5, 3) # 5,6,7
     1637        s3 = Spans(8, 4) # 8,9,10,11
     1638
     1639        s = s1 - s2
     1640        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,8,9])
     1641        s = s1 - s3
     1642        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7])
     1643        s = s2 - s3
     1644        self.failUnlessEqual(list(s.each()), [5,6,7])
     1645
     1646        s = s1 + s2
     1647        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9])
     1648        s = s1 + s3
     1649        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9,10,11])
     1650        s = s2 + s3
     1651        self.failUnlessEqual(list(s.each()), [5,6,7,8,9,10,11])
     1652
     1653        s = Spans(s1)
     1654        s -= s2
     1655        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,8,9])
     1656        s = Spans(s1)
     1657        s -= s3
     1658        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7])
     1659        s = Spans(s2)
     1660        s -= s3
     1661        self.failUnlessEqual(list(s.each()), [5,6,7])
     1662
     1663        s = Spans(s1)
     1664        s += s2
     1665        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9])
     1666        s = Spans(s1)
     1667        s += s3
     1668        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9,10,11])
     1669        s = Spans(s2)
     1670        s += s3
     1671        self.failUnlessEqual(list(s.each()), [5,6,7,8,9,10,11])
     1672
     1673    def test_random(self):
     1674        # attempt to increase coverage of corner cases by comparing behavior
     1675        # of a simple-but-slow model implementation against the
     1676        # complex-but-fast actual implementation, in a large number of random
     1677        # operations
     1678        S1 = SimpleSpans
     1679        S2 = Spans
     1680        s1 = S1(); s2 = S2()
     1681        seed = ""
     1682        def _create(subseed):
     1683            ns1 = S1(); ns2 = S2()
     1684            for i in range(10):
     1685                what = md5(subseed+str(i)).hexdigest()
     1686                start = int(what[2:4], 16)
     1687                length = max(1,int(what[5:6], 16))
     1688                ns1.add(start, length); ns2.add(start, length)
     1689            return ns1, ns2
     1690
     1691        #print
     1692        for i in range(1000):
     1693            what = md5(seed+str(i)).hexdigest()
     1694            op = what[0]
     1695            subop = what[1]
     1696            start = int(what[2:4], 16)
     1697            length = max(1,int(what[5:6], 16))
     1698            #print what
     1699            if op in "0":
     1700                if subop in "01234":
     1701                    s1 = S1(); s2 = S2()
     1702                elif subop in "5678":
     1703                    s1 = S1(start, length); s2 = S2(start, length)
     1704                else:
     1705                    s1 = S1(s1); s2 = S2(s2)
     1706                #print "s2 = %s" % s2.dump()
     1707            elif op in "123":
     1708                #print "s2.add(%d,%d)" % (start, length)
     1709                s1.add(start, length); s2.add(start, length)
     1710            elif op in "456":
     1711                #print "s2.remove(%d,%d)" % (start, length)
     1712                s1.remove(start, length); s2.remove(start, length)
     1713            elif op in "78":
     1714                ns1, ns2 = _create(what[7:11])
     1715                #print "s2 + %s" % ns2.dump()
     1716                s1 = s1 + ns1; s2 = s2 + ns2
     1717            elif op in "9a":
     1718                ns1, ns2 = _create(what[7:11])
     1719                #print "%s - %s" % (s2.dump(), ns2.dump())
     1720                s1 = s1 - ns1; s2 = s2 - ns2
     1721            elif op in "bc":
     1722                ns1, ns2 = _create(what[7:11])
     1723                #print "s2 += %s" % ns2.dump()
     1724                s1 += ns1; s2 += ns2
     1725            else:
     1726                ns1, ns2 = _create(what[7:11])
     1727                #print "%s -= %s" % (s2.dump(), ns2.dump())
     1728                s1 -= ns1; s2 -= ns2
     1729            #print "s2 now %s" % s2.dump()
     1730            self.failUnlessEqual(list(s1.each()), list(s2.each()))
     1731            self.failUnlessEqual(len(s1), len(s2))
     1732            self.failUnlessEqual(bool(s1), bool(s2))
     1733            self.failUnlessEqual(list(s1), list(s2))
     1734            for j in range(10):
     1735                what = md5(what[12:14]+str(j)).hexdigest()
     1736                start = int(what[2:4], 16)
     1737                length = max(1, int(what[5:6], 16))
     1738                span = (start, length)
     1739                self.failUnlessEqual(bool(span in s1), bool(span in s2))
     1740
     1741
     1742    # s()
     1743    # s(start,length)
     1744    # s(s0)
     1745    # s.add(start,length) : returns s
     1746    # s.remove(start,length)
     1747    # s.each() -> list of byte offsets, mostly for testing
     1748    # list(s) -> list of (start,length) tuples, one per span
     1749    # (start,length) in s -> True if (start..start+length-1) are all members
     1750    #  NOT equivalent to x in list(s)
     1751    # len(s) -> number of bytes, for testing, bool(), and accounting/limiting
     1752    # bool(s)  (__len__)
     1753    # s = s1+s2, s1-s2, +=s1, -=s1
     1754
     1755    def test_overlap(self):
     1756        for a in range(20):
     1757            for b in range(10):
     1758                for c in range(20):
     1759                    for d in range(10):
     1760                        self._test_overlap(a,b,c,d)
     1761
     1762    def _test_overlap(self, a, b, c, d):
     1763        s1 = set(range(a,a+b))
     1764        s2 = set(range(c,c+d))
     1765        #print "---"
     1766        #self._show_overlap(s1, "1")
     1767        #self._show_overlap(s2, "2")
     1768        o = overlap(a,b,c,d)
     1769        expected = s1.intersection(s2)
     1770        if not expected:
     1771            self.failUnlessEqual(o, None)
     1772        else:
     1773            start,length = o
     1774            so = set(range(start,start+length))
     1775            #self._show(so, "o")
     1776            self.failUnlessEqual(so, expected)
     1777
     1778    def _show_overlap(self, s, c):
     1779        import sys
     1780        out = sys.stdout
     1781        if s:
     1782            for i in range(max(s)):
     1783                if i in s:
     1784                    out.write(c)
     1785                else:
     1786                    out.write(" ")
     1787        out.write("\n")
     1788
     1789def extend(s, start, length, fill):
     1790    if len(s) >= start+length:
     1791        return s
     1792    assert len(fill) == 1
     1793    return s + fill*(start+length-len(s))
     1794
     1795def replace(s, start, data):
     1796    assert len(s) >= start+len(data)
     1797    return s[:start] + data + s[start+len(data):]
     1798
     1799class SimpleDataSpans:
     1800    def __init__(self, other=None):
     1801        self.missing = "" # "1" where missing, "0" where found
     1802        self.data = ""
     1803        if other:
     1804            for (start, data) in other.get_spans():
     1805                self.add(start, data)
     1806
     1807    def __len__(self):
     1808        return len(self.missing.translate(None, "1"))
     1809    def _dump(self):
     1810        return [i for (i,c) in enumerate(self.missing) if c == "0"]
     1811    def _have(self, start, length):
     1812        m = self.missing[start:start+length]
     1813        if not m or len(m)<length or int(m):
     1814            return False
     1815        return True
     1816    def get_spans(self):
     1817        for i in self._dump():
     1818            yield (i, self.data[i])
     1819    def get(self, start, length):
     1820        if self._have(start, length):
     1821            return self.data[start:start+length]
     1822        return None
     1823    def pop(self, start, length):
     1824        data = self.get(start, length)
     1825        if data:
     1826            self.remove(start, length)
     1827        return data
     1828    def remove(self, start, length):
     1829        self.missing = replace(extend(self.missing, start, length, "1"),
     1830                               start, "1"*length)
     1831    def add(self, start, data):
     1832        self.missing = replace(extend(self.missing, start, len(data), "1"),
     1833                               start, "0"*len(data))
     1834        self.data = replace(extend(self.data, start, len(data), " "),
     1835                            start, data)
     1836
     1837
     1838class StringSpans(unittest.TestCase):
     1839    def do_basic(self, klass):
     1840        ds = klass()
     1841        self.failUnlessEqual(len(ds), 0)
     1842        self.failUnlessEqual(list(ds._dump()), [])
     1843        self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_spans()]), 0)
     1844        self.failUnlessEqual(ds.get(0, 4), None)
     1845        self.failUnlessEqual(ds.pop(0, 4), None)
     1846        ds.remove(0, 4)
     1847
     1848        ds.add(2, "four")
     1849        self.failUnlessEqual(len(ds), 4)
     1850        self.failUnlessEqual(list(ds._dump()), [2,3,4,5])
     1851        self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_spans()]), 4)
     1852        self.failUnlessEqual(ds.get(0, 4), None)
     1853        self.failUnlessEqual(ds.pop(0, 4), None)
     1854        self.failUnlessEqual(ds.get(4, 4), None)
     1855
     1856        ds2 = klass(ds)
     1857        self.failUnlessEqual(len(ds2), 4)
     1858        self.failUnlessEqual(list(ds2._dump()), [2,3,4,5])
     1859        self.failUnlessEqual(sum([len(d) for (s,d) in ds2.get_spans()]), 4)
     1860        self.failUnlessEqual(ds2.get(0, 4), None)
     1861        self.failUnlessEqual(ds2.pop(0, 4), None)
     1862        self.failUnlessEqual(ds2.pop(2, 3), "fou")
     1863        self.failUnlessEqual(sum([len(d) for (s,d) in ds2.get_spans()]), 1)
     1864        self.failUnlessEqual(ds2.get(2, 3), None)
     1865        self.failUnlessEqual(ds2.get(5, 1), "r")
     1866        self.failUnlessEqual(ds.get(2, 3), "fou")
     1867        self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_spans()]), 4)
     1868
     1869        ds.add(0, "23")
     1870        self.failUnlessEqual(len(ds), 6)
     1871        self.failUnlessEqual(list(ds._dump()), [0,1,2,3,4,5])
     1872        self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_spans()]), 6)
     1873        self.failUnlessEqual(ds.get(0, 4), "23fo")
     1874        self.failUnlessEqual(ds.pop(0, 4), "23fo")
     1875        self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_spans()]), 2)
     1876        self.failUnlessEqual(ds.get(0, 4), None)
     1877        self.failUnlessEqual(ds.pop(0, 4), None)
     1878
     1879        ds = klass()
     1880        ds.add(2, "four")
     1881        ds.add(3, "ea")
     1882        self.failUnlessEqual(ds.get(2, 4), "fear")
     1883
     1884    def do_scan(self, klass):
     1885        # do a test with gaps and spans of size 1 and 2
     1886        #  left=(1,11) * right=(1,11) * gapsize=(1,2)
     1887        # 111, 112, 121, 122, 211, 212, 221, 222
     1888        #    211
     1889        #      121
     1890        #         112
     1891        #            212
     1892        #               222
     1893        #                   221
     1894        #                      111
     1895        #                        122
     1896        #  11 1  1 11 11  11  1 1  111
     1897        # 0123456789012345678901234567
     1898        # abcdefghijklmnopqrstuvwxyz-=
     1899        pieces = [(1, "bc"),
     1900                  (4, "e"),
     1901                  (7, "h"),
     1902                  (9, "jk"),
     1903                  (12, "mn"),
     1904                  (16, "qr"),
     1905                  (20, "u"),
     1906                  (22, "w"),
     1907                  (25, "z-="),
     1908                  ]
     1909        p_elements = set([1,2,4,7,9,10,12,13,16,17,20,22,25,26,27])
     1910        S = "abcdefghijklmnopqrstuvwxyz-="
     1911        # TODO: when adding data, add capital letters, to make sure we aren't
     1912        # just leaving the old data in place
     1913        l = len(S)
     1914        def base():
     1915            ds = klass()
     1916            for start, data in pieces:
     1917                ds.add(start, data)
     1918            return ds
     1919        def dump(s):
     1920            p = set(s._dump())
     1921            # wow, this is the first time I've ever wanted ?: in python
     1922            # note: this requires python2.5
     1923            d = "".join([(S[i] if i in p else " ") for i in range(l)])
     1924            assert len(d) == l
     1925            return d
     1926        DEBUG = False
     1927        for start in range(0, l):
     1928            for end in range(start+1, l):
     1929                # add [start-end) to the baseline
     1930                which = "%d-%d" % (start, end-1)
     1931                p_added = set(range(start, end))
     1932                b = base()
     1933                if DEBUG:
     1934                    print
     1935                    print dump(b), which
     1936                    add = klass(); add.add(start, S[start:end])
     1937                    print dump(add)
     1938                b.add(start, S[start:end])
     1939                if DEBUG:
     1940                    print dump(b)
     1941                # check that the new span is there
     1942                d = b.get(start, end-start)
     1943                self.failUnlessEqual(d, S[start:end], which)
     1944                # check that all the original pieces are still there
     1945                for t_start, t_data in pieces:
     1946                    t_len = len(t_data)
     1947                    self.failUnlessEqual(b.get(t_start, t_len),
     1948                                         S[t_start:t_start+t_len],
     1949                                         "%s %d+%d" % (which, t_start, t_len))
     1950                # check that a lot of subspans are mostly correct
     1951                for t_start in range(l):
     1952                    for t_len in range(1,4):
     1953                        d = b.get(t_start, t_len)
     1954                        if d is not None:
     1955                            which2 = "%s+(%d-%d)" % (which, t_start,
     1956                                                     t_start+t_len-1)
     1957                            self.failUnlessEqual(d, S[t_start:t_start+t_len],
     1958                                                 which2)
     1959                        # check that removing a subspan gives the right value
     1960                        b2 = klass(b)
     1961                        b2.remove(t_start, t_len)
     1962                        removed = set(range(t_start, t_start+t_len))
     1963                        for i in range(l):
     1964                            exp = (((i in p_elements) or (i in p_added))
     1965                                   and (i not in removed))
     1966                            which2 = "%s-(%d-%d)" % (which, t_start,
     1967                                                     t_start+t_len-1)
     1968                            self.failUnlessEqual(bool(b2.get(i, 1)), exp,
     1969                                                 which2+" %d" % i)
     1970
     1971    def test_test(self):
     1972        self.do_basic(SimpleDataSpans)
     1973        self.do_scan(SimpleDataSpans)
     1974
     1975    def test_basic(self):
     1976        self.do_basic(DataSpans)
     1977        self.do_scan(DataSpans)
     1978
     1979    def test_random(self):
     1980        # attempt to increase coverage of corner cases by comparing behavior
     1981        # of a simple-but-slow model implementation against the
     1982        # complex-but-fast actual implementation, in a large number of random
     1983        # operations
     1984        S1 = SimpleDataSpans
     1985        S2 = DataSpans
     1986        s1 = S1(); s2 = S2()
     1987        seed = ""
     1988        def _randstr(length, seed):
     1989            created = 0
     1990            pieces = []
     1991            while created < length:
     1992                piece = md5(seed + str(created)).hexdigest()
     1993                pieces.append(piece)
     1994                created += len(piece)
     1995            return "".join(pieces)[:length]
     1996        def _create(subseed):
     1997            ns1 = S1(); ns2 = S2()
     1998            for i in range(10):
     1999                what = md5(subseed+str(i)).hexdigest()
     2000                start = int(what[2:4], 16)
     2001                length = max(1,int(what[5:6], 16))
     2002                ns1.add(start, _randstr(length, what[7:9]));
     2003                ns2.add(start, _randstr(length, what[7:9]))
     2004            return ns1, ns2
     2005
     2006        #print
     2007        for i in range(1000):
     2008            what = md5(seed+str(i)).hexdigest()
     2009            op = what[0]
     2010            subop = what[1]
     2011            start = int(what[2:4], 16)
     2012            length = max(1,int(what[5:6], 16))
     2013            #print what
     2014            if op in "0":
     2015                if subop in "0123456":
     2016                    s1 = S1(); s2 = S2()
     2017                else:
     2018                    s1, s2 = _create(what[7:11])
     2019                #print "s2 = %s" % list(s2._dump())
     2020            elif op in "123456":
     2021                #print "s2.add(%d,%d)" % (start, length)
     2022                s1.add(start, _randstr(length, what[7:9]));
     2023                s2.add(start, _randstr(length, what[7:9]))
     2024            elif op in "789abc":
     2025                #print "s2.remove(%d,%d)" % (start, length)
     2026                s1.remove(start, length); s2.remove(start, length)
     2027            else:
     2028                #print "s2.pop(%d,%d)" % (start, length)
     2029                d1 = s1.pop(start, length); d2 = s2.pop(start, length)
     2030                self.failUnlessEqual(d1, d2)
     2031            #print "s1 now %s" % list(s1._dump())
     2032            #print "s2 now %s" % list(s2._dump())
     2033            self.failUnlessEqual(len(s1), len(s2))
     2034            self.failUnlessEqual(list(s1._dump()), list(s2._dump()))
     2035            for j in range(100):
     2036                what = md5(what[12:14]+str(j)).hexdigest()
     2037                start = int(what[2:4], 16)
     2038                length = max(1, int(what[5:6], 16))
     2039                d1 = s1.get(start, length); d2 = s2.get(start, length)
     2040                self.failUnlessEqual(d1, d2, "%d+%d" % (start, length))
  • new file src/allmydata/util/spans.py

    diff --git a/src/allmydata/util/spans.py b/src/allmydata/util/spans.py
    new file mode 100755
    index 0000000..336fddf
    - +  
     1
     2class Spans:
     3    """I represent a compressed list of booleans, one per index (an integer).
     4    Typically, each index represents an offset into a large string, pointing
     5    to a specific byte of a share. In this context, True means that byte has
     6    been received, or has been requested.
     7
     8    Another way to look at this is maintaining a set of integers, optimized
     9    for operations on spans like 'add range to set' and 'is range in set?'.
     10
     11    This is a python equivalent of perl's Set::IntSpan module, frequently
     12    used to represent .newsrc contents.
     13
     14    Rather than storing an actual (large) list or dictionary, I represent my
     15    internal state as a sorted list of spans, each with a start and a length.
     16    My API is presented in terms of start+length pairs. I provide set
     17    arithmetic operators, to efficiently answer questions like 'I want bytes
     18    XYZ, I already requested bytes ABC, and I've already received bytes DEF:
     19    what bytes should I request now?'.
     20
     21    The new downloader will use it to keep track of which bytes we've requested
     22    or received already.
     23    """
     24
     25    def __init__(self, _span_or_start=None, length=None):
     26        self._spans = list()
     27        if length is not None:
     28            self._spans.append( (_span_or_start, length) )
     29        elif _span_or_start:
     30            for (start,length) in _span_or_start:
     31                self.add(start, length)
     32        self._check()
     33
     34    def _check(self):
     35        assert sorted(self._spans) == self._spans
     36        prev_end = None
     37        try:
     38            for (start,length) in self._spans:
     39                if prev_end is not None:
     40                    assert start > prev_end
     41                prev_end = start+length
     42        except AssertionError:
     43            print "BAD:", self.dump()
     44            raise
     45
     46    def add(self, start, length):
     47        assert start >= 0
     48        assert length > 0
     49        #print " ADD [%d+%d -%d) to %s" % (start, length, start+length, self.dump())
     50        first_overlap = last_overlap = None
     51        for i,(s_start,s_length) in enumerate(self._spans):
     52            #print "  (%d+%d)-> overlap=%s adjacent=%s" % (s_start,s_length, overlap(s_start, s_length, start, length), adjacent(s_start, s_length, start, length))
     53            if (overlap(s_start, s_length, start, length)
     54                or adjacent(s_start, s_length, start, length)):
     55                last_overlap = i
     56                if first_overlap is None:
     57                    first_overlap = i
     58                continue
     59            # no overlap
     60            if first_overlap is not None:
     61                break
     62        #print "  first_overlap", first_overlap, last_overlap
     63        if first_overlap is None:
     64            # no overlap, so just insert the span and sort by starting
     65            # position.
     66            self._spans.insert(0, (start,length))
     67            self._spans.sort()
     68        else:
     69            # everything from [first_overlap] to [last_overlap] overlapped
     70            first_start,first_length = self._spans[first_overlap]
     71            last_start,last_length = self._spans[last_overlap]
     72            newspan_start = min(start, first_start)
     73            newspan_end = max(start+length, last_start+last_length)
     74            newspan_length = newspan_end - newspan_start
     75            newspan = (newspan_start, newspan_length)
     76            self._spans[first_overlap:last_overlap+1] = [newspan]
     77        #print "  ADD done: %s" % self.dump()
     78        self._check()
     79
     80        return self
     81
     82    def remove(self, start, length):
     83        assert start >= 0
     84        assert length > 0
     85        #print " REMOVE [%d+%d -%d) from %s" % (start, length, start+length, self.dump())
     86        first_complete_overlap = last_complete_overlap = None
     87        for i,(s_start,s_length) in enumerate(self._spans):
     88            s_end = s_start + s_length
     89            o = overlap(s_start, s_length, start, length)
     90            if o:
     91                o_start, o_length = o
     92                o_end = o_start+o_length
     93                if o_start == s_start and o_end == s_end:
     94                    # delete this span altogether
     95                    if first_complete_overlap is None:
     96                        first_complete_overlap = i
     97                    last_complete_overlap = i
     98                elif o_start == s_start:
     99                    # we only overlap the left side, so trim the start
     100                    #    1111
     101                    #  rrrr
     102                    #    oo
     103                    # ->   11
     104                    new_start = o_end
     105                    new_end = s_end
     106                    assert new_start > s_start
     107                    new_length = new_end - new_start
     108                    self._spans[i] = (new_start, new_length)
     109                elif o_end == s_end:
     110                    # we only overlap the right side
     111                    #    1111
     112                    #      rrrr
     113                    #      oo
     114                    # -> 11
     115                    new_start = s_start
     116                    new_end = o_start
     117                    assert new_end < s_end
     118                    new_length = new_end - new_start
     119                    self._spans[i] = (new_start, new_length)
     120                else:
     121                    # we overlap the middle, so create a new span. No need to
     122                    # examine any other spans.
     123                    #    111111
     124                    #      rr
     125                    #    LL  RR
     126                    left_start = s_start
     127                    left_end = o_start
     128                    left_length = left_end - left_start
     129                    right_start = o_end
     130                    right_end = s_end
     131                    right_length = right_end - right_start
     132                    self._spans[i] = (left_start, left_length)
     133                    self._spans.append( (right_start, right_length) )
     134                    self._spans.sort()
     135                    break
     136        if first_complete_overlap is not None:
     137            del self._spans[first_complete_overlap:last_complete_overlap+1]
     138        #print "  REMOVE done: %s" % self.dump()
     139        self._check()
     140        return self
     141
     142    def dump(self):
     143        return "len=%d: %s" % (len(self),
     144                               ",".join(["[%d-%d]" % (start,start+l-1)
     145                                         for (start,l) in self._spans]) )
     146
     147    def each(self):
     148        for start, length in self._spans:
     149            for i in range(start, start+length):
     150                yield i
     151
     152    def __iter__(self):
     153        for s in self._spans:
     154            yield s
     155
     156    def __len__(self):
     157        # this also gets us bool(s)
     158        return sum([length for start,length in self._spans])
     159
     160    def __add__(self, other):
     161        s = self.__class__(self)
     162        for (start, length) in other:
     163            s.add(start, length)
     164        return s
     165
     166    def __sub__(self, other):
     167        s = self.__class__(self)
     168        for (start, length) in other:
     169            s.remove(start, length)
     170        return s
     171
     172    def __iadd__(self, other):
     173        for (start, length) in other:
     174            self.add(start, length)
     175        return self
     176
     177    def __isub__(self, other):
     178        for (start, length) in other:
     179            self.remove(start, length)
     180        return self
     181
     182    def __contains__(self, (start,length)):
     183        for span_start,span_length in self._spans:
     184            o = overlap(start, length, span_start, span_length)
     185            if o:
     186                o_start,o_length = o
     187                if o_start == start and o_length == length:
     188                    return True
     189        return False
     190
     191def overlap(start0, length0, start1, length1):
     192    # return start2,length2 of the overlapping region, or None
     193    #  00      00   000   0000  00  00 000  00   00  00      00
     194    #     11    11   11    11   111 11 11  1111 111 11    11
     195    left = max(start0, start1)
     196    right = min(start0+length0, start1+length1)
     197    # if there is overlap, 'left' will be its start, and right-1 will
     198    # be the end'
     199    if left < right:
     200        return (left, right-left)
     201    return None
     202
     203def adjacent(start0, length0, start1, length1):
     204    if (start0 < start1) and start0+length0 == start1:
     205        return True
     206    elif (start1 < start0) and start1+length1 == start0:
     207        return True
     208    return False
     209
     210class DataSpans:
     211    """I represent portions of a large string. Equivalently, I can be said to
     212    maintain a large array of characters (with gaps of empty elements). I can
     213    be used to manage access to a remote share, where some pieces have been
     214    retrieved, some have been requested, and others have not been read.
     215    """
     216
     217    def __init__(self, other=None):
     218        self.spans = [] # (start, data) tuples, non-overlapping, merged
     219        if other:
     220            for (start, data) in other.get_spans():
     221                self.add(start, data)
     222
     223    def __len__(self):
     224        # return number of bytes we're holding
     225        return sum([len(data) for (start,data) in self.spans])
     226
     227    def _dump(self):
     228        # return iterator of sorted list of offsets, one per byte
     229        for (start,data) in self.spans:
     230            for i in range(start, start+len(data)):
     231                yield i
     232
     233    def get_spans(self):
     234        return list(self.spans)
     235
     236    def assert_invariants(self):
     237        if not self.spans:
     238            return
     239        prev_start = self.spans[0][0]
     240        prev_end = prev_start + len(self.spans[0][1])
     241        for start, data in self.spans[1:]:
     242            if not start > prev_end:
     243                # adjacent or overlapping: bad
     244                print "ASSERTION FAILED", self.spans
     245                raise AssertionError
     246
     247    def get(self, start, length):
     248        # returns a string of LENGTH, or None
     249        #print "get", start, length, self.spans
     250        end = start+length
     251        for (s_start,s_data) in self.spans:
     252            s_end = s_start+len(s_data)
     253            #print " ",s_start,s_end
     254            if s_start <= start < s_end:
     255                # we want some data from this span. Because we maintain
     256                # strictly merged and non-overlapping spans, everything we
     257                # want must be in this span.
     258                offset = start - s_start
     259                if offset + length > len(s_data):
     260                    #print " None, span falls short"
     261                    return None # span falls short
     262                #print " some", s_data[offset:offset+length]
     263                return s_data[offset:offset+length]
     264            if s_start >= end:
     265                # we've gone too far: no further spans will overlap
     266                #print " None, gone too far"
     267                return None
     268        #print " None, ran out of spans"
     269        return None
     270
     271    def add(self, start, data):
     272        # first: walk through existing spans, find overlap, modify-in-place
     273        #  create list of new spans
     274        #  add new spans
     275        #  sort
     276        #  merge adjacent spans
     277        #print "add", start, data, self.spans
     278        end = start + len(data)
     279        i = 0
     280        while len(data):
     281            #print " loop", start, data, i, len(self.spans), self.spans
     282            if i >= len(self.spans):
     283                #print " append and done"
     284                # append a last span
     285                self.spans.append( (start, data) )
     286                break
     287            (s_start,s_data) = self.spans[i]
     288            # five basic cases:
     289            #  a: OLD  b:OLDD  c1:OLD  c2:OLD   d1:OLDD  d2:OLD  e: OLLDD
     290            #    NEW     NEW      NEW     NEWW      NEW      NEW     NEW
     291            #
     292            # we handle A by inserting a new segment (with "N") and looping,
     293            # turning it into B or C. We handle B by replacing a prefix and
     294            # terminating. We handle C (both c1 and c2) by replacing the
     295            # segment (and, for c2, looping, turning it into A). We handle D
     296            # by replacing a suffix (and, for d2, looping, turning it into
     297            # A). We handle E by replacing the middle and terminating.
     298            if start < s_start:
     299                # case A: insert a new span, then loop with the remainder
     300                #print " insert new psan"
     301                s_len = s_start-start
     302                self.spans.insert(i, (start, data[:s_len]))
     303                i += 1
     304                start = s_start
     305                data = data[s_len:]
     306                continue
     307            s_len = len(s_data)
     308            s_end = s_start+s_len
     309            if s_start <= start < s_end:
     310                #print " modify this span", s_start, start, s_end
     311                # we want to modify some data in this span: a prefix, a
     312                # suffix, or the whole thing
     313                if s_start == start:
     314                    if s_end <= end:
     315                        #print " replace whole segment"
     316                        # case C: replace this segment
     317                        self.spans[i] = (s_start, data[:s_len])
     318                        i += 1
     319                        start += s_len
     320                        data = data[s_len:]
     321                        # C2 is where len(data)>0
     322                        continue
     323                    # case B: modify the prefix, retain the suffix
     324                    #print " modify prefix"
     325                    self.spans[i] = (s_start, data + s_data[len(data):])
     326                    break
     327                if start > s_start and end < s_end:
     328                    # case E: modify the middle
     329                    #print " modify middle"
     330                    prefix_len = start - s_start # we retain this much
     331                    suffix_len = s_end - end # and retain this much
     332                    newdata = s_data[:prefix_len] + data + s_data[-suffix_len:]
     333                    self.spans[i] = (s_start, newdata)
     334                    break
     335                # case D: retain the prefix, modify the suffix
     336                #print " modify suffix"
     337                prefix_len = start - s_start # we retain this much
     338                suffix_len = s_len - prefix_len # we replace this much
     339                #print "  ", s_data, prefix_len, suffix_len, s_len, data
     340                self.spans[i] = (s_start,
     341                                 s_data[:prefix_len] + data[:suffix_len])
     342                i += 1
     343                start += suffix_len
     344                data = data[suffix_len:]
     345                #print "  now", start, data
     346                # D2 is where len(data)>0
     347                continue
     348            # else we're not there yet
     349            #print " still looking"
     350            i += 1
     351            continue
     352        # now merge adjacent spans
     353        #print " merging", self.spans
     354        newspans = []
     355        for (s_start,s_data) in self.spans:
     356            if newspans and adjacent(newspans[-1][0], len(newspans[-1][1]),
     357                                     s_start, len(s_data)):
     358                newspans[-1] = (newspans[-1][0], newspans[-1][1] + s_data)
     359            else:
     360                newspans.append( (s_start, s_data) )
     361        self.spans = newspans
     362        self.assert_invariants()
     363        #print " done", self.spans
     364
     365    def remove(self, start, length):
     366        i = 0
     367        end = start + length
     368        #print "remove", start, length, self.spans
     369        while i < len(self.spans):
     370            (s_start,s_data) = self.spans[i]
     371            if s_start >= end:
     372                # this segment is entirely right of the removed region, and
     373                # all further segments are even further right. We're done.
     374                break
     375            s_len = len(s_data)
     376            s_end = s_start + s_len
     377            o = overlap(start, length, s_start, s_len)
     378            if not o:
     379                i += 1
     380                continue
     381            o_start, o_len = o
     382            o_end = o_start + o_len
     383            if o_len == s_len:
     384                # remove the whole segment
     385                del self.spans[i]
     386                continue
     387            if o_start == s_start:
     388                # remove a prefix, leaving the suffix from o_end to s_end
     389                prefix_len = o_end - o_start
     390                self.spans[i] = (o_end, s_data[prefix_len:])
     391                i += 1
     392                continue
     393            elif o_end == s_end:
     394                # remove a suffix, leaving the prefix from s_start to o_start
     395                prefix_len = o_start - s_start
     396                self.spans[i] = (s_start, s_data[:prefix_len])
     397                i += 1
     398                continue
     399            # remove the middle, creating a new segment
     400            # left is s_start:o_start, right is o_end:s_end
     401            left_len = o_start - s_start
     402            left = s_data[:left_len]
     403            right_len = s_end - o_end
     404            right = s_data[-right_len:]
     405            self.spans[i] = (s_start, left)
     406            self.spans.insert(i+1, (o_end, right))
     407            break
     408        #print " done", self.spans
     409
     410    def pop(self, start, length):
     411        data = self.get(start, length)
     412        if data:
     413            self.remove(start, length)
     414        return data