def get_program_generator(args): vocab = utils.load_vocab(args.vocab_json) if args.program_generator_start_from is not None: # it is None pg, kwargs = utils.load_program_generator( args.program_generator_start_from, model_type=args.model_type) cur_vocab_size = pg.encoder_embed.weight.size(0) if cur_vocab_size != len(vocab['question_token_to_idx']): print('Expanding vocabulary of program generator') pg.expand_encoder_vocab(vocab['question_token_to_idx']) kwargs['encoder_vocab_size'] = len(vocab['question_token_to_idx']) else: kwargs = { 'encoder_vocab_size': len(vocab['question_token_to_idx']), 'decoder_vocab_size': len(vocab['program_token_to_idx']), 'wordvec_dim': args.rnn_wordvec_dim, 'hidden_dim': args.rnn_hidden_dim, 'rnn_num_layers': args.rnn_num_layers, 'rnn_dropout': args.rnn_dropout, # 0e-2 } if args.model_type == 'FiLM': kwargs[ 'parameter_efficient'] = args.program_generator_parameter_efficient == 1 kwargs['output_batchnorm'] = args.rnn_output_batchnorm == 1 kwargs['bidirectional'] = args.bidirectional == 1 kwargs['encoder_type'] = args.encoder_type kwargs['decoder_type'] = args.decoder_type kwargs['gamma_option'] = args.gamma_option kwargs['gamma_baseline'] = args.gamma_baseline kwargs['num_modules'] = args.num_modules kwargs['module_num_layers'] = args.module_num_layers kwargs['module_dim'] = args.module_dim kwargs['debug_every'] = args.debug_every pg = FiLMGen(**kwargs) else: pg = Seq2Seq(**kwargs) pg.cuda() pg.encoder_rnn.flatten_parameters() if args.gpu_devices: gpu_id = parse_int_list(args.gpu_devices) pg = DataParallel(pg, device_ids=gpu_id) pg.train() pg.module.encoder_rnn.flatten_parameters() return pg, kwargs
def get_program_generator(args): vocab = utils.load_vocab(args.vocab_json) if args.program_generator_start_from is not None: pg, kwargs = utils.load_program_generator( args.program_generator_start_from, model_type=args.model_type) cur_vocab_size = pg.encoder_embed.weight.size(0) if cur_vocab_size != len(vocab['question_token_to_idx']): print('Expanding vocabulary of program generator') pg.expand_encoder_vocab(vocab['question_token_to_idx']) kwargs['encoder_vocab_size'] = len(vocab['question_token_to_idx']) else: kwargs = { 'encoder_vocab_size': len(vocab['question_token_to_idx']), 'decoder_vocab_size': len(vocab['program_token_to_idx']), 'wordvec_dim': args.rnn_wordvec_dim, 'hidden_dim': args.rnn_hidden_dim, 'rnn_num_layers': args.rnn_num_layers, 'rnn_dropout': args.rnn_dropout, } if args.model_type.startswith('FiLM'): kwargs['parameter_efficient'] = args.program_generator_parameter_efficient == 1 kwargs['output_batchnorm'] = args.rnn_output_batchnorm == 1 kwargs['bidirectional'] = args.bidirectional == 1 kwargs['encoder_type'] = args.encoder_type kwargs['decoder_type'] = args.decoder_type kwargs['gamma_option'] = args.gamma_option kwargs['gamma_baseline'] = args.gamma_baseline kwargs['num_modules'] = args.num_modules kwargs['module_num_layers'] = args.module_num_layers kwargs['module_dim'] = args.module_dim kwargs['debug_every'] = args.debug_every if args.model_type == 'FiLM+BoW': kwargs['encoder_type'] = 'bow' pg = FiLMGen(**kwargs) else: pg = Seq2Seq(**kwargs) if torch.cuda.is_available(): pg.cuda() else: pg.cpu() pg.train() return pg, kwargs