Ejemplo n.º 1
0
def load(name):
    if name == "BlogCatalog":
        dataset = data_loader.BlogCatalogDataset()
    elif name == "ArXiv":
        dataset = data_loader.ArXivDataset()
    else:
        raise ValueError(name + " dataset doesn't exists")
    return dataset
Ejemplo n.º 2
0
def train(args):
    import logging
    log.setLevel(logging.DEBUG)
    log.info("start")

    worker_num = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
    num_devices = int(os.getenv("CPU_NUM", 10))

    model = DeepwalkModel(args.num_nodes, args.hidden_size, args.neg_num,
                          args.is_sparse, args.is_distributed, 1.)
    pyreader = model.pyreader
    loss = model.forward()

    # init fleet
    init_role()

    train_steps = math.ceil(1. * args.num_nodes * args.epoch /
                            args.batch_size / num_devices / worker_num)
    log.info("Train step: %s" % train_steps)

    if args.optimizer == "sgd":
        args.lr *= args.batch_size * args.walk_len * args.win_size
    optimization(args.lr, loss, train_steps, args.optimizer)

    # init and run server or worker
    if fleet.is_server():
        fleet.init_server(args.warm_start_from_dir)
        fleet.run_server()

    if fleet.is_worker():
        log.info("start init worker done")
        fleet.init_worker()
        #just the worker, load the sample
        log.info("init worker done")

        exe = F.Executor(F.CPUPlace())
        exe.run(fleet.startup_program)
        log.info("Startup done")

        if args.dataset is not None:
            if args.dataset == "BlogCatalog":
                graph = data_loader.BlogCatalogDataset().graph
            elif args.dataset == "ArXiv":
                graph = data_loader.ArXivDataset().graph
            else:
                raise ValueError(args.dataset + " dataset doesn't exists")
            log.info("Load buildin BlogCatalog dataset done.")
        elif args.walkpath_files is None or args.walkpath_files == "None":
            graph = build_graph(args.num_nodes, args.edge_path)
            log.info("Load graph from '%s' done." % args.edge_path)
        else:
            graph = build_fake_graph(args.num_nodes)
            log.info("Load fake graph done.")

        # bind gen
        gen_func = build_gen_func(args, graph)

        pyreader.decorate_tensor_provider(gen_func)
        pyreader.start()

        compiled_prog = build_complied_prog(fleet.main_program, loss)
        train_prog(exe, compiled_prog, loss, pyreader, args, train_steps)