Esempio n. 1
0
    def train(self):
        """
        Trains the encoder and decoder as specified via configuration.
        """
        self.base_model.eval()
        while self.cur_epoch < self.final_epoch:
            if not self.only_test:
                # Perform epoch on training set
                self.__epoch(self.train_dataloader, do_step=True)

                # Perform epoch on validation set
                _, val_recon_acc, _ = self.__epoch(self.val_dataloader,
                                                   do_step=False)

            # Perform epoch on test dataset
            _, _, _ = self.__epoch(self.test_dataloader,
                                   do_step=False,
                                   do_print=False)

            if not self.only_test:
                self.__save_current_state(val_recon_acc)
            self.cur_epoch += 1

            # Place functions back on GPU, if necessary
            self.parity_model = try_cuda(self.parity_model)
            self.loss_fn = try_cuda(self.loss_fn)
Esempio n. 2
0
    def __epoch(self, data_loader, do_step=False, do_print=True):
        """
        Performs a single epoch of either training or validation.

        Parameters
        ----------
        data_loader:
            The data loader to use for this epoch
        do_step: bool
            Whether to make optimization steps using this data loader.
        do_print: bool
            Whether to print accuracies for this epoch.
        """
        stats = util.stats.StatsTracker()
        label = data_loader.dataset.name

        if label == "train":
            self.parity_model.train()
        else:
            self.parity_model.eval()

        if do_print:
            data_loader = tqdm(data_loader, ascii=True,
                               desc="Epoch {}. {}".format(self.cur_epoch, label))

        for mb_data, mb_labels, mb_true_labels in data_loader:
            mb_data = try_cuda(mb_data.view(-1, self.ec_k, mb_data.size(1)))
            mb_labels = try_cuda(
                mb_labels.view(-1, self.ec_k, mb_labels.size(1)))
            mb_true_labels = try_cuda(mb_true_labels.view(-1, self.ec_k))

            if do_step:
                if self.train_parity_model: self.parity_model_opt.zero_grad()
                if self.train_encoder: self.encoder_opt.zero_grad()
                if self.train_decoder: self.decoder_opt.zero_grad()

            loss = self.__forward(mb_data, mb_labels, mb_true_labels, stats)

            if do_step:
                loss.backward()
                if self.train_parity_model: self.parity_model_opt.step()
                if self.train_encoder: self.encoder_opt.step()
                if self.train_decoder: self.decoder_opt.step()

            if do_print:
                rloss, rtop1, rtop5 = stats.running_averages()
                data_loader.set_description(
                    "Epoch {}. {}. Top-1={:.4f}, Top-5={:.4f}, Loss={:.4f}".format(
                    self.cur_epoch, label, rtop1, rtop5, rloss))

        epoch_loss, epoch_acc_map = stats.averages()
        outfile_fmt = os.path.join(self.save_dir, label + "_{}.txt")
        epoch_map = epoch_acc_map
        epoch_map["loss"] = epoch_loss
        util.util.write_vals_dict(outfile_fmt, epoch_map)

        top_recon = epoch_acc_map["reconstruction_top1"]
        top_overall = epoch_acc_map["overall_top1"]
        return epoch_loss, top_recon, top_overall
    def __epoch(self, data_loader, do_step=False, do_print=True):
        """
        Performs a single epoch of either training or validation.

        Parameters
        ----------
        data_loader:
            The data loader to use for this epoch
        do_step: bool
            Whether to make optimization steps using this data loader.
        do_print: bool
            Whether to print accuracies for this epoch.
        """
        stats = util.stats.StatsTracker()
        label = data_loader.dataset.name

        if label == "train":
            self.enc_model.train()
            self.dec_model.train()
        else:
            self.enc_model.eval()
            self.dec_model.eval()

        if do_print:
            print("--------------- EPOCH {} : {} ---------------".format(
                self.cur_epoch, label))

        for mb_data, mb_labels, mb_true_labels in data_loader:
            mb_data = try_cuda(mb_data.view(-1, self.ec_k, mb_data.size(1)))
            mb_labels = try_cuda(
                mb_labels.view(-1, self.ec_k, mb_labels.size(1)))
            mb_true_labels = try_cuda(mb_true_labels.view(-1, self.ec_k))

            if do_step:
                self.enc_opt.zero_grad()
                self.dec_opt.zero_grad()

            loss = self.__forward(mb_data, mb_labels, mb_true_labels, stats)

            if do_step:
                loss.backward()
                self.enc_opt.step()
                self.dec_opt.step()

        epoch_loss, epoch_recon_acc, epoch_overall_acc = stats.averages()

        if do_print:
            print("loss:", epoch_loss)
            print("reconstruction-acc:", epoch_recon_acc)
            print("overall-acc:", epoch_overall_acc)

        outfile_fmt = os.path.join(self.save_dir, label + "_{}.txt")
        vals = [epoch_loss, epoch_recon_acc, epoch_overall_acc]
        names = ["loss", "reconstruction_accuracy", "overall_accuracy"]
        util.util.write_vals(outfile_fmt, vals, names)

        return epoch_loss, epoch_recon_acc, epoch_overall_acc
Esempio n. 4
0
    def __gen_masks(self):
        """Generates masks to be used when erasing indices, calculating loss,
        and calculating accuracies.

        This method currently assumes that only one element will be erased, and
        thus that ec_r is 1.

        Returns
        -------
            ``torch.autograd.Variable``:
                A mask of all ones but with zeros in the locations of elements
                which should be erased when simulating erasure. There is one
                mask for each possible combination of erased and non-erased
                data units.
                Dimensions: (ec_k, ec_k + ec_r, base_model_output_dim)
            ``torch.autograd.Variable``:
                Same as loss_mask, but with of dimensionality to be used when
                calculating accuracies. There is one mask for each possible
                combination of erased and non-erased data units.
                Dimensions: (ec_k, ec_k)
        """
        base_model_output_dim = self.val_dataloader.dataset.decoder_in_dim()

        # As this method assumes that only one data unit is erased at a given time,
        # the only possible erasure scenarios correspond to when one of the first
        # `ec_r` elements are erased.
        erased_indices = [torch.LongTensor([e]) for e in range(self.ec_k)]

        erase_mask = torch.ones((len(erased_indices),
                                 self.ec_k + self.ec_r,
                                 base_model_output_dim))

        acc_mask = torch.zeros((len(erased_indices), self.ec_k)).byte()

        for i, erased_idx in enumerate(erased_indices):
            i = torch.LongTensor([i])
            erase_mask[i, erased_idx, :] = 0.
            acc_mask[i, erased_idx] = 1

        return try_cuda(erase_mask), try_cuda(acc_mask)
Esempio n. 5
0
    def __init_from_config_map(self, config_map):
        """
        Initializes state for training based on the contents of `config_map`.
        """
        # If "continue_from_file" is set, we load previous state for training
        # from the associated value.
        prev_state = None
        if "continue_from_file" in config_map and config_map[
                "continue_from_file"] is not None:
            prev_state = util.util.load_state(config_map["continue_from_file"])

        self.ec_k = config_map["ec_k"]
        self.ec_r = config_map["ec_r"]
        if self.ec_r != 1:
            raise Exception("Currently support only `ec_r` = 1")
        self.batch_size = config_map["batch_size"]

        # Base models are wrapped around a thin class that ensures that inputs
        # to base models are of correct size prior to performing a forward
        # pass. We place the base model in "eval" mode so as to not trigger
        # training-specific operations.
        underlying_base_model = construct(config_map["BaseModel"])
        underlying_base_model.load_state_dict(
            torch.load(config_map["base_model_file"]))
        underlying_base_model.eval()
        base_model_input_size = config_map["base_model_input_size"]
        self.base_model = BaseModelWrapper(underlying_base_model,
                                           base_model_input_size)
        self.base_model = try_cuda(self.base_model)
        self.base_model.eval()

        trdl, vdl, tsdl = get_dataloaders(config_map["Dataset"],
                                          self.base_model, self.ec_k,
                                          self.batch_size)
        self.train_dataloader = trdl
        self.val_dataloader = vdl
        self.test_dataloader = tsdl

        self.loss_fn = construct(config_map["Loss"])

        encoder_in_dim = self.val_dataloader.dataset.encoder_in_dim()
        decoder_in_dim = self.val_dataloader.dataset.decoder_in_dim()
        self.enc_model = construct(config_map["Encoder"], {
            "ec_k": self.ec_k,
            "ec_r": self.ec_r,
            "in_dim": encoder_in_dim
        })

        self.dec_model = construct(config_map["Decoder"], {
            "ec_k": self.ec_k,
            "ec_r": self.ec_r,
            "in_dim": decoder_in_dim
        })

        # Move our encoder, decoder, and loss functions to GPU, if available
        self.enc_model = try_cuda(self.enc_model)
        self.dec_model = try_cuda(self.dec_model)
        self.enc_model.eval()
        self.dec_model.eval()
        self.loss_fn = try_cuda(self.loss_fn)

        underlying_parity_model = construct(config_map["ParityModel"])
        util.util.init_weights(underlying_parity_model)
        base_model_input_size = config_map["base_model_input_size"]
        self.parity_model = BaseModelWrapper(underlying_parity_model,
                                             base_model_input_size)
        self.opt = construct(config_map["Optimizer"],
                             {"params": self.parity_model.parameters()})
        self.parity_model = try_cuda(self.parity_model)

        self.cur_epoch = 0
        self.best_recon_accuracy = 0.0
        self.final_epoch = config_map["final_epoch"]

        # If we are loading from a previous state, update our encoder, decoder,
        # optimizers, and current status of training so that we can continue.
        if prev_state is not None:
            self.parity_model.load_state_dict(prev_state["parity_model"])
            self.cur_epoch = prev_state["epoch"]
            self.best_recon_accuracy = prev_state["best_val_acc"]
            self.opt.load_state_dict(prev_state["opt"])

        self.only_test = config_map["only_test"]
        if self.only_test:
            if prev_state is None:
                raise Exception(
                    "only_test cannot be set unless --continue_from_file is set"
                )
            self.final_epoch = 1

        # Directory to save stats and checkpoints to
        self.save_dir = config_map["save_dir"]
        if not os.path.isdir(self.save_dir):
            os.makedirs(self.save_dir)
Esempio n. 6
0
    def __init__(self,
                 name,
                 base_model,
                 num_classes,
                 base_dataset,
                 ec_k,
                 code_dataset=None,
                 put_gpu=True):
        """
        Parameters
        ----------
        name: str
            One of {"train", "val", "test"}
        base_model: ``torch.nn.Module``
            Base model on which inference is being performed and over which a
            code imparts resilience.
        num_classes: int
            The number of classes in the underlying dataset.
        base_dataset: ``torchvision.datasets.Dataset``
            A dataset from the datasets provided by torchvision.
        ec_k: int
            Number of samples from ``base_dataset`` that will be encoded
            together.
        code_dataset: ``torchvision.dataset.Dataset``
            Dataset containing a set of transforms to apply to samples prior to
            encoding. These transforms may differ from those in
            `base_transform` as one may wish to include transformations such as
            random cropping and rotating of images so as to reduce overfiting.
            Such transformations would not be included in `base_dataset` as
            they could lead to noisy labels being generated.
        put_gpu: bool
            Whether to put data and labels on GPU. This is untenable for large
            datasets.
        """
        self.name = name
        self.base_model = base_model
        self.ec_k = ec_k
        self.dataset = base_dataset

        # Since we are not directly calling this DataLoader when we perform
        # iterations when training a code, it is OK not to shuffle the
        # underlying dataset.
        dataloader = data.DataLoader(self.dataset,
                                     batch_size=32,
                                     shuffle=False)

        in_size = self.dataset[0][0].view(-1).size(0)
        self.num_channels = self.dataset[0][0].size(0)
        if self.num_channels > 1:
            assert self.num_channels == 3, "Only currently support 3 channels for multi-channel input"

        # Preprate data, outputs from base model, and the true labels for
        # samples. We will populate these tensors so that we can later access
        # them without pulling PIL images from the underlying dataset.
        self.data = torch.zeros(len(self.dataset), in_size)
        self.outputs = torch.zeros(len(self.dataset), num_classes)
        self.true_labels = torch.zeros(len(self.dataset))

        cur = 0
        for inputs, targets in dataloader:
            inputs = try_cuda(inputs.squeeze(1).view(inputs.size(0), -1))
            x = self.base_model(inputs)
            last = cur + inputs.size(0)
            self.data[cur:last, :] = inputs.data
            self.outputs[cur:last, :] = x.data
            self.true_labels[cur:last] = targets
            cur = last

        # Calculate the accuracy of the base model with respect to this dataset.
        base_model_preds = torch.max(self.outputs, dim=1)[1]
        correct_preds = (base_model_preds == self.true_labels.long())
        base_model_num_correct = torch.sum(correct_preds).item()
        base_model_num_tried = self.outputs.size(0)
        base_model_accuracy = base_model_num_correct / base_model_num_tried

        # We don't print the accuracy for the validation dataset because we
        # only split the training set into a training and validation set
        # after getting all inference results from the training dataset.
        # Printing accuracy for the validation dataset can lead to confusion.
        if name != "val":
            print("Base model", name, "accuracy is", base_model_num_correct,
                  "/", base_model_num_tried, "=", base_model_accuracy)

        self.true_labels = self.true_labels.long()
        if put_gpu:
            # Move data, outputs, and true labels to GPU for fast access.
            self.data = try_cuda(self.data)
            self.outputs = try_cuda(self.outputs)
            self.true_labels = try_cuda(self.true_labels)

        # If extra transformations are passed, create a new dataset containing
        # these so that a caller can pull new, transformed samples with calls
        # to `__getitem__`.
        if code_dataset is not None:
            self.dataset = code_dataset
            self.extra_transforms = True
        else:
            self.extra_transforms = False
    def __init__(self,
                 name,
                 base_model,
                 num_classes,
                 base_dataset,
                 base_dataset_dir,
                 ec_k,
                 base_transform=None,
                 code_transform=None):
        """
        Parameters
        ----------
        name: str
            One of {"train", "val", "test"}
        base_model: ``torch.nn.Module``
            Base model on which inference is being performed and over which a
            code imparts resilience.
        num_classes: int
            The number of classes in the underlying dataset.
        base_dataset: ``torchvision.datasets.Dataset``
            A dataset from the datasets provided by torchvision.
        base_dataset_dir: str
            Location where ``base_dataset`` has been or will be saved. This
            avoids re-downloading the dataset.
        ec_k: int
            Number of samples from ``base_dataset`` that will be encoded
            together.
        base_transform: ``torchvision.transforms.Transform``
            Set of transforms to apply to samples when generating base model
            outputs that will (potentially) be used as labels.
        code_transform: ``torchvision.transforms.Transform``
            Set of transforms to apply to samples prior to encoding. These
            transforms may differ from those in `base_transform` as one
            may wish to include transformations such as random cropping and
            rotating of images so as to reduce overfiting. Such transformations
            would not be included in `base_transform` as they could lead to
            noisy labels being generated.
        """
        self.name = name
        self.base_model = base_model
        self.ec_k = ec_k

        if base_transform is None:
            base_transform = transforms.ToTensor()

        # Draw from the torchvisions "train" datasets for training and
        # validation datasets
        is_train = (name != "test")

        # Create the datasets from the underlying `base_model_dataset`.
        # When generating outputs from running samples through the base model,
        # we do apply `base_transform`.
        self.dataset = base_dataset(root=base_dataset_dir,
                                    train=is_train,
                                    download=True,
                                    transform=base_transform)

        # Since we are not directly calling this DataLoader when we perform
        # iterations when training a code, it is OK not to shuffle the
        # underlying dataset.
        dataloader = data.DataLoader(self.dataset,
                                     batch_size=32,
                                     shuffle=False)

        in_size = self.dataset[0][0].view(-1).size(0)
        self.num_channels = self.dataset[0][0].size(0)
        if self.num_channels > 1:
            assert self.num_channels == 3, "Only currently support 3 channels for multi-channel input"

        # Preprate data, outputs from base model, and the true labels for
        # samples. We will populate these tensors so that we can later access
        # them without pulling PIL images from the underlying dataset.
        self.data = torch.zeros(len(self.dataset), in_size)
        self.outputs = torch.zeros(len(self.dataset), num_classes)
        self.true_labels = torch.zeros(len(self.dataset))

        cur = 0
        for inputs, targets in dataloader:
            inputs = try_cuda(inputs.squeeze(1).view(inputs.size(0), -1))
            x = self.base_model(inputs)
            last = cur + inputs.size(0)
            self.data[cur:last, :] = inputs.data
            self.outputs[cur:last, :] = x.data
            self.true_labels[cur:last] = targets
            cur = last

        # Calculate the accuracy of the base model with respect to this dataset.
        base_model_preds = torch.max(self.outputs, dim=1)[1]
        correct_preds = (base_model_preds == self.true_labels.long())
        base_model_num_correct = torch.sum(correct_preds).item()
        base_model_num_tried = self.outputs.size(0)
        base_model_accuracy = base_model_num_correct / base_model_num_tried

        print("Base model", name, "accuracy is", base_model_num_correct, "/",
              base_model_num_tried, "=", base_model_accuracy)

        # Move data, outputs, and true labels to GPU for fast access.
        self.data = try_cuda(self.data)
        self.outputs = try_cuda(self.outputs)
        self.true_labels = try_cuda(self.true_labels.long())

        # If extra transformations are passed, create a new dataset containing
        # these so that a caller can pull new, transformed samples with calls
        # to `__getitem__`.
        if name == "train" and code_transform is not None:
            self.dataset = base_dataset(root=base_dataset_dir,
                                        train=is_train,
                                        download=True,
                                        transform=code_transform)
            self.extra_transforms = True
        else:
            self.extra_transforms = False
    def __init_from_config_map(self, config_map):
        """
        Initializes state for training based on the contents of `config_map`.
        """
        # If "continue_from_file" is set, we load previous state for training
        # from the associated value.
        prev_state = None
        if "continue_from_file" in config_map and config_map[
                "continue_from_file"] is not None:
            prev_state = util.util.load_state(config_map["continue_from_file"])

        self.ec_k = config_map["ec_k"]
        self.ec_r = config_map["ec_r"]
        assert self.ec_r == 1, "Currently only support `ec_r` being one"
        self.batch_size = config_map["batch_size"]

        # Base models are wrapped around a thin class that ensures that inputs
        # to base models are of correct size prior to performing a forward
        # pass. We place the base model in "eval" mode so as to not trigger
        # training-specific operations.
        underlying_base_model = construct(config_map["BaseModel"])
        underlying_base_model.load_state_dict(
            torch.load(config_map["base_model_file"]))
        underlying_base_model.eval()
        base_model_input_size = config_map["base_model_input_size"]
        self.base_model = BaseModelWrapper(underlying_base_model,
                                           base_model_input_size)
        self.base_model = try_cuda(self.base_model)
        self.base_model.eval()

        trdl, vdl, tsdl = get_dataloaders(config_map["Dataset"],
                                          self.base_model, self.ec_k,
                                          self.batch_size)
        self.train_dataloader = trdl
        self.val_dataloader = vdl
        self.test_dataloader = tsdl

        # Loss functions are wrapped by a thin class that masks loss prior to
        # summing it for performing a backward pass. Using masks enables us to
        # perform loss calculations over all unavailability scenarios for a
        # given minibatch of data in a vectorized fashion, rather than using
        # for-loops. We find this to be faster.
        self.loss_fn = MaskedLoss(base_loss=config_map["Loss"])

        encoder_in_dim = self.val_dataloader.dataset.encoder_in_dim()
        decoder_in_dim = self.val_dataloader.dataset.decoder_in_dim()
        self.enc_model = construct(config_map["Encoder"], {
            "ec_k": self.ec_k,
            "ec_r": self.ec_r,
            "in_dim": encoder_in_dim
        })
        util.util.init_weights(self.enc_model)

        self.dec_model = construct(config_map["Decoder"], {
            "ec_k": self.ec_k,
            "ec_r": self.ec_r,
            "in_dim": decoder_in_dim
        })
        util.util.init_weights(self.dec_model)

        # Move our encoder, decoder, and loss functions to GPU, if available
        self.enc_model = try_cuda(self.enc_model)
        self.dec_model = try_cuda(self.dec_model)
        self.loss_fn = try_cuda(self.loss_fn)

        # We use two separate optimizers for the encoder and decoder. We found
        # little benefit from using a single optimizer for each of these.
        self.enc_opt = construct(config_map["EncoderOptimizer"],
                                 {"params": self.enc_model.parameters()})
        self.dec_opt = construct(config_map["DecoderOptimizer"],
                                 {"params": self.dec_model.parameters()})

        self.cur_epoch = 0
        self.best_recon_accuracy = 0.0
        self.final_epoch = config_map["final_epoch"]

        # If we are loading from a previous state, update our encoder, decoder,
        # optimizers, and current status of training so that we can continue.
        if prev_state is not None:
            self.cur_epoch = prev_state["epoch"]
            self.best_recon_accuracy = prev_state["best_val_acc"]
            self.enc_model.load_state_dict(prev_state["enc_model"])
            self.dec_model.load_state_dict(prev_state["dec_model"])
            self.enc_opt.load_state_dict(prev_state["enc_opt"])
            self.dec_opt.load_state_dict(prev_state["dec_opt"])

        # Whether or not loss should be calculated with respect to true labels
        # for the underlying dataset.
        self.from_true = False
        if "CrossEntropyLoss" in config_map["Loss"]["class"]:
            self.from_true = True

        # Directory to save stats and checkpoints to
        self.save_dir = config_map["save_dir"]
        if not os.path.isdir(self.save_dir):
            os.makedirs(self.save_dir)