def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.scores = nn.ParameterList([
            nn.Parameter(module_util.mask_init(self))
            for _ in range(pargs.num_seed_tasks_learned)
        ])
        for s in self.scores:
            s.requires_grad = False
        self.scores.requires_grad = False
        if pargs.train_weight_tasks == 0:
            self.weight.requires_grad = False

        if pargs.start_at_optimal:
            self.basis_alphas = nn.ParameterList([
                nn.Parameter(torch.eye(pargs.num_seed_tasks_learned)[i])
                for i in range(pargs.num_seed_tasks_learned)
            ] + [
                nn.Parameter(
                    torch.ones(pargs.num_seed_tasks_learned) /
                    pargs.num_seed_tasks_learned)
                for _ in range(pargs.num_seed_tasks_learned, pargs.num_tasks)
            ])
        else:
            self.basis_alphas = nn.ParameterList([
                nn.Parameter(
                    torch.ones(pargs.num_seed_tasks_learned) /
                    pargs.num_seed_tasks_learned)
                for _ in range(pargs.num_tasks)
            ])
        self.sparsity = pargs.sparsity
Beispiel #2
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # initialize the scores
        self.scores = nn.Parameter(module_util.mask_init(self))

        # Turn the gradient on the weights off
        self.weight.requires_grad = False
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.scores = nn.ParameterList([
            nn.Parameter(module_util.mask_init(self))
            for _ in range(pargs.num_tasks)
        ])

        if pargs.train_weight_tasks == 0:
            self.weight.requires_grad = False
Beispiel #4
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        mask_init = module_util.mask_init(self)
        # initialize the scores
        d = {set: nn.Parameter(mask_init.clone()) for set in pargs.set}
        d['INIT'] = nn.Parameter(mask_init.clone())
        self.scores = nn.ParameterDict(d)

        # Turn the gradient on the weights off
        self.weight.requires_grad = False
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # initialize the scores
        self.scores = nn.Parameter(module_util.mask_init(self))

        # Turn the gradient on the weights off
        if pargs.train_weight_tasks == 0:
            self.weight.requires_grad = False

        # default sparsity
        self.sparsity = pargs.sparsity