Ejemplo n.º 1
0
    def inference_classification(self):
        self.model.eval()
        self.model.module.mode = 0
        val_accuracy = AverageMeter()

        with torch.no_grad():
            final_itr = tqdm(self.test_loader, ncols=80, desc='Inference (instance) ...')

            for i, (input, labels) in enumerate(final_itr):
                input  = input.to(self.device)
                labels = labels.to(self.device)

                logits = self.model(input)[0]
                preds  = self.model.module.pooling.predictions(logits)

                accuracy = (preds == labels).sum().item() / labels.shape[0]
                val_accuracy.append(accuracy)

                final_itr.set_description('--- (test) | Accuracy: {:.3f}  :'.format(
                    val_accuracy.avg())
                )

        err = val_accuracy.avg()
        fp = open(os.path.join(self.logdir, 'meanscores.csv'), 'w')
        fp.write('Accuracy: {:.4f} \n'.format(err))
        fp.close()
Ejemplo n.º 2
0
    def _train_epoch(self, epoch):
        logits_losses = AverageMeter()
        bag_losses = AverageMeter()
        center_losses = AverageMeter()
        train_accuracy = AverageMeter()
        bag_accuracy = AverageMeter()

        self.center_loss.train()
        self.model.train()
        self.model.module.mode = 1  # combined mode (instance & bag prediction)

        self.adjust_lr_staircase(
            self.optimizer.param_groups,
            [0.001, 0.01],  # initial values for features and pooling 
            epoch + 1,
            [10, 15, 17],  # set the steps to adjust accordingly
            0.1  # reduce by this value
        )
        pbar = tqdm(self.train_loader, ncols=160, desc=' ')
        for i, (inputs, labels, all_labels) in enumerate(pbar):

            inputs = inputs.to(self.device)
            labels = labels.to(self.device)
            all_labels = all_labels.view(-1).to(self.device).long()

            self.optimizer.zero_grad()
            self.optimizerpool.zero_grad()

            # get features and logits
            inst_logits, inst_feat, bag_embed, bag_logits = self.model(inputs)

            loss_soft = self.model.module.pooling.loss(inst_logits, all_labels)
            loss_bag = self.model.module.pooling.loss(bag_logits, labels)

            # default : clustering instances
            #loss_center = self.center_loss(inst_embed, all_labels)
            # other : clustering bags / instances
            loss_center = self.center_loss(bag_feat, labels)
            # alpha, lambda and bag weight
            loss = 1.0 * loss_soft + loss_center * 1.0 + loss_bag * 1.0

            preds_bag = self.model.module.pooling.predictions(bag_logits)
            preds = self.model.module.pooling.predictions(inst_logits)
            accuracy = (preds == all_labels).sum().item() / all_labels.shape[0]
            accuracy_bag = (preds_bag == labels).sum().item() / labels.shape[0]

            loss_cen = loss_center.item()
            loss_val = loss_soft.item()
            loss_slide = loss_bag.item()
            logits_losses.append(loss_val)
            center_losses.append(loss_cen)
            bag_losses.append(loss_slide)
            train_accuracy.append(accuracy)
            bag_accuracy.append(accuracy_bag)

            loss.backward()
            self.optimizer.step()
            for param in self.center_loss.parameters():
                # center loss weight should match as in the loss function
                param.grad.data *= (1. / 1.0)
            self.optimizerpool.step()

            pbar.set_description(
                '--- (train) | Loss(I): {:.4f} | Loss(C): {:.4f} | Loss(B): {:.4f} | ACC(I): {:.3f} | ACC(B): {:.3f} :'
                .format(logits_losses.avg(), center_losses.avg(),
                        bag_losses.avg(), train_accuracy.avg(),
                        bag_accuracy.avg()))

        step = epoch + 1
        self.writer.add_scalar('training/loss_i', logits_losses.avg(), step)
        self.writer.add_scalar('training/loss_c', center_losses.avg(), step)
        self.writer.add_scalar('training/loss_b', bag_losses.avg(), step)
        self.writer.add_scalar('training/accuracy', train_accuracy.avg(), step)
        self.writer.add_scalar('training/accuracy_bag', bag_accuracy.avg(),
                               step)
        print()