예제 #1
0
파일: greedy.py 프로젝트: ketanbj/castnet
 def replicate(self, content, source, dest):
     print 'Greedy: replicate file %s from %s to %s', (content, source,
                                                       dest)
     if source == dest:
         if dest not in self.replica_map[content]:
             self.replica_map[content] = 0
         self.replica_map[content][dest] += 1
     else:
         util.replicate(content, source, dest)
예제 #2
0
 def replicate(self, content, source, dest):
     print 'Greedy: replicate file %s from %s to %s', (content, source,
                                                       dest)
     if source == dest:
         # a server can have at most one replcia, se we replicate to second nearest server
         dest_simulation_ip = util.convert_to_simulation_ip(target)
         candidate_servers = self.server_set - set(
             self.replica_map[content])
         dest = util.find_closest_servers_with_ip(
             dest_simulation_ip, candidate_servers)[0]['server']
     util.replicate(content, source, dest)
예제 #3
0
파일: greedy.py 프로젝트: ketanbj/castnet
 def add_replica(self, request_delta, replica_delta):
     I = []
     for c in self.content_set:
         if c in self.access_map:
             for a in self.access_map[c].keys():
                 # add a small amount of requests for content c from client a
                 self.access_map[c][a] += request_delta
                 # test whether current replicas can handle that much request
                 if not self.enough_replica():
                     I.append((a, c))
                 # back tracking,
                 self.access_map[c][a] -= request_delta
     max_satisfied_num = 0
     best_c = None
     best_s = None
     # find the server s to replicate content c so that
     # maximum number of starved clients can be satisfied
     for a, c in I:
         for s in self.server_set:
             satisfied_num = 0
             self.access_map[c][a] += request_delta
             if s not in self.replica_map[c]:
                 self.replica_map[c][s] = 0
             self.replica_map[c][s] += replica_delta
             if self.enough_replica():
                 satisfied_num += 1
             self.access_map[c][a] -= request_delta
             self.replica_map[c][s] -= replica_delta
             if self.replica_map[c][s] == 0:
                 self.replica_map[c].pop(s)
             if (satisfied_num > max_satisfied_num):
                 max_satisfied_num = satisfied_num
                 best_c = c
                 best_s = s
     if max_satisfied_num > 0:
         source = self.replica_map[best_c].iterkeys().next()
         if source == best_s:
             # can't hold more than 1 replica, replicate to a random other server
             best_s = random.sample(self.server_set - set([source]), 1)[0]
         self.replicate(best_c, source, best_s)
     else:
         # replicate everything
         print 'replicate to all servers'
         for content in self.content_set:
             if not self.enough_replica_for_content(content):
                 if content not in self.replica_map:
                     continue
                 source = self.replica_map[content].iterkeys().next()
                 #select first none zero replica
                 for server in self.server_set:
                     print 'replicate ' + 'content: ' + content + ' from: ' + source + ' to ' + server
                     util.replicate(content, source, server)
예제 #4
0
파일: loadcfg.py 프로젝트: areinecke/cylc
def load_combined( FILE1, descr1,
                      FILE2, descr2,
                      SPEC, upgrader=None,
                      do_expand=False, verbose=True):
    """
    Parse, upgrade, validate, combine/override, and expand two parsec config files.
    """
    cfg1 = load_single( FILE1, SPEC, descr1, upgrader, False, verbose )
    cfg2 = load_single( FILE2, SPEC, descr2, upgrader, False, verbose )

    if cfg2:
        replicate( cfg1, cfg2 )
    if do_expand:
        cfg = expand( cfg1, SPEC )
    else:
        cfg = cfg1
    return cfg
예제 #5
0
파일: loadcfg.py 프로젝트: areinecke/cylc
def load_combined(FILE1,
                  descr1,
                  FILE2,
                  descr2,
                  SPEC,
                  upgrader=None,
                  do_expand=False,
                  verbose=True):
    """
    Parse, upgrade, validate, combine/override, and expand two parsec config files.
    """
    cfg1 = load_single(FILE1, SPEC, descr1, upgrader, False, verbose)
    cfg2 = load_single(FILE2, SPEC, descr2, upgrader, False, verbose)

    if cfg2:
        replicate(cfg1, cfg2)
    if do_expand:
        cfg = expand(cfg1, SPEC)
    else:
        cfg = cfg1
    return cfg
예제 #6
0
    def train(self, key, data_loader, model, start_it, checkpoint_iters=5000, save_hook=None):

        losses = []

        # Copy the model state
        replicated_model_state = util.replicate((self.n_gpus,), model.state)
        replicated_opt_state   = util.replicate((self.n_gpus,), self.opt_state)

        bits_per_dim_scale = jnp.prod(model.x_shape)*jnp.log(2)

        pbar = tqdm(np.arange(start_it, self.max_iterations), initial=start_it)
        for i in pbar:

            # Save a checkpoint
            if(i%checkpoint_iters == 0):

                # Update the model and the optimizer state
                model_state    = util.unreplicate(replicated_model_state)
                opt_state      = util.unreplicate(replicated_opt_state)
                model.state    = model_state
                self.opt_state = opt_state

                # Checkpoint the current model.  Also reset the losses array so that it doesn't get too big
                losses = save_hook(i, key, losses)

            # Make sure we do this after the checkpoint
            key, *keys = fast_split(key, 2)

            # Take a gradient step
            val, replicated_model_state, replicated_opt_state = self.train_step(keys[0], i, replicated_model_state, replicated_opt_state, data_loader)

            # Convert to bits per dimension and save
            val /= bits_per_dim_scale
            losses.append(val)

            progress_str = f'Bits/Dim: {val:.3f}'
            pbar.set_description(progress_str)
예제 #7
0
def train(opts):

    run = u.DTS()
    logging.info("starting run %s", run)

    # # init w & b
    wandb_enabled = opts.group is not None
    if wandb_enabled and u.primary_host():
        wandb.init(project='ensemble_net',
                   group=opts.group,
                   name=run,
                   reinit=True)
        # save group again explicitly to work around sync bug that drops
        # group when 'wandb off'
        wandb.config.group = opts.group
        wandb.config.seed = opts.seed
        wandb.config.max_conv_size = opts.max_conv_size
        wandb.config.dense_kernel_size = opts.dense_kernel_size
        wandb.config.models_per_device = opts.models_per_device
        wandb.config.learning_rate = opts.learning_rate
        wandb.config.batch_size = opts.batch_size
        wandb.config.steps_per_batch = opts.steps_per_batch
    else:
        logging.info("not using wandb and/or not primary host")

    logging.info("build_model")
    model = models.build_model(opts)

    num_devices = len(jax.local_devices())
    num_models = num_devices * opts.models_per_device

    # we make two rngs; one that is distinct per host and one
    # that will be common across the pod
    host_rng = jax.random.PRNGKey(opts.seed ^ jax.host_id())
    pod_rng = jax.random.PRNGKey(opts.seed * 2)  # o_O

    logging.info("init models")
    keys = jax.random.split(host_rng, num_models)
    logging.debug("model keys %s" % list(keys))
    representative_input = jnp.zeros((1, 64, 64, 3))
    params = vmap(lambda k: model.init(k, representative_input))(keys)

    logging.info("init optimisers")
    opt = optax.adam(opts.learning_rate)
    opt_states = vmap(opt.init)(params)

    def reshape_for_devices_and_shard(p):
        return u.shard(
            u.reshape_leading_axis(p, (num_devices, opts.models_per_device)))

    logging.info("treemap reshape params")
    params = tree_map(reshape_for_devices_and_shard, params)
    opt_states = tree_map(reshape_for_devices_and_shard, opt_states)

    # -----------------------------------
    # prepare loss / training functions

    def mean_ensemble_xent(params, x, y_true):
        logits = model.apply(params, x)
        logits = psum(logits, axis_name='device')
        return jnp.mean(softmax_cross_entropy(logits, y_true))

    def update(params, opt_state, sub_model_idx, x, y_true):
        # select the sub model & corresponding optimiser state to use
        sub_params = tree_map(lambda v: v[sub_model_idx], params)
        sub_opt_state = tree_map(lambda v: v[sub_model_idx], opt_state)
        # calculate loss and gradients; summing logits over all selected models
        losses, grads = value_and_grad(mean_ensemble_xent)(sub_params, x,
                                                           y_true)
        # apply optimiser
        updates, sub_opt_state = opt.update(grads, sub_opt_state)
        sub_params = optax.apply_updates(sub_params, updates)

        # assign updated values back into params and optimiser state
        def update_sub_model(values, update_value):
            return jax.ops.index_update(values, sub_model_idx, update_value)

        params = tree_multimap(update_sub_model, params, sub_params)
        opt_state = tree_multimap(update_sub_model, opt_state, sub_opt_state)
        # return
        return params, opt_state, losses

    logging.info("compile pmap update")
    p_update = pmap(update, in_axes=(0, 0, 0, 0, 0), axis_name='device')

    # -----------------------------------
    # prepare evaluation functions

    # plumb batch dimension for models_per_device
    all_models_apply = vmap(model.apply, in_axes=(0, None))
    # plumb batch dimension for num_devices
    all_models_apply = vmap(all_models_apply, in_axes=(0, None))

    def ensemble_logits(params, imgs):
        logits = all_models_apply(params, imgs)
        batch_size = logits.shape[-2]  # since last batch may be smaller
        num_classes = 10
        logits = logits.reshape((-1, batch_size, num_classes))  # (M, B, 10)
        ensemble_logits = jnp.sum(logits, axis=0)  # (B, 10)
        return ensemble_logits

    @jit
    def total_ensemble_xent_loss(params, x, y_true):
        y_pred_logits = ensemble_logits(params, x)
        return jnp.sum(softmax_cross_entropy(y_pred_logits, y_true))

    # --------------------------------
    # run training loop

    for epoch in range(opts.epochs):

        # train for one epoch
        logging.info("data.training_dataset: epoch %d", epoch)

        total_training_loss = 0
        training_num_examples = 0

        # split out a new shuffle seed for this epoch common
        # across pod
        pod_rng, shuffle_seed = jax.random.split(pod_rng)

        # create dataset
        train_ds = data.training_dataset(batch_size=opts.batch_size,
                                         shuffle_seed=shuffle_seed[0],
                                         num_inputs=1,
                                         sample_data=opts.sample_data)

        for imgs, labels in train_ds:

            logging.debug("labels %s" % labels)

            # replicate batch across M devices
            # (M, B, H, W, 3)
            imgs = u.replicate(imgs, replicas=num_devices)
            labels = u.replicate(labels, replicas=num_devices)  # (M, B)

            # run across all the 4 rotations
            # for k in range(4):
            #   rotated_imgs = rot90_imgs(imgs, k)

            # run some steps for this set, each with a different set of
            # dropout idxs
            for _ in range(opts.steps_per_batch):
                host_rng, dropout_key = jax.random.split(host_rng)
                logging.debug("dropout_key %s" % dropout_key[0])
                sub_model_idxs = jax.random.randint(
                    dropout_key,
                    minval=0,
                    maxval=opts.models_per_device,
                    shape=(num_devices, ))
                logging.debug("sub_model_idxs %s" % sub_model_idxs)
                params, opt_states, losses = p_update(params, opt_states,
                                                      sub_model_idxs, imgs,
                                                      labels)
                logging.debug("losses %s" % losses)

                total_training_loss += jnp.sum(losses)
                training_num_examples += len(losses)

        mean_training_loss = float(total_training_loss / training_num_examples)
        logging.info("mean training loss %f", mean_training_loss)

        # post epoch stats collection and housekeeping on primary host only
        if u.primary_host():
            # checkpoint model
            ckpt_file = f"saved_models/{run}/ckpt_{epoch:04d}"
            u.ensure_dir_exists_for_file(ckpt_file)
            with open(ckpt_file, "wb") as f:
                pickle.dump(params, f)

            # run validation
            total_validation_loss = 0
            validation_num_examples = 0
            validation_data = data.validation_dataset(
                batch_size=opts.batch_size, sample_data=opts.sample_data)
            for imgs, labels in validation_data:
                total_validation_loss += total_ensemble_xent_loss(
                    params, imgs, labels)
                validation_num_examples += len(labels)
            mean_validation_loss = float(total_validation_loss /
                                         validation_num_examples)
            logging.info("mean validation loss %f", mean_validation_loss)

            if wandb_enabled:
                wandb.log({'training_loss': mean_training_loss}, step=epoch)
                wandb.log({'validation_loss': mean_validation_loss},
                          step=epoch)

    # close out wandb run
    if u.primary_host():
        if wandb_enabled:
            wandb.log({'final_validation_loss': mean_validation_loss},
                      step=opts.epochs)
            wandb.join()
        else:
            logging.info("finished %s final validation_loss %f" %
                         (run, mean_validation_loss))
        # return validation loss to ax
        return mean_validation_loss
    else:
        return None
예제 #8
0
파일: config.py 프로젝트: dmanubens-zz/cylc
            try:
                self.validate(sparse)
            except Exception, x:
                if strict:
                    raise
                if cylc.flags.verbose:
                    print >> sys.stderr, x
                    print >> sys.stderr, "WARNING: " + title + " validation failed"

            else:
                if not self.sparse:
                    self.sparse = sparse
                else:
                    # already loaded, this must be an override
                    replicate(self.sparse, sparse)

    def validate(self, sparse):
        "Validate sparse config against the file spec."
        validate(sparse, self.spec)
        check_compulsory(sparse, self.spec)

    def expand(self):
        "Flesh out undefined items with defaults, if any, from the spec."
        if not self.dense:
            self.dense = expand(self.sparse, self.spec)

    def get(self, keys=[], sparse=False):
        """
        Retrieve items or sections, sparse or dense, by list of keys:
        [sec1,sec2,item] =>
예제 #9
0
파일: config.py 프로젝트: raghu330/cylc
            try:
                self.validate( sparse )
            except Exception, x:
                if strict:
                    raise
                if cylc.flags.verbose:
                    print >> sys.stderr, x
                    print >> sys.stderr, "WARNING: " + title + " validation failed"

            else:
                if not self.sparse:
                    self.sparse = sparse
                else:
                    # already loaded, this must be an override
                    replicate( self.sparse, sparse )

    def validate( self, sparse ):
        "Validate sparse config against the file spec."
        validate( sparse, self.spec )
        check_compulsory( sparse, self.spec )

    def expand( self ):
        "Flesh out undefined items with defaults, if any, from the spec."
        if not self.dense:
            self.dense = expand( self.sparse, self.spec )

    def get( self, keys=[], sparse=False ):
        """
        Retrieve items or sections, sparse or dense, by list of keys:
        [sec1,sec2,item] =>
예제 #10
0
# construct and init optimiser

if opts.optimiser == 'adam':
    opt = optax.adam(learning_rate=opts.learning_rate)
elif opts.optimiser == 'lamb':
    opt = optax.lamb(learning_rate=opts.learning_rate,
                     weight_decay=opts.weight_decay)
else:  # sgd
    opt = optax.sgd(learning_rate=opts.learning_rate,
                    momentum=opts.momentum)

opt_state = opt.init(params)

# replicate both model params and optimiser state across devices

params = u.replicate(params)
opt_state = u.replicate(opt_state)

# define training loops and some validation functions


def softmax_cross_entropy(logits, labels):
    one_hot = jax.nn.one_hot(labels, logits.shape[-1])
    return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)


def mean_cross_entropy(params, x, y_true):
    logits = model.apply(params, x)
    return jnp.mean(softmax_cross_entropy(logits, y_true))