Exemplo n.º 1
0
if args.loss == "cross_entropy":
    criterion = losses.seg_cross_entropy
else:
    criterion = losses.seg_ale_cross_entropy

# construct and load model
if args.swa_resume is not None:
    checkpoint = torch.load(args.swa_resume)
    model = SWAG(
        model_cfg.base,
        no_cov_mat=False,
        max_num_models=20,
        num_classes=num_classes,
        use_aleatoric=args.loss == "aleatoric",
    )
    model.cuda()
    model.load_state_dict(checkpoint["state_dict"])

    model.sample(0.0)
    bn_update(loaders["fine_tune"], model)
else:
    model = model_cfg.base(num_classes=num_classes,
                           use_aleatoric=args.loss == "aleatoric").cuda()
    checkpoint = torch.load(args.resume)
    start_epoch = checkpoint["epoch"]
    print(start_epoch)
    model.load_state_dict(checkpoint["state_dict"])

print(len(loaders["test"]))
if args.use_test:
    print("Using test dataset")
Exemplo n.º 2
0
class RegressionRunner(RegressionModel):
    def __init__(self,
                 base,
                 epochs,
                 criterion,
                 batch_size=50,
                 lr_init=1e-2,
                 momentum=0.9,
                 wd=1e-4,
                 swag_lr=1e-3,
                 swag_freq=1,
                 swag_start=50,
                 subspace_type='pca',
                 subspace_kwargs={'max_rank': 20},
                 use_cuda=False,
                 use_swag=False,
                 double_bias_lr=False,
                 model_variance=True,
                 num_samples=30,
                 scale=0.5,
                 const_lr=False,
                 *args,
                 **kwargs):

        self.base = base
        self.model = base(*args, **kwargs)
        num_pars = 0
        for p in self.model.parameters():
            num_pars += p.numel()
        print('number of parameters: ', num_pars)

        if use_cuda:
            self.model.cuda()

        if use_swag:
            self.swag_model = SWAG(base,
                                   subspace_type=subspace_type,
                                   subspace_kwargs=subspace_kwargs,
                                   *args,
                                   **kwargs)
            if use_cuda:
                self.swag_model.cuda()
        else:
            self.swag_model = None

        self.use_cuda = use_cuda

        if not double_bias_lr:
            pars = self.model.parameters()
        else:
            pars = []
            for name, module in self.model.named_parameters():
                if 'bias' in str(name):
                    print('Doubling lr of ', name)
                    pars.append({'params': module, 'lr': 2.0 * lr_init})
                else:
                    pars.append({'params': module, 'lr': lr_init})

        self.optimizer = torch.optim.SGD(pars,
                                         lr=lr_init,
                                         momentum=momentum,
                                         weight_decay=wd)

        self.const_lr = const_lr
        self.batch_size = batch_size

        # TODO: set up criterions better for classification
        if model_variance:
            self.criterion = criterion(noise_var=None)
        else:
            self.criterion = criterion(noise_var=1.0)

        if self.criterion.noise_var is not None:
            self.var = self.criterion.noise_var

        self.epochs = epochs

        self.lr_init = lr_init

        self.use_swag = use_swag
        self.swag_start = swag_start
        self.swag_lr = swag_lr
        self.swag_freq = swag_freq

        self.num_samples = num_samples
        self.scale = scale

    def train(self,
              model,
              loader,
              optimizer,
              criterion,
              lr_init=1e-2,
              epochs=3000,
              swag_model=None,
              swag=False,
              swag_start=2000,
              swag_freq=50,
              swag_lr=1e-3,
              print_freq=100,
              use_cuda=False,
              const_lr=False):
        # copied from pavels regression notebook
        if const_lr:
            lr = lr_init

        train_res_list = []
        for epoch in range(epochs):
            if not const_lr:
                t = (epoch + 1) / swag_start if swag else (epoch + 1) / epochs
                lr_ratio = swag_lr / lr_init if swag else 0.05

                if t <= 0.5:
                    factor = 1.0
                elif t <= 0.9:
                    factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
                else:
                    factor = lr_ratio

                lr = factor * lr_init
                adjust_learning_rate(optimizer, factor)

            train_res = utils.train_epoch(loader,
                                          model,
                                          criterion,
                                          optimizer,
                                          cuda=use_cuda,
                                          regression=True)
            train_res_list.append(train_res)
            if swag and epoch > swag_start:
                swag_model.collect_model(model)

            if (epoch % print_freq == 0 or epoch == epochs - 1):
                print('Epoch %d. LR: %g. Loss: %.4f' %
                      (epoch, lr, train_res['loss']))

        return train_res_list

    def fit(self, features, labels):
        self.features, self.labels = torch.FloatTensor(
            features), torch.FloatTensor(labels)

        # construct data loader
        self.data_loader = DataLoader(TensorDataset(self.features,
                                                    self.labels),
                                      batch_size=self.batch_size)

        # now train with pre-specified options
        result = self.train(model=self.model,
                            loader=self.data_loader,
                            optimizer=self.optimizer,
                            criterion=self.criterion,
                            lr_init=self.lr_init,
                            swag_model=self.swag_model,
                            swag=self.use_swag,
                            swag_start=self.swag_start,
                            swag_freq=self.swag_freq,
                            swag_lr=self.swag_lr,
                            use_cuda=self.use_cuda,
                            epochs=self.epochs,
                            const_lr=self.const_lr)

        if self.criterion.noise_var is not None:
            # another forwards pass through network to estimate noise variance
            preds, targets = utils.predictions(model=self.model,
                                               test_loader=self.data_loader,
                                               regression=True,
                                               cuda=self.use_cuda)
            self.var = np.power(np.linalg.norm(preds - targets),
                                2.0) / targets.shape[0]
            print(self.var)

        return result

    def predict(self, features, swag_model=None):
        """
        default prediction method is to use built in Low rank Gaussian
        SWA: scale = 0.0, num_samples = 1
        """
        swag_model = swag_model if swag_model is not None else self.swag_model

        if self.use_cuda:
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')

        with torch.no_grad():

            if swag_model is None:
                self.model.eval()
                preds = self.model(
                    torch.FloatTensor(features).to(device)).data.cpu()

                if preds.size(1) == 1:
                    var = torch.ones_like(preds[:, 0]).unsqueeze(1) * self.var
                else:
                    var = preds[:, 1].view(-1, 1)
                    preds = preds[:, 0].view(-1, 1)

                print(var.mean())

            else:
                prediction = 0
                sq_prediction = 0
                for _ in range(self.num_samples):
                    swag_model.sample(scale=self.scale)
                    current_prediction = swag_model(
                        torch.FloatTensor(features).to(device)).data.cpu()
                    prediction += current_prediction
                    if current_prediction.size(1) == 2:
                        #convert to standard deviation
                        current_prediction[:, 1] = current_prediction[:,
                                                                      1]**0.5

                    sq_prediction += current_prediction**2.0
                # preds = bma/(self.num_samples)

                # compute mean of prediction
                # \mu^*
                preds = (prediction[:, 0] / self.num_samples).view(-1, 1)

                # 1/M \sum(\sigma^2(x) + \mu^2(x)) - \mu*^2
                var = torch.sum(sq_prediction, 1, keepdim=True
                                ) / self.num_samples - preds.pow(2.0)

                # add variance if not heteroscedastic
                if prediction.size(1) == 1:
                    var = var + self.var

            return preds.numpy(), var.numpy()