示例#1
0
 def __init__(self, train_x, train_y, likelihood):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = ZeroMean()
     self.base_covar_module = ScaleKernel(RBFKernel(ard_num_dims=2))
     self.covar_module = AdditiveStructureKernel(GridInterpolationKernel(
         self.base_covar_module, grid_size=100, num_dims=1),
                                                 num_dims=2)
示例#2
0
 def __init__(self, train_x, train_y, likelihood):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = ZeroMean()
     self.base_covar_module = ScaleKernel(
         RBFKernel(ard_num_dims=2,
                   log_lengthscale_prior=SmoothedBoxPrior(
                       exp(-3), exp(3), sigma=0.1, log_transform=True)))
     self.covar_module = AdditiveStructureKernel(GridInterpolationKernel(
         self.base_covar_module, grid_size=100, num_dims=2),
                                                 num_dims=2)
    def test_forward(self):
        gam_kernel = MemoryEfficientGamKernel()
        x = torch.tensor([[1., 2., 3.], [1.1, 2.2, 3.3]])
        K = gam_kernel(x, x).evaluate()

        k = ScaleKernel(RBFKernel())
        k.initialize(outputscale=1.)
        as_kernel = AdditiveStructureKernel(k, 2)
        K2 = as_kernel(x, x).evaluate()

        np.testing.assert_allclose(K.detach().numpy(),
                                   K2.detach().numpy(),
                                   atol=1e-6)
    def test_postscale(self):
        x = torch.tensor([[1., 2., 3.], [1.1, 2.2, 3.3]])
        kbase = RBFKernel()
        kbase.initialize(lengthscale=torch.tensor([1.]))
        base_kernel = AdditiveStructureKernel(kbase, 3)
        proj_module = torch.nn.Linear(3, 3, bias=False)
        proj_module.weight.data = torch.eye(3, dtype=torch.float)
        proj_kernel = ScaledProjectionKernel(proj_module,
                                             base_kernel,
                                             prescale=False,
                                             ard_num_dims=3)
        proj_kernel.initialize(lengthscale=torch.tensor([1., 2., 3.]))

        with torch.no_grad():
            K = proj_kernel(x, x).evaluate()

        k = RBFKernel()
        k.initialize(lengthscale=torch.tensor([1.]))

        with torch.no_grad():
            K2 = 3 * k(x[:, 0:1], x[:, 0:1]).evaluate()

        np.testing.assert_allclose(K.numpy(), K2.numpy())
示例#5
0
    def sample_arch(self, START_BO, g, steps, hyperparams, og_flops, full_val_loss, target_flops=0):
        if args.slim:
            if target_flops == 0:
                parameterization = hyperparams.random_sample()
                layer_budget = hyperparams.get_layer_budget_from_parameterization(parameterization, self.mask_pruner)
            else:
                parameterization = np.ones(hyperparams.get_dim()) * args.lower_channel
                layer_budget = hyperparams.get_layer_budget_from_parameterization(parameterization, self.mask_pruner)
        else:
            # random sample to warmup history for MOBO
            if g < START_BO:
                if target_flops == 0:
                    f = np.random.rand(1) * (args.upper_channel-args.lower_channel) + args.lower_channel
                else:
                    f = args.lower_channel
                parameterization = np.ones(hyperparams.get_dim()) * f
                layer_budget = hyperparams.get_layer_budget_from_parameterization(parameterization, self.mask_pruner)
            # put the largest model into the history
            elif g == START_BO:
                if target_flops == 0:
                    parameterization = np.ones(hyperparams.get_dim())
                else:
                    f = args.lower_channel
                    parameterization = np.ones(hyperparams.get_dim()) * f
                layer_budget = hyperparams.get_layer_budget_from_parameterization(parameterization, self.mask_pruner)
            # MOBO
            else:
                # this is the scalarization (lambda_{FLOPs})
                rand = torch.rand(1).cuda()

                # standardize data for building Gaussian Processes
                train_X = torch.FloatTensor(self.X).cuda()
                train_Y_loss = torch.FloatTensor(np.array(self.Y)[:, 0].reshape(-1, 1)).cuda()
                train_Y_loss = standardize(train_Y_loss)

                train_Y_cost = torch.FloatTensor(np.array(self.Y)[:, 1].reshape(-1, 1)).cuda()
                train_Y_cost = standardize(train_Y_cost)

                new_train_X = train_X
                # GP for the cross entropy loss
                gp_loss = SingleTaskGP(new_train_X, train_Y_loss)
                mll = ExactMarginalLogLikelihood(gp_loss.likelihood, gp_loss)
                mll = mll.to('cuda')
                fit_gpytorch_model(mll)


                # GP for FLOPs
                # we use add-gp since FLOPs has addive structure (not exactly though)
                # the parameters for ScaleKernel and MaternKernel simply follow the default
                covar_module = AdditiveStructureKernel(
                    ScaleKernel(
                        MaternKernel(
                            nu=2.5,
                            lengthscale_prior=GammaPrior(3.0, 6.0),
                            num_dims=1
                        ),
                        outputscale_prior=GammaPrior(2.0, 0.15),
                    ),
                    num_dims=train_X.shape[1]
                )
                gp_cost = SingleTaskGP(new_train_X, train_Y_cost, covar_module=covar_module)
                mll = ExactMarginalLogLikelihood(gp_cost.likelihood, gp_cost)
                mll = mll.to('cuda')
                fit_gpytorch_model(mll)

                # Build acquisition functions
                UCB_loss = UpperConfidenceBound(gp_loss, beta=0.1).cuda()
                UCB_cost = UpperConfidenceBound(gp_cost, beta=0.1).cuda()

                # Combine them via augmented Tchebyshev scalarization
                self.mobo_obj = RandAcquisition(UCB_loss).cuda()
                self.mobo_obj.setup(UCB_loss, UCB_cost, rand)

                # Bounds for the optimization variable (alpha)
                lower = torch.ones(new_train_X.shape[1])*args.lower_channel
                upper = torch.ones(new_train_X.shape[1])*args.upper_channel
                self.mobo_bounds = torch.stack([lower, upper]).cuda()

                # Pareto-aware sampling
                if args.pas:
                    # Generate approximate Pareto front first
                    costs = []
                    for i in range(len(self.population_data)):
                        costs.append([self.population_data[i]['loss'], self.population_data[i]['ratio']])
                    costs = np.array(costs)
                    efficient_mask = is_pareto_efficient(costs)
                    costs = costs[efficient_mask]
                    loss = costs[:, 0]
                    flops = costs[:, 1]
                    sorted_idx = np.argsort(flops)
                    loss = loss[sorted_idx]
                    flops = flops[sorted_idx]
                    if flops[0] > args.lower_flops:
                        flops = np.concatenate([[args.lower_flops], flops.reshape(-1)])
                        loss = np.concatenate([[8], loss.reshape(-1)])
                    else:
                        flops = flops.reshape(-1)
                        loss = loss.reshape(-1)

                    if flops[-1] < args.upper_flops and (loss[-1] > full_val_loss):
                        flops = np.concatenate([flops.reshape(-1), [args.upper_flops]])
                        loss = np.concatenate([loss.reshape(-1), [full_val_loss]])
                    else:
                        flops = flops.reshape(-1)
                        loss = loss.reshape(-1)

                    # Equation (4) in paper
                    areas = (flops[1:]-flops[:-1])*(loss[:-1]-loss[1:])

                    # Quantize into 50 bins to sample from multinomial
                    self.sampling_weights = np.zeros(50)
                    k = 0
                    while k < len(flops) and flops[k] < args.lower_flops:
                        k+=1
                    for i in range(50):
                        lower = i/50.
                        upper = (i+1)/50.
                        if upper < args.lower_flops or lower > args.upper_flops or lower < args.lower_flops:
                            continue
                        cnt = 1
                        while ((k+1) < len(flops)) and upper > flops[k+1]:
                            self.sampling_weights[i] += areas[k]
                            cnt += 1
                            k += 1
                        if k < len(areas):
                            self.sampling_weights[i] += areas[k]
                        self.sampling_weights[i] /= cnt
                    if np.sum(self.sampling_weights) == 0:
                        self.sampling_weights = np.ones(50)
                        
                    if target_flops == 0:
                        val = np.arange(0.01, 1, 0.02)
                        chosen_target_flops = np.random.choice(val, p=(self.sampling_weights/np.sum(self.sampling_weights)))
                    else:
                        chosen_target_flops = target_flops
                    
                    # Binary search is here
                    lower_bnd, upper_bnd = 0, 1
                    lmda = 0.5
                    for i in range(10):
                        self.mobo_obj.rand = lmda

                        parameterization, acq_value = optimize_acqf(
                            self.mobo_obj, bounds=self.mobo_bounds, q=1, num_restarts=5, raw_samples=1000,
                        )

                        parameterization = parameterization[0].cpu().numpy()
                        layer_budget = hyperparams.get_layer_budget_from_parameterization(parameterization, self.mask_pruner)
                        sim_flops = self.mask_pruner.simulate_and_count_flops(layer_budget)
                        ratio = sim_flops/og_flops

                        if np.abs(ratio - chosen_target_flops) <= 0.02:
                            break
                        if args.baseline > 0:
                            if ratio < chosen_target_flops:
                                lower_bnd = lmda
                                lmda = (lmda + upper_bnd) / 2
                            elif ratio > chosen_target_flops:
                                upper_bnd = lmda
                                lmda = (lmda + lower_bnd) / 2
                        else:
                            if ratio < chosen_target_flops:
                                upper_bnd = lmda
                                lmda = (lmda + lower_bnd) / 2
                            elif ratio > chosen_target_flops:
                                lower_bnd = lmda
                                lmda = (lmda + upper_bnd) / 2
                    rand[0] = lmda
                    writer.add_scalar('Binary search trials', i, steps)

                else:
                    parameterization, acq_value = optimize_acqf(
                        self.mobo_obj, bounds=self.mobo_bounds, q=1, num_restarts=5, raw_samples=1000,
                    )
                    parameterization = parameterization[0].cpu().numpy()

                layer_budget = hyperparams.get_layer_budget_from_parameterization(parameterization, self.mask_pruner)
        return layer_budget, parameterization, self.sampling_weights/np.sum(self.sampling_weights)
示例#6
0
    def sample_arch(self, START_BO, g, hyperparams, og_flops, empty_val_loss, full_val_loss, target_flops=0):
        if g < START_BO:
            if target_flops == 0:
                f = np.random.rand(1) * (args.upper_channel-args.lower_channel) + args.lower_channel
            else:
                f = args.lower_channel
            parameterization = np.ones(hyperparams.get_dim()) * f
            layer_budget = hyperparams.get_layer_budget_from_parameterization(parameterization, self.mask_pruner)
        elif g == START_BO:
            if target_flops == 0:
                parameterization = np.ones(hyperparams.get_dim())
            else:
                f = args.lower_channel
                parameterization = np.ones(hyperparams.get_dim()) * f
            layer_budget = hyperparams.get_layer_budget_from_parameterization(parameterization, self.mask_pruner)
        else:
            rand = torch.rand(1).cuda()

            train_X = torch.FloatTensor(self.X).cuda()
            train_Y_loss = torch.FloatTensor(np.array(self.Y)[:, 0].reshape(-1, 1)).cuda()
            train_Y_loss = standardize(train_Y_loss)

            train_Y_cost = torch.FloatTensor(np.array(self.Y)[:, 1].reshape(-1, 1)).cuda()
            train_Y_cost = standardize(train_Y_cost)

            covar_module = None
            if args.ski and g > 128:
                if args.additive:
                    covar_module = AdditiveStructureKernel(
                        ScaleKernel(
                            GridInterpolationKernel(
                                MaternKernel(
                                    nu=2.5,
                                    lengthscale_prior=GammaPrior(3.0, 6.0),
                                ),
                                grid_size=128, num_dims=1, grid_bounds=[(0, 1)]
                            ),
                            outputscale_prior=GammaPrior(2.0, 0.15),
                        ), 
                        num_dims=train_X.shape[1]
                    )
                else:
                    covar_module = ScaleKernel(
                        GridInterpolationKernel(
                            MaternKernel(
                                nu=2.5,
                                lengthscale_prior=GammaPrior(3.0, 6.0),
                            ),
                            grid_size=128, num_dims=train_X.shape[1], grid_bounds=[(0, 1) for _ in range(train_X.shape[1])]
                        ),
                        outputscale_prior=GammaPrior(2.0, 0.15),
                    )
            else:
                if args.additive:
                    covar_module = AdditiveStructureKernel(
                        ScaleKernel(
                            MaternKernel(
                                nu=2.5,
                                lengthscale_prior=GammaPrior(3.0, 6.0),
                                num_dims=1
                            ),
                            outputscale_prior=GammaPrior(2.0, 0.15),
                        ),
                        num_dims=train_X.shape[1]
                    )
                else:
                    covar_module = ScaleKernel(
                        MaternKernel(
                            nu=2.5,
                            lengthscale_prior=GammaPrior(3.0, 6.0),
                            num_dims=train_X.shape[1]
                        ),
                        outputscale_prior=GammaPrior(2.0, 0.15),
                    )

            new_train_X = train_X
            gp_loss = SingleTaskGP(new_train_X, train_Y_loss, covar_module=covar_module)
            mll = ExactMarginalLogLikelihood(gp_loss.likelihood, gp_loss)
            mll = mll.to('cuda')
            fit_gpytorch_model(mll)


            # Use add-gp for cost
            covar_module = AdditiveStructureKernel(
                ScaleKernel(
                    MaternKernel(
                        nu=2.5,
                        lengthscale_prior=GammaPrior(3.0, 6.0),
                        num_dims=1
                    ),
                    outputscale_prior=GammaPrior(2.0, 0.15),
                ),
                num_dims=train_X.shape[1]
            )
            gp_cost = SingleTaskGP(new_train_X, train_Y_cost, covar_module=covar_module)
            mll = ExactMarginalLogLikelihood(gp_cost.likelihood, gp_cost)
            mll = mll.to('cuda')
            fit_gpytorch_model(mll)

            UCB_loss = UpperConfidenceBound(gp_loss, beta=args.beta).cuda()
            UCB_cost = UpperConfidenceBound(gp_cost, beta=args.beta).cuda()
            self.mobo_obj = RandAcquisition(UCB_loss).cuda()
            self.mobo_obj.setup(UCB_loss, UCB_cost, rand)

            lower = torch.ones(new_train_X.shape[1])*args.lower_channel
            upper = torch.ones(new_train_X.shape[1])*args.upper_channel
            self.mobo_bounds = torch.stack([lower, upper]).cuda()

            if args.pas:
                val = np.linspace(args.lower_flops, 1, 50)
                chosen_target_flops = np.random.choice(val, p=(self.sampling_weights/np.sum(self.sampling_weights)))
                
                lower_bnd, upper_bnd = 0, 1
                lmda = 0.5
                for i in range(10):
                    self.mobo_obj.rand = lmda

                    parameterization, acq_value = optimize_acqf(
                        self.mobo_obj, bounds=self.mobo_bounds, q=1, num_restarts=5, raw_samples=1000,
                    )

                    parameterization = parameterization[0].cpu().numpy()
                    layer_budget = hyperparams.get_layer_budget_from_parameterization(parameterization, self.mask_pruner)
                    sim_flops = self.mask_pruner.simulate_and_count_flops(layer_budget, self.use_mem)
                    ratio = sim_flops/og_flops

                    if np.abs(ratio - chosen_target_flops) <= 0.02:
                        break
                    if args.baseline > 0:
                        if ratio < chosen_target_flops:
                            lower_bnd = lmda
                            lmda = (lmda + upper_bnd) / 2
                        elif ratio > chosen_target_flops:
                            upper_bnd = lmda
                            lmda = (lmda + lower_bnd) / 2
                    else:
                        if ratio < chosen_target_flops:
                            upper_bnd = lmda
                            lmda = (lmda + lower_bnd) / 2
                        elif ratio > chosen_target_flops:
                            lower_bnd = lmda
                            lmda = (lmda + upper_bnd) / 2
                rand[0] = lmda
                writer.add_scalar('Binary search trials', i, g)

            else:
                parameterization, acq_value = optimize_acqf(
                    self.mobo_obj, bounds=self.mobo_bounds, q=1, num_restarts=5, raw_samples=1000,
                )
                parameterization = parameterization[0].cpu().numpy()

            layer_budget = hyperparams.get_layer_budget_from_parameterization(parameterization, self.mask_pruner)
        return layer_budget, parameterization, self.sampling_weights/np.sum(self.sampling_weights)
    def test_gradients(self):
        x = torch.tensor([[1., 2., 3.], [1.1, 2.2, 3.3]])
        y = torch.sin(x).sum(dim=1)
        kbase = RBFKernel()
        kbase.initialize(lengthscale=torch.tensor([1.]))
        base_kernel = AdditiveStructureKernel(kbase, 3)
        proj_module = torch.nn.Linear(3, 3, bias=False)
        proj_module.weight.data = torch.eye(3, dtype=torch.float)
        proj_kernel = ScaledProjectionKernel(proj_module,
                                             base_kernel,
                                             prescale=True,
                                             ard_num_dims=3)
        proj_kernel.initialize(lengthscale=torch.tensor([1., 2., 3.]))

        model = ExactGPModel(x, y, gpytorch.likelihoods.GaussianLikelihood(),
                             proj_kernel)
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
        optimizer_ = torch.optim.Adam(model.parameters(), lr=0.1)
        optimizer_.zero_grad()

        pred = model(x)
        loss = -mll(pred, y)
        loss.backward()

        optimizer_.step()

        np.testing.assert_allclose(
            proj_kernel.base_kernel.base_kernel.lengthscale.numpy(),
            torch.tensor([[1.]]).numpy())
        np.testing.assert_allclose(
            proj_kernel.projection_module.weight.numpy(),
            torch.eye(3, dtype=torch.float).numpy())
        self.assertFalse(
            np.allclose(proj_kernel.lengthscale.detach().numpy(),
                        torch.tensor([1., 2., 3.]).numpy()))

        proj_module = torch.nn.Linear(3, 3, bias=False)
        proj_module.weight.data = torch.eye(3, dtype=torch.float)
        proj_kernel2 = ScaledProjectionKernel(proj_module,
                                              base_kernel,
                                              prescale=True,
                                              ard_num_dims=3,
                                              learn_proj=True)

        proj_kernel2.initialize(lengthscale=torch.tensor([1., 2., 3.]))

        model = ExactGPModel(x, y, gpytorch.likelihoods.GaussianLikelihood(),
                             proj_kernel2)
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
        optimizer_ = torch.optim.Adam(model.parameters(), lr=0.1)
        optimizer_.zero_grad()

        pred = model(x)
        loss = -mll(pred, y)
        loss.backward()

        optimizer_.step()

        np.testing.assert_allclose(
            proj_kernel2.base_kernel.base_kernel.lengthscale.numpy(),
            torch.tensor([[1.]]).numpy())
        self.assertFalse(
            np.allclose(proj_kernel2.projection_module.weight.detach().numpy(),
                        torch.eye(3, dtype=torch.float).numpy()))
        self.assertFalse(
            np.allclose(proj_kernel2.lengthscale.detach().numpy(),
                        torch.tensor([1., 2., 3.]).numpy()))
示例#8
0
    def sample_arch(self,
                    START_BO,
                    g,
                    hyperparams,
                    og_flops,
                    empty_val_loss,
                    full_val_loss,
                    target_flops=0):
        # Warming up the history with a single width-multiplier
        if g < START_BO:
            if target_flops == 0:
                f = np.random.rand(1) * (args.upper_channel - args.
                                         lower_channel) + args.lower_channel
            else:
                f = args.lower_channel
            parameterization = np.ones(hyperparams.get_dim()) * f
            layer_budget = hyperparams.get_layer_budget_from_parameterization(
                parameterization, self.mask_pruner)
        # Put largest model into the history
        elif g == START_BO:
            if target_flops == 0:
                parameterization = np.ones(hyperparams.get_dim())
            else:
                f = args.lower_channel
                parameterization = np.ones(hyperparams.get_dim()) * f
            layer_budget = hyperparams.get_layer_budget_from_parameterization(
                parameterization, self.mask_pruner)
        # MOBO-RS
        else:
            rand = torch.rand(1).cuda()

            train_X = torch.FloatTensor(self.X).cuda()
            train_Y_loss = torch.FloatTensor(
                np.array(self.Y)[:, 0].reshape(-1, 1)).cuda()
            train_Y_loss = standardize(train_Y_loss)

            train_Y_cost = torch.FloatTensor(
                np.array(self.Y)[:, 1].reshape(-1, 1)).cuda()
            train_Y_cost = standardize(train_Y_cost)

            new_train_X = train_X
            gp_loss = SingleTaskGP(new_train_X, train_Y_loss)
            mll = ExactMarginalLogLikelihood(gp_loss.likelihood, gp_loss)
            mll = mll.to('cuda')
            fit_gpytorch_model(mll)

            # Use add-gp for cost
            covar_module = AdditiveStructureKernel(ScaleKernel(
                MaternKernel(nu=2.5,
                             lengthscale_prior=GammaPrior(3.0, 6.0),
                             num_dims=1),
                outputscale_prior=GammaPrior(2.0, 0.15),
            ),
                                                   num_dims=train_X.shape[1])
            gp_cost = SingleTaskGP(new_train_X,
                                   train_Y_cost,
                                   covar_module=covar_module)
            mll = ExactMarginalLogLikelihood(gp_cost.likelihood, gp_cost)
            mll = mll.to('cuda')
            fit_gpytorch_model(mll)

            UCB_loss = UpperConfidenceBound(gp_loss).cuda()
            UCB_cost = UpperConfidenceBound(gp_cost).cuda()
            self.mobo_obj = RandAcquisition(UCB_loss).cuda()
            self.mobo_obj.setup(UCB_loss, UCB_cost, rand)

            lower = torch.ones(new_train_X.shape[1]) * args.lower_channel
            upper = torch.ones(new_train_X.shape[1]) * args.upper_channel
            self.mobo_bounds = torch.stack([lower, upper]).cuda()

            if args.pas:
                costs = []
                for i in range(len(self.population_data)):
                    costs.append([
                        self.population_data[i]['loss'],
                        self.population_data[i]['ratio']
                    ])
                costs = np.array(costs)
                efficient_mask = is_pareto_efficient(costs)
                costs = costs[efficient_mask]
                loss = costs[:, 0]
                flops = costs[:, 1]
                sorted_idx = np.argsort(flops)
                loss = loss[sorted_idx]
                flops = flops[sorted_idx]
                if flops[0] > args.lower_flops:
                    flops = np.concatenate([[args.lower_flops],
                                            flops.reshape(-1)])
                    loss = np.concatenate([[empty_val_loss], loss.reshape(-1)])
                else:
                    flops = flops.reshape(-1)
                    loss = loss.reshape(-1)

                if flops[-1] < args.upper_flops and (loss[-1] > full_val_loss):
                    flops = np.concatenate(
                        [flops.reshape(-1), [args.upper_flops]])
                    loss = np.concatenate([loss.reshape(-1), [full_val_loss]])
                else:
                    flops = flops.reshape(-1)
                    loss = loss.reshape(-1)

                areas = (flops[1:] - flops[:-1]) * (loss[:-1] - loss[1:])

                self.sampling_weights = np.zeros(50)
                k = 0
                while k < len(flops) and flops[k] < args.lower_flops:
                    k += 1
                for i in range(50):
                    lower = i / 50.
                    upper = (i + 1) / 50.
                    if upper < args.lower_flops or lower > args.upper_flops or lower < args.lower_flops:
                        continue
                    cnt = 1
                    while ((k + 1) < len(flops)) and upper > flops[k + 1]:
                        self.sampling_weights[i] += areas[k]
                        cnt += 1
                        k += 1
                    if k < len(areas):
                        self.sampling_weights[i] += areas[k]
                    self.sampling_weights[i] /= cnt
                if np.sum(self.sampling_weights) == 0:
                    self.sampling_weights = np.ones(50)

                if target_flops == 0:
                    val = np.arange(0.01, 1, 0.02)
                    chosen_target_flops = np.random.choice(
                        val,
                        p=(self.sampling_weights /
                           np.sum(self.sampling_weights)))
                else:
                    chosen_target_flops = target_flops

                lower_bnd, upper_bnd = 0, 1
                lmda = 0.5
                for i in range(10):
                    self.mobo_obj.rand = lmda

                    parameterization, acq_value = optimize_acqf(
                        self.mobo_obj,
                        bounds=self.mobo_bounds,
                        q=1,
                        num_restarts=5,
                        raw_samples=1000,
                    )

                    parameterization = parameterization[0].cpu().numpy()
                    layer_budget = hyperparams.get_layer_budget_from_parameterization(
                        parameterization, self.mask_pruner)
                    sim_flops = self.mask_pruner.simulate_and_count_flops(
                        layer_budget)
                    ratio = sim_flops / og_flops

                    if np.abs(ratio - chosen_target_flops) <= 0.02:
                        break
                    if args.baseline > 0:
                        if ratio < chosen_target_flops:
                            lower_bnd = lmda
                            lmda = (lmda + upper_bnd) / 2
                        elif ratio > chosen_target_flops:
                            upper_bnd = lmda
                            lmda = (lmda + lower_bnd) / 2
                    else:
                        if ratio < chosen_target_flops:
                            upper_bnd = lmda
                            lmda = (lmda + lower_bnd) / 2
                        elif ratio > chosen_target_flops:
                            lower_bnd = lmda
                            lmda = (lmda + upper_bnd) / 2
                rand[0] = lmda
                writer.add_scalar('Binary search trials', i, g)

            else:
                parameterization, acq_value = optimize_acqf(
                    self.mobo_obj,
                    bounds=self.mobo_bounds,
                    q=1,
                    num_restarts=5,
                    raw_samples=1000,
                )
                parameterization = parameterization[0].cpu().numpy()

            layer_budget = hyperparams.get_layer_budget_from_parameterization(
                parameterization, self.mask_pruner)
        return layer_budget, parameterization, self.sampling_weights / np.sum(
            self.sampling_weights)