예제 #1
0
    def train_one_batch(self, performance_estimators, batch_idx, input_s, target_s, metadata):
        # outputs used to calculate the loss of the supervised model
        # must be done with the model prior to regularization:
        self.net.train()
        indel_weight = self.args.indel_weight_factor
        snp_weight = 1.0
        self.optimizer_training.zero_grad()
        self.net.zero_grad()
        output_s = self.net(input_s)
        output_s_p = self.get_p(output_s)
        _, target_index = torch.max(recode_as_multi_label(target_s), dim=1)
        supervised_loss = self.criterion_classifier(output_s, recode_as_multi_label(target_s))

        batch_weight = self.estimate_batch_weight(metadata, indel_weight=indel_weight,
                                                  snp_weight=snp_weight)

        weighted_supervised_loss = supervised_loss * batch_weight
        optimized_loss = weighted_supervised_loss
        optimized_loss.backward()
        self.optimizer_training.step()
        performance_estimators.set_metric(batch_idx, "supervised_loss", supervised_loss.data[0])
        performance_estimators.set_metric_with_outputs(batch_idx, "train_accuracy", supervised_loss.data[0],
                                                       output_s_p, targets=target_index)
        if not self.args.no_progress:
            progress_bar(batch_idx * self.mini_batch_size,
                         self.max_training_examples,
                         performance_estimators.progress_message(
                             ["supervised_loss", "reconstruction_loss", "train_accuracy"]))
    def test_one_batch(self,
                       performance_estimators,
                       batch_idx,
                       input_s,
                       target_s,
                       metadata=None,
                       errors=None):
        if errors is None:
            errors = torch.zeros(target_s[0].size())
        self.net.eval()
        output_s = self.net(input_s)
        output_s_p = self.get_p(output_s)

        _, target_index = torch.max(recode_as_multi_label(target_s), dim=1)
        _, output_index = torch.max(recode_as_multi_label(output_s_p), dim=1)
        supervised_loss = self.criterion_classifier(output_s, target_s)
        #self.estimate_errors(errors, output_s_p, target_s)

        performance_estimators.set_metric(batch_idx, "test_supervised_loss",
                                          supervised_loss.item())
        performance_estimators.set_metric_with_outputs(batch_idx,
                                                       "test_accuracy",
                                                       supervised_loss.item(),
                                                       output_s_p,
                                                       targets=target_index)
        if not self.args.no_progress:
            progress_bar(
                batch_idx * self.mini_batch_size, self.max_validation_examples,
                performance_estimators.progress_message([
                    "test_supervised_loss", "test_reconstruction_loss",
                    "test_accuracy"
                ]))
    def test_one_batch(self, performance_estimators,
                       batch_idx, input_s, target_s, metadata=None, errors=None):
        # Estimate the reconstruction loss on validation examples:
        reconstruction_loss = self.net.get_reconstruction_loss(input_s)

        # now evaluate prediction of categories:
        categories_predicted, latent_code = self.net.encoder(input_s)
        categories_predicted_p = self.get_p(categories_predicted)
        categories_predicted_p[categories_predicted_p != categories_predicted_p] = 0.0
        _, target_index = torch.max(target_s, dim=1)
        _, output_index = torch.max(categories_predicted_p, dim=1)
        categories_loss = self.net.semisup_loss_criterion(categories_predicted, target_s)
        weight = 1
        indel_weight = self.args.indel_weight_factor
        snp_weight = 1.0
        if self.use_pdf:

            weight *= self.estimate_example_density_weight(latent_code)
        else:
            weight *= self.estimate_batch_weight(metadata, indel_weight=indel_weight,
                                                 snp_weight=snp_weight)

        self.cm.add(predicted=output_index.data, target=target_index.data)

        performance_estimators.set_metric(batch_idx, "reconstruction_loss", reconstruction_loss.item())
        performance_estimators.set_metric(batch_idx, "weight", weight)
        performance_estimators.set_metric_with_outputs(batch_idx, "test_accuracy", reconstruction_loss.item(),
                                                       categories_predicted_p, target_index)
        performance_estimators.set_metric_with_outputs(batch_idx, "test_loss", categories_loss.item() * weight,
                                                       categories_predicted_p, target_s)

        if not self.args.no_progress:
            progress_bar(batch_idx * self.mini_batch_size, self.max_validation_examples,
                         performance_estimators.progress_message(["test_loss", "test_accuracy", "reconstruction_loss"]))
예제 #4
0
    def train_autoencoder(self, epoch, performance_estimators=None):

        if performance_estimators is None:
            performance_estimators = PerformanceList()
            performance_estimators += [LossHelper("train_loss")]
            performance_estimators += [FloatHelper("train_grad_norm")]
            print('\nTraining, epoch: %d' % epoch)

        self.net.train()
        supervised_grad_norm = 1.
        for performance_estimator in performance_estimators:
            performance_estimator.init_performance_metrics()

        unsupervised_loss_acc = 0
        num_batches = 0
        train_loader_subset = self.problem.train_loader_subset_range(
            0, self.args.num_training)

        for batch_idx, (_, data_dict) in enumerate(train_loader_subset):
            inputs = data_dict["input"].to(self.device)
            num_batches += 1

            inputs, targets = Variable(inputs), Variable(inputs,
                                                         requires_grad=False)
            # outputs used to calculate the loss of the supervised model
            # must be done with the model prior to regularization:
            self.net.train()
            self.optimizer_training.zero_grad()
            outputs = self.net(inputs)

            supervised_loss = self.criterion(outputs, targets)
            optimized_loss = supervised_loss
            optimized_loss.backward()
            self.optimizer_training.step()
            performance_estimators.set_metric_with_outputs(
                batch_idx, "train_loss", supervised_loss.item(), outputs,
                targets)

            supervised_grad_norm = grad_norm(self.net.parameters())
            performance_estimators.set_metric(batch_idx, "train_grad_norm",
                                              supervised_grad_norm)

            performance_estimators.set_metric_with_outputs(
                batch_idx, "optimized_loss", optimized_loss.item(), outputs,
                targets)

            progress_bar(
                batch_idx * self.mini_batch_size, self.max_training_examples,
                " ".join([
                    performance_estimator.progress_message()
                    for performance_estimator in performance_estimators
                ]))

            if (batch_idx +
                    1) * self.mini_batch_size > self.max_training_examples:
                break

        return performance_estimators
예제 #5
0
    def test_somatic_classifer(self, epoch, performance_estimators=None):
        print('\nTesting, epoch: %d' % epoch)
        if performance_estimators is None:
            performance_estimators = PerformanceList()
            performance_estimators += [LossHelper("test_loss")]
            performance_estimators += [LossHelper("classification_loss")]
            performance_estimators += [LossHelper("frequency_loss")]

        self.net.eval()
        for performance_estimator in performance_estimators:
            performance_estimator.init_performance_metrics()
        cross_entropy_loss = CrossEntropyLoss()
        mse_loss = MSELoss()
        for batch_idx, (_, data_dict) in enumerate(
                self.problem.validation_loader_range(
                    0, self.args.num_validation)):
            inputs = data_dict["input"]
            is_mutated_base_target = data_dict["isBaseMutated"]
            # transform one-hot encoding into a class index:
            max, indices = is_mutated_base_target.max(dim=1)
            is_mutated_base_target = indices
            somatic_frequency_target = data_dict["somaticFrequency"]
            if self.use_cuda:
                inputs, is_mutated_base_target, somatic_frequency_target = inputs.cuda(), \
                                                                           is_mutated_base_target.cuda(), \
                                                                           somatic_frequency_target.cuda()

            inputs, mut_targets, freq_targets = Variable(inputs), Variable(is_mutated_base_target, volatile=True), \
                                                Variable(somatic_frequency_target, volatile=True)

            is_base_mutated, output_frequency = self.net(inputs)
            classification_loss = cross_entropy_loss(is_base_mutated,
                                                     mut_targets)
            frequency_loss = mse_loss(output_frequency, freq_targets)
            test_loss = classification_loss + frequency_loss

            performance_estimators.set_metric(batch_idx, "test_loss",
                                              test_loss.data[0])
            performance_estimators.set_metric(batch_idx, "classification_loss",
                                              classification_loss.data[0])
            performance_estimators.set_metric(batch_idx, "frequency_loss",
                                              frequency_loss.data[0])

            progress_bar(
                batch_idx * self.mini_batch_size, self.max_validation_examples,
                performance_estimators.progress_message(["test_loss"]))

            if ((batch_idx + 1) *
                    self.mini_batch_size) > self.max_validation_examples:
                break
        # print()

        # Apply learning rate schedule:
        test_accuracy = performance_estimators.get_metric("test_loss")
        assert test_accuracy is not None, "test_loss must be found among estimated performance metrics"
        if not self.args.constant_learning_rates:
            self.scheduler_train.step(test_accuracy, epoch)
        return performance_estimators
    def train_one_batch(self, performance_estimators, batch_idx, input_s_1,
                        target_s_1, metadata_1, input_u_2):
        self.net.train()
        self.num_classes = len(target_s_1[0])
        genotype_frequencies = self.class_frequencies["softmaxGenotype"]

        category_prior = (genotype_frequencies /
                          torch.sum(genotype_frequencies)).numpy()

        indel_weight = self.args.indel_weight_factor
        snp_weight = 1.0

        target_s_2 = self.dreamup_target_for(num_classes=self.num_classes,
                                             category_prior=category_prior,
                                             input=input_u_2)

        if self.use_cuda:
            target_s_2 = target_s_2.cuda()
        with self.lock:
            input_s_mixup, target_s_mixup = self._recreate_mixup_batch(
                input_s_1, input_u_2, target_s_1, target_s_2)

        self.optimizer_training.zero_grad()
        self.net.zero_grad()

        # outputs used to calculate the loss of the supervised model
        # must be done with the model prior to regularization:
        output_s = self.net(input_s_mixup)
        output_s_p = self.get_p(output_s)
        _, target_index = torch.max(target_s_mixup, dim=1)

        supervised_loss = self.criterion_classifier(output_s, target_s_mixup)
        # assume weight is the same for the two batches (we don't know metadata on the unlabeled batch):
        with self.lock:
            batch_weight = self.estimate_batch_weight(
                metadata_1, indel_weight=indel_weight, snp_weight=snp_weight)

        supervised_loss = supervised_loss * batch_weight
        supervised_loss.backward()

        self.optimizer_training.step()
        performance_estimators.set_metric(batch_idx, "supervised_loss",
                                          supervised_loss.data[0])
        performance_estimators.set_metric_with_outputs(batch_idx,
                                                       "train_accuracy",
                                                       supervised_loss.data[0],
                                                       output_s_p,
                                                       targets=target_index)
        if not self.args.no_progress:
            progress_bar(
                batch_idx * self.mini_batch_size, self.max_training_examples,
                performance_estimators.progress_message([
                    "supervised_loss", "reconstruction_loss", "train_accuracy"
                ]))
    def class_frequency(self,
                        recode_as_multi_label=False,
                        class_frequencies=None):
        """
        Estimate class frequencies for the output vectors of the problem and rebuild criterions
        with weights that correct class imbalances.
        """
        if class_frequencies is None:

            train_loader_subset = self.problem.train_loader_subset_range(
                0,
                min(self.args.num_estimate_class_frequencies,
                    min(100000, self.args.num_training)))
            data_provider = MultiThreadedCpuGpuDataProvider(
                iterator=zip(train_loader_subset),
                is_cuda=False,
                batch_names=["training"],
                volatile={"training": self.problem.get_vector_names()},
            )

            class_frequencies = {}
            done = False
            for batch_idx, (_, data_dict) in enumerate(data_provider):
                if batch_idx * self.mini_batch_size > self.args.num_estimate_class_frequencies:
                    break
                for output_name in self.problem.get_output_names():
                    target_s = data_dict["training"][output_name]
                    if output_name not in class_frequencies.keys():
                        class_frequencies[output_name] = torch.ones(
                            target_s[0].size())
                    cf = class_frequencies[output_name]
                    indices = torch.nonzero(target_s.data)
                    indices = indices[:, 1]
                    for index in range(indices.size(0)):
                        cf[indices[index]] += 1

                progress_bar(batch_idx * self.mini_batch_size,
                             self.max_training_examples, "Class frequencies")
        else:
            self.class_frequencies = class_frequencies
        for output_name in self.problem.get_output_names():

            class_frequencies_output = class_frequencies[output_name]
            # normalize with 1-f, where f is normalized frequency vector:
            weights = torch.ones(class_frequencies_output.size())
            weights -= class_frequencies_output / torch.norm(
                class_frequencies_output, p=1, dim=0)
            if self.use_cuda:
                weights = weights.cuda()

            self.rebuild_criterions(output_name=output_name, weights=weights)
        return class_frequencies
예제 #8
0
    def test_semi_sup(self, epoch):
        print('\nTesting, epoch: %d' % epoch)

        performance_estimators = PerformanceList()
        performance_estimators += [LossHelper("test_supervised_loss")]
        performance_estimators += [LossHelper("test_reconstruction_loss")]
        performance_estimators += [AccuracyHelper("test_")]

        self.net.eval()
        for performance_estimator in performance_estimators:
            performance_estimator.init_performance_metrics()
        validation_loader_subset = self.problem.validation_loader_range(0, self.args.num_validation)
        data_provider = MultiThreadedCpuGpuDataProvider(iterator=zip(validation_loader_subset), is_cuda=self.use_cuda,
                                                        batch_names=["validation"],
                                                        requires_grad={"validation": []},
                                                        volatile={"validation": ["input", "softmaxGenotype"]})
        try:
            for batch_idx, (_, data_dict) in enumerate(data_provider):
                input_s = data_dict["validation"]["input"]
                target_s = data_dict["validation"]["softmaxGenotype"]
                # we need copies of the same tensors:
                input_u, target_u = Variable(input_s.data, volatile=True), Variable(input_s.data, volatile=True)

                output_s = self.net(input_s)
                output_u = self.net.autoencoder(input_u)
                output_s_p = self.get_p(output_s)

                _, target_index = torch.max(target_s, dim=1)

                supervised_loss = self.criterion_classifier(output_s, target_s)
                reconstruction_loss = self.criterion_autoencoder(output_u, target_u)

                performance_estimators.set_metric(batch_idx, "test_supervised_loss", supervised_loss.data[0])
                performance_estimators.set_metric(batch_idx, "test_reconstruction_loss", reconstruction_loss.data[0])
                performance_estimators.set_metric_with_outputs(batch_idx, "test_accuracy", supervised_loss.data[0],
                                                               output_s_p, targets=target_index)

                progress_bar(batch_idx * self.mini_batch_size, self.max_validation_examples,
                             performance_estimators.progress_message(["test_supervised_loss", "test_reconstruction_loss",
                                                                      "test_accuracy"]))

                if ((batch_idx + 1) * self.mini_batch_size) > self.max_validation_examples:
                    break
            # print()
        finally:
            data_provider.close()
        test_metric = performance_estimators.get_metric(self.get_test_metric_name())
        assert test_metric is not None, self.get_test_metric_name() + "must be found among estimated performance metrics"
        if not self.args.constant_learning_rates:
            self.scheduler_train.step(test_metric, epoch)
        return performance_estimators
    def predict(self, iterator, output_filename, max_examples=sys.maxsize):
        self.model.eval()
        if self.processing_type == "multithreaded":
            # Enable fake_GPU_on_CPU to debug on CPU
            data_provider = MultiThreadedCpuGpuDataProvider(
                iterator=zip(iterator),
                is_cuda=self.use_cuda,
                batch_names=["unlabeled"],
                volatile={"unlabeled": [self.input_name]},
                recode_functions=self.recode_fn,
                fake_gpu_on_cpu=False)

        elif self.processing_type == "sequential":
            data_provider = DataProvider(
                iterator=zip(iterator),
                is_cuda=self.use_cuda,
                batch_names=["unlabeled"],
                volatile={"unlabeled": [self.input_name]},
                recode_functions=self.recode_fn)
        else:
            raise Exception("Unrecognized processing type {}".format(
                self.processing_type))

        with VectorWriterBinary(sample_id=0,
                                path_with_basename=output_filename,
                                tensor_names=self.problem.get_output_names(),
                                domain_descriptor=self.domain_descriptor,
                                feature_mapper=self.feature_mapper,
                                samples=self.samples,
                                input_files=self.input_files,
                                problem=self.problem,
                                model=self.model) as writer:
            for batch_idx, (indices_dict,
                            data_dict) in enumerate(data_provider):
                input_u = data_dict["unlabeled"][self.input_name]
                idxs_u = indices_dict["unlabeled"]
                outputs = self.model(input_u)
                writer.append(list(idxs_u), outputs, inverse_logit=True)
                progress_bar(batch_idx * self.mini_batch_size, max_examples)

                if ((batch_idx + 1) * self.mini_batch_size) > max_examples:
                    break
        data_provider.close()
        print("Done")
예제 #10
0
    def write_lines(self):
        num_vector_lines = 0
        try:
            while num_vector_lines < self.total_vector_lines:
                next_vector_line = self.vector_text_reader.get_next_vector_line(
                )
                next_vector_type = self.vector_reader_properties.get_vector_type_from_idx(
                    next_vector_line.line_vector_id)
                if next_vector_type == "float32":
                    packed_type = "f"
                else:
                    raise ValueError("Unknown data type to unpack: {}".format(
                        next_vector_type))
                fmt_string = ">IQII{}{}".format(
                    len(next_vector_line.line_vector_elements), packed_type)
                self.output_writer.write(
                    struct.pack(fmt_string, next_vector_line.line_sample_id,
                                next_vector_line.line_example_id,
                                next_vector_line.line_vector_id,
                                len(next_vector_line.line_vector_elements),
                                *next_vector_line.line_vector_elements))
                num_vector_lines += 1
                num_examples_written = num_vector_lines / self.num_vector_lines_per_example
                if (num_vector_lines % self.num_vector_lines_per_example
                        == 0) and (num_examples_written % 1000 == 1):
                    progress_bar(num_examples_written, self.max_records,
                                 "Caching " + self.path_basename)

        except StopIteration:
            pass
        self.output_writer.flush()
        self.output_writer.seek(0, 2)
        num_bytes_written = self.output_writer.tell()
        if num_bytes_written != self.expected_bytes:
            raise RuntimeError(
                "Number of bytes written {} differs from expected {}. "
                "Wrote {} vector lines, {} per example "
                "for {} records, {} stray lines.".format(
                    num_bytes_written, self.expected_bytes, num_vector_lines,
                    self.num_vector_lines_per_example,
                    num_vector_lines / self.num_vector_lines_per_example,
                    num_vector_lines % self.num_vector_lines_per_example))
        self.close()
예제 #11
0
    def test_autoencoder(self, epoch, performance_estimators=None):
        print('\nTesting, epoch: %d' % epoch)
        if performance_estimators is None:
            performance_estimators = PerformanceList()
            performance_estimators += [LossHelper("test_loss")]

        self.net.eval()
        for performance_estimator in performance_estimators:
            performance_estimator.init_performance_metrics()

        for batch_idx, (_, data_dict) in enumerate(
                self.problem.validation_loader_range(
                    0, self.args.num_validation)):
            inputs = data_dict["input"]
            if self.use_cuda:
                inputs = inputs.cuda()

            inputs, targets = Variable(inputs,
                                       volatile=True), Variable(inputs,
                                                                volatile=True)

            outputs = self.net(inputs)
            loss = self.criterion(outputs, targets)

            performance_estimators.set_metric_with_outputs(
                batch_idx, "test_loss", loss.data[0], outputs, targets)

            progress_bar(
                batch_idx * self.mini_batch_size, self.max_validation_examples,
                performance_estimators.progress_message(["test_loss"]))

            if ((batch_idx + 1) *
                    self.mini_batch_size) > self.max_validation_examples:
                break
        # print()

        # Apply learning rate schedule:
        test_accuracy = performance_estimators.get_metric("test_loss")
        assert test_accuracy is not None, "test_loss must be found among estimated performance metrics"
        if not self.args.constant_learning_rates:
            self.scheduler_train.step(test_accuracy, epoch)
        return performance_estimators
    def test_one_batch(self, performance_estimators, batch_idx,
                       input_supervised, target_s):
        self.src_encoder.eval()
        self.tgt_encoder.eval()
        self.critic.eval()

        output_s = self.net.forward_with_src_encoding(input_supervised)
        output_s_p = self.get_p(output_s)

        supervised_loss = self.criterion_classifier(output_s, target_s)

        _, target_index = torch.max(target_s, dim=1)
        _, output_index = torch.max(output_s_p, dim=1)
        performance_estimators.set_metric(batch_idx, "test_supervised_loss",
                                          supervised_loss.item())
        performance_estimators.set_metric_with_outputs(batch_idx,
                                                       "test_accuracy",
                                                       supervised_loss.item(),
                                                       output_s_p,
                                                       targets=target_index)
        # use target encoding:
        output_s = self.net(input_supervised)
        output_s_p = self.get_p(output_s)

        supervised_loss = self.criterion_classifier(output_s, target_s)

        _, target_index = torch.max(target_s, dim=1)
        _, output_index = torch.max(output_s_p, dim=1)
        performance_estimators.set_metric(batch_idx,
                                          "test_encoded_supervised_loss",
                                          supervised_loss.item())
        performance_estimators.set_metric_with_outputs(batch_idx,
                                                       "test_encoded_accuracy",
                                                       supervised_loss.item(),
                                                       output_s_p,
                                                       targets=target_index)
        if not self.args.no_progress:
            progress_bar(
                batch_idx * self.mini_batch_size, self.max_validation_examples,
                performance_estimators.progress_message(
                    ["test_accuracy", "test_encoded_accuracy"]))
예제 #13
0
            if done:
                break
            for output_name in problem.get_output_names():
                target_s = data_dict["training"][output_name]
                if output_name not in class_frequencies.keys():
                    class_frequencies[output_name] = torch.ones(target_s.size(1))
                cf = class_frequencies[output_name]
                indices = torch.nonzero(target_s.data)
                indices=indices[:, 1]
                for index in range(indices.size(0)):
                    cf[indices[index]]+=1
                #for class_index in range(target_s.size(1)):
                #    cf[indices[:,1]] += 1

            progress_bar(batch_idx * args.mini_batch_size,
                         args.num_estimate_class_frequencies,
                         "Class frequencies")
    print("class frequencies: "+str(class_frequencies))
    del train_loader_subset

    # Initialize the trainers:
    global_lock = threading.Lock()
    for trainer_command_line in trainer_arguments:
        trainer_parser = define_train_auto_encoder_parser()
        trainer_args = trainer_parser.parse_args(trainer_command_line.split())

        if trainer_args.max_examples_per_epoch is None:
            trainer_args.max_examples_per_epoch = trainer_args.num_training
        trainer_args.num_training = args.num_training
        print("Executing " + trainer_args.checkpoint_key)
    def supervised_somatic(self, epoch, performance_estimators=None):

        if performance_estimators is None:
            performance_estimators = PerformanceList()
            performance_estimators += [LossHelper("train_loss")]
            performance_estimators += [LossHelper("classification_loss")]
            performance_estimators += [LossHelper("frequency_loss")]
            performance_estimators += [FloatHelper("train_grad_norm")]
            print('\nTraining, epoch: %d' % epoch)

        self.net.train()
        supervised_grad_norm = 1.
        for performance_estimator in performance_estimators:
            performance_estimator.init_performance_metrics()

        unsupervised_loss_acc = 0
        num_batches = 0
        train_loader_subset = self.problem.train_loader_subset_range(
            0, self.args.num_training)
        cross_entropy_loss = CrossEntropyLoss()
        mse_loss = MSELoss()
        self.net.train()

        for batch_idx, (_, data_dict) in enumerate(train_loader_subset):
            inputs = data_dict["input"].to(self.device)
            is_mutated_base_target = data_dict["isBaseMutated"].to(self.device)
            # transform one-hot encoding into a class index:
            max, indices = is_mutated_base_target.max(dim=1)
            is_mutated_base_target = indices
            somatic_frequency_target = data_dict["somaticFrequency"].to(
                self.device)
            num_batches += 1

            # outputs used to calculate the loss of the supervised model
            # must be done with the model prior to regularization:

            self.optimizer_training.zero_grad()
            output_mut, output_frequency = self.net(inputs)

            classification_loss = cross_entropy_loss(output_mut,
                                                     is_mutated_base_target)
            frequency_loss = mse_loss(output_frequency,
                                      somatic_frequency_target)
            optimized_loss = classification_loss + frequency_loss

            optimized_loss.backward()
            self.optimizer_training.step()
            performance_estimators.set_metric(batch_idx, "train_loss",
                                              optimized_loss.item())
            performance_estimators.set_metric(batch_idx, "classification_loss",
                                              classification_loss.item())
            performance_estimators.set_metric(batch_idx, "frequency_loss",
                                              frequency_loss.item())

            supervised_grad_norm = grad_norm(self.net.parameters())
            performance_estimators.set_metric(batch_idx, "train_grad_norm",
                                              supervised_grad_norm)

            progress_bar(
                batch_idx * self.mini_batch_size, self.max_training_examples,
                performance_estimators.progress_message(
                    ["classification_loss", "frequency_loss"]))

            if (batch_idx +
                    1) * self.mini_batch_size > self.max_training_examples:
                break

        return performance_estimators
    def train_one_batch(self, performance_estimators, batch_idx,
                        input_supervised, input_unlabeled):
        self.critic.train()
        self.tgt_encoder.train()
        self.src_encoder.eval()
        ###########################
        # 2.1 train discriminator #
        ###########################
        batch_size = input_unlabeled.size(0)
        # zero gradients for optimizer
        self.optimizer_critic.zero_grad()
        self.optimizer_tgt.zero_grad()
        self.tgt_encoder.zero_grad()
        self.critic.zero_grad()
        # extract and concat features
        feat_src = self.src_encoder(input_supervised)
        feat_tgt = self.tgt_encoder(input_unlabeled)
        feat_concat = torch.cat((feat_src, feat_tgt), 0)

        # predict on discriminator
        pred_concat = self.critic(feat_concat)

        # prepare real and fake label
        source_is_training_set = torch.ones(batch_size).long()
        source_is_unlabeled_set = torch.zeros(batch_size).long()

        label_src = make_variable(source_is_training_set, requires_grad=False)
        label_tgt = make_variable(source_is_unlabeled_set, requires_grad=False)
        label_concat = torch.cat((label_src, label_tgt), 0)

        # compute loss for critic
        loss_critic = self.criterion_nll(pred_concat, label_concat)
        loss_critic.backward()

        # optimize critic
        self.optimizer_critic.step()

        pred_cls = torch.squeeze(pred_concat.max(1)[1])
        accuracy = (pred_cls == label_concat).float().mean()

        ############################
        # 2.2 train target encoder #
        ############################

        # train to make unlabeled into training:

        source_is_training_set = source_is_training_set
        train_encoded_accuracy = self.train_encoder_with(
            batch_idx, performance_estimators, input_unlabeled,
            source_is_training_set)
        performance_estimators.set_metric(batch_idx, "train_encoded_accuracy",
                                          train_encoded_accuracy)

        if self.args.adda_pass_through:
            # train pass-through for training set examples:
            source_is_training_set = source_is_training_set
            train_encoded_accuracy = self.train_encoder_with(
                batch_idx, performance_estimators, input_supervised,
                source_is_training_set)
            performance_estimators.set_metric(batch_idx,
                                              "train_encoded_accuracy",
                                              train_encoded_accuracy)

        ratio = self.calculate_ratio(train_encoded_accuracy, accuracy.item())
        performance_estimators.set_metric(batch_idx, "ratio", ratio)

        performance_estimators.set_metric(batch_idx, "train_critic_loss",
                                          loss_critic.item())
        performance_estimators.set_metric(batch_idx, "train_accuracy",
                                          accuracy.item())

        if not self.args.no_progress:
            progress_bar(
                batch_idx * self.mini_batch_size, self.max_training_examples,
                performance_estimators.progress_message([
                    "train_critic_loss", "train_encoder_loss",
                    "train_accuracy", "train_encoded_accuracy", "ratio"
                ]))
예제 #16
0
    def train_semisup(self, epoch):
        performance_estimators = PerformanceList()
        performance_estimators += [FloatHelper("optimized_loss")]
        performance_estimators += [FloatHelper("supervised_loss")]
        performance_estimators += [FloatHelper("reconstruction_loss")]
        performance_estimators += [AccuracyHelper("train_")]

        print('\nTraining, epoch: %d' % epoch)

        self.net.train()

        for performance_estimator in performance_estimators:
            performance_estimator.init_performance_metrics()

        unsupervised_loss_acc = 0
        num_batches = 0
        train_loader_subset = self.problem.train_loader_subset_range(0, self.args.num_training)
        unlabeled_loader = self.problem.unlabeled_loader()
        data_provider = MultiThreadedCpuGpuDataProvider(iterator=zip(train_loader_subset, unlabeled_loader),is_cuda=self.use_cuda,
                                     batch_names=["training", "unlabeled"],
                                     requires_grad={"training": ["input"], "unlabeled": ["input"]},
                                     volatile={"training": ["metaData"], "unlabeled": []},
                                     recode_functions={"softmaxGenotype": lambda x: recode_for_label_smoothing(x,self.epsilon)})
        self.net.autoencoder.train()
        try:
            for batch_idx, (_, data_dict) in enumerate(data_provider):
                input_s = data_dict["training"]["input"]
                metadata = data_dict["training"]["metaData"]
                target_s = data_dict["training"]["softmaxGenotype"]
                input_u = data_dict["unlabeled"]["input"]
                num_batches += 1

                # need a copy of input_u and input_s as output:
                target_u = Variable(input_u.data, requires_grad=False)
                target_output_s = Variable(input_s.data, requires_grad=False)
                # outputs used to calculate the loss of the supervised model
                # must be done with the model prior to regularization:

                # Zero gradients:
                self.net.zero_grad()
                self.net.autoencoder.zero_grad()
                self.optimizer_training.zero_grad()

                output_s = self.net(input_s)
                output_u = self.net.autoencoder(input_u)
                input_output_s = self.net.autoencoder(input_s)
                output_s_p = self.get_p(output_s)

                _, target_index = torch.max(target_s, dim=1)
                supervised_loss = self.criterion_classifier(output_s, target_s)
                reconstruction_loss_unsup = self.criterion_autoencoder(output_u, target_u)
                reconstruction_loss_sup = self.criterion_autoencoder(input_output_s, target_output_s)
                reconstruction_loss = self.args.gamma * reconstruction_loss_unsup+reconstruction_loss_sup
                optimized_loss = supervised_loss + reconstruction_loss
                optimized_loss.backward()
                self.optimizer_training.step()
                performance_estimators.set_metric(batch_idx, "supervised_loss", supervised_loss.data[0])
                performance_estimators.set_metric(batch_idx, "reconstruction_loss", reconstruction_loss.data[0])
                performance_estimators.set_metric(batch_idx, "optimized_loss", optimized_loss.data[0])
                performance_estimators.set_metric_with_outputs(batch_idx, "train_accuracy", supervised_loss.data[0],
                                                               output_s_p, targets=target_index)

                progress_bar(batch_idx * self.mini_batch_size,
                             self.max_training_examples,
                             performance_estimators.progress_message(["supervised_loss", "reconstruction_loss",
                                                                      "train_accuracy"]))

                if (batch_idx + 1) * self.mini_batch_size > self.max_training_examples:
                    break
        finally:
            data_provider.close()

        return performance_estimators
예제 #17
0
    def test_semisup_aae(self, epoch, performance_estimators=None):
        print('\nTesting, epoch: %d' % epoch)
        if performance_estimators is None:
            performance_estimators = PerformanceList()
            performance_estimators += [FloatHelper("reconstruction_loss")]
            performance_estimators += [LossHelper("test_loss")]
            performance_estimators += [AccuracyHelper("test_")]
            performance_estimators += [FloatHelper("weight")]

        self.net.eval()
        for performance_estimator in performance_estimators:
            performance_estimator.init_performance_metrics()
        validation_loader_subset = self.problem.validation_loader_range(
            0, self.args.num_validation)
        data_provider = MultiThreadedCpuGpuDataProvider(
            iterator=zip(validation_loader_subset),
            is_cuda=self.use_cuda,
            batch_names=["validation"],
            requires_grad={"validation": []},
            volatile={
                "validation": ["input", "softmaxGenotype"],
            },
            recode_functions={"input": self.normalize_inputs})
        self.net.eval()
        try:
            for batch_idx, (_, data_dict) in enumerate(data_provider):
                input_s = data_dict["validation"]["input"]
                target_s = data_dict["validation"]["softmaxGenotype"]

                # Estimate the reconstruction loss on validation examples:
                reconstruction_loss = self.net.get_crossconstruction_loss(
                    input_s, input_s, target_s)

                # now evaluate prediction of categories:
                categories_predicted, latent_code = self.net.encoder(input_s)
                #            categories_predicted+=self.net.latent_to_categories(latent_code)

                categories_predicted_p = self.get_p(categories_predicted)
                categories_predicted_p[
                    categories_predicted_p != categories_predicted_p] = 0.0
                _, target_index = torch.max(target_s, dim=1)
                categories_loss = self.net.semisup_loss_criterion(
                    categories_predicted, target_s)

                weight = self.estimate_example_density_weight(latent_code)
                performance_estimators.set_metric(batch_idx,
                                                  "reconstruction_loss",
                                                  reconstruction_loss.data[0])
                performance_estimators.set_metric(batch_idx, "weight", weight)
                performance_estimators.set_metric_with_outputs(
                    batch_idx, "test_accuracy", reconstruction_loss.data[0],
                    categories_predicted_p, target_index)
                performance_estimators.set_metric_with_outputs(
                    batch_idx, "test_loss", categories_loss.data[0] * weight,
                    categories_predicted_p, target_s)

                if not self.args.no_progress:
                    progress_bar(
                        batch_idx * self.mini_batch_size,
                        self.max_validation_examples,
                        performance_estimators.progress_message([
                            "test_loss", "test_accuracy", "reconstruction_loss"
                        ]))

                if ((batch_idx + 1) *
                        self.mini_batch_size) > self.max_validation_examples:
                    break
            # print()
        finally:
            data_provider.close()
        # Apply learning rate schedules:
        test_metric = performance_estimators.get_metric(
            self.get_test_metric_name())
        assert test_metric is not None, (
            self.get_test_metric_name() +
            "must be found among estimated performance metrics")
        if not self.args.constant_learning_rates:
            for scheduler in self.schedulers:
                scheduler.step(test_metric, epoch)
        # Run the garbage collector to try to release memory we no longer need:
        import gc
        gc.collect()
        return performance_estimators
예제 #18
0
    def train_semisup_aae(self, epoch, performance_estimators=None):
        if performance_estimators is None:
            performance_estimators = PerformanceList()
            performance_estimators += [FloatHelper("reconstruction_loss")]
            performance_estimators += [FloatHelper("discriminator_loss")]
            performance_estimators += [FloatHelper("generator_loss")]
            performance_estimators += [FloatHelper("supervised_loss")]
            performance_estimators += [FloatHelper("weight")]
            print('\nTraining, epoch: %d' % epoch)
        for performance_estimator in performance_estimators:
            performance_estimator.init_performance_metrics()

        self.net.train()
        supervised_grad_norm = 1.
        for performance_estimator in performance_estimators:
            performance_estimator.init_performance_metrics()

        unsupervised_loss_acc = 0
        num_batches = 0
        train_loader_subset1 = self.problem.train_loader_subset_range(
            0, self.args.num_training)
        train_loader_subset2 = self.problem.train_loader_subset_range(
            0, self.args.num_training)

        data_provider = MultiThreadedCpuGpuDataProvider(
            iterator=zip(train_loader_subset1, train_loader_subset2),
            is_cuda=self.use_cuda,
            batch_names=["training1", "training2"],
            requires_grad={
                "training1": ["input"],
                "training2": ["input"]
            },
            volatile={
                "training1": ["metaData"],
                "training2": ["metaData"]
            },
            recode_functions={
                "softmaxGenotype": recode_for_label_smoothing,
                "input": self.normalize_inputs
            })

        indel_weight = self.args.indel_weight_factor
        snp_weight = 1.0

        latent_codes = []
        try:
            for batch_idx, (_, data_dict) in enumerate(data_provider):
                input_s1 = data_dict["training1"]["input"]
                input_s2 = data_dict["training2"]["input"]
                target_s1 = data_dict["training1"]["softmaxGenotype"]
                target_s2 = data_dict["training2"]["softmaxGenotype"]

                meta_data1 = data_dict["training1"]["metaData"]
                meta_data2 = data_dict["training2"]["metaData"]
                num_batches += 1
                self.zero_grad_all_optimizers()

                # input_s=normalize_mean_std(input_s)
                # input_u=normalize_mean_std(input_u)
                # print(torch.mean(input_s,dim=0))
                # Train reconstruction phase:
                self.net.decoder.train()
                reconstruction_loss = self.net.get_crossconstruction_loss(
                    input_s1, input_s2, target_s2)
                reconstruction_loss.backward()
                for opt in [self.decoder_opt, self.encoder_reconstruction_opt]:
                    opt.step()

                # Train discriminators:
                self.net.encoder.train()
                self.net.discriminator_cat.train()
                self.net.discriminator_prior.train()
                self.zero_grad_all_optimizers()
                genotype_frequencies = self.class_frequencies[
                    "softmaxGenotype"]
                category_prior = (genotype_frequencies /
                                  torch.sum(genotype_frequencies)).numpy()
                discriminator_loss = self.net.get_discriminator_loss(
                    input_s1, category_prior=category_prior)
                discriminator_loss.backward()
                for opt in [
                        self.discriminator_cat_opt,
                        self.discriminator_prior_opt
                ]:
                    opt.step()
                self.zero_grad_all_optimizers()

                # Train generator:
                self.net.encoder.train()
                generator_loss = self.net.get_generator_loss(input_s1)
                generator_loss.backward()
                for opt in [self.encoder_generator_opt]:
                    opt.step()
                self.zero_grad_all_optimizers()

                if self.use_pdf:
                    self.net.encoder.train()
                    _, latent_code = self.net.encoder(input_s1)
                    weight = self.estimate_example_density_weight(latent_code)
                else:
                    weight = self.estimate_batch_weight(
                        meta_data1,
                        indel_weight=indel_weight,
                        snp_weight=snp_weight)
                self.net.encoder.train()
                supervised_loss = self.net.get_crossencoder_supervised_loss(
                    input_s1, target_s1) * weight
                supervised_loss.backward()

                for opt in [self.encoder_semisup_opt]:
                    opt.step()
                self.zero_grad_all_optimizers()

                performance_estimators.set_metric(batch_idx,
                                                  "reconstruction_loss",
                                                  reconstruction_loss.data[0])
                performance_estimators.set_metric(batch_idx,
                                                  "discriminator_loss",
                                                  discriminator_loss.data[0])
                performance_estimators.set_metric(batch_idx, "generator_loss",
                                                  generator_loss.data[0])
                performance_estimators.set_metric(batch_idx, "supervised_loss",
                                                  supervised_loss.data[0])
                performance_estimators.set_metric(batch_idx, "weight", weight)
                if not self.args.no_progress:
                    progress_bar(
                        batch_idx * self.mini_batch_size,
                        self.max_training_examples,
                        performance_estimators.progress_message([
                            "reconstruction_loss", "discriminator_loss",
                            "generator_loss", "semisup_loss"
                        ]))
                if ((batch_idx + 1) *
                        self.mini_batch_size) > self.max_training_examples:
                    break
        finally:
            data_provider.close()

        return performance_estimators
    def train_one_batch(self, performance_estimators, batch_idx, input_s, target_s, meta_data, input_u):

        self.zero_grad_all_optimizers()

        self.num_classes = len(target_s[0])
        # Train reconstruction phase:
        self.net.encoder.train()
        self.net.decoder.train()
        reconstruction_loss = self.net.get_reconstruction_loss(input_u)
        reconstruction_loss.backward()
        for opt in [self.decoder_opt, self.encoder_reconstruction_opt]:
            opt.step()

        # Train discriminators:
        self.net.encoder.train()
        self.net.discriminator_cat.train()
        self.net.discriminator_prior.train()
        self.zero_grad_all_optimizers()
        genotype_frequencies = self.class_frequencies["softmaxGenotype"]
        category_prior = (genotype_frequencies / torch.sum(genotype_frequencies)).numpy()
        discriminator_loss = self.net.get_discriminator_loss(common_trainer=self, model_input=input_u,
                                                             category_prior=category_prior,
                                                             recode_labels=lambda x: recode_for_label_smoothing(x,
                                                                                                                epsilon=self.epsilon))
        discriminator_loss.backward()
        for opt in [self.discriminator_cat_opt, self.discriminator_prior_opt]:
            opt.step()
        self.zero_grad_all_optimizers()

        # Train generator:
        self.net.encoder.train()
        generator_loss = self.net.get_generator_loss(input_u)
        generator_loss.backward()
        for opt in [self.encoder_generator_opt]:
            opt.step()
        self.zero_grad_all_optimizers()
        weight = 1
        if self.use_pdf:
            self.net.encoder.train()
            _, latent_code = self.net.encoder(input_s)
            weight *= self.estimate_example_density_weight(latent_code)
        indel_weight = self.args.indel_weight_factor
        snp_weight = 1.0

        weight *= self.estimate_batch_weight(meta_data, indel_weight=indel_weight,
                                             snp_weight=snp_weight)
        self.net.encoder.train()
        semisup_loss = self.net.get_semisup_loss(input_s, target_s) * weight
        semisup_loss.backward()

        for opt in [self.encoder_semisup_opt]:
            opt.step()
        self.zero_grad_all_optimizers()

        performance_estimators.set_metric(batch_idx, "reconstruction_loss", reconstruction_loss.item())
        performance_estimators.set_metric(batch_idx, "discriminator_loss", discriminator_loss.item())
        performance_estimators.set_metric(batch_idx, "generator_loss", generator_loss.item())
        performance_estimators.set_metric(batch_idx, "semisup_loss", semisup_loss.item())
        performance_estimators.set_metric(batch_idx, "weight", weight)

        if self.args.latent_code_output is not None:
            _, latent_code = self.net.encoder(input_u)
            # Randomly select n rows from the minibatch to keep track of the latent codes for
            idxs_to_sample = torch.randperm(latent_code.size()[0])[:self.args.latent_code_n_per_minibatch]
            for row_idx in idxs_to_sample:
                latent_code_row = latent_code[row_idx]
                self.gaussian_codes.append(torch.squeeze(draw_from_gaussian(latent_code_row.size()[0], 1)))
                self.latent_codes.append(latent_code_row)

        if not self.args.no_progress:
            progress_bar(batch_idx * self.mini_batch_size, self.max_training_examples,
                         performance_estimators.progress_message(
                             ["reconstruction_loss", "discriminator_loss", "generator_loss", "semisup_loss"]))