コード例 #1
0
 def test_balanced_split_errors(self):
     for N in -1, 0:
         with self.assertRaises(ValueError):
             list(common.balanced_split([42], N))
     with self.assertRaises(ValueError):
         list(common.balanced_split([42], 2))
     with self.assertRaises(ValueError):
         list(common.balanced_split([], 1))
コード例 #2
0
def run_mapred(model, input_dirs, output_dir, nmaps, log_level, collate=False):
    wd = tempfile.mkdtemp(prefix="pydeep_")
    zip_fn = os.path.join(wd, "{}.zip".format(PACKAGE))
    shutil.make_archive(*zip_fn.rsplit(".", 1), base_dir=PACKAGE)
    if nmaps > len(input_dirs):
        nmaps = len(input_dirs)
        LOGGER.warn("Not enough input dirs, will only do %d splits" % nmaps)
    splits = common.balanced_split(input_dirs, nmaps)
    splits_uri = "pydoop_splits_%s" % uuid.uuid4().hex
    with hdfs.open(splits_uri, 'wb') as f:
        write_opaques([OpaqueInputSplit(1, _) for _ in splits], f)
    submitter = PydoopSubmitter()
    properties = {
        common.GRAPH_ARCH_KEY: model.name,
        common.LOG_LEVEL_KEY: log_level,
        common.NUM_MAPS_KEY: nmaps,
        common.PYDOOP_EXTERNALSPLITS_URI_KEY: splits_uri,
    }
    submitter.set_args(
        argparse.Namespace(
            D=list(properties.items()),
            avro_input=None,
            avro_output=None,
            cache_archive=None,
            cache_file=None,
            disable_property_name_conversion=True,
            do_not_use_java_record_reader=True,
            do_not_use_java_record_writer=True,
            entry_point="__main__",
            hadoop_conf=None,
            input=input_dirs[0],  # does it matter?
            input_format=None,
            job_conf=None,
            job_name="dump_weights",
            keep_wd=False,
            libjars=None,
            log_level=log_level,
            module=os.path.splitext(os.path.basename(__file__))[0],
            no_override_env=False,
            no_override_home=False,
            no_override_ld_path=False,
            no_override_path=False,
            no_override_pypath=False,
            num_reducers=0,
            output=output_dir,
            output_format=None,
            pretend=False,
            pstats_dir=None,
            python_program=sys.executable,
            python_zip=[zip_fn],
            set_env=None,
            upload_archive_to_cache=None,
            upload_file_to_cache=[__file__],
        ))
    submitter.run()
    hdfs.rmr(splits_uri)
    if collate:
        collate_mapred_output(output_dir)
    shutil.rmtree(wd)
コード例 #3
0
 def test_balanced_split(self):
     for seq_len in 15, 16, 17, 18, 19, 20:
         seq, N = list(range(seq_len)), 4
         groups = list(common.balanced_split(seq, N))
         self.assertEqual(len(groups), N)
         self.assertEqual(sum(groups, []), seq)
         sg = sorted(groups, key=len)
         self.assertTrue(len(sg[-1]) - len(sg[0]) <= 1)
コード例 #4
0
def generate_input_splits(N, bneck_map, splits_path):
    """\
    Assign to each split a chunk of bottlenecks across all classes.
    """
    for locs in bneck_map.values():
        random.shuffle(locs)
    bneck_map = {d: list(common.balanced_split(locs, N))
                 for d, locs in bneck_map.items()}
    splits = [{d: seq[i] for d, seq in bneck_map.items()} for i in range(N)]
    LOGGER.debug("saving input splits to: %s", splits_path)
    with hdfs.open(splits_path, 'wb') as f:
        write_opaques([OpaqueInputSplit(1, _) for _ in splits], f)
コード例 #5
0
def main(argv=None):

    os.chdir(os.path.dirname(os.path.abspath(__file__)))
    wd = tempfile.mkdtemp(prefix="pydeep_")
    zip_fn = os.path.join(wd, "{}.zip".format(PACKAGE))
    shutil.make_archive(*zip_fn.rsplit(".", 1), base_dir=PACKAGE)

    parser = make_parser()
    args, unknown_args = parser.parse_known_args(argv)
    args.job_name = WORKER
    args.module = WORKER
    args.upload_file_to_cache = ['%s.py' % WORKER]
    args.python_zip = [zip_fn]
    args.do_not_use_java_record_reader = True
    args.do_not_use_java_record_writer = True
    args.num_reducers = 0

    LOGGER.setLevel(args.log_level)
    model = get_model_info(args.architecture)
    get_graph(model, log_level=args.log_level)

    images = list_images(args.input)
    splits = common.balanced_split(images, args.num_maps)
    uri = os.path.join(args.input, '_' + uuid.uuid4().hex)
    LOGGER.debug("saving input splits to: %s", uri)
    with hdfs.open(uri, 'wb') as f:
        write_opaques([OpaqueInputSplit(1, _) for _ in splits], f)
    submitter = PydoopSubmitter()
    submitter.set_args(args, [] if unknown_args is None else unknown_args)
    submitter.properties.update({
        common.NUM_MAPS_KEY: args.num_maps,
        common.GRAPH_ARCH_KEY: args.architecture,
        common.PYDOOP_EXTERNALSPLITS_URI_KEY: uri,
    })
    submitter.run()
    hdfs.rmr(uri)
    shutil.rmtree(wd)