Example #1
0
File: debug.py Project: kzka/kronos
def main(args):
    log_dir = osp.join(CONFIG.DIRS.LOG_DIR, "debug")

    # Initialize experiment
    config, device = experiment_utils.init_experiment(
        log_dir, CONFIG, args.config_path, transient=True)

    # Load model, data loaders and trainer.
    debug = {
        "sample_sequential": not args.shuffle,
        "augment": args.augment,
        "num_workers": 0,
    }
    model, optimizer, loaders, trainer, _ = experiment_utils.get_factories(
        config, device, debug=debug)

    num_ctx_frames = config.SAMPLING.NUM_CONTEXT_FRAMES
    num_frames = config.SAMPLING.NUM_FRAMES_PER_SEQUENCE
    try:
        loader = loaders["pretrain_train"]
        for batch_idx, batch in enumerate(loader):
            if batch_idx > 4:
                break
            print(f"Batch #{batch_idx}")
            frames = batch["frames"]
            b, _, c, h, w = frames.shape
            frames = frames.view(b, num_frames, num_ctx_frames, c, h, w)

            # # To visualize baseline batch.
            # for i in range(frames.shape[1]):
            #     grid_img = torchvision.utils.make_grid(frames[:, i, -1], nrow=5)
            #     plt.imshow(grid_img.permute(1, 2, 0))
            #     plt.show()

            for b in range(frames.shape[0]):
                print(f"\tBatch Item {b}")
                if args.viz_context:
                    fig, axes = plt.subplots(
                        num_ctx_frames, num_frames, constrained_layout=True
                    )
                    for i in range(num_frames):
                        for j in range(num_ctx_frames):
                            axes[j, i].imshow(
                                UnNormalize()(frames[b, i, j]).permute(1, 2, 0)
                            )
                    for ax in axes.flatten():
                        ax.axis("off")
                    plt.show()
                else:
                    grid_img = torchvision.utils.make_grid(
                        frames[b, :, -1], nrow=5)
                    # grid_img = UnNormalize()(grid_img)
                    plt.imshow(grid_img.permute(1, 2, 0))
                    plt.show()
    except KeyboardInterrupt:
        sys.exit()
Example #2
0
def setup_embedder(log_dir, config_path):
    checkpoint_dir = os.path.join(log_dir, 'checkpoints')
    config, device = experiment_utils.init_experiment(log_dir, CONFIG,
                                                      config_path)
    embedder, _, _, _, _ = experiment_utils.get_factories(
        config, device, debug={'augment': False})
    checkpoint_manager = checkpoint.CheckpointManager(
        checkpoint.Checkpoint(embedder), checkpoint_dir, device)
    global_step = checkpoint_manager.restore_or_initialize()
    print(f"Restored model from checkpoint {global_step}")
    return embedder
Example #3
0
def setup_embedder(log_dir, config_path):
    """Load the latest embedder checkpoint and dataloaders."""
    checkpoint_dir = os.path.join(log_dir, 'checkpoints')
    config, device = experiment_utils.init_experiment(log_dir, CONFIG,
                                                      config_path)
    embedder, _, loaders, _, _ = experiment_utils.get_factories(
        config, device, debug={'augment': False})
    checkpoint_manager = checkpoint.CheckpointManager(
        checkpoint.Checkpoint(embedder), checkpoint_dir, device)
    global_step = checkpoint_manager.restore_or_initialize()
    l2_normalize = config.LOSS.L2_NORMALIZE_EMBEDDINGS
    print(f"Restored model from checkpoint {global_step}")
    return embedder, loaders, l2_normalize
Example #4
0
def main(args):
    if torch.cuda.is_available():
        device = torch.device("cuda")
        logging.info("Using GPU {}.".format(
            torch.cuda.get_device_name(device)))
    else:
        logging.info("No GPU found. Falling back to CPU.")
        device = torch.device("cpu")

    # initialize experiment
    opts = [
        "SAMPLING.STRIDE_ALL_SAMPLER",
        args.stride,
        "ACTION_CLASS",
        [],
    ]
    config_path = osp.join(args.logdir, "config.yml")
    config, device = experiment_utils.init_experiment(
        args.logdir,
        CONFIG,
        config_path,
        opts,
    )

    # load model and data loaders
    debug = {"sample_sequential": False, "augment": False, "labeled": "both"}
    tcc_model, _, loaders, _, _ = experiment_utils.get_factories(config,
                                                                 device,
                                                                 debug=debug)

    # load model checkpoint
    checkpoint.CheckpointManager.load_latest_checkpoint(
        checkpoint.Checkpoint(tcc_model),
        osp.join(config.DIRS.CKPT_DIR,
                 osp.basename(osp.normpath(args.logdir))),
        device,
    )
    tcc_model.to(device)
    tcc_model.featurizer_net.eval()
    tcc_model.encoder_net.train()
    freeze_model(tcc_model.featurizer_net, True, True)

    # for name, param in tcc_model.named_parameters():
    #     if param.requires_grad:
    #         print(name)

    # create LSTM model
    lstm = LSTMModel(
        in_dim=config.MODEL.EMBEDDER.EMBEDDING_SIZE,
        lstm_dim=32,
        num_layers=1,
    )
    lstm.to(device).train()

    optimizer = torch.optim.Adam(
        [
            {
                "params": lstm.parameters()
            },
            {
                "params": tcc_model.parameters(),
                "lr": 1e-4
            },
        ],
        lr=1e-3,
    )

    # figure out max batch size that's
    # a multiple of the number of context
    # frames.
    # this is so we can support large videos
    # with many frames.
    lcm = tcc_model.num_ctx_frames
    max_batch_size = math.floor(128 / lcm) * lcm

    # test on query
    rand_corr, rand_num = test_query(
        device,
        lstm,
        tcc_model,
        max_batch_size,
        (
            loaders["downstream_train"]["rms"],
            loaders["downstream_valid"]["rms"],
        ),
        args.l2_normalize,
        args.batch_size,
    )
    rand_acc = rand_corr / rand_num
    print("Randomly initialized LSTM accuracy: {} ({}/{})".format(
        rand_acc, rand_corr, rand_num))

    global_step = 0
    complete = False
    try:
        while not complete:
            accs = []
            for action_name, loader in loaders["downstream_train"].items():
                tcc_model.encoder_net.train()
                lstm.train()
                if action_name != "rms":
                    batch_embs, batch_labels = [], []
                    correct, num_x = 0, 0
                    set_trace()
                    for batch_idx, batch in enumerate(loader):
                        if batch_idx > 2:
                            break
                        if len(batch_embs) < args.batch_size:
                            frames = batch["frames"]
                            embs = embed(
                                tcc_model,
                                frames,
                                device,
                                max_batch_size,
                                args.l2_normalize,
                            )
                            batch_embs.append(embs[0])
                            batch_labels.extend(batch["success"])

                        if len(batch_embs) == args.batch_size or batch_idx == (
                                len(loader) - 1):
                            print(len(batch_embs))
                            # sort list of embeddings by sequence length
                            # in descending order
                            idxs_sorted = np.argsort(
                                [-len(x) for x in batch_embs])
                            batch_embs = [batch_embs[i] for i in idxs_sorted]
                            batch_labels = torch.stack(
                                [batch_labels[i] for i in idxs_sorted])
                            batch_labels = (
                                batch_labels.unsqueeze(1).float().to(device))

                            # forward through lstm and compute loss
                            out = lstm(batch_embs)
                            loss = F.binary_cross_entropy_with_logits(
                                out, batch_labels)

                            # backprop
                            optimizer.zero_grad()
                            loss.backward()
                            torch.nn.utils.clip_grad_norm_(
                                lstm.parameters(), 1.0)
                            optimizer.step()

                            # compute accuracy
                            pred = torch.sigmoid(out) > 0.5
                            correct += (pred.eq(
                                batch_labels.view_as(pred)).sum().item())
                            num_x += len(pred)
                            logging.info("{}: Loss: {:.6f}".format(
                                global_step, loss.item()))

                            global_step += 1
                            batch_embs, batch_labels = [], []

                        # exit if complete
                        if global_step >= args.max_iters:
                            complete = True
                            break

                    if num_x > 0:
                        acc = correct / num_x
                        logging.info("{} accuracy: {:.2f} ({}/{})".format(
                            action_name, acc, correct, num_x))
                        accs.append(acc)
            print("Mean embodiment accuracy: {}".format(np.mean(accs)))

            query_corr, query_num = test_query(
                device,
                lstm,
                tcc_model,
                max_batch_size,
                (
                    loaders["downstream_train"]["rms"],
                    loaders["downstream_valid"]["rms"],
                ),
                args.l2_normalize,
                args.batch_size,
            )
            query_acc = query_corr / query_num
            print("Learner accuracy: {} ({}/{})".format(
                query_acc, query_corr, query_num))

    except KeyboardInterrupt:
        logging.info(
            "Caught keyboard interrupt. Saving model before quitting.")
    finally:
        pass
Example #5
0
def main(args):
    if torch.cuda.is_available():
        device = torch.device("cuda")
        logging.info("Using GPU {}.".format(
            torch.cuda.get_device_name(device)))
    else:
        logging.info("No GPU found. Falling back to CPU.")
        device = torch.device("cpu")

    # initialize experiment
    opts = [
        "SAMPLING.STRIDE_ALL_SAMPLER",
        args.stride,
        "ACTION_CLASS",
        [],
        # "IMAGE_SIZE",
        # (216, 224),
        "DATASET",
        "embodied_glasses",
    ]
    config_path = osp.join(args.logdir, "config.yml")
    config, device = experiment_utils.init_experiment(
        args.logdir,
        CONFIG,
        config_path,
        opts,
    )

    # load TCC model and checkpoint
    debug = {
        "sample_sequential": True,
        "augment": False,
        "labeled": "both",
    }
    tcc_model, _, loaders, _, _ = experiment_utils.get_factories(config,
                                                                 device,
                                                                 debug=debug)
    checkpoint.CheckpointManager.load_latest_checkpoint(
        checkpoint.Checkpoint(tcc_model),
        osp.join(config.DIRS.CKPT_DIR,
                 osp.basename(osp.normpath(args.logdir))),
        device,
    )
    checkpoint_manager = checkpoint.CheckpointManager(
        checkpoint.Checkpoint(tcc_model),
        osp.join(config.DIRS.CKPT_DIR,
                 osp.basename(osp.normpath(args.logdir))),
        device,
    )
    init_step = checkpoint_manager.restore_or_initialize()
    # tcc_model.to(device).eval()
    tcc_model.to(device)
    tcc_model.featurizer_net.eval()
    tcc_model.encoder_net.train()
    freeze_model(tcc_model.featurizer_net, True, True)

    # create LSTM model
    lstm = LSTMModel(
        in_dim=config.MODEL.EMBEDDER.EMBEDDING_SIZE,
        lstm_dim=512,
        out_dim=1,
        num_layers=1,
    )
    # lstm = MLPModel(
    #     in_dim=config.MODEL.EMBEDDER.EMBEDDING_SIZE,
    #     hidden_dim=128,
    #     out_dim=1,
    # )
    lstm.to(device).train()

    # optimizer = torch.optim.Adam(lstm.parameters(), lr=1e-3, weight_decay=0)
    optimizer = torch.optim.Adam(
        [
            {
                "params": lstm.parameters(),
                "weight_decay": 1e-5
            },
            {
                "params": tcc_model.parameters(),
                "lr": 1e-4
            },
        ],
        lr=1e-3,
    )
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[100, 250],
                                                     gamma=0.1)

    # figure out max batch size that's
    # a multiple of the number of context
    # frames.
    # this is so we can support large videos
    # with many frames.
    lcm = tcc_model.num_ctx_frames
    max_batch_size = math.floor(128 / lcm) * lcm

    query_metric = test_query(
        device,
        lstm,
        tcc_model,
        max_batch_size,
        (
            loaders["downstream_train"][args.query],
            loaders["downstream_valid"][args.query],
        ),
        args.l2_normalize,
        args.batch_size,
    )
    conf_matrix = query_metric["confusion_matrix"]
    num_correct = conf_matrix[0, 0] + conf_matrix[1, 1]
    num_total = conf_matrix.ravel().sum()
    print(
        f"Random LSTM accuracy: {query_metric['accuracy']} ({num_correct}/{num_total})"
    )

    global_step = 0
    max_acc = 0
    complete = False
    losses, expert_acc, learner_acc = [], [], []
    try:
        while not complete:
            accs = []
            lstm.train()
            tcc_model.encoder_net.train()
            for action_name, loader in loaders["downstream_train"].items():
                if action_name != args.query:
                    batch_embs, batch_labels, batch_names = [], [], []
                    correct, num_x = 0, 0
                    for batch_idx, batch in enumerate(loader):
                        if len(batch_embs) < args.batch_size:
                            frames = batch["frames"]
                            # b, t, c, h, w = frames.shape
                            # frames = frames.view(b, t // tcc_model.num_ctx_frames, tcc_model.num_ctx_frames, c, h, w)
                            # img = UnNormalize()(frames[:, -1, -1])[0].permute(1, 2, 0).cpu().numpy()
                            # plt.imshow(img)
                            # plt.show()
                            # set_trace()
                            batch_names.extend(batch["video_name"][0])
                            embs = embed(
                                tcc_model,
                                frames,
                                device,
                                max_batch_size,
                                args.l2_normalize,
                            )
                            batch_embs.append(embs[0])
                            batch_labels.extend(batch["success"])

                        if len(batch_embs) == args.batch_size or batch_idx == (
                                len(loader) - 1):
                            # batch_embs, batch_labels = augment_batch(batch_embs, batch_labels)

                            # sort list of embeddings by sequence length
                            # in descending order
                            idxs_sorted = np.argsort(
                                [-len(x) for x in batch_embs])
                            batch_embs = [batch_embs[i] for i in idxs_sorted]
                            batch_labels = torch.stack(
                                [batch_labels[i] for i in idxs_sorted])
                            batch_labels = (
                                batch_labels.unsqueeze(1).float().to(device))
                            batch_names = [batch_names[i] for i in idxs_sorted]

                            # forward through lstm and compute loss
                            out = lstm(batch_embs)
                            loss = F.binary_cross_entropy_with_logits(
                                out, batch_labels)
                            losses.append(loss.item())

                            # backprop
                            optimizer.zero_grad()
                            loss.backward()
                            # torch.nn.utils.clip_grad_norm_(lstm.parameters(), 1.0)
                            optimizer.step()
                            scheduler.step()
                            for param_group in optimizer.param_groups:
                                print(global_step, param_group["lr"])

                            # compute accuracy
                            with torch.no_grad():
                                probs = torch.sigmoid(out)
                                pred = probs > 0.5
                                # corr = pred.eq(batch_labels)
                                # pred_np = corr.flatten().cpu().numpy()
                                # mistakes = [batch_names[i] for i in np.argwhere(pred_np == False).flatten().tolist()]
                                # if global_step > 100:
                                #     for m in mistakes:
                                #         print(m)
                                correct += (pred.eq(
                                    batch_labels.view_as(pred)).sum().item())
                                num_x += len(pred)
                                logging.info("{}: Loss: {:.6f}".format(
                                    global_step, loss.item()))

                            global_step += 1
                            batch_embs, batch_labels, batch_names = [], [], []

                        # exit if complete
                        if global_step >= args.max_iters:
                            complete = True
                            break

                    if num_x > 0:
                        acc = correct / num_x
                        logging.info("{} accuracy: {:.2f} ({}/{})".format(
                            action_name, acc, correct, num_x))
                        accs.append(acc)
            mean_acc = np.mean(accs)
            expert_acc.append(mean_acc)
            print("Mean embodiment accuracy: {}".format(mean_acc))

            plot_loss(losses, osp.join(args.logdir, "lstm_loss.png"))
            query_metric = test_query(
                device,
                lstm,
                tcc_model,
                max_batch_size,
                (
                    loaders["downstream_train"][args.query],
                    loaders["downstream_valid"][args.query],
                ),
                args.l2_normalize,
                args.batch_size,
            )
            learner_acc.append(query_metric["accuracy"])
            plot_acc(expert_acc, learner_acc,
                     osp.join(args.logdir, "lstm_acc.png"))
            if query_metric["accuracy"] > max_acc:
                max_acc = query_metric["accuracy"]
            conf_matrix = query_metric["confusion_matrix"]
            num_correct = conf_matrix[0, 0] + conf_matrix[1, 1]
            num_total = conf_matrix.ravel().sum()
            print(
                f"Learner accuracy: {query_metric['accuracy']} ({num_correct}/{num_total})"
            )

    except KeyboardInterrupt:
        logging.info(
            "Caught keyboard interrupt. Saving model before quitting.")
    finally:
        conf_matrix = query_metric["confusion_matrix"]
        np.savetxt(osp.join(args.logdir, "confusion_matrix_lstm.txt"),
                   conf_matrix)
        plot_auc(query_metric, osp.join(args.logdir, "auc_curve_lstm.png"))
        learner_acc.append(query_metric["accuracy"])
        plot_acc(expert_acc, learner_acc, osp.join(args.logdir,
                                                   "lstm_acc.png"))
        plot_loss(losses, osp.join(args.logdir, "lstm_loss.png"))
        checkpoint_manager.save(init_step + global_step)
        print("best acc achieved: {}".format(max_acc))
Example #6
0
def main(args):
    if torch.cuda.is_available():
        device = torch.device("cuda")
        logging.info("Using GPU {}.".format(
            torch.cuda.get_device_name(device)))
    else:
        logging.info("No GPU found. Falling back to CPU.")
        device = torch.device("cpu")

    # initialize experiment
    opts = [
        "SAMPLING.STRIDE_ALL_SAMPLER",
        args.stride,
    ]
    config_path = osp.join(args.logdir, "config.yml")
    config, device = experiment_utils.init_experiment(
        args.logdir,
        CONFIG,
        config_path,
        opts,
    )

    # load model and data loaders
    debug = {
        "sample_sequential": not args.shuffle,
        "augment": False,
        "labeled": None,  # "both" if args.negative else "pos",
    }
    model, _, loaders, _, _ = experiment_utils.get_factories(config,
                                                             device,
                                                             debug=debug)
    loaders = (loaders["downstream_valid"]
               if args.split == "valid" else loaders["downstream_train"])

    # load model checkpoint
    if args.model_ckpt is not None:
        checkpoint.Checkpoint(model).restore(args.model_ckpt, device)
    else:
        checkpoint.CheckpointManager.load_latest_checkpoint(
            checkpoint.Checkpoint(model),
            osp.join(config.DIRS.CKPT_DIR,
                     osp.basename(osp.normpath(args.logdir))),
            device,
        )
    model.to(device).eval()

    # figure out max batch size that's
    # a multiple of the number of context
    # frames.
    # this is so we can support large videos
    # with many frames.
    lcm = model.num_ctx_frames
    max_batch_size = math.floor(128 / lcm) * lcm

    # create save folder
    save_path = osp.join(
        config.DIRS.DIR,
        args.save_path,
        osp.basename(osp.normpath(args.logdir)),
        "embs",
    )
    file_utils.mkdir(save_path)

    # iterate over every class action
    pbar = tqdm.tqdm(loaders.items(), leave=False)
    for action_name, loader in pbar:
        msg = "embedding {}".format(action_name)
        pbar.set_description(msg)

        (
            embeddings,
            seq_lens,
            steps,
            vid_frames,
            names,
            labels,
            phase_labels,
        ) = ([] for i in range(7))
        for batch_idx, batch in enumerate(loader):
            if args.max_embs != -1 and batch_idx >= args.max_embs:
                break

            # unpack batch data
            frames = batch["frames"]
            chosen_steps = batch["frame_idxs"].to(device)
            seq_len = batch["video_len"].to(device)
            name = batch["video_name"][0]
            # label = batch["success"][0]
            # phase_label = None
            # if "phase_labels" in batch:
            #     phase_label = batch["phase_labels"].to(device)

            # forward through model to compute embeddings
            with torch.no_grad():
                if frames.shape[1] > max_batch_size:
                    embs = []
                    for i in range(math.ceil(frames.shape[1] /
                                             max_batch_size)):
                        sub_frames = frames[:, i * max_batch_size:(i + 1) *
                                            max_batch_size].to(device)
                        sub_embs = model(sub_frames)["embs"]
                        embs.append(sub_embs.cpu())
                    embs = torch.cat(embs, dim=1)
                else:
                    embs = model(frames.to(device))["embs"]

            # store
            embeddings.append(embs.cpu().squeeze().numpy())
            seq_lens.append(seq_len.cpu().squeeze().numpy())
            steps.append(chosen_steps.cpu().squeeze().numpy())
            # if phase_label is not None:
            # phase_labels.append(phase_label.cpu().squeeze().numpy())
            names.append(name)
            # labels.append(label.item())
            if args.keep_frames:
                frames = frames[0]
                frames = UnNormalize()(frames)
                frames = frames.view(
                    embs.shape[1],
                    config.SAMPLING.NUM_CONTEXT_FRAMES,
                    *frames.shape[1:],
                )
                vid_frames.append(frames.cpu().squeeze().numpy()[:,
                                                                 -1].transpose(
                                                                     0, 2, 3,
                                                                     1))

        data = {
            "embs": embeddings,
            "seq_lens": seq_lens,
            "steps": steps,
            "names": names,
            # "labels": labels,
        }
        if args.keep_frames:
            data["frames"] = vid_frames
        # if phase_labels:
        # data["phase_labels"] = phase_labels
        np.save(osp.join(save_path, action_name), data)
Example #7
0
def main(args):
    if torch.cuda.is_available():
        device = torch.device("cuda")
        logging.info("Using GPU {}.".format(
            torch.cuda.get_device_name(device)))
    else:
        logging.info("No GPU found. Falling back to CPU.")
        device = torch.device("cpu")

    # initialize experiment
    opts = [
        "SAMPLING.STRIDE_ALL_SAMPLER",
        args.stride,
        "ACTION_CLASS",
        [],
        "DATASET",
        "embodied_glasses",
    ]
    config_path = osp.join(args.logdir, "config.yml")
    config, device = experiment_utils.init_experiment(
        args.logdir,
        CONFIG,
        config_path,
        opts,
    )

    # determine the number of context frames
    num_ctx_frames = config.SAMPLING.NUM_CONTEXT_FRAMES
    print(f"ctx frames: {num_ctx_frames}")

    # load TCC model and checkpoint
    debug = {
        "sample_sequential": True,
        "augment": False,
        "labeled": "both",
    }
    _, _, loaders, _, _ = experiment_utils.get_factories(config,
                                                         device,
                                                         debug=debug)

    # create model
    model = Model(in_dim=3, out_dim=1)
    model.to(device).train()

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[100, 250],
                                                     gamma=0.1)

    query_metric = test_query(
        device,
        model,
        (
            loaders["downstream_train"][args.query],
            loaders["downstream_valid"][args.query],
        ),
        args.batch_size,
        num_ctx_frames,
    )
    conf_matrix = query_metric["confusion_matrix"]
    num_correct = conf_matrix[0, 0] + conf_matrix[1, 1]
    num_total = conf_matrix.ravel().sum()
    print(
        f"Initial pretrained model accuracy: {query_metric['accuracy']} ({num_correct}/{num_total})"
    )

    global_step = 0
    max_acc = 0
    complete = False
    losses, expert_acc, learner_acc = [], [], []
    try:
        while not complete:
            accs = []
            model.train()
            for action_name, loader in loaders["downstream_train"].items():
                if action_name != args.query:
                    batch_embs, batch_labels, batch_names = [], [], []
                    correct, num_x = 0, 0
                    for batch_idx, batch in enumerate(loader):
                        if len(batch_embs) < args.batch_size:
                            frames = batch["frames"]
                            b, t, c, h, w = frames.shape
                            frames = frames.view(b, t // num_ctx_frames,
                                                 num_ctx_frames, c, h, w)
                            frame = frames[:, -1, -1]
                            # img = UnNormalize()(frame)[0].permute(1, 2, 0).cpu().numpy()
                            # plt.imshow(img)
                            # plt.show()
                            # set_trace()
                            batch_names.extend(batch["video_name"][0])
                            batch_embs.append(frame)
                            batch_labels.extend(batch["success"])
                        if len(batch_embs) == args.batch_size or batch_idx == (
                                len(loader) - 1):
                            batch_embs = torch.cat(batch_embs,
                                                   dim=0).to(device)
                            batch_labels = (torch.stack(
                                batch_labels).unsqueeze(1).float().to(device))

                            # forward through model and compute loss
                            out = model(batch_embs)
                            loss = F.binary_cross_entropy_with_logits(
                                out, batch_labels)
                            losses.append(loss.item())

                            # backprop
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()
                            scheduler.step()
                            for param_group in optimizer.param_groups:
                                print(global_step, param_group["lr"])

                            # compute accuracy
                            with torch.no_grad():
                                probs = torch.sigmoid(out)
                                pred = probs > 0.5
                                correct += (pred.eq(
                                    batch_labels.view_as(pred)).sum().item())
                                num_x += len(pred)
                                logging.info("{}: Loss: {:.6f}".format(
                                    global_step, loss.item()))

                            global_step += 1
                            batch_embs, batch_labels, batch_names = [], [], []

                        # exit if complete
                        if global_step >= args.max_iters:
                            complete = True
                            break

                    if num_x > 0:
                        acc = correct / num_x
                        logging.info("{} accuracy: {:.2f} ({}/{})".format(
                            action_name, acc, correct, num_x))
                        accs.append(acc)
            mean_acc = np.mean(accs)
            expert_acc.append(mean_acc)
            print("Mean embodiment accuracy: {}".format(mean_acc))

            plot_loss(losses, osp.join(args.logdir, "conv_loss.png"))
            query_metric = test_query(
                device,
                model,
                (
                    loaders["downstream_train"][args.query],
                    loaders["downstream_valid"][args.query],
                ),
                args.batch_size,
                num_ctx_frames,
            )
            learner_acc.append(query_metric["accuracy"])
            plot_acc(expert_acc, learner_acc,
                     osp.join(args.logdir, "conv_acc.png"))
            if query_metric["accuracy"] > max_acc:
                max_acc = query_metric["accuracy"]
            conf_matrix = query_metric["confusion_matrix"]
            num_correct = conf_matrix[0, 0] + conf_matrix[1, 1]
            num_total = conf_matrix.ravel().sum()
            print(
                f"Learner accuracy: {query_metric['accuracy']} ({num_correct}/{num_total})"
            )

    except KeyboardInterrupt:
        logging.info(
            "Caught keyboard interrupt. Saving model before quitting.")
    finally:
        conf_matrix = query_metric["confusion_matrix"]
        np.savetxt(osp.join(args.logdir, "confusion_matrix_conv.txt"),
                   conf_matrix)
        plot_auc(query_metric, osp.join(args.logdir, "auc_curve_conv.png"))
        learner_acc.append(query_metric["accuracy"])
        plot_acc(expert_acc, learner_acc, osp.join(args.logdir,
                                                   "conv_acc.png"))
        plot_loss(losses, osp.join(args.logdir, "conv_loss.png"))
        print("best acc achieved: {}".format(max_acc))
Example #8
0
def main(args):
    log_dir = osp.join(CONFIG.DIRS.LOG_DIR, args.experiment_name)

    # initialize experiment
    config, device = experiment_utils.init_experiment(log_dir, CONFIG,
                                                      args.config_path)

    # setup logger
    logger = Logger(log_dir, args.resume)

    # load model, data loaders and trainer
    debug = {"sample_sequential": True, "num_workers": 4}
    (
        model,
        optimizer,
        loaders,
        trainer,
        evaluators,
    ) = experiment_utils.get_factories(config, device, debug=debug)

    # create checkpoint manager
    checkpoint_manager = checkpoint.CheckpointManager(
        checkpoint.Checkpoint(model, optimizer),
        osp.join(config.DIRS.CKPT_DIR, args.experiment_name),
        device,
    )

    global_step = checkpoint_manager.restore_or_initialize()
    epoch = int(global_step / len(loaders["pretrain_train"]))
    complete = False
    stopwatch = experiment_utils.Stopwatch()
    try:
        while not complete:
            logger.log_learning_rate(optimizer, global_step)
            for batch_idx, batch in enumerate(loaders["pretrain_train"]):
                if batch_idx >= args.num_samples:
                    break

                # train one iteration
                loss = trainer.train_one_iter(batch, global_step)

                # save model checkpoint
                if not global_step % config.CHECKPOINT.SAVE_INTERVAL:
                    checkpoint_manager.save(global_step)

                if not global_step % config.LOGGING.REPORT_INTERVAL:
                    # train loss with model in eval mode
                    train_eval_loss = trainer.eval_num_iters(
                        loaders["pretrain_train"],
                        (args.num_samples // config.BATCH_SIZE),
                    )
                    losses = {"train": loss, "train_eval": train_eval_loss}
                    logger.log_loss(losses, global_step)
                    logging.info(
                        "Iter[{}/{}] (Epoch {}), Train (eval) Loss: {:.3f}".
                        format(
                            global_step,
                            config.TRAIN.MAX_ITERS,
                            epoch,
                            train_eval_loss.item(),
                        ))

                if not global_step % config.LOGGING.EVAL_INTERVAL:
                    for eval_name, evaluator in evaluators.items():
                        eval_metric_train = evaluator.evaluate(
                            global_step,
                            model,
                            loaders["downstream_train"],
                            device,
                            args.num_samples,
                            msg="train",
                        )
                        metrics = {"train": eval_metric_train}
                        logger.log_metric(metrics, global_step, eval_name)
                        logging.info(
                            "Iter[{}/{}] (Epoch {}) - {} train: {:.4f}".format(
                                global_step,
                                config.TRAIN.MAX_ITERS,
                                epoch,
                                eval_name,
                                eval_metric_train["scalar"],
                            ))

                # exit if complete
                global_step += 1
                if global_step >= config.TRAIN.MAX_ITERS:
                    complete = True
                    break

                time_per_iter = stopwatch.elapsed()
                logging.info(
                    "Iter[{}/{}] (Epoch {}), {:.1f}s/iter, Loss: {:.3f}".
                    format(
                        global_step,
                        config.TRAIN.MAX_ITERS,
                        epoch,
                        time_per_iter,
                        loss.item(),
                    ))
                stopwatch.reset()
            epoch += 1

    except KeyboardInterrupt:
        logging.info(
            "Caught keyboard interrupt. Saving model before quitting.")

    finally:
        checkpoint_manager.save(global_step)
        logger.close()
Example #9
0
def main(args):
    if torch.cuda.is_available():
        device = torch.device("cuda")
        logging.info(
            "Using GPU {}.".format(torch.cuda.get_device_name(device))
        )
    else:
        logging.info("No GPU found. Falling back to CPU.")
        device = torch.device("cpu")

    # Initialize experiment.
    opts = [
        # "SAMPLING.STRIDE_ALL_SAMPLER",
        # args.stride,
        "IMAGE_SIZE",
        (64, 64),
        "ACTION_CLASS",
        [args.query],
        "DATASET",
        "demo",
    ]
    config_path = osp.join(args.logdir, "config.yml")
    config, device = experiment_utils.init_experiment(
        args.logdir, CONFIG, config_path, opts,
    )

    # Load TCC model and checkpoint.
    debug = {
        "sample_sequential": False,
        "augment": False,
        "labeled": None,
    }
    _, _, loaders, _, _ = experiment_utils.get_factories(
        config, device, debug=debug
    )
    # checkpoint.CheckpointManager.load_latest_checkpoint(
    #     checkpoint.Checkpoint(tcc_model),
    #     osp.join(
    #         config.DIRS.CKPT_DIR, osp.basename(osp.normpath(args.logdir))
    #     ),
    #     device,
    # )

    # Freeze TCC weights.
    # tcc_model.to(device)
    # tcc_model.featurizer_net.eval()
    # freeze_model(tcc_model.featurizer_net, False, False)
    # tcc_model.encoder_net.train()

    probe = Net(
        out_dim=5
    )  # LinearProbe(in_dim=config.MODEL.EMBEDDER.EMBEDDING_SIZE, out_dim=5)
    probe.to(device).train()

    optimizer = torch.optim.Adam(
        [
            {"params": probe.parameters(), "weight_decay": 0},
            # {"params": tcc_model.parameters()},
        ],
        lr=1e-3,
    )

    # Figure out max batch size that's a multiple of the number of context
    # frames. This is so we can support large videos with many frames.
    lcm = 1  # tcc_model.num_ctx_frames
    max_batch_size = math.floor(32 / lcm) * lcm

    # Grab dataloader.
    train_loader = loaders["pretrain_train"]  # [args.query]
    test_loader = loaders["pretrain_valid"]  # [args.query]

    global_step = 0
    complete = False
    try:
        while not complete:
            probe.train()
            # tcc_model.encoder_net.train()
            for batch_idx, batch in enumerate(train_loader):
                frames = batch["frames"].to(device)
                target = batch["debris_nums"].to(device)

                # Forward pass.
                loss, acc = embed_with_grad(
                    None,  # tcc_model,
                    probe,
                    frames,
                    target,
                    device,
                    max_batch_size,
                    args.l2_normalize,
                    optimizer,
                )

                print(f"{global_step}: Loss: {loss} - Acc: {100 * acc}")
                global_step += 1

            # Get test accuracy.
            test_probe(
                test_loader,
                None,  # tcc_model,
                probe,
                device,
                max_batch_size,
                args.l2_normalize,
                global_step,
            )

            # Exit if complete.
            if global_step >= args.max_iters:
                complete = True
                break

    except KeyboardInterrupt:
        logging.info(
            "Caught keyboard interrupt. Saving model before quitting."
        )
    finally:
        print("Exiting.")