source file: /home/buildslave/tahoe/edgy/build/src/allmydata/mutable/retrieve.py
file stats: 363 lines, 351 executed: 96.7% covered
coverage versus previous test: 0 lines added, 0 lines removed
    1. 
    2. import struct, time
    3. from itertools import count
    4. from zope.interface import implements
    5. from twisted.internet import defer
    6. from twisted.python import failure
    7. from foolscap.api import DeadReferenceError, eventually, fireEventually
    8. from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError
    9. from allmydata.util import hashutil, idlib, log
   10. from allmydata import hashtree, codec
   11. from allmydata.storage.server import si_b2a
   12. from pycryptopp.cipher.aes import AES
   13. from pycryptopp.publickey import rsa
   14. 
   15. from common import DictOfSets, CorruptShareError, UncoordinatedWriteError
   16. from layout import SIGNED_PREFIX, unpack_share_data
   17. 
   18. class RetrieveStatus:
   19.     implements(IRetrieveStatus)
   20.     statusid_counter = count(0)
   21.     def __init__(self):
   22.         self.timings = {}
   23.         self.timings["fetch_per_server"] = {}
   24.         self.timings["cumulative_verify"] = 0.0
   25.         self.problems = {}
   26.         self.active = True
   27.         self.storage_index = None
   28.         self.helper = False
   29.         self.encoding = ("?","?")
   30.         self.size = None
   31.         self.status = "Not started"
   32.         self.progress = 0.0
   33.         self.counter = self.statusid_counter.next()
   34.         self.started = time.time()
   35. 
   36.     def get_started(self):
   37.         return self.started
   38.     def get_storage_index(self):
   39.         return self.storage_index
   40.     def get_encoding(self):
   41.         return self.encoding
   42.     def using_helper(self):
   43.         return self.helper
   44.     def get_size(self):
   45.         return self.size
   46.     def get_status(self):
   47.         return self.status
   48.     def get_progress(self):
   49.         return self.progress
   50.     def get_active(self):
   51.         return self.active
   52.     def get_counter(self):
   53.         return self.counter
   54. 
   55.     def add_fetch_timing(self, peerid, elapsed):
   56.         if peerid not in self.timings["fetch_per_server"]:
   57.             self.timings["fetch_per_server"][peerid] = []
   58.         self.timings["fetch_per_server"][peerid].append(elapsed)
   59.     def set_storage_index(self, si):
   60.         self.storage_index = si
   61.     def set_helper(self, helper):
   62.         self.helper = helper
   63.     def set_encoding(self, k, n):
   64.         self.encoding = (k, n)
   65.     def set_size(self, size):
   66.         self.size = size
   67.     def set_status(self, status):
   68.         self.status = status
   69.     def set_progress(self, value):
   70.         self.progress = value
   71.     def set_active(self, value):
   72.         self.active = value
   73. 
   74. class Marker:
   75.     pass
   76. 
   77. class Retrieve:
   78.     # this class is currently single-use. Eventually (in MDMF) we will make
   79.     # it multi-use, in which case you can call download(range) multiple
   80.     # times, and each will have a separate response chain. However the
   81.     # Retrieve object will remain tied to a specific version of the file, and
   82.     # will use a single ServerMap instance.
   83. 
   84.     def __init__(self, filenode, servermap, verinfo, fetch_privkey=False):
   85.         self._node = filenode
   86.         assert self._node._pubkey
   87.         self._storage_index = filenode.get_storage_index()
   88.         assert self._node._readkey
   89.         self._last_failure = None
   90.         prefix = si_b2a(self._storage_index)[:5]
   91.         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
   92.         self._outstanding_queries = {} # maps (peerid,shnum) to start_time
   93.         self._running = True
   94.         self._decoding = False
   95.         self._bad_shares = set()
   96. 
   97.         self.servermap = servermap
   98.         assert self._node._pubkey
   99.         self.verinfo = verinfo
  100.         # during repair, we may be called upon to grab the private key, since
  101.         # it wasn't picked up during a verify=False checker run, and we'll
  102.         # need it for repair to generate the a new version.
  103.         self._need_privkey = fetch_privkey
  104.         if self._node._privkey:
  105.             self._need_privkey = False
  106. 
  107.         self._status = RetrieveStatus()
  108.         self._status.set_storage_index(self._storage_index)
  109.         self._status.set_helper(False)
  110.         self._status.set_progress(0.0)
  111.         self._status.set_active(True)
  112.         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
  113.          offsets_tuple) = self.verinfo
  114.         self._status.set_size(datalength)
  115.         self._status.set_encoding(k, N)
  116. 
  117.     def get_status(self):
  118.         return self._status
  119. 
  120.     def log(self, *args, **kwargs):
  121.         if "parent" not in kwargs:
  122.             kwargs["parent"] = self._log_number
  123.         if "facility" not in kwargs:
  124.             kwargs["facility"] = "tahoe.mutable.retrieve"
  125.         return log.msg(*args, **kwargs)
  126. 
  127.     def download(self):
  128.         self._done_deferred = defer.Deferred()
  129.         self._started = time.time()
  130.         self._status.set_status("Retrieving Shares")
  131. 
  132.         # first, which servers can we use?
  133.         versionmap = self.servermap.make_versionmap()
  134.         shares = versionmap[self.verinfo]
  135.         # this sharemap is consumed as we decide to send requests
  136.         self.remaining_sharemap = DictOfSets()
  137.         for (shnum, peerid, timestamp) in shares:
  138.             self.remaining_sharemap.add(shnum, peerid)
  139. 
  140.         self.shares = {} # maps shnum to validated blocks
  141. 
  142.         # how many shares do we need?
  143.         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
  144.          offsets_tuple) = self.verinfo
  145.         assert len(self.remaining_sharemap) >= k
  146.         # we start with the lowest shnums we have available, since FEC is
  147.         # faster if we're using "primary shares"
  148.         self.active_shnums = set(sorted(self.remaining_sharemap.keys())[:k])
  149.         for shnum in self.active_shnums:
  150.             # we use an arbitrary peer who has the share. If shares are
  151.             # doubled up (more than one share per peer), we could make this
  152.             # run faster by spreading the load among multiple peers. But the
  153.             # algorithm to do that is more complicated than I want to write
  154.             # right now, and a well-provisioned grid shouldn't have multiple
  155.             # shares per peer.
  156.             peerid = list(self.remaining_sharemap[shnum])[0]
  157.             self.get_data(shnum, peerid)
  158. 
  159.         # control flow beyond this point: state machine. Receiving responses
  160.         # from queries is the input. We might send out more queries, or we
  161.         # might produce a result.
  162. 
  163.         return self._done_deferred
  164. 
  165.     def get_data(self, shnum, peerid):
  166.         self.log(format="sending sh#%(shnum)d request to [%(peerid)s]",
  167.                  shnum=shnum,
  168.                  peerid=idlib.shortnodeid_b2a(peerid),
  169.                  level=log.NOISY)
  170.         ss = self.servermap.connections[peerid]
  171.         started = time.time()
  172.         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
  173.          offsets_tuple) = self.verinfo
  174.         offsets = dict(offsets_tuple)
  175. 
  176.         # we read the checkstring, to make sure that the data we grab is from
  177.         # the right version.
  178.         readv = [ (0, struct.calcsize(SIGNED_PREFIX)) ]
  179. 
  180.         # We also read the data, and the hashes necessary to validate them
  181.         # (share_hash_chain, block_hash_tree, share_data). We don't read the
  182.         # signature or the pubkey, since that was handled during the
  183.         # servermap phase, and we'll be comparing the share hash chain
  184.         # against the roothash that was validated back then.
  185. 
  186.         readv.append( (offsets['share_hash_chain'],
  187.                        offsets['enc_privkey'] - offsets['share_hash_chain'] ) )
  188. 
  189.         # if we need the private key (for repair), we also fetch that
  190.         if self._need_privkey:
  191.             readv.append( (offsets['enc_privkey'],
  192.                            offsets['EOF'] - offsets['enc_privkey']) )
  193. 
  194.         m = Marker()
  195.         self._outstanding_queries[m] = (peerid, shnum, started)
  196. 
  197.         # ask the cache first
  198.         got_from_cache = False
  199.         datavs = []
  200.         for (offset, length) in readv:
  201.             (data, timestamp) = self._node._cache.read(self.verinfo, shnum,
  202.                                                        offset, length)
  203.             if data is not None:
  204.                 datavs.append(data)
  205.         if len(datavs) == len(readv):
  206.             self.log("got data from cache")
  207.             got_from_cache = True
  208.             d = fireEventually({shnum: datavs})
  209.             # datavs is a dict mapping shnum to a pair of strings
  210.         else:
  211.             d = self._do_read(ss, peerid, self._storage_index, [shnum], readv)
  212.         self.remaining_sharemap.discard(shnum, peerid)
  213. 
  214.         d.addCallback(self._got_results, m, peerid, started, got_from_cache)
  215.         d.addErrback(self._query_failed, m, peerid)
  216.         # errors that aren't handled by _query_failed (and errors caused by
  217.         # _query_failed) get logged, but we still want to check for doneness.
  218.         def _oops(f):
  219.             self.log(format="problem in _query_failed for sh#%(shnum)d to %(peerid)s",
  220.                      shnum=shnum,
  221.                      peerid=idlib.shortnodeid_b2a(peerid),
  222.                      failure=f,
  223.                      level=log.WEIRD, umid="W0xnQA")
  224.         d.addErrback(_oops)
  225.         d.addBoth(self._check_for_done)
  226.         # any error during _check_for_done means the download fails. If the
  227.         # download is successful, _check_for_done will fire _done by itself.
  228.         d.addErrback(self._done)
  229.         d.addErrback(log.err)
  230.         return d # purely for testing convenience
  231. 
  232.     def _do_read(self, ss, peerid, storage_index, shnums, readv):
  233.         # isolate the callRemote to a separate method, so tests can subclass
  234.         # Publish and override it
  235.         d = ss.callRemote("slot_readv", storage_index, shnums, readv)
  236.         return d
  237. 
  238.     def remove_peer(self, peerid):
  239.         for shnum in list(self.remaining_sharemap.keys()):
  240.             self.remaining_sharemap.discard(shnum, peerid)
  241. 
  242.     def _got_results(self, datavs, marker, peerid, started, got_from_cache):
  243.         now = time.time()
  244.         elapsed = now - started
  245.         if not got_from_cache:
  246.             self._status.add_fetch_timing(peerid, elapsed)
  247.         self.log(format="got results (%(shares)d shares) from [%(peerid)s]",
  248.                  shares=len(datavs),
  249.                  peerid=idlib.shortnodeid_b2a(peerid),
  250.                  level=log.NOISY)
  251.         self._outstanding_queries.pop(marker, None)
  252.         if not self._running:
  253.             return
  254. 
  255.         # note that we only ask for a single share per query, so we only
  256.         # expect a single share back. On the other hand, we use the extra
  257.         # shares if we get them.. seems better than an assert().
  258. 
  259.         for shnum,datav in datavs.items():
  260.             (prefix, hash_and_data) = datav[:2]
  261.             try:
  262.                 self._got_results_one_share(shnum, peerid,
  263.                                             prefix, hash_and_data)
  264.             except CorruptShareError, e:
  265.                 # log it and give the other shares a chance to be processed
  266.                 f = failure.Failure()
  267.                 self.log(format="bad share: %(f_value)s",
  268.                          f_value=str(f.value), failure=f,
  269.                          level=log.WEIRD, umid="7fzWZw")
  270.                 self.notify_server_corruption(peerid, shnum, str(e))
  271.                 self.remove_peer(peerid)
  272.                 self.servermap.mark_bad_share(peerid, shnum, prefix)
  273.                 self._bad_shares.add( (peerid, shnum) )
  274.                 self._status.problems[peerid] = f
  275.                 self._last_failure = f
  276.                 pass
  277.             if self._need_privkey and len(datav) > 2:
  278.                 lp = None
  279.                 self._try_to_validate_privkey(datav[2], peerid, shnum, lp)
  280.         # all done!
  281. 
  282.     def notify_server_corruption(self, peerid, shnum, reason):
  283.         ss = self.servermap.connections[peerid]
  284.         ss.callRemoteOnly("advise_corrupt_share",
  285.                           "mutable", self._storage_index, shnum, reason)
  286. 
  287.     def _got_results_one_share(self, shnum, peerid,
  288.                                got_prefix, got_hash_and_data):
  289.         self.log("_got_results: got shnum #%d from peerid %s"
  290.                  % (shnum, idlib.shortnodeid_b2a(peerid)))
  291.         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
  292.          offsets_tuple) = self.verinfo
  293.         assert len(got_prefix) == len(prefix), (len(got_prefix), len(prefix))
  294.         if got_prefix != prefix:
  295.             msg = "someone wrote to the data since we read the servermap: prefix changed"
  296.             raise UncoordinatedWriteError(msg)
  297.         (share_hash_chain, block_hash_tree,
  298.          share_data) = unpack_share_data(self.verinfo, got_hash_and_data)
  299. 
  300.         assert isinstance(share_data, str)
  301.         # build the block hash tree. SDMF has only one leaf.
  302.         leaves = [hashutil.block_hash(share_data)]
  303.         t = hashtree.HashTree(leaves)
  304.         if list(t) != block_hash_tree:
  305.             raise CorruptShareError(peerid, shnum, "block hash tree failure")
  306.         share_hash_leaf = t[0]
  307.         t2 = hashtree.IncompleteHashTree(N)
  308.         # root_hash was checked by the signature
  309.         t2.set_hashes({0: root_hash})
  310.         try:
  311.             t2.set_hashes(hashes=share_hash_chain,
  312.                           leaves={shnum: share_hash_leaf})
  313.         except (hashtree.BadHashError, hashtree.NotEnoughHashesError,
  314.                 IndexError), e:
  315.             msg = "corrupt hashes: %s" % (e,)
  316.             raise CorruptShareError(peerid, shnum, msg)
  317.         self.log(" data valid! len=%d" % len(share_data))
  318.         # each query comes down to this: placing validated share data into
  319.         # self.shares
  320.         self.shares[shnum] = share_data
  321. 
  322.     def _try_to_validate_privkey(self, enc_privkey, peerid, shnum, lp):
  323. 
  324.         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
  325.         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
  326.         if alleged_writekey != self._node.get_writekey():
  327.             self.log("invalid privkey from %s shnum %d" %
  328.                      (idlib.nodeid_b2a(peerid)[:8], shnum),
  329.                      parent=lp, level=log.WEIRD, umid="YIw4tA")
  330.             return
  331. 
  332.         # it's good
  333.         self.log("got valid privkey from shnum %d on peerid %s" %
  334.                  (shnum, idlib.shortnodeid_b2a(peerid)),
  335.                  parent=lp)
  336.         privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
  337.         self._node._populate_encprivkey(enc_privkey)
  338.         self._node._populate_privkey(privkey)
  339.         self._need_privkey = False
  340. 
  341.     def _query_failed(self, f, marker, peerid):
  342.         self.log(format="query to [%(peerid)s] failed",
  343.                  peerid=idlib.shortnodeid_b2a(peerid),
  344.                  level=log.NOISY)
  345.         self._status.problems[peerid] = f
  346.         self._outstanding_queries.pop(marker, None)
  347.         if not self._running:
  348.             return
  349.         self._last_failure = f
  350.         self.remove_peer(peerid)
  351.         level = log.WEIRD
  352.         if f.check(DeadReferenceError):
  353.             level = log.UNUSUAL
  354.         self.log(format="error during query: %(f_value)s",
  355.                  f_value=str(f.value), failure=f, level=level, umid="gOJB5g")
  356. 
  357.     def _check_for_done(self, res):
  358.         # exit paths:
  359.         #  return : keep waiting, no new queries
  360.         #  return self._send_more_queries(outstanding) : send some more queries
  361.         #  fire self._done(plaintext) : download successful
  362.         #  raise exception : download fails
  363. 
  364.         self.log(format="_check_for_done: running=%(running)s, decoding=%(decoding)s",
  365.                  running=self._running, decoding=self._decoding,
  366.                  level=log.NOISY)
  367.         if not self._running:
  368.             return
  369.         if self._decoding:
  370.             return
  371.         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
  372.          offsets_tuple) = self.verinfo
  373. 
  374.         if len(self.shares) < k:
  375.             # we don't have enough shares yet
  376.             return self._maybe_send_more_queries(k)
  377.         if self._need_privkey:
  378.             # we got k shares, but none of them had a valid privkey. TODO:
  379.             # look further. Adding code to do this is a bit complicated, and
  380.             # I want to avoid that complication, and this should be pretty
  381.             # rare (k shares with bitflips in the enc_privkey but not in the
  382.             # data blocks). If we actually do get here, the subsequent repair
  383.             # will fail for lack of a privkey.
  384.             self.log("got k shares but still need_privkey, bummer",
  385.                      level=log.WEIRD, umid="MdRHPA")
  386. 
  387.         # we have enough to finish. All the shares have had their hashes
  388.         # checked, so if something fails at this point, we don't know how
  389.         # to fix it, so the download will fail.
  390. 
  391.         self._decoding = True # avoid reentrancy
  392.         self._status.set_status("decoding")
  393.         now = time.time()
  394.         elapsed = now - self._started
  395.         self._status.timings["fetch"] = elapsed
  396. 
  397.         d = defer.maybeDeferred(self._decode)
  398.         d.addCallback(self._decrypt, IV, self._node._readkey)
  399.         d.addBoth(self._done)
  400.         return d # purely for test convenience
  401. 
  402.     def _maybe_send_more_queries(self, k):
  403.         # we don't have enough shares yet. Should we send out more queries?
  404.         # There are some number of queries outstanding, each for a single
  405.         # share. If we can generate 'needed_shares' additional queries, we do
  406.         # so. If we can't, then we know this file is a goner, and we raise
  407.         # NotEnoughSharesError.
  408.         self.log(format=("_maybe_send_more_queries, have=%(have)d, k=%(k)d, "
  409.                          "outstanding=%(outstanding)d"),
  410.                  have=len(self.shares), k=k,
  411.                  outstanding=len(self._outstanding_queries),
  412.                  level=log.NOISY)
  413. 
  414.         remaining_shares = k - len(self.shares)
  415.         needed = remaining_shares - len(self._outstanding_queries)
  416.         if not needed:
  417.             # we have enough queries in flight already
  418. 
  419.             # TODO: but if they've been in flight for a long time, and we
  420.             # have reason to believe that new queries might respond faster
  421.             # (i.e. we've seen other queries come back faster, then consider
  422.             # sending out new queries. This could help with peers which have
  423.             # silently gone away since the servermap was updated, for which
  424.             # we're still waiting for the 15-minute TCP disconnect to happen.
  425.             self.log("enough queries are in flight, no more are needed",
  426.                      level=log.NOISY)
  427.             return
  428. 
  429.         outstanding_shnums = set([shnum
  430.                                   for (peerid, shnum, started)
  431.                                   in self._outstanding_queries.values()])
  432.         # prefer low-numbered shares, they are more likely to be primary
  433.         available_shnums = sorted(self.remaining_sharemap.keys())
  434.         for shnum in available_shnums:
  435.             if shnum in outstanding_shnums:
  436.                 # skip ones that are already in transit
  437.                 continue
  438.             if shnum not in self.remaining_sharemap:
  439.                 # no servers for that shnum. note that DictOfSets removes
  440.                 # empty sets from the dict for us.
  441.                 continue
  442.             peerid = list(self.remaining_sharemap[shnum])[0]
  443.             # get_data will remove that peerid from the sharemap, and add the
  444.             # query to self._outstanding_queries
  445.             self._status.set_status("Retrieving More Shares")
  446.             self.get_data(shnum, peerid)
  447.             needed -= 1
  448.             if not needed:
  449.                 break
  450. 
  451.         # at this point, we have as many outstanding queries as we can. If
  452.         # needed!=0 then we might not have enough to recover the file.
  453.         if needed:
  454.             format = ("ran out of peers: "
  455.                       "have %(have)d shares (k=%(k)d), "
  456.                       "%(outstanding)d queries in flight, "
  457.                       "need %(need)d more, "
  458.                       "found %(bad)d bad shares")
  459.             args = {"have": len(self.shares),
  460.                     "k": k,
  461.                     "outstanding": len(self._outstanding_queries),
  462.                     "need": needed,
  463.                     "bad": len(self._bad_shares),
  464.                     }
  465.             self.log(format=format,
  466.                      level=log.WEIRD, umid="ezTfjw", **args)
  467.             err = NotEnoughSharesError("%s, last failure: %s" %
  468.                                       (format % args, self._last_failure))
  469.             if self._bad_shares:
  470.                 self.log("We found some bad shares this pass. You should "
  471.                          "update the servermap and try again to check "
  472.                          "more peers",
  473.                          level=log.WEIRD, umid="EFkOlA")
  474.                 err.servermap = self.servermap
  475.             raise err
  476. 
  477.         return
  478. 
  479.     def _decode(self):
  480.         started = time.time()
  481.         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
  482.          offsets_tuple) = self.verinfo
  483. 
  484.         # shares_dict is a dict mapping shnum to share data, but the codec
  485.         # wants two lists.
  486.         shareids = []; shares = []
  487.         for shareid, share in self.shares.items():
  488.             shareids.append(shareid)
  489.             shares.append(share)
  490. 
  491.         assert len(shareids) >= k, len(shareids)
  492.         # zfec really doesn't want extra shares
  493.         shareids = shareids[:k]
  494.         shares = shares[:k]
  495. 
  496.         fec = codec.CRSDecoder()
  497.         fec.set_params(segsize, k, N)
  498. 
  499.         self.log("params %s, we have %d shares" % ((segsize, k, N), len(shares)))
  500.         self.log("about to decode, shareids=%s" % (shareids,))
  501.         d = defer.maybeDeferred(fec.decode, shares, shareids)
  502.         def _done(buffers):
  503.             self._status.timings["decode"] = time.time() - started
  504.             self.log(" decode done, %d buffers" % len(buffers))
  505.             segment = "".join(buffers)
  506.             self.log(" joined length %d, datalength %d" %
  507.                      (len(segment), datalength))
  508.             segment = segment[:datalength]
  509.             self.log(" segment len=%d" % len(segment))
  510.             return segment
  511.         def _err(f):
  512.             self.log(" decode failed: %s" % f)
  513.             return f
  514.         d.addCallback(_done)
  515.         d.addErrback(_err)
  516.         return d
  517. 
  518.     def _decrypt(self, crypttext, IV, readkey):
  519.         self._status.set_status("decrypting")
  520.         started = time.time()
  521.         key = hashutil.ssk_readkey_data_hash(IV, readkey)
  522.         decryptor = AES(key)
  523.         plaintext = decryptor.process(crypttext)
  524.         self._status.timings["decrypt"] = time.time() - started
  525.         return plaintext
  526. 
  527.     def _done(self, res):
  528.         if not self._running:
  529.             return
  530.         self._running = False
  531.         self._status.set_active(False)
  532.         self._status.timings["total"] = time.time() - self._started
  533.         # res is either the new contents, or a Failure
  534.         if isinstance(res, failure.Failure):
  535.             self.log("Retrieve done, with failure", failure=res,
  536.                      level=log.UNUSUAL)
  537.             self._status.set_status("Failed")
  538.         else:
  539.             self.log("Retrieve done, success!")
  540.             self._status.set_status("Done")
  541.             self._status.set_progress(1.0)
  542.             # remember the encoding parameters, use them again next time
  543.             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
  544.              offsets_tuple) = self.verinfo
  545.             self._node._populate_required_shares(k)
  546.             self._node._populate_total_shares(N)
  547.         eventually(self._done_deferred.callback, res)
  548.