def test(self, model, test_set, l_loss, m_loss):
        model.train(mode=False)
        loss_classification_sum = 0
        loss_segmentation_sum = 0
        accuracy_classification_sum = 0
        batch_count = 0
        for images, segments, labels in test_set:
            labels, segments = model_utils.reduce_to_class_number(self.left_class_number, self.right_class_number,
                                                                  labels,
                                                                  segments)
            images, labels, segments = self.convert_data_and_label(images, labels, segments)
            segments_list = []
            for puller in self.puller:
                segments_list.append(puller(segments))
            model_classification, model_segmentation = model_utils.wait_while_can_execute(model, images)

            classification_loss = l_loss(model_classification, labels)
            if self.use_mloss:
                sum_segm_loss = None
                for ms, sl in zip(model_segmentation, segments_list):
                    segmentation_loss = self.m_loss(ms, sl)
                    if sum_segm_loss is None:
                        sum_segm_loss = segmentation_loss
                    else:
                        sum_segm_loss += segmentation_loss

            output_probability, output_cl, cl_acc = self.calculate_accuracy(labels, model_classification,
                                                                            labels.size(0))

            self.save_test_data(labels, output_cl, output_probability)

            # accumulate information
            accuracy_classification_sum += model_utils.scalar(cl_acc.sum())
            loss_classification_sum += model_utils.scalar(classification_loss.sum())
            if self.use_mloss:
                loss_segmentation_sum += model_utils.scalar(sum_segm_loss.sum())
            batch_count += 1
            # self.de_convert_data_and_label(images, labels)
            # torch.cuda.empty_cache()

        f_1_score_text, recall_score_text, precision_score_text = metrics_processor.calculate_metric(self.classes,
                                                                                                     self.test_trust_answers,
                                                                                                     self.test_model_answers)

        loss_classification_sum /= batch_count + p.EPS
        accuracy_classification_sum /= batch_count + p.EPS
        loss_segmentation_sum /= batch_count + p.EPS
        text = 'TEST={} Loss_CL={:.5f} Loss_M={:.5f} Accuracy_CL={:.5f} {} {} {} '.format(self.current_epoch,
                                                                                          loss_classification_sum,
                                                                                          loss_segmentation_sum,
                                                                                          accuracy_classification_sum,
                                                                                          f_1_score_text,
                                                                                          recall_score_text,
                                                                                          precision_score_text)
        p.write_to_log(text)
        model.train(mode=True)
        return loss_classification_sum, accuracy_classification_sum
Ejemplo n.º 2
0
    def train_segments(self, model, l_loss, m_loss,
                       optimizer: torch.optim.Adam, train_set):
        model.train(mode=True)
        accuracy_classification_sum = 0
        loss_m_sum = 0
        loss_l1_sum = 0
        loss_classification_sum = 0
        batch_count = 0

        for images, segments, labels in train_set:
            labels, segments = model_utils.reduce_to_class_number(
                self.left_class_number, self.right_class_number, labels,
                segments)
            images, labels, segments = self.convert_data_and_label(
                images, labels, segments)
            segments = self.puller(segments)
            optimizer.zero_grad()
            model_classification, model_segmentation = model_utils.wait_while_can_execute(
                model, images)

            classification_loss = l_loss(model_classification, labels)
            segmentation_loss = m_loss(model_segmentation, segments)

            #torch.cuda.empty_cache()
            segmentation_loss.backward()
            optimizer.step()

            output_probability, output_cl, cl_acc = self.calculate_accuracy(
                labels, model_classification, labels.size(0))
            self.save_train_data(labels, output_cl, output_probability)

            # accumulate information
            accuracy_classification_sum += model_utils.scalar(cl_acc.sum())
            loss_m_sum += model_utils.scalar(segmentation_loss.sum())
            loss_l1_sum += 0
            loss_classification_sum += model_utils.scalar(
                classification_loss.sum())
            batch_count += 1
            #self.de_convert_data_and_label(images, labels, segments)
            #torch.cuda.empty_cache()
        model.train(mode=False)
        return accuracy_classification_sum / (
            batch_count +
            p.EPS), loss_m_sum / (batch_count + p.EPS), loss_l1_sum / (
                batch_count + p.EPS), loss_classification_sum / (batch_count +
                                                                 p.EPS)
    def train(self):
        optimizer = torch.optim.Adam(self.am_model.parameters(), self.classifier_learning_rate)

        while self.current_epoch <= self.train_epochs:

            loss_m_sum = 0
            loss_l1_sum = 0

            loss_classification_sum = 0
            loss_segmentation_sum = 0
            accuracy_sum = 0
            batch_count = 0
            self.am_model.train(mode=True)
            for images, segments, labels in self.train_segments_set:
                labels, segments = model_utils.reduce_to_class_number(self.left_class_number, self.right_class_number,
                                                                      labels,
                                                                      segments)
                images, labels, segments = self.convert_data_and_label(images, labels, segments)
                segments_list = []
                for puller in self.puller:
                    segments_list.append(puller(segments))

                # calculate and optimize model
                optimizer.zero_grad()

                model_classification, model_segmentation = model_utils.wait_while_can_execute(self.am_model, images)
                classification_loss = self.l_loss(model_classification, labels)
                total_loss = classification_loss

                if self.use_mloss:
                    sum_segm_loss = None
                    for ms, sl in zip(model_segmentation, segments_list):
                        segmentation_loss = self.m_loss(ms, sl)
                        total_loss += segmentation_loss
                        if sum_segm_loss is None:
                            sum_segm_loss = segmentation_loss
                        else:
                            sum_segm_loss += segmentation_loss
                total_loss.backward()
                optimizer.step()

                output_probability, output_cl, cl_acc = self.calculate_accuracy(labels, model_classification,
                                                                                labels.size(0))

                optimizer.zero_grad()

                self.save_train_data(labels, output_cl, output_probability)

                accuracy_sum += model_utils.scalar(cl_acc.sum())
                loss_classification_sum += model_utils.scalar(classification_loss.sum())
                if self.use_mloss:
                    loss_segmentation_sum += model_utils.scalar(sum_segm_loss.sum())
                batch_count += 1

            loss_classification_sum = loss_classification_sum / (batch_count + p.EPS)
            accuracy_sum = accuracy_sum / (batch_count + p.EPS)
            loss_segmentation_sum = loss_segmentation_sum / (batch_count + p.EPS)
            loss_total = loss_classification_sum + loss_m_sum + loss_segmentation_sum
            prefix = "TRAIN"
            f_1_score_text, recall_score_text, precision_score_text = metrics_processor.calculate_metric(self.classes,
                                                                                                         self.train_trust_answers,
                                                                                                         self.train_model_answers)

            text = "{}={} Loss_CL={:.5f} Loss_M={:.5f} Loss_L1={:.5f} Loss_Total={:.5f} Accuracy_CL={:.5f} " \
                   "{} {} {} ".format(prefix, self.current_epoch, loss_classification_sum,
                                      loss_m_sum,
                                      loss_l1_sum,
                                      loss_total,
                                      accuracy_sum,
                                      f_1_score_text,
                                      recall_score_text,
                                      precision_score_text)

            P.write_to_log(text)

            if self.current_epoch % self.test_each_epoch == 0:
                test_loss, _ = self.test(self.am_model, self.test_set, self.l_loss, self.m_loss)

            self.clear_temp_metrics()
            self.current_epoch += 1
Ejemplo n.º 4
0
    def take_snapshot(self, data_set, model, snapshot_name: str = None):
        cnt = 0
        model_segments_list = []
        trust_segments_list = []
        images_list = []

        for images, segments, labels in data_set:

            segments = segments[:, self.left_class_number:self.
                                right_class_number, :, :]

            images, labels, segments = self.convert_data_and_label(
                images, labels, segments)
            segments = self.puller(segments)
            _, model_segmentation = model_utils.wait_while_can_execute(
                model, images)

            cnt += segments.size(0)
            images, _, segments = self.de_convert_data_and_label(
                images, labels, segments)
            model_segmentation = model_segmentation.cpu()
            for idx in range(segments.size(0)):
                images_list.append(images[idx])
                model_segments_list.append(model_segmentation[idx])
                trust_segments_list.append(segments[idx])

            if cnt >= self.snapshot_elements_count:
                break
        fig, axes = plt.subplots(len(images_list),
                                 model_segments_list[0].size(0) * 3 + 1,
                                 figsize=(50, 100))
        fig.tight_layout()
        for idx, img in enumerate(images_list):
            axes[idx][0].imshow(np.transpose(img.numpy(), (1, 2, 0)))

        for idx, (trist_answer, model_answer) in enumerate(
                zip(trust_segments_list, model_segments_list)):
            for class_number in range(trist_answer.size(0)):
                a = model_answer[class_number].detach().numpy()
                a = np.array([a] * 3)
                axes[idx][1 + class_number * 3].imshow(
                    np.transpose(a, (1, 2, 0)))
                p.write_to_log(
                    "model        idx={}, class={}, sum={}, max={}, min={}".
                    format(idx, class_number, np.sum(a), np.max(a), np.min(a)))
                a = (a - np.min(a)) / (np.max(a) - np.min(a))
                axes[idx][1 + class_number * 3 + 1].imshow(
                    np.transpose(a, (1, 2, 0)))
                p.write_to_log(
                    "model normed idx={}, class={}, sum={}, max={}, min={}".
                    format(idx, class_number, np.sum(a), np.max(a), np.min(a)))

                a = trist_answer[class_number].detach().numpy()
                a = np.array([a] * 3)
                axes[idx][1 + class_number * 3 + 2].imshow(
                    np.transpose(a, (1, 2, 0)))
                p.write_to_log(
                    "trust        idx={}, class={}, sum={}, max={}, min={}".
                    format(idx, class_number, np.sum(a), np.max(a), np.min(a)))

                p.write_to_log("=" * 50)

                axes[idx][1 + class_number * 3].set(
                    xlabel='model answer class: {}'.format(class_number))
                axes[idx][1 + class_number * 3 +
                          1].set(xlabel='model normed answer class: {}'.format(
                              class_number))
                axes[idx][1 + class_number * 3 + 2].set(
                    xlabel='trust answer class: {}'.format(class_number))
        print("=" * 50)
        print("=" * 50)
        print("=" * 50)
        print("=" * 50)
        print("=" * 50)
        plt.savefig(os.path.join(self.snapshot_dir, snapshot_name))
        plt.close(fig)
Ejemplo n.º 5
0
    def train(self):
        if self.is_vgg_model:
            classifier_optimizer = torch.optim.Adam(gr.register_weights("classifier", self.am_model),
                                                    self.classifier_learning_rate)
            attention_module_optimizer = torch.optim.Adam(gr.register_weights("attention", self.am_model),
                                                          lr=self.attention_module_learning_rate)
        else:
            classifier_optimizer = torch.optim.Adam(rgr.register_weights("classifier", self.am_model),
                                                    self.classifier_learning_rate)
            attention_module_optimizer = torch.optim.Adam(rgr.register_weights("attention", self.am_model),
                                                          lr=self.attention_module_learning_rate)

        while self.current_epoch <= self.train_epochs:

            loss_m_sum = 0
            loss_l1_sum = 0

            loss_classification_sum = 0
            loss_segmentation_sum = 0
            accuracy_sum = 0
            batch_count = 0
            self.am_model.train(mode=True)
            for images, segments, labels in self.train_segments_set:
                labels, segments = model_utils.reduce_to_class_number(self.left_class_number, self.right_class_number,
                                                                      labels,
                                                                      segments)
                images, labels, segments = self.convert_data_and_label(images, labels, segments)
                segments = self.puller(segments)

                # calculate and optimize model
                classifier_optimizer.zero_grad()
                attention_module_optimizer.zero_grad()

                model_classification, model_segmentation = model_utils.wait_while_can_execute(self.am_model, images)
                segmentation_loss = self.m_loss(model_segmentation, segments)
                classification_loss = self.l_loss(model_classification, labels)
                # torch.cuda.empty_cache()
                classification_loss.backward(retain_graph=True)
                segmentation_loss.backward()

                classifier_optimizer.step()
                attention_module_optimizer.step()

                output_probability, output_cl, cl_acc = self.calculate_accuracy(labels, model_classification,
                                                                                labels.size(0))

                classifier_optimizer.zero_grad()
                attention_module_optimizer.zero_grad()

                self.save_train_data(labels, output_cl, output_probability)

                # accumulate information
                accuracy_sum += model_utils.scalar(cl_acc.sum())
                loss_classification_sum += model_utils.scalar(classification_loss.sum())
                loss_segmentation_sum += model_utils.scalar(segmentation_loss.sum())
                batch_count += 1
                # self.de_convert_data_and_label(images, segments, labels)
                # torch.cuda.empty_cache()

            loss_classification_sum = loss_classification_sum / (batch_count + p.EPS)
            accuracy_sum = accuracy_sum / (batch_count + p.EPS)
            loss_segmentation_sum = loss_segmentation_sum / (batch_count + p.EPS)
            loss_total = loss_classification_sum + loss_m_sum + loss_segmentation_sum
            prefix = "PRETRAIN" if self.current_epoch <= self.pre_train_epochs else "TRAIN"
            f_1_score_text, recall_score_text, precision_score_text = metrics_processor.calculate_metric(self.classes,
                                                                                                         self.train_trust_answers,
                                                                                                         self.train_model_answers)

            text = "{}={} Loss_CL={:.5f} Loss_M={:.5f} Loss_L1={:.5f} Loss_Total={:.5f} Accuracy_CL={:.5f} " \
                   "{} {} {} ".format(prefix, self.current_epoch, loss_classification_sum,
                                      loss_m_sum,
                                      loss_l1_sum,
                                      loss_total,
                                      accuracy_sum,
                                      f_1_score_text,
                                      recall_score_text,
                                      precision_score_text)

            P.write_to_log(text)
            self.am_model.train(mode=False)
            if self.current_epoch % self.test_each_epoch == 0:
                test_loss, _ = self.test(self.am_model, self.test_set, self.l_loss, self.m_loss)
            if self.current_epoch % 200 == 0:
                self.take_snapshot(self.train_segments_set, self.am_model, "TRAIN_{}".format(self.current_epoch))
                self.take_snapshot(self.test_set, self.am_model, "TEST_{}".format(self.current_epoch))

            self.clear_temp_metrics()
            self.current_epoch += 1