def predict(config): model = getattr(models, config.model.name).GNNModel( **dict(config.model.items())) _create_if_not_exist(config.output_path) load_model(config.output_path, model) model.eval() pred_temp = [] dataset = MAG240M(config.data_dir, seed=123) evaluator = MAG240MEvaluator() dataset.prepare_data() test_iter = DataGenerator( dataset=dataset, samples=[160] * len(config.samples), batch_size=16, num_workers=config.num_workers, data_type="test") for batch in test_iter.generator(): graph_list, x, y = batch x = paddle.to_tensor(x, dtype='float32') y = paddle.to_tensor(y, dtype='int64') graph_list = [(item[0].tensor(), paddle.to_tensor(item[2])) for item in graph_list] out = model(graph_list, x) pred_temp.append(out.numpy()) pred_temp = np.concatenate(pred_temp, axis=0) y_pred = pred_temp.argmax(axis=-1) res = {'y_pred': y_pred} evaluator.save_test_submission(res, 'results')
def infer(config, do_eval=False): if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() dataset = MAG240M(config) evaluator = MAG240MEvaluator() dataset.prepare_data() model = getattr(models, config.model.name).GNNModel(**dict(config.model.items())) if paddle.distributed.get_world_size() > 1: model = paddle.DataParallel(model) loss_func = F.cross_entropy _create_if_not_exist(config.output_path) load_model(config.output_path, model) if paddle.distributed.get_rank() == 0: file = 'model_result_temp' sudo_label = np.memmap(file, dtype=np.float32, mode='w+', shape=(121751666, 153)) if do_eval: valid_iter = DataGenerator( dataset=dataset, samples=[200] * len(config.samples), batch_size=64, num_workers=config.num_workers, data_type="eval") r = evaluate(valid_iter, model, loss_func, config, evaluator, dataset) log.info("finish eval") test_iter = DataGenerator( dataset=dataset, samples=[200] * len(config.samples), batch_size=64, num_workers=config.num_workers, data_type="test") r = evaluate(test_iter, model, loss_func, config, evaluator, dataset) log.info("finish test")
def train(config, do_eval=False): if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() dataset = MAG240M(config.data_dir, seed=123) evaluator = MAG240MEvaluator() dataset.prepare_data() train_iter = DataGenerator(dataset=dataset, samples=config.samples, batch_size=config.batch_size, num_workers=config.num_workers, data_type="train") valid_iter = DataGenerator(dataset=dataset, samples=config.samples, batch_size=config.batch_size, num_workers=config.num_workers, data_type="eval") model = getattr(models, config.model.name).GNNModel(**dict(config.model.items())) if paddle.distributed.get_world_size() > 1: model = paddle.DataParallel(model) loss_func = F.cross_entropy opt, lr_scheduler = optim.get_optimizer( parameters=model.parameters(), learning_rate=config.lr, max_steps=config.max_steps, weight_decay=config.weight_decay, warmup_proportion=config.warmup_proportion, clip=config.clip, use_lr_decay=config.use_lr_decay) _create_if_not_exist(config.output_path) load_model(config.output_path, model) swriter = SummaryWriter(os.path.join(config.output_path, 'log')) if do_eval and paddle.distributed.get_rank() == 0: valid_iter = DataGenerator(dataset=dataset, samples=[160] * len(config.samples), batch_size=16, num_workers=config.num_workers, data_type="eval") r = evaluate(valid_iter, model, loss_func, config, evaluator, dataset) log.info(dict(r)) else: best_valid_acc = -1 for e_id in range(config.epochs): loss_temp = [] for batch in tqdm.tqdm(train_iter.generator()): loss = train_step(model, loss_func, batch, dataset) log.info(loss.numpy()[0]) loss.backward() opt.step() opt.clear_gradients() loss_temp.append(loss.numpy()[0]) if lr_scheduler is not None: lr_scheduler.step() loss = np.mean(loss_temp) log.info("Epoch %s Train Loss: %s" % (e_id, loss)) swriter.add_scalar('loss', loss, e_id) if e_id >= config.eval_step and e_id % config.eval_per_steps == 0 and \ paddle.distributed.get_rank() == 0: r = evaluate(valid_iter, model, loss_func, config, evaluator, dataset) log.info(dict(r)) for key, value in r.items(): swriter.add_scalar('eval/' + key, value, e_id) best_valid_acc = max(best_valid_acc, r['acc']) if best_valid_acc == r['acc']: save_model(config.output_path, model, e_id, opt, lr_scheduler) swriter.close()