class GitClient(object): """Git smart server client. """ def __init__(self, can_read, read, write, thin_packs=True, report_activity=None): """Create a new GitClient instance. :param can_read: Function that returns True if there is data available to be read. :param read: Callback for reading data, takes number of bytes to read :param write: Callback for writing data :param thin_packs: Whether or not thin packs should be retrieved :param report_activity: Optional callback for reporting transport activity. """ self.proto = Protocol(read, write, report_activity) self._can_read = can_read self._capabilities = list(CAPABILITIES) #if thin_packs: # self._capabilities.append("thin-pack") def capabilities(self): return " ".join(self._capabilities) def read_refs(self): server_capabilities = None refs = {} # Receive refs from server for pkt in self.proto.read_pkt_seq(): (sha, ref) = pkt.rstrip("\n").split(" ", 1) if server_capabilities is None: (ref, server_capabilities) = extract_capabilities(ref) refs[ref] = sha return refs, server_capabilities def send_pack(self, path, determine_wants, generate_pack_contents): """Upload a pack to a remote repository. :param path: Repository path :param generate_pack_contents: Function that can return the shas of the objects to upload. """ refs, server_capabilities = self.read_refs() changed_refs = determine_wants(refs) if not changed_refs: self.proto.write_pkt_line(None) return {} want = [] have = [] sent_capabilities = False for changed_ref, new_sha1 in changed_refs.iteritems(): old_sha1 = refs.get(changed_ref, "0" * 40) if sent_capabilities: self.proto.write_pkt_line("%s %s %s" % (old_sha1, new_sha1, changed_ref)) else: self.proto.write_pkt_line("%s %s %s\0%s" % (old_sha1, new_sha1, changed_ref, self.capabilities())) sent_capabilities = True want.append(new_sha1) if old_sha1 != "0"*40: have.append(old_sha1) self.proto.write_pkt_line(None) objects = generate_pack_contents(want, have) (entries, sha) = write_pack_data(self.proto.write_file(), objects, len(objects)) self.proto.write(sha) # read the final confirmation sha client_sha = self.proto.read(20) # TODO : do something here that doesn't break #if not client_sha in (None, sha): # print "warning: local %s and server %s differ" % (sha_to_hex(sha), sha_to_hex(client_sha)) return changed_refs def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress): """Retrieve a pack from a git smart server. :param determine_wants: Callback that returns list of commits to fetch :param graph_walker: Object with next() and ack(). :param pack_data: Callback called for each bit of data in the pack :param progress: Callback for progress reports (strings) """ (refs, server_capabilities) = self.read_refs() wants = determine_wants(refs) if not wants: self.proto.write_pkt_line(None) return refs self.proto.write_pkt_line("want %s %s\n" % (wants[0], self.capabilities())) for want in wants[1:]: self.proto.write_pkt_line("want %s\n" % want) self.proto.write_pkt_line(None) have = graph_walker.next() while have: self.proto.write_pkt_line("have %s\n" % have) if self._can_read(): pkt = self.proto.read_pkt_line() parts = pkt.rstrip("\n").split(" ") if parts[0] == "ACK": graph_walker.ack(parts[1]) assert parts[2] == "continue" have = graph_walker.next() self.proto.write_pkt_line("done\n") pkt = self.proto.read_pkt_line() while pkt: parts = pkt.rstrip("\n").split(" ") if parts[0] == "ACK": graph_walker.ack(pkt.split(" ")[1]) if len(parts) < 3 or parts[2] != "continue": break pkt = self.proto.read_pkt_line() for pkt in self.proto.read_pkt_seq(): channel = ord(pkt[0]) pkt = pkt[1:] if channel == 1: pack_data(pkt) elif channel == 2: progress(pkt) else: raise AssertionError("Invalid sideband channel %d" % channel) return refs