예제 #1
0
def get_generator(checkpoint, evalArgs):
    args = AttrDict(checkpoint['args'])
    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm,
        use_gpu=evalArgs.use_gpu)
    generator.load_state_dict(checkpoint['g_state'])

    if evalArgs.use_gpu:
        generator.cuda()
    else:
        generator.cpu()

    generator.train()
    return generator
예제 #2
0
 def get_generator(self, checkpoint):
     self.import_modules()
     from sgan.models import TrajectoryGenerator
     args = attrdict.AttrDict(checkpoint['args'])
     generator = TrajectoryGenerator(
         obs_len=args.obs_len,
         pred_len=args.pred_len,
         embedding_dim=args.embedding_dim,
         encoder_h_dim=args.encoder_h_dim_g,
         decoder_h_dim=args.decoder_h_dim_g,
         mlp_dim=args.mlp_dim,
         num_layers=args.num_layers,
         noise_dim=args.noise_dim,
         noise_type=args.noise_type,
         noise_mix_type=args.noise_mix_type,
         pooling_type=args.pooling_type,
         pool_every_timestep=args.pool_every_timestep,
         dropout=args.dropout,
         bottleneck_dim=args.bottleneck_dim,
         neighborhood_size=args.neighborhood_size,
         grid_size=args.grid_size,
         batch_norm=args.batch_norm)
     generator.load_state_dict(checkpoint['g_state'])
     generator.cpu()
     generator.train()
     return generator