Exemple #1
0
def main():
    global net
    global test_loader
    global scatter
    parser = argparse.ArgumentParser()
    # generic params
    parser.add_argument('--name', default=datetime.now().strftime('%Y-%m-%d_%H:%M:%S'), help='Name to store the log file as')
    parser.add_argument('--resume', help='Path to log file to resume from')

    parser.add_argument('--encoder', default='FSEncoder', help='Encoder model')
    parser.add_argument('--decoder', default='FSDecoder', help='Decoder model')
    parser.add_argument('--cardinality', type=int, default=20, help='Size of set')
    parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train with')
    parser.add_argument('--latent', type=int, default=8, help='Dimensionality of latent space')
    parser.add_argument('--dim', type=int, default=64, help='Dimensionality of hidden layers')
    parser.add_argument('--lr', type=float, default=1e-2, help='Learning rate of model')
    parser.add_argument('--batch-size', type=int, default=32, help='Batch size to train with')
    parser.add_argument('--num-workers', type=int, default=4, help='Number of threads for data loader')
    parser.add_argument('--samples', type=int, default=2**14, help='Dataset size')
    parser.add_argument('--skip', action='store_true', help='Skip permutation use in decoder')
    parser.add_argument('--mnist', action='store_true', help='Use MNIST dataset')
    parser.add_argument('--masked', action='store_true', help='Use masked version of MNIST dataset')
    parser.add_argument('--no-cuda', action='store_true', help='Run on CPU instead of GPU (not recommended)')
    parser.add_argument('--train-only', action='store_true', help='Only run training, no evaluation')
    parser.add_argument('--eval-only', action='store_true', help='Only run evaluation, no training')
    parser.add_argument('--multi-gpu', action='store_true', help='Use multiple GPUs')
    parser.add_argument('--show', action='store_true', help='Show generated samples')
    parser.add_argument('--classify', action='store_true', help='Classifier version')
    parser.add_argument('--freeze-encoder', action='store_true', help='Freeze weights in encoder')

    parser.add_argument('--loss', choices=['direct', 'hungarian', 'chamfer'], default='direct', help='Type of loss used')

    parser.add_argument('--shift', action='store_true', help='')
    parser.add_argument('--rotate', action='store_true', help='')
    parser.add_argument('--scale', action='store_true', help='')
    parser.add_argument('--variable', action='store_true', help='')
    parser.add_argument('--noise', type=float, default=0, help='Standard deviation of noise')
    args = parser.parse_args()


    if args.mnist:
        args.cardinality = 342

    model_args = {
        'set_size': args.cardinality,
        'dim': args.dim,
        'skip': args.skip,
        'relaxed': not args.classify,  # usually relaxed, not relaxed when classifying
    }
    net_class = SAE
    net = net_class(
        encoder=globals()[args.encoder],
        decoder=globals()[args.decoder],
        latent_dim=args.latent,
        encoder_args=model_args,
        decoder_args=model_args,
        classify=args.classify,
        input_channels=3 if args.mnist and args.masked else 2,
    )

    if not args.no_cuda:
        net = net.cuda()

    if args.multi_gpu:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.Adam([p for p in net.parameters() if p.requires_grad], lr=args.lr)

    dataset_settings = {
        'cardinality': args.cardinality,
        'shift': args.shift,
        'rotate': args.rotate,
        'scale': args.scale,
        'variable': args.variable,
    }
    if not args.mnist:
        dataset_train = data.Polygons(size=args.samples, **dataset_settings)
        dataset_test = data.Polygons(size=2**14, **dataset_settings)
    else:
        if not args.masked:
            dataset_train = data.MNISTSet(train=True)
            dataset_test = data.MNISTSet(train=False)
        else:
            dataset_train = data.MNISTSetMasked(train=True)
            dataset_test = data.MNISTSetMasked(train=False)

    train_loader = data.get_loader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers)
    test_loader = data.get_loader(dataset_test, batch_size=args.batch_size, num_workers=args.num_workers)

    tracker = track.Tracker(
        train_mse=track.ExpMean(),
        train_cha=track.ExpMean(),
        train_loss=track.ExpMean(),
        train_acc=track.ExpMean(),

        test_mse=track.Mean(),
        test_cha=track.Mean(),
        test_loss=track.Mean(),
        test_acc=track.Mean(),
    )

    if args.resume:
        log = torch.load(args.resume)
        weights = log['weights']
        n = net
        if args.multi_gpu:
            n = n.module
        strict = not args.classify
        n.load_state_dict(weights, strict=strict)
        if args.freeze_encoder:
            for p in n.encoder.parameters():
                p.requires_grad = False


    def outer(a, b=None):
        if b is None:
            b = a
        size_a = tuple(a.size()) + (b.size()[-1],)
        size_b = tuple(b.size()) + (a.size()[-1],)
        a = a.unsqueeze(dim=-1).expand(*size_a)
        b = b.unsqueeze(dim=-2).expand(*size_b)
        return a, b


    def hungarian_loss(predictions, targets):
        # predictions and targets shape :: (n, c, s)
        predictions, targets = outer(predictions, targets)
        # squared_error shape :: (n, s, s)
        squared_error = (predictions - targets).pow(2).mean(1)

        squared_error_np = squared_error.detach().cpu().numpy()
        indices = pool.map(per_sample_hungarian_loss, squared_error_np)
        losses = [sample[row_idx, col_idx].mean() for sample, (row_idx, col_idx) in zip(squared_error, indices)]
        total_loss = torch.mean(torch.stack(list(losses)))
        return total_loss


    def chamfer_loss(predictions, targets):
        # predictions and targets shape :: (n, c, s)
        predictions, targets = outer(predictions, targets)
        # squared_error shape :: (n, s, s)
        squared_error = (predictions - targets).pow(2).mean(1)
        loss = squared_error.min(1)[0] + squared_error.min(2)[0]
        return loss.mean()

    def run(net, loader, optimizer, train=False, epoch=0, pool=None):
        if train:
            net.train()
            prefix = 'train'
        else:
            net.eval()
            prefix = 'test'

        total_train_steps = args.epochs * len(loader)

        loader = tqdm(loader, ncols=0, desc='{1} E{0:02d}'.format(epoch, 'train' if train else 'test '))
        for i, sample in enumerate(loader):
            points, labels, n_points = map(lambda x: x.cuda(), sample)

            if args.decoder != 'FSDecoder' and points.size(2) < args.cardinality:
                # pad to fixed size
                padding = torch.zeros(points.size(0), points.size(1), args.cardinality - points.size(2)).to(points.device)
                points = torch.cat([points, padding], dim=2)

            if args.noise > 0:
                noise = torch.randn_like(points) * args.noise
                input_points = points + noise
            else:
                input_points = points
            pred = net(input_points, n_points)

            mse, cha, acc = torch.FloatTensor([-1, -1, -1])
            if not args.classify:
                mse = (pred - points).pow(2).mean()
                cha = chamfer_loss(pred, points)
                if args.loss == 'direct':
                    loss = mse
                elif args.loss == 'chamfer':
                    loss = cha
                elif args.loss == 'hungarian':
                    loss = hungarian_loss(pred, points)
                else:
                    raise NotImplementedError
            else:
                loss = F.cross_entropy(pred, labels)
                acc = (pred.max(dim=1)[1] == labels).float().mean()

            if train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            tracked_mse = tracker.update('{}_mse'.format(prefix), mse.item())
            tracked_cha = tracker.update('{}_cha'.format(prefix), cha.item())
            tracked_loss = tracker.update('{}_loss'.format(prefix), loss.item())
            tracked_acc = tracker.update('{}_acc'.format(prefix), acc.item())

            fmt = '{:.5f}'.format
            loader.set_postfix(
                mse=fmt(tracked_mse),
                cha=fmt(tracked_cha),
                loss=fmt(tracked_loss),
                acc=fmt(tracked_acc),
            )

            if args.show and not train:
                #scatter(input_points, n_points, marker='o', transpose=args.mnist)
                scatter(pred, n_points, marker='x', transpose=args.mnist)
                plt.axes().set_aspect('equal', 'datalim')
                plt.show()


    def scatter(tensor, n_points, transpose=False, *args, **kwargs):
        x, y = tensor[0].detach().cpu().numpy()
        n = n_points[0].detach().cpu().numpy()
        if transpose:
            x, y = y, x
            y = -y
        plt.scatter(x[:n], y[:n], *args, **kwargs)


    import subprocess
    git_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD'])

    #torch.backends.cudnn.benchmark = True

    for epoch in range(args.epochs):
        tracker.new_epoch()
        with mp.Pool(4) as pool:
            if not args.eval_only:
                run(net, train_loader, optimizer, train=True, epoch=epoch, pool=pool)
            if not args.train_only:
                run(net, test_loader, optimizer, train=False, epoch=epoch, pool=pool)

        results = {
            'name': args.name,
            'tracker': tracker.data,
            'weights': net.state_dict() if not args.multi_gpu else net.module.state_dict(),
            'args': vars(args),
            'hash': git_hash,
        }
        torch.save(results, os.path.join('logs', args.name))
        if args.eval_only:
            break
def main():
    global net
    global test_loader
    global scatter
    parser = argparse.ArgumentParser()
    # generic params
    parser.add_argument(
        "--name",
        default=datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
        help="Name to store the log file as",
    )
    parser.add_argument("--resume", help="Path to log file to resume from")

    parser.add_argument("--encoder", default="FSEncoder", help="Encoder")
    parser.add_argument("--decoder", default="DSPN", help="Decoder")
    parser.add_argument("--epochs",
                        type=int,
                        default=10,
                        help="Number of epochs to train with")
    parser.add_argument("--latent",
                        type=int,
                        default=32,
                        help="Dimensionality of latent space")
    parser.add_argument("--dim",
                        type=int,
                        default=64,
                        help="Dimensionality of hidden layers")
    parser.add_argument("--lr",
                        type=float,
                        default=1e-2,
                        help="Outer learning rate of model")
    parser.add_argument("--batch-size",
                        type=int,
                        default=32,
                        help="Batch size to train with")
    parser.add_argument("--num-workers",
                        type=int,
                        default=4,
                        help="Number of threads for data loader")
    parser.add_argument(
        "--dataset",
        choices=["mnist", "clevr-box", "clevr-state"],
        help="Use MNIST dataset",
    )
    parser.add_argument(
        "--no-cuda",
        action="store_true",
        help="Run on CPU instead of GPU (not recommended)",
    )
    parser.add_argument("--train-only",
                        action="store_true",
                        help="Only run training, no evaluation")
    parser.add_argument("--eval-only",
                        action="store_true",
                        help="Only run evaluation, no training")
    parser.add_argument("--multi-gpu",
                        action="store_true",
                        help="Use multiple GPUs")
    parser.add_argument("--show",
                        action="store_true",
                        help="Plot generated samples in Tensorboard")

    parser.add_argument("--supervised", action="store_true", help="")
    parser.add_argument("--baseline",
                        action="store_true",
                        help="Use baseline model")

    parser.add_argument("--export-dir",
                        type=str,
                        help="Directory to output samples to")
    parser.add_argument("--export-n",
                        type=int,
                        default=10**9,
                        help="How many samples to output")
    parser.add_argument(
        "--export-progress",
        action="store_true",
        help="Output intermediate set predictions for DSPN?",
    )
    parser.add_argument(
        "--full-eval",
        action="store_true",
        help=
        "Use full evaluation set (default: 1/10 of evaluation data)",  # don't need full evaluation when training to save some time
    )
    parser.add_argument(
        "--mask-feature",
        action="store_true",
        help="Treat mask as a feature to compute loss with",
    )
    parser.add_argument(
        "--inner-lr",
        type=float,
        default=800,
        help="Learning rate of DSPN inner optimisation",
    )
    parser.add_argument(
        "--iters",
        type=int,
        default=10,
        help="How many DSPN inner optimisation iteration to take",
    )
    parser.add_argument(
        "--huber-repr",
        type=float,
        default=1,
        help="Scaling of representation loss term for DSPN supervised learning",
    )
    parser.add_argument(
        "--loss",
        choices=["hungarian", "chamfer"],
        default="hungarian",
        help="Type of loss used",
    )
    args = parser.parse_args()

    train_writer = SummaryWriter(f"runs/{args.name}", purge_step=0)

    net = model.build_net(args)

    if not args.no_cuda:
        net = net.cuda()

    if args.multi_gpu:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.Adam(
        [p for p in net.parameters() if p.requires_grad], lr=args.lr)

    if args.dataset == "mnist":
        dataset_train = data.MNISTSet(train=True, full=args.full_eval)
        dataset_test = data.MNISTSet(train=False, full=args.full_eval)
    else:
        dataset_train = data.CLEVR("clevr",
                                   "train",
                                   box=args.dataset == "clevr-box",
                                   full=args.full_eval)
        dataset_test = data.CLEVR("clevr",
                                  "val",
                                  box=args.dataset == "clevr-box",
                                  full=args.full_eval)

    if not args.eval_only:
        train_loader = data.get_loader(dataset_train,
                                       batch_size=args.batch_size,
                                       num_workers=args.num_workers)
    if not args.train_only:
        test_loader = data.get_loader(
            dataset_test,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            shuffle=False,
        )

    tracker = track.Tracker(
        train_mae=track.ExpMean(),
        train_last=track.ExpMean(),
        train_loss=track.ExpMean(),
        test_mae=track.Mean(),
        test_last=track.Mean(),
        test_loss=track.Mean(),
    )

    if args.resume:
        log = torch.load(args.resume)
        weights = log["weights"]
        n = net
        if args.multi_gpu:
            n = n.module
        n.load_state_dict(weights, strict=True)

    def run(net, loader, optimizer, train=False, epoch=0, pool=None):
        writer = train_writer
        if train:
            net.train()
            prefix = "train"
            torch.set_grad_enabled(True)
        else:
            net.eval()
            prefix = "test"
            torch.set_grad_enabled(False)

        total_train_steps = args.epochs * len(loader)
        if args.export_dir:
            true_export = []
            pred_export = []

        iters_per_epoch = len(loader)
        loader = tqdm(
            loader,
            ncols=0,
            desc="{1} E{0:02d}".format(epoch, "train" if train else "test "),
        )
        for i, sample in enumerate(loader, start=epoch * iters_per_epoch):
            # input is either a set or an image
            input, target_set, target_mask = map(lambda x: x.cuda(), sample)

            # forward evaluation through the network
            (progress, masks, evals,
             gradn), (y_enc, y_label) = net(input, target_set, target_mask)

            progress_only = progress

            # if using mask as feature, concat mask feature into progress
            if args.mask_feature:
                target_set = torch.cat(
                    [target_set, target_mask.unsqueeze(dim=1)], dim=1)
                progress = [
                    torch.cat([p, m.unsqueeze(dim=1)], dim=1)
                    for p, m in zip(progress, masks)
                ]

            if args.loss == "chamfer":
                # dim 0 is over the inner iteration steps
                # target set is broadcasted over dim 0
                set_loss = utils.chamfer_loss(torch.stack(progress),
                                              target_set.unsqueeze(0))
            else:
                # dim 0 is over the inner iteration steps
                a = torch.stack(progress)
                # target set is explicitly broadcasted over dim 0
                b = target_set.repeat(a.size(0), 1, 1, 1)
                # flatten inner iteration dim and batch dim
                a = a.view(-1, a.size(2), a.size(3))
                b = b.view(-1, b.size(2), b.size(3))
                set_loss = utils.hungarian_loss(progress[-1],
                                                target_set,
                                                thread_pool=pool).unsqueeze(0)
            # Only use representation loss with DSPN and when doing general supervised prediction, not when auto-encoding
            if args.supervised and not args.baseline:
                repr_loss = args.huber_repr * F.smooth_l1_loss(y_enc, y_label)
                loss = set_loss.mean() + repr_loss.mean()
            else:
                loss = set_loss.mean()

            # restore progress variable to not contain masks for correct exporting
            progress = progress_only

            # Outer optim step
            if train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # Tensorboard tracking of metrics for debugging
            tracked_last = tracker.update("{}_last".format(prefix),
                                          set_loss[-1].item())
            tracked_loss = tracker.update("{}_loss".format(prefix),
                                          loss.item())
            if train:
                writer.add_scalar("metric/set-loss",
                                  loss.item(),
                                  global_step=i)
                writer.add_scalar("metric/set-last",
                                  set_loss[-1].mean().item(),
                                  global_step=i)
                if not args.baseline:
                    writer.add_scalar("metric/eval-first",
                                      evals[0].mean().item(),
                                      global_step=i)
                    writer.add_scalar("metric/eval-last",
                                      evals[-1].mean().item(),
                                      global_step=i)
                    writer.add_scalar(
                        "metric/max-inner-grad-norm",
                        max(g.item() for g in gradn),
                        global_step=i,
                    )
                    writer.add_scalar(
                        "metric/mean-inner-grad-norm",
                        sum(g.item() for g in gradn) / len(gradn),
                        global_step=i,
                    )
                    if args.supervised:
                        writer.add_scalar("metric/repr_loss",
                                          repr_loss.item(),
                                          global_step=i)

            # Print current progress to progress bar
            fmt = "{:.6f}".format
            loader.set_postfix(
                last=fmt(tracked_last),
                loss=fmt(tracked_loss),
                bad=fmt(evals[-1].detach().cpu().item() *
                        1000) if not args.baseline else 0,
            )

            # Store predictions to be exported
            if args.export_dir:
                if len(true_export) < args.export_n:
                    for p, m in zip(target_set, target_mask):
                        true_export.append(p.detach().cpu())
                    progress_steps = []
                    for pro, mas in zip(progress, masks):
                        # pro and mas are one step of the inner optim
                        # score boxes contains the list of predicted elements for one step
                        score_boxes = []
                        for p, m in zip(pro.cpu().detach(),
                                        mas.cpu().detach()):
                            score_box = torch.cat([m.unsqueeze(0), p], dim=0)
                            score_boxes.append(score_box)
                        progress_steps.append(score_boxes)
                    for b in zip(*progress_steps):
                        pred_export.append(b)

            # Plot predictions in Tensorboard
            if args.show and not train:
                name = f"set/epoch-{epoch}/img-{i}"
                # thresholded set
                progress.append(progress[-1])
                masks.append((masks[-1] > 0.5).float())
                # target set
                if args.mask_feature:
                    # target set is augmented with masks, so remove them
                    progress.append(target_set[:, :-1])
                else:
                    progress.append(target_set)
                masks.append(target_mask)
                # intermediate sets
                for j, (s, ms) in enumerate(zip(progress, masks)):
                    if args.dataset == "clevr-state":
                        continue
                    s, ms = utils.scatter_masked(
                        s,
                        ms,
                        binned=args.dataset.startswith("clevr"),
                        threshold=0.5
                        if args.dataset.startswith("clevr") else None,
                    )
                    tag_name = f"{name}" if j != len(
                        progress) - 1 else f"{name}-target"
                    if args.dataset == "clevr-box":
                        img = input[0].detach().cpu()
                        writer.add_image_with_boxes(tag_name,
                                                    img,
                                                    s.transpose(0, 1),
                                                    global_step=j)
                    elif args.dataset == "clevr-state":
                        pass
                    else:  # mnist
                        fig = plt.figure()
                        y, x = s
                        y = 1 - y
                        ms = ms.numpy()
                        rgba_colors = np.zeros((ms.size, 4))
                        rgba_colors[:, 2] = 1.0
                        rgba_colors[:, 3] = ms
                        plt.scatter(x, y, color=rgba_colors)
                        plt.axes().set_aspect("equal")
                        plt.xlim(0, 1)
                        plt.ylim(0, 1)
                        writer.add_figure(tag_name, fig, global_step=j)

        # Export predictions
        if args.export_dir:
            os.makedirs(f"{args.export_dir}/groundtruths", exist_ok=True)
            os.makedirs(f"{args.export_dir}/detections", exist_ok=True)
            for i, (gt, dets) in enumerate(zip(true_export, pred_export)):
                with open(f"{args.export_dir}/groundtruths/{i}.txt",
                          "w") as fd:
                    for box in gt.transpose(0, 1):
                        if (box == 0).all():
                            continue
                        s = "box " + " ".join(map(str, box.tolist()))
                        fd.write(s + "\n")
                if args.export_progress:
                    for step, det in enumerate(dets):
                        with open(
                                f"{args.export_dir}/detections/{i}-step{step}.txt",
                                "w") as fd:
                            for sbox in det.transpose(0, 1):
                                s = f"box " + " ".join(map(str, sbox.tolist()))
                                fd.write(s + "\n")
                with open(f"{args.export_dir}/detections/{i}.txt", "w") as fd:
                    for sbox in dets[-1].transpose(0, 1):
                        s = f"box " + " ".join(map(str, sbox.tolist()))
                        fd.write(s + "\n")

    import subprocess

    git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"])

    torch.backends.cudnn.benchmark = True

    for epoch in range(args.epochs):
        tracker.new_epoch()
        with mp.Pool(10) as pool:
            if not args.eval_only:
                run(net,
                    train_loader,
                    optimizer,
                    train=True,
                    epoch=epoch,
                    pool=pool)
            if not args.train_only:
                run(net,
                    test_loader,
                    optimizer,
                    train=False,
                    epoch=epoch,
                    pool=pool)

        results = {
            "name":
            args.name,
            "tracker":
            tracker.data,
            "weights":
            net.state_dict()
            if not args.multi_gpu else net.module.state_dict(),
            "args":
            vars(args),
            "hash":
            git_hash,
        }
        torch.save(results, os.path.join("logs", args.name))
        if args.eval_only:
            break
Exemple #3
0
def main():
    global net
    global test_loader
    global scatter
    parser = argparse.ArgumentParser()
    # generic params
    parser.add_argument(
        "--name",
        default=datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
        help="Name to store the log file as",
    )
    parser.add_argument("--resume", help="Path to log file to resume from")

    parser.add_argument("--encoder", default="FSEncoder", help="Encoder")
    parser.add_argument("--decoder", default="DSPN", help="Decoder")
    parser.add_argument("--epochs",
                        type=int,
                        default=10,
                        help="Number of epochs to train with")
    parser.add_argument("--latent",
                        type=int,
                        default=32,
                        help="Dimensionality of latent space")
    parser.add_argument("--dim",
                        type=int,
                        default=64,
                        help="Dimensionality of hidden layers")
    parser.add_argument("--lr",
                        type=float,
                        default=1e-2,
                        help="Outer learning rate of model")
    parser.add_argument("--batch-size",
                        type=int,
                        default=12,
                        help="Batch size to train with")
    parser.add_argument("--num-workers",
                        type=int,
                        default=0,
                        help="Number of threads for data loader")
    parser.add_argument(
        "--dataset",
        choices=[
            "mnist", "clevr-box", "clevr-state", "cats", "merged", "wflw"
        ],
        help="Which dataset to use",
    )
    parser.add_argument(
        "--no-cuda",
        action="store_true",
        help="Run on CPU instead of GPU (not recommended)",
    )
    parser.add_argument("--train-only",
                        action="store_true",
                        help="Only run training, no evaluation")
    parser.add_argument("--eval-only",
                        action="store_true",
                        help="Only run evaluation, no training")
    parser.add_argument("--multi-gpu",
                        action="store_true",
                        help="Use multiple GPUs")
    parser.add_argument("--show",
                        action="store_true",
                        help="Plot generated samples in Tensorboard")
    parser.add_argument(
        "--show-skip",
        type=int,
        default=1,
        help="Number of epochs to skip before exporting to Tensorboard")

    parser.add_argument(
        "--infer-name",
        action="store_true",
        help="Automatically name run based on dataset/run number")

    parser.add_argument("--supervised", action="store_true", help="")
    parser.add_argument("--baseline",
                        action="store_true",
                        help="Use baseline model")

    parser.add_argument("--export-dir",
                        type=str,
                        help="Directory to output samples to")
    parser.add_argument("--export-n",
                        type=int,
                        default=10**9,
                        help="How many samples to output")
    parser.add_argument(
        "--export-progress",
        action="store_true",
        help="Output intermediate set predictions for DSPN?",
    )
    parser.add_argument(
        "--full-eval",
        action="store_true",
        help="Use full evaluation set (default: 1/10 of evaluation data)",
        # don't need full evaluation when training to save some time
    )
    parser.add_argument(
        "--mask-feature",
        action="store_true",
        help="Treat mask as a feature to compute loss with",
    )
    parser.add_argument(
        "--inner-lr",
        type=float,
        default=800,
        help="Learning rate of DSPN inner optimisation",
    )
    parser.add_argument(
        "--iters",
        type=int,
        default=10,
        help="How many DSPN inner optimisation iteration to take",
    )
    parser.add_argument(
        "--huber-repr",
        type=float,
        default=1,
        help="Scaling of repr loss term for DSPN supervised learning",
    )
    parser.add_argument(
        "--loss",
        choices=["hungarian", "chamfer", "emd"],
        default="emd",
        help="Type of loss used",
    )
    parser.add_argument(
        "--export-csv",
        action="store_true",
        help="Only perform predictions, don't evaluate in any way")
    parser.add_argument("--eval-split", help="Overwrite split on test set")

    args = parser.parse_args()

    if args.infer_name:
        if args.baseline:
            prefix = "base"
        else:
            prefix = "dspn"

        used_nums = []

        if not os.path.exists("runs"):
            os.makedirs("runs")

        runs = os.listdir("runs")
        for run in runs:
            if args.dataset in run:
                used_nums.append(int(run.split("-")[-1]))

        num = 1
        while num in used_nums:
            num += 1
        name = f"{prefix}-{args.dataset}-{num}"
    else:
        name = args.name

    print(f"Saving run to runs/{name}")
    train_writer = SummaryWriter(f"runs/{name}", purge_step=0)

    net = model.build_net(args)

    if not args.no_cuda:
        net = net.cuda()

    if args.multi_gpu:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.Adam(
        [p for p in net.parameters() if p.requires_grad], lr=args.lr)

    print("Building dataloader")
    if args.dataset == "mnist":
        dataset_train = data.MNISTSet(train=True, full=args.full_eval)
        dataset_test = data.MNISTSet(train=False, full=args.full_eval)
    elif args.dataset in ["clevr-box", "clevr-state"]:
        dataset_train = data.CLEVR("clevr",
                                   "train",
                                   box=args.dataset == "clevr-box",
                                   full=args.full_eval)

        dataset_test = data.CLEVR("clevr",
                                  "val",
                                  box=args.dataset == "clevr-box",
                                  full=args.full_eval)
    elif args.dataset == "cats":
        dataset_train = data.Cats("cats", "train", 9, full=args.full_eval)
        dataset_test = data.Cats("cats", "val", 9, full=args.full_eval)
    elif args.dataset == "faces":
        dataset_train = data.Faces("faces", "train", 4, full=args.full_eval)
        dataset_test = data.Faces("faces", "val", 4, full=args.full_eval)
    elif args.dataset == "wflw":
        if args.eval_split:
            eval_split = f"test_{args.eval_split}"
        else:
            eval_split = "test"

        dataset_train = data.WFLW("wflw", "train", 7, full=args.full_eval)
        dataset_test = data.WFLW("wflw", eval_split, 7, full=args.full_eval)
    elif args.dataset == "merged":
        # merged cats and human faces
        dataset_train_cats = data.Cats("cats", "train", 9, full=args.full_eval)
        dataset_train_wflw = data.WFLW("wflw", "train", 9, full=args.full_eval)

        dataset_test_cats = data.Cats("cats", "val", 9, full=args.full_eval)
        dataset_test_wflw = data.WFLW("wflw", "test", 9, full=args.full_eval)

        dataset_train = data.MergedDataset(dataset_train_cats,
                                           dataset_train_wflw)

        dataset_test = data.MergedDataset(dataset_test_cats, dataset_test_wflw)

    if not args.eval_only:
        train_loader = data.get_loader(dataset_train,
                                       batch_size=args.batch_size,
                                       num_workers=args.num_workers)

    if not args.train_only:
        test_loader = data.get_loader(dataset_test,
                                      batch_size=args.batch_size,
                                      num_workers=args.num_workers,
                                      shuffle=False)

    tracker = track.Tracker(
        train_mae=track.ExpMean(),
        train_last=track.ExpMean(),
        train_loss=track.ExpMean(),
        test_mae=track.Mean(),
        test_last=track.Mean(),
        test_loss=track.Mean(),
    )

    if args.resume:
        log = torch.load(args.resume)
        weights = log["weights"]
        n = net
        if args.multi_gpu:
            n = n.module
        n.load_state_dict(weights, strict=True)

    if args.export_csv:
        names = []
        predictions = []
        export_targets = []

    def run(net, loader, optimizer, train=False, epoch=0, pool=None):
        writer = train_writer
        if train:
            net.train()
            prefix = "train"
            torch.set_grad_enabled(True)
        else:
            net.eval()
            prefix = "test"
            torch.set_grad_enabled(False)

        if args.export_dir:
            true_export = []
            pred_export = []

        iters_per_epoch = len(loader)
        loader = tqdm(
            loader,
            ncols=0,
            desc="{1} E{0:02d}".format(epoch, "train" if train else "test "),
        )

        for i, sample in enumerate(loader, start=epoch * iters_per_epoch):
            # input is either a set or an image
            input, target_set, target_mask = map(lambda x: x.cuda(), sample)

            # forward evaluation through the network
            (progress, masks, evals,
             gradn), (y_enc, y_label) = net(input, target_set, target_mask)

            progress_only = progress

            # if using mask as feature, concat mask feature into progress
            if args.mask_feature:
                target_set = torch.cat(
                    [target_set, target_mask.unsqueeze(dim=1)], dim=1)
                progress = [
                    torch.cat([p, m.unsqueeze(dim=1)], dim=1)
                    for p, m in zip(progress, masks)
                ]

            if args.loss == "chamfer":
                # dim 0 is over the inner iteration steps
                # target set is broadcasted over dim 0
                set_loss = utils.chamfer_loss(torch.stack(progress),
                                              target_set.unsqueeze(0))
            elif args.loss == "hungarian":
                set_loss = utils.hungarian_loss(progress[-1],
                                                target_set,
                                                thread_pool=pool).unsqueeze(0)
            elif args.loss == "emd":
                set_loss = utils.emd(progress[-1], target_set).unsqueeze(0)

            # Only use representation loss with DSPN and when doing general
            # supervised prediction, not when auto-encoding
            if args.supervised and not args.baseline:
                repr_loss = args.huber_repr * F.smooth_l1_loss(y_enc, y_label)
                loss = set_loss.mean() + repr_loss.mean()
            else:
                loss = set_loss.mean()

            # restore progress variable to not contain masks for correct
            # exporting
            progress = progress_only

            # Outer optim step
            if train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # Tensorboard tracking of metrics for debugging
            tracked_last = tracker.update(f"{prefix}_last",
                                          set_loss[-1].item())
            tracked_loss = tracker.update(f"{prefix}_loss", loss.item())
            if train:
                writer.add_scalar("metric/set-loss",
                                  loss.item(),
                                  global_step=i)

                writer.add_scalar("metric/set-last",
                                  set_loss[-1].mean().item(),
                                  global_step=i)

                if not args.baseline:
                    writer.add_scalar("metric/eval-first",
                                      evals[0].mean().item(),
                                      global_step=i)

                    writer.add_scalar("metric/eval-last",
                                      evals[-1].mean().item(),
                                      global_step=i)

                    writer.add_scalar("metric/max-inner-grad-norm",
                                      max(g.item() for g in gradn),
                                      global_step=i)

                    writer.add_scalar("metric/mean-inner-grad-norm",
                                      sum(g.item()
                                          for g in gradn) / len(gradn),
                                      global_step=i)

                    if args.supervised:
                        writer.add_scalar("metric/repr_loss",
                                          repr_loss.item(),
                                          global_step=i)

            # Print current progress to progress bar
            fmt = "{:.6f}".format
            loader.set_postfix(last=fmt(tracked_last),
                               loss=fmt(tracked_loss),
                               bad=fmt(evals[-1].detach().cpu().item() *
                                       1000) if not args.baseline else 0)

            if args.export_dir:
                # export last inner optim of each input as csv
                # (one input per row)
                if args.export_csv:
                    # the second to last element are the last of the
                    # inner optim
                    for batch_i, p in enumerate(progress[-2]):
                        img_id = i * args.batch_size + batch_i

                        names.append(loader.iterable.dataset.get_fname(img_id))

                        m = masks[-2][batch_i]
                        m = m.cpu().detach().numpy().astype(bool)

                        p = p.cpu().detach().numpy()
                        p = p[:, m]

                        sample_preds = [
                            p[k % 2, k // 2] for k in range(p.shape[1] * 2)
                        ]
                        # remove values according to mask and add zeros to the
                        # end in stead
                        sample_preds += [0] * (len(m) * 2 - len(sample_preds))
                        predictions.append(sample_preds)

                        true_mask = target_set[batch_i, 2, :].cpu().detach()
                        true_mask = true_mask.numpy().astype(bool)
                        trues = target_set[batch_i, :2, :]
                        trues = trues.cpu().detach().numpy()

                        t = trues[:, true_mask]

                        t = [t[k % 2, k // 2] for k in range(t.shape[1] * 2)]

                        t += [0] * (len(true_mask) * 2 - len(t))

                        export_targets.append(t)

                # Store predictions to be exported
                else:
                    if len(true_export) < args.export_n:
                        for p, m in zip(target_set, target_mask):
                            true_export.append(p.detach().cpu())
                        progress_steps = []
                        for pro, ms in zip(progress, masks):
                            # pro and ms are one step of the inner optim
                            # score boxes contains the list of predicted
                            # elements for one step
                            score_boxes = []
                            for p, m in zip(pro.cpu().detach(),
                                            ms.cpu().detach()):
                                score_box = torch.cat([m.unsqueeze(0), p],
                                                      dim=0)
                                score_boxes.append(score_box)
                            progress_steps.append(score_boxes)
                        for b in zip(*progress_steps):
                            pred_export.append(b)

            # Plot predictions in Tensorboard
            if args.show and epoch % args.show_skip == 0 and not train:
                name = f"set/epoch-{epoch}/img-{i}"
                # thresholded set
                progress.append(progress[-1])
                masks.append((masks[-1] > 0.5).float())
                # target set
                if args.mask_feature:
                    # target set is augmented with masks, so remove them
                    progress.append(target_set[:, :-1])
                else:
                    progress.append(target_set)
                masks.append(target_mask)
                # intermediate sets

                for j, (s, ms) in enumerate(zip(progress, masks)):
                    if args.dataset == "clevr-state":
                        continue

                    if args.dataset.startswith("clevr"):
                        threshold = 0.5
                    else:
                        threshold = None

                    s, ms = utils.scatter_masked(
                        s,
                        ms,
                        binned=args.dataset.startswith("clevr"),
                        threshold=threshold)

                    if j != len(progress) - 1:
                        tag_name = f"{name}"
                    else:
                        tag_name = f"{name}-target"

                    if args.dataset == "clevr-box":
                        img = input[0].detach().cpu()

                        writer.add_image_with_boxes(tag_name,
                                                    img,
                                                    s.transpose(0, 1),
                                                    global_step=j)
                    elif args.dataset == "cats" \
                            or args.dataset == "wflw" \
                            or args.dataset == "merged":

                        img = input[0].detach().cpu()

                        fig = plt.figure()
                        plt.scatter(s[0, :] * 128, s[1, :] * 128)

                        plt.imshow(np.transpose(img, (1, 2, 0)))

                        writer.add_figure(tag_name, fig, global_step=j)
                    else:  # mnist
                        fig = plt.figure()
                        y, x = s
                        y = 1 - y
                        ms = ms.numpy()
                        rgba_colors = np.zeros((ms.size, 4))
                        rgba_colors[:, 2] = 1.0
                        rgba_colors[:, 3] = ms
                        plt.scatter(x, y, color=rgba_colors)
                        plt.axes().set_aspect("equal")
                        plt.xlim(0, 1)
                        plt.ylim(0, 1)
                        writer.add_figure(tag_name, fig, global_step=j)

        # Export predictions
        if args.export_dir and not args.export_csv:
            os.makedirs(f"{args.export_dir}/groundtruths", exist_ok=True)
            os.makedirs(f"{args.export_dir}/detections", exist_ok=True)
            for i, (gt, dets) in enumerate(zip(true_export, pred_export)):
                export_groundtruths_path = os.path.join(
                    args.export_dir, "groundtruths", f"{i}.txt")

                with open(export_groundtruths_path, "w") as fd:
                    for box in gt.transpose(0, 1):
                        if (box == 0).all():
                            continue
                        s = "box " + " ".join(map(str, box.tolist()))
                        fd.write(s + "\n")

                if args.export_progress:
                    for step, det in enumerate(dets):
                        export_progress_path = os.path.join(
                            args.export_dir, "detections",
                            f"{i}-step{step}.txt")

                        with open(export_progress_path, "w") as fd:
                            for sbox in det.transpose(0, 1):
                                s = f"box " + " ".join(map(str, sbox.tolist()))
                                fd.write(s + "\n")

                export_path = os.path.join(args.export_dir, "detections",
                                           f"{i}.txt")
                with open(export_path, "w") as fd:
                    for sbox in dets[-1].transpose(0, 1):
                        s = f"box " + " ".join(map(str, sbox.tolist()))
                        fd.write(s + "\n")

    import subprocess

    git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"])
    # git_hash = "483igtrfiuey46"

    torch.backends.cudnn.benchmark = True

    metrics = {}

    start = time.time()

    if args.eval_only:
        tracker.new_epoch()
        with mp.Pool(10) as pool:
            run(net, test_loader, optimizer, train=False, epoch=0, pool=pool)

        metrics["test_loss"] = np.mean(tracker.data["test_loss"][-1])
        metrics["test_set_loss"] = np.mean(tracker.data["test_last"][-1])
    else:
        best_test_loss = float("inf")

        for epoch in range(args.epochs):
            tracker.new_epoch()
            with mp.Pool(10) as pool:
                run(net,
                    train_loader,
                    optimizer,
                    train=True,
                    epoch=epoch,
                    pool=pool)
                if not args.train_only:
                    run(net,
                        test_loader,
                        optimizer,
                        train=False,
                        epoch=epoch,
                        pool=pool)

            epoch_test_loss = np.mean(tracker.data["test_loss"][-1])

            if epoch_test_loss < best_test_loss:
                print("new best loss")
                best_test_loss = epoch_test_loss
                # only save if the epoch has lower loss
                metrics["test_loss"] = epoch_test_loss
                metrics["train_loss"] = np.mean(tracker.data["train_loss"][-1])

                metrics["train_set_loss"] = np.mean(
                    tracker.data["train_last"][-1])
                metrics["test_set_loss"] = np.mean(
                    tracker.data["test_last"][-1])

                metrics["best_epoch"] = epoch

                results = {
                    "name":
                    name + "-best",
                    "tracker":
                    tracker.data,
                    "weights":
                    net.state_dict()
                    if not args.multi_gpu else net.module.state_dict(),
                    "args":
                    vars(args),
                    "hash":
                    git_hash,
                }

                torch.save(results, os.path.join("logs", name + "-best"))

        results = {
            "name":
            name + "-final",
            "tracker":
            tracker.data,
            "weights":
            net.state_dict()
            if not args.multi_gpu else net.module.state_dict(),
            "args":
            vars(args),
            "hash":
            git_hash,
        }
        torch.save(results, os.path.join("logs", name + "-final"))

    if args.export_csv and args.export_dir:
        path = os.path.join(args.export_dir, f'{args.name}-predictions.csv')
        pd.DataFrame(np.array(predictions), index=names).to_csv(path,
                                                                sep=',',
                                                                index=names,
                                                                header=False)

        path = os.path.join(args.export_dir, f'{args.name}-targets.csv')
        pd.DataFrame(np.array(export_targets),
                     index=names).to_csv(path,
                                         sep=',',
                                         index=names,
                                         header=False)

    took = time.time() - start
    print(f"Process took {took:.1f}s, avg {took/args.epochs:.1f} s/epoch.")

    # save hyper parameters to tensorboard for reference
    hparams = {k: v for k, v in vars(args).items() if v is not None}

    print(metrics)
    metrics = {"total_time": took, "avg_time_per_epoch": took / args.epochs}

    print("writing hparams")
    train_writer.add_hparams(hparams, metric_dict=metrics, name="hparams")
Exemple #4
0
                             lr=args.lr)

train_loader = data.get_loader(dataset_train,
                               batch_size=args.batch_size,
                               num_workers=args.num_workers)
test_loader = data.get_loader(dataset_test,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers)

tracker = track.Tracker(
    train_loss=track.ExpMean(),
    train_acc=track.ExpMean(),
    train_l1=track.ExpMean(),
    train_l2=track.ExpMean(),
    train_lr=track.Identity(),
    test_loss=track.Mean(),
    test_acc=track.Mean(),
    test_l1=track.Mean(),
    test_l2=track.Mean(),
    test_lr=track.Identity(),
)

if args.resume:
    log = torch.load(args.resume)
    weights = log['weights']
    n = net
    strict = True
    if args.multi_gpu:
        n = n.module
    if args.task == 'classify' and not args.vis:
        # we only want to load the classifier portion of the model, not the
Exemple #5
0
def main():
    global net
    global test_loader
    global scatter
    parser = argparse.ArgumentParser()
    # generic params
    parser.add_argument('--name',
                        default=datetime.now().strftime('%Y-%m-%d_%H:%M:%S'),
                        help='Name to store the log file as')
    parser.add_argument('--resume',
                        nargs='+',
                        help='Path to log file to resume from')

    parser.add_argument('--encoder', default='FSEncoder', help='Encoder')
    parser.add_argument('--decoder', default='FSDecoder', help='Decoder')
    parser.add_argument('--cardinality',
                        type=int,
                        default=20,
                        help='Size of set')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        help='Number of epochs to train with')
    parser.add_argument('--latent',
                        type=int,
                        default=8,
                        help='Dimensionality of latent space')
    parser.add_argument('--dim',
                        type=int,
                        default=64,
                        help='Dimensionality of hidden layers')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-2,
                        help='Learning rate of model')
    parser.add_argument('--batch-size',
                        type=int,
                        default=32,
                        help='Batch size to train with')
    parser.add_argument('--num-workers',
                        type=int,
                        default=4,
                        help='Number of threads for data loader')
    parser.add_argument('--samples',
                        type=int,
                        default=2**14,
                        help='Dataset size')
    parser.add_argument('--decay',
                        action='store_true',
                        help='Decay sort temperature')
    parser.add_argument('--skip',
                        action='store_true',
                        help='Skip permutation use in decoder')
    parser.add_argument('--mnist',
                        action='store_true',
                        help='Use MNIST dataset')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        help='Run on CPU instead of GPU (not recommended)')
    parser.add_argument('--train-only',
                        action='store_true',
                        help='Only run training, no evaluation')
    parser.add_argument('--eval-only',
                        action='store_true',
                        help='Only run evaluation, no training')
    parser.add_argument('--multi-gpu',
                        action='store_true',
                        help='Use multiple GPUs')
    parser.add_argument('--show',
                        action='store_true',
                        help='Show generated samples')

    parser.add_argument('--loss',
                        choices=['direct', 'hungarian', 'chamfer'],
                        default='direct',
                        help='Type of loss used')

    parser.add_argument('--shift', action='store_true', help='')
    parser.add_argument('--rotate', action='store_true', help='')
    parser.add_argument('--scale', action='store_true', help='')
    parser.add_argument('--variable', action='store_true', help='')
    parser.add_argument('--noise',
                        type=float,
                        default=0,
                        help='Standard deviation of noise')
    args = parser.parse_args()
    args.mnist = True
    args.eval_only = True
    args.show = True
    args.cardinality = 342
    args.batch_size = 1

    model_args = {
        'set_size': args.cardinality,
        'dim': args.dim,
        'skip': args.skip,
    }
    net_class = SAE
    net = [
        net_class(
            encoder=globals()[args.encoder if k == 0 else 'SumEncoder'],
            decoder=globals()[args.decoder if k == 0 else 'MLPDecoder'],
            latent_dim=args.latent,
            encoder_args=model_args,
            decoder_args=model_args,
        ) for k in range(2)
    ]

    if not args.no_cuda:
        net = [n.cuda() for n in net]

    dataset_train = data.MNISTSet(train=True)
    dataset_test = data.MNISTSet(train=False)

    train_loader = data.get_loader(dataset_train,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers)
    test_loader = data.get_loader(dataset_test,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers)

    tracker = track.Tracker(
        train_mae=track.ExpMean(),
        train_cha=track.ExpMean(),
        train_loss=track.ExpMean(),
        test_mae=track.Mean(),
        test_cha=track.Mean(),
        test_loss=track.Mean(),
    )

    optimizer = None

    def run(net, loader, optimizer, train=False, epoch=0, pool=None):
        [n.eval() for n in net]

        preds = []
        ns = []
        for i, sample in enumerate(loader):
            points, labels, n_points = map(lambda x: x.cuda(async=True),
                                           sample)

            if args.noise > 0:
                noise = torch.randn_like(points) * args.noise
                input_points = points + noise
            else:
                input_points = points

            # pad to fixed size
            padding = torch.zeros(points.size(0), points.size(1),
                                  args.cardinality - points.size(2)).to(
                                      points.device)
            padded_points = torch.cat([input_points, padding], dim=2)
            points2 = [input_points, padded_points]

            pred = [points, input_points
                    ] + [n(p, n_points) for n, p in zip(net, points2)]
            pred = [p[0].detach().cpu().numpy() for p in pred]
            preds.append(pred)
            ns.append(n_points)
            if i == 1:
                return preds, ns

    def scatter(tensor, n_points, transpose=False, *args, **kwargs):
        x, y = tensor
        n = n_points
        if transpose:
            x, y = y, x
            y = 1 - y
        plt.scatter(x[:n], y[:n], *args, **kwargs)

    # group same noise levels together
    d = {}
    for path in sorted(args.resume):
        name = path.split('/')[-1]
        model, noise, num = name.split('-')[1:]
        noise = float(noise)

        d.setdefault(noise, []).append((model, path))
    print(d)

    plt.figure(figsize=(16, 3.9))
    for i, (noise, ms) in enumerate(d.items()):
        print(i, noise, ms)
        for (_, path), n in zip(ms, net):
            weights = torch.load(path)['weights']
            print(path, type(n.encoder), type(n.decoder))
            n.load_state_dict(weights, strict=True)
        args.noise = noise

        points, n_points = run(net, test_loader, None)

        for j, (po, np) in enumerate(zip(points, n_points)):
            for p, row in zip(po, [0, 0, 1, 2]):
                ax = plt.subplot(3, 12, 12 * row + 1 + 2 * i + j)
                if row == 2:
                    np = 342
                scatter(p, np, transpose=True, marker='o', s=8, alpha=0.5)
                plt.xlim(0, 1)
                plt.ylim(0, 1)
                ax.set_xticks([])
                ax.set_yticks([])
                if row == 0:
                    plt.title(r'$\sigma = {:.2f}$'.format(noise))
                if i == 0 and j == 0:
                    label = {
                        0: 'Input / Target',
                        1: 'Ours',
                        2: 'Baseline',
                    }[row]
                    plt.ylabel(label)

    plt.subplots_adjust(wspace=0.0, hspace=0.0)
    plt.savefig('mnist.pdf', bbox_inches='tight')
Exemple #6
0
def main():
    global net
    global test_loader
    global scatter
    parser = argparse.ArgumentParser()
    # generic params
    parser.add_argument(
        "--name",
        default=datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
        help="Name to store the log file as",
    )
    parser.add_argument("--resume", help="Path to log file to resume from")

    parser.add_argument("--encoder", default="FSEncoder", help="Encoder")
    parser.add_argument("--decoder", default="DSPN", help="Decoder")
    parser.add_argument(
        "--epochs", type=int, default=10, help="Number of epochs to train with"
    )
    parser.add_argument(
        "--latent", type=int, default=32, help="Dimensionality of latent space"
    )
    parser.add_argument(
        "--dim", type=int, default=64, help="Dimensionality of hidden layers"
    )
    parser.add_argument(
        "--lr", type=float, default=1e-2, help="Outer learning rate of model"
    )
    parser.add_argument(
        "--batch-size", type=int, default=32, help="Batch size to train with"
    )
    parser.add_argument(
        "--num-workers", type=int, default=2, help="Number of threads for data loader"
    )
    parser.add_argument(
        "--dataset",
        choices=["mnist", "clevr-box", "clevr-state", "lhc"],
        help="Use MNIST dataset",
    )
    parser.add_argument(
        "--no-cuda",
        action="store_true",
        help="Run on CPU instead of GPU (not recommended)",
    )
    parser.add_argument(
        "--train-only", action="store_true", help="Only run training, no evaluation"
    )
    parser.add_argument(
        "--eval-only", action="store_true", help="Only run evaluation, no training"
    )
    parser.add_argument("--multi-gpu", action="store_true", help="Use multiple GPUs")
    parser.add_argument(
        "--show", action="store_true", help="Plot generated samples in Tensorboard"
    )

    parser.add_argument("--supervised", action="store_true", help="")
    parser.add_argument("--baseline", action="store_true", help="Use baseline model")

    parser.add_argument("--export-dir", type=str, help="Directory to output samples to")
    parser.add_argument(
        "--export-n", type=int, default=10 ** 9, help="How many samples to output"
    )
    parser.add_argument(
        "--export-progress",
        action="store_true",
        help="Output intermediate set predictions for DSPN?",
    )
    parser.add_argument(
        "--full-eval",
        action="store_true",
        help="Use full evaluation set (default: 1/10 of evaluation data)",  # don't need full evaluation when training to save some time
    )
    parser.add_argument(
        "--mask-feature",
        action="store_true",
        help="Treat mask as a feature to compute loss with",
    )
    parser.add_argument(
        "--inner-lr",
        type=float,
        default=800,
        help="Learning rate of DSPN inner optimisation",
    )
    parser.add_argument(
        "--iters",
        type=int,
        default=10,
        help="How many DSPN inner optimisation iteration to take",
    )
    parser.add_argument(
        "--huber-repr",
        type=float,
        default=1,
        help="Scaling of representation loss term for DSPN supervised learning",
    )
    parser.add_argument(
        "--loss",
        choices=["hungarian", "chamfer"],
        default="hungarian",
        help="Type of loss used",
    )
    args = parser.parse_args()

    train_writer = SummaryWriter(f"runs/{args.name}", purge_step=0)

    net = model.build_net(args)

    if not args.no_cuda:
        net = net.cuda()

    if args.multi_gpu:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.Adam(
        [p for p in net.parameters() if p.requires_grad], lr=args.lr
    )

    if args.dataset == "mnist":
        dataset_train = data.MNISTSet(train=True, full=args.full_eval)
        dataset_test = data.MNISTSet(train=False, full=args.full_eval)
    elif args.dataset == "lhc":
        dataset_train = data.LHCSet(train=True)
        dataset_test = data.LHCSet(train=False)
    else:
        dataset_train = data.CLEVR(
            "clevr", "train", box=args.dataset == "clevr-box", full=args.full_eval
        )
        dataset_test = data.CLEVR(
            "clevr", "val", box=args.dataset == "clevr-box", full=args.full_eval
        )

    if not args.eval_only:
        train_loader = data.get_loader(
            dataset_train, batch_size=args.batch_size, num_workers=args.num_workers
        )
    if not args.train_only:
        test_loader = data.get_loader(
            dataset_test,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            shuffle=False,
        )

    tracker = track.Tracker(
        train_mae=track.ExpMean(),
        train_last=track.ExpMean(),
        train_loss=track.ExpMean(),
        test_mae=track.Mean(),
        test_last=track.Mean(),
        test_loss=track.Mean(),
    )

    if args.resume:
        log = torch.load(args.resume)
        weights = log["weights"]
        n = net
        if args.multi_gpu:
            n = n.module
        n.load_state_dict(weights, strict=True)

    def run(net, loader, optimizer, train=False, epoch=0, pool=None):
        writer = train_writer
        if train:
            net.train()
            prefix = "train"
            torch.set_grad_enabled(True)
        else:
            net.eval()
            prefix = "test"
            torch.set_grad_enabled(False)

        total_train_steps = args.epochs * len(loader)
        if args.export_dir:
            true_export = []
            pred_export = []

        iters_per_epoch = len(loader) #len(loader) is # batches, which is len(dataset)/batch size (which is an arg)
        loader = tqdm(
            loader,
            ncols=0,
            desc="{1} E{0:02d}".format(epoch, "train" if train else "test "),
        )

        losses = []
        steps = []
        encodings = []
        for i, sample in enumerate(loader, start=epoch * iters_per_epoch):
            # input is either a set or an image
            input, target_set, target_mask = map(lambda x: x.cuda(), sample)

            # forward evaluation through the network
            (progress, masks, evals, gradn), (y_enc, y_label) = net(
                input, target_set, target_mask
            )

            encodings.append(y_label)
            progress_only = progress

            # if using mask as feature, concat mask feature into progress
            if args.mask_feature:
                target_set = torch.cat(
                    [target_set, target_mask.unsqueeze(dim=1)], dim=1
                )
                progress = [
                    torch.cat([p, m.unsqueeze(dim=1)], dim=1)
                    for p, m in zip(progress, masks)
                ]

            if args.loss == "chamfer":
                # dim 0 is over the inner iteration steps
                # target set is broadcasted over dim 0
                set_loss = utils.chamfer_loss(
                    torch.stack(progress), target_set.unsqueeze(0)
                )
            else:
                # dim 0 is over the inner iteration steps
                a = torch.stack(progress)
                # target set is explicitly broadcasted over dim 0
                b = target_set.repeat(a.size(0), 1, 1, 1)
                # flatten inner iteration dim and batch dim
                a = a.view(-1, a.size(2), a.size(3))
                b = b.view(-1, b.size(2), b.size(3))
                set_loss = utils.hungarian_loss(
                    progress[-1], target_set, thread_pool=pool
                ).unsqueeze(0)
            # Only use representation loss with DSPN and when doing general supervised prediction, not when auto-encoding
            if args.supervised and not args.baseline:
                repr_loss = args.huber_repr * F.smooth_l1_loss(y_enc, y_label)
                loss = set_loss.mean() + repr_loss.mean()
            else:
                loss = set_loss.mean()

            # restore progress variable to not contain masks for correct exporting
            progress = progress_only

            # Outer optim step
            if train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # loss value can be gotten as loss.item() for graphing purposes
            steps.append(i)
            losses.append(loss.item())
            # Tensorboard tracking of metrics for debugging
            tracked_last = tracker.update("{}_last".format(prefix), set_loss[-1].item())
            tracked_loss = tracker.update("{}_loss".format(prefix), loss.item())
            if train:
                writer.add_scalar("metric/set-loss", loss.item(), global_step=i)
                writer.add_scalar(
                    "metric/set-last", set_loss[-1].mean().item(), global_step=i
                )
                if not args.baseline:
                    writer.add_scalar(
                        "metric/eval-first", evals[0].mean().item(), global_step=i
                    )
                    writer.add_scalar(
                        "metric/eval-last", evals[-1].mean().item(), global_step=i
                    )
                    writer.add_scalar(
                        "metric/max-inner-grad-norm",
                        max(g.item() for g in gradn),
                        global_step=i,
                    )
                    writer.add_scalar(
                        "metric/mean-inner-grad-norm",
                        sum(g.item() for g in gradn) / len(gradn),
                        global_step=i,
                    )
                    if args.supervised:
                        writer.add_scalar(
                            "metric/repr_loss", repr_loss.item(), global_step=i
                        )

            # Print current progress to progress bar
            fmt = "{:.6f}".format
            loader.set_postfix(
                last=fmt(tracked_last),
                loss=fmt(tracked_loss),
                bad=fmt(evals[-1].detach().cpu().item() * 1000)
                if not args.baseline
                else 0,
            )

            # Store predictions to be exported
            if args.export_dir:
                if len(true_export) < args.export_n:
                    for p, m in zip(target_set, target_mask):
                        true_export.append(p.detach().cpu())
                    progress_steps = []
                    for pro, mas in zip(progress, masks):
                        # pro and mas are one step of the inner optim
                        # score boxes contains the list of predicted elements for one step
                        score_boxes = []
                        for p, m in zip(pro.cpu().detach(), mas.cpu().detach()):
                            score_box = torch.cat([m.unsqueeze(0), p], dim=0)
                            score_boxes.append(score_box)
                        progress_steps.append(score_boxes)
                    for b in zip(*progress_steps):
                        pred_export.append(b)

            # Plot predictions in Tensorboard
            if args.show and not train:
                name = f"set/epoch-{epoch}/img-{i}"
                # thresholded set
                progress.append(progress[-1])
                masks.append((masks[-1] > 0.5).float())
                # target set
                if args.mask_feature:
                    # target set is augmented with masks, so remove them
                    progress.append(target_set[:, :-1])
                else:
                    progress.append(target_set)
                masks.append(target_mask)
                # intermediate sets
                for j, (s, ms) in enumerate(zip(progress, masks)):
                    if args.dataset == "clevr-state":
                        continue
                    s, ms = utils.scatter_masked(
                        s,
                        ms,
                        binned=args.dataset.startswith("clevr"),
                        threshold=0.5 if args.dataset.startswith("clevr") else None,
                    )
                    tag_name = f"{name}" if j != len(progress) - 1 else f"{name}-target"
                    if args.dataset == "clevr-box":
                        img = input[0].detach().cpu()
                        writer.add_image_with_boxes(
                            tag_name, img, s.transpose(0, 1), global_step=j
                        )
                    elif args.dataset == "clevr-state":
                        pass
                    else:  # mnist
                        fig = plt.figure()
                        y, x = s
                        y = 1 - y
                        ms = ms.numpy()
                        rgba_colors = np.zeros((ms.size, 4))
                        rgba_colors[:, 2] = 1.0
                        rgba_colors[:, 3] = ms
                        plt.scatter(x, y, color=rgba_colors)
                        plt.axes().set_aspect("equal")
                        plt.xlim(0, 1)
                        plt.ylim(0, 1)
                        writer.add_figure(tag_name, fig, global_step=j)




        #create scatter of encoded points?
        print([t.shape for t in encodings[0:5]])
        '''
        #get the first latent dimension, and plot that
        first_latent_dim = [t.detach().cpu()[0] for t in encodings[100:]]
        numbers = list(range(len(first_latent_dim[0]))) * len(first_latent_dim)
        first_latent_dim_as_list = []
        for tensor in first_latent_dim:
            for entry in tensor:
                first_latent_dim_as_list.append(entry)
        
        print(len(numbers))
        print(len(first_latent_dim_as_list))
        fig = plt.figure()
        #plt.yscale('log')
        #print(first_latent_dim[0:5])
        plt.scatter(numbers, first_latent_dim_as_list)
        name = f'img-latent-epoch-{epoch}-{"train" if train else "test"}'
        plt.savefig(name, dpi=300)
        plt.close(fig)
        '''

        #create scatter of encoded points, put through tsne
        # see: https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
        set_encoding_in_hidden_dim = torch.cat(encodings, dim = 0).detach().cpu()
        set_encoding_tsne = TSNE(n_components = 2).fit_transform(set_encoding_in_hidden_dim)
        set_encoding_tsne_x = set_encoding_tsne[:, 0]
        set_encoding_tsne_y = set_encoding_tsne[:, 1]
        fig = plt.figure()
        plt.scatter(set_encoding_tsne_x, set_encoding_tsne_y)
        name = f'img-latent-epoch-{epoch}-{"train" if train else "test"}'
        plt.savefig(name, dpi=300)
        plt.close(fig)

        # create graph
        fig = plt.figure()
        #plt.yscale('linear')
        plt.scatter(steps, losses)
        name = f"img-losses-epoch-{epoch}-train-{'train' if train else 'test'}"
        plt.savefig(name, dpi=300)
        plt.close(fig)

        # Export predictions
        if args.export_dir:
            os.makedirs(f"{args.export_dir}/groundtruths", exist_ok=True)
            os.makedirs(f"{args.export_dir}/detections", exist_ok=True)
            for i, (gt, dets) in enumerate(zip(true_export, pred_export)):
                with open(f"{args.export_dir}/groundtruths/{i}.txt", "w") as fd:
                    for box in gt.transpose(0, 1):
                        if (box == 0).all():
                            continue
                        s = "box " + " ".join(map(str, box.tolist()))
                        fd.write(s + "\n")
                if args.export_progress:
                    for step, det in enumerate(dets):
                        with open(
                            f"{args.export_dir}/detections/{i}-step{step}.txt", "w"
                        ) as fd:
                            for sbox in det.transpose(0, 1):
                                s = f"box " + " ".join(map(str, sbox.tolist()))
                                fd.write(s + "\n")
                with open(f"{args.export_dir}/detections/{i}.txt", "w") as fd:
                    for sbox in dets[-1].transpose(0, 1):
                        s = f"box " + " ".join(map(str, sbox.tolist()))
                        fd.write(s + "\n")

        return np.mean(losses)

    import subprocess

    git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"])

    torch.backends.cudnn.benchmark = True

    epoch_losses = []

    for epoch in range(args.epochs):
        tracker.new_epoch()
        with mp.Pool(10) as pool:
            if not args.eval_only:
                e_loss = run(net, train_loader, optimizer, train=True, epoch=epoch, pool=pool)
                epoch_losses.append(e_loss)
            if not args.train_only:
                e_loss = run(net, test_loader, optimizer, train=False, epoch=epoch, pool=pool)

        results = {
            "name": args.name,
            "tracker": tracker.data,
            "weights": net.state_dict()
            if not args.multi_gpu
            else net.module.state_dict(),
            "args": vars(args),
            "hash": git_hash,
        }
        torch.save(results, os.path.join("logs", args.name))
        if args.eval_only:
            break

    #plot encoding/decoding
    #train loader is bkdg points, test loader is signal points
    with mp.Pool(10) as pool:
        train_encodings = []
        for i, sample in enumerate(train_loader):
            # input is either a set or an image
            input, target_set, target_mask = map(lambda x: x.cuda(), sample)

            # forward evaluation through the network
            (progress, masks, evals, gradn), (y_enc, y_label) = net(
                input, target_set, target_mask
            )
            train_encodings.append(y_label)
        test_encodings = []
        for i, sample in enumerate(test_loader):
            # input is either a set or an image
            input, target_set, target_mask = map(lambda x: x.cuda(), sample)

            # forward evaluation through the network
            (progress, masks, evals, gradn), (y_enc, y_label) = net(
                input, target_set, target_mask
            )
            test_encodings.append(y_label)



    #recall: "train" is actually just bkgd and "test" is really just signal,
    #because we're trying to learn the background distribution and then encode signal points in the same way, and see what happens
    #so train an MLP and see if it can distinguish
    num_bkdg_val = len(train_encodings) // 3
    num_sign_val = len(test_encodings) // 3
    bkgd_train_encodings_tensor = torch.cat(train_encodings[:-num_bkdg_val], dim = 0)
    sign_train_encodings_tensor = torch.cat(test_encodings[:-num_sign_val], dim = 0)
    bkgd_train_encoding_label_tensor = torch.zeros((len(train_encodings) - num_bkdg_val) * args.batch_size, dtype=torch.long)
    sign_train_encoding_label_tensor = torch.ones((len(test_encodings) - num_sign_val) * args.batch_size, dtype=torch.long)

    train_encodings_tensor = torch.cat((bkgd_train_encodings_tensor, sign_train_encodings_tensor)).to('cuda:0')
    train_encodings_tensor_copy = torch.tensor(train_encodings_tensor, requires_grad=True).to('cuda:0')
    train_labels_tensor = torch.cat((bkgd_train_encoding_label_tensor, sign_train_encoding_label_tensor)).to('cuda:0')

    bkgd_test_encodings_tensor = torch.cat(train_encodings[-num_bkdg_val:])
    sign_test_encodings_tensor = torch.cat(test_encodings[-num_sign_val:])
    bkgd_test_encoding_label_tensor = torch.zeros((num_bkdg_val) * args.batch_size, dtype=torch.long)
    sign_test_encoding_label_tensor = torch.ones((num_sign_val) * args.batch_size, dtype=torch.long)

    test_encodings_tensor = torch.cat((bkgd_test_encodings_tensor, sign_test_encodings_tensor)).to('cuda:0')
    test_encodings_tensor_copy = torch.tensor(test_encodings_tensor, requires_grad=True).to('cuda:0')
    test_labels_tensor = torch.cat((bkgd_test_encoding_label_tensor, sign_test_encoding_label_tensor)).to('cuda:0')

    _, encoding_len = y_label.shape
    mlp = dummymlp.MLP(embedding_dim=encoding_len, hidden_dim=20, label_dim=2).to('cuda:0')
    loss_func = nn.CrossEntropyLoss().to('cuda:0')
    #mlp.train()

    epochs = 20
    for epoch in range(epochs):
        mlp.zero_grad()
        output = mlp(train_encodings_tensor_copy)
        loss = loss_func(output, train_labels_tensor)
        loss.backward()

    #test
    #mlp.eval()
    test_output = mlp(test_encodings_tensor_copy)
    argmaxed_test_output = torch.argmax(test_output, dim=1)
    total = argmaxed_test_output.size(0)
    correct = (argmaxed_test_output == test_labels_tensor).sum().item()

    print(f'acc = {correct / total}')



    #create scatter of encoded points, put through tsne
    # see: https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
    tsne = TSNE(n_components = 2)
    num_test = len(test_encodings) * args.batch_size

    fig = plt.figure()
    set_encoding_in_hidden_dim_train = torch.cat(train_encodings, dim = 0).detach().cpu()
    set_encoding_in_hidden_dim_test = torch.cat(test_encodings, dim = 0).detach().cpu()
    set_encoding_in_hidden_dim = torch.cat((set_encoding_in_hidden_dim_test, set_encoding_in_hidden_dim_train), dim = 0)

    set_encoding_tsne = tsne.fit_transform(set_encoding_in_hidden_dim)
    set_encoding_tsne_x_train = set_encoding_tsne[num_test:, 0]
    set_encoding_tsne_y_train = set_encoding_tsne[num_test:, 1]
    set_encoding_tsne_x_test = set_encoding_tsne[:num_test, 0]
    set_encoding_tsne_y_test = set_encoding_tsne[:num_test, 1]

    plt.scatter(set_encoding_tsne_x_train, set_encoding_tsne_y_train)
    plt.scatter(set_encoding_tsne_x_test, set_encoding_tsne_y_test)
    name = f'img-latent-NEW'
    plt.savefig(name, dpi=300)
    plt.close(fig)

    fig = plt.figure()
    plt.scatter(range(args.epochs), epoch_losses)
    name = f"img-losses-per-epoch-train-{'train' if not args.eval_only else 'test'}"
    plt.savefig(name, dpi=300)