예제 #1
0
 def __init__(self, w_1=None, w_2=None, w_3=None):
     self.w_1 = w_1 if not w_1 is None else weight_vector(
         SimpleModel().parameters())
     self.w_2 = w_2 if not w_2 is None else weight_vector(
         SimpleModel().parameters())
     self.w_3 = nn.Parameter(w_3 if not w_3 is None else weight_vector(
         SimpleModel().parameters()))
예제 #2
0
    def start(self):
        self.trainer_a.start()
        self.trainer_b.start()

        w_1 = weight_vector(self.trainer_a.model.parameters())
        w_2 = weight_vector(self.trainer_b.model.parameters())

        self.elbow_trainer.set_weights(w_1, w_2)
        self.elbow_trainer.start()
예제 #3
0
    def init_models(self):
        if self.config.hp.model_name == 'conv':
            self.model = ConvModel(self.config.hp.conv_model_config)
        elif self.config.hp.model_name == 'fast_resnet':
            model = FastResNet(n_classes=10, n_input_channels=3).nn
            self.model = convert_sequential_model_to_op(weight_vector(
                model.parameters()),
                                                        model,
                                                        detach=True)

            assert len(weight_vector(model.parameters())) == len(
                weight_vector(self.model.parameters()))
        else:
            raise NotImplementedError(
                f'Model {self.config.hp.model_name} is not supported')

        self.model = self.model.to(self.device_name)
예제 #4
0
    def __init__(self,
                 torch_model_cls,
                 num_models: int,
                 coords_init_strategy: str = 'isotropic_normal'):
        super(PlaneEnsemble, self).__init__(torch_model_cls, num_models)

        assert coords_init_strategy in self.COORDS_INIT_STRATEGIES, \
            f"Unknown init strategy: {coords_init_strategy}"

        self.register_param(
            'coords',
            torch.stack([
                self.sample_coords(coords_init_strategy)
                for _ in range(num_models)
            ]))
        self.register_param('origin_param',
                            weight_vector(torch_model_cls().parameters()))
        self.register_param('right_param',
                            weight_vector(torch_model_cls().parameters()))
        self.register_param('up_param',
                            weight_vector(torch_model_cls().parameters()))
예제 #5
0
    def log_weight_stats(self):
        if self.config.hp.ensemble_type == 'plane':
            # Weight l2 norms
            self.writer.add_scalar('Stats/norms/origin_param',
                                   self.model.origin_param.norm().item(),
                                   self.num_iters_done)
            self.writer.add_scalar('Stats/norms/right_param',
                                   self.model.right_param.norm().item(),
                                   self.num_iters_done)
            self.writer.add_scalar('Stats/norms/up_param',
                                   self.model.up_param.norm().item(),
                                   self.num_iters_done)

            # Grad norms
            self.writer.add_scalar('Stats/grad_norms/origin_param',
                                   self.model.origin_param.grad.norm().item(),
                                   self.num_iters_done)
            self.writer.add_scalar('Stats/grad_norms/right_param',
                                   self.model.right_param.grad.norm().item(),
                                   self.num_iters_done)
            self.writer.add_scalar('Stats/grad_norms/up_param',
                                   self.model.up_param.grad.norm().item(),
                                   self.num_iters_done)
        elif self.config.hp.ensemble_type == 'mapping':
            mapping_weight_norm = weight_vector(
                self.model.mapping.parameters()).norm()
            mapping_grad_norm = torch.cat([
                p.grad.view(-1) for p in self.model.mapping.parameters()
            ]).norm()

            self.writer.add_scalar('Stats/norms/mapping',
                                   mapping_weight_norm.item(),
                                   self.num_iters_done)
            self.writer.add_scalar('Stats/grad_norms/mapping',
                                   mapping_grad_norm.item(),
                                   self.num_iters_done)
        else:
            pass

        if self.config.hp.ensemble_type in ('mapping', 'plane'):
            self.writer.add_histogram(
                'Coords/x', self.model.coords[:, 0].cpu().detach().numpy(),
                self.num_iters_done)
            self.writer.add_histogram(
                'Coords/y', self.model.coords[:, 1].cpu().detach().numpy(),
                self.num_iters_done)
            self.writer.add_scalar('Stats/grad_norms/coords',
                                   self.model.coords.grad.norm().item(),
                                   self.num_iters_done)
예제 #6
0
    def train_on_batch(self, batch):
        x = batch[0].to(self.device_name)
        y = batch[1].to(self.device_name)

        preds = self.model(x)
        loss = self.criterion(preds, y).sum()
        acc = (preds.argmax(dim=1) == y).float().mean()
        norm = weight_vector(self.model.parameters()).norm()

        self.optim.zero_grad()
        loss.backward()
        self.optim.step()

        self.writer.add_scalar('Train/loss', loss.item(), self.num_iters_done)
        self.writer.add_scalar('Train/acc', acc.item(), self.num_iters_done)
        self.writer.add_scalar('Stats/weights_norm', norm.item(),
                               self.num_iters_done)

        if not self.scheduler is None:
            self.scheduler.step()
            self.writer.add_scalar('Stats/lr',
                                   self.scheduler.get_lr()[0],
                                   self.num_iters_done)
예제 #7
0
    def __init__(self):
        super(LineModel, self).__init__()

        self.w_1 = nn.Parameter(weight_vector(SimpleModel().parameters()))
        self.w_2 = nn.Parameter(weight_vector(SimpleModel().parameters()))
        self.param_sizes = param_sizes(SimpleModel().parameters())
예제 #8
0
    def __init__(self, torch_model_cls, num_models:int):
        super(NormalEnsemble, self).__init__(torch_model_cls, num_models)

        for i in range(num_models):
            self.register_param(f'model_{i}', weight_vector(torch_model_cls().parameters()))