def main(args): # import os # cwd = os.getcwd() # print(cwd) # return if args.randomize_checkpoint_path == 1: name, ext = os.path.splitext(args.checkpoint_path) num = random.randint(1, 1000000) args.checkpoint_path = '%s_%06d%s' % (name, num, ext) vocab = utils.load_vocab(args.vocab_json) if args.use_local_copies == 1: shutil.copy(args.train_question_h5, '/tmp/train_questions.h5') shutil.copy(args.train_features_h5, '/tmp/train_features.h5') shutil.copy(args.val_question_h5, '/tmp/val_questions.h5') shutil.copy(args.val_features_h5, '/tmp/val_features.h5') args.train_question_h5 = '/tmp/train_questions.h5' args.train_features_h5 = '/tmp/train_features.h5' args.val_question_h5 = '/tmp/val_questions.h5' args.val_features_h5 = '/tmp/val_features.h5' question_families = None if args.family_split_file is not None: with open(args.family_split_file, 'r') as f: question_families = json.load(f) train_loader_kwargs = { 'question_h5': args.train_question_h5, 'feature_h5': args.train_features_h5, 'vocab': vocab, 'batch_size': args.batch_size, 'shuffle': args.shuffle_train_data == 1, 'question_families': question_families, 'max_samples': args.num_train_samples, 'num_workers': args.loader_num_workers, } val_loader_kwargs = { 'question_h5': args.val_question_h5, 'feature_h5': args.val_features_h5, 'vocab': vocab, 'batch_size': args.batch_size, 'question_families': question_families, 'max_samples': args.num_val_samples, 'num_workers': args.loader_num_workers, } with ClevrDataLoader(**train_loader_kwargs) as train_loader, \ ClevrDataLoader(**val_loader_kwargs) as val_loader: train_loop(args, train_loader, val_loader) if args.use_local_copies == 1 and args.cleanup_local_copies == 1: os.remove('/tmp/train_questions.h5') os.remove('/tmp/train_features.h5') os.remove('/tmp/val_questions.h5') os.remove('/tmp/val_features.h5')
def main(args): print() model = None import pdb pdb.set_trace() if args.baseline_model is not None: print('Loading baseline model from ', args.baseline_model) model, _ = utils.load_baseline(args.baseline_model) if args.vocab_json is not None: new_vocab = utils.load_vocab(args.vocab_json) model.rnn.expand_vocab(new_vocab['question_token_to_idx']) elif args.program_generator is not None and args.execution_engine is not None: print('Loading program generator from ', args.program_generator) program_generator, _ = utils.load_program_generator( args.program_generator) print('Loading execution engine from ', args.execution_engine) execution_engine, _ = utils.load_execution_engine( args.execution_engine, verbose=False) if args.vocab_json is not None: new_vocab = utils.load_vocab(args.vocab_json) program_generator.expand_encoder_vocab( new_vocab['question_token_to_idx']) model = (program_generator, execution_engine) else: print( 'Must give either --baseline_model or --program_generator and --execution_engine' ) return if args.question is not None and args.image is not None: run_single_example(args, model) else: vocab = load_vocab(args) loader_kwargs = { 'question_h5': args.input_question_h5, 'feature_h5': args.input_features_h5, 'vocab': vocab, 'batch_size': args.batch_size, } if args.num_samples is not None and args.num_samples > 0: loader_kwargs['max_samples'] = args.num_samples if args.family_split_file is not None: with open(args.family_split_file, 'r') as f: loader_kwargs['question_families'] = json.load(f) with ClevrDataLoader(**loader_kwargs) as loader: run_batch(args, model, loader)
def main(args): print() model = init_model(args) if model is None: return if args.question is not None and args.image is not None: run_single_example(args, model) else: vocab = load_vocab(args) loader_kwargs = { 'question_h5': args.input_question_h5, 'feature_h5': args.input_features_h5, 'vocab': vocab, 'batch_size': args.batch_size, } if args.num_samples is not None and args.num_samples > 0: loader_kwargs['max_samples'] = args.num_samples if args.family_split_file is not None: with open(args.family_split_file, 'r') as f: loader_kwargs['question_families'] = json.load(f) with ClevrDataLoader(**loader_kwargs) as loader: run_batch(args, model, loader)