예제 #1
0
    def __init__(self,
                 total_batches,
                 batch_size,
                 num_layers=8,
                 image_width=512,
                 loss_coef=100,
                 theta_initial=None,
                 theta_hidden=None):
        super().__init__()
        # load clip

        self.loss_coef = loss_coef
        self.image_width = image_width

        self.batch_size = batch_size
        self.total_batches = total_batches
        self.num_batches_processed = 0

        w0 = default(theta_hidden, 30.)
        w0_initial = default(theta_initial, 30.)

        siren = SirenNet(dim_in=2,
                         dim_hidden=256,
                         num_layers=num_layers,
                         dim_out=3,
                         use_bias=True,
                         w0=w0,
                         w0_initial=w0_initial)

        self.model = SirenWrapper(siren,
                                  image_width=image_width,
                                  image_height=image_width)

        self.generate_size_schedule()
예제 #2
0
    def __init__(
        self,
        total_batches,
        batch_size,
        num_layers=8,
        image_width=512,
        loss_coef=100,
    ):
        super().__init__()
        self.loss_coef = loss_coef
        self.image_width = image_width

        self.batch_size = batch_size
        self.total_batches = total_batches
        self.num_batches_processed = 0

        siren = SirenNet(dim_in=2,
                         dim_hidden=256,
                         num_layers=num_layers,
                         dim_out=3,
                         use_bias=True)

        self.model = SirenWrapper(siren,
                                  image_width=image_width,
                                  image_height=image_width)

        self.generate_size_schedule()
예제 #3
0
    def __init__(
            self,
            total_batches,
            batch_size,
            num_layers=8,
            image_width=512,
            loss_coef=100,
            theta_initial=None,
            theta_hidden=None,
            lower_bound_cutout=0.1, # should be smaller than 0.8
            upper_bound_cutout=1.0,
            saturate_bound=False,
    ):
        super().__init__()
        # load clip

        self.loss_coef = loss_coef
        self.image_width = image_width

        self.batch_size = batch_size
        self.total_batches = total_batches
        self.num_batches_processed = 0

        w0 = default(theta_hidden, 30.)
        w0_initial = default(theta_initial, 30.)

        siren = SirenNet(
            dim_in=2,
            dim_hidden=256,
            num_layers=num_layers,
            dim_out=3,
            use_bias=True,
            w0=w0,
            w0_initial=w0_initial
        )

        self.model = SirenWrapper(
            siren,
            image_width=image_width,
            image_height=image_width
        )

        self.saturate_bound = saturate_bound
        self.saturate_limit = 0.75 # cutouts above this value lead to destabilization
        self.lower_bound_cutout = lower_bound_cutout
        self.upper_bound_cutout = upper_bound_cutout
예제 #4
0
    def __init__(
        self,
        clip_perceptor,
        clip_norm,
        input_res,
        total_batches,
        batch_size,
        num_layers=8,
        image_width=512,
        loss_coef=100,
        theta_initial=None,
        theta_hidden=None,
        lower_bound_cutout=0.1,  # should be smaller than 0.8
        upper_bound_cutout=1.0,
        saturate_bound=False,
        gauss_sampling=False,
        gauss_mean=0.6,
        gauss_std=0.2,
        do_cutout=True,
        center_bias=False,
        center_focus=2,
        hidden_size=256,
        averaging_weight=0.3,
    ):
        super().__init__()
        # load clip
        self.perceptor = clip_perceptor
        self.input_resolution = input_res
        self.normalize_image = clip_norm

        self.loss_coef = loss_coef
        self.image_width = image_width

        self.batch_size = batch_size
        self.total_batches = total_batches
        self.num_batches_processed = 0

        w0 = default(theta_hidden, 30.)
        w0_initial = default(theta_initial, 30.)

        siren = SirenNet(dim_in=2,
                         dim_hidden=hidden_size,
                         num_layers=num_layers,
                         dim_out=1,
                         use_bias=True,
                         w0=w0,
                         w0_initial=w0_initial)

        self.model = SirenWrapper(siren,
                                  image_width=image_width,
                                  image_height=image_width)

        self.saturate_bound = saturate_bound
        self.saturate_limit = 0.75  # cutouts above this value lead to destabilization
        self.lower_bound_cutout = lower_bound_cutout
        self.upper_bound_cutout = upper_bound_cutout
        self.gauss_sampling = gauss_sampling
        self.gauss_mean = gauss_mean
        self.gauss_std = gauss_std
        self.do_cutout = do_cutout
        self.center_bias = center_bias
        self.center_focus = center_focus
        self.averaging_weight = averaging_weight