def load_program_generator(path): checkpoint = load_cpu(path) kwargs = checkpoint['program_generator_kwargs'] state = checkpoint['program_generator_state'] model = Seq2Seq(**kwargs) model.load_state_dict(state) return model, kwargs
def get_program_generator(args): vocab = utils.load_vocab(args.vocab_json) if args.program_generator_start_from is not None: pg, kwargs = utils.load_program_generator(args.program_generator_start_from) cur_vocab_size = pg.encoder_embed.weight.size(0) if cur_vocab_size != len(vocab['refexp_token_to_idx']): print('Expanding vocabulary of program generator') pg.expand_encoder_vocab(vocab['refexp_token_to_idx']) kwargs['encoder_vocab_size'] = len(vocab['refexp_token_to_idx']) else: kwargs = { 'encoder_vocab_size': len(vocab['refexp_token_to_idx']), 'decoder_vocab_size': len(vocab['program_token_to_idx']), 'wordvec_dim': args.rnn_wordvec_dim, 'hidden_dim': args.rnn_hidden_dim, 'rnn_num_layers': args.rnn_num_layers, 'rnn_dropout': args.rnn_dropout, } pg = Seq2Seq(**kwargs) pg.cuda() pg.train() return pg, kwargs