示例#1
0
文件: graph.py 项目: NeilGirdhar/tjax
 def ty_to_state_dict(graph: T) -> Dict[str, Any]:
     edge_dict_of_dicts = defaultdict[Any, Dict[Any, Any]](dict)
     for (source, target), edge_dict in dict(graph.edges).items():
         edge_dict_of_dicts[source][target] = edge_dict
     return {
         'nodes': to_state_dict(dict(graph.nodes)),
         'edges': to_state_dict(dict(edge_dict_of_dicts))
     }
示例#2
0
def serialize_MCState(vstate):
    state_dict = {
        "variables": serialization.to_state_dict(vstate.variables),
        "sampler_state": serialization.to_state_dict(vstate.sampler_state),
        "n_samples": vstate.n_samples,
        "n_discard": vstate.n_discard,
    }
    return state_dict
示例#3
0
def serialize_MCMixedState(vstate):
    state_dict = {
        "variables": serialization.to_state_dict(vstate.variables),
        "sampler_state": serialization.to_state_dict(vstate.sampler_state),
        "diagonal": serialization.to_state_dict(vstate.diagonal),
        "n_samples": vstate.n_samples,
        "n_discard_per_chain": vstate.n_discard_per_chain,
    }
    return state_dict
示例#4
0
def _load_optimizer(optimizer, ckpt, allow_missing=False):
    """Loads the optimizer from the state dict."""
    init_keys = set(dict(tree.flatten_with_path(ckpt["target"])))
    model_keys = set(dict(tree.flatten_with_path(optimizer.target)))
    missing_in_model = init_keys.difference(model_keys)
    missing_in_init = model_keys.difference(init_keys)
    missing = model_keys.symmetric_difference(init_keys)
    print("init - model keys: %s", str(missing_in_model))
    print("model - init keys: %s", str(missing_in_init))
    print("difference: %s", str(missing))

    if not allow_missing:
        if missing_in_init:
            raise ValueError(
                "Checkpoints must match exactly if `allow_missing=False`. "
                "Checkpoint missing %s" % str(missing_in_init))

    for param_path in missing_in_init:

        def get_path(d, path):
            print(path)
            print("get")
            for k in path:
                print(k)
                d = d[k]
            return d

        def set_path(d, path, value):
            print("set")
            for k in path[:-1]:
                if k not in d:
                    d[k] = dict()
                d = d[k]
            k = path[-1]
            if k in d:
                if value.shape != d[k].shape:
                    raise ValueError("Shape mismatch: %s" % str(
                        (k, value.shape, d[k].shape)))
            d[k] = value
            return d

        target_param = get_path(optimizer.target, param_path)
        set_path(ckpt["target"], param_path, target_param)

        try:
            target_opt_state = get_path(optimizer.state.param_states,
                                        param_path)
            target_opt_state = serialization.to_state_dict(target_opt_state)
            set_path(ckpt["state"]["param_states"], param_path,
                     target_opt_state)
        except TypeError:
            print(
                f"unable to restore state for {param_path}. Resetting state.")
            ckpt["state"] = serialization.to_state_dict(optimizer.state)

    return serialization.from_state_dict(optimizer, ckpt)
示例#5
0
 def to_state_dict(x):
     state_dict = {
         name: serialization.to_state_dict(getattr(x, name))
         for name in data_fields
         if name not in skip_serialize_fields
     }
     return state_dict
示例#6
0
 def test_optimizer_serialization(self):
     rng = random.PRNGKey(0)
     module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones)
     _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)])
     model = nn.Model(module, initial_params)
     optim_def = optim.Momentum(learning_rate=1.)
     optimizer = optim_def.create(model)
     state = serialization.to_state_dict(optimizer)
     expected_state = {
         'target': {
             'params': {
                 'kernel': onp.ones((1, 1)),
                 'bias': onp.zeros((1, )),
             }
         },
         'state': {
             'step': 0,
             'param_states': {
                 'params': {
                     'kernel': {
                         'momentum': onp.zeros((1, 1))
                     },
                     'bias': {
                         'momentum': onp.zeros((1, ))
                     },
                 }
             }
         },
     }
     self.assertEqual(state, expected_state)
     state = jax.tree_map(lambda x: x + 1, expected_state)
     restored_optimizer = serialization.from_state_dict(optimizer, state)
     optimizer_plus1 = jax.tree_map(lambda x: x + 1, optimizer)
     self.assertEqual(restored_optimizer, optimizer_plus1)
示例#7
0
 def test_statedict(self):
   d = {'a': jnp.array([1.0]),
        'b': {'c': jnp.array([2.0]),
              'd': jnp.array([3.0])}}
   dg = DotGetter(d)
   ser = serialization.to_state_dict(dg)
   deser = serialization.from_state_dict(dg, ser)
   self.assertEqual(d, deser)
示例#8
0
def get_suffix_module_pairs(module_tree) -> List[Tuple[str, Type["Module"]]]:
    """Helper for naming pytrees of submodules."""
    if isinstance(module_tree, Module):
        return [('', module_tree)]
    else:
        flat_tree = traverse_util.flatten_dict(
            serialization.to_state_dict(module_tree))
        return [('_' + '_'.join(k), v) for k, v in flat_tree.items()]
示例#9
0
def _get_suffix_value_pairs(
        tree_or_leaf: Any) -> List[Tuple[str, Type["Module"]]]:
    """Helper for naming pytrees of submodules."""
    dict_or_leaf = serialization.to_state_dict(tree_or_leaf)
    if dict_or_leaf == {} or not isinstance(dict_or_leaf, dict):
        return [('', tree_or_leaf)]
    else:
        flat_dict = traverse_util.flatten_dict(dict_or_leaf)
        return [('_' + '_'.join(k), v) for k, v in flat_dict.items()]
示例#10
0
def _map_over_modules_in_tree(fn, tree_or_leaf):
  """Helper for mapping function over submodules."""
  dict_or_leaf = serialization.to_state_dict(tree_or_leaf)
  if not isinstance(dict_or_leaf, dict) or dict_or_leaf == {}:
    return fn('', tree_or_leaf)
  else:
    flat_dict = traverse_util.flatten_dict(dict_or_leaf, keep_empty_nodes=True)
    mapped_flat_dict = {k: fn('_' + '_'.join(k), v)
                        for k, v in _sorted_items(flat_dict)}
    return serialization.from_state_dict(
        tree_or_leaf, traverse_util.unflatten_dict(mapped_flat_dict))
示例#11
0
    def test_dataclass_serialization(self):
        p = Point(x=1, y=2, meta={'dummy': True})
        state_dict = serialization.to_state_dict(p)
        self.assertEqual(state_dict, {
            'x': 1,
            'y': 2,
        })
        restored_p = serialization.from_state_dict(p, {'x': 3, 'y': 4})
        expected_p = Point(x=3, y=4, meta={'dummy': True})
        self.assertEqual(restored_p, expected_p)

        with self.assertRaises(ValueError):  # invalid field
            serialization.from_state_dict(p, {'z': 3})
        with self.assertRaises(ValueError):  # missing field
            serialization.from_state_dict(p, {'x': 3})
示例#12
0
 def test_model_serialization(self):
     rng = random.PRNGKey(0)
     module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones)
     _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)])
     model = nn.Model(module, initial_params)
     state = serialization.to_state_dict(model)
     self.assertEqual(state, {
         'params': {
             'kernel': onp.ones((1, 1)),
             'bias': onp.zeros((1, )),
         }
     })
     state = {
         'params': {
             'kernel': onp.zeros((1, 1)),
             'bias': onp.zeros((1, )),
         }
     }
     restored_model = serialization.from_state_dict(model, state)
     self.assertEqual(restored_model.params, state['params'])
示例#13
0
  def restore_state(self, state_dict):
    """Restore parameter and optimizer state from state dictionary.

    Adapted from
    https://github.com/google-research/t5x/blob/main/t5x/optimizers.py. Includes
    support to handle `optax.EmptyState`.

    Args:
      state_dict: Contains desired new parameters and optimizer state

    Returns:
      Updated train state.
    """
    params = serialization.from_state_dict(self.params, state_dict["params"])

    # Get all the possible keys in the reference optimizer state.
    flat_ref_opt_state_dict = traverse_util.flatten_dict(
        serialization.to_state_dict(self.opt_state),
        keep_empty_nodes=True,
        sep="/")

    flat_src_opt_state_dict = dict(
        traverse_util.flatten_dict(state_dict["opt_state"], sep="/"))
    # Adding the empty paths back to flat_src_opt_state_dict.
    for k, v in flat_ref_opt_state_dict.items():
      if k in flat_src_opt_state_dict:
        continue
      # The key is not in the input state dict, presumably because it
      # corresponds to an empty dict.
      if v != traverse_util.empty_node:
        raise ValueError(
            f"Failed to restore optimizer state, path {k} is not present "
            "in the input optimizer state dict.")
      flat_src_opt_state_dict[k] = v

    # Restore state from the enhanced state dict.
    opt_state = serialization.from_state_dict(
        self.opt_state,
        traverse_util.unflatten_dict(flat_src_opt_state_dict, sep="/"))
    return self.replace(params=params, opt_state=opt_state)
示例#14
0
    def test_collection_serialization(self):
        @struct.dataclass
        class DummyDataClass:
            x: float

            @classmethod
            def initializer(cls, key, shape):
                del shape, key
                return cls(x=0.)

        class StatefulModule(nn.Module):
            def apply(self):
                state = self.state('state', (), DummyDataClass.initializer)
                state.value = state.value.replace(x=state.value.x + 1.)

        # use stateful
        with nn.stateful() as state:
            self.assertEqual(state.as_dict(), {})
            StatefulModule.init(random.PRNGKey(0))
        self.assertEqual(state.as_dict(),
                         {'/': {
                             'state': DummyDataClass(x=1.)
                         }})
        with nn.stateful(state) as new_state:
            StatefulModule.call({})
        self.assertEqual(new_state.as_dict(),
                         {'/': {
                             'state': DummyDataClass(x=2.)
                         }})
        serialized_state_dict = serialization.to_state_dict(new_state)
        self.assertEqual(serialized_state_dict, {'/': {'state': {'x': 2.}}})
        deserialized_state = serialization.from_state_dict(
            state, serialized_state_dict)
        self.assertEqual(state.as_dict(),
                         {'/': {
                             'state': DummyDataClass(x=1.)
                         }})
        self.assertEqual(new_state.as_dict(), deserialized_state.as_dict())
def main_opt(N, l, i0, nn_arq, act_fun, n_epochs, lr, w_decay, rho_g):

    start_time = time.time()

    str_nn_arq = ''
    for item in nn_arq:
        str_nn_arq = str_nn_arq + '_{}'.format(item)

    f_job = 'nn_arq{}_N_{}_i0_{}_l_{}_batch'.format(str_nn_arq, N, i0, l)
    f_out = '{}/out_opt_{}.txt'.format(r_dir, f_job)
    f_w_nn = '{}/W_{}.npy'.format(r_dir, f_job)
    file_results = '{}/data_nh3_{}.npy'.format(r_dir, f_job)

    #     --------------------------------------
    #     Data
    n_atoms = 4
    batch_size = 768  #1024#768#512#256#128#64#32
    Dtr, Dt = load_data(file_results, N, l)
    Xtr, gXtr, gXctr, ytr = Dtr
    Xt, gXt, gXct, yt = Dt
    print(gXtr.shape, gXtr.shape, gXctr.shape, ytr.shape)
    # --------------------------------
    #     BATCHES

    n_complete_batches, leftover = divmod(N, batch_size)
    n_batches = n_complete_batches + bool(leftover)

    def data_stream():
        rng = onpr.RandomState(0)
        while True:
            perm = rng.permutation(N)
            for i in range(n_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield Xtr[batch_idx], gXtr[batch_idx], gXctr[batch_idx], ytr[
                    batch_idx]

    batches = data_stream()
    # --------------------------------

    f = open(f_out, 'a+')
    print('-----------------------------------', file=f)
    print('Starting time', file=f)
    print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), file=f)
    print('-----------------------------------', file=f)
    print(f_out, file=f)
    print('N = {}, n_atoms = {}, data_random = {}, NN_random = {}'.format(
        N, n_atoms, l, i0),
          file=f)
    print(nn_arq, file=f)
    print('lr = {}, w decay = {}'.format(lr, w_decay), file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('N Epoch = {}'.format(n_epochs), file=f)
    print('rho G = {}'.format(rho_g), file=f)
    print('-----------------------------------', file=f)
    f.close()

    #     --------------------------------------
    #     initialize NN

    nn_arq.append(3)
    tuple_nn_arq = tuple(nn_arq)
    nn_model = NN_adiab(n_atoms, tuple_nn_arq)

    def get_init_NN_params(key):
        x = Xtr[0, :]
        x = x[None, :]  #         x = jnp.ones((1,Xtr.shape[1]))
        variables = nn_model.init(key, x)
        return variables

#     Initilialize parameters

    rng = random.PRNGKey(i0)
    rng, subkey = jax.random.split(rng)
    params = get_init_NN_params(subkey)

    f = open(f_out, 'a+')
    if os.path.isfile(f_w_nn):
        print('Reading NN parameters from prev calculation!', file=f)
        print('-----------------------', file=f)

        nn_dic = jnp.load(f_w_nn, allow_pickle=True)
        params = unfreeze(params)
        params['params'] = nn_dic.item()['params']
        params = freeze(params)
    f.close()
    init_params = params

    #     --------------------------------------
    #     Phys functions

    @jit
    def nn_adiab(params, x):
        y_ad_pred = nn_model.apply(params, x)
        return y_ad_pred

    @jit
    def jac_nn_adiab(params, x):
        g_y_pred = jacrev(nn_adiab, argnums=1)(params, x[None, :])
        return jnp.reshape(g_y_pred, (2, g_y_pred.shape[-1]))

    '''
#     WRONG
    @jit
    def f_nac_coup_i(gH_diab,eigvect_): #for a single cartesian dimension
        temp = jnp.dot(gH_diab,eigvect_[:,0])
        return jnp.vdot(eigvect_[:,1],temp)
    @jit
    def f_nac_coup(params,x):
        eigval_, eigvect_ = f_adiab(params,x)
        gy_diab = jac_nn_diab(params,x)
        gy_diab = jnp.reshape(gy_diab.T,(12,2,2))
        g_coup = vmap(f_nac_coup_i,(0,None))(gy_diab,eigvect_)
        return g_coup
    '''

    #     --------------------------------------
    #     Validation loss functions

    @jit
    def f_validation(params):
        y_pred = nn_adiab(params, Xt)
        diff_y = y_pred - yt
        z = jnp.linalg.norm(diff_y)
        return z

    @jit
    def f_jac_validation(params):
        gX_pred = vmap(jac_nn_adiab, (None, 0))(params, Xt)
        diff_y = gX_pred - gXt
        z = jnp.linalg.norm(diff_y)
        return z

    '''
    @jit
    def f_nac_validation(params):
        g_nac_coup = vmap(f_nac_coup,(None,0))(params,Xt)
        diff_y = g_nac_coup - gXct
        z = jnp.linalg.norm(diff_y)
        return z 
    '''
    #     --------------------------------------
    #    training loss functions
    @jit
    def f_loss_ad_energy(params, batch):
        X_inputs, _, _, y_true = batch
        y_pred = nn_adiab(params, X_inputs)
        diff_y = y_pred - y_true  #Ha2cm*
        loss = jnp.linalg.norm(diff_y)
        return loss

    @jit
    def f_loss_jac(params, batch):
        X_inputs, gX_inputs, _, y_true = batch
        gX_pred = vmap(jac_nn_adiab, (None, 0))(params, X_inputs)
        diff_g_X = gX_pred - gX_inputs
        return jnp.linalg.norm(diff_g_X)

    '''    
    @jit
    def f_loss_nac(params,batch):
        X_inputs, _,gXc_inputs,y_true = batch
        g_nac_coup = vmap(f_nac_coup,(None,0))(params,x)
        diff_y = g_nac_coup - gXc_inputs
        z = jnp.linalg.norm(diff_y)
        return z 
    '''
    #     ------
    @jit
    def f_loss(params, batch):
        loss_ad_energy = f_loss_ad_energy(params, batch)
        #         loss_jac_energy = f_loss_jac(params,batch)
        loss = loss_ad_energy  #+ rho_g*loss_jac_energy
        return loss


#     --------------------------------------
#     Optimization  and Training

#     Perform a single training step.

    @jit
    def train_step(optimizer, batch):  #, learning_rate_fn, model
        grad_fn = jax.value_and_grad(f_loss)
        loss, grad = grad_fn(optimizer.target, batch)
        optimizer = optimizer.apply_gradient(grad)  #, {"learning_rate": lr}
        return optimizer, loss

    optimizer = optim.Adam(learning_rate=lr,
                           weight_decay=w_decay).create(init_params)
    optimizer = jax.device_put(optimizer)

    loss0 = 1E16
    loss0_tot = 1E16
    itercount = itertools.count()
    f_params = init_params
    for epoch in range(n_epochs):
        for _ in range(n_batches):
            optimizer, loss = train_step(optimizer, next(batches))

        params = optimizer.target
        loss_tot = f_validation(params)

        if epoch % 10 == 0:
            f = open(f_out, 'a+')
            print(epoch, loss, loss_tot, file=f)
            f.close()

        if loss < loss0:
            loss0 = loss
            f = open(f_out, 'a+')
            print(epoch, loss, loss_tot, file=f)
            f.close()

        if loss_tot < loss0_tot:
            loss0_tot = loss_tot
            f_params = params
            dict_output = serialization.to_state_dict(params)
            jnp.save(f_w_nn, dict_output)  #unfreeze()

    f = open(f_out, 'a+')
    print('---------------------------------', file=f)
    print('Training time =  %.6f seconds ---' % ((time.time() - start_time)),
          file=f)
    print('---------------------------------', file=f)
    f.close()

    #     --------------------------------------
    #     Prediction
    f = open(f_out, 'a+')
    print('Prediction of the entire data set', file=f)
    print('N = {}, n_atoms = {}, random = {}'.format(N, n_atoms, i0), file=f)
    print('NN : {}'.format(nn_arq), file=f)
    print('lr = {}, w decay = {}, rho G = {}'.format(lr, w_decay, rho_g),
          file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('Total points  = {}'.format(yt.shape[0]), file=f)

    y_pred = nn_adiab(f_params, Xt)
    gX_pred = vmap(jac_nn_adiab, (None, 0))(f_params, Xt)

    diff_y = y_pred - yt
    rmse_Ha = jnp.linalg.norm(diff_y)
    rmse_cm = jnp.linalg.norm(Ha2cm * diff_y)
    mae_Ha = jnp.linalg.norm(diff_y, ord=1)
    mae_cm = jnp.linalg.norm(Ha2cm * diff_y, ord=1)

    print('RMSE = {} [Ha]'.format(rmse_Ha), file=f)
    print('RMSE(tr) = {} [cm-1]'.format(loss0), file=f)
    print('RMSE = {} [cm-1]'.format(rmse_cm), file=f)
    print('MAE = {} [Ha]'.format(mae_Ha), file=f)
    print('MAE = {} [cm-1]'.format(mae_cm), file=f)

    Dpred = jnp.column_stack((Xt, y_pred))
    data_dic = {
        'Dtr': Dtr,
        'Dpred': Dpred,
        'gXpred': gX_pred,
        'loss_tr': loss0,
        'error_full': rmse_cm,
        'N': N,
        'l': l,
        'i0': i0,
        'rho_g': rho_g
    }

    jnp.save(file_results, data_dic)

    print('---------------------------------', file=f)
    print('Total time =  %.6f seconds ---' % ((time.time() - start_time)),
          file=f)
    print('---------------------------------', file=f)
    f.close()
示例#16
0
文件: utils.py 项目: weiningwei/jaxfg
 def _to_state_dict(x: T):
     state_dict = {
         name: serialization.to_state_dict(getattr(x, name))
         for name in field_names
     }
     return state_dict
示例#17
0
def _frozen_dict_state_dict(xs):
  return {key: serialization.to_state_dict(value) for key, value in xs.items()}
示例#18
0
def main_opt(N, l, i0, nn_arq, act_fun, n_epochs, lr, w_decay, rho_g):

    start_time = time.time()

    str_nn_arq = ''
    for item in nn_arq:
        str_nn_arq = str_nn_arq + '_{}'.format(item)

    f_job = 'nn_arq{}_N_{}_i0_{}_l_{}_batch'.format(str_nn_arq, N, i0, l)
    f_out = '{}/out_opt_{}.txt'.format(r_dir, f_job)
    f_w_nn = '{}/W_{}.npy'.format(r_dir, f_job)
    file_results = '{}/data_nh3_{}.npy'.format(r_dir, f_job)

    #     --------------------------------------
    #     Data
    n_atoms = 4
    batch_size = 768  #1024#768#512#256#128#64#32
    Dtr, Dval, Dt = load_data(file_results, N, l)
    Xtr, gXtr, gXctr, ytr = Dtr
    Xval, gXval, gXcval, yval = Dval
    Xt, gXt, gXct, yt = Dt
    print(gXtr.shape, gXtr.shape, gXctr.shape, ytr.shape)
    # --------------------------------
    #     BATCHES

    n_complete_batches, leftover = divmod(N, batch_size)
    n_batches = n_complete_batches + bool(leftover)

    def data_stream():
        rng = onpr.RandomState(0)
        while True:
            perm = rng.permutation(N)
            for i in range(n_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield Xtr[batch_idx], gXtr[batch_idx], gXctr[batch_idx], ytr[
                    batch_idx]

    batches = data_stream()
    # --------------------------------

    f = open(f_out, 'a+')
    print('-----------------------------------', file=f)
    print('Starting time', file=f)
    print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), file=f)
    print('-----------------------------------', file=f)
    print(f_out, file=f)
    print('N = {}, n_atoms = {}, data_random = {}, NN_random = {}'.format(
        N, n_atoms, l, i0),
          file=f)
    print(nn_arq, file=f)
    print('lr = {}, w decay = {}'.format(lr, w_decay), file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('N Epoch = {}'.format(n_epochs), file=f)
    print('rho G = {}'.format(rho_g), file=f)
    print('-----------------------------------', file=f)
    f.close()

    #     --------------------------------------
    #     initialize NN

    nn_arq.append(3)
    tuple_nn_arq = tuple(nn_arq)
    nn_model = NN_adiab(n_atoms, tuple_nn_arq)

    def get_init_NN_params(key):
        x = Xtr[0, :]
        x = x[None, :]  #         x = jnp.ones((1,Xtr.shape[1]))
        variables = nn_model.init(key, x)
        return variables

#     Initilialize parameters

    rng = random.PRNGKey(i0)
    rng, subkey = jax.random.split(rng)
    params = get_init_NN_params(subkey)

    f = open(f_out, 'a+')
    if os.path.isfile(f_w_nn):
        print('Reading NN parameters from prev calculation!', file=f)
        print('-----------------------', file=f)

        nn_dic = jnp.load(f_w_nn, allow_pickle=True)
        params = unfreeze(params)
        params['params'] = nn_dic.item()['params']
        params = freeze(params)
#         print(params)

    f.close()
    init_params = params

    #     --------------------------------------
    #     Phys functions

    @jit
    def nn_adiab(params, x):
        y_ad_pred = nn_model.apply(params, x)
        return y_ad_pred

    @jit
    def jac_nn_adiab(params, x):
        g_y_pred = jacrev(nn_adiab, argnums=1)(params, x[None, :])
        return jnp.reshape(g_y_pred, (2, g_y_pred.shape[-1]))

#     --------------------------------------
#    training loss functions

    @jit
    def f_loss_ad_energy(params, batch):
        X_inputs, _, _, y_true = batch
        y_pred = nn_adiab(params, X_inputs)
        diff_y = y_pred - y_true  #Ha2cm*
        return jnp.linalg.norm(diff_y, axis=0)

    @jit
    def f_loss_jac(params, batch):
        X_inputs, gX_inputs, _, y_true = batch
        gX_pred = vmap(jac_nn_adiab, (None, 0))(params, X_inputs)
        diff_g_X = gX_pred - gX_inputs
        # jnp.linalg.norm(diff_g_X,axis=0)

        diff_g_X0 = diff_g_X[:, 0, :]
        diff_g_X1 = diff_g_X[:, 1, :]
        l0 = jnp.linalg.norm(diff_g_X0)
        l1 = jnp.linalg.norm(diff_g_X1)
        return jnp.stack([l0, l1])

#     ------

    @jit
    def f_loss(params, rho_g, batch):
        rho_g = jnp.exp(rho_g)
        loss_ad_energy = f_loss_ad_energy(params, batch)
        loss_jac_energy = f_loss_jac(params, batch)
        loss = jnp.vdot(jnp.ones_like(loss_ad_energy),
                        loss_ad_energy) + jnp.vdot(rho_g, loss_jac_energy)
        return loss
#     --------------------------------------
#     Optimization  and Training

#     Perform a single training step.

    @jit
    def train_step(optimizer, rho_g, batch):  #, learning_rate_fn, model
        grad_fn = jax.value_and_grad(f_loss)
        loss, grad = grad_fn(optimizer.target, rho_g, batch)
        optimizer = optimizer.apply_gradient(grad)  #, {"learning_rate": lr}
        return optimizer, (loss, grad)

#     @jit

    def train(rho_g, nn_params):
        optimizer = optim.Adam(learning_rate=lr,
                               weight_decay=w_decay).create(nn_params)
        optimizer = jax.device_put(optimizer)

        train_loss = []
        loss0 = 1E16
        loss0_tot = 1E16
        itercount = itertools.count()
        f_params = init_params
        for epoch in range(n_epochs):
            for _ in range(n_batches):
                optimizer, loss_and_grad = train_step(optimizer, rho_g,
                                                      next(batches))
                loss, grad = loss_and_grad

#             f = open(f_out,'a+')
#             print(i,loss,file=f)
#             f.close()

            train_loss.append(loss)
#             params = optimizer.target
#             loss_tot = f_validation(params)

        nn_params = optimizer.target

        return nn_params, loss_and_grad, train_loss

    @jit
    def val_step(optimizer, nn_params):  #, learning_rate_fn, model

        rho_g_prev = optimizer.target
        nn_params, loss_and_grad_train, train_loss_iter = train(
            rho_g_prev, nn_params)
        loss_train, grad_loss_train = loss_and_grad_train

        grad_fn_val = jax.value_and_grad(f_loss, argnums=1)
        loss_val, grad_val = grad_fn_val(nn_params, optimizer.target, Dval)
        optimizer = optimizer.apply_gradient(
            grad_val)  #, {"learning_rate": lr}
        return optimizer, nn_params, (loss_val, loss_train,
                                      train_loss_iter), (grad_loss_train,
                                                         grad_val)

#     Initilialize rho_G

    rng = random.PRNGKey(0)
    rng, subkey = jax.random.split(rng)

    rho_G0 = random.uniform(subkey, shape=(2, ), minval=5E-4, maxval=0.025)
    rho_G0 = jnp.log(rho_G0)
    print('Initial lambdas', rho_G0)
    init_G = rho_G0  #

    optimizer_out = optim.Adam(learning_rate=2E-4,
                               weight_decay=0.).create(init_G)
    optimizer_out = jax.device_put(optimizer_out)

    f_params = init_params

    for i in range(50000):
        start_va_time = time.time()
        optimizer_out, f_params, loss_all, grad_all = val_step(
            optimizer_out, f_params)

        rho_g = optimizer_out.target
        loss_val, loss_train, train_loss_iter = loss_all
        grad_loss_train, grad_val = grad_all

        loss0_tot = f_loss(f_params, rho_g, Dt)

        dict_output = serialization.to_state_dict(f_params)
        jnp.save(f_w_nn, dict_output)  #unfreeze()

        f = open(f_out, 'a+')
        #         print(i,rho_g, loss0, loss0_tot, (time.time() - start_va_time),file=f)
        print(i, loss_val, loss_train, (time.time() - start_va_time), file=f)
        print(jnp.exp(rho_g), file=f)
        print(grad_val, file=f)
        #         print(train_loss_iter ,file=f)
        #         print(grad_val,file=f)
        #         print(grad_loss_train,file=f)
        f.close()


#     --------------------------------------
#     Prediction
    f = open(f_out, 'a+')
    print('Prediction of the entire data set', file=f)
    print('N = {}, n_atoms = {}, random = {}'.format(N, n_atoms, i0), file=f)
    print('NN : {}'.format(nn_arq), file=f)
    print('lr = {}, w decay = {}, rho G = {}'.format(lr, w_decay, rho_g),
          file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('Total points  = {}'.format(yt.shape[0]), file=f)

    y_pred = nn_adiab(f_params, Xt)
    gX_pred = vmap(jac_nn_adiab, (None, 0))(f_params, Xt)

    diff_y = y_pred - yt
    rmse_Ha = jnp.linalg.norm(diff_y)
    rmse_cm = jnp.linalg.norm(Ha2cm * diff_y)
    mae_Ha = jnp.linalg.norm(diff_y, ord=1)
    mae_cm = jnp.linalg.norm(Ha2cm * diff_y, ord=1)

    print('RMSE = {} [Ha]'.format(rmse_Ha), file=f)
    print('RMSE(tr) = {} [cm-1]'.format(loss0), file=f)
    print('RMSE = {} [cm-1]'.format(rmse_cm), file=f)
    print('MAE = {} [Ha]'.format(mae_Ha), file=f)
    print('MAE = {} [cm-1]'.format(mae_cm), file=f)

    Dpred = jnp.column_stack((Xt, y_pred))
    data_dic = {
        'Dtr': Dtr,
        'Dpred': Dpred,
        'gXpred': gX_pred,
        'loss_tr': loss0,
        'error_full': rmse_cm,
        'N': N,
        'l': l,
        'i0': i0,
        'rho_g': rho_g
    }

    jnp.save(file_results, data_dic)

    print('---------------------------------', file=f)
    print('Total time =  %.6f seconds ---' % ((time.time() - start_time)),
          file=f)
    print('---------------------------------', file=f)
    f.close()
示例#19
0
 def state_dict(self):
     return serialization.to_state_dict(
         {'state': serialization.to_state_dict(self.state)})
示例#20
0
def train(config, model_def, device_batch_size, eval_ds, num_steps,
          steps_per_epoch, steps_per_eval, train_ds, image_size, data_source,
          workdir):
  """Train model."""

  make_lr_fn = schedulers.get_make_lr_fn(config)
  make_temp_fn = schedulers.get_make_temp_fn(config)
  make_step_size_fn = schedulers.get_make_step_size_fn(config)
  if jax.host_count() > 1:
    raise ValueError('CIFAR10 example should not be run on '
                     'more than 1 host due to preconditioner updating.')

  initial_step = 0  # TODO(basv): load from checkpoint.
  writer = metric_writers.create_default_writer(
      workdir, just_logging=jax.host_id() > 0)

  # Write config to the summary files. This makes the hyperparameters available
  # in TensorBoard and makes comparison of runs in TensorBoard easier.
  # with writer.summary_writer.as_default():
  writer.write_hparams(dict(config))

  rng = random.PRNGKey(config.seed)
  rng, opt_rng, init_key, sampler_rng = jax.random.split(rng, 4)

  base_learning_rate = config.learning_rate

  # Create the model.
  model, state = create_model(rng, device_batch_size, image_size, model_def)
  parameter_overview.log_parameter_overview(model.params)
  state = jax_utils.replicate(state)

  train_size = data_source.TRAIN_IMAGES

  with flax.deprecated.nn.stochastic(init_key):
    optimizer = create_optimizer(config, model, base_learning_rate, train_size,
                                 sampler_rng)
  del model  # Don't keep a copy of the initial model.

  # Learning rate schedule
  learning_rate_fn = make_lr_fn(base_learning_rate, steps_per_epoch)
  temperature_fn = make_temp_fn(config.base_temp, steps_per_epoch)
  step_size_fn = make_step_size_fn(steps_per_epoch)

  p_eval_step, _, p_train_step, p_update_grad_vars = make_step_functions(
      config, config.l2_reg, learning_rate_fn, train_size, temperature_fn,
      step_size_fn)

  # Create dataset batch iterators.
  train_iter = iter(train_ds)
  eval_iter = iter(eval_ds)

  # Gather metrics.
  train_metrics = []
  epoch = 0

  # Ensemble.
  ensemble = []
  ensemble_logits = []
  ensemble_labels = []
  ensemble_probs = []

  def ensemble_add_step(step):
    if config.lr_schedule == 'cosine':
      # Add if learning rate jumps up again in the next step.
      increase = step_size_fn(step) < step_size_fn(step + 1) - 1e-8
      _, temp_end = ast.literal_eval(config.temp_ramp)
      past_burn_in = step >= steps_per_epoch * temp_end
      return increase and past_burn_in

    elif config.lr_schedule == 'constant':
      if (step + 1) % steps_per_epoch == 0:
        return True
    return False

  logging.info('Starting training loop at step %d.', initial_step)

  for step in range(initial_step, num_steps):
    if config.optimizer in ['sym_euler'] and (step) % steps_per_epoch == 0:
      optimizer, rng = update_preconditioner(config, optimizer,
                                             p_update_grad_vars, rng, state,
                                             train_iter)
    # Generate a PRNG key that will be rolled into the batch
    step_key = jax.random.fold_in(rng, step)
    opt_step_rng = jax.random.fold_in(opt_rng, step)

    # Load and shard the TF batch
    batch = next(train_iter)
    batch = input_pipeline.load_and_shard_tf_batch(config, batch)
    if not config.debug_run:
      # Shard the step PRNG key
      # Don't shard the optimizer rng, as it should be equal among all machines.
      sharded_keys = common_utils.shard_prng_key(step_key)
    else:
      sharded_keys = step_key

    # Train step
    optimizer, state, metrics = p_train_step(optimizer, state, batch,
                                             sharded_keys, opt_step_rng)
    train_metrics.append(metrics)

    # Quick indication that training is happening.
    logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
    if step == initial_step:
      initial_train_metrics = get_metrics(config, train_metrics)
      train_summary = jax.tree_map(lambda x: x.mean(), initial_train_metrics)
      train_summary = {'train_' + k: v for k, v in train_summary.items()}
      logging.log(logging.INFO, 'initial metrics = %s',
                  str(train_summary.items()))

    if (step + 1) % steps_per_epoch == 0:
      # We've finished an epoch
      # Save model params/state.

      train_metrics = get_metrics(config, train_metrics)
      # Get training epoch summary for logging
      train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)

      train_summary = {'train_' + k: v for k, v in train_summary.items()}

      writer.write_scalars(epoch, train_summary)
      # Reset train metrics
      train_metrics = []

      # Evaluation
      if config.do_eval:
        eval_metrics = []
        eval_logits = []
        eval_labels = []
        for _ in range(steps_per_eval):
          eval_batch = next(eval_iter)
          # Load and shard the TF batch
          eval_batch = input_pipeline.load_and_shard_tf_batch(
              config, eval_batch)
          # Step
          logits, labels, metrics = p_eval_step(optimizer.target, state,
                                                eval_batch)
          eval_metrics.append(metrics)
          eval_logits.append(logits)
          eval_labels.append(labels)
        eval_metrics = get_metrics(config, eval_metrics)
        # Get eval epoch summary for logging
        eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
        eval_summary = {'eval_' + k: v for k, v in eval_summary.items()}
        writer.write_scalars(epoch, eval_summary)

      if config.algorithm == 'sgmcmc' and ensemble_add_step(step):
        ensemble.append((serialization.to_state_dict(optimizer.target), state))

      if config.algorithm == 'sgmcmc' and ensemble_add_step(
          step) and len(ensemble) >= 1:
        # Gather predictions for this ensemble sample.
        eval_logits = jnp.concatenate(eval_logits, axis=0)
        eval_probs = jax.nn.softmax(eval_logits, axis=-1)
        eval_labels = jnp.concatenate(eval_labels, axis=0)
        # Ensure that labels are consistent between predict runs.
        if ensemble_labels:
          assert jnp.allclose(
              eval_labels,
              ensemble_labels[0]), 'Labels unordered between eval runs.'

        ensemble_logits.append(eval_logits)
        ensemble_probs.append(eval_probs)
        ensemble_labels.append(eval_labels)

        # Compute ensemble predictions over last config.ensemble_size samples.
        ensemble_last_probs = jnp.mean(
            jnp.array(ensemble_probs[-config.ensemble_size:]), axis=0)
        ensemble_metrics = train_functions.compute_metrics_probs(
            ensemble_last_probs, ensemble_labels[0])
        ensemble_summary = jax.tree_map(lambda x: x.mean(), ensemble_metrics)
        ensemble_summary = {'ens_' + k: v for k, v in ensemble_summary.items()}
        ensemble_summary['ensemble_size'] = min(config.ensemble_size,
                                                len(ensemble_probs))
        writer.write_scalars(epoch, ensemble_summary)

      epoch += 1

  return ensemble, optimizer
示例#21
0
 def save_wordvecs(self, fname):
     # generate a word -> vec dict
     state = ser.to_state_dict(self)
     with open(fname, 'wb') as f:
         pickle.dump(state, f)
示例#22
0
def _ckpt_state_dict(checkpoint_state):
  return serialization.to_state_dict({
      'pytree': serialization.to_state_dict(
          checkpoint_state.pytree),
      'pystate': serialization.to_state_dict(checkpoint_state.pystate),
  })
示例#23
0
def serialize_ExactState(vstate):
    state_dict = {
        "variables": serialization.to_state_dict(vstate.variables),
    }
    return state_dict