Ejemplo n.º 1
0
class QuasiSiameseNetwork(object):
    def __init__(self, args):
        train_config = args.outputType
        net_config = args.networkType
        n_freeze = args.numFreeze
        input_size = (args.inputSize, args.inputSize)

        assert train_config in ("soft-targets", "softmax")
        assert net_config in ("pre-trained", "full")
        self.train_config = train_config
        self.input_size = input_size
        self.lr = args.learningRate

        if train_config == "soft-targets":
            self.n_classes = 1
            self.criterion = nnloss.BCEWithLogitsLoss()
        else:
            # TODO: weights
            self.n_classes = 4
            self.criterion = nnloss.CrossEntropyLoss()

        self.transforms = {}
        if net_config == "pre-trained":
            self.model = SiameseNetwork(self.n_classes, n_freeze=n_freeze)

            for s in ("train", "val", "test"):
                self.transforms[s] = get_pretrained_iv3_transforms(s)

        else:
            self.model = build_net(input_size, self.n_classes)
            assert input_size[0] == input_size[1]
            for s in ("train", "val", "test"):
                self.transforms[s] = get_transforms(s, input_size[0])

        log.debug("Num params: {}".format(
            len([_ for _ in self.model.parameters()])))

        self.optimizer = Adam(self.model.parameters(), lr=self.lr)
        self.lr_scheduler = ReduceLROnPlateau(self.optimizer,
                                              factor=0.1,
                                              patience=10,
                                              min_lr=1e-5,
                                              verbose=True)

    def run_epoch(self, epoch, loader, device, phase="train"):
        assert phase in ("train", "val", "test")

        self.model = self.model.to(device)

        log.info("Phase: {}, Epoch: {}".format(phase, epoch))

        if phase == 'train':
            self.model.train()  # Set model to training mode
        else:
            self.model.eval()

        running_loss = 0.0
        running_corrects = 0
        running_n = 0.0

        rolling_eval = RollingEval()

        for idx, (image1, image2, labels) in enumerate(loader):
            image1 = image1.to(device)
            image2 = image2.to(device)
            labels = labels.to(device)

            if phase == "train":
                # zero the parameter gradients
                self.optimizer.zero_grad()

            with torch.set_grad_enabled(phase == "train"):
                outputs = self.model(image1, image2)
                _, preds = torch.max(outputs, 1)
                _, labels = torch.max(labels, 1)
                loss = self.criterion(outputs, labels)

                if phase == 'train':
                    loss.backward()
                    self.optimizer.step()

                rolling_eval.add(labels, preds)

            running_loss += loss.item() * image1.size(0)
            running_corrects += torch.sum(preds == labels.data)
            running_n += image1.size(0)

            if idx % 1 == 0:
                log.info(
                    "\tBatch {}: Loss: {:.4f} Acc: {:.4f} F1: {:.4f} Recall: {:.4f}"
                    .format(idx, running_loss / running_n,
                            running_corrects.double() / running_n,
                            rolling_eval.f1_score(), rolling_eval.recall()))

        epoch_loss = running_loss / running_n
        epoch_acc = running_corrects.double() / \
            running_n
        epoch_f1 = rolling_eval.f1_score()
        epoch_recall = rolling_eval.recall()

        log.info('{}: Loss: {:.4f} \nReport: {}'.format(
            phase, epoch_loss, rolling_eval.every_measure()))

        return epoch_loss, epoch_acc, epoch_f1

    def train(self, n_epochs, datasets, device, save_path):
        train_set, train_loader = datasets.load("train")
        val_set, val_loader = datasets.load("val")

        best_f1, best_model_wts = 0.0, copy.deepcopy(self.model.state_dict())

        start_time = time.time()
        for epoch in range(n_epochs):
            # train network
            train_loss, train_acc, train_f1 = self.run_epoch(epoch,
                                                             train_loader,
                                                             device,
                                                             phase="train")

            # eval on validation
            val_loss, val_acc, val_f1 = self.run_epoch(epoch,
                                                       val_loader,
                                                       device,
                                                       phase="val")

            self.lr_scheduler.step(val_loss)

            if val_f1 > best_f1:
                best_f1 = val_f1
                best_model_wts = copy.deepcopy(self.model.state_dict())

                log.info("Checkpoint: Saving to {}".format(save_path))
                torch.save(best_model_wts, save_path)

        time_elapsed = time.time() - start_time
        log.info('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))

        log.info('Best val F1: {:4f}.'.format(best_f1))

    def test(self, datasets, device, load_path):
        self.model.load_state_dict(torch.load(load_path))
        test_set, test_loader = datasets.load("test")
        self.run_epoch(0, test_loader, device, phase="test")
Ejemplo n.º 2
0
class QuasiSiameseNetwork(object):
    def __init__(self, args):
        input_size = (args.inputSize, args.inputSize)

        self.run_name = args.runName
        self.input_size = input_size
        self.lr = args.learningRate

        self.criterion = nnloss.MSELoss()

        self.transforms = {}

        self.model = SiameseNetwork()

        if torch.cuda.device_count() > 1:
            logger.info('Using {} GPUs'.format(torch.cuda.device_count()))
            self.model = nn.DataParallel(self.model)

        for s in ('train', 'validation', 'test'):
            self.transforms[s] = get_pretrained_iv3_transforms(s)

        logger.debug('Num params: {}'.format(
            len([_ for _ in self.model.parameters()])))

        self.optimizer = Adam(self.model.parameters(), lr=self.lr)
        self.lr_scheduler = ReduceLROnPlateau(self.optimizer,
                                              factor=0.1,
                                              patience=10,
                                              min_lr=1e-5,
                                              verbose=True)

    def run_epoch(self,
                  epoch,
                  loader,
                  device,
                  phase='train',
                  accuracy_threshold=0.1):
        assert phase in ('train', 'validation', 'test')

        self.model = self.model.to(device)

        self.model.eval()
        if phase == 'train':
            self.model.train()  # Set model to training mode

        running_loss = 0.0
        running_corrects = 0
        running_n = 0.0

        if not (phase == 'train'):
            prediction_file = open(
                os.path.join(
                    loader.dataset.directory,
                    '{}_epoch_{:03d}_predictions.txt'.format(
                        self.run_name, epoch)), 'w+')
            prediction_file.write('filename label prediction\n')

        for idx, (filename, image1, image2, labels) in enumerate(loader, 1):
            image1 = image1.to(device)
            image2 = image2.to(device)
            labels = labels.float().to(device)

            if phase == 'train':
                # zero the parameter gradients
                self.optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                outputs = self.model(image1, image2).squeeze()
                loss = self.criterion(outputs, labels)

                if not (phase == 'train'):
                    prediction_file.writelines([
                        '{} {} {}\n'.format(*line)
                        for line in zip(filename, labels.tolist(),
                                        outputs.clamp(0, 1).tolist())
                    ])

                if phase == 'train':
                    loss.backward()
                    self.optimizer.step()

            running_loss += loss.item() * image1.size(0)
            running_corrects += (
                outputs - labels.data).abs().le(accuracy_threshold).sum()
            running_n += image1.size(0)

            if idx % 1 == 0:
                logger.debug(
                    'Epoch: {:03d} Phase: {:10s} Batch {:04d}/{:04d}: Loss: {:.4f} Accuracy: {:.4f}'
                    .format(epoch, phase, idx, len(loader),
                            running_loss / running_n,
                            running_corrects.double() / running_n))

        epoch_loss = running_loss / running_n
        epoch_accuracy = running_corrects.double() / running_n

        if not (phase == 'train'):
            prediction_file.write('Epoch {:03d} Accuracy: {:.4f}\n'.format(
                epoch, epoch_accuracy))
            prediction_file.close()

        logger.info(
            'Epoch {:03d} Phase: {:10s} Loss: {:.4f} Accuracy: {:.4f}'.format(
                epoch, phase, epoch_loss, epoch_accuracy))

        return epoch_loss, epoch_accuracy

    def train(self, n_epochs, datasets, device, save_path):
        train_set, train_loader = datasets.load('train')
        validation_set, validation_loader = datasets.load('validation')

        best_accuracy, best_model_wts = 0.0, copy.deepcopy(
            self.model.state_dict())

        start_time = time.time()
        for epoch in range(1, n_epochs + 1):
            # train network
            train_loss, train_accuracy = self.run_epoch(epoch,
                                                        train_loader,
                                                        device,
                                                        phase='train')

            # eval on validation
            validation_loss, validation_accuracy = self.run_epoch(
                epoch, validation_loader, device, phase='validation')

            self.lr_scheduler.step(validation_loss)

            if validation_accuracy > best_accuracy:
                best_accuracy = validation_accuracy
                best_model_wts = copy.deepcopy(self.model.state_dict())

                logger.info('Epoch {:03d} Checkpoint: Saving to {}'.format(
                    epoch, save_path))
                torch.save(best_model_wts, save_path)

        time_elapsed = time.time() - start_time
        logger.info('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))

        logger.info('Best validation Accuracy: {:4f}.'.format(best_accuracy))

    def test(self, datasets, device, load_path):
        self.model.load_state_dict(torch.load(load_path))
        test_set, test_loader = datasets.load('test')
        self.run_epoch(1, test_loader, device, phase='test')
Ejemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--margin',
                        type=float,
                        default=1.0,
                        help='margin for contrastive loss')
    parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--num_epochs',
                        type=int,
                        default=100,
                        help='number of epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=1000,
                        help='batch size')
    parser.add_argument('--log_dir', required=True, help='log directory')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of workers for data loading')
    opt = parser.parse_args()
    opt.use_gpu = torch.cuda.is_available()

    if not os.path.exists(opt.log_dir):
        os.makedirs(opt.log_dir)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    train_dataset = datasets.MNIST(root='./data',
                                   train=True,
                                   download=True,
                                   transform=transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              num_workers=opt.num_workers)

    siamese_net = SiameseNetwork()
    if opt.use_gpu:
        siamese_net = siamese_net.cuda()

    criterion = ContrastiveLoss()
    optimizer = torch.optim.SGD(siamese_net.parameters(),
                                lr=opt.lr,
                                momentum=opt.momentum)

    history = {}
    history['loss'] = []

    for epoch in range(opt.num_epochs):
        num_itrs = len(train_loader)
        running_loss = 0
        for itr, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad()

            x1, x2, t = create_pairs(inputs, labels)
            x1, x2, t = Variable(x1), Variable(x2), Variable(t)
            if opt.use_gpu:
                x1, x2, t = x1.cuda(), x2.cuda(), t.cuda()

            y1, y2 = siamese_net(x1, x2)
            loss = criterion(y1, y2, t)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            sys.stdout.write('\r\033[Kitr [{}/{}], loss: {:.4f}'.format(
                itr, num_itrs, loss.item()))
            sys.stdout.flush()

        history['loss'].append(running_loss / num_itrs)
        sys.stdout.write('\r\033[Kepoch [{}/{}], loss: {:.4f}'.format(
            epoch + 1, opt.num_epochs, running_loss / num_itrs))
        sys.stdout.write('\n')

    torch.save(siamese_net.state_dict(), os.path.join(opt.log_dir,
                                                      'model.pth'))

    with open(os.path.join(opt.log_dir, 'history.pkl'), 'wb') as f:
        pickle.dump(history, f)

    plt.plot(history['loss'])
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid()
    plt.savefig(os.path.join(opt.log_dir, 'loss.png'))