def __init__(self, objective, dual_objective, accuracy_fn1, value_fn,
                 concat_states, key_state, compute_min_grad_fn,
                 compute_grad_fn, hparams, delta_s, pred_state,
                 pred_prev_state, counter, dataset_tuple):
        self.concat_states = concat_states
        self._state = None
        self._bparam = None
        self.opt = OptimizerCreator(
            opt_string=hparams["meta"]["optimizer"],
            learning_rate=hparams["descent_lr"]).get_optimizer()
        self.objective = objective
        self.dual_objective = dual_objective
        self._lagrange_multiplier = hparams["lagrange_init"]
        self._state_secant_vector = None
        self._state_secant_c2 = None
        self.delta_s = delta_s
        self.descent_period = hparams["descent_period"]
        self.max_norm_state = hparams["max_bounds"]
        self.hparams = hparams
        self.compute_min_grad_fn = compute_min_grad_fn
        self.compute_grad_fn = compute_grad_fn
        self._assign_states()
        self._parc_vec = None
        self.state_stack = dict()
        self.key_state = key_state
        self.pred_state = pred_state
        self.pred_prev_state = pred_prev_state
        self.sphere_radius = hparams["sphere_radius"]
        self.counter = counter
        self.value_fn = value_fn
        self.accuracy_fn1 = accuracy_fn1
        self.dataset_tuple = dataset_tuple
        if hparams["meta"]["dataset"] == "mnist":
            (self.train_images, self.train_labels, self.test_images,
             self.test_labels) = dataset_tuple

            if hparams["continuation_config"] == 'data':
                # data continuation
                self.data_loader = iter(
                    get_mnist_batch_alter(self.train_images,
                                          self.train_labels,
                                          self.test_images,
                                          self.test_labels,
                                          alter=self._bparam,
                                          batch_size=hparams["batch_size"],
                                          resize=hparams["resize_to_small"],
                                          filter=hparams["filter"]))
            else:
                # model continuation
                self.data_loader = iter(
                    get_mnist_data(batch_size=hparams["batch_size"],
                                   resize=hparams["resize_to_small"],
                                   filter=hparams["filter"]))
            self.num_batches = meta_mnist(hparams["batch_size"],
                                          hparams["filter"])["num_batches"]
        else:
            self.data_loader = None
            self.num_batches = 1
Exemplo n.º 2
0
 def __init__(
     self,
     objective,
     dual_objective,
     value_fn,
     concat_states,
     key_state,
     compute_min_grad_fn,
     compute_grad_fn,
     hparams,
     pred_state,
     pred_prev_state,
     counter,
 ):
     self.concat_states = concat_states
     self._state = None
     self._bparam = None
     self.opt = OptimizerCreator(
         opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["descent_lr"]
     ).get_optimizer()
     self.objective = objective
     self.dual_objective = dual_objective
     self._lagrange_multiplier = hparams["lagrange_init"]
     self._state_secant_vector = None
     self._state_secant_c2 = None
     self.delta_s = hparams["delta_s"]
     self.descent_period = hparams["descent_period"]
     self.max_norm_state = hparams["max_bounds"]
     self.hparams = hparams
     self.compute_min_grad_fn = compute_min_grad_fn
     self.compute_grad_fn = compute_grad_fn
     self._assign_states()
     self._parc_vec = None
     self.state_stack = dict()
     self.key_state = key_state
     self.pred_state = pred_state
     self.pred_prev_state = pred_prev_state
     self.sphere_radius = hparams["sphere_radius"]
     self.counter = counter
     self.value_fn = value_fn
     # self.data_loader = iter(get_data(dataset=hparams["meta"]['dataset'],
     #                             batch_size=hparams['batch_size'],
     #                             num_workers=hparams['data_workers'],
     #                             train_only=True, test_only=False))
     if hparams["meta"]["dataset"] == "mnist":
         self.data_loader = iter(
             get_mnist_data(
                 batch_size=hparams["batch_size"], resize=hparams["resize_to_small"]
             )
         )
         self.num_batches = meta_mnist(hparams["batch_size"])["num_batches"]
     else:
         self.data_loader = None
         self.num_batches = 1
    def __init__(self, objective, concat_states, grad_fn, value_fn,
                 accuracy_fn, hparams, dataset_tuple):
        self.concat_states = concat_states
        self._state = None
        self._bparam = None
        self.opt = OptimizerCreator(
            opt_string=hparams["meta"]["optimizer"],
            learning_rate=hparams["natural_lr"]).get_optimizer()
        self.objective = objective
        self.accuracy_fn = accuracy_fn
        self.warmup_period = hparams["warmup_period"]
        self.hparams = hparams
        self.grad_fn = grad_fn
        self.value_fn = value_fn
        self._assign_states()
        if hparams["meta"]["dataset"] == "mnist":
            (self.train_images, self.train_labels, self.test_images,
             self.test_labels) = dataset_tuple
            if hparams["continuation_config"] == 'data':
                # data continuation
                self.data_loader = iter(
                    get_mnist_batch_alter(self.train_images,
                                          self.train_labels,
                                          self.test_images,
                                          self.test_labels,
                                          alter=self._bparam,
                                          batch_size=hparams["batch_size"],
                                          resize=hparams["resize_to_small"],
                                          filter=hparams["filter"]))
            else:
                # model continuation
                self.data_loader = iter(
                    get_mnist_data(batch_size=hparams["batch_size"],
                                   resize=hparams["resize_to_small"],
                                   filter=hparams["filter"])

                    # get_preload_mnist_data(self.train_images, ## TODO: better way to prefetch mnist
                    #                        self.train_labels,
                    #                        self.test_images,
                    #                        self.test_labels,
                    #                          batch_size = hparams["batch_size"],
                    #                          resize = hparams["resize_to_small"],
                    #                         filter = hparams["filter"])
                )

            self.num_batches = meta_mnist(hparams["batch_size"],
                                          hparams["filter"])["num_batches"]
        else:
            self.data_loader = None
            self.num_batches = 1
Exemplo n.º 4
0
        artifact_uri = mlflow.get_artifact_uri()
        print("Artifact uri: {}".format(artifact_uri))

        mlflow.log_text("", artifact_file="output/_touch.txt")
        artifact_uri2 = mlflow.get_artifact_uri("output/")
        print("Artifact uri: {}".format(artifact_uri2))
        hparams["meta"]["output_dir"] = artifact_uri2
        file_name = f"{artifact_uri2}/version.jsonl"

        sw = StateWriter(file_name=file_name)

        data_loader = iter(
            get_mnist_data(batch_size=hparams["batch_size"],
                           resize=True,
                           filter=hparams['filter']))
        num_batches = meta_mnist(batch_size=hparams["batch_size"],
                                 filter=hparams['filter'])["num_batches"]
        print(f"num of bathces: {num_batches}")
        compute_grad_fn = jit(grad(problem.objective, [0]))

        opt = OptimizerCreator(
            hparams["meta"]["optimizer"],
            learning_rate=hparams["natural_lr"]).get_optimizer()
        ma_loss = []
        for epoch in range(hparams["warmup_period"]):
            for b_j in range(num_batches):
                batch = next(data_loader)
                ae_grads = compute_grad_fn(ae_params, batch)
                ae_params = opt.update_params(ae_params,
                                              ae_grads[0],
                                              step_index=epoch)
                loss = problem.objective(ae_params, batch)
Exemplo n.º 5
0
        bparams = [bparam, bparam_1]
        return states, bparams


def exp_decay(epoch, initial_lrate):
    k = 0.02
    lrate = initial_lrate * np.exp(-k * epoch)
    return lrate


if __name__ == "__main__":
    problem = DataTopologyAE()
    ae_params, bparam = problem.initial_value()
    bparam = pytree_element_add(bparam, 0.99)
    data_loader = iter(get_mnist_data(batch_size=25000, resize=True))
    num_batches = meta_mnist(batch_size=25000)["num_batches"]
    print(f"num of bathces: {num_batches}")
    compute_grad_fn = jit(grad(problem.objective, [0]))

    with open(problem.HPARAMS_PATH, "r") as hfile:
        hparams = json.load(hfile)
    opt = AdamOptimizer(learning_rate=hparams["descent_lr"])
    ma_loss = []
    for epoch in range(500):
        for b_j in range(num_batches):
            batch = next(data_loader)
            grads = compute_grad_fn(ae_params, bparam, batch)
            ae_params = opt.update_params(ae_params,
                                          grads[0],
                                          step_index=epoch)
            loss = problem.objective(ae_params, bparam, batch)
Exemplo n.º 6
0
def exp_decay(epoch, initial_lrate):
    k = 0.02
    lrate = initial_lrate * np.exp(-k * epoch)
    return lrate


if __name__ == "__main__":
    problem = DataTopologyAE()
    ae_params, bparam = problem.initial_value()
    bparam = pytree_element_add(bparam, 0.99)
    with open(problem.HPARAMS_PATH, "r") as hfile:
        hparams = json.load(hfile)
    data_loader = iter(
        get_mnist_data(batch_size=hparams["batch_size"], resize=True))
    num_batches = meta_mnist(batch_size=hparams["batch_size"])["num_batches"]
    print(f"num of bathces: {num_batches}")
    compute_grad_fn = jit(grad(problem.objective, [0]))

    opt = AdamOptimizer(learning_rate=hparams["descent_lr"])
    ma_loss = []
    for epoch in range(500):
        for b_j in range(num_batches):
            batch = next(data_loader)
            grads = compute_grad_fn(ae_params, bparam, batch)
            ae_params = opt.update_params(ae_params,
                                          grads[0],
                                          step_index=epoch)
            loss = problem.objective(ae_params, bparam, batch)
            ma_loss.append(loss)
            print(f"loss:{loss}  norm:{l2_norm(grads)}")