def main(): import argparse from util import makedir parser = argparse.ArgumentParser() parser.add_argument('--path') parser.add_argument('--output_dir') args = parser.parse_args() makedir(args.output_dir) doc = load_doc(args.path) copying_data = [] copying_columns = ('id', 'length', 'doc_ids') text_data = {} for i, chunk in enumerate(doc.find_all("chunk")): texts = extract_chunk_info(chunk) copying_data.append( (i, int(chunk['length']), [t['id'] for t in texts]) ) for t in texts: text_data[t['id']] = t cp_df = pd.DataFrame(copying_data, columns=copying_columns) text_df = pd.DataFrame.from_records(text_data.values(), index='id') cp_df.to_pickle('{}/copying_dataframe.pkl'.format(args.output_dir)) text_df.to_pickle('{}/text_dataframe.pkl'.format(args.output_dir))
def move_best_result(best_version, this_dataset): best_mat_version, best_data_version = best_version.split(" ") path = result_path + "/" + best_data_version + "/" + this_dataset out_dir = result_path + "/best/row_results/" makedir(out_dir) new_path = out_dir + this_dataset for algo in algos: file_name = path + "_" + best_mat_version + "_" + algo + ".txt" new_file_name = new_path + "_" + algo + ".txt" copyfile(file_name, new_file_name)
def run(feature_dir, selected_feature_dir, config): ff = FeatureFilter(config) util.makedir(selected_feature_dir) for name in os.listdir(feature_dir): util.write('begin %s' % name) feature_path = os.path.join(feature_dir, name) selected_path = os.path.join(selected_feature_dir, name) ff.filter_and_dump(feature_path, selected_path) return None
def write_body(content, path, logger): _dir = os.path.dirname(path) util.makedir(_dir) with open(path, 'wb') as out: try: out.write(content.encode('utf-8')) except: logger.error('!!!!!! Invalid Encode with %s !!!!!!!' % path) out.write('Invalid HUO Encode') return None
def to_vector(name, feature_dict, matrix_dir): util.makedir(matrix_dir) pc = PersonCorpus(name) fm = FeatureMapper() for rank, feat in feature_dict.items(): vec = FeatureVector(rank, feat, fm) pc.add_vector(vec) pc.compute_matrix() matrix_path = os.path.join(matrix_dir, '%s.matrix' % name) pc.dump_matrix(matrix_path) return None
def run(webpages_dir, body_text_dir, log_path): logger = util.init_log('TextExtract', log_path, console=False) util.makedir(body_text_dir) for name in os.listdir(webpages_dir): logger.info('begin extract body text of %s' % name) a_person_dir = os.path.join(webpages_dir, '%s/' % name) for file_name in os.listdir(a_person_dir): rank = file_name.split('.')[0] html_path = os.path.join(a_person_dir, file_name) content = text_extract(html_path, logger) body_path = os.path.join(body_text_dir, name, '%s.txt' % rank) write_body(content, body_path, logger) return None
def install_jbrowse(install_dir, root_url, rename_to=None, conf_file=None): if not conf_file: conf_file = config.get_default_conf_file() install_dir = util.abspath(install_dir) pkg_data_dir = config.get_pkg_data_dir() jbrowse_zip = glob(pjoin(pkg_data_dir,"JBrowse*")) assert len(jbrowse_zip) == 1, "Expected a single JBrowse archive "+\ "stored within the package data dir: %s" % (pkg_data_dir,) jbrowse_zip = jbrowse_zip[0] util.makedir(install_dir) #context manager only works for zipfile in Python 2.7 f = zipfile.ZipFile(jbrowse_zip, 'r') try: install_name = os.path.dirname(f.namelist()[0]) assert (install_name not in (".","..")) and \ os.path.dirname(install_name) == "",\ "Unsafe path detected in JBrowse archive: {}".format(f.namelist()[0]) install_home = pjoin(install_dir,install_name) #JBrowse setup script will not install Pelr modules #if it is executed in a directory where it was ran before, #even unsuccessfuly. #Wack the existing directory: if os.path.exists(install_home): shutil.rmtree(install_home) #somehow zipfile module wacks executable bits #f.extractall(path=install_dir) #unsafe: check_call(["unzip","-q","-o",jbrowse_zip],cwd=install_dir) finally: f.close() if rename_to: install_home_new = pjoin(install_dir,rename_to) if os.path.exists(install_home_new): shutil.rmtree(install_home_new) os.rename(install_home,install_home_new) install_home = install_home_new check_call(["./setup.sh"],cwd=install_home) for line in fileinput.input(pjoin(install_home,"index.html"),inplace=True): #Galaxy Web server intercepts 'data' in URL params, we need to use another name print line.replace('queryParams.data','queryParams.jbrowse_data'), conf = util.load_config_json(conf_file) conf["jbrowse_bin_dir"] = util.abspath(pjoin(install_home,"bin")) conf["jbrowse_url"] = util.urljoin_path(root_url, os.path.basename(install_home)) util.save_config_json(conf,conf_file) return conf_file
def run(matrix_dir, cosine_dir, similarity_method): util.makedir(cosine_dir) count = 0 for file_name in os.listdir(matrix_dir): name = file_name.split('.')[0] count += 1 util.write('begin %s: %s' % (count, name)) file_path = os.path.join(matrix_dir, file_name) matrix = util.load_matrix(file_path) # sim_matrix = cosine(matrix) sim_matrix = compute_similarity(matrix, similarity_method) cosine_path = os.path.join(cosine_dir, '%s.matrix' % name) util.dump_matrix(sim_matrix, cosine_path) return None
def run_extra(body_text_dir, extra_feature_dir, config): flt = FeatureExtractor(config) util.makedir(extra_feature_dir) c = 0 for name in os.listdir(body_text_dir): name_body_dir = os.path.join(body_text_dir, name) extra_features = {} print 'begin %s' % name extra_features = flt.extra_extract(name, name_body_dir) features_pickle_path = os.path.join(extra_feature_dir, '%s.json' % name) with open(features_pickle_path, 'wb') as fp: json.dump(extra_features, fp) # c += 1 # if(c==1): # break return None
def install_jbrowse(install_dir, root_url, rename_to=None, conf_file=None): if not conf_file: conf_file = config.get_default_conf_file() install_dir = util.abspath(install_dir) pkg_data_dir = config.get_pkg_data_dir() jbrowse_zip = glob(pjoin(pkg_data_dir, "JBrowse*")) assert len(jbrowse_zip) == 1, "Expected a single JBrowse archive "+\ "stored within the package data dir: %s" % (pkg_data_dir,) jbrowse_zip = jbrowse_zip[0] util.makedir(install_dir) #context manager only works for zipfile in Python 2.7 f = zipfile.ZipFile(jbrowse_zip, 'r') try: install_name = os.path.dirname(f.namelist()[0]) assert (install_name not in (".","..")) and \ os.path.dirname(install_name) == "",\ "Unsafe path detected in JBrowse archive: {}".format(f.namelist()[0]) install_home = pjoin(install_dir, install_name) #JBrowse setup script will not install Pelr modules #if it is executed in a directory where it was ran before, #even unsuccessfuly. #Wack the existing directory: if os.path.exists(install_home): shutil.rmtree(install_home) #somehow zipfile module wacks executable bits #f.extractall(path=install_dir) #unsafe: check_call(["unzip", "-q", "-o", jbrowse_zip], cwd=install_dir) finally: f.close() if rename_to: install_home_new = pjoin(install_dir, rename_to) if os.path.exists(install_home_new): shutil.rmtree(install_home_new) os.rename(install_home, install_home_new) install_home = install_home_new check_call(["./setup.sh"], cwd=install_home) for line in fileinput.input(pjoin(install_home, "index.html"), inplace=True): #Galaxy Web server intercepts 'data' in URL params, we need to use another name print line.replace('queryParams.data', 'queryParams.jbrowse_data'), conf = util.load_config_json(conf_file) conf["jbrowse_bin_dir"] = util.abspath(pjoin(install_home, "bin")) conf["jbrowse_url"] = util.urljoin_path(root_url, os.path.basename(install_home)) util.save_config_json(conf, conf_file) return conf_file
def run(args): basedir = os.path.join(args.directory, "intermediate_results") util.makedir(basedir) fin_bam = pysam.AlignmentFile(args.bam, "rb") fout = open(os.path.join(basedir, "spec.txt"), "w") if args.mk_file: fout_mk = open(os.path.join(basedir, "inc.mk"), "w") fout_mk.write("ALN_ROOTDIR=\"{}\"\n".format(basedir)) fout_mk.write("ALN_SOURCES=") ref_ids = ManyFiles.ManyWriteFiles(args.max_nfiles - 2, args.cache_size) ref_lengths = fin_bam.lengths for aln in fin_bam.fetch(until_eof=True): # N.B. the bam file generated by bwa-mem could have reference_id == -1. Before finding a better way to resolve this issue, skip this alignment for now if aln.reference_id < 0 or aln.reference_id >= len(ref_lengths): continue fout_aln = ref_ids.get(aln.reference_id, None) if not fout_aln: composed_dir = util.int2path(aln.reference_id) directory_name = os.path.join(basedir, composed_dir) util.makedir(directory_name) aln_filename = os.path.join(directory_name, str(aln.reference_id) + ".txt") fout_aln = ref_ids.open(aln_filename, aln.reference_id) fout.write("{} {} {} {}\n".format(aln.reference_id, aln.reference_name, ref_lengths[aln.reference_id], directory_name)) if args.mk_file: fout_mk.write("{} ".format( os.path.join(composed_dir, str(aln.reference_id)))) ref_ids.write( fout_aln, "{} {} {} {} {} {} {} {} {}\n".format( aln.reference_start, aln.reference_end, aln.query_alignment_start, aln.query_alignment_end, aln.query_name, aln.reference_length, aln.query_length, aln.mapping_quality if aln.mapping_quality != "*" else 0, 'R' if aln.is_reverse else 'F')) ref_ids.close() fout.close() if args.mk_file: if ref_ids: fout_mk.write("ALN_RUN=1\n") else: fout_mk.write("ALN_RUN=0\n") fout_mk.close() fin_bam.close()
def cp(self, mapped_doc_dir): id_mapper = self.id_mapper util.makedir(mapped_doc_dir) for file_name in os.listdir(self.doc_dir): if self.version == '2007test': source_path = os.path.join(self.doc_dir, file_name, 'index.html') elif self.version == '2008test': source_path = os.path.join(self.doc_dir, file_name) print source_path doc_id = file_name.split('.')[0].lstrip('0').zfill(1) mapped_doc_id = id_mapper[doc_id] mapped_file_name = '%d.html' % mapped_doc_id target_path = os.path.join(mapped_doc_dir, mapped_file_name) cmd = ['cp', source_path, target_path] util.write('Copy file from %s to %s' % (source_path, target_path)) subprocess.call(cmd) return None
def main(): args = parse_args() if args.steps: args.run_steps = set(args.steps) else: args.run_steps = set(all_steps[all_steps_key[args.start_from]:]) riginv_rootdir = os.path.normpath( os.path.join( os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "..")) args.aux_dir = os.path.join(riginv_rootdir, "libexec", "bin") args.makefile_dir = os.path.join(riginv_rootdir, "libexec", "makefiles") args.pylib_dir = os.path.join(riginv_rootdir, "lib", "python", "riginv_lib") sys.path.append(args.pylib_dir) log_levels = { "debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING, "error": logging.ERROR, "critical": logging.CRITICAL } if args.log: import util util.makedir(args.working_dir) logging.basicConfig( filename=os.path.join(args.working_dir, "riginv.log"), format="[%(asctime)s] [%(levelname)s]: %(message)s", level=log_levels[args.log_level]) else: logging.basicConfig( format="[%(asctime)s] [%(levelname)s]: %(message)s", level=log_levels[args.log_level]) args.logger = logging.getLogger() args.logger.info("Start") args.logger.debug("Run steps: {}".format(", ".join(args.run_steps))) run(args) args.logger.info("Done!")
def save_best_dataset(best_version, this_dataset): # Copy files in the "best" folder best_mat_version, best_data_version = best_version.split(" ") path = dataset_path + "/" + best_data_version + "/" + this_dataset out_dir = dataset_path + "/best/" makedir(out_dir) new_path = out_dir + this_dataset files = [ "_preprocessed.csv", "_preprocessed_" + best_mat_version + ".mat", "_preprocessed_vocabulary.csv" ] out_files = [ "_preprocessed.csv", "_preprocessed.mat", "_preprocessed_vocabulary.csv" ] for i in range(len(files)): file_name = files[i] print(path + file_name) copyfile(path + file_name, new_path + out_files[i])
def share(ui, source, dest=None, update=True): '''create a shared repository''' if not islocal(source): raise util.Abort(_('can only share local repositories')) if not dest: dest = defaultdest(source) else: dest = ui.expandpath(dest) if isinstance(source, str): origsource = ui.expandpath(source) source, branches = parseurl(origsource) srcrepo = repository(ui, source) rev, checkout = addbranchrevs(srcrepo, srcrepo, branches, None) else: srcrepo = source origsource = source = srcrepo.url() checkout = None sharedpath = srcrepo.sharedpath # if our source is already sharing root = os.path.realpath(dest) roothg = os.path.join(root, '.hg') if os.path.exists(roothg): raise util.Abort(_('destination already exists')) if not os.path.isdir(root): os.mkdir(root) util.makedir(roothg, notindexed=True) requirements = '' try: requirements = srcrepo.opener.read('requires') except IOError, inst: if inst.errno != errno.ENOENT: raise
def run(selected_feature_dir, extra_feature_dir, matrix_dir): feature_dir = selected_feature_dir util.makedir(matrix_dir) for file_name in os.listdir(feature_dir): name = file_name.split('.')[0] feature_path = os.path.join(feature_dir, file_name) feature_dict = util.load_pickle(feature_path, typ='json') # word features matrix word_matrix_dir = os.path.join(matrix_dir, 'word/') to_vector(name, feature_dict, word_matrix_dir) # extra features matrix extra_feature_path = os.path.join(extra_feature_dir, file_name) print extra_feature_path extra_feature_dict = util.load_pickle(extra_feature_path, typ='json') for cls in get_extra_feature_class(extra_feature_dict): cls_matrix_dir = os.path.join(matrix_dir, cls) cls_feature_dict = {} for rank in extra_feature_dict: cls_feature_dict[rank] = extra_feature_dict[rank][cls] to_vector(name, cls_feature_dict, cls_matrix_dir) return None
def run(body_text_dir, feature_dir, config): flt = FeatureExtractor(config) util.makedir(feature_dir) c = 0 for name in os.listdir(body_text_dir): name_dir = os.path.join(body_text_dir, name) features = {} print 'begin %s' % name name_list = name.lower().split('_') for rank_file_name in os.listdir(name_dir): rank = rank_file_name.split('.')[0] print 'start %s' % rank_file_name with open(os.path.join(name_dir, rank_file_name)) as rank_file: text = rank_file.read() features[rank] = flt.extract(name, name_list, text) features_pickle_path = os.path.join(feature_dir, '%s.json' % name) with open(features_pickle_path, 'wb') as fp: json.dump(features, fp) # c += 1 # if(c==1): # break break return None
def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, required=True, help='[KTH Actions|UCF-101|HMDB-51]') parser.add_argument('--results_dir', type=str, default='results/') parser.add_argument('--output_dir', type=str, default='paper_imgs') opt = parser.parse_args() exp_dict = dict(np.load('records/finished_exp.npy').item()) f = open(videos[opt.dataset], 'r') cases = f.readlines() test_names = [] baseline = None makedir(opt.output_dir) for model in models: try: test_name = exp_dict[opt.dataset][model][input_output[ opt.dataset]][-1] test_names.append(test_name) if model == 'MC-Net': baseline = test_name except: print('exp with {%s, %s, %s} does not exist' % (opt.dataset, model, input_output[opt.dataset])) continue img_dir = os.path.join(opt.results_dir, 'images', nick_name[opt.dataset]) name_prefix = os.path.join(opt.output_dir, '%s_vis' % (opt.dataset)) slide_selected_videos(input_output[opt.dataset], cases, test_names, img_dir, name_prefix, skip=skip[opt.dataset], start=start[opt.dataset], baseline=baseline)
self._files.clear() def _safe_open(self, fp): if len(self._open_queue) >= self._max_nfile: self._files[self._open_queue.pop()].close() fp.open() self._open_queue.add(fp.fid) if __name__ == "__main__": nfiles = 128 n_tests = 1024 files = ManyWriteFiles(nfiles) import util for i in xrange(n_tests): directory = os.path.join("working_dir", util.int2path(i)) util.makedir(directory) fp = files.open(os.path.join(directory, str(i) + ".txt"), i) files.write(fp, "hello, world! {}\n".format(i)) for i in xrange(n_tests): files.write(files[i], "hello, world again! {}\n".format(i)) files.close() print "Write down! Do check" for i in xrange(n_tests): with open( os.path.join("working_dir", util.int2path(i), str(i) + ".txt")) as fin: assert (fin.read() == "hello, world! {}\nhello, world again! {}\n".format(i, i)) print "Done!"
def clone(ui, peeropts, source, dest=None, pull=False, rev=None, update=True, stream=False, branch=None): """Make a copy of an existing repository. Create a copy of an existing repository in a new directory. The source and destination are URLs, as passed to the repository function. Returns a pair of repository peers, the source and newly created destination. The location of the source is added to the new repository's .hg/hgrc file, as the default to be used for future pulls and pushes. If an exception is raised, the partly cloned/updated destination repository will be deleted. Arguments: source: repository object or URL dest: URL of destination repository to create (defaults to base name of source repository) pull: always pull from source repository, even in local case or if the server prefers streaming stream: stream raw data uncompressed from repository (fast over LAN, slow over WAN) rev: revision to clone up to (implies pull=True) update: update working directory after clone completes, if destination is local repository (True means update to default rev, anything else is treated as a revision) branch: branches to clone """ if isinstance(source, str): origsource = ui.expandpath(source) source, branch = parseurl(origsource, branch) srcpeer = peer(ui, peeropts, source) else: srcpeer = source.peer() # in case we were called with a localrepo branch = (None, branch or []) origsource = source = srcpeer.url() rev, checkout = addbranchrevs(srcpeer, srcpeer, branch, rev) if dest is None: dest = defaultdest(source) if dest: ui.status(_("destination directory: %s\n") % dest) else: dest = ui.expandpath(dest) dest = util.urllocalpath(dest) source = util.urllocalpath(source) if not dest: raise util.Abort(_("empty destination path is not valid")) destvfs = scmutil.vfs(dest, expandpath=True) if destvfs.lexists(): if not destvfs.isdir(): raise util.Abort(_("destination '%s' already exists") % dest) elif destvfs.listdir(): raise util.Abort(_("destination '%s' is not empty") % dest) srclock = destlock = cleandir = None srcrepo = srcpeer.local() try: abspath = origsource if islocal(origsource): abspath = os.path.abspath(util.urllocalpath(origsource)) if islocal(dest): cleandir = dest copy = False if (srcrepo and srcrepo.cancopy() and islocal(dest) and not phases.hassecret(srcrepo)): copy = not pull and not rev if copy: try: # we use a lock here because if we race with commit, we # can end up with extra data in the cloned revlogs that's # not pointed to by changesets, thus causing verify to # fail srclock = srcrepo.lock(wait=False) except error.LockError: copy = False if copy: srcrepo.hook('preoutgoing', throw=True, source='clone') hgdir = os.path.realpath(os.path.join(dest, ".hg")) if not os.path.exists(dest): os.mkdir(dest) else: # only clean up directories we create ourselves cleandir = hgdir try: destpath = hgdir util.makedir(destpath, notindexed=True) except OSError as inst: if inst.errno == errno.EEXIST: cleandir = None raise util.Abort(_("destination '%s' already exists") % dest) raise destlock = copystore(ui, srcrepo, destpath) # copy bookmarks over srcbookmarks = srcrepo.join('bookmarks') dstbookmarks = os.path.join(destpath, 'bookmarks') if os.path.exists(srcbookmarks): util.copyfile(srcbookmarks, dstbookmarks) # Recomputing branch cache might be slow on big repos, # so just copy it def copybranchcache(fname): srcbranchcache = srcrepo.join('cache/%s' % fname) dstbranchcache = os.path.join(dstcachedir, fname) if os.path.exists(srcbranchcache): if not os.path.exists(dstcachedir): os.mkdir(dstcachedir) util.copyfile(srcbranchcache, dstbranchcache) dstcachedir = os.path.join(destpath, 'cache') # In local clones we're copying all nodes, not just served # ones. Therefore copy all branch caches over. copybranchcache('branch2') for cachename in repoview.filtertable: copybranchcache('branch2-%s' % cachename) # we need to re-init the repo after manually copying the data # into it destpeer = peer(srcrepo, peeropts, dest) srcrepo.hook('outgoing', source='clone', node=node.hex(node.nullid)) else: try: destpeer = peer(srcrepo or ui, peeropts, dest, create=True) # only pass ui when no srcrepo except OSError as inst: if inst.errno == errno.EEXIST: cleandir = None raise util.Abort(_("destination '%s' already exists") % dest) raise revs = None if rev: if not srcpeer.capable('lookup'): raise util.Abort(_("src repository does not support " "revision lookup and so doesn't " "support clone by revision")) revs = [srcpeer.lookup(r) for r in rev] checkout = revs[0] if destpeer.local(): if not stream: if pull: stream = False else: stream = None destpeer.local().clone(srcpeer, heads=revs, stream=stream) elif srcrepo: exchange.push(srcrepo, destpeer, revs=revs, bookmarks=srcrepo._bookmarks.keys()) else: raise util.Abort(_("clone from remote to remote not supported")) cleandir = None destrepo = destpeer.local() if destrepo: template = uimod.samplehgrcs['cloned'] fp = destrepo.vfs("hgrc", "w", text=True) u = util.url(abspath) u.passwd = None defaulturl = str(u) fp.write(template % defaulturl) fp.close() destrepo.ui.setconfig('paths', 'default', defaulturl, 'clone') if update: if update is not True: checkout = srcpeer.lookup(update) uprev = None status = None if checkout is not None: try: uprev = destrepo.lookup(checkout) except error.RepoLookupError: pass if uprev is None: try: uprev = destrepo._bookmarks['@'] update = '@' bn = destrepo[uprev].branch() if bn == 'default': status = _("updating to bookmark @\n") else: status = (_("updating to bookmark @ on branch %s\n") % bn) except KeyError: try: uprev = destrepo.branchtip('default') except error.RepoLookupError: uprev = destrepo.lookup('tip') if not status: bn = destrepo[uprev].branch() status = _("updating to branch %s\n") % bn destrepo.ui.status(status) _update(destrepo, uprev) if update in destrepo._bookmarks: bookmarks.activate(destrepo, update) finally: release(srclock, destlock) if cleandir is not None: shutil.rmtree(cleandir, True) if srcpeer is not None: srcpeer.close() return srcpeer, destpeer
mat = io.loadmat(mat_file)['X'] print(mat.shape) no_cluster = len(np.unique(y)) print(no_cluster) algo_pipeline = [] algo_pipeline.append((CoclustInfo(n_row_clusters=no_cluster, n_col_clusters=no_cluster, n_init=10, max_iter=200), "CoclustInfo")) algo_pipeline.append((CoclustMod(n_clusters=no_cluster, n_init=10, max_iter=200), "CoclustMod")) algo_pipeline.append((CoclustSpecMod(n_clusters=no_cluster, n_init=10, max_iter=200), "CoclustSpecMod")) for model, model_name in algo_pipeline: res_nmi, res_ari, res_acc = execute_algo(model, model_name, mat, y) # Save results out_dir = result_path + "/" + data_version + "/" makedir(out_dir) out_file = out_dir + dataset + "_" + mat_version + "_" + model_name + ".txt" content = str(res_nmi) + ", " + str(res_ari) + ", " + str( res_acc) + "\n" myfile = open(out_file, "a") myfile.write(content) myfile.close()
def train(): conf = load_conf() no_replay = conf['noreplay'] solver = conf['solver'] n = conf['n'] eps = conf['eps'] savedir = conf['dirname'] makedir(savedir) tmpdir = os.path.join(savedir, 'tmp') makedir(tmpdir) np.random.seed(conf['seed']) logfile = os.path.join(savedir, 'log') ave = 0 aves = [] ma = 0 global_ma = 0 channels = [10, 100, 500, n * (n - 1) // 2] if 'channels' in conf: channels = conf['channels'] channels.append(n * (n - 1) // 2) bias = -np.log(1.0 / conf['p'] - 1) net = MLP(channels, bias) if conf['gpu'] != -1: chainer.cuda.get_device_from_id(conf['gpu']).use() net.to_gpu() if conf['opt'] == 'SGD': opt = chainer.optimizers.SGD(lr=conf['lr']) elif conf['opt'] == 'Adam': opt = chainer.optimizers.Adam(alpha=conf['lr']) opt.setup(net) stop = 0 pool_size = 10 start_training = 20 r_bests = [] edges_bests = [] z_bests = [] if no_replay: pool_size = 1 start_training = 1e9 iteration = 0 from_restart = 0 start_time = time.time() while True: iteration += 1 from_restart += 1 z = net.z(1) x = net(z)[0] edges_li, edges, lp = gen_edges(n, x, net.xp) r = calc_reward(n, edges, solver, tmpdir) entropy = F.mean(x * F.log(x + 1e-6) + (1 - x) * F.log(1 - x + 1e-6)) if no_replay: loss = -r * lp net.cleargrads() loss.backward() opt.update() if r > ma: ma = r stop = 0 else: stop += 1 if r > global_ma: global_ma = r output_graph(os.path.join(savedir, 'output_{}.txt'.format(r)), n, edges) output_distribution( os.path.join(savedir, 'distribution_{}.txt'.format(r)), n, x.data) chainer.serializers.save_npz( os.path.join(savedir, 'snapshot_at_reward_{}'.format(r)), net) elapsed = time.time() - start_time ave = ave * (1 - conf['eps']) + r * conf['eps'] aves.append(ave) with open(logfile, 'a') as f: print(savedir, iteration, elapsed, r, len(edges), entropy.data, global_ma, ma, ave, flush=True) print(iteration, elapsed, r, len(edges), entropy.data, global_ma, ma, ave, flush=True, file=f) f = False for es in edges_bests: if (es == edges_li).all(): f = True if not f: r_bests.append(r) edges_bests.append(edges_li) z_bests.append(z) while len(r_bests) > pool_size: mi = 0 for j in range(len(r_bests)): if r_bests[j] < r_bests[mi]: mi = j r_bests.pop(mi) edges_bests.pop(mi) z_bests.pop(mi) if from_restart >= start_training: ind = np.random.randint(len(r_bests)) x = net(z_bests[ind])[0] lp = calc_lp(n, x, edges_bests[ind], net.xp) loss = -r_bests[ind] * lp net.cleargrads() loss.backward() opt.update() if stop >= conf['restart']: stop = 0 ma = 0 r_bests = [] edges_bests = [] z_bests = [] from_restart = 0 net = MLP(channels, bias) if conf['gpu'] != -1: chainer.cuda.get_device_from_id(conf['gpu']).use() net.to_gpu() if conf['opt'] == 'SGD': opt = chainer.optimizers.SGD(lr=conf['lr']) elif conf['opt'] == 'Adam': opt = chainer.optimizers.Adam(alpha=conf['lr']) opt.setup(net) continue if iteration % 100 == 0: plt.clf() plt.plot(range(len(aves)), aves) plt.savefig(os.path.join(savedir, 'graph.png')) if iteration % 1000 == 0: plt.savefig(os.path.join(savedir, 'graph_{}.png'.format(iteration))) plt.savefig(os.path.join(savedir, 'graph_{}.eps'.format(iteration))) chainer.serializers.save_npz( os.path.join(savedir, 'snapshot_{}'.format(iteration)), net) chainer.serializers.save_npz( os.path.join(savedir, 'opt_{}'.format(iteration)), opt)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--which_plot', type=str, required=True, choices=['fig4', 'fig7(a)', 'fig8', 'fig9', 'fig24']) parser.add_argument('--results_dir', type=str, default='results/') parser.add_argument('--output_dir', type=str, default='paper_plots') opt = parser.parse_args() datasets, metrics, models, i_os = name2keys(opt.which_plot) exp_dict = dict(np.load('records/finished_exp.npy').item()) makedir(opt.output_dir) if opt.which_plot != 'fig8': for metric in metrics: if metric == 'ssim': ylabel = 'SSIM' else: ylabel = 'PSNR' for dataset in datasets: for i_o in i_os: fig, ax = plt.subplots(1, 1) for model in models: try: test_name = exp_dict[dataset][model][i_o][-1] except: print('exp with {%s, %s, %s} does not exist' % (dataset, model, i_o)) continue quant_exp_dir = os.path.join(opt.results_dir, 'quantitative', nick_name[dataset], test_name) metrics_dict = dict( np.load(os.path.join(quant_exp_dir, 'results.npz'))) quant_result = metrics_dict[metric] mask = np.isinf(quant_result) quant_result = np.ma.array(quant_result, mask=mask) avg_err = quant_result.mean(axis=0) T = quant_result.shape[1] x = np.arange(1, T + 1) ax.plot(x, avg_err.data, linestyle=ls_dict[model], color=color_dict[model], linewidth=2) axes = plt.gca() axes.set_ylim(ylims[metric][dataset][opt.which_plot]) x_lim = [1, T] axes.set_xlim(x_lim) ax.set_xlabel('time steps') ax.set_ylabel(ylabel) ax.set_xticks(x) ax.set_title(dataset) plt.legend(models, loc='upper center', bbox_to_anchor=(0.5, -0.2), fancybox=True, ncol=min(4, len(models))) plt.grid() plot_name = os.path.join( opt.output_dir, '%s_%s_%s_%s.eps' % (opt.which_plot, nick_name[dataset], metric, i_o)) fig.set_size_inches(7, 5 * 0.75) plt.savefig(plot_name, format='eps', bbox_inches='tight') else: for metric in metrics: if metric == 'ssim': ylabel = 'SSIM' else: ylabel = 'PSNR' for dataset in datasets: for model in models: fig, ax = plt.subplots(1, 1) legends = [] for i_o in i_os: try: test_name = exp_dict[dataset][model][i_o][-1] except: print('exp with {%s, %s, %s} does not exist' % (dataset, model, i_o)) continue in_num = int(i_o.split('_')[0]) legends.append('%s(input %d frames)' % (model, in_num)) quant_exp_dir = os.path.join(opt.results_dir, 'quantitative', nick_name[dataset], test_name) metrics_dict = dict( np.load(os.path.join(quant_exp_dir, 'results.npz'))) quant_result = metrics_dict[metric] mask = np.isinf(quant_result) quant_result = np.ma.array(quant_result, mask=mask) avg_err = quant_result.mean(axis=0) T = quant_result.shape[1] x = np.arange(1, T + 1) ax.plot(x, avg_err.data, linewidth=2) axes = plt.gca() axes.set_ylim(ylims[metric][dataset][opt.which_plot]) x_lim = [1, T] axes.set_xlim(x_lim) ax.set_xlabel('time steps') ax.set_ylabel(ylabel) ax.set_xticks(x) ax.set_title(dataset) plt.legend(models, loc='upper center', bbox_to_anchor=(0.5, -0.2), fancybox=True, ncol=min(4, len(models))) plt.grid() plot_name = os.path.join( opt.output_dir, '%s_%s_%s.eps' % (opt.which_plot, nick_name[dataset], metric)) fig.set_size_inches(7, 5 * 0.75) plt.savefig(plot_name, format='eps', bbox_inches='tight')
def run_discordant_type2(args): """Run discordant type2 analysis, partition them into connected components, and do clustering 1. Discordant type2 analysis Input file: each <id path>/<id>.sorted.txt Output file: each <id path>/<id>.type2.txt 2. Partition into connected components Input file: each <id path>/<id>.type2.txt Output files: <id_path>/<id>_type2_parts/ |- spec.txt |- <part_id> paths ... / (for each <part_id>) |- <part_id>.txt 3. Clustering: Input file: each <part_id_path>/<part_id>.txt Output file: <part_id_path>/ |- <part_id>_sol/ |- predictions.sol |- spec.txt (may exist) |- <sub_part_id> paths ... / |- <sub_part_id>.txt 4. Refine the results Input file: each <id>.concordant.txt Input directory: each <id>_type2_parts/ Output file: each <id>_type2.sol """ args.logger.info("Find Type 2 inversions") import util for contig in id_path_iter(args): if "type2" in args.run_steps: # 1. Discordant type2 analysis args.logger.info(" Analyze contig {}".format(contig["id"])) args.logger.info(" Find all type 2 rectangles") subprocess.check_call([ os.path.join(args.aux_dir, "discordant_type2"), os.path.join(contig["id_path"], contig["id"] + ".sorted.txt"), os.path.join(contig["id_path"], contig["id"] + ".type2.txt"), str(args.min_quality), str(args.t2_min_extension), str(args.t2_ksi), str(args.max_adj_distance) ]) if "cluster2" in args.run_steps: # 2. Partition into connected components args.logger.info(" Partition into connected components") type2_parts_dir = os.path.join(contig["id_path"], contig["id"] + "_type2_parts") util.makedir(type2_parts_dir) subprocess.check_call([ os.path.join(args.aux_dir, "partition_disconnected_rects"), os.path.join(contig["id_path"], contig["id"] + ".type2.txt"), os.path.join(type2_parts_dir), str(args.min_rectangles), str(args.min_rect_sides) ]) # 3. Clustering args.logger.info(" Clustering") for component in part_id_path_iter(contig, "type2"): sol_dir = os.path.join(component["part_id_path"], component["part_id"] + "_sol") util.makedir(sol_dir) subprocess.check_call([ "rm", "-f", os.path.join(sol_dir, "spec.txt"), os.path.join(sol_dir, "predictions.sol") ]) subprocess.check_call([ os.path.join(args.aux_dir, "cluster_by_maximal_coverage"), os.path.join(component["part_id_path"], component["part_id"] + ".txt"), sol_dir, str(args.prob_contain_rect), str(args.confidence), str(args.min_rectangles), str(args.min_rect_sides), str(args.remove_portion), boolTo01[args.compute_rect_first], boolTo01[not args.keep_overlapping_predictions], str(args.min_brace_coverage), str(args.min_brace_imbalance_ratio) ]) # 4. Refinement if args.t2_no_refine: args.logger.info(" Merge predictions to a single file") subprocess.check_call([ os.path.join(args.aux_dir, "merge_files"), os.path.join(contig["id_path"], contig["id"] + "_type2_parts"), os.path.join(contig["id_path"], contig["id"] + "_type2.sol"), "2", "0", "1", os.path.join("_sol", "predictions.sol") ]) else: args.logger.info(" Refinement") subprocess.check_call([ os.path.join(args.aux_dir, "refine_type2"), os.path.join(contig["id_path"], contig["id"] + ".concordant.txt"), os.path.join(contig["id_path"], contig["id"] + "_type2_parts"), os.path.join(contig["id_path"], contig["id"] + "_type2.sol") ])
def makedir(self, path=None, notindexed=True): return util.makedir(self.join(path), notindexed)
def clone(ui, peeropts, source, dest=None, pull=False, rev=None, update=True, stream=False, branch=None): """Make a copy of an existing repository. Create a copy of an existing repository in a new directory. The source and destination are URLs, as passed to the repository function. Returns a pair of repository objects, the source and newly created destination. The location of the source is added to the new repository's .hg/hgrc file, as the default to be used for future pulls and pushes. If an exception is raised, the partly cloned/updated destination repository will be deleted. Arguments: source: repository object or URL dest: URL of destination repository to create (defaults to base name of source repository) pull: always pull from source repository, even in local case stream: stream raw data uncompressed from repository (fast over LAN, slow over WAN) rev: revision to clone up to (implies pull=True) update: update working directory after clone completes, if destination is local repository (True means update to default rev, anything else is treated as a revision) branch: branches to clone """ if isinstance(source, str): origsource = ui.expandpath(source) source, branch = parseurl(origsource, branch) srcrepo = repository(remoteui(ui, peeropts), source) else: srcrepo = source branch = (None, branch or []) origsource = source = srcrepo.url() rev, checkout = addbranchrevs(srcrepo, srcrepo, branch, rev) if dest is None: dest = defaultdest(source) ui.status(_("destination directory: %s\n") % dest) else: dest = ui.expandpath(dest) dest = util.urllocalpath(dest) source = util.urllocalpath(source) if os.path.exists(dest): if not os.path.isdir(dest): raise util.Abort(_("destination '%s' already exists") % dest) elif os.listdir(dest): raise util.Abort(_("destination '%s' is not empty") % dest) class DirCleanup(object): def __init__(self, dir_): self.rmtree = shutil.rmtree self.dir_ = dir_ def close(self): self.dir_ = None def cleanup(self): if self.dir_: self.rmtree(self.dir_, True) srclock = destlock = dircleanup = None try: abspath = origsource if islocal(origsource): abspath = os.path.abspath(util.urllocalpath(origsource)) if islocal(dest): dircleanup = DirCleanup(dest) copy = False if srcrepo.cancopy() and islocal( dest) and not srcrepo.revs("secret()"): copy = not pull and not rev if copy: try: # we use a lock here because if we race with commit, we # can end up with extra data in the cloned revlogs that's # not pointed to by changesets, thus causing verify to # fail srclock = srcrepo.lock(wait=False) except error.LockError: copy = False if copy: srcrepo.hook('preoutgoing', throw=True, source='clone') hgdir = os.path.realpath(os.path.join(dest, ".hg")) if not os.path.exists(dest): os.mkdir(dest) else: # only clean up directories we create ourselves dircleanup.dir_ = hgdir try: destpath = hgdir util.makedir(destpath, notindexed=True) except OSError, inst: if inst.errno == errno.EEXIST: dircleanup.close() raise util.Abort( _("destination '%s' already exists") % dest) raise destlock = copystore(ui, srcrepo, destpath) # we need to re-init the repo after manually copying the data # into it destrepo = repository(remoteui(ui, peeropts), dest) srcrepo.hook('outgoing', source='clone', node=node.hex(node.nullid)) else:
def clone(ui, peeropts, source, dest=None, pull=False, rev=None, update=True, stream=False, branch=None): """Make a copy of an existing repository. Create a copy of an existing repository in a new directory. The source and destination are URLs, as passed to the repository function. Returns a pair of repository objects, the source and newly created destination. The location of the source is added to the new repository's .hg/hgrc file, as the default to be used for future pulls and pushes. If an exception is raised, the partly cloned/updated destination repository will be deleted. Arguments: source: repository object or URL dest: URL of destination repository to create (defaults to base name of source repository) pull: always pull from source repository, even in local case stream: stream raw data uncompressed from repository (fast over LAN, slow over WAN) rev: revision to clone up to (implies pull=True) update: update working directory after clone completes, if destination is local repository (True means update to default rev, anything else is treated as a revision) branch: branches to clone """ if isinstance(source, str): origsource = ui.expandpath(source) source, branch = parseurl(origsource, branch) srcrepo = repository(remoteui(ui, peeropts), source) else: srcrepo = source branch = (None, branch or []) origsource = source = srcrepo.url() rev, checkout = addbranchrevs(srcrepo, srcrepo, branch, rev) if dest is None: dest = defaultdest(source) ui.status(_("destination directory: %s\n") % dest) else: dest = ui.expandpath(dest) dest = util.urllocalpath(dest) source = util.urllocalpath(source) if os.path.exists(dest): if not os.path.isdir(dest): raise util.Abort(_("destination '%s' already exists") % dest) elif os.listdir(dest): raise util.Abort(_("destination '%s' is not empty") % dest) class DirCleanup(object): def __init__(self, dir_): self.rmtree = shutil.rmtree self.dir_ = dir_ def close(self): self.dir_ = None def cleanup(self): if self.dir_: self.rmtree(self.dir_, True) srclock = destlock = dircleanup = None try: abspath = origsource if islocal(origsource): abspath = os.path.abspath(util.urllocalpath(origsource)) if islocal(dest): dircleanup = DirCleanup(dest) copy = False if srcrepo.cancopy() and islocal(dest) and not srcrepo.revs("secret()"): copy = not pull and not rev if copy: try: # we use a lock here because if we race with commit, we # can end up with extra data in the cloned revlogs that's # not pointed to by changesets, thus causing verify to # fail srclock = srcrepo.lock(wait=False) except error.LockError: copy = False if copy: srcrepo.hook('preoutgoing', throw=True, source='clone') hgdir = os.path.realpath(os.path.join(dest, ".hg")) if not os.path.exists(dest): os.mkdir(dest) else: # only clean up directories we create ourselves dircleanup.dir_ = hgdir try: destpath = hgdir util.makedir(destpath, notindexed=True) except OSError, inst: if inst.errno == errno.EEXIST: dircleanup.close() raise util.Abort(_("destination '%s' already exists") % dest) raise destlock = copystore(ui, srcrepo, destpath) # we need to re-init the repo after manually copying the data # into it destrepo = repository(remoteui(ui, peeropts), dest) srcrepo.hook('outgoing', source='clone', node=node.hex(node.nullid)) else:
def clone(ui, peeropts, source, dest=None, pull=False, rev=None, update=True, stream=False, branch=None): """Make a copy of an existing repository. Create a copy of an existing repository in a new directory. The source and destination are URLs, as passed to the repository function. Returns a pair of repository peers, the source and newly created destination. The location of the source is added to the new repository's .hg/hgrc file, as the default to be used for future pulls and pushes. If an exception is raised, the partly cloned/updated destination repository will be deleted. Arguments: source: repository object or URL dest: URL of destination repository to create (defaults to base name of source repository) pull: always pull from source repository, even in local case stream: stream raw data uncompressed from repository (fast over LAN, slow over WAN) rev: revision to clone up to (implies pull=True) update: update working directory after clone completes, if destination is local repository (True means update to default rev, anything else is treated as a revision) branch: branches to clone """ if isinstance(source, str): origsource = ui.expandpath(source) source, branch = parseurl(origsource, branch) srcpeer = peer(ui, peeropts, source) else: srcpeer = source.peer() # in case we were called with a localrepo branch = (None, branch or []) origsource = source = srcpeer.url() rev, checkout = addbranchrevs(srcpeer, srcpeer, branch, rev) if dest is None: dest = defaultdest(source) ui.status(_("destination directory: %s\n") % dest) else: dest = ui.expandpath(dest) dest = util.urllocalpath(dest) source = util.urllocalpath(source) if not dest: raise util.Abort(_("empty destination path is not valid")) if os.path.exists(dest): if not os.path.isdir(dest): raise util.Abort(_("destination '%s' already exists") % dest) elif os.listdir(dest): raise util.Abort(_("destination '%s' is not empty") % dest) srclock = destlock = cleandir = None srcrepo = srcpeer.local() try: abspath = origsource if islocal(origsource): abspath = os.path.abspath(util.urllocalpath(origsource)) if islocal(dest): cleandir = dest copy = False if (srcrepo and srcrepo.cancopy() and islocal(dest) and not phases.hassecret(srcrepo)): copy = not pull and not rev if copy: try: # we use a lock here because if we race with commit, we # can end up with extra data in the cloned revlogs that's # not pointed to by changesets, thus causing verify to # fail srclock = srcrepo.lock(wait=False) except error.LockError: copy = False if copy: srcrepo.hook('preoutgoing', throw=True, source='clone') hgdir = os.path.realpath(os.path.join(dest, ".hg")) if not os.path.exists(dest): os.mkdir(dest) else: # only clean up directories we create ourselves cleandir = hgdir try: destpath = hgdir util.makedir(destpath, notindexed=True) except OSError, inst: if inst.errno == errno.EEXIST: cleandir = None raise util.Abort(_("destination '%s' already exists") % dest) raise destlock = copystore(ui, srcrepo, destpath) # Recomputing branch cache might be slow on big repos, # so just copy it dstcachedir = os.path.join(destpath, 'cache') srcbranchcache = srcrepo.sjoin('cache/branchheads') dstbranchcache = os.path.join(dstcachedir, 'branchheads') if os.path.exists(srcbranchcache): if not os.path.exists(dstcachedir): os.mkdir(dstcachedir) util.copyfile(srcbranchcache, dstbranchcache) # we need to re-init the repo after manually copying the data # into it destpeer = peer(srcrepo, peeropts, dest) srcrepo.hook('outgoing', source='clone', node=node.hex(node.nullid)) else:
def clone(ui, peeropts, source, dest=None, pull=False, rev=None, update=True, stream=False, branch=None): """Make a copy of an existing repository. Create a copy of an existing repository in a new directory. The source and destination are URLs, as passed to the repository function. Returns a pair of repository peers, the source and newly created destination. The location of the source is added to the new repository's .hg/hgrc file, as the default to be used for future pulls and pushes. If an exception is raised, the partly cloned/updated destination repository will be deleted. Arguments: source: repository object or URL dest: URL of destination repository to create (defaults to base name of source repository) pull: always pull from source repository, even in local case or if the server prefers streaming stream: stream raw data uncompressed from repository (fast over LAN, slow over WAN) rev: revision to clone up to (implies pull=True) update: update working directory after clone completes, if destination is local repository (True means update to default rev, anything else is treated as a revision) branch: branches to clone """ if isinstance(source, str): origsource = ui.expandpath(source) source, branch = parseurl(origsource, branch) srcpeer = peer(ui, peeropts, source) else: srcpeer = source.peer() # in case we were called with a localrepo branch = (None, branch or []) origsource = source = srcpeer.url() rev, checkout = addbranchrevs(srcpeer, srcpeer, branch, rev) if dest is None: dest = defaultdest(source) if dest: ui.status(_("destination directory: %s\n") % dest) else: dest = ui.expandpath(dest) dest = util.urllocalpath(dest) source = util.urllocalpath(source) if not dest: raise util.Abort(_("empty destination path is not valid")) destvfs = scmutil.vfs(dest, expandpath=True) if destvfs.lexists(): if not destvfs.isdir(): raise util.Abort(_("destination '%s' already exists") % dest) elif destvfs.listdir(): raise util.Abort(_("destination '%s' is not empty") % dest) srclock = destlock = cleandir = None srcrepo = srcpeer.local() try: abspath = origsource if islocal(origsource): abspath = os.path.abspath(util.urllocalpath(origsource)) if islocal(dest): cleandir = dest copy = False if (srcrepo and srcrepo.cancopy() and islocal(dest) and not phases.hassecret(srcrepo)): copy = not pull and not rev if copy: try: # we use a lock here because if we race with commit, we # can end up with extra data in the cloned revlogs that's # not pointed to by changesets, thus causing verify to # fail srclock = srcrepo.lock(wait=False) except error.LockError: copy = False if copy: srcrepo.hook('preoutgoing', throw=True, source='clone') hgdir = os.path.realpath(os.path.join(dest, ".hg")) if not os.path.exists(dest): os.mkdir(dest) else: # only clean up directories we create ourselves cleandir = hgdir try: destpath = hgdir util.makedir(destpath, notindexed=True) except OSError, inst: if inst.errno == errno.EEXIST: cleandir = None raise util.Abort( _("destination '%s' already exists") % dest) raise destlock = copystore(ui, srcrepo, destpath) # copy bookmarks over srcbookmarks = srcrepo.join('bookmarks') dstbookmarks = os.path.join(destpath, 'bookmarks') if os.path.exists(srcbookmarks): util.copyfile(srcbookmarks, dstbookmarks) # Recomputing branch cache might be slow on big repos, # so just copy it def copybranchcache(fname): srcbranchcache = srcrepo.join('cache/%s' % fname) dstbranchcache = os.path.join(dstcachedir, fname) if os.path.exists(srcbranchcache): if not os.path.exists(dstcachedir): os.mkdir(dstcachedir) util.copyfile(srcbranchcache, dstbranchcache) dstcachedir = os.path.join(destpath, 'cache') # In local clones we're copying all nodes, not just served # ones. Therefore copy all branch caches over. copybranchcache('branch2') for cachename in repoview.filtertable: copybranchcache('branch2-%s' % cachename) # we need to re-init the repo after manually copying the data # into it destpeer = peer(srcrepo, peeropts, dest) srcrepo.hook('outgoing', source='clone', node=node.hex(node.nullid)) else:
def clone(ui, peeropts, source, dest=None, pull=False, rev=None, update=True, stream=False, branch=None): """Make a copy of an existing repository. Create a copy of an existing repository in a new directory. The source and destination are URLs, as passed to the repository function. Returns a pair of repository peers, the source and newly created destination. The location of the source is added to the new repository's .hg/hgrc file, as the default to be used for future pulls and pushes. If an exception is raised, the partly cloned/updated destination repository will be deleted. Arguments: source: repository object or URL dest: URL of destination repository to create (defaults to base name of source repository) pull: always pull from source repository, even in local case or if the server prefers streaming stream: stream raw data uncompressed from repository (fast over LAN, slow over WAN) rev: revision to clone up to (implies pull=True) update: update working directory after clone completes, if destination is local repository (True means update to default rev, anything else is treated as a revision) branch: branches to clone """ if isinstance(source, str): origsource = ui.expandpath(source) source, branch = parseurl(origsource, branch) srcpeer = peer(ui, peeropts, source) else: srcpeer = source.peer() # in case we were called with a localrepo branch = (None, branch or []) origsource = source = srcpeer.url() rev, checkout = addbranchrevs(srcpeer, srcpeer, branch, rev) if dest is None: dest = defaultdest(source) if dest: ui.status(_("destination directory: %s\n") % dest) else: dest = ui.expandpath(dest) dest = util.urllocalpath(dest) source = util.urllocalpath(source) if not dest: raise util.Abort(_("empty destination path is not valid")) destvfs = scmutil.vfs(dest, expandpath=True) if destvfs.lexists(): if not destvfs.isdir(): raise util.Abort(_("destination '%s' already exists") % dest) elif destvfs.listdir(): raise util.Abort(_("destination '%s' is not empty") % dest) srclock = destlock = cleandir = None srcrepo = srcpeer.local() try: abspath = origsource if islocal(origsource): abspath = os.path.abspath(util.urllocalpath(origsource)) if islocal(dest): cleandir = dest copy = False if srcrepo and srcrepo.cancopy() and islocal(dest) and not phases.hassecret(srcrepo): copy = not pull and not rev if copy: try: # we use a lock here because if we race with commit, we # can end up with extra data in the cloned revlogs that's # not pointed to by changesets, thus causing verify to # fail srclock = srcrepo.lock(wait=False) except error.LockError: copy = False if copy: srcrepo.hook("preoutgoing", throw=True, source="clone") hgdir = os.path.realpath(os.path.join(dest, ".hg")) if not os.path.exists(dest): os.mkdir(dest) else: # only clean up directories we create ourselves cleandir = hgdir try: destpath = hgdir util.makedir(destpath, notindexed=True) except OSError as inst: if inst.errno == errno.EEXIST: cleandir = None raise util.Abort(_("destination '%s' already exists") % dest) raise destlock = copystore(ui, srcrepo, destpath) # copy bookmarks over srcbookmarks = srcrepo.join("bookmarks") dstbookmarks = os.path.join(destpath, "bookmarks") if os.path.exists(srcbookmarks): util.copyfile(srcbookmarks, dstbookmarks) # Recomputing branch cache might be slow on big repos, # so just copy it def copybranchcache(fname): srcbranchcache = srcrepo.join("cache/%s" % fname) dstbranchcache = os.path.join(dstcachedir, fname) if os.path.exists(srcbranchcache): if not os.path.exists(dstcachedir): os.mkdir(dstcachedir) util.copyfile(srcbranchcache, dstbranchcache) dstcachedir = os.path.join(destpath, "cache") # In local clones we're copying all nodes, not just served # ones. Therefore copy all branch caches over. copybranchcache("branch2") for cachename in repoview.filtertable: copybranchcache("branch2-%s" % cachename) # we need to re-init the repo after manually copying the data # into it destpeer = peer(srcrepo, peeropts, dest) srcrepo.hook("outgoing", source="clone", node=node.hex(node.nullid)) else: try: destpeer = peer(srcrepo or ui, peeropts, dest, create=True) # only pass ui when no srcrepo except OSError as inst: if inst.errno == errno.EEXIST: cleandir = None raise util.Abort(_("destination '%s' already exists") % dest) raise revs = None if rev: if not srcpeer.capable("lookup"): raise util.Abort( _( "src repository does not support " "revision lookup and so doesn't " "support clone by revision" ) ) revs = [srcpeer.lookup(r) for r in rev] checkout = revs[0] if destpeer.local(): if not stream: if pull: stream = False else: stream = None destpeer.local().clone(srcpeer, heads=revs, stream=stream) elif srcrepo: exchange.push(srcrepo, destpeer, revs=revs, bookmarks=srcrepo._bookmarks.keys()) else: raise util.Abort(_("clone from remote to remote not supported")) cleandir = None destrepo = destpeer.local() if destrepo: template = uimod.samplehgrcs["cloned"] fp = destrepo.vfs("hgrc", "w", text=True) u = util.url(abspath) u.passwd = None defaulturl = str(u) fp.write(template % defaulturl) fp.close() destrepo.ui.setconfig("paths", "default", defaulturl, "clone") if update: if update is not True: checkout = srcpeer.lookup(update) uprev = None status = None if checkout is not None: try: uprev = destrepo.lookup(checkout) except error.RepoLookupError: pass if uprev is None: try: uprev = destrepo._bookmarks["@"] update = "@" bn = destrepo[uprev].branch() if bn == "default": status = _("updating to bookmark @\n") else: status = _("updating to bookmark @ on branch %s\n") % bn except KeyError: try: uprev = destrepo.branchtip("default") except error.RepoLookupError: uprev = destrepo.lookup("tip") if not status: bn = destrepo[uprev].branch() status = _("updating to branch %s\n") % bn destrepo.ui.status(status) _update(destrepo, uprev) if update in destrepo._bookmarks: bookmarks.activate(destrepo, update) finally: release(srclock, destlock) if cleandir is not None: shutil.rmtree(cleandir, True) if srcpeer is not None: srcpeer.close() return srcpeer, destpeer