def get_program_generator(args): vocab = utils.load_vocab(args.vocab_json) if args.program_generator_start_from is not None: # it is None pg, kwargs = utils.load_program_generator( args.program_generator_start_from, model_type=args.model_type) cur_vocab_size = pg.encoder_embed.weight.size(0) if cur_vocab_size != len(vocab['question_token_to_idx']): print('Expanding vocabulary of program generator') pg.expand_encoder_vocab(vocab['question_token_to_idx']) kwargs['encoder_vocab_size'] = len(vocab['question_token_to_idx']) else: kwargs = { 'encoder_vocab_size': len(vocab['question_token_to_idx']), 'decoder_vocab_size': len(vocab['program_token_to_idx']), 'wordvec_dim': args.rnn_wordvec_dim, 'hidden_dim': args.rnn_hidden_dim, 'rnn_num_layers': args.rnn_num_layers, 'rnn_dropout': args.rnn_dropout, # 0e-2 } if args.model_type == 'FiLM': kwargs[ 'parameter_efficient'] = args.program_generator_parameter_efficient == 1 kwargs['output_batchnorm'] = args.rnn_output_batchnorm == 1 kwargs['bidirectional'] = args.bidirectional == 1 kwargs['encoder_type'] = args.encoder_type kwargs['decoder_type'] = args.decoder_type kwargs['gamma_option'] = args.gamma_option kwargs['gamma_baseline'] = args.gamma_baseline kwargs['num_modules'] = args.num_modules kwargs['module_num_layers'] = args.module_num_layers kwargs['module_dim'] = args.module_dim kwargs['debug_every'] = args.debug_every pg = FiLMGen(**kwargs) else: pg = Seq2Seq(**kwargs) pg.cuda() pg.encoder_rnn.flatten_parameters() if args.gpu_devices: gpu_id = parse_int_list(args.gpu_devices) pg = DataParallel(pg, device_ids=gpu_id) pg.train() pg.module.encoder_rnn.flatten_parameters() return pg, kwargs
def get_program_generator(args): vocab = utils.load_vocab(args.vocab_json) if args.program_generator_start_from is not None: pg, kwargs = utils.load_program_generator( args.program_generator_start_from, model_type=args.model_type) cur_vocab_size = pg.encoder_embed.weight.size(0) if cur_vocab_size != len(vocab['question_token_to_idx']): print('Expanding vocabulary of program generator') pg.expand_encoder_vocab(vocab['question_token_to_idx']) kwargs['encoder_vocab_size'] = len(vocab['question_token_to_idx']) else: kwargs = { 'encoder_vocab_size': len(vocab['question_token_to_idx']), 'decoder_vocab_size': len(vocab['program_token_to_idx']), 'wordvec_dim': args.rnn_wordvec_dim, 'hidden_dim': args.rnn_hidden_dim, 'rnn_num_layers': args.rnn_num_layers, 'rnn_dropout': args.rnn_dropout, } if args.model_type.startswith('FiLM'): kwargs['parameter_efficient'] = args.program_generator_parameter_efficient == 1 kwargs['output_batchnorm'] = args.rnn_output_batchnorm == 1 kwargs['bidirectional'] = args.bidirectional == 1 kwargs['encoder_type'] = args.encoder_type kwargs['decoder_type'] = args.decoder_type kwargs['gamma_option'] = args.gamma_option kwargs['gamma_baseline'] = args.gamma_baseline kwargs['num_modules'] = args.num_modules kwargs['module_num_layers'] = args.module_num_layers kwargs['module_dim'] = args.module_dim kwargs['debug_every'] = args.debug_every if args.model_type == 'FiLM+BoW': kwargs['encoder_type'] = 'bow' pg = FiLMGen(**kwargs) else: pg = Seq2Seq(**kwargs) if torch.cuda.is_available(): pg.cuda() else: pg.cpu() pg.train() return pg, kwargs
film_gen = FiLMGen(encoder_vocab_size=len(vocab['question_token_to_idx']), wordvec_dim=args.rnn_wordvec_dim, hidden_dim=args.rnn_hidden_dim, rnn_num_layers=args.rnn_num_layers, rnn_dropout=0, output_batchnorm=False, bidirectional=False, encoder_type=args.encoder_type, decoder_type=args.decoder_type, gamma_option=args.gamma_option, gamma_baseline=1, num_modules=args.num_modules, module_num_layers=args.module_num_layers, module_dim=args.module_dim, parameter_efficient=True) film_gen = film_gen.cuda() filmed_net = FiLMedNet( vocab, feature_dim=(1024, 14, 14), stem_num_layers=args.module_stem_num_layers, stem_batchnorm=args.module_stem_batchnorm, stem_kernel_size=args.module_stem_kernel_size, stem_stride=1, stem_padding=None, num_modules=args.num_modules, module_num_layers=args.module_num_layers, module_dim=args.module_dim, module_residual=args.module_residual, module_batchnorm=args.module_batchnorm, module_batchnorm_affine=args.module_batchnorm_affine, module_dropout=args.module_dropout,