Пример #1
0
    def setup_model(self):
        torch_devices = self.args.pytorch_gpu_ids
        device = "cuda:" + str(torch_devices[0])
        args = copy.deepcopy(self.args)
        args.title = os.path.join(args.title, "VinceModel")
        args.tensorboard_dir = os.path.join(
            args.base_logdir, args.title,
            *(args.checkpoint_dir.split(os.sep)[2:-1]), constants.TIME_STR)
        args.checkpoint_dir = os.path.join(
            args.base_logdir, args.title,
            *(args.checkpoint_dir.split(os.sep)[2:]))
        args.long_save_checkpoint_dir = os.path.join(
            args.base_logdir, args.title,
            *(args.long_save_checkpoint_dir.split(os.sep)[2:-1]),
            constants.TIME_STR)
        self.model = VinceModel(args)
        print(self.model)
        self.iteration = self.model.restore()
        self.model.to(device)
        self.queue_model = VinceQueueModel(args, self.model)
        self.queue_model.to(device)
        self.vince_queue = StorageQueue(args.vince_queue_size,
                                        args.vince_embedding_size,
                                        device=device)

        self.epoch = self.iteration // (self.args.iterations_per_epoch *
                                        self.args.batch_size)
        if self.iteration > 0:
            print("Resuming epoch", self.epoch)
        self.start_prefetch()
        self.fill_queue_repeat()
Пример #2
0
 def setup_feature_extractor(self):
     args = copy.deepcopy(self.args)
     args.title = os.path.join(args.title, "VinceModel")
     args.checkpoint_dir = os.path.join(
         args.base_logdir, args.title,
         *(args.checkpoint_dir.split(os.sep)[2:]))
     args.long_save_checkpoint_dir = os.path.join(
         args.base_logdir, args.title,
         *(args.long_save_checkpoint_dir.split(os.sep)[2:-1]),
         constants.TIME_STR)
     args.tensorboard_dir = os.path.join(
         args.base_logdir, args.title,
         *(args.checkpoint_dir.split(os.sep)[2:-1]), constants.TIME_STR)
     self.feature_extractor = VinceModel(args)
     print(self.feature_extractor)
     self.feature_extractor.restore()
     self.feature_extractor.to(self.device)
     if self.freeze_feature_extractor:
         self.feature_extractor.eval()
     else:
         self.feature_extractor.train()
Пример #3
0
def main():
    with torch.no_grad():
        torch_devices = args.pytorch_gpu_ids
        device = "cuda:" + str(torch_devices[0])
        model = VinceModel(args)

        model.restore()
        model.eval()
        model.to(device)
        yt_dataset = R2V2Dataset(
            args, "val", transform=StandardVideoTransform(args.input_size, "val"), num_images_to_return=1
        )
        torch.manual_seed(0)
        random.seed(0)
        np.random.seed(0)
        data_loader = PersistentDataLoader(
            yt_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
            collate_fn=R2V2Dataset.collate_fn,
            worker_init_fn=R2V2Dataset.worker_init_fn,
        )

        yt_features, yt_images = dataset_nn(model, data_loader)
        del data_loader

        draw_nns(yt_features, yt_images, "youtube")

        torch.manual_seed(0)
        random.seed(0)
        np.random.seed(0)
        valdir = os.path.join(args.imagenet_data_path, data_subset)
        transform = RepeatedImagenetTransform(args.input_height, data_subset="val", repeats=1)
        imagenet_dataset = datasets.ImageFolder(valdir, transform)
        data_loader = DataLoader(
            imagenet_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True
        )

        imagenet_features, imagenet_images = dataset_nn(model, data_loader)
        del data_loader

        draw_nns(imagenet_features, imagenet_images, "imagenet")
        draw_nns(imagenet_features, imagenet_images, "imagenet", yt_features, yt_images, "youtube")
        draw_nns(yt_features, yt_images, "youtube", imagenet_features, imagenet_images, "imagenet")
Пример #4
0
class VinceSolver(BaseSolver):
    def __init__(self, args, train_logger=None, val_logger=None):
        # Declare these before super call because super sets them up.
        self.num_frames = args.num_frames
        self.train_loaders = []
        self.train_batch_iterators = []
        self.train_loader_counts = []
        self.train_batch_fns = []
        self.train_data_names = []
        self.val_loaders = []
        self.val_batch_fns = []
        self.val_data_names = []
        self.vince_queue: StorageQueue = None
        self.queue_model: VinceQueueModel = None
        self.batch_count = 0
        self.batch_queue = Queue(2)
        self.prefetch_thread = None
        self.kill_thread = False
        self.cifar_dataset: NPZDataset = None
        self.drawn_this_epoch = False

        super(VinceSolver, self).__init__(args, train_logger, val_logger)

    def setup_dataloader(self):
        torch_devices = self.args.pytorch_gpu_ids
        device = torch_devices[0]

        # Can use multiple datasets at once.
        # Delayed start. Create data loaders to make the processes, then create the datasets so they don't overwhelm
        # shared memory.
        print("creating data processes")
        t_start = time.time()
        if self.args.use_imagenet:
            imagenet_train_loader = PersistentDataLoader(
                dataset=None,
                num_workers=min(self.args.num_workers, 40),
                pin_memory=True,
                device=device,
                never_ending=True,
            )
            batch_iterator = iter(imagenet_train_loader)
            batch_fn = self.get_imagenet_batch
            self.train_loaders.append(imagenet_train_loader)
            self.train_batch_iterators.append(batch_iterator)
            self.train_batch_fns.append(batch_fn)
            self.train_loader_counts.append(0)
            self.train_data_names.append("ImageNet")

            imagenet_val_loader = PersistentDataLoader(
                dataset=None,
                num_workers=min(self.args.num_workers, 10),
                pin_memory=False,
                device=device)
            self.val_loaders.append(imagenet_val_loader)
            self.val_batch_fns.append(self.process_imagenet_data)
            self.val_data_names.append("ImageNet")

        if not self.args.disable_dataloader and self.args.use_videos:
            video_train_loader = PersistentDataLoader(
                dataset=None,
                num_workers=min(self.args.num_workers, 40),
                pin_memory=True,
                device=device,
                never_ending=True,
            )
            batch_iterator = iter(video_train_loader)
            batch_fn = self.get_video_batch
            self.train_loaders.append(video_train_loader)
            self.train_batch_iterators.append(batch_iterator)
            self.train_batch_fns.append(batch_fn)
            self.train_loader_counts.append(0)
            self.train_data_names.append("R2V2")

            video_val_loader = PersistentDataLoader(dataset=None,
                                                    num_workers=min(
                                                        self.args.num_workers,
                                                        20),
                                                    pin_memory=False,
                                                    device=device)
            self.val_loaders.append(video_val_loader)
            self.val_batch_fns.append(self.process_video_data)
            self.val_data_names.append("R2V2")

        print("done in %.3f" % (time.time() - t_start))
        print("creating datasets")
        t_start = time.time()

        # Now create the actual datasets
        if not self.args.disable_dataloader and self.args.use_imagenet:
            imagenet_train_loader.set_dataset(
                datasets.ImageFolder(
                    os.path.join(self.args.imagenet_data_path, "train"),
                    transform=self.args.transform(self.input_size,
                                                  "train",
                                                  2 * self.num_frames,
                                                  stack=False),
                ),
                batch_size=self.args.batch_size // self.num_frames,
                shuffle=True,
            )
            print(
                "Loaded ImageNet train",
                len(imagenet_train_loader.dataset),
                "images",
                len(imagenet_train_loader),
                "batches",
            )
            imagenet_val_loader.set_dataset(
                datasets.ImageFolder(
                    os.path.join(self.args.imagenet_data_path, "val"),
                    transform=self.args.transform(self.input_size,
                                                  "val",
                                                  2 * self.num_frames,
                                                  stack=False),
                ),
                batch_size=self.args.batch_size // self.num_frames,
                shuffle=True,
            )
            print("Loaded ImageNet val", len(imagenet_val_loader.dataset),
                  "images", len(imagenet_val_loader), "batches")

        if not self.args.disable_dataloader and self.args.use_videos:
            video_train_loader.set_dataset(
                self.args.dataset(
                    self.args,
                    "train",
                    transform=self.args.transform(self.input_size, "train"),
                    num_images_to_return=self.num_frames,
                ),
                batch_size=self.args.batch_size // self.num_frames,
                shuffle=True,
                collate_fn=self.args.dataset.collate_fn,
                worker_init_fn=self.args.dataset.worker_init_fn,
                drop_last=True,
            )
            print("Loaded Video train", len(video_train_loader.dataset),
                  "images", len(video_train_loader), "batches")

            # Use train transform to make it equally hard.
            video_val_loader.set_dataset(
                self.args.dataset(
                    self.args,
                    "val",
                    transform=self.args.transform(self.input_size, "train"),
                    num_images_to_return=self.num_frames,
                ),
                batch_size=self.args.batch_size // self.num_frames,
                shuffle=True,
                collate_fn=self.args.dataset.collate_fn,
                worker_init_fn=self.args.dataset.worker_init_fn,
            )
            print("Loaded Video val", len(video_val_loader.dataset), "images",
                  len(video_val_loader), "batches")
        print("done in %.3f" % (time.time() - t_start))

    @property
    def iterations_per_epoch(self):
        return self.args.iterations_per_epoch

    def process_imagenet_data(self, data):
        images, labels = data
        data = images[:self.num_frames]
        queue_data = images[self.num_frames:]
        if self.num_frames > 1:
            data = pt_util.remove_dim(torch.stack(data, dim=1), 1)
            queue_data = pt_util.remove_dim(torch.stack(queue_data, dim=1), 1)
            labels = labels.repeat_interleave(self.num_frames)
        else:
            data = data[0]
            queue_data = queue_data[0]
        batch = {
            "data": data,
            "queue_data": queue_data,
            "imagenet_labels": labels,
            "data_source": "IN",
            "num_frames": self.num_frames,
            "batch_type": "images",
            "batch_size": len(data),
        }
        return batch

    def get_imagenet_batch(self, loader_id):
        data = next(self.train_batch_iterators[loader_id])
        self.train_loader_counts[loader_id] += 1
        if self.train_loader_counts[loader_id] == (
                len(self.train_loaders[loader_id]) + 1):
            # Check this because using never-ending persistent dataloader which never throws stop iteration.
            self.train_loader_counts[loader_id] = 0
            print("Hit ImageNet stop iteration. End of epoch.")
            return None
        return self.process_imagenet_data(data)

    def process_video_data(self, batch):
        data = pt_util.remove_dim(batch["data"], 1)
        queue_data = pt_util.remove_dim(batch["queue_data"], 1)
        batch = {
            "data": data,
            "queue_data": queue_data,
            "data_source": "YT",
            "batch_type": "video",
            "batch_size": len(data),
            "num_frames": self.num_frames,
            "imagenet_labels": torch.full((len(data), ), -1,
                                          dtype=torch.int64),
        }
        return batch

    def get_video_batch(self, loader_id):
        batch = next(self.train_batch_iterators[loader_id])
        self.train_loader_counts[loader_id] += 1
        if self.train_loader_counts[loader_id] == len(
                self.train_loaders[loader_id]) + 1:
            # Check this because using never-ending persistent dataloader which never throws stop iteration.
            self.train_loader_counts[loader_id] = 0
            print("Hit Video stop iteration. End of epoch.")
            return None
        return self.process_video_data(batch)

    def setup_other(self):
        t_start = time.time()
        print("Loading CIFAR")
        if self.args.save or self.args.test_first:
            self.cifar_dataset = NPZDataset(
                self.args,
                os.path.join(os.path.dirname(__file__), os.pardir, "datasets",
                             "cifar_data", "cifar_{data_subset}.npz"),
                "train",
                10000,
            )
        else:
            print("Not loading CIFAR, probably in debug")
        print("CIFAR loaded in %.3f" % (time.time() - t_start))

    def setup_optimizer(self):
        base_lr = self.args.base_lr
        params = self.model.parameters()
        param_groups = [{"params": params, "initial_lr": base_lr}]
        optimizer = optim.SGD(param_groups,
                              lr=base_lr,
                              weight_decay=0.0001,
                              momentum=0.9)
        for param_group in optimizer.param_groups:
            if "initial_lr" not in param_group:
                param_group["initial_lr"] = base_lr
        if self.use_apex:
            (self.model, self.queue_model), optimizer = amp.initialize(
                [self.model, self.queue_model], optimizer, opt_level="O1")

        self.optimizer = optimizer
        self.print_optimizer()

    def setup_model(self):
        torch_devices = self.args.pytorch_gpu_ids
        device = "cuda:" + str(torch_devices[0])
        args = copy.deepcopy(self.args)
        args.title = os.path.join(args.title, "VinceModel")
        args.tensorboard_dir = os.path.join(
            args.base_logdir, args.title,
            *(args.checkpoint_dir.split(os.sep)[2:-1]), constants.TIME_STR)
        args.checkpoint_dir = os.path.join(
            args.base_logdir, args.title,
            *(args.checkpoint_dir.split(os.sep)[2:]))
        args.long_save_checkpoint_dir = os.path.join(
            args.base_logdir, args.title,
            *(args.long_save_checkpoint_dir.split(os.sep)[2:-1]),
            constants.TIME_STR)
        self.model = VinceModel(args)
        print(self.model)
        self.iteration = self.model.restore()
        self.model.to(device)
        self.queue_model = VinceQueueModel(args, self.model)
        self.queue_model.to(device)
        self.vince_queue = StorageQueue(args.vince_queue_size,
                                        args.vince_embedding_size,
                                        device=device)

        self.epoch = self.iteration // (self.args.iterations_per_epoch *
                                        self.args.batch_size)
        if self.iteration > 0:
            print("Resuming epoch", self.epoch)
        self.start_prefetch()
        self.fill_queue_repeat()

    def fill_queue(self):
        # Fill queue with many different batches.
        # Sync parameters because might as well.
        self.queue_model.param_update(self.model, 0)
        num_added = 0
        self.vince_queue.clear()
        print("Filling queue")
        with torch.no_grad():
            pbar = tqdm.tqdm(total=self.vince_queue.maxsize)
            while num_added < self.vince_queue.maxsize:
                batches_concat, batches = self.get_batch()
                outputs = self.queue_model(batches_concat)
                for batch, output in zip(batches, outputs):
                    images = batch["queue_data_cpu"]
                    self.vince_queue.enqueue(output["queue_embeddings"],
                                             images, batch["data_source"])
                    num_added += len(images)
                    pbar.update(len(images))
                    if num_added >= self.vince_queue.maxsize:
                        break
            pbar.close()
        print("Queue filled")

    def fill_queue_repeat(self):
        # Fill queue with the same batch over and over.
        # Sync parameters because might as well.
        self.queue_model.param_update(self.model, 0)
        num_added = 0
        self.vince_queue.clear()
        with torch.no_grad():
            batches_concat, batches = self.get_batch()
            outputs = self.queue_model(batches_concat)
            while num_added < self.vince_queue.maxsize:
                for batch, output in zip(batches, outputs):
                    images = batch["queue_data_cpu"]
                    self.vince_queue.enqueue(output["queue_embeddings"],
                                             images, batch["data_source"])
                    num_added += len(images)
                    if num_added >= self.vince_queue.maxsize:
                        break
        self.vince_queue.current_tail = 0
        self.vince_queue.full = False
        print("Queue filled with repeats")

    def reset_epoch(self):
        super(VinceSolver, self).reset_epoch()
        self.queue_model.train()
        self.drawn_this_epoch = False

    def prefetch_batches(self):
        while not self.kill_thread:
            batches = []
            for ii in range(len(self.train_batch_fns)):
                loader_id = self.batch_count % len(self.train_batch_fns)
                batch = self.train_batch_fns[loader_id](loader_id)
                if batch is None:
                    self.batch_queue.put(None)
                    batches = None
                    break
                self.batch_count += 1
                initial_images = batch["queue_data"]
                batch = {
                    key: (val.to(self.model.device) if isinstance(
                        val, torch.Tensor) else val)
                    for key, val in batch.items()
                }
                batch["queue_data_cpu"] = initial_images
                batches.append(batch)

            if batches is not None:
                if len(batches) == 1:
                    batches_concat = {
                        key: val if isinstance(val, torch.Tensor) else [val]
                        for key, val in batches[0].items()
                    }
                else:
                    batches_concat = pt_util.stack_dicts_in_list(batches,
                                                                 axis=0,
                                                                 concat=True)
                batches_concat["batch_types"] = batches_concat["batch_type"]
                del batches_concat["batch_type"]
                batches_concat["batch_sizes"] = batches_concat["batch_size"]
                del batches_concat["batch_size"]
                self.batch_queue.put((batches_concat, batches))

    def start_prefetch(self):
        self.prefetch_thread = Thread(target=self.prefetch_batches)
        self.prefetch_thread.start()

    def end(self):
        self.kill_thread = True

    def get_batch(self):
        batches = self.batch_queue.get()
        while batches is None:
            self.fill_queue_repeat()
            batches = self.batch_queue.get()
        return batches

    def run_train_iteration(self):
        total_t_start = time.time()
        t_start = time.time()
        image_batch_concat, image_batches = self.get_batch()

        t_end = time.time()
        self.time_meters["data_cache_time"].update(t_end - t_start)
        t_start = time.time()

        # Feed in whole batch together through encoder (the most computationally expensive part).

        if self.args.jigsaw:
            if random.random() < 0.5:
                queue_batches = self.queue_model(image_batch_concat,
                                                 jigsaw=True,
                                                 shuffle=True)
                outputs = self.model.get_embeddings(image_batch_concat,
                                                    jigsaw=False,
                                                    shuffle=True)
            else:
                queue_batches = self.queue_model(image_batch_concat,
                                                 jigsaw=False,
                                                 shuffle=True)
                outputs = self.model.get_embeddings(image_batch_concat,
                                                    jigsaw=True,
                                                    shuffle=True)
        else:
            queue_batches = self.queue_model(image_batch_concat, shuffle=True)
            outputs = self.model.get_embeddings(image_batch_concat,
                                                shuffle=True)

        t_end = time.time()
        self.time_meters["forward_time"].update(t_end - t_start)
        t_start = time.time()

        loss_list = []
        metrics_list = []
        # Feed in batch as separate mini-batches depending on batch type.
        image_batches = self.model.split_dict_by_type(
            image_batch_concat["batch_types"],
            image_batch_concat["batch_sizes"], image_batch_concat)

        for bb, (image_batch, queue_batch, output) in enumerate(
                zip(image_batches, queue_batches, outputs)):
            output.update(self.vince_queue.dequeue())
            output.update(image_batch)
            output.update(queue_batch)

            output.update(self.model(output))
            loss_dict = self.model.loss(output)
            metrics = self.model.get_metrics(output)
            loss_list.append(
                {key: val[0] * val[1]
                 for key, val in loss_dict.items()})
            metrics_list.append(metrics)

        loss_dict = pt_util.stack_dicts_in_list(loss_list)
        loss_dict = {key: val.mean() for key, val in loss_dict.items()}
        metrics = pt_util.stack_dicts_in_list(metrics_list)
        metrics = {key: val.mean() for key, val in metrics.items()}

        updated_loss_meters = set()
        try:
            total_loss = 0
            for key, weighted_loss in loss_dict.items():
                total_loss = total_loss + weighted_loss
                self.loss_meters[key].update(weighted_loss)
                updated_loss_meters.add(key)
            if "total_loss" in self.loss_meters:
                self.loss_meters["total_loss"].update(total_loss)
                updated_loss_meters.add("total_loss")
            loss = total_loss
            assert torch.isfinite(loss)
        except AssertionError as re:
            import pdb
            traceback.print_exc()
            pdb.set_trace()
            print("anomoly", re)
            raise re

        updated_metric_meters = set()
        for key, val in metrics.items():
            self.metric_meters[key].update(val)
            updated_metric_meters.add(key)

        t_end = time.time()
        self.time_meters["metrics_time"].update(t_end - t_start)

        t_start = time.time()
        self.optimizer.zero_grad()
        if self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.optimizer.step()

        t_end = time.time()
        self.time_meters["backward_time"].update(t_end - t_start)

        for bb, (image_batch, output) in enumerate(zip(image_batches,
                                                       outputs)):
            if not self.drawn_this_epoch and self.vince_queue.full:
                # If we haven't drawn yet and the queue is full of many different batches, draw now.
                if bb == len(image_batches) - 1:
                    self.drawn_this_epoch = True
                print("Drawing Tensorboard images")
                image_output = self.model.get_image_output(output)
                if self.train_logger is not None:
                    for key, val in image_output.items():
                        if key.startswith("images"):
                            if isinstance(val, list):
                                for vv, item in enumerate(val):
                                    self.train_logger.image_summary(
                                        self.full_name + "_" +
                                        key[len("images/"):], item,
                                        self.iteration + vv, False)
                            else:
                                self.train_logger.image_summary(
                                    self.full_name + "_" +
                                    key[len("images/"):], val, self.iteration,
                                    False)

            # Must be done after image logger because images are reassigned in place.
            # update queue and queue params
            queue_images_cpu = image_batch["queue_data_cpu"]
            self.vince_queue.enqueue(output["queue_embeddings"],
                                     queue_images_cpu,
                                     image_batch["data_source"])

        self.queue_model.vince_update(self.model)
        if self.logger_iteration % self.args.save_frequency == 0:
            self.model.save(self.iteration, 5)

        if self.logger_iteration % self.args.log_frequency == 0:
            log_dict = {
                "times/%s/%s" % (self.full_name, key): val.val
                for key, val in self.time_meters.items()
            }
            log_dict.update({
                "losses/%s/%s" % (self.full_name, key):
                self.loss_meters[key].val
                for key in updated_loss_meters
            })
            log_dict.update({
                "metrics/%s/%s" % (self.full_name, key):
                self.metric_meters[key].val
                for key in updated_metric_meters
            })
            if self.train_logger is not None:
                self.train_logger.dict_log(log_dict, self.iteration)

        self.iteration += self.args.batch_size

        total_t_end = time.time()
        self.time_meters["total_time"].update(total_t_end - total_t_start)
        self.logger_iteration += 1

    def run_val(self):
        with torch.no_grad():
            self.model.eval()
            time_meters = dict(
                total_time=RollingAverageMeter(self.args.log_frequency),
                data_cache_time=RollingAverageMeter(self.args.log_frequency),
                forward_time=RollingAverageMeter(self.args.log_frequency),
                metrics_time=RollingAverageMeter(self.args.log_frequency),
            )
            loss_meters = {
                key: RollingAverageMeter(self.args.log_frequency)
                for key in self.model.loss(None).keys()
            }
            if len(loss_meters) > 1:
                loss_meters["total_loss"] = RollingAverageMeter(
                    self.args.log_frequency)
            metric_meters = {
                metric: RollingAverageMeter(self.args.log_frequency)
                for metric in self.model.get_metrics(None).keys()
            }

            epoch_loss_meters = {
                "epoch_" + key: AverageMeter()
                for key in loss_meters.keys()
            }
            epoch_metric_meters = {
                "epoch_" + key: AverageMeter()
                for key in metric_meters.keys()
            }

            updated_epoch_loss_meters = set()
            updated_epoch_metric_meters = set()

            step_on = self.iteration

            for val_name, val_loader, data_processor in zip(
                    self.val_data_names, self.val_loaders, self.val_batch_fns):
                print("Running val for", val_name)
                total_t_start = time.time()
                test_t_start = time.time()
                for ii, image_batch in enumerate(tqdm.tqdm(val_loader)):
                    if test_t_start - time.time() > 5 * 60:
                        # Break after 5 minutes.
                        break
                    image_batch = data_processor(image_batch)
                    image_batch["batch_types"] = [image_batch["batch_type"]]
                    del image_batch["batch_type"]
                    image_batch["batch_sizes"] = [image_batch["batch_size"]]
                    del image_batch["batch_size"]
                    image_batch = {
                        key: (val.to(self.model.device, non_blocking=True)
                              if isinstance(val, torch.Tensor) else val)
                        for key, val in image_batch.items()
                    }

                    batch_size = image_batch["data"].shape[0]

                    t_end = time.time()
                    time_meters["data_cache_time"].update(t_end -
                                                          total_t_start)
                    t_start = time.time()

                    image_batch.update(self.queue_model(image_batch)[0])
                    image_batch.update(
                        self.model.get_embeddings(image_batch)[0])
                    image_batch.update(self.vince_queue.dequeue())

                    image_batch.update(self.model(image_batch))
                    output = image_batch
                    loss_dict = self.model.loss(output)

                    t_end = time.time()
                    time_meters["forward_time"].update(t_end - t_start)
                    t_start = time.time()

                    metrics = self.model.get_metrics(output)
                    if ii % self.args.image_log_frequency == 0:
                        image_output = self.model.get_image_output(output)

                    updated_loss_meters = set()
                    total_loss = 0
                    for key, val in loss_dict.items():
                        weighted_loss = val[0] * val[1]
                        total_loss = total_loss + weighted_loss
                        loss_meters[key].update(weighted_loss)
                        epoch_loss_meters["epoch_" + key].update(
                            weighted_loss, batch_size)
                        updated_loss_meters.add(key)
                        updated_epoch_loss_meters.add("epoch_" + key)
                    if "total_loss" in loss_meters:
                        loss_meters["total_loss"].update(total_loss)
                        epoch_loss_meters["epoch_total_loss"].update(
                            total_loss, batch_size)
                        updated_loss_meters.add("total_loss")
                        updated_epoch_loss_meters.add("epoch_total_loss")
                    loss = total_loss

                    try:
                        assert torch.isfinite(loss)
                    except:
                        # output = self.model.forward(image_batch)
                        print("Nan loss", loss_dict)

                    updated_metric_meters = set()
                    for key, val in metrics.items():
                        metric_meters[key].update(val)
                        updated_metric_meters.add(key)
                        epoch_metric_meters["epoch_" + key].update(
                            val, batch_size)
                        updated_epoch_metric_meters.add("epoch_" + key)

                    t_end = time.time()
                    time_meters["metrics_time"].update(t_end - t_start)

                    if ii % self.args.image_log_frequency == 0:
                        if self.val_logger is not None:
                            for key, val in image_output.items():
                                if isinstance(val, list):
                                    for vv, item in enumerate(val):
                                        self.val_logger.image_summary(
                                            self.full_name + "_" +
                                            key[len("images/"):], item,
                                            step_on + vv, False)
                                else:
                                    self.val_logger.image_summary(
                                        self.full_name + "_" +
                                        key[len("images/"):], val, step_on,
                                        False)

                    if ii % self.args.log_frequency == 0:
                        log_dict = {
                            "times/%s/%s" % (self.full_name, key): val.val
                            for key, val in time_meters.items()
                        }
                        log_dict.update({
                            "losses/%s/%s" % (self.full_name, key):
                            loss_meters[key].val
                            for key in updated_loss_meters
                        })
                        log_dict.update({
                            "metrics/%s/%s" % (self.full_name, key):
                            metric_meters[key].val
                            for key in updated_metric_meters
                        })
                        if self.val_logger is not None:
                            self.val_logger.dict_log(log_dict, step_on)

                    step_on += self.args.batch_size
                    total_t_end = time.time()
                    time_meters["total_time"].update(total_t_end -
                                                     total_t_start)
                    total_t_start = time.time()

            ##### CIFAR #####
            epoch_metric_meters["epoch_knn_cifar"] = AverageMeter()

            all_features = []
            imagenet_mean = pt_util.from_numpy(constants.IMAGENET_MEAN).to(
                self.model.device).view(1, -1, 1, 1)
            imagenet_std = pt_util.from_numpy(constants.IMAGENET_STD).to(
                self.model.device).view(1, -1, 1, 1)

            print("Running CIFAR")
            for start_ind in tqdm.tqdm(
                    range(0, len(self.cifar_dataset), self.args.batch_size)):
                data = self.cifar_dataset.data[
                    start_ind:min(len(self.cifar_dataset), start_ind +
                                  self.args.batch_size)]
                data = data.to(device=self.model.device, dtype=torch.float32)
                data = data - imagenet_mean
                data.div_(imagenet_std)
                features = self.model.get_embeddings({"data":
                                                      data})["embeddings"]
                all_features.append(pt_util.to_numpy(features))
            all_images = np.transpose(
                pt_util.to_numpy(self.cifar_dataset.data), (0, 2, 3, 1))
            labels = pt_util.to_numpy(self.cifar_dataset.labels)
            all_features = np.concatenate(all_features, axis=0)
            if len(all_features.shape) == 4:
                # all_features = pt_util.remove_dim(all_features, dim=(2, 3))
                all_features = np.mean(all_features, axis=(2, 3))

            if self.val_logger is not None:
                kdt = KDTree(all_features, leaf_size=40, metric="euclidean")
                neighbors = kdt.query(all_features, k=11)[1]
                # remove self match
                neighbors = neighbors[:, 1:]
                preds_all = labels[neighbors]
                preds = scipy.stats.mode(preds_all, axis=1)[0].squeeze(1)
                acc = np.mean(preds == labels)
                epoch_metric_meters["epoch_knn_cifar"].update(acc)
                updated_epoch_metric_meters.add("epoch_knn_cifar")

                nn_inds = kdt.query(all_features[0:100:10], k=10)[1]
                image = drawing.subplot(all_images[nn_inds.reshape(-1)],
                                        10,
                                        10,
                                        self.args.input_width,
                                        self.args.input_height,
                                        border=10)

                self.val_logger.image_summary(self.full_name + "_kNN/cifar",
                                              image,
                                              step_on,
                                              increment_counter=False,
                                              max_size=1000)

        log_dict = {
            "epoch/losses/%s/%s" % (self.full_name, key):
            epoch_loss_meters[key].avg
            for key in updated_epoch_loss_meters
        }
        log_dict.update({
            "epoch/metrics/%s/%s" % (self.full_name, key):
            epoch_metric_meters[key].avg
            for key in updated_epoch_metric_meters
        })
        if self.val_logger is not None:
            self.val_logger.dict_log(log_dict, step_on)
Пример #5
0
class EndTaskBaseSolver(BaseSolver, abc.ABC):
    def __init__(self, args, train_logger=None, val_logger=None):
        self.train_loader = None
        self.val_loader = None
        self.batch_iter = None
        self.feature_extractor: nn.Module = None

        super(EndTaskBaseSolver, self).__init__(args, train_logger, val_logger)
        self.train_batch_counter = 0

    def setup_dataloader(self):
        if not self.args.disable_dataloader:
            self.train_loader = PersistentDataLoader(
                dataset=None,
                num_workers=min(self.args.num_workers, 40),
                pin_memory=True,
                device=self.device,
                never_ending=True,
            )

            self.val_loader = PersistentDataLoader(
                dataset=None,
                num_workers=min(self.args.num_workers, 40),
                pin_memory=False,
                device=self.device,
                never_ending=True,
            )

            self.train_loader.set_dataset(
                self.args.dataset(self.args, "train"),
                batch_size=self.args.batch_size,
                shuffle=True,
                collate_fn=self.args.dataset.collate_fn,
                worker_init_fn=self.args.dataset.worker_init_fn,
            )
            self.batch_iter = iter(self.train_loader)
            self.val_loader.set_dataset(
                self.args.dataset(self.args, "val"),
                batch_size=self.args.batch_size,
                shuffle=True,
                collate_fn=self.args.dataset.collate_fn,
                worker_init_fn=self.args.dataset.worker_init_fn,
            )

    @property
    def iterations_per_epoch(self):
        return len(self.train_loader)

    def setup_model_param_groups(self) -> List[Dict]:
        raise NotImplementedError

    @staticmethod
    def create_optimizer(param_groups, base_lr):
        return torch.optim.Adam(param_groups, lr=base_lr, weight_decay=1e-4)

    def setup_optimizer(self):
        base_lr = self.args.base_lr
        param_groups = self.setup_model_param_groups()

        if not self.freeze_feature_extractor:
            param_group = {
                "params": self.feature_extractor.parameters(),
                "lr": base_lr,
                "weight_decay": 1e-4,
                "initial_lr": base_lr,
            }
            param_groups.append(param_group)

        # optimizer = torch.optim.SGD(param_groups, lr=base_lr, weight_decay=1e-4, momentum=0.9)
        optimizer = self.create_optimizer(param_groups, base_lr)

        for param_group in optimizer.param_groups:
            if "initial_lr" not in param_group:
                param_group["initial_lr"] = base_lr

        if self.use_apex:
            (self.feature_extractor, self.model), optimizer = amp.initialize(
                [self.feature_extractor, self.model],
                optimizer,
                opt_level="O1",
                max_loss_scale=65536)

        print("optimizer", optimizer)
        self.optimizer = optimizer
        self.print_optimizer()

    @property
    def solver_model_name(self):
        return type(self).__name__[:-len("Solver")] + "Model"

    def setup_feature_extractor(self):
        args = copy.deepcopy(self.args)
        args.title = os.path.join(args.title, "VinceModel")
        args.checkpoint_dir = os.path.join(
            args.base_logdir, args.title,
            *(args.checkpoint_dir.split(os.sep)[2:]))
        args.long_save_checkpoint_dir = os.path.join(
            args.base_logdir, args.title,
            *(args.long_save_checkpoint_dir.split(os.sep)[2:-1]),
            constants.TIME_STR)
        args.tensorboard_dir = os.path.join(
            args.base_logdir, args.title,
            *(args.checkpoint_dir.split(os.sep)[2:-1]), constants.TIME_STR)
        self.feature_extractor = VinceModel(args)
        print(self.feature_extractor)
        self.feature_extractor.restore()
        self.feature_extractor.to(self.device)
        if self.freeze_feature_extractor:
            self.feature_extractor.eval()
        else:
            self.feature_extractor.train()

    def make_decoder_network(self, args) -> torch.nn.Module:
        raise NotImplementedError

    def setup_model(self):
        self.setup_feature_extractor()
        args = copy.deepcopy(self.args)
        args.title = os.path.join(args.title, self.solver_model_name)
        args.checkpoint_dir = os.path.join(
            args.base_logdir, args.title,
            *(args.checkpoint_dir.split(os.sep)[2:]))
        args.long_save_checkpoint_dir = os.path.join(
            args.base_logdir, args.title,
            *(args.long_save_checkpoint_dir.split(os.sep)[2:-1]),
            constants.TIME_STR)
        args.tensorboard_dir = os.path.join(
            args.base_logdir, args.title,
            *(args.checkpoint_dir.split(os.sep)[2:-1]), constants.TIME_STR)
        self.model = self.make_decoder_network(args)

        self.iteration = self.model.restore()
        self.model.to(self.device)

        if self.train_loader is not None:
            self.epoch = self.iteration // len(self.train_loader.dataset)
        if self.freeze_feature_extractor:
            if self.train_logger is not None:
                self.train_logger.network_conv_summary(
                    self.feature_extractor.feature_extractor.module,
                    self.iteration)

    def reset_epoch(self):
        super(EndTaskBaseSolver, self).reset_epoch()
        if not self.freeze_feature_extractor and self.train_logger is not None:
            self.train_logger.network_conv_summary(
                self.feature_extractor.feature_extractor.module,
                self.iteration)
        if self.freeze_feature_extractor:
            self.feature_extractor.eval()
        else:
            self.feature_extractor.train()
        self.model.train()

    def convert_batch(self, batch, batch_type: str = "train") -> Dict:
        batch = {
            key: (val.to(self.model.device, non_blocking=True) if isinstance(
                val, torch.Tensor) else val)
            for key, val in batch.items()
        }
        return batch

    def get_batch(self):
        iter_output = next(self.batch_iter)
        self.train_batch_counter += 1
        if self.train_batch_counter == len(self.train_loader):
            print("Hit stop iteration. End of epoch.")
            self.train_logger.scalar_summary("metrics/%s/epoch" %
                                             self.full_name,
                                             self.epoch,
                                             step=self.iteration,
                                             increment_counter=False)
            self.train_logger.scalar_summary(
                "metrics/%s/lr" % self.full_name,
                self.optimizer.param_groups[0]["lr"],
                step=self.iteration,
                increment_counter=False,
            )
            self.train_batch_counter = 0
            raise StopIteration
        return self.convert_batch(iter_output, "train")

    def get_val_batch(self):
        # Useful for never_ending persistent dataloader which will never raise StopIteration on its own.
        val_iter = iter(self.val_loader)
        for _ in range(len(self.val_loader)):
            iter_output = next(val_iter)
            yield self.convert_batch(iter_output, "val")
        raise StopIteration

    def forward(self, batch):
        if self.freeze_feature_extractor:
            with torch.no_grad():
                features = self.feature_extractor.extract_features(
                    batch["data"])
                extracted_features = features["extracted_features"].to(
                    self.model.device).detach()
        else:
            features = self.feature_extractor.extract_features(batch["data"])
            extracted_features = features["extracted_features"].to(
                self.model.device)

        output = self.model(extracted_features)

        output.update(features)
        output.update(batch)
        return output

    def run_train_iteration(self):
        total_t_start = time.time()
        t_start = time.time()
        try:
            image_batch = self.get_batch()
        except StopIteration:
            return

        t_end = time.time()
        self.time_meters["data_cache_time"].update(t_end - t_start)
        t_start = time.time()

        output = self.forward(image_batch)
        loss_dict = self.model.loss(output)

        t_end = time.time()
        self.time_meters["forward_time"].update(t_end - t_start)
        t_start = time.time()

        metrics = self.model.get_metrics(output)

        total_loss = 0
        for key, val in loss_dict.items():
            weighted_loss = val[0] * val[1]
            total_loss = total_loss + weighted_loss
            self.loss_meters[key].update(weighted_loss)
        if "total_loss" in self.loss_meters:
            self.loss_meters["total_loss"].update(total_loss)
        loss = total_loss

        try:
            assert torch.isfinite(loss)
        except AssertionError as re:
            import pdb
            traceback.print_exc()
            pdb.set_trace()
            print("anomoly", re)
            raise re

        for key, val in metrics.items():
            self.metric_meters[key].update(val)

        t_end = time.time()
        self.time_meters["metrics_time"].update(t_end - t_start)

        t_start = time.time()

        self.optimizer.zero_grad()
        if self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.optimizer.step()

        t_end = time.time()
        self.time_meters["backward_time"].update(t_end - t_start)

        if self.logger_iteration % self.args.image_log_frequency == 0:
            image_output = self.model.get_image_output(output)
            if self.train_logger is not None:
                for key, val in image_output.items():
                    if key.startswith("images"):
                        if isinstance(val, list):
                            for vv, item in enumerate(val):
                                self.train_logger.image_summary(
                                    self.full_name + "_" +
                                    key[len("images/"):], item,
                                    self.iteration + vv, False)
                        else:
                            self.train_logger.image_summary(
                                self.full_name + "_" + key[len("images/"):],
                                val, self.iteration, False)

        if self.logger_iteration % self.args.log_frequency == 0:
            log_dict = {
                "times/%s/%s" % (self.full_name, key): val.val
                for key, val in self.time_meters.items()
            }
            log_dict.update({
                "losses/%s/%s" % (self.full_name, key): val.val
                for key, val in self.loss_meters.items()
            })
            log_dict.update({
                "metrics/%s/%s" % (self.full_name, key): val.val
                for key, val in self.metric_meters.items()
            })
            if self.train_logger is not None:
                self.train_logger.dict_log(log_dict, self.iteration)

        self.iteration += self.args.batch_size

        if self.logger_iteration % self.args.save_frequency == 0:
            self.save(5)

        total_t_end = time.time()
        self.time_meters["total_time"].update(total_t_end - total_t_start)
        self.logger_iteration += 1

    def run_val(self):
        with torch.no_grad():
            self.feature_extractor.eval()
            self.model.eval()
            time_meters = dict(
                total_time=RollingAverageMeter(self.args.log_frequency),
                data_cache_time=RollingAverageMeter(self.args.log_frequency),
                forward_time=RollingAverageMeter(self.args.log_frequency),
                metrics_time=RollingAverageMeter(self.args.log_frequency),
            )
            loss_meters = {
                key: RollingAverageMeter(self.args.log_frequency)
                for key in self.model.loss(None).keys()
            }
            if len(loss_meters) > 1:
                loss_meters["total_loss"] = RollingAverageMeter(
                    self.args.log_frequency)
            metric_meters = {
                metric: RollingAverageMeter(self.args.log_frequency)
                for metric in self.model.get_metrics(None).keys()
            }

            epoch_loss_meters = {
                "epoch_" + key: AverageMeter()
                for key in loss_meters.keys()
            }
            epoch_metric_meters = {
                "epoch_" + key: AverageMeter()
                for key in metric_meters.keys()
            }

            step_on = self.iteration

            total_t_start = time.time()
            for ii, image_batch in enumerate(
                    tqdm.tqdm(self.get_val_batch(),
                              total=len(self.val_loader))):
                batch_size = image_batch["data"].shape[0]

                t_end = time.time()
                time_meters["data_cache_time"].update(t_end - total_t_start)
                t_start = time.time()
                output = self.forward(image_batch)
                loss_dict = self.model.loss(output)

                t_end = time.time()
                time_meters["forward_time"].update(t_end - t_start)
                t_start = time.time()

                metrics = self.model.get_metrics(output)
                if ii % self.args.image_log_frequency == 0:
                    image_output = self.model.get_image_output(output)

                total_loss = 0
                for key, val in loss_dict.items():
                    weighted_loss = val[0] * val[1]
                    total_loss = total_loss + weighted_loss
                    loss_meters[key].update(weighted_loss)
                    epoch_loss_meters["epoch_" + key].update(
                        weighted_loss, batch_size)
                if "total_loss" in loss_meters:
                    loss_meters["total_loss"].update(total_loss)
                    epoch_loss_meters["epoch_total_loss"].update(
                        total_loss, batch_size)
                loss = total_loss

                try:
                    assert torch.isfinite(loss)
                except AssertionError:
                    import pdb
                    pdb.set_trace()
                    print("Loss is infinite")

                for key, val in metrics.items():
                    metric_meters[key].update(val)
                    epoch_metric_meters["epoch_" + key].update(val, batch_size)

                t_end = time.time()
                time_meters["metrics_time"].update(t_end - t_start)

                if ii % self.args.image_log_frequency == 0:
                    if self.val_logger is not None:
                        for key, val in image_output.items():
                            if isinstance(val, list):
                                for vv, item in enumerate(val):
                                    self.val_logger.image_summary(
                                        self.full_name + "/" + key, item,
                                        step_on + vv, False)
                            else:
                                self.val_logger.image_summary(
                                    self.full_name + "/" + key, val, step_on,
                                    False)

                if ii % self.args.log_frequency == 0:
                    log_dict = {
                        "times/%s/%s" % (self.full_name, key): val.val
                        for key, val in time_meters.items()
                    }
                    log_dict.update({
                        "losses/%s/%s" % (self.full_name, key): val.val
                        for key, val in loss_meters.items()
                    })
                    log_dict.update({
                        "metrics/%s/%s" % (self.full_name, key): val.val
                        for key, val in metric_meters.items()
                    })
                    if self.val_logger is not None:
                        self.val_logger.dict_log(log_dict, step_on)

                step_on += self.args.batch_size
                total_t_end = time.time()
                time_meters["total_time"].update(total_t_end - total_t_start)
                total_t_start = time.time()

        log_dict = {
            "epoch/losses/%s/%s" % (self.full_name, key): val.avg
            for key, val in epoch_loss_meters.items()
        }
        log_dict.update({
            "epoch/metrics/%s/%s" % (self.full_name, key): val.avg
            for key, val in epoch_metric_meters.items()
        })
        if self.val_logger is not None:
            self.val_logger.dict_log(log_dict, step_on)

    def run_eval(self):
        self.val_loader = PersistentDataLoader(
            dataset=None,
            num_workers=min(self.args.num_workers, 40),
            pin_memory=True,
            device=self.device,
            never_ending=True,
        )
        self.val_loader.set_dataset(
            self.args.dataset(self.args, "val"),
            batch_size=self.args.batch_size,
            shuffle=True,
            collate_fn=self.args.dataset.collate_fn,
            worker_init_fn=self.args.dataset.worker_init_fn,
        )
        self.run_val()

    def save(self, num_to_keep=-1):
        self.model.save(self.iteration, num_to_keep)
        if not self.freeze_feature_extractor:
            self.feature_extractor.save(self.iteration, num_to_keep)
Пример #6
0
print("starting TSNE")
dataset = R2V2Dataset(args, "val", transform=StandardVideoTransform(args.input_size, "val"), num_images_to_return=1)

data_loader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=args.num_workers,
    pin_memory=True,
    collate_fn=R2V2Dataset.collate_fn,
)

with torch.no_grad():
    torch_devices = args.pytorch_gpu_ids
    device = "cuda:" + str(torch_devices[0])
    model = VinceModel(args)

    model.restore()
    model.eval()
    model.to(device)
    all_images = []
    all_features = []

    for batch in tqdm.tqdm(data_loader, total=NUM_IMAGES_IN_TSNE // args.batch_size):
        batch = process_video_data(batch)

        features = model.get_embeddings(batch)["embeddings"]
        images = to_uint8(batch["data"])
        for image, feature in zip(images, features):
            all_images.append(image)
            all_features.append(pt_util.to_numpy(feature))