Ejemplo n.º 1
0
def main():
    args = parse_arguments()
    config = get_config(args.config_file, is_test=args.test)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)
    config.use_gpu = config.use_gpu and torch.cuda.is_available()

    # log info
    log_file = os.path.join(config.save_dir, "log_exp_{}.txt".format(config.run_id))
    logger = setup_logging(args.log_level, log_file)
    logger.info("Writing log file to {}".format(log_file))
    logger.info("Exp instance id = {}".format(config.run_id))
    logger.info("Exp comment = {}".format(args.comment))
    logger.info("Config =")
    print(">" * 80)
    pprint(config)
    print("<" * 80)

    # Run the experiment
    try:
        runner = eval(config.runner)(config)
        if not args.test:
            runner.train()
        else:
            runner.test()
    except:
        logger.error(traceback.format_exc())

    sys.exit(0)
Ejemplo n.º 2
0
    print(model)
    print('number of parameters : {}'.format(
        sum([np.prod(x.shape) for x in model.parameters()])))
    if not args.test:
        trained_model = train_graph_generation(args, config, train_loader,
                                               val_loader, test_loader, model)
    else:
        test_model = test(args, config, model, dataset)


if __name__ == '__main__':
    """
    Process command-line arguments, then call main()
    """

    args = parse_arguments()
    config = get_config(args.config_file, is_test=args.test)
    p_name = utils.project_name(config.dataset.name)

    if args.wandb:
        os.environ['WANDB_API_KEY'] = args.wandb_apikey
        wandb.init(project='{}'.format(p_name),
                   name='{}-{}'.format(args.namestr, args.model_name))
    ''' Fix Random Seed '''
    seed_everything(args.seed)
    # Check if settings file
    if os.path.isfile("settings.json"):
        with open('settings.json') as f:
            data = json.load(f)
        args.wandb_apikey = data.get("wandbapikey")
Ejemplo n.º 3
0
    )
    optimizer = optim.Adam(model.parameters(),
                           lr=config.train.lr_init,
                           betas=(0.9, 0.999),
                           eps=1e-8,
                           weight_decay=config.train.weight_decay)
    fit(model,
        optimizer,
        mc_sampler,
        train_dl,
        max_node_number=config.dataset.max_node_num,
        max_epoch=config.train.max_epoch,
        config=config,
        save_interval=config.train.save_interval,
        sample_interval=config.train.sample_interval,
        sigma_list=config.train.sigmas,
        sample_from_sigma_delta=0.0,
        test_dl=test_dl)

    sample_main(config, args)


if __name__ == "__main__":
    # torch.autograd.set_detect_anomaly(True)
    args = parse_arguments('train_ego_small.yaml')
    ori_config_dict = get_config(args)
    config_dict = edict(ori_config_dict.copy())
    process_config(config_dict)
    print(config_dict)
    train_main(config_dict, args)
Ejemplo n.º 4
0
def main():
    args = parse_arguments()
    config = get_bo_config(args.config_file)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)
    config.use_gpu = config.use_gpu and torch.cuda.is_available()
    device = torch.device('cuda' if config.use_gpu else 'cpu')

    # log info
    log_file = os.path.join(config.save_dir,
                            "log_exp_{}.txt".format(config.run_id))
    logger = setup_logging(args.log_level, log_file)
    logger.info("Writing log file to {}".format(log_file))
    logger.info("Exp instance id = {}".format(config.run_id))
    logger.info("Exp comment = {}".format(args.comment))
    logger.info("Config =")
    print(">" * 80)
    pprint(config)
    print("<" * 80)

    #load model
    model = eval(config.model.name)(config.model)
    model_snapshot = torch.load(config.model.pretrained_model,
                                map_location=device)
    model.load_state_dict(model_snapshot["model"], strict=True)
    model.to(device)
    if config.use_gpu:
        model = nn.DataParallel(model, device_ids=config.gpus).cuda()
    # Run the experiment
    results_list = bq_loop(config.bq, model)
    if config.bq.is_GPY:
        if config.bq.is_sparseGP:
            pickle.dump(
                results_list,
                open(
                    os.path.join(
                        config.bq.save_dir,
                        config.bq.name + '_sparseGP_init_p' +
                        str(config.bq.init_num_data) + '_inducing_p' +
                        str(config.bq.num_inducing_pts) + '_results.p'), 'wb'))
        else:
            pickle.dump(
                results_list,
                open(
                    os.path.join(
                        config.bq.save_dir, config.bq.name + '_fullGP_init_p' +
                        str(config.bq.init_num_data) + '_results.p'), 'wb'))
    else:
        if config.bq.is_ai:
            pickle.dump(
                results_list,
                open(
                    os.path.join(
                        config.bq.save_dir, config.bq.name + '_ai_init_p' +
                        str(config.bq.init_num_data) + '_results.p'), 'wb'))
        else:
            pickle.dump(
                results_list,
                open(
                    os.path.join(
                        config.bq.save_dir, config.bq.name + '_opt_init_p' +
                        str(config.bq.init_num_data) + 'iter' +
                        str(config.bq.opt_iter) + 'lr' +
                        str(config.bq.opt_lr) + '_results.p'), 'wb'))
    sys.exit(0)
Ejemplo n.º 5
0
                    list(best_config['best_batch_mmd'].values())):
                best_config = {
                    'step_size_ratio': step_size_ratio,
                    'eps': eps,
                    'best_batch_mmd': valid_result_dict
                }
    logging.info(f'best_config {file} iter: {best_config}')
    _, gen_graph_list = run_sample(
        target_graph_list=test_graph_list,
        step_size_ratio=best_config['step_size_ratio'],
        eps=best_config['eps'],
        validation=False,
        eval_len=1024)
    print(best_config)
    sample_dir = os.path.join(config.save_dir, 'sample_data')
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    with open(
            os.path.join(
                sample_dir, file + f"_{best_config['step_size_ratio']}"
                f"_{best_config['eps']}_sample.pkl"), 'wb') as f:
        pickle.dump(obj=gen_graph_list,
                    file=f,
                    protocol=pickle.HIGHEST_PROTOCOL)


if __name__ == "__main__":
    args = parse_arguments('sample_com_small.yaml')
    config_dict = get_config(args)
    sample_main(config_dict, args)