예제 #1
0
class QuasiSiameseNetwork(object):
    def __init__(self, args):
        input_size = (args.inputSize, args.inputSize)

        self.run_name = args.runName
        self.input_size = input_size
        self.lr = args.learningRate

        self.criterion = nnloss.MSELoss()

        self.transforms = {}

        self.model = SiameseNetwork()

        if torch.cuda.device_count() > 1:
            logger.info('Using {} GPUs'.format(torch.cuda.device_count()))
            self.model = nn.DataParallel(self.model)

        for s in ('train', 'validation', 'test'):
            self.transforms[s] = get_pretrained_iv3_transforms(s)

        logger.debug('Num params: {}'.format(
            len([_ for _ in self.model.parameters()])))

        self.optimizer = Adam(self.model.parameters(), lr=self.lr)
        self.lr_scheduler = ReduceLROnPlateau(self.optimizer,
                                              factor=0.1,
                                              patience=10,
                                              min_lr=1e-5,
                                              verbose=True)

    def run_epoch(self,
                  epoch,
                  loader,
                  device,
                  phase='train',
                  accuracy_threshold=0.1):
        assert phase in ('train', 'validation', 'test')

        self.model = self.model.to(device)

        self.model.eval()
        if phase == 'train':
            self.model.train()  # Set model to training mode

        running_loss = 0.0
        running_corrects = 0
        running_n = 0.0

        if not (phase == 'train'):
            prediction_file = open(
                os.path.join(
                    loader.dataset.directory,
                    '{}_epoch_{:03d}_predictions.txt'.format(
                        self.run_name, epoch)), 'w+')
            prediction_file.write('filename label prediction\n')

        for idx, (filename, image1, image2, labels) in enumerate(loader, 1):
            image1 = image1.to(device)
            image2 = image2.to(device)
            labels = labels.float().to(device)

            if phase == 'train':
                # zero the parameter gradients
                self.optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                outputs = self.model(image1, image2).squeeze()
                loss = self.criterion(outputs, labels)

                if not (phase == 'train'):
                    prediction_file.writelines([
                        '{} {} {}\n'.format(*line)
                        for line in zip(filename, labels.tolist(),
                                        outputs.clamp(0, 1).tolist())
                    ])

                if phase == 'train':
                    loss.backward()
                    self.optimizer.step()

            running_loss += loss.item() * image1.size(0)
            running_corrects += (
                outputs - labels.data).abs().le(accuracy_threshold).sum()
            running_n += image1.size(0)

            if idx % 1 == 0:
                logger.debug(
                    'Epoch: {:03d} Phase: {:10s} Batch {:04d}/{:04d}: Loss: {:.4f} Accuracy: {:.4f}'
                    .format(epoch, phase, idx, len(loader),
                            running_loss / running_n,
                            running_corrects.double() / running_n))

        epoch_loss = running_loss / running_n
        epoch_accuracy = running_corrects.double() / running_n

        if not (phase == 'train'):
            prediction_file.write('Epoch {:03d} Accuracy: {:.4f}\n'.format(
                epoch, epoch_accuracy))
            prediction_file.close()

        logger.info(
            'Epoch {:03d} Phase: {:10s} Loss: {:.4f} Accuracy: {:.4f}'.format(
                epoch, phase, epoch_loss, epoch_accuracy))

        return epoch_loss, epoch_accuracy

    def train(self, n_epochs, datasets, device, save_path):
        train_set, train_loader = datasets.load('train')
        validation_set, validation_loader = datasets.load('validation')

        best_accuracy, best_model_wts = 0.0, copy.deepcopy(
            self.model.state_dict())

        start_time = time.time()
        for epoch in range(1, n_epochs + 1):
            # train network
            train_loss, train_accuracy = self.run_epoch(epoch,
                                                        train_loader,
                                                        device,
                                                        phase='train')

            # eval on validation
            validation_loss, validation_accuracy = self.run_epoch(
                epoch, validation_loader, device, phase='validation')

            self.lr_scheduler.step(validation_loss)

            if validation_accuracy > best_accuracy:
                best_accuracy = validation_accuracy
                best_model_wts = copy.deepcopy(self.model.state_dict())

                logger.info('Epoch {:03d} Checkpoint: Saving to {}'.format(
                    epoch, save_path))
                torch.save(best_model_wts, save_path)

        time_elapsed = time.time() - start_time
        logger.info('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))

        logger.info('Best validation Accuracy: {:4f}.'.format(best_accuracy))

    def test(self, datasets, device, load_path):
        self.model.load_state_dict(torch.load(load_path))
        test_set, test_loader = datasets.load('test')
        self.run_epoch(1, test_loader, device, phase='test')
예제 #2
0
class QuasiSiameseNetwork(object):
    def __init__(self, args):
        train_config = args.outputType
        net_config = args.networkType
        n_freeze = args.numFreeze
        input_size = (args.inputSize, args.inputSize)

        assert train_config in ("soft-targets", "softmax")
        assert net_config in ("pre-trained", "full")
        self.train_config = train_config
        self.input_size = input_size
        self.lr = args.learningRate

        if train_config == "soft-targets":
            self.n_classes = 1
            self.criterion = nnloss.BCEWithLogitsLoss()
        else:
            # TODO: weights
            self.n_classes = 4
            self.criterion = nnloss.CrossEntropyLoss()

        self.transforms = {}
        if net_config == "pre-trained":
            self.model = SiameseNetwork(self.n_classes, n_freeze=n_freeze)

            for s in ("train", "val", "test"):
                self.transforms[s] = get_pretrained_iv3_transforms(s)

        else:
            self.model = build_net(input_size, self.n_classes)
            assert input_size[0] == input_size[1]
            for s in ("train", "val", "test"):
                self.transforms[s] = get_transforms(s, input_size[0])

        log.debug("Num params: {}".format(
            len([_ for _ in self.model.parameters()])))

        self.optimizer = Adam(self.model.parameters(), lr=self.lr)
        self.lr_scheduler = ReduceLROnPlateau(self.optimizer,
                                              factor=0.1,
                                              patience=10,
                                              min_lr=1e-5,
                                              verbose=True)

    def run_epoch(self, epoch, loader, device, phase="train"):
        assert phase in ("train", "val", "test")

        self.model = self.model.to(device)

        log.info("Phase: {}, Epoch: {}".format(phase, epoch))

        if phase == 'train':
            self.model.train()  # Set model to training mode
        else:
            self.model.eval()

        running_loss = 0.0
        running_corrects = 0
        running_n = 0.0

        rolling_eval = RollingEval()

        for idx, (image1, image2, labels) in enumerate(loader):
            image1 = image1.to(device)
            image2 = image2.to(device)
            labels = labels.to(device)

            if phase == "train":
                # zero the parameter gradients
                self.optimizer.zero_grad()

            with torch.set_grad_enabled(phase == "train"):
                outputs = self.model(image1, image2)
                _, preds = torch.max(outputs, 1)
                _, labels = torch.max(labels, 1)
                loss = self.criterion(outputs, labels)

                if phase == 'train':
                    loss.backward()
                    self.optimizer.step()

                rolling_eval.add(labels, preds)

            running_loss += loss.item() * image1.size(0)
            running_corrects += torch.sum(preds == labels.data)
            running_n += image1.size(0)

            if idx % 1 == 0:
                log.info(
                    "\tBatch {}: Loss: {:.4f} Acc: {:.4f} F1: {:.4f} Recall: {:.4f}"
                    .format(idx, running_loss / running_n,
                            running_corrects.double() / running_n,
                            rolling_eval.f1_score(), rolling_eval.recall()))

        epoch_loss = running_loss / running_n
        epoch_acc = running_corrects.double() / \
            running_n
        epoch_f1 = rolling_eval.f1_score()
        epoch_recall = rolling_eval.recall()

        log.info('{}: Loss: {:.4f} \nReport: {}'.format(
            phase, epoch_loss, rolling_eval.every_measure()))

        return epoch_loss, epoch_acc, epoch_f1

    def train(self, n_epochs, datasets, device, save_path):
        train_set, train_loader = datasets.load("train")
        val_set, val_loader = datasets.load("val")

        best_f1, best_model_wts = 0.0, copy.deepcopy(self.model.state_dict())

        start_time = time.time()
        for epoch in range(n_epochs):
            # train network
            train_loss, train_acc, train_f1 = self.run_epoch(epoch,
                                                             train_loader,
                                                             device,
                                                             phase="train")

            # eval on validation
            val_loss, val_acc, val_f1 = self.run_epoch(epoch,
                                                       val_loader,
                                                       device,
                                                       phase="val")

            self.lr_scheduler.step(val_loss)

            if val_f1 > best_f1:
                best_f1 = val_f1
                best_model_wts = copy.deepcopy(self.model.state_dict())

                log.info("Checkpoint: Saving to {}".format(save_path))
                torch.save(best_model_wts, save_path)

        time_elapsed = time.time() - start_time
        log.info('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))

        log.info('Best val F1: {:4f}.'.format(best_f1))

    def test(self, datasets, device, load_path):
        self.model.load_state_dict(torch.load(load_path))
        test_set, test_loader = datasets.load("test")
        self.run_epoch(0, test_loader, device, phase="test")