def discretize_test(): """Test pose discretization/undiscretization""" mapnet = MapNet(cnn=lambda i: i, embedding_size=3, map_size=7, aggregator='avg', hardmax=True, improved_padding=True) # test data: all positions and angles (x, y, ang) = t.meshgrid(t.arange(7, dtype=t.float) - 3, t.arange(7, dtype=t.float) - 3, t.arange(4, dtype=t.float) * math.pi / 2) poses = Rigid2D(x, y, ang) poses = poses.apply(t.Tensor.flatten) # discretize and undiscretize (bin_x, bin_y, bin_ang, invalid) = mapnet.discretize_pose(poses) (x, y, ang) = mapnet.undiscretize_pose(bin_x, bin_y, bin_ang) assert (x - poses.x).abs().max().item() < 1e-4 assert (y - poses.y).abs().max().item() < 1e-4 assert (ang - poses.ang).abs().max().item() < 1e-4 assert invalid.sum().item() < 1e-4 # test flat indexes shape = [mapnet.orientations, mapnet.map_size, mapnet.map_size] bin_idx = sub2ind([bin_ang, bin_y, bin_x], shape, check_bounds=True) (ang, y, x) = ind2sub(bin_idx, shape) (x, y, ang) = mapnet.undiscretize_pose(x, y, ang) assert (x - poses.x).abs().max().item() < 1e-4 assert (y - poses.y).abs().max().item() < 1e-4 assert (ang - poses.ang).abs().max().item() < 1e-4 assert invalid.sum().item() < 1e-4
def visualization_test(vectorization=False): """Show observations only, for manual inspection""" mapnet = MapNet(cnn=lambda x: x, embedding_size=3, map_size=5, aggregator='avg', hardmax=True, improved_padding=True, debug_vectorization=vectorization) # get local observations obs1 = """.#. .*# ...""" obs2 = """.*# .#. .#.""" obs3 = """#.. *#. ...""" # shape = (batch=1, time, channels=1, height, width) obs = [parse_map(o) for o in (obs1, obs2, obs3)] obs = t.stack(obs, dim=0).unsqueeze(dim=0) # run mapnet out = mapnet(obs, debug_output=True) # show results show_result(None, obs, out)
def discretize_center_test(): """Test pose discretization center (0,0 should correspond to center bin of map)""" mapnet = MapNet(cnn=lambda i: i, embedding_size=3, map_size=7, aggregator='avg', hardmax=True, improved_padding=True) center = (mapnet.map_size - 1) // 2 # test data: positions and angles around center, excluding boundaries pos_range = t.linspace(-0.5, 0.5, 20)[1:-1] ang_range = t.linspace(-math.pi / 4, math.pi / 4, 20)[1:-1] (x, y, ang) = t.meshgrid(pos_range, pos_range, ang_range) poses = Rigid2D(x, y, ang).apply(t.Tensor.flatten) # discretize those poses, they should all map to the center bin (bin_x, bin_y, bin_ang, invalid) = mapnet.discretize_pose(poses) assert ((bin_x == center).all() and (bin_y == center).all() and (bin_ang == 0).all() and not invalid.any()) # discretize positions and angles just outside center (xo, yo, ango) = t.meshgrid(t.tensor([-0.6, 0.6]), t.tensor([-0.6, 0.6]), t.tensor([-0.26 * math.pi, 0.26 * math.pi])) poses = Rigid2D(xo, yo, ango).apply(t.Tensor.flatten) (xo, yo, ango, invalid) = mapnet.discretize_pose(poses) assert ((xo != center).all() and (yo != center).all() and (ango != 0).all() and not invalid.any()) # undiscretize center bin (xc, yc, angc) = mapnet.undiscretize_pose(t.tensor(center), t.tensor(center), t.tensor(0)) assert xc == 0 and yc == 0 and angc == 0
def full_test(exhaustive=True, flip=False, vectorization=False): """Test MapNet with toy observations""" '''# map with L-shape, ambiguous correlation result in some edge cases map = parse_map("""... *.. ##*""")''' # unambiguous map with only 2 identifiable tiles (allows triangulation) map = parse_map("""... *.. ..#""") # enlarge map by 0-padding pad = 3 map = t.nn.functional.pad(map, [pad] * 4, value=0) if flip: # rotates the map 180 degrees map = map.flip(dims=[1, 2]) if not exhaustive: # hand-crafted sequence of poses (x, y, angle) poses = [ (1+1, 1, 0+1), # center (or around it) (0, 2, 2), # bottom-left (2, 2, 0), # bottom-right (2, 0, 1), # top-right ] else: # exhaustive test of all valid poses poses = [(x, y, ang) for x in range(0, 3) for y in range(0, 3) for ang in range(4)] # start around center, to build initial map #poses.insert(0, (1, 1, 0)) poses.insert(0, (2, 1, 1)) if flip: # replace initial direction so it points the other way poses[0] = (poses[0][0], poses[0][1], 2) # account for map padding in pose coordinates poses = [(x + pad, y + pad, ang) for (x, y, ang) in poses] # get local observations obs = [extract_view(map, x, y, ang, view_range=2) for (x, y, ang) in poses] obs = t.stack(obs, dim=0) # batch of size 2, same samples obs = t.stack((obs, obs), dim=0) # run mapnet mapnet = MapNet(cnn=lambda i: i, embedding_size=3, map_size=map.shape[-1], aggregator='avg', hardmax=True, improved_padding=True, debug_vectorization=vectorization) out = mapnet(obs) # show results print(t.tensor(poses)[1:,:]) # (x, y, angle) print((out['softmax_poses'] > 0.5).nonzero()[:,(4,3,2)]) show_result(map, obs, out) if True: #not exhaustive: visualize_poses(poses, obs, map_sz=map.shape[-1], title="Ground truth observations") pred_poses = [out['softmax_poses'][0,step,:,:,:].nonzero()[0,:].flip(dims=(0,)).tolist() for step in range(len(poses) - 1)] pred_poses.insert(0, [1+pad, 1+pad, 0]) # insert map-agnostic starting pose (centered facing right) visualize_poses(pred_poses, obs, map_sz=map.shape[-1], title="Observations registered wrt predicted poses") # compare to ground truth for (step, (x, y, ang)) in enumerate(poses[1:]): # place the ground truth in the same coordinate-frame as the map, which is # created considering that the first frame is at the center looking right. # also move from/to discretized poses. gt_pose = Rigid2D(*mapnet.undiscretize_pose(t.tensor(x), t.tensor(y), t.tensor(ang))) initial_gt_pose = Rigid2D(*mapnet.undiscretize_pose(*[t.tensor(x) for x in poses[0]])) (x, y, ang, invalid) = mapnet.discretize_pose(gt_pose - initial_gt_pose) assert x >= 2 and x <= map.shape[-1] - 2 and y >= 2 and y <= map.shape[-1] - 2 and ang >= 0 and ang < 4, "GT poses going too much outside of bounds" # probability of each pose, shape = (orientations, height, width) p = out['softmax_poses'][0,step,:,:,:] assert p[ang,y,x].item() > 0.5 # peak at correct location assert p.sum().item() < 1.5 # no other peak elsewhere assert (p >= 0).all().item() # all positive
def main(): # parse command line options parser = argparse.ArgumentParser() parser.add_argument("experiment", nargs='?', default="", help="Experiment name (sub-folder for this particular run). Default: test") parser.add_argument("-data-dir", default='data/maze/', help="Directory where maze data is located") parser.add_argument("-output-dir", default='data/mapnet', help="Output directory where results will be stored (point OverBoard to this location)") parser.add_argument("-device", default="cuda:0", help="Device, cpu or cuda") parser.add_argument("-data-loaders", default=8, type=int, help="Number of asynchronous worker threads for data loading") parser.add_argument("-epochs", default=40, type=int, help="Number of training epochs") parser.add_argument("-bs", default=100, type=int, help="Batch size") parser.add_argument("-lr", default=1e-3, type=float, help="Learning rate") parser.add_argument("--no-bn", dest="bn", action="store_false", help="Disable batch normalization") parser.add_argument("-seq-length", default=5, type=int, help="Sequence length for unrolled RNN (longer creates more long-term maps)") parser.add_argument("-map-size", default=15, type=int, help="Spatial size of map memory (always square)") parser.add_argument("-embedding", default=16, type=int, help="Size of map embedding (vector stored in each map cell)") parser.add_argument("--no-improved-padding", dest="improved_padding", action="store_false", help="Disable improved padding, which ensures softmax is only over valid locations and not edges") parser.add_argument("-lstm-forget-bias", default=1.0, type=float, help="Initial value for LSTM forget gate") parser.add_argument("-max-speed", default=0, type=int, help="If non-zero, only samples trajectories with this maximum spatial distance between steps") parser.add_argument("--spawn", action="store_true", help="Use spawn multiprocessing method, to work around problem with some debuggers (e.g. VSCode)") parser.set_defaults(bn=True, improved_padding=True) args = parser.parse_args() if not t.cuda.is_available(): args.device = 'cpu' if args.spawn: # workaround for vscode debugging import torch.multiprocessing as multiprocessing multiprocessing.set_start_method('spawn', True) if not args.experiment: args.experiment = 'test' # complete directory with experiment name args.output_dir = (args.output_dir + '/' + args.experiment) if os.path.isdir(args.output_dir): input('Directory already exists. Press Enter to overwrite or Ctrl+C to cancel.') # repeatable random sequences hopefully random.seed(0) t.manual_seed(0) # initialize dataset env_size = (21, 21) full_set = Mazes(args.data_dir + '/mazes-10-10-100000.txt', env_size, seq_length=args.seq_length, max_speed=args.max_speed) (train_set, val_set) = t.utils.data.random_split(full_set, (len(full_set) - 5000, 5000)) val_loader = DataLoader(val_set, batch_size=10 * args.bs, shuffle=False, num_workers=args.data_loaders) # create base CNN and MapNet cnn = get_two_layers_cnn(args) mapnet = MapNet(cnn=cnn, embedding_size=args.embedding, map_size=args.map_size, lstm_forget_bias=args.lstm_forget_bias, improved_padding=args.improved_padding, orientations=4) # use GPU if needed device = t.device(args.device) mapnet.to(device) # create optimizer optimizer = t.optim.Adam(mapnet.parameters(), lr=args.lr) with Logger(args.output_dir, meta=args) as logger: for epoch in range(args.epochs): # refresh subset of mazes every epoch train_sampler = BatchSampler( RandomSampler(SequentialSampler(range(95000)), num_samples=10000, replacement=True), batch_size=args.bs, drop_last=True) train_loader = DataLoader(train_set, batch_sampler=train_sampler, num_workers=args.data_loaders) # training phase mapnet.train() for inputs in train_loader: # with t.autograd.detect_anomaly(): optimizer.zero_grad() loss = batch_forward(inputs, mapnet, 'train', device, args, logger) loss.backward() optimizer.step() logger.print(prefix='train', line_prefix=f"ep {epoch + 1} ") # validation phase mapnet.eval() with t.no_grad(): for inputs in val_loader: loss = batch_forward(inputs, mapnet, 'val', device, args, logger) logger.print(prefix='val', line_prefix=f"ep {epoch + 1} ") logger.append() # save state state = {'epoch': epoch, 'state_dict': mapnet.state_dict(), 'optimizer': optimizer.state_dict()} try: os.replace(args.output_dir + "/state.pt", args.output_dir + "/prev_state.pt"); except: pass t.save(state, args.output_dir + "/state.pt")