Beispiel #1
0
        rocks = batch['rocks'].cuda()
        trees = batch['trees'].cuda()

        gt_masks = torch.cat((skier, flags, rocks, trees), dim=1)

        pred_masks, latent = model(rgb)

        all_images = model.make_visuals(rgb, gt_masks, pred_masks)
    pass


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_dir',
                        type=str,
                        default='data/20210413_182405')
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--max_epochs', type=int, default=2)
    parser.add_argument(
        '--model_dir',
        type=str,
        default=
        '/home/aaron/workspace/vlr/vlr-project/checkpoints/20210423_184757')
    args = parser.parse_args()

    ckpts = sorted(list(Path(args.model_dir).glob('*.ckpt')))
    model = AutoEncoder.load_from_checkpoint(str(ckpts[-1]))

    #evaluate_on_dataset(model)
Beispiel #2
0
encoder_ids = [
    '20210505_165831',
    '20210505_165834',
    '20210505_165837',
    '20210505_165841',
    '20210505_165845',
    '20210505_165848',
    '20210505_165852',
    '20210505_165856',
]

ID = 1

#encoder_path = f"/home/aaronhua/vlr/dqn-pong/autoencoder/checkpoints/{encoder_ids[ID]}/epoch=19.ckpt"
encoder_path = f"/home/aaron/workspace/vlr/dqn-pong/autoencoder/checkpoints/{encoder_ids[ID]}/epoch=19.ckpt"
auto_encoder = AutoEncoder.load_from_checkpoint(encoder_path).to(device)
encoder = auto_encoder.encoder
decoder = auto_encoder.decoder
encoder.eval()

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

# initialize replay memory
memory = ReplayMemory(MEMORY_SIZE)

# create networks
policy_net = DQNBase(n_actions=4).to(device)
target_net = DQNBase(n_actions=4).to(device)
target_net.load_state_dict(policy_net.state_dict())