def main(args):

    dataset = "data/shapes_bg_deterministic_train.h5"
    batch_size = 1024

    bisim_model = make_pairwise_encoder()
    bisim_model.load_state_dict(torch.load(args.model_path))
    bisim_model.to(args.device)

    dataset = utils.StateTransitionsDatasetStateIds(hdf5_file=dataset)

    train_loader = data.DataLoader(dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=4)

    for batch_idx, data_batch in enumerate(train_loader):

        data_batch = [tensor.to(args.device) for tensor in data_batch]
        obs, action, next_obs, state_ids, next_state_ids = data_batch

        batch_size = obs.size(0)
        perm = np.random.permutation(batch_size)
        neg_obs = obs[perm]

        stack = torch.cat([obs, neg_obs], dim=1)
        dists = bisim_model(stack)[:, 0].detach()
        print(dists)

        if batch_idx > 10:
            break
Exemple #2
0
def main(args):

    env = BlockPushing(render_type="shapes", background=BlockPushing.BACKGROUND_DETERMINISTIC)

    bisim_model = make_pairwise_encoder()
    bisim_model.load_state_dict(torch.load(args.model_path))
    bisim_model.to(args.device)

    for _ in range(10):

        env.reset()
        base_image = env.render()

        color_images = []
        for i in range(4):
            env.background_index = i + 1
            color_images.append(env.render())

        random_images = []
        for i in range(10):
            env.reset()
            env.background_index = np.random.randint(5)
            random_images.append(env.render())

        all_other_images = [base_image] + color_images + random_images

        if args.plot:
            plot_all_images(all_other_images)

        pairs_first = pair(base_image, all_other_images, base_first=True)
        pairs_second = pair(base_image, all_other_images, base_first=False)

        pairs_first = stack_to_pt(pairs_first, args.device)
        pairs_second = stack_to_pt(pairs_second, args.device)

        dists_first = bisim_model(pairs_first)[:, 0].cpu().numpy()
        dists_second = bisim_model(pairs_second)[:, 0].cpu().numpy()

        print("[base_image, base_image]: {:s}".format(str(dists_first[0])))

        print("[base_image, other_image]")
        print("different colors: {:s}".format(str(dists_first[1:5])))
        print("different states: {:s}".format(str(dists_first[5:])))

        print("[other_image, base_image]")
        print("different colors: {:s}".format(str(dists_second[1:5])))
        print("different states: {:s}".format(str(dists_second[5:])))
Exemple #3
0
def evaluate(args, args_eval, model_file):

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    ex = None
    if args_eval.sacred:
        from sacred import Experiment
        from sacred.observers import MongoObserver
        ex = Experiment(args_eval.sacred_name)
        ex.observers.append(
            MongoObserver(url=constants.MONGO_URI, db_name=constants.DB_NAME))
        ex.add_config({
            "batch_size": args.batch_size,
            "epochs": args.epochs,
            "learning_rate": args.learning_rate,
            "encoder": args.encoder,
            "num_objects": args.num_objects,
            "custom_neg": args.custom_neg,
            "in_ep_prob": args.in_ep_prob,
            "seed": args.seed,
            "dataset": args.dataset,
            "save_folder": args.save_folder,
            "eval_dataset": args_eval.dataset,
            "num_steps": args_eval.num_steps,
            "use_action_attention": args.use_action_attention
        })

    device = torch.device('cuda' if args.cuda else 'cpu')

    dataset = utils.PathDatasetStateIds(hdf5_file=args.dataset,
                                        path_length=args_eval.num_steps)
    eval_loader = data.DataLoader(dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=4)

    # Get data sample
    obs = eval_loader.__iter__().next()[0]
    input_shape = obs[0][0].size()

    model = modules.ContrastiveSWM(
        embedding_dim=args.embedding_dim,
        hidden_dim=args.hidden_dim,
        action_dim=args.action_dim,
        input_dims=input_shape,
        num_objects=args.num_objects,
        sigma=args.sigma,
        hinge=args.hinge,
        ignore_action=args.ignore_action,
        copy_action=args.copy_action,
        split_mlp=args.split_mlp,
        same_ep_neg=args.same_ep_neg,
        only_same_ep_neg=args.only_same_ep_neg,
        immovable_bit=args.immovable_bit,
        split_gnn=args.split_gnn,
        no_loss_first_two=args.no_loss_first_two,
        bisim_model=make_pairwise_encoder() if args.bisim_model_path else None,
        encoder=args.encoder,
        use_action_attention=args.use_action_attention).to(device)

    model.load_state_dict(torch.load(model_file))
    model.eval()

    # topk = [1, 5, 10]
    topk = [1]
    hits_at = defaultdict(int)
    num_samples = 0
    rr_sum = 0

    pred_states = []
    next_states = []
    next_ids = []

    with torch.no_grad():

        for batch_idx, data_batch in enumerate(eval_loader):
            data_batch = [[t.to(device) for t in tensor]
                          for tensor in data_batch]
            observations, actions, state_ids = data_batch

            if observations[0].size(0) != args.batch_size:
                continue

            obs = observations[0]
            next_obs = observations[-1]
            next_id = state_ids[-1]

            state = model.obj_encoder(model.obj_extractor(obs))
            next_state = model.obj_encoder(model.obj_extractor(next_obs))

            pred_state = state
            for i in range(args_eval.num_steps):
                pred_state = model.forward_transition(pred_state, actions[i])

            pred_states.append(pred_state.cpu())
            next_states.append(next_state.cpu())
            next_ids.append(next_id.cpu().numpy())

        pred_state_cat = torch.cat(pred_states, dim=0)
        next_state_cat = torch.cat(next_states, dim=0)
        next_ids_cat = np.concatenate(next_ids, axis=0)

        full_size = pred_state_cat.size(0)

        # Flatten object/feature dimensions
        next_state_flat = next_state_cat.view(full_size, -1)
        pred_state_flat = pred_state_cat.view(full_size, -1)

        dist_matrix = utils.pairwise_distance_matrix(next_state_flat,
                                                     pred_state_flat)

        #num_digits = 1
        #dist_matrix = (dist_matrix * 10 ** num_digits).round() / (10 ** num_digits)
        #dist_matrix = dist_matrix.float()

        dist_matrix_diag = torch.diag(dist_matrix).unsqueeze(-1)
        dist_matrix_augmented = torch.cat([dist_matrix_diag, dist_matrix],
                                          dim=1)

        # Workaround to get a stable sort in numpy.
        dist_np = dist_matrix_augmented.numpy()
        indices = []
        for row in dist_np:
            keys = (np.arange(len(row)), row)
            indices.append(np.lexsort(keys))
        indices = np.stack(indices, axis=0)

        if args_eval.dedup:
            mask_mistakes = indices[:, 0] != 0
            closest_next_ids = next_ids_cat[indices[:, 0] - 1]

            if len(next_ids_cat.shape) == 2:
                equal_mask = np.all(closest_next_ids == next_ids_cat, axis=1)
            else:
                equal_mask = closest_next_ids == next_ids_cat

            indices[:, 0][np.logical_and(equal_mask, mask_mistakes)] = 0

        indices = torch.from_numpy(indices).long()

        #print('Processed {} batches of size {}'.format(
        #    batch_idx + 1, args.batch_size))

        labels = torch.zeros(indices.size(0),
                             device=indices.device,
                             dtype=torch.int64).unsqueeze(-1)

        num_samples += full_size
        #print('Size of current topk evaluation batch: {}'.format(
        #    full_size))

        for k in topk:
            match = indices[:, :k] == labels
            num_matches = match.sum()
            hits_at[k] += num_matches.item()

        match = indices == labels
        _, ranks = match.max(1)

        reciprocal_ranks = torch.reciprocal(ranks.double() + 1)
        rr_sum += reciprocal_ranks.sum().item()

        pred_states = []
        next_states = []
        next_ids = []

    hits = hits_at[topk[0]] / float(num_samples)
    mrr = rr_sum / float(num_samples)

    if ex is not None:
        # ugly hack
        @ex.main
        def sacred_main():
            ex.log_scalar("hits", hits)
            ex.log_scalar("mrr", mrr)

        ex.run()

    print('Hits @ {}: {}'.format(topk[0], hits))
    print('MRR: {}'.format(mrr))

    return hits, mrr
Exemple #4
0
def evaluate(args, args_eval, model_file):

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    ex = None
    if args_eval.sacred:
        from sacred import Experiment
        from sacred.observers import MongoObserver
        ex = Experiment(args_eval.sacred_name)
        ex.observers.append(MongoObserver(url=constants.MONGO_URI, db_name=constants.DB_NAME))
        ex.add_config({
            "batch_size": args.batch_size,
            "epochs": args.epochs,
            "learning_rate": args.learning_rate,
            "encoder": args.encoder,
            "num_objects": args.num_objects,
            "custom_neg": args.custom_neg,
            "in_ep_prob": args.in_ep_prob,
            "seed": args.seed,
            "dataset": args.dataset,
            "save_folder": args.save_folder,
            "eval_dataset": args_eval.dataset,
            "num_steps": args_eval.num_steps,
            "use_action_attention": args.use_action_attention
        })

    device = torch.device('cuda' if args.cuda else 'cpu')

    dataset = utils.PathDatasetStateIds(
        hdf5_file=args.dataset, path_length=10)
    eval_loader = data.DataLoader(
        dataset, batch_size=100, shuffle=False, num_workers=4)

    # Get data sample
    obs = eval_loader.__iter__().next()[0]
    input_shape = obs[0][0].size()

    model = modules.ContrastiveSWM(
        embedding_dim=args.embedding_dim,
        hidden_dim=args.hidden_dim,
        action_dim=args.action_dim,
        input_dims=input_shape,
        num_objects=args.num_objects,
        sigma=args.sigma,
        hinge=args.hinge,
        ignore_action=args.ignore_action,
        copy_action=args.copy_action,
        split_mlp=args.split_mlp,
        same_ep_neg=args.same_ep_neg,
        only_same_ep_neg=args.only_same_ep_neg,
        immovable_bit=args.immovable_bit,
        split_gnn=args.split_gnn,
        no_loss_first_two=args.no_loss_first_two,
        bisim_model=make_pairwise_encoder() if args.bisim_model_path else None,
        encoder=args.encoder,
        use_action_attention=args.use_action_attention
    ).to(device)

    model.load_state_dict(torch.load(model_file))
    model.eval()

    hits_list = []

    with torch.no_grad():

        for batch_idx, data_batch in enumerate(eval_loader):

            data_batch = [[t.to(
                device) for t in tensor] for tensor in data_batch]

            observations, actions, state_ids = data_batch

            if observations[0].size(0) != args.batch_size:
                continue

            states = []
            for obs in observations:
                states.append(model.obj_encoder(model.obj_extractor(obs)))
            states = torch.stack(states, dim=0)
            state_ids = torch.stack(state_ids, dim=0)

            pred_state = states[0]
            if not args_eval.no_transition:
                for i in range(args_eval.num_steps):
                    pred_state = model.forward_transition(pred_state, actions[i])

            # pred_state: [100, |O|, D]
            # states: [10, 100, |O|, D]
            # pred_state_flat: [100, X]
            # states_flat: [10, 100, X]
            pred_state_flat = pred_state.reshape((pred_state.size(0), pred_state.size(1) * pred_state.size(2)))
            states_flat = states.reshape((states.size(0), states.size(1), states.size(2) * states.size(3)))

            # dist_matrix: [10, 100]
            dist_matrix = (states_flat - pred_state_flat[None]).pow(2).sum(2)
            indices = torch.argmin(dist_matrix, dim=0)
            correct = indices == args_eval.num_steps

            # print(indices[0], args_eval.num_steps)
            # observations = torch.stack(observations, dim=0)
            # correct_obs = observations[args_eval.num_steps, 0]
            # pred_obs = observations[indices[0], 0]
            # import matplotlib
            # matplotlib.use("TkAgg")
            # import matplotlib.pyplot as plt
            # plt.subplot(1, 2, 1)
            # plt.imshow(correct_obs.cpu().numpy()[3:].transpose((1, 2, 0)))
            # plt.subplot(1, 2, 2)
            # plt.imshow(pred_obs.cpu().numpy()[3:].transpose((1, 2, 0)))
            # plt.show()

            # check for duplicates
            if args_eval.dedup:
                equal_mask = torch.all(state_ids[indices, list(range(100))] == state_ids[args_eval.num_steps], dim=1)
                correct = correct + equal_mask

            # hits
            hits_list.append(correct.float().mean().item())

    hits = np.mean(hits_list)

    if ex is not None:
        # ugly hack
        @ex.main
        def sacred_main():
            ex.log_scalar("hits", hits)
            ex.log_scalar("mrr", 0.)

        ex.run()

    print('Hits @ 1: {}'.format(hits))

    return hits, 0.
Exemple #5
0
# Get data sample
obs = train_loader.__iter__().next()[0]
input_shape = obs[0].size()

# maybe load bisim metric and turn it into torch tensor on the selected device
bisim_metric = None
if args.bisim_metric_path is not None:
    bisim_metric = torch.tensor(np.load(args.bisim_metric_path),
                                dtype=torch.float32,
                                device=device)

# maybe load bisim model
bisim_model = None
if args.bisim_model_path is not None:
    bisim_model = make_pairwise_encoder()
    bisim_model.load_state_dict(torch.load(args.bisim_model_path))

model = modules.ContrastiveSWM(
    embedding_dim=args.embedding_dim,
    hidden_dim=args.hidden_dim,
    action_dim=args.action_dim,
    input_dims=input_shape,
    num_objects=args.num_objects,
    sigma=args.sigma,
    hinge=args.hinge,
    ignore_action=args.ignore_action,
    copy_action=args.copy_action,
    split_mlp=args.split_mlp,
    same_ep_neg=args.same_ep_neg,
    only_same_ep_neg=args.only_same_ep_neg,
Exemple #6
0
    embedding_dim=args.embedding_dim,
    hidden_dim=args.hidden_dim,
    action_dim=args.action_dim,
    input_dims=input_shape,
    num_objects=args.num_objects,
    sigma=args.sigma,
    hinge=args.hinge,
    ignore_action=args.ignore_action,
    copy_action=args.copy_action,
    split_mlp=args.split_mlp,
    same_ep_neg=args.same_ep_neg,
    only_same_ep_neg=args.only_same_ep_neg,
    immovable_bit=args.immovable_bit,
    split_gnn=args.split_gnn,
    no_loss_first_two=args.no_loss_first_two,
    bisim_model=make_pairwise_encoder() if args.bisim_model_path else None,
    encoder=args.encoder,
    use_coord_grid=args.coord_grid,
    use_action_attention=args.use_action_attention).to(device)

model.load_state_dict(torch.load(model_file))
model.eval()

# topk = [1, 5, 10]
topk = [1]
hits_at = defaultdict(int)
num_samples = 0
rr_sum = 0

pred_states = []
next_states = []