Пример #1
0
def main():
    config = get_config()
    if config.resume:
        json_config = json.load(open(config.resume + '/config.json', 'r'))
        json_config['resume'] = config.resume
        config = edict(json_config)

    if config.is_cuda and not torch.cuda.is_available():
        raise Exception("No GPU found")
    device = get_torch_device(config.is_cuda)

    logging.info('===> Configurations')
    dconfig = vars(config)
    for k in dconfig:
        logging.info('    {}: {}'.format(k, dconfig[k]))

    DatasetClass = load_dataset(config.dataset)
    if config.test_original_pointcloud:
        if not DatasetClass.IS_FULL_POINTCLOUD_EVAL:
            raise ValueError(
                'This dataset does not support full pointcloud evaluation.')

    if config.evaluate_original_pointcloud:
        if not config.return_transformation:
            raise ValueError(
                'Pointcloud evaluation requires config.return_transformation=true.'
            )

    if (config.return_transformation ^ config.evaluate_original_pointcloud):
        raise ValueError(
            'Rotation evaluation requires config.evaluate_original_pointcloud=true and '
            'config.return_transformation=true.')

    logging.info('===> Initializing dataloader')
    if config.is_train:
        train_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            phase=config.train_phase,
            threads=config.threads,
            augment_data=True,
            shuffle=True,
            repeat=True,
            batch_size=config.batch_size,
            limit_numpoints=config.train_limit_numpoints)

        val_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            threads=config.val_threads,
            phase=config.val_phase,
            augment_data=False,
            shuffle=True,
            repeat=False,
            batch_size=config.val_batch_size,
            limit_numpoints=False)
        if train_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3  # RGB color

        num_labels = train_data_loader.dataset.NUM_LABELS
    else:
        test_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            threads=config.threads,
            phase=config.test_phase,
            augment_data=False,
            shuffle=False,
            repeat=False,
            batch_size=config.test_batch_size,
            limit_numpoints=False)
        if test_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3  # RGB color

        num_labels = test_data_loader.dataset.NUM_LABELS

    logging.info('===> Building model')
    NetClass = load_model(config.model)
    if config.wrapper_type == 'None':
        model = NetClass(num_in_channel, num_labels, config)
        logging.info('===> Number of trainable parameters: {}: {}'.format(
            NetClass.__name__, count_parameters(model)))
    else:
        wrapper = load_wrapper(config.wrapper_type)
        model = wrapper(NetClass, num_in_channel, num_labels, config)
        logging.info('===> Number of trainable parameters: {}: {}'.format(
            wrapper.__name__ + NetClass.__name__, count_parameters(model)))

    logging.info(model)
    model = model.to(device)

    if config.weights == 'modelzoo':  # Load modelzoo weights if possible.
        logging.info('===> Loading modelzoo weights')
        model.preload_modelzoo()

    # Load weights if specified by the parameter.
    elif config.weights.lower() != 'none':
        logging.info('===> Loading weights: ' + config.weights)
        state = torch.load(config.weights)
        if config.weights_for_inner_model:
            model.model.load_state_dict(state['state_dict'])
        else:
            if config.lenient_weight_loading:
                matched_weights = load_state_with_same_shape(
                    model, state['state_dict'])
                model_dict = model.state_dict()
                model_dict.update(matched_weights)
                model.load_state_dict(model_dict)
            else:
                model.load_state_dict(state['state_dict'])

    if config.is_train:
        train(model, train_data_loader, val_data_loader, config)
    else:
        test(model, test_data_loader, config)
Пример #2
0
    net.to(device)
    agent = ptan.agent.ActorCriticAgent(net, device=device, apply_softmax=True)

    exp_src = ptan.experience.ExperienceSourceFirstLast(envs, agent, params.gamma,steps_count=params.steps)
    generator = utils.BatchGenerator(exp_src, params)
    mean_monitor = utils.MeanRewardsMonitor(envs[0], net, 'A2C', params.solve_rewards)

    writer = SummaryWriter(logdir=mean_monitor.runs_dir,comment=params.frame_stack)

    optimizer = torch.optim.Adam(net.parameters(), lr=params.lr)
    
    # lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.75, patience=20000,
    #                                                        cooldown=20000, verbose=True, min_lr=params.min_lr)

    
    print('# Parameters: ', utils.count_parameters(net))
    print(net)
    print('*'*10, ' Start Training ',
          envs[0].game, ' {} '.format(device), '*'*10)
    

    frame = 0
    episode = 0
    with ptan.common.utils.RewardTracker(writer) as tracker:
        for batch in generator:
            reward = generator.pop_total_rewards()
            if reward:
                episode += 1
                mean = tracker.reward(
                    reward[0], generator.frame)
                if mean_monitor(mean):
    d      = 1
    alph   = args.alph
    nt     = args.nt
    nt_val = args.nt_val
    nTh    = args.nTh
    m      = args.m
    net = Phi(nTh=nTh, m=args.m, d=d, alph=alph)
    net = net.to(prec).to(device)

    optim = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay ) # lr=0.04 good

    logger.info(net)
    logger.info("-------------------------")
    logger.info("DIMENSION={:}  m={:}  nTh={:}   alpha={:}".format(d,m,nTh,alph))
    logger.info("nt={:}   nt_val={:}".format(nt,nt_val))
    logger.info("Number of trainable parameters: {}".format(count_parameters(net)))
    logger.info("-------------------------")
    logger.info(str(optim)) # optimizer info
    logger.info("data={:} batch_size={:} gpu={:}".format(args.data, args.batch_size, args.gpu))
    logger.info("maxIters={:} val_freq={:} viz_freq={:}".format(args.niters, args.val_freq, args.viz_freq))
    logger.info("saveLocation = {:}".format(args.save))
    logger.info("-------------------------\n")

    end = time.time()
    best_loss = float('inf')
    bestParams = None

    # x0 = toy_data.inf_train_gen(args.data, batch_size=args.batch_size)
    x0=torch.randn(args.batch_size,1)-3+6*((torch.rand(args.batch_size,1))>0.5).float()
    x0 = cvt(x0)
Пример #4
0
def infer_with_cfg(args, cfg):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)

    out_dir = os.path.join(args.out_dir, "lsvrd_features", args.cfg_name)
    os.makedirs(out_dir, exist_ok=True)
    torch.backends.cudnn.benchmark = True

    print("model configs:")
    print(json.dumps(cfg, indent=2))
    print()

    print("run args:")
    for arg in vars(args):
        print("%10s: %s" % (arg, str(getattr(args, arg))))
    print()

    word_dict = SymbolDictionary.load_from_file(cfg.word_dict)
    pred_dict = SymbolDictionary.load_from_file(cfg.pred_dict)

    print("building language model")
    word_emb = WordEmbedding.build_from_config(cfg.language_model, word_dict).cuda()
    word_emb.init_embedding(cfg.language_model.word_emb_init)
    word_emb.freeze()
    language_model = LanguageModel.build_from_config(cfg.language_model)
    language_model = language_model.cuda()
    lckpt = torch.load(args.lckpt)
    language_model.load_state_dict(lckpt)
    language_model.train(False)
    language_model.eval()
    n_l_params = count_parameters(language_model)
    print("language model: {:,} parameters".format(n_l_params))

    print("obtaining predicate embeddings")
    pred_emb = get_sym_emb(word_emb, language_model, word_dict, pred_dict, cfg.language_model.tokens_length)

    print("building vision model")
    with torch.no_grad():
        vision_model = VisionModel.build_from_config(cfg.vision_model)
        vision_model = vision_model.cuda()
        vckpt = torch.load(args.vckpt)
        vision_model.load_state_dict(vckpt)
        vision_model.train(False)
        vision_model.eval()
    n_v_params = count_parameters(vision_model)
    print("vision model: {:,} parameters".format(n_v_params))

    # load objects h5
    info_path = os.path.join(args.objects_dir, "info.json")
    info = json.load(open(info_path))
    h5_names = [ name for name in os.listdir(args.objects_dir) if name.endswith('.h5') ]
    h5_names = sorted(h5_names)
    h5_paths = [ os.path.join(args.objects_dir, name) for name in h5_names ]
    h5s = [ h5py.File(path, "r+") for path in h5_paths ]

    # load pre extracted features
    print("creating h5 loader")
    fields = [{ "name": "features",
                "shape": [cfg.vision_model.feature_dim, cfg.vision_model.feature_height, cfg.vision_model.feature_width],
                "dtype": "float32",
                "preload": False }]
    image_ids = [image_id for image_id in info["indices"].keys()]
    image_ids = list(set(image_ids))
    loader = H5DataLoader.load_from_directory(cfg.vision_model.cache_dir, fields, image_ids)

    print("inference started")
    infer_rel_only(vision_model, pred_emb, loader, h5s, info, args, cfg)

    for h5 in h5s: h5.close()
Пример #5
0
def train_with_config(args, cfg):

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)

    out_dir = os.path.join(args.out_dir, args.cfg_name)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.benchmark = True

    print("model configs:")
    print(json.dumps(cfg, indent=2))
    print()

    print("run args:")
    for arg in vars(args):
        print("%10s: %s" % (arg, str(getattr(args, arg))))
    print()

    print("parsing dictionaries")
    word_dict = SymbolDictionary.load_from_file(cfg.word_dict)
    ent_dict = SymbolDictionary.load_from_file(cfg.ent_dict)
    attr_dict = SymbolDictionary.load_from_file(cfg.attr_dict)
    pred_dict = SymbolDictionary.load_from_file(cfg.pred_dict)

    print("building model")
    word_emb = WordEmbedding.build_from_config(cfg.language_model,
                                               word_dict).cuda()
    word_emb.init_embedding(cfg.language_model.word_emb_init)
    word_emb.freeze()
    vision_model = VisionModel.build_from_config(cfg.vision_model).cuda()
    language_model = LanguageModel.build_from_config(cfg.language_model).cuda()
    ent_loss = LossModel.build_from_config(cfg.ent_loss).cuda()
    rel_loss = LossModel.build_from_config(cfg.rel_loss).cuda()

    n_v_params = count_parameters(vision_model)
    n_l_params = count_parameters(language_model)
    print("vision model: {:,} parameters".format(n_v_params))
    print("language model: {:,} parameters".format(n_l_params))
    print()

    print("loading train data...")
    train_set = GQATriplesDataset.create(cfg,
                                         word_dict,
                                         ent_dict,
                                         attr_dict,
                                         pred_dict,
                                         cfg.train.triples_path,
                                         mode="train",
                                         preload=args.preload)
    train_loader = DataLoader(train_set,
                              batch_size=cfg.train.batch_size,
                              shuffle=True,
                              num_workers=args.n_workers)

    print("loading val data...")
    val_set = GQATriplesDataset.create(cfg,
                                       word_dict,
                                       ent_dict,
                                       attr_dict,
                                       pred_dict,
                                       cfg.val.triples_path,
                                       mode="eval",
                                       preload=args.preload)
    val_loader = DataLoader(val_set,
                            batch_size=cfg.val.batch_size,
                            shuffle=True,
                            num_workers=args.n_workers)

    print("training started...")
    train(word_emb, vision_model, language_model, ent_loss, rel_loss,
          train_loader, val_loader, word_dict, ent_dict, pred_dict,
          args.n_epochs, args.val_freq, out_dir, cfg, args.grad_freq)
Пример #6
0
def main(config, init_distributed=False):

    if not torch.cuda.is_available():
        raise Exception('No GPUs FOUND.')

    # setup initial seed
    torch.cuda.set_device(config.device_id)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed(config.seed)

    device = config.device_id
    distributed = config.distributed_world_size > 1

    if init_distributed:
        config.distributed_rank = distributed_utils.distributed_init(config)

    setup_logging(config)

    logging.info('===> Configurations')
    dconfig = vars(config)
    for k in dconfig:
        logging.info('    {}: {}'.format(k, dconfig[k]))

    DatasetClass = load_dataset(config.dataset)
    if config.test_original_pointcloud:
        if not DatasetClass.IS_FULL_POINTCLOUD_EVAL:
            raise ValueError(
                'This dataset does not support full pointcloud evaluation.')

    if config.evaluate_original_pointcloud:
        if not config.return_transformation:
            raise ValueError(
                'Pointcloud evaluation requires config.return_transformation=true.'
            )

    if (config.return_transformation ^ config.evaluate_original_pointcloud):
        raise ValueError(
            'Rotation evaluation requires config.evaluate_original_pointcloud=true and '
            'config.return_transformation=true.')

    logging.info('===> Initializing dataloader')
    if config.is_train:
        train_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            phase=config.train_phase,
            num_workers=config.num_workers,
            augment_data=True,
            shuffle=True,
            repeat=True,
            batch_size=config.batch_size,
            limit_numpoints=config.train_limit_numpoints)

        val_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            num_workers=config.num_val_workers,
            phase=config.val_phase,
            augment_data=False,
            shuffle=True,
            repeat=False,
            batch_size=config.val_batch_size,
            limit_numpoints=False)

        if train_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3  # RGB color

        num_labels = train_data_loader.dataset.NUM_LABELS

    else:

        test_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            num_workers=config.num_workers,
            phase=config.test_phase,
            augment_data=False,
            shuffle=False,
            repeat=False,
            batch_size=config.test_batch_size,
            limit_numpoints=False)

        if test_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3  # RGB color

        num_labels = test_data_loader.dataset.NUM_LABELS

    logging.info('===> Building model')
    NetClass = load_model(config.model)
    if config.wrapper_type == 'None':
        model = NetClass(num_in_channel, num_labels, config)
        logging.info('===> Number of trainable parameters: {}: {}'.format(
            NetClass.__name__, count_parameters(model)))
    else:
        wrapper = load_wrapper(config.wrapper_type)
        model = wrapper(NetClass, num_in_channel, num_labels, config)
        logging.info('===> Number of trainable parameters: {}: {}'.format(
            wrapper.__name__ + NetClass.__name__, count_parameters(model)))

    logging.info(model)

    if config.weights == 'modelzoo':  # Load modelzoo weights if possible.
        logging.info('===> Loading modelzoo weights')
        model.preload_modelzoo()

    # Load weights if specified by the parameter.
    elif config.weights.lower() != 'none':
        logging.info('===> Loading weights: ' + config.weights)
        # state = torch.load(config.weights)
        state = torch.load(
            config.weights,
            map_location=lambda s, l: default_restore_location(s, 'cpu'))

        if config.weights_for_inner_model:
            model.model.load_state_dict(state['state_dict'])
        else:
            if config.lenient_weight_loading:
                matched_weights = load_state_with_same_shape(
                    model, state['state_dict'])
                model_dict = model.state_dict()
                model_dict.update(matched_weights)
                model.load_state_dict(model_dict)
            else:
                model.load_state_dict(state['state_dict'])

    model = model.cuda()
    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            module=model,
            device_ids=[device],
            output_device=device,
            broadcast_buffers=False,
            bucket_cap_mb=config.bucket_cap_mb)

    if config.is_train:
        train(model, train_data_loader, val_data_loader, config)
    else:
        test(model, test_data_loader, config)
Пример #7
0
    y         = cvt(torch.randn(nSamples,d))

    net.eval()
    with torch.no_grad():

        test_loss, test_cs = compute_loss(net, p_samples, args.nt)

        # sample_fn, density_fn = get_transforms(model)
        modelFx     = integrate(p_samples[:, 0:d], net, [0.0, 1.0], args.nt, stepper="rk4", alph=net.alph)
        modelFinvfx = integrate(modelFx[:, 0:d]  , net, [1.0, 0.0], args.nt, stepper="rk4", alph=net.alph)
        modelGen    = integrate(y[:, 0:d]        , net, [1.0, 0.0], args.nt, stepper="rk4", alph=net.alph)

        print("          {:9s}  {:9s}  {:11s}  {:9s}".format( "loss", "L (L_2)", "C (loss)", "R (HJB)"))
        print("[TEST]:   {:9.3e}  {:9.3e}  {:11.5e}  {:9.3e}".format(test_loss, test_cs[0], test_cs[1], test_cs[2]))

        print("Using ", utils.count_parameters(net), " parameters")
        invErr = np.mean(np.linalg.norm(p_samples.detach().cpu().numpy() - modelFinvfx[:,:d].detach().cpu().numpy(), ord=2, axis=1))
        # invErr = (torch.norm(p_samples-modelFinvfx[:,:d]) / p_samples.size(0)).item()
        print("inv error: ", invErr )

        modelGen = modelGen[:, 0:d].detach().cpu().numpy()
        p_samples = p_samples.detach().cpu().numpy()

        nBins = 80
        LOW = -4
        HIGH = 4
        extent = [[LOW, HIGH], [LOW, HIGH]]

        d1 = 0
        d2 = 1
Пример #8
0
    if args.val_freq == 0:
        # if val_freq set to 0, then validate after every epoch....assume mnist train 50000
        args.val_freq = math.ceil(50000 / args.batch_size)

    # ADAM optimizer
    optim = torch.optim.Adam(net.parameters(),
                             lr=args.lr,
                             weight_decay=args.weight_decay)

    logger.info(net)
    logger.info("-------------------------")
    logger.info("DIMENSION={:}  m={:}  nTh={:}   alpha={:}".format(
        d, m, nTh, net.alph))
    logger.info("nt={:}   nt_val={:}".format(nt, nt_val))
    logger.info("Number of trainable parameters: {}".format(
        count_parameters(net)))
    logger.info("-------------------------")
    logger.info(str(optim))  # optimizer info
    logger.info("data={:} batch_size={:} gpu={:}".format(
        args.data, args.batch_size, args.gpu))
    logger.info("maxIters={:} val_freq={:} viz_freq={:}".format(
        args.niters, args.val_freq, args.viz_freq))
    logger.info("saveLocation = {:}".format(args.save))
    logger.info("-------------------------\n")

    begin = time.time()
    end = begin
    best_loss = float('inf')
    best_costs = [0.0] * 3
    best_params = None
Пример #9
0
                                              params.eps_final,
                                              params.eps_frames)

    exp_src = ptan.experience.ExperienceSourceFirstLast(
        envs, agent, gamma=params.gamma, steps_count=params.steps)
    buffer = ptan.experience.ExperienceReplayBuffer(exp_src,
                                                    params.buffer_size)

    mean_monitor = utils.MeanRewardsMonitor(env, net, ALGORITHM,
                                            params.solve_rewards)

    writer = SummaryWriter(logdir=mean_monitor.runs_dir,
                           comment=str(params.n_envs))
    writer.add_text(ALGORITHM + ' HParams', str(vars(params)))
    writer.add_text('Number of Trainable Parameters',
                    str(utils.count_parameters(net)))

    optimizer = torch.optim.Adam(
        net.parameters(),
        lr=params.lr,
    )

    print(net)
    print('*' * 10, ' Start Training ', env.game, ' {} '.format(device),
          '*' * 10)

    frame = 0
    episode = 0
    with ptan.common.utils.RewardTracker(writer) as tracker:
        while True:
            frame += params.n_envs
def main_worker(gpu, ngpus_per_node, config):
    config.gpu = gpu

    #if config.is_cuda and not torch.cuda.is_available():
    #  raise Exception("No GPU found")
    if config.gpu is not None:
        print("Use GPU: {} for training".format(config.gpu))
    device = get_torch_device(config.is_cuda)

    if config.distributed:
        if config.dist_url == "env://" and config.rank == -1:
            config.rank = int(os.environ["RANK"])
        if config.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            config.rank = config.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=config.dist_backend,
                                init_method=config.dist_url,
                                world_size=config.world_size,
                                rank=config.rank)

    logging.info('===> Configurations')
    dconfig = vars(config)
    for k in dconfig:
        logging.info('    {}: {}'.format(k, dconfig[k]))

    DatasetClass = load_dataset(config.dataset)
    if config.test_original_pointcloud:
        if not DatasetClass.IS_FULL_POINTCLOUD_EVAL:
            raise ValueError(
                'This dataset does not support full pointcloud evaluation.')

    if config.evaluate_original_pointcloud:
        if not config.return_transformation:
            raise ValueError(
                'Pointcloud evaluation requires config.return_transformation=true.'
            )

    if (config.return_transformation ^ config.evaluate_original_pointcloud):
        raise ValueError(
            'Rotation evaluation requires config.evaluate_original_pointcloud=true and '
            'config.return_transformation=true.')

    logging.info('===> Initializing dataloader')
    if config.is_train:
        train_data_loader, train_sampler = initialize_data_loader(
            DatasetClass,
            config,
            phase=config.train_phase,
            num_workers=config.num_workers,
            augment_data=True,
            shuffle=True,
            repeat=True,
            batch_size=config.batch_size,
            limit_numpoints=config.train_limit_numpoints)

        val_data_loader, val_sampler = initialize_data_loader(
            DatasetClass,
            config,
            num_workers=config.num_val_workers,
            phase=config.val_phase,
            augment_data=False,
            shuffle=True,
            repeat=False,
            batch_size=config.val_batch_size,
            limit_numpoints=False)
        if train_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3  # RGB color

        num_labels = train_data_loader.dataset.NUM_LABELS
    else:
        test_data_loader, val_sampler = initialize_data_loader(
            DatasetClass,
            config,
            num_workers=config.num_workers,
            phase=config.test_phase,
            augment_data=False,
            shuffle=False,
            repeat=False,
            batch_size=config.test_batch_size,
            limit_numpoints=False)
        if test_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3  # RGB color

        num_labels = test_data_loader.dataset.NUM_LABELS

    logging.info('===> Building model')
    NetClass = load_model(config.model)
    if config.wrapper_type == 'None':
        model = NetClass(num_in_channel, num_labels, config)
        logging.info('===> Number of trainable parameters: {}: {}'.format(
            NetClass.__name__, count_parameters(model)))
    else:
        wrapper = load_wrapper(config.wrapper_type)
        model = wrapper(NetClass, num_in_channel, num_labels, config)
        logging.info('===> Number of trainable parameters: {}: {}'.format(
            wrapper.__name__ + NetClass.__name__, count_parameters(model)))

    logging.info(model)

    if config.weights == 'modelzoo':  # Load modelzoo weights if possible.
        logging.info('===> Loading modelzoo weights')
        model.preload_modelzoo()
    # Load weights if specified by the parameter.
    elif config.weights.lower() != 'none':
        logging.info('===> Loading weights: ' + config.weights)
        state = torch.load(config.weights)
        if config.weights_for_inner_model:
            model.model.load_state_dict(state['state_dict'])
        else:
            if config.lenient_weight_loading:
                matched_weights = load_state_with_same_shape(
                    model, state['state_dict'])
                model_dict = model.state_dict()
                model_dict.update(matched_weights)
                model.load_state_dict(model_dict)
            else:
                init_model_from_weights(model, state, freeze_bb=False)

    if config.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if config.gpu is not None:
            torch.cuda.set_device(config.gpu)
            model.cuda(config.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            config.batch_size = int(config.batch_size / ngpus_per_node)
            config.num_workers = int(
                (config.num_workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[config.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)

    if config.is_train:
        train(model,
              train_data_loader,
              val_data_loader,
              config,
              train_sampler=train_sampler,
              ngpus_per_node=ngpus_per_node)
    else:
        test(model, test_data_loader, config)
Пример #11
0
    # if specified precision supplied, override the loaded precision
    if args.prec != 'None':
        if args.prec == 'single':
            argPrec = torch.float32 
        if args.prec == 'double':
            argPrec = torch.float64 
        net = net.to(argPrec)

    cvt = lambda x: x.type(argPrec).to(device, non_blocking=True)

    logger.info(net)
    logger.info("----------TESTING---------------")
    logger.info("DIMENSION={:}  m={:}  nTh={:}   alpha={:}".format(d,m,nTh,net.alph))
    logger.info("nt_test={:}".format(nt_test))
    logger.info("Number of trainable parameters: {}".format(count_parameters(net)))
    logger.info("Number of testing examples: {}".format(nex))
    logger.info("-------------------------")
    logger.info("data={:} batch_size={:} gpu={:}".format(args.data, args.batch_size, args.gpu))
    logger.info("saveLocation = {:}".format(args.save))
    logger.info("-------------------------\n")

    end = time.time()

    log_msg = (
        '{:4s}        {:9s}  {:9s}  {:11s}  {:9s}'.format(
            'itr', 'loss', 'L (L_2)', 'C (loss)', 'R (HJB)'
        )
    )
    logger.info(log_msg)
Пример #12
0
def main():
    config = get_config()

    if config.test_config:
        json_config = json.load(open(config.test_config, 'r'))
        json_config['is_train'] = False
        json_config['weights'] = config.weights
        config = edict(json_config)
    elif config.resume:
        json_config = json.load(open(config.resume + '/config.json', 'r'))
        json_config['resume'] = config.resume
        config = edict(json_config)

    if config.is_cuda and not torch.cuda.is_available():
        raise Exception("No GPU found")
    device = get_torch_device(config.is_cuda)

    # torch.set_num_threads(config.threads)
    # torch.manual_seed(config.seed)
    # if config.is_cuda:
    #   torch.cuda.manual_seed(config.seed)

    logging.info('===> Configurations')
    dconfig = vars(config)
    for k in dconfig:
        logging.info('    {}: {}'.format(k, dconfig[k]))

    DatasetClass = load_dataset(config.dataset)
    logging.info('===> Initializing dataloader')

    if config.is_train:
        setup_seed(2021)

        train_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            phase=config.train_phase,
            # threads=config.threads,
            threads=4,
            augment_data=True,
            elastic_distortion=config.train_elastic_distortion,
            # elastic_distortion=False,
            # shuffle=True,
            shuffle=False,
            # repeat=True,
            repeat=False,
            batch_size=config.batch_size,
            # batch_size=8,
            limit_numpoints=config.train_limit_numpoints)

        # dat = iter(train_data_loader).__next__()
        # import ipdb; ipdb.set_trace()

        val_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            # threads=0,
            threads=config.val_threads,
            phase=config.val_phase,
            augment_data=False,
            elastic_distortion=config.test_elastic_distortion,
            shuffle=False,
            repeat=False,
            # batch_size=config.val_batch_size,
            batch_size=8,
            limit_numpoints=False)

        # dat = iter(val_data_loader).__next__()
        # import ipdb; ipdb.set_trace()

        if train_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3

        num_labels = train_data_loader.dataset.NUM_LABELS
    else:
        test_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            threads=config.threads,
            phase=config.test_phase,
            augment_data=False,
            elastic_distortion=config.test_elastic_distortion,
            shuffle=False,
            repeat=False,
            batch_size=config.test_batch_size,
            limit_numpoints=False)
        if test_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3

        num_labels = test_data_loader.dataset.NUM_LABELS

    logging.info('===> Building model')
    NetClass = load_model(config.model)
    model = NetClass(num_in_channel, num_labels, config)
    logging.info('===> Number of trainable parameters: {}: {}'.format(
        NetClass.__name__, count_parameters(model)))
    logging.info(model)

    # Set the number of threads
    # ME.initialize_nthreads(12, D=3)

    model = model.to(device)

    if config.weights == 'modelzoo':  # Load modelzoo weights if possible.
        logging.info('===> Loading modelzoo weights')
        model.preload_modelzoo()
    # Load weights if specified by the parameter.
    elif config.weights.lower() != 'none':
        logging.info('===> Loading weights: ' + config.weights)
        state = torch.load(config.weights)
        if config.weights_for_inner_model:
            model.model.load_state_dict(state['state_dict'])
        else:
            if config.lenient_weight_loading:
                matched_weights = load_state_with_same_shape(
                    model, state['state_dict'])
                model_dict = model.state_dict()
                model_dict.update(matched_weights)
                model.load_state_dict(model_dict)
            else:
                model.load_state_dict(state['state_dict'])

    if config.is_train:
        train(model, train_data_loader, val_data_loader, config)
    else:
        test(model, test_data_loader, config)
Пример #13
0
def main():
    config = get_config()
    ch = logging.StreamHandler(sys.stdout)
    logging.getLogger().setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(
        os.path.join(config.log_dir, './model.log'))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logging.basicConfig(format=os.uname()[1].split('.')[0] +
                        ' %(asctime)s %(message)s',
                        datefmt='%m/%d %H:%M:%S',
                        handlers=[ch, file_handler])

    if config.test_config:
        # When using the test_config, reload and overwrite it, so should keep some configs
        val_bs = config.val_batch_size
        is_export = config.is_export

        json_config = json.load(open(config.test_config, 'r'))
        json_config['is_train'] = False
        json_config['weights'] = config.weights
        json_config['multiprocess'] = False
        json_config['log_dir'] = config.log_dir
        json_config['val_threads'] = config.val_threads
        json_config['submit'] = config.submit
        config = edict(json_config)

        config.val_batch_size = val_bs
        config.is_export = is_export
        config.is_train = False
        sys.path.append(config.log_dir)
        # from local_models import load_model
    else:
        '''bakup files'''
        if not os.path.exists(os.path.join(config.log_dir, 'models')):
            os.mkdir(os.path.join(config.log_dir, 'models'))
        for filename in os.listdir('./models'):
            if ".py" in filename:  # donnot cp the init file since it will raise import error
                shutil.copy(os.path.join("./models", filename),
                            os.path.join(config.log_dir, 'models'))
            elif 'modules' in filename:
                # copy the moduls folder also
                if os.path.exists(
                        os.path.join(config.log_dir, 'models/modules')):
                    shutil.rmtree(
                        os.path.join(config.log_dir, 'models/modules'))
                shutil.copytree(os.path.join('./models', filename),
                                os.path.join(config.log_dir, 'models/modules'))

        shutil.copy('./main.py', config.log_dir)
        shutil.copy('./config.py', config.log_dir)
        shutil.copy('./lib/train.py', config.log_dir)
        shutil.copy('./lib/test.py', config.log_dir)

    if config.resume == 'True':
        new_iter_size = config.max_iter
        new_bs = config.batch_size
        config.resume = config.log_dir
        json_config = json.load(open(config.resume + '/config.json', 'r'))
        json_config['resume'] = config.resume
        config = edict(json_config)
        config.weights = os.path.join(
            config.log_dir, 'weights.pth')  # use the pre-trained weights
        logging.info('==== resuming from {}, Total {} ======'.format(
            config.max_iter, new_iter_size))
        config.max_iter = new_iter_size
        config.batch_size = new_bs
    else:
        config.resume = None

    if config.is_cuda and not torch.cuda.is_available():
        raise Exception("No GPU found")
    gpu_list = range(config.num_gpu)
    device = get_torch_device(config.is_cuda)

    # torch.set_num_threads(config.threads)
    # torch.manual_seed(config.seed)
    # if config.is_cuda:
    #       torch.cuda.manual_seed(config.seed)

    logging.info('===> Configurations')
    dconfig = vars(config)
    for k in dconfig:
        logging.info('      {}: {}'.format(k, dconfig[k]))

    DatasetClass = load_dataset(config.dataset)
    logging.info('===> Initializing dataloader')

    setup_seed(2021)
    """
    ---- Setting up train, val, test dataloaders ----
    Supported datasets:
    - ScannetSparseVoxelizationDataset
    - ScannetDataset
    - SemanticKITTI
    """

    point_scannet = False
    if config.is_train:

        if config.dataset == 'ScannetSparseVoxelizationDataset':
            point_scannet = False
            train_data_loader = initialize_data_loader(
                DatasetClass,
                config,
                phase=config.train_phase,
                threads=config.threads,
                augment_data=True,
                elastic_distortion=config.train_elastic_distortion,
                shuffle=True,
                # shuffle=False,   # DEBUG ONLY!!!
                repeat=True,
                # repeat=False,
                batch_size=config.batch_size,
                limit_numpoints=config.train_limit_numpoints)

            val_data_loader = initialize_data_loader(
                DatasetClass,
                config,
                threads=config.val_threads,
                phase=config.val_phase,
                augment_data=False,
                elastic_distortion=config.test_elastic_distortion,
                shuffle=False,
                repeat=False,
                batch_size=config.val_batch_size,
                limit_numpoints=False)

        elif config.dataset == 'ScannetDataset':
            val_DatasetClass = load_dataset(
                'ScannetDatasetWholeScene_evaluation')
            point_scannet = True

            # collate_fn = t.cfl_collate_fn_factory(False) # no limit num-points
            trainset = DatasetClass(
                root=
                '/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles',
                npoints=config.num_points,
                # split='debug',
                split='train',
                with_norm=False,
            )
            train_data_loader = torch.utils.data.DataLoader(
                dataset=trainset,
                num_workers=config.threads,
                # num_workers=0,  # for loading big pth file, should use single-thread
                batch_size=config.batch_size,
                # collate_fn=collate_fn, # input points, should not have collate-fn
                worker_init_fn=_init_fn,
                sampler=InfSampler(trainset, True))  # shuffle=True

            valset = val_DatasetClass(
                root=
                '/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles',
                scene_list_dir=
                '/data/eva_share_users/zhaotianchen/scannet/raw/metadata',
                # split='debug',
                split='eval',
                block_points=config.num_points,
                with_norm=False,
                delta=1.0,
            )
            val_data_loader = torch.utils.data.DataLoader(
                dataset=valset,
                # num_workers=config.threads,
                num_workers=
                0,  # for loading big pth file, should use single-thread
                batch_size=config.val_batch_size,
                # collate_fn=collate_fn, # input points, should not have collate-fn
                worker_init_fn=_init_fn)

        elif config.dataset == "SemanticKITTI":
            point_scannet = False
            dataset = SemanticKITTI(root=config.semantic_kitti_path,
                                    num_points=None,
                                    voxel_size=config.voxel_size,
                                    sample_stride=config.sample_stride,
                                    submit=False)
            collate_fn_factory = t.cfl_collate_fn_factory
            train_data_loader = torch.utils.data.DataLoader(
                dataset['train'],
                batch_size=config.batch_size,
                sampler=InfSampler(dataset['train'],
                                   shuffle=True),  # shuffle=true, repeat=true
                num_workers=config.threads,
                pin_memory=True,
                collate_fn=collate_fn_factory(config.train_limit_numpoints))

            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                dataset['test'],
                batch_size=config.batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(False))
        elif config.dataset == "S3DIS":
            trainset = S3DIS(
                config,
                train=True,
            )
            valset = S3DIS(
                config,
                train=False,
            )
            train_data_loader = torch.utils.data.DataLoader(
                trainset,
                batch_size=config.batch_size,
                sampler=InfSampler(trainset,
                                   shuffle=True),  # shuffle=true, repeat=true
                num_workers=config.threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(
                    config.train_limit_numpoints))

            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                valset,
                batch_size=config.batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(False))
        elif config.dataset == 'Nuscenes':
            config.xyz_input = False
            # todo:
            trainset = Nuscenes(
                config,
                train=True,
            )
            valset = Nuscenes(
                config,
                train=False,
            )
            train_data_loader = torch.utils.data.DataLoader(
                trainset,
                batch_size=config.batch_size,
                sampler=InfSampler(trainset,
                                   shuffle=True),  # shuffle=true, repeat=true
                num_workers=config.threads,
                pin_memory=True,
                # collate_fn=t.collate_fn_BEV,    # used when cylinder voxelize
                collate_fn=t.cfl_collate_fn_factory(False))

            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                valset,
                batch_size=config.batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                # collate_fn=t.collate_fn_BEV,
                collate_fn=t.cfl_collate_fn_factory(False))
        else:
            print('Dataset {} not supported').format(config.dataset)
            raise NotImplementedError

        # Setting up num_in_channel and num_labels
        if train_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3

        num_labels = train_data_loader.dataset.NUM_LABELS

        # it = iter(train_data_loader)
        # for _ in range(100):
        # data = it.__next__()
        # print(data)

    else:  # not config.is_train

        val_DatasetClass = load_dataset('ScannetDatasetWholeScene_evaluation')

        if config.dataset == 'ScannetSparseVoxelizationDataset':

            if config.is_export:  # when export, we need to export the train results too
                train_data_loader = initialize_data_loader(
                    DatasetClass,
                    config,
                    phase=config.train_phase,
                    threads=config.threads,
                    augment_data=True,
                    elastic_distortion=config.
                    train_elastic_distortion,  # DEBUG: not sure about this
                    shuffle=False,
                    repeat=False,
                    batch_size=config.batch_size,
                    limit_numpoints=config.train_limit_numpoints)

                # the valid like, no aug data
                # train_data_loader = initialize_data_loader(
                # DatasetClass,
                # config,
                # threads=config.val_threads,
                # phase=config.train_phase,
                # augment_data=False,
                # elastic_distortion=config.test_elastic_distortion,
                # shuffle=False,
                # repeat=False,
                # batch_size=config.val_batch_size,
                # limit_numpoints=False)

            val_data_loader = initialize_data_loader(
                DatasetClass,
                config,
                threads=config.val_threads,
                phase=config.val_phase,
                augment_data=False,
                elastic_distortion=config.test_elastic_distortion,
                shuffle=False,
                repeat=False,
                batch_size=config.val_batch_size,
                limit_numpoints=False)

            if val_data_loader.dataset.NUM_IN_CHANNEL is not None:
                num_in_channel = val_data_loader.dataset.NUM_IN_CHANNEL
            else:
                num_in_channel = 3

            num_labels = val_data_loader.dataset.NUM_LABELS

        elif config.dataset == 'ScannetDataset':
            '''when using scannet-point, use val instead of test'''

            point_scannet = True
            valset = val_DatasetClass(
                root=
                '/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles',
                scene_list_dir=
                '/data/eva_share_users/zhaotianchen/scannet/raw/metadata',
                split='eval',
                block_points=config.num_points,
                delta=1.0,
                with_norm=False,
            )
            val_data_loader = torch.utils.data.DataLoader(
                dataset=valset,
                # num_workers=config.threads,
                num_workers=
                0,  # for loading big pth file, should use single-thread
                batch_size=config.val_batch_size,
                # collate_fn=collate_fn, # input points, should not have collate-fn
                worker_init_fn=_init_fn,
            )

            num_labels = val_data_loader.dataset.NUM_LABELS
            num_in_channel = 3

        elif config.dataset == "SemanticKITTI":
            dataset = SemanticKITTI(root=config.semantic_kitti_path,
                                    num_points=None,
                                    voxel_size=config.voxel_size,
                                    submit=config.submit)
            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                dataset['test'],
                batch_size=config.val_batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(False))
            num_in_channel = 4
            num_labels = 19

        elif config.dataset == 'S3DIS':
            config.xyz_input = False

            trainset = S3DIS(
                config,
                train=True,
            )
            valset = S3DIS(
                config,
                train=False,
            )
            train_data_loader = torch.utils.data.DataLoader(
                trainset,
                batch_size=config.batch_size,
                sampler=InfSampler(trainset,
                                   shuffle=True),  # shuffle=true, repeat=true
                num_workers=config.threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(
                    config.train_limit_numpoints))

            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                valset,
                batch_size=config.batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(False))
            num_in_channel = 9
            num_labels = 13
        elif config.dataset == 'Nuscenes':
            config.xyz_input = False
            trainset = Nuscenes(
                config,
                train=True,
            )
            valset = Nuscenes(
                config,
                train - False,
            )
            train_data_loader = torch.utils.data.DataLoader(
                trainset,
                batch_size=config.batch_size,
                sampler=InfSampler(trainset,
                                   shuffle=True),  # shuffle=true, repeat=true
                num_workers=config.threads,
                pin_memory=True,
                # collate_fn=t.collate_fn_BEV,
                collate_fn=t.cfl_collate_fn_factory(False))

            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                valset,
                batch_size=config.batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                # collate_fn=t.collate_fn_BEV,
                collate_fn=t.cfl_collate_fn_factory(False))
            num_in_channel = 5
            num_labels = 16
        else:
            print('Dataset {} not supported').format(config.dataset)
            raise NotImplementedError

    logging.info('===> Building model')

    # if config.model == 'PointTransformer' or config.model == 'MixedTransformer':
    if config.model == 'PointTransformer':
        config.pure_point = True

    NetClass = load_model(config.model)
    if config.pure_point:
        model = NetClass(config,
                         num_class=num_labels,
                         N=config.num_points,
                         normal_channel=num_in_channel)
    else:
        if config.model == 'MixedTransformer':
            model = NetClass(config,
                             num_class=num_labels,
                             N=config.num_points,
                             normal_channel=num_in_channel)
        elif config.model == 'MinkowskiVoxelTransformer':
            model = NetClass(config, num_in_channel, num_labels)
        elif config.model == 'MinkowskiTransformerNet':
            model = NetClass(config, num_in_channel, num_labels)
        elif "Res" in config.model:
            model = NetClass(num_in_channel, num_labels, config)
        else:
            model = NetClass(num_in_channel, num_labels, config)

    logging.info('===> Number of trainable parameters: {}: {}M'.format(
        NetClass.__name__,
        count_parameters(model) / 1e6))
    if hasattr(model, "block1"):
        if hasattr(model.block1[0], 'h'):
            h = model.block1[0].h
            vec_dim = model.block1[0].vec_dim
        else:
            h = None
            vec_dim = None
    else:
        h = None
        vec_dim = None
    # logging.info('===> Model Args:\n PLANES: {} \n LAYERS: {}\n HEADS: {}\n Vec-dim: {}\n'.format(model.PLANES, model.LAYERS, h, vec_dim))
    logging.info(model)

    # Set the number of threads
    # ME.initialize_nthreads(12, D=3)

    model = model.to(device)

    if config.weights == 'modelzoo':  # Load modelzoo weights if possible.
        logging.info('===> Loading modelzoo weights')
        model.preload_modelzoo()
    # Load weights if specified by the parameter.
    elif config.weights.lower() != 'none':
        logging.info('===> Loading weights: ' + config.weights)
        state = torch.load(config.weights)
        # delete the keys containing the 'attn' since it raises size mismatch
        d_ = {
            k: v
            for k, v in state['state_dict'].items() if '_map' not in k
        }  # debug: sometiems model conmtains 'map_qk' which is not right for naming a module, since 'map' are always buffers
        d = {}
        for k in d_.keys():
            if 'module.' in k:
                d[k.replace('module.', '')] = d_[k]
            else:
                d[k] = d_[k]
        # del d_

        if config.weights_for_inner_model:
            model.model.load_state_dict(d)
        else:
            if config.lenient_weight_loading:
                matched_weights = load_state_with_same_shape(
                    model, state['state_dict'])
                model_dict = model.state_dict()
                model_dict.update(matched_weights)
                model.load_state_dict(model_dict)
            else:
                model.load_state_dict(d, strict=True)

    if config.is_debug:
        check_data(model, train_data_loader, val_data_loader, config)
        return None
    elif config.is_train:
        if hasattr(config, 'distill') and config.distill:
            assert point_scannet is not True  # only support whole scene for no
            train_distill(model, train_data_loader, val_data_loader, config)
        if config.multiprocess:
            if point_scannet:
                raise NotImplementedError
            else:
                train_mp(NetClass, train_data_loader, val_data_loader, config)
        else:
            if point_scannet:
                train_point(model, train_data_loader, val_data_loader, config)
            else:
                train(model, train_data_loader, val_data_loader, config)
    elif config.is_export:
        if point_scannet:
            raise NotImplementedError
        else:  # only support the whole-scene-style for now
            test(model,
                 train_data_loader,
                 config,
                 save_pred=True,
                 split='train')
            test(model, val_data_loader, config, save_pred=True, split='val')
    else:
        assert config.multiprocess == False
        # if test for submission, make a submit directory at current directory
        submit_dir = os.path.join(os.getcwd(), 'submit', 'sequences')
        if config.submit and not os.path.exists(submit_dir):
            os.makedirs(submit_dir)
            print("Made submission directory: " + submit_dir)
        if point_scannet:
            test_points(model, val_data_loader, config)
        else:
            test(model, val_data_loader, config, submit_dir=submit_dir)
Пример #14
0
def infer_with_cfg(args, cfg):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)

    out_dir = os.path.join(args.out_dir, "lsvrd_features", args.cfg_name)
    os.makedirs(out_dir, exist_ok=True)
    torch.backends.cudnn.benchmark = True

    print("model configs:")
    print(json.dumps(cfg, indent=2))
    print()

    print("run args:")
    for arg in vars(args):
        print("%10s: %s" % (arg, str(getattr(args, arg))))
    print()

    word_dict = SymbolDictionary.load_from_file(cfg.word_dict)
    pred_dict = SymbolDictionary.load_from_file(cfg.pred_dict)

    print("building language model")
    word_emb = WordEmbedding.build_from_config(cfg.language_model, word_dict).cuda()
    word_emb.init_embedding(cfg.language_model.word_emb_init)
    word_emb.freeze()
    language_model = LanguageModel.build_from_config(cfg.language_model)
    language_model = language_model.cuda()
    lckpt = torch.load(args.lckpt)
    language_model.load_state_dict(lckpt)
    language_model.train(False)
    language_model.eval()
    n_l_params = count_parameters(language_model)
    print("language model: {:,} parameters".format(n_l_params))

    print("obtaining predicate embeddings")
    pred_emb = get_sym_emb(word_emb, language_model, word_dict, pred_dict, cfg.language_model.tokens_length)

    print("building vision model")
    with torch.no_grad():
        vision_model = VisionModel.build_from_config(cfg.vision_model)
        vision_model = vision_model.cuda()
        vckpt = torch.load(args.vckpt)
        vision_model.load_state_dict(vckpt)
        vision_model.train(False)
        vision_model.eval()
    n_v_params = count_parameters(vision_model)
    print("vision model: {:,} parameters".format(n_v_params))

    print("getting boxes from objects_dir")

    if args.dataset == "gqa":
        info_path = os.path.join(args.objects_dir, "gqa_objects_info.json")
        indices = json.load(open(info_path))
        h5_paths = [ os.path.join(args.objects_dir, "gqa_objects_%d.h5" % i) for i in range(16) ]
        h5s = [ h5py.File(h5_path) for h5_path in h5_paths ]
        h5_boxes = [ h5["bboxes"] for h5 in h5s ]
        h5_features = [ h5["features"] for h5 in h5s]
        all_boxes = {}
        rearange_inds = np.argsort([ 1, 0, 3, 2 ]) # (x1, y1, x2, y2) -> (y1, x1, y2, x2)
        for image_id, meta in tqdm(indices.items()):
            file_idx = meta["file"]
            idx = meta["idx"]
            n_use = min(meta["objectsNum"], args.n_obj)
            width = float(meta["width"])
            height = float(meta["height"])
            boxes = h5_boxes[file_idx][idx, :n_use, :]
            boxes = boxes[:, rearange_inds] # (x1, y1, x2, y2) -> (y1, x1, y2, x2)
            boxes = boxes / np.array([height, width, height, width])
            all_boxes[image_id] = boxes
        n_entries = len(all_boxes)
        indices = { key: (val["file"], val["idx"]) for key, val in indices.items() }
    elif args.dataset == "vqa2":
        info_path = os.path.join(args.objects_dir, "info.json")
        info = json.load(open(info_path))
        indices = info['indices']
        meta = info['meta']
        h5_paths = [os.path.join(args.objects_dir, "data_%d.h5" % i) for i in range(16)]
        h5s = [h5py.File(h5_path) for h5_path in h5_paths]
        h5_boxes = [h5["boxes"] for h5 in h5s]
        h5_features = [h5["v_objs"] for h5 in h5s]
        all_boxes = {}
        rearange_inds = np.argsort([1, 0, 3, 2])  # (x1, y1, x2, y2) -> (y1, x1, y2, x2)
        for image_id, idx_pair in tqdm(indices.items()):
            block_idx = idx_pair["block"]
            idx = idx_pair["idx"]
            height, width = meta[image_id]['height'], meta[image_id]['width']
            boxes = h5_boxes[block_idx][idx]
            keep = np.where(np.sum(boxes, axis=1) > 0)
            boxes = boxes[keep]
            boxes = boxes[:, rearange_inds]  # (x1, y1, x2, y2) -> (y1, x1, y2, x2)
            boxes = boxes / np.array([height, width, height, width])
            all_boxes[image_id] = boxes
        n_entries = len(all_boxes)
        indices = {key: (val["block"], val["idx"]) for key, val in indices.items()}
    else:
        raise NotImplementedError

    print("creating h5 loader for pre-extracted feature")
    fields = [{ "name": "features",
                "shape": [cfg.vision_model.feature_dim, cfg.vision_model.feature_height, cfg.vision_model.feature_width],
                "dtype": "float32",
                "preload": False }]
    image_ids = [image_id for image_id in all_boxes.keys()]
    image_ids = list(set(image_ids))
    loader = H5DataLoader.load_from_directory(args.cache_dir, fields, image_ids)

    print("creating h5 writer")
    n_obj = args.n_obj
    emb_dim = cfg.vision_model.emb_dim
    fields = [
        { "name": "objects", "shape": [n_obj, 2048], "dtype": "float32"},
        { "name": "boxes", "shape": [n_obj, 4], "dtype": "float32" },
        # { "name": "entities", "shape": [ n_obj, emb_dim ], "dtype": "float32" },
        { "name": "relations", "shape": [ n_obj, n_obj, emb_dim ], "dtype": "float32" },
        { "name": "rel_mat", "shape": [ n_obj, n_obj ], "dtype": "int32" }
    ]
    writer = H5DataWriter(out_dir, "gqa_lsvrd_features", n_entries, 16, fields)

    print("inference started")
    infer(vision_model, all_boxes, pred_emb, loader, writer, h5_features, indices, args, cfg)
    writer.close()