def build_vg_dsets(args): with open(args.vocab_json, 'r') as f: vocab = json.load(f) dset_kwargs = { 'vocab': vocab, 'h5_path': args.train_h5, 'image_dir': args.vg_image_dir, 'image_size': args.image_size, 'max_samples': args.num_train_samples, 'max_objects': args.max_objects_per_image, 'use_orphaned_objects': args.vg_use_orphaned_objects, 'include_relationships': args.include_relationships, } train_dset = VgSceneGraphDataset(**dset_kwargs) iter_per_epoch = len(train_dset) // args.batch_size print('There are %d iterations per epoch' % iter_per_epoch) dset_kwargs['h5_path'] = args.val_h5 del dset_kwargs['max_samples'] val_dset = VgSceneGraphDataset(**dset_kwargs) return vocab, train_dset, val_dset
def build_vg_dset(args, checkpoint): vocab = checkpoint['model_kwargs']['vocab'] dset_kwargs = { 'vocab': vocab, 'h5_path': args.vg_h5, 'image_dir': args.vg_image_dir, 'image_size': args.image_size, 'max_samples': args.num_samples, 'max_objects': checkpoint['args']['max_objects_per_image'], 'use_orphaned_objects': checkpoint['args']['vg_use_orphaned_objects'], } dset = VgSceneGraphDataset(**dset_kwargs) return dset