def __init__(self, w_1=None, w_2=None, w_3=None): self.w_1 = w_1 if not w_1 is None else weight_vector( SimpleModel().parameters()) self.w_2 = w_2 if not w_2 is None else weight_vector( SimpleModel().parameters()) self.w_3 = nn.Parameter(w_3 if not w_3 is None else weight_vector( SimpleModel().parameters()))
def start(self): self.trainer_a.start() self.trainer_b.start() w_1 = weight_vector(self.trainer_a.model.parameters()) w_2 = weight_vector(self.trainer_b.model.parameters()) self.elbow_trainer.set_weights(w_1, w_2) self.elbow_trainer.start()
def init_models(self): if self.config.hp.model_name == 'conv': self.model = ConvModel(self.config.hp.conv_model_config) elif self.config.hp.model_name == 'fast_resnet': model = FastResNet(n_classes=10, n_input_channels=3).nn self.model = convert_sequential_model_to_op(weight_vector( model.parameters()), model, detach=True) assert len(weight_vector(model.parameters())) == len( weight_vector(self.model.parameters())) else: raise NotImplementedError( f'Model {self.config.hp.model_name} is not supported') self.model = self.model.to(self.device_name)
def __init__(self, torch_model_cls, num_models: int, coords_init_strategy: str = 'isotropic_normal'): super(PlaneEnsemble, self).__init__(torch_model_cls, num_models) assert coords_init_strategy in self.COORDS_INIT_STRATEGIES, \ f"Unknown init strategy: {coords_init_strategy}" self.register_param( 'coords', torch.stack([ self.sample_coords(coords_init_strategy) for _ in range(num_models) ])) self.register_param('origin_param', weight_vector(torch_model_cls().parameters())) self.register_param('right_param', weight_vector(torch_model_cls().parameters())) self.register_param('up_param', weight_vector(torch_model_cls().parameters()))
def log_weight_stats(self): if self.config.hp.ensemble_type == 'plane': # Weight l2 norms self.writer.add_scalar('Stats/norms/origin_param', self.model.origin_param.norm().item(), self.num_iters_done) self.writer.add_scalar('Stats/norms/right_param', self.model.right_param.norm().item(), self.num_iters_done) self.writer.add_scalar('Stats/norms/up_param', self.model.up_param.norm().item(), self.num_iters_done) # Grad norms self.writer.add_scalar('Stats/grad_norms/origin_param', self.model.origin_param.grad.norm().item(), self.num_iters_done) self.writer.add_scalar('Stats/grad_norms/right_param', self.model.right_param.grad.norm().item(), self.num_iters_done) self.writer.add_scalar('Stats/grad_norms/up_param', self.model.up_param.grad.norm().item(), self.num_iters_done) elif self.config.hp.ensemble_type == 'mapping': mapping_weight_norm = weight_vector( self.model.mapping.parameters()).norm() mapping_grad_norm = torch.cat([ p.grad.view(-1) for p in self.model.mapping.parameters() ]).norm() self.writer.add_scalar('Stats/norms/mapping', mapping_weight_norm.item(), self.num_iters_done) self.writer.add_scalar('Stats/grad_norms/mapping', mapping_grad_norm.item(), self.num_iters_done) else: pass if self.config.hp.ensemble_type in ('mapping', 'plane'): self.writer.add_histogram( 'Coords/x', self.model.coords[:, 0].cpu().detach().numpy(), self.num_iters_done) self.writer.add_histogram( 'Coords/y', self.model.coords[:, 1].cpu().detach().numpy(), self.num_iters_done) self.writer.add_scalar('Stats/grad_norms/coords', self.model.coords.grad.norm().item(), self.num_iters_done)
def train_on_batch(self, batch): x = batch[0].to(self.device_name) y = batch[1].to(self.device_name) preds = self.model(x) loss = self.criterion(preds, y).sum() acc = (preds.argmax(dim=1) == y).float().mean() norm = weight_vector(self.model.parameters()).norm() self.optim.zero_grad() loss.backward() self.optim.step() self.writer.add_scalar('Train/loss', loss.item(), self.num_iters_done) self.writer.add_scalar('Train/acc', acc.item(), self.num_iters_done) self.writer.add_scalar('Stats/weights_norm', norm.item(), self.num_iters_done) if not self.scheduler is None: self.scheduler.step() self.writer.add_scalar('Stats/lr', self.scheduler.get_lr()[0], self.num_iters_done)
def __init__(self): super(LineModel, self).__init__() self.w_1 = nn.Parameter(weight_vector(SimpleModel().parameters())) self.w_2 = nn.Parameter(weight_vector(SimpleModel().parameters())) self.param_sizes = param_sizes(SimpleModel().parameters())
def __init__(self, torch_model_cls, num_models:int): super(NormalEnsemble, self).__init__(torch_model_cls, num_models) for i in range(num_models): self.register_param(f'model_{i}', weight_vector(torch_model_cls().parameters()))