def load(name): if name == "BlogCatalog": dataset = data_loader.BlogCatalogDataset() elif name == "ArXiv": dataset = data_loader.ArXivDataset() else: raise ValueError(name + " dataset doesn't exists") return dataset
def train(args): import logging log.setLevel(logging.DEBUG) log.info("start") worker_num = int(os.getenv("PADDLE_TRAINERS_NUM", "0")) num_devices = int(os.getenv("CPU_NUM", 10)) model = DeepwalkModel(args.num_nodes, args.hidden_size, args.neg_num, args.is_sparse, args.is_distributed, 1.) pyreader = model.pyreader loss = model.forward() # init fleet init_role() train_steps = math.ceil(1. * args.num_nodes * args.epoch / args.batch_size / num_devices / worker_num) log.info("Train step: %s" % train_steps) if args.optimizer == "sgd": args.lr *= args.batch_size * args.walk_len * args.win_size optimization(args.lr, loss, train_steps, args.optimizer) # init and run server or worker if fleet.is_server(): fleet.init_server(args.warm_start_from_dir) fleet.run_server() if fleet.is_worker(): log.info("start init worker done") fleet.init_worker() #just the worker, load the sample log.info("init worker done") exe = F.Executor(F.CPUPlace()) exe.run(fleet.startup_program) log.info("Startup done") if args.dataset is not None: if args.dataset == "BlogCatalog": graph = data_loader.BlogCatalogDataset().graph elif args.dataset == "ArXiv": graph = data_loader.ArXivDataset().graph else: raise ValueError(args.dataset + " dataset doesn't exists") log.info("Load buildin BlogCatalog dataset done.") elif args.walkpath_files is None or args.walkpath_files == "None": graph = build_graph(args.num_nodes, args.edge_path) log.info("Load graph from '%s' done." % args.edge_path) else: graph = build_fake_graph(args.num_nodes) log.info("Load fake graph done.") # bind gen gen_func = build_gen_func(args, graph) pyreader.decorate_tensor_provider(gen_func) pyreader.start() compiled_prog = build_complied_prog(fleet.main_program, loss) train_prog(exe, compiled_prog, loss, pyreader, args, train_steps)