Esempio n. 1
0
def train_epoch(train_examples,
                train_queue,
                model,
                optimizer: optim.Optimizer,
                regularizer: Regularizer,
                batch_size: int,
                verbose: bool = True):
    loss = nn.CrossEntropyLoss(reduction='mean')
    with tqdm.tqdm(total=train_examples.shape[0],
                   unit='ex',
                   disable=not verbose) as bar:
        bar.set_description(f'train loss')
        for step, input in enumerate(train_queue):
            model.train()

            input_var = Variable(input, requires_grad=False).cuda()
            target_var = Variable(input[:, 2],
                                  requires_grad=False).cuda()  #async=True)

            predictions, factors = model.forward(input_var)
            truth = input_var[:, 2]

            l_fit = loss(predictions, truth)
            l_reg = regularizer.forward(factors)
            l = l_fit + l_reg

            optimizer.zero_grad()
            l.backward()
            nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
            optimizer.step()

            bar.update(input_var.shape[0])
            bar.set_postfix(loss=f'{l.item():.0f}')
Esempio n. 2
0
def train_epoch(examples: torch.LongTensor,
                model,
                optimizer: optim.Optimizer,
                regularizer: Regularizer,
                batch_size: int,
                verbose: bool = True):
    actual_examples = examples[torch.randperm(examples.shape[0]), :]
    loss = nn.CrossEntropyLoss(reduction='mean')
    with tqdm.tqdm(total=examples.shape[0], unit='ex',
                   disable=not verbose) as bar:
        bar.set_description(f'train loss')
        b_begin = 0
        while b_begin < examples.shape[0]:
            ##set current batch
            input_batch = actual_examples[b_begin:b_begin + batch_size].cuda()

            #compute predictions, ground truth
            predictions, factors = model.forward(input_batch)
            truth = input_batch[:, 2]

            #evaluate loss
            l_fit = loss(predictions, truth)
            l_reg = regularizer.forward(factors)
            l = l_fit + l_reg

            #optimise
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            b_begin += batch_size

            #progress bar
            bar.update(input_batch.shape[0])
            bar.set_postfix(loss=f'{l.item():.0f}')
Esempio n. 3
0
def train_epoch(train_examples,
                train_queue,
                valid_queue,
                model,
                architect,
                criterion,
                optimizer: optim.Optimizer,
                regularizer: Regularizer,
                batch_size: int,
                lr,
                verbose: bool = True):
    loss = nn.CrossEntropyLoss(reduction='mean')
    print('avg entity embedding norm',
          torch.norm(model.embeddings[0].weight, dim=1).mean())
    print('avg relation embedding norm',
          torch.norm(model.embeddings[1].weight, dim=1).mean())
    with tqdm.tqdm(total=train_examples.shape[0],
                   unit='ex',
                   disable=not verbose) as bar:
        bar.set_description(f'train loss')
        for step, input in enumerate(train_queue):

            model.train()

            input_var = Variable(input, requires_grad=False).cuda()
            target_var = Variable(input[:, 2],
                                  requires_grad=False).cuda()  #async=True)

            input_search = next(iter(valid_queue))
            input_search_var = Variable(input_search,
                                        requires_grad=False).cuda()
            target_search_var = Variable(
                input_search[:, 2], requires_grad=False).cuda()  #async=True)

            architect.step(input_var,
                           target_var,
                           input_search_var,
                           target_search_var,
                           lr,
                           optimizer,
                           unrolled=args.unrolled)
            optimizer.zero_grad()

            predictions, factors = model.forward(input_var)
            #truth = input_var[:, 2]

            l_fit = loss(predictions, target_var)
            l_reg = regularizer.forward(factors)
            l = l_fit + l_reg

            l.backward()
            nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
            optimizer.step()

            bar.update(input_var.shape[0])
            bar.set_postfix(loss=f'{l.item():.0f}')
    def __init__(self, **kwargs):
        """
		Creating tomography solver object.
		Required Args:
			shape: shape of the object in [y, x, z]
			voxel_size: size of voxel in [y, x, z]
			wavelength: wavelength of probing wave, scalar
			sigma: sigma used in calculating transmittance function (exp(1i * sigma * object)), scalar
			tilt_angles: an array of sample rotation angles
			defocus_list: an array of defocus values

		Optional Args [default]
			amplitude_measurements: measurements for reconstruction, not needed for forward evaluation of the model only [None]
			numerical_aperture: numerical aperture of the system, scalar [1.0]
			binning_factor: bins the number of slices together to save computation, scalar [1]
			pad_size: padding reconstruction from measurements in [dy,dx], final size will be measurement.shape + 2*[dy, dx], [0, 0]
			shuffle: random shuffle of measurements, boolean [True]
			pupil: inital value for the pupil function [None]
			maxitr: maximum number of iterations [100]
			step_size: step_size for each gradient update [0.1]
			momentum: [0.0 NOTIMPLEMENTED]


			-- transform alignment parameters (currently only support rigid body transform alignment) -- 
			transform_align: whether to turn on transform alignment, boolean, [False]
			ta_method: "turboreg"
			ta_start_iteration: alignment process will not start until then, int, [0]
			ta_iterations: iterations during which the alignment process will be on, [0, max_itr]

			-- Shift alignment parameters -- 
			shift_align: whether to turn on alignment, boolean, [False]
			sa_method: shift alignment method, can be "gradient", "hybrid_correlation", "cross_correlation", or "phase_correlation", string, ["gradient"]
			sa_step_size: step_size of shift parameters, float, [0.1]
			sa_start_iteration: alignment process will not start until then, int, [0]
			sa_iterations: iterations during which the alignment process will be on, [0, max_itr]

			-- Defocus refinement parameters -- 
			defocus_refine: whether to turn on defocus refinement for each measurement, boolean, [False]
			dr_method: defocus refinement method, can be "gradient", string, ["gradient"]
			dr_step_size: step_size of defocus refinement parameters, float, [0.1]
			dr_start_iteration: refinement process will not start until then, int, [0]
			dr_iterations: iterations during which the defocus refocus process will be on, [0, max_itr]

			-- regularizer parameters --
			regularizer_total_variation: boolean [False]
			regularizer_total_variation_gpu: boolean [False]
			regularizer_total_variation_parameter: controls amount of total variation, scalar or vector of length maxitr. [scalar 1.0]
			regularizer_total_variation_maxitr: number of iterations for total variation, integer [15]
			regularizer_total_variation_order: differential order, scalar [1], higher order not yet implemented
			regularizer_pure_real: boolean [False]
			regularizer_pure_imag: boolean [False]
			regularizer_pure_amplitude: boolean [False]
			regularizer_pure_phase: boolean [False]
			regularizer_positivity_real: boolean [False]
			regularizer_positivity_imag: boolean [False]
			regularizer_negativity_real: boolean [False]
			regularizer_negativity_imag: boolean [False]
			regularizer_dtype: torch dtype class [torch.float32]
		"""

        self.shape = kwargs.get("shape")

        self.shuffle = kwargs.get("shuffle", True)
        self.optim_max_itr = kwargs.get("maxitr", 100)
        self.optim_step_size = kwargs.get("step_size", 0.1)
        self.optim_momentum = kwargs.get("momentum", 0.0)

        self.obj_update_iterations = kwargs.get("obj_update_iterations",
                                                np.arange(self.optim_max_itr))

        #parameters for transform alignment
        self.transform_align = kwargs.get("transform_align", False)
        self.ta_method = kwargs.get("ta_method", "turboreg")
        self.ta_start_iteration = kwargs.get("ta_start_iteration", 0)
        self.ta_iterations = kwargs.get("ta_iterations", None)
        if self.ta_iterations is None:
            self.ta_iterations = np.arange(self.ta_start_iteration,
                                           self.optim_max_itr)

        #parameters for shift alignment
        self.shift_align = kwargs.get("shift_align", False)
        self.sa_method = kwargs.get("sa_method", "gradient")
        self.sa_step_size = kwargs.get("sa_step_size", 0.1)
        self.sa_start_iteration = kwargs.get("sa_start_iteration", 0)
        self.sa_iterations = kwargs.get("sa_iterations", None)
        if self.sa_iterations is None:
            self.sa_iterations = np.arange(self.sa_start_iteration,
                                           self.optim_max_itr)

        #parameters for defocus refinement
        self.defocus_refine = kwargs.get("defocus_refine", False)
        self.dr_method = kwargs.get("dr_method", "gradient")
        self.dr_step_size = kwargs.get("dr_step_size", 0.1)
        self.dr_start_iteration = kwargs.get("dr_start_iteration", 0)
        self.dr_iterations = kwargs.get("dr_iterations", None)
        if self.dr_iterations is None:
            self.dr_iterations = np.arange(self.dr_start_iteration,
                                           self.optim_max_itr)

        if not shift.is_valid_method(self.sa_method):
            raise ValueError('Shift alignment method not valid.')
        if self.shift_align and shift.is_correlation_method(self.sa_method):
            self.shift_obj   = shift.ImageShiftCorrelationBased(kwargs["amplitude_measurements"].shape[0:2], \
                                upsample_factor = 10, method = self.sa_method, \
                              device=torch.device('cpu'))

        if self.transform_align:
            self.transform_obj   = transform.ImageTransformOpticalFlow(kwargs["amplitude_measurements"].shape[0:2],\
                                    method = self.ta_method)

        self.dataset = AETDataset(**kwargs)
        self.num_defocus = self.dataset.get_all_defocus_lists().shape[0]
        self.num_rotation = len(self.dataset.tilt_angles)
        self.tomography_obj = PhaseContrastScattering(**kwargs)
        reg_temp_param = kwargs.get("regularizer_total_variation_parameter",
                                    None)
        if reg_temp_param is not None:
            if not np.isscalar(reg_temp_param):
                assert self.optim_max_itr == len(
                    kwargs["regularizer_total_variation_parameter"])
        self.regularizer_obj = Regularizer(**kwargs)
        self.rotation_obj = utilities.ImageRotation(self.shape, axis=0)

        self.cost_function = nn.MSELoss(reduction='sum')
class TorchTomographySolver:
    def __init__(self, **kwargs):
        """
		Creating tomography solver object.
		Required Args:
			shape: shape of the object in [y, x, z]
			voxel_size: size of voxel in [y, x, z]
			wavelength: wavelength of probing wave, scalar
			sigma: sigma used in calculating transmittance function (exp(1i * sigma * object)), scalar
			tilt_angles: an array of sample rotation angles
			defocus_list: an array of defocus values

		Optional Args [default]
			amplitude_measurements: measurements for reconstruction, not needed for forward evaluation of the model only [None]
			numerical_aperture: numerical aperture of the system, scalar [1.0]
			binning_factor: bins the number of slices together to save computation, scalar [1]
			pad_size: padding reconstruction from measurements in [dy,dx], final size will be measurement.shape + 2*[dy, dx], [0, 0]
			shuffle: random shuffle of measurements, boolean [True]
			pupil: inital value for the pupil function [None]
			maxitr: maximum number of iterations [100]
			step_size: step_size for each gradient update [0.1]
			momentum: [0.0 NOTIMPLEMENTED]


			-- transform alignment parameters (currently only support rigid body transform alignment) -- 
			transform_align: whether to turn on transform alignment, boolean, [False]
			ta_method: "turboreg"
			ta_start_iteration: alignment process will not start until then, int, [0]
			ta_iterations: iterations during which the alignment process will be on, [0, max_itr]

			-- Shift alignment parameters -- 
			shift_align: whether to turn on alignment, boolean, [False]
			sa_method: shift alignment method, can be "gradient", "hybrid_correlation", "cross_correlation", or "phase_correlation", string, ["gradient"]
			sa_step_size: step_size of shift parameters, float, [0.1]
			sa_start_iteration: alignment process will not start until then, int, [0]
			sa_iterations: iterations during which the alignment process will be on, [0, max_itr]

			-- Defocus refinement parameters -- 
			defocus_refine: whether to turn on defocus refinement for each measurement, boolean, [False]
			dr_method: defocus refinement method, can be "gradient", string, ["gradient"]
			dr_step_size: step_size of defocus refinement parameters, float, [0.1]
			dr_start_iteration: refinement process will not start until then, int, [0]
			dr_iterations: iterations during which the defocus refocus process will be on, [0, max_itr]

			-- regularizer parameters --
			regularizer_total_variation: boolean [False]
			regularizer_total_variation_gpu: boolean [False]
			regularizer_total_variation_parameter: controls amount of total variation, scalar or vector of length maxitr. [scalar 1.0]
			regularizer_total_variation_maxitr: number of iterations for total variation, integer [15]
			regularizer_total_variation_order: differential order, scalar [1], higher order not yet implemented
			regularizer_pure_real: boolean [False]
			regularizer_pure_imag: boolean [False]
			regularizer_pure_amplitude: boolean [False]
			regularizer_pure_phase: boolean [False]
			regularizer_positivity_real: boolean [False]
			regularizer_positivity_imag: boolean [False]
			regularizer_negativity_real: boolean [False]
			regularizer_negativity_imag: boolean [False]
			regularizer_dtype: torch dtype class [torch.float32]
		"""

        self.shape = kwargs.get("shape")

        self.shuffle = kwargs.get("shuffle", True)
        self.optim_max_itr = kwargs.get("maxitr", 100)
        self.optim_step_size = kwargs.get("step_size", 0.1)
        self.optim_momentum = kwargs.get("momentum", 0.0)

        self.obj_update_iterations = kwargs.get("obj_update_iterations",
                                                np.arange(self.optim_max_itr))

        #parameters for transform alignment
        self.transform_align = kwargs.get("transform_align", False)
        self.ta_method = kwargs.get("ta_method", "turboreg")
        self.ta_start_iteration = kwargs.get("ta_start_iteration", 0)
        self.ta_iterations = kwargs.get("ta_iterations", None)
        if self.ta_iterations is None:
            self.ta_iterations = np.arange(self.ta_start_iteration,
                                           self.optim_max_itr)

        #parameters for shift alignment
        self.shift_align = kwargs.get("shift_align", False)
        self.sa_method = kwargs.get("sa_method", "gradient")
        self.sa_step_size = kwargs.get("sa_step_size", 0.1)
        self.sa_start_iteration = kwargs.get("sa_start_iteration", 0)
        self.sa_iterations = kwargs.get("sa_iterations", None)
        if self.sa_iterations is None:
            self.sa_iterations = np.arange(self.sa_start_iteration,
                                           self.optim_max_itr)

        #parameters for defocus refinement
        self.defocus_refine = kwargs.get("defocus_refine", False)
        self.dr_method = kwargs.get("dr_method", "gradient")
        self.dr_step_size = kwargs.get("dr_step_size", 0.1)
        self.dr_start_iteration = kwargs.get("dr_start_iteration", 0)
        self.dr_iterations = kwargs.get("dr_iterations", None)
        if self.dr_iterations is None:
            self.dr_iterations = np.arange(self.dr_start_iteration,
                                           self.optim_max_itr)

        if not shift.is_valid_method(self.sa_method):
            raise ValueError('Shift alignment method not valid.')
        if self.shift_align and shift.is_correlation_method(self.sa_method):
            self.shift_obj   = shift.ImageShiftCorrelationBased(kwargs["amplitude_measurements"].shape[0:2], \
                                upsample_factor = 10, method = self.sa_method, \
                              device=torch.device('cpu'))

        if self.transform_align:
            self.transform_obj   = transform.ImageTransformOpticalFlow(kwargs["amplitude_measurements"].shape[0:2],\
                                    method = self.ta_method)

        self.dataset = AETDataset(**kwargs)
        self.num_defocus = self.dataset.get_all_defocus_lists().shape[0]
        self.num_rotation = len(self.dataset.tilt_angles)
        self.tomography_obj = PhaseContrastScattering(**kwargs)
        reg_temp_param = kwargs.get("regularizer_total_variation_parameter",
                                    None)
        if reg_temp_param is not None:
            if not np.isscalar(reg_temp_param):
                assert self.optim_max_itr == len(
                    kwargs["regularizer_total_variation_parameter"])
        self.regularizer_obj = Regularizer(**kwargs)
        self.rotation_obj = utilities.ImageRotation(self.shape, axis=0)

        self.cost_function = nn.MSELoss(reduction='sum')

    def run(self, obj_init=None, forward_only=False, callback=None):
        """
		run tomography solver
		Args:
		forward_only: True  -- only runs forward model on estimated object
					  False -- runs reconstruction
		"""
        if forward_only:
            assert obj_init is not None
            self.shuffle = False
            amplitude_list = []

        self.dataloader = DataLoader(self.dataset,
                                     batch_size=1,
                                     shuffle=self.shuffle)

        error = []
        #initialize object
        self.obj = obj_init
        if self.obj is None:
            self.obj = op.r2c(torch.zeros(self.shape).cuda())
        else:
            if not self.obj.is_cuda:
                self.obj = self.obj.cuda()
            if len(self.obj.shape) == 3:
                self.obj = op.r2c(self.obj)

        #initialize shift parameters
        self.yx_shifts = None
        if self.shift_align:
            self.sa_pixel_count = []
            self.yx_shift_all = []
            self.yx_shifts = torch.zeros(
                (2, self.num_defocus, self.num_rotation))

        if self.transform_align:
            self.xy_transform_all = []
            self.xy_transforms = torch.zeros(
                (6, self.num_defocus, self.num_rotation))
#			self.xy_transforms = torch.zeros((3, self.num_defocus, self.num_rotation))
# TEMPP
# defocus_list_grad = torch.zeros((self.num_defocus, self.num_rotation), dtype = torch.float32)
        ref_rot_idx = None
        #begin iteration
        for itr_idx in range(self.optim_max_itr):
            sys.stdout.flush()
            running_cost = 0.0
            #defocus_list_grad[:] = 0.0
            if self.shift_align and itr_idx in self.sa_iterations:
                running_sa_pixel_count = 0.0
            for data_idx, data in enumerate(self.dataloader, 0):
                #parse data
                if not forward_only:
                    amplitudes, rotation_angle, defocus_list, rotation_idx = data
                    if ref_rot_idx is None and abs(rotation_angle -
                                                   0.0) < 1e-2:
                        ref_rot_idx = rotation_idx
                        print("reference index is:", ref_rot_idx)
                    amplitudes = torch.squeeze(amplitudes)
                    if len(amplitudes.shape) < 3:
                        amplitudes = amplitudes.unsqueeze(-1)

                else:
                    rotation_angle, defocus_list, rotation_idx = data[-3:]
                #prepare tilt specific parameters
                defocus_list = torch.flatten(defocus_list).cuda()
                rotation_angle = rotation_angle.item()
                yx_shift = None
                if self.shift_align and self.sa_method == "gradient" and itr_idx in self.sa_iterations:
                    yx_shift = self.yx_shifts[:, :, rotation_idx]
                    yx_shift = yx_shift.cuda()
                    yx_shift.requires_grad_()
                if self.defocus_refine and self.dr_method == "gradient" and itr_idx in self.dr_iterations:
                    defocus_list.requires_grad_()
                #rotate object
                if data_idx == 0:
                    self.obj = self.rotation_obj.forward(
                        self.obj, rotation_angle)
                else:
                    if abs(rotation_angle - previous_angle) > 90:
                        self.obj = self.rotation_obj.forward(
                            self.obj, -1 * previous_angle)
                        self.obj = self.rotation_obj.forward(
                            self.obj, rotation_angle)
                    else:
                        self.obj = self.rotation_obj.forward(
                            self.obj, rotation_angle - previous_angle)
                if not forward_only:
                    #define optimizer
                    optimizer_params = []

                    if itr_idx in self.obj_update_iterations:
                        self.obj.requires_grad_()
                        optimizer_params.append({
                            'params': self.obj,
                            'lr': self.optim_step_size
                        })
                    if self.shift_align and self.sa_method == "gradient" and itr_idx in self.sa_iterations:
                        optimizer_params.append({
                            'params': yx_shift,
                            'lr': self.sa_step_size
                        })
                    if self.defocus_refine and self.dr_method == "gradient" and itr_idx in self.dr_iterations:
                        optimizer_params.append({
                            'params': defocus_list,
                            'lr': self.dr_step_size
                        })
                    optimizer = optim.SGD(optimizer_params)

                #forward scattering
                estimated_amplitudes = self.tomography_obj(
                    self.obj, defocus_list, yx_shift)
                #in-plane rotation estimation
                if not forward_only:
                    if self.transform_align and itr_idx in self.ta_iterations:
                        if rotation_idx != ref_rot_idx:
                            amplitudes, xy_transform = self.transform_obj.estimate(
                                estimated_amplitudes, amplitudes)
                            xy_transform = xy_transform.unsqueeze(-1)
#						self.dataset.update_amplitudes(amplitudes, rotation_idx)
#Correlation based shift estimation
                    if self.shift_align and shift.is_correlation_method(
                            self.sa_method) and itr_idx in self.sa_iterations:
                        if rotation_idx != ref_rot_idx:
                            amplitudes, yx_shift, _ = self.shift_obj.estimate(
                                estimated_amplitudes, amplitudes)
                            yx_shift = yx_shift.unsqueeze(-1)


#						self.dataset.update_amplitudes(amplitudes, rotation_idx)
                    if itr_idx == self.optim_max_itr - 1:
                        print("Last iteration: updated amplitudes")
                        self.dataset.update_amplitudes(amplitudes,
                                                       rotation_idx)

            #compute cost
                    cost = self.cost_function(estimated_amplitudes,
                                              amplitudes.cuda())
                    running_cost += cost.item()

                    #backpropagation
                    cost.backward()
                    #update object
                    # if itr_idx >= self.dr_start_iteration:
                    # 	# print(torch.norm(defocus_list.grad.data))
                    # 	defocus_list_grad[:,data_idx] = defocus_list.grad.data *  self.dr_step_size
                    optimizer.step()
                    optimizer.zero_grad()
                    del cost
                else:
                    #store measurement
                    amplitude_list.append(estimated_amplitudes.cpu().detach())
                del estimated_amplitudes
                self.obj.requires_grad = False
                if not forward_only:
                    #keep track of shift alignment for the tilt
                    if self.shift_align and itr_idx in self.sa_iterations:
                        if yx_shift is not None:
                            yx_shift.requires_grad = False
                            if rotation_idx != ref_rot_idx:
                                self.yx_shifts[:, :,
                                               rotation_idx] = yx_shift[:].cpu(
                                               )
                                running_sa_pixel_count += torch.sum(
                                    torch.abs(yx_shift.cpu().flatten()))

                    #keep track of transform alignment for the tilt
                    if self.transform_align and itr_idx in self.ta_iterations:
                        if rotation_idx != ref_rot_idx:
                            self.xy_transforms[
                                ..., rotation_idx] = xy_transform[:].cpu()

                    #keep track of defocus alignment for the tilt
                    if self.defocus_refine and itr_idx in self.dr_iterations:
                        defocus_list.requires_grad = False
                        self.dataset.update_defocus_list(
                            defocus_list[:].cpu().detach(), rotation_idx)

                previous_angle = rotation_angle

                #rotate object back
                if data_idx == (self.dataset.__len__() - 1):
                    previous_angle = 0.0
                    self.obj = self.rotation_obj.forward(
                        self.obj, -1.0 * rotation_angle)
                print("Rotation {:03d}/{:03d}.".format(data_idx + 1,
                                                       self.dataset.__len__()),
                      end="\r")

            #apply regularization
            amplitudes = None
            torch.cuda.empty_cache()
            if not forward_only:
                if itr_idx in self.obj_update_iterations:
                    self.obj = self.regularizer_obj.apply(self.obj)
            error.append(running_cost)

            #keep track of shift alignment results
            if self.shift_align and itr_idx in self.sa_iterations:
                self.sa_pixel_count.append(running_sa_pixel_count)
                self.yx_shift_all.append(np.array(self.yx_shifts).copy())

            #keep track of transform alignment results
            if self.transform_align and itr_idx in self.ta_iterations:
                self.xy_transform_all.append(
                    np.array(self.xy_transforms).copy())

            if callback is not None:
                callback(self.obj.cpu().detach(), error)
                #TEMPPPPP
                # callback(defocus_list_grad, self.dataset.get_all_defocus_lists(), error)
            if forward_only and itr_idx == 0:
                return torch.cat([
                    torch.unsqueeze(amplitude_list[idx], -1)
                    for idx in range(len(amplitude_list))
                ],
                                 axis=-1)
            print("Iteration {:03d}/{:03d}. Error: {:03f}".format(
                itr_idx + 1, self.optim_max_itr, np.log10(running_cost)))

        self.defocus_list = self.dataset.get_all_defocus_lists()
        return self.obj.cpu().detach(), error
Esempio n. 6
0
### User-defined and Multiple regularizers
from regularizers import Regularizer


def prox_g(z, sigma):
    return (z >= 0) * z


def value_g(x):
    if any(x < -1e-7):
        return float('Inf')
    return 0.0


regObjNonneg = Regularizer(prox=prox_g, value=value_g)
gamma = 1.0
projSplit = ps.ProjSplitFit(gamma)
projSplit.addRegularizer(regObjNonneg)
lam1 = 0.1
projSplit = ps.ProjSplitFit()
projSplit.addData(A, r, loss=2, intercept=False, normalize=False)
regObj = L1(scaling=lam1)
projSplit.addRegularizer(regObj)
regObjNonneg = Regularizer(prox=prox_g, value=value_g)
projSplit.addRegularizer(regObjNonneg)
projSplit.run(verbose=True)
optimalVal = projSplit.getObjective()
z = projSplit.getSolution()
print(f"Objective value = {optimalVal}")
Esempio n. 7
0
def test_user_defined_embedded(processor, testNumber):
    def val1(x):
        return 0.5 * np.linalg.norm(x, 2)**2

    def prox1(x, scale):
        return (1 + scale)**(-1) * x

    def val2(x):
        return np.linalg.norm(x, 2)

    def prox2(x, scale):
        normx = np.linalg.norm(x, 2)
        if normx <= scale:
            return 0 * x
        else:
            return (normx - scale) * x / normx

    tau = 0.2

    def val3(x):
        if ((x <= tau) & (x >= -tau)).all():
            return 0
        else:
            return float('inf')

    def prox3(x, scale):
        ones = np.ones(x.shape)
        return tau * (x >= tau) * ones - tau * (x <= -tau) * ones + (
            (x <= tau) & (x >= -tau)) * x

    m = 40
    d = 10
    if getNewOptVals and (testNumber == 0):
        A, y = getLSdata(m, d)
        cache['Aembed'] = A
        cache['yembed'] = y
    else:
        A = cache['Aembed']
        y = cache['yembed']

    projSplit = ps.ProjSplitFit()

    gamma = 1e0
    projSplit.setDualScaling(gamma)

    try:
        scaling = projSplit.getScale()
        exceptMade = False
    except:
        exceptMade = True
    if exceptMade == False:
        raise Exception

    regObj = []
    nu = [0.01, 0.03, 0.1]
    step = [1.0, 1.0, 1.0]

    regObj.append(Regularizer(prox1, val1, nu[0], step[0]))
    regObj.append(Regularizer(prox2, val2, nu[1], step[1]))
    regObj.append(Regularizer(prox3, val3, nu[2], step[2]))

    projSplit.addData(A,
                      y,
                      2,
                      processor,
                      normalize=False,
                      intercept=True,
                      embed=regObj[2])
    projSplit.addRegularizer(regObj[0])
    projSplit.addRegularizer(regObj[1])

    projSplit.run(maxIterations=1000,
                  keepHistory=True,
                  nblocks=5,
                  resetIterate=True)

    if getNewOptVals and (testNumber == 0):
        AwithIntercept = np.zeros((m, d + 1))
        AwithIntercept[:, 0] = np.ones(m)
        AwithIntercept[:, 1:(d + 1)] = A

        (m, d) = AwithIntercept.shape
        x_cvx = cvx.Variable(d)
        f = (1 / (2 * m)) * cvx.sum_squares(AwithIntercept @ x_cvx - y)

        constraints = [-tau <= x_cvx[1:d], x_cvx[1:d] <= tau]

        f += 0.5 * nu[0] * cvx.norm(x_cvx[1:d], 2)**2
        f += nu[1] * cvx.norm(x_cvx[1:d], 2)

        obj = cvx.Minimize(f)
        prob = cvx.Problem(obj, constraints)
        prob.solve(verbose=False)
        #opt = prob.value
        xopt = x_cvx.value
        xopt = np.squeeze(np.array(xopt))
        cache['xoptembedded'] = xopt
    else:
        xopt = cache['xoptembedded']

    xps = projSplit.getSolution()
    print("Norm error = {}".format(np.linalg.norm(xopt - xps, 2)))
    assert (np.linalg.norm(xopt - xps, 2) < 1e-2)
Esempio n. 8
0
def test_user_defined(processor, testNumber):
    def val1(x):
        return 0.5 * np.linalg.norm(x, 2)**2

    def prox1(x, scale):
        return (1 + scale)**(-1) * x

    def val2(x):
        return np.linalg.norm(x, 2)

    def prox2(x, scale):
        normx = np.linalg.norm(x, 2)
        if normx <= scale:
            return 0 * x
        else:
            return (normx - scale) * x / normx

    tau = 0.2

    def val3(x):
        if ((x <= tau) & (x >= -tau)).all():
            return 0
        else:
            return float('inf')

    def prox3(x, scale):
        ones = np.ones(x.shape)
        return tau * (x >= tau) * ones - tau * (x <= -tau) * ones + (
            (x <= tau) & (x >= -tau)) * x

    funcList = [(val3, prox3), (val1, prox1), (val2, prox2)]

    i = 0
    m = 40
    d = 10
    if getNewOptVals and (testNumber == 0):
        A, y = getLSdata(m, d)
        cache['Auser'] = A
        cache['yuser'] = y
    else:
        A = cache['Auser']
        y = cache['yuser']

    for (val, prox) in funcList:

        projSplit = ps.ProjSplitFit()

        gamma = 1e0
        projSplit.setDualScaling(gamma)
        projSplit.addData(A, y, 2, processor, normalize=False, intercept=False)
        nu = 5.5
        step = 1e0
        regObj = Regularizer(prox, val, scaling=nu, step=step)
        projSplit.addRegularizer(regObj)
        projSplit.run(maxIterations=1000,
                      keepHistory=True,
                      nblocks=1,
                      resetIterate=True,
                      primalTol=1e-12,
                      dualTol=1e-12)
        ps_val = projSplit.getObjective()

        (m, d) = A.shape
        if getNewOptVals and (testNumber == 0):
            x_cvx = cvx.Variable(d)
            f = (1 / (2 * m)) * cvx.sum_squares(A @ x_cvx - y)

            if i == 0:
                constraints = [-tau <= x_cvx, x_cvx <= tau]
            elif i == 1:
                f += 0.5 * nu * cvx.norm(x_cvx, 2)**2
                constraints = []
            elif i == 2:
                f += nu * cvx.norm(x_cvx, 2)
                constraints = []

            obj = cvx.Minimize(f)
            prob = cvx.Problem(obj, constraints)
            prob.solve(verbose=True)
            opt = prob.value
            xopt = x_cvx.value
            xopt = np.squeeze(np.array(xopt))
            cache[(i, 'optuser')] = opt
            cache[(i, 'xuser')] = xopt
        else:
            opt = cache[(i, 'optuser')]
            xopt = cache[(i, 'xuser')]

        if i == 0:
            xps = projSplit.getSolution()
            print(np.linalg.norm(xopt - xps, 2))
            assert (np.linalg.norm(xopt - xps, 2) < 1e-2)
        else:
            print('cvx opt val = {}'.format(opt))
            print('ps opt val = {}'.format(ps_val))
            assert abs(ps_val - opt) < 1e-2
        i += 1

    # test combined
    m = 40
    d = 10
    if getNewOptVals and (testNumber == 0):
        A, y = getLSdata(m, d)
        cache['Acombined'] = A
        cache['ycombined'] = y
    else:
        A = cache['Acombined']
        y = cache['ycombined']

    projSplit = ps.ProjSplitFit()

    projSplit.setDualScaling(gamma)
    projSplit.addData(A, y, 2, processor, normalize=False, intercept=False)
    nu1 = 0.01
    step = 1e0
    regObj = Regularizer(prox1, val1, scaling=nu1, step=step)
    projSplit.addRegularizer(regObj)
    nu2 = 0.05
    step = 1e0
    regObj = Regularizer(prox2, val2, scaling=nu2, step=step)
    projSplit.addRegularizer(regObj)
    step = 1e0
    regObj = Regularizer(prox3, val3, step=step)
    projSplit.addRegularizer(regObj)
    projSplit.run(maxIterations=1000,
                  keepHistory=True,
                  nblocks=1,
                  resetIterate=True,
                  primalTol=1e-12,
                  dualTol=1e-12)
    ps_val = projSplit.getObjective()
    xps = projSplit.getSolution()

    if getNewOptVals and (testNumber == 0):
        x_cvx = cvx.Variable(d)
        f = (1 / (2 * m)) * cvx.sum_squares(A @ x_cvx - y)

        constraints = [-tau <= x_cvx, x_cvx <= tau]

        f += 0.5 * nu1 * cvx.norm(x_cvx, 2)**2
        f += nu2 * cvx.norm(x_cvx, 2)

        obj = cvx.Minimize(f)
        prob = cvx.Problem(obj, constraints)
        prob.solve(verbose=True)
        opt = prob.value
        xopt = x_cvx.value
        xopt = np.squeeze(np.array(xopt))
        cache['optcombined'] = opt
        cache['xcombined'] = xopt
    else:
        opt = cache['optcombined']
        xopt = cache['xcombined']

    assert (np.linalg.norm(xopt - xps, 2) < 1e-2)