示例#1
0
def train_partition(idx, iterator):
    port = 50000 + idx % 256
    main = SparkFiles.get("main.py")
    architecture = SparkFiles.get("train_val.prototxt")
    model = SparkFiles.get("deepq16.caffemodel")
    solver = SparkFiles.get("solver.prototxt")
    root = SparkFiles.getRootDirectory()
    dset = os.path.join(root, "dset-%02d.hdf5" % idx)

    flag_file = "flags/__BARISTA_READY__.%d" % port
    if os.path.isfile(flag_file):
        os.remove(flag_file)

    #  out = open(os.path.join(root, "barista.log"), 'w')
    subprocess.Popen(["python", main, architecture, model,
                      "--dataset", dset,
                      "--solver", solver,
                      "--dset-size", "30000",
                      "--initial-replay", "20000",
                      "--debug",
                      "--overwrite",
                      "--port", str(port)])

    while not os.path.isfile(flag_file):
        pass

    for step in iterator:
        dc = DummyClient("127.0.0.1", port)
        dc.send(barista.GRAD_UPDATE)
        response = dc.recv()
        yield response
示例#2
0
def train_partition(idx, iterator):
    port = 50000 + idx % 256
    main = SparkFiles.get("main.py")
    architecture = SparkFiles.get("train_val.prototxt")
    model = SparkFiles.get("deepq16.caffemodel")
    solver = SparkFiles.get("solver.prototxt")
    root = SparkFiles.getRootDirectory()
    dset = os.path.join(root, "dset-%02d.hdf5" % idx)

    flag_file = "flags/__BARISTA_READY__.%d" % port
    if os.path.isfile(flag_file):
        os.remove(flag_file)

    #  out = open(os.path.join(root, "barista.log"), 'w')
    subprocess.Popen([
        "python", main, architecture, model, "--dataset", dset, "--solver",
        solver, "--dset-size", "30000", "--initial-replay", "20000", "--debug",
        "--overwrite", "--port",
        str(port)
    ])

    while not os.path.isfile(flag_file):
        pass

    for step in iterator:
        dc = DummyClient("127.0.0.1", port)
        dc.send(barista.GRAD_UPDATE)
        response = dc.recv()
        yield response
示例#3
0
def sgd_step(step_num):
    dc = DummyClient("127.0.0.1", 50001)
    dc.send(barista.GRAD_UPDATE)
    response = dc.recv()
    return response
示例#4
0
def sgd_step(step_num):
    dc = DummyClient("127.0.0.1", 50001)
    dc.send(barista.GRAD_UPDATE)
    response = dc.recv()
    return response