예제 #1
0
파일: utils.py 프로젝트: dschaehi/film
def load_program_generator(path, model_type="PG+EE"):
    checkpoint = load_cpu(path)
    kwargs = checkpoint["program_generator_kwargs"]
    state = checkpoint["program_generator_state"]
    if model_type == "FiLM":
        print("Loading FiLMGen from " + path)
        kwargs = get_updated_args(kwargs, FiLMGen)
        model = FiLMGen(**kwargs)
    else:
        print("Loading PG from " + path)
        model = Seq2Seq(**kwargs)
    model.load_state_dict(state)
    return model, kwargs
예제 #2
0
def load_program_generator(path, model_type='PG+EE'):
    checkpoint = load_cpu(path)
    kwargs = checkpoint['program_generator_kwargs']
    state = checkpoint['program_generator_state']
    if model_type == 'FiLM':
        print('Loading FiLMGen from ' + path)
        kwargs = get_updated_args(kwargs, FiLMGen)
        model = FiLMGen(**kwargs)
    else:
        print('Loading PG from ' + path)
        model = Seq2Seq(**kwargs)
    model.load_state_dict(state)
    return model, kwargs
예제 #3
0
def load_program_generator(path, model_type='PG+EE'):
    checkpoint = load_cpu(path)
    kwargs = checkpoint['program_generator_kwargs']
    state = checkpoint['program_generator_state']
    if model_type == 'FiLM':
        print('Loading FiLMGen from ' + path)
        kwargs = get_updated_args(kwargs, FiLMGen)
        model = FiLMGen(**kwargs)
    else:
        print('Loading PG from ' + path)
        model = Seq2Seq(**kwargs)
    state_stemed = {}
    for k, v in state.iteritems():
        k_new = '.'.join(k.split('.')[1:])
        state_stemed[k_new] = v
    model.load_state_dict(state_stemed)
    return model, kwargs
예제 #4
0
def load_program_generator(path, model_type='PG+EE'):
    checkpoint = load_cpu(path)
    kwargs = checkpoint['program_generator_kwargs']
    state = checkpoint['program_generator_state']
    if model_type == 'FiLM':
        print('Loading FiLMGen from ' + path)
        kwargs = get_updated_args(kwargs, FiLMGen)
        model = FiLMGen(**kwargs)
        new_state_dict = OrderedDict()
        for k, v in state.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        state = new_state_dict
    else:
        print('Loading PG from ' + path)
        model = Seq2Seq(**kwargs)
    model.load_state_dict(state)
    return model, kwargs
예제 #5
0
def load_program_generator(path):
    checkpoint = load_cpu(path)
    model_type = checkpoint['args']['model_type']
    kwargs = checkpoint['program_generator_kwargs']
    state = checkpoint['program_generator_state']
    if model_type in ['FiLM', 'MAC', 'RelNet']:
        kwargs = get_updated_args(kwargs, FiLMGen)
        model = FiLMGen(**kwargs)
    elif model_type == 'PG+EE':
        if kwargs.rnn_attention:
            model = Seq2SeqAtt(**kwargs)
        else:
            model = Seq2Seq(**kwargs)
    else:
        model = None
    if model is not None:
        model.load_state_dict(state)
    return model, kwargs
예제 #6
0
def get_program_generator(args):
    vocab = utils.load_vocab(args.vocab_json)
    if args.program_generator_start_from is not None:  # it is None
        pg, kwargs = utils.load_program_generator(
            args.program_generator_start_from, model_type=args.model_type)
        cur_vocab_size = pg.encoder_embed.weight.size(0)
        if cur_vocab_size != len(vocab['question_token_to_idx']):
            print('Expanding vocabulary of program generator')
            pg.expand_encoder_vocab(vocab['question_token_to_idx'])
            kwargs['encoder_vocab_size'] = len(vocab['question_token_to_idx'])
    else:
        kwargs = {
            'encoder_vocab_size': len(vocab['question_token_to_idx']),
            'decoder_vocab_size': len(vocab['program_token_to_idx']),
            'wordvec_dim': args.rnn_wordvec_dim,
            'hidden_dim': args.rnn_hidden_dim,
            'rnn_num_layers': args.rnn_num_layers,
            'rnn_dropout': args.rnn_dropout,  # 0e-2
        }
        if args.model_type == 'FiLM':
            kwargs[
                'parameter_efficient'] = args.program_generator_parameter_efficient == 1
            kwargs['output_batchnorm'] = args.rnn_output_batchnorm == 1
            kwargs['bidirectional'] = args.bidirectional == 1
            kwargs['encoder_type'] = args.encoder_type
            kwargs['decoder_type'] = args.decoder_type
            kwargs['gamma_option'] = args.gamma_option
            kwargs['gamma_baseline'] = args.gamma_baseline
            kwargs['num_modules'] = args.num_modules
            kwargs['module_num_layers'] = args.module_num_layers
            kwargs['module_dim'] = args.module_dim
            kwargs['debug_every'] = args.debug_every
            pg = FiLMGen(**kwargs)
        else:
            pg = Seq2Seq(**kwargs)
    pg.cuda()
    pg.encoder_rnn.flatten_parameters()
    if args.gpu_devices:
        gpu_id = parse_int_list(args.gpu_devices)
        pg = DataParallel(pg, device_ids=gpu_id)
    pg.train()
    pg.module.encoder_rnn.flatten_parameters()
    return pg, kwargs
예제 #7
0
def get_program_generator(args):
  vocab = utils.load_vocab(args.vocab_json)
  if args.program_generator_start_from is not None:
    pg, kwargs = utils.load_program_generator(
      args.program_generator_start_from, model_type=args.model_type)
    cur_vocab_size = pg.encoder_embed.weight.size(0)
    if cur_vocab_size != len(vocab['question_token_to_idx']):
      print('Expanding vocabulary of program generator')
      pg.expand_encoder_vocab(vocab['question_token_to_idx'])
      kwargs['encoder_vocab_size'] = len(vocab['question_token_to_idx'])
  else:
    kwargs = {
      'encoder_vocab_size': len(vocab['question_token_to_idx']),
      'decoder_vocab_size': len(vocab['program_token_to_idx']),
      'wordvec_dim': args.rnn_wordvec_dim,
      'hidden_dim': args.rnn_hidden_dim,
      'rnn_num_layers': args.rnn_num_layers,
      'rnn_dropout': args.rnn_dropout,
    }
    if args.model_type.startswith('FiLM'):
      kwargs['parameter_efficient'] = args.program_generator_parameter_efficient == 1
      kwargs['output_batchnorm'] = args.rnn_output_batchnorm == 1
      kwargs['bidirectional'] = args.bidirectional == 1
      kwargs['encoder_type'] = args.encoder_type
      kwargs['decoder_type'] = args.decoder_type
      kwargs['gamma_option'] = args.gamma_option
      kwargs['gamma_baseline'] = args.gamma_baseline
      kwargs['num_modules'] = args.num_modules
      kwargs['module_num_layers'] = args.module_num_layers
      kwargs['module_dim'] = args.module_dim
      kwargs['debug_every'] = args.debug_every
      if args.model_type == 'FiLM+BoW':
        kwargs['encoder_type'] = 'bow'
      pg = FiLMGen(**kwargs)
    else:
      pg = Seq2Seq(**kwargs)
  if torch.cuda.is_available():
    pg.cuda()
  else:
    pg.cpu()
  pg.train()
  return pg, kwargs