Beispiel #1
0
 def fit(
     self, parameters: List[np.ndarray], config: Dict[str, str]
 ) -> Tuple[List[np.ndarray], int]:
     # Set model parameters, train model, return updated model parameters
     self.set_parameters(parameters)
     cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE)
     return self.get_parameters(), len(self.trainloader), {}
Beispiel #2
0
    def fit(self, ins: FitIns) -> FitRes:
        # Set the seed so we are sure to generate the same global batches
        # indices across all clients
        np.random.seed(123)

        print(f"Client {self.cid}: fit")

        weights: Weights = fl.common.parameters_to_weights(ins[0])
        config = ins[1]
        fit_begin = timeit.default_timer()

        # Get training config
        epochs = int(config["epochs"])
        batch_size = int(config["batch_size"])

        # Set model parameters
        #self.model.set_weights(weights)
        set_weights(self.model, weights)
        """
        # Get the data corresponding to this client
        dataset_size = len(self.trainset)
        nb_samples_per_clients = dataset_size // self.nb_clients
        dataset_indices = list(range(dataset_size))
        np.random.shuffle(dataset_indices)

        #Get starting and ending indices w.r.t cid
        start_ind = int(self.cid) * nb_samples_per_clients
        end_ind = (int(self.cid) * nb_samples_per_clients) + nb_samples_per_clients
        train_sampler = torch.utils.data.SubsetRandomSampler(dataset_indices[start_ind:end_ind])

        print(f"Client {self.cid}: sampler {len(train_sampler)}")
        """
        # load IID partitioned dataset
        trainset = dataset_afterpartition(train=True, client_id=int(self.cid))
        testset = dataset_afterpartition(train=False, client_id=int(self.cid))

        # Train model
        #trainloader = torch.utils.data.DataLoader(self.trainset, batch_size=batch_size, shuffle=False, sampler=train_sampler)
        trainloader = torch.utils.data.DataLoader(self.trainset,
                                                  batch_size=batch_size,
                                                  shuffle=False)

        cifar.train(self.model, trainloader, epochs=epochs, device=DEVICE)

        # Return the refined weights and the number of examples used for training
        #weights_prime: Weights = self.model.get_weights()
        weights_prime: Weights = get_weights(self.model)
        params_prime = fl.common.weights_to_parameters(weights_prime)
        num_examples_train = len(self.trainset)
        fit_duration = timeit.default_timer() - fit_begin
        return params_prime, num_examples_train, num_examples_train, fit_duration
Beispiel #3
0
    def fit(self, ins: FitIns) -> FitRes:
        # Set the seed so we are sure to generate the same global batches
        # indices across all clients
        np.random.seed(123)

        print(f"Client {self.cid}: fit")

        weights: Weights = fl.common.parameters_to_weights(ins[0])
        config = ins[1]
        fit_begin = timeit.default_timer()

        # Get training config
        epochs = int(config["epochs"])
        batch_size = int(config["batch_size"])
        rnd = int(config["epoch_global"])
        lr = 0.001 * 0.98**(rnd)

        # Set model parameters
        #self.model.set_weights(weights)
        set_weights(self.model, weights)

        # get IID dataset
        trainset = dataset_afterpartition(
            train=True,
            client_id=int(self.cid),
            num_partitions=self.nb_clients,
            xy_train_partitions=self.xy_train_partitions,
            xy_test_partitions=self.xy_test_partitions)

        # Train model
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=batch_size,
                                                  shuffle=False)

        cifar.train(self.model,
                    trainloader,
                    epochs=epochs,
                    lr=lr,
                    device=DEVICE)

        # Return the refined weights and the number of examples used for training
        #weights_prime: Weights = self.model.get_weights()
        weights_prime: Weights = get_weights(self.model)
        params_prime = fl.common.weights_to_parameters(weights_prime)
        num_examples_train = len(trainset)
        fit_duration = timeit.default_timer() - fit_begin
        return params_prime, num_examples_train, num_examples_train, fit_duration
 def evalute(weights: fl.common.Weights) -> Optional[Tuple[float, float]]:
     epochs = int(config["epochs"])
     model = models.resnet18()
     set_weights(model, weights)
     model.to(DEVICE)
     trainloader = torch.utils.data.DataLoader(trainset,
                                               batch_size=32,
                                               shuffle=False)
     return cifar.train(model, trainloader, epochs, device=DEVICE)
Beispiel #5
0
def train(dataset_no):
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        images, labels = cifar.distorted_inputs(dataset_no)

        logits = cifar.inference(images)

        loss = cifar.loss(logits, labels)

        train_op = cifar.train(loss, global_step)

        saver = tf.train.Saver(tf.all_variables())

        summary_op = tf.merge_all_summaries()

        init = tf.initialize_all_variables()

        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

        for step in range(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f )'
                    'sec/batch')
                print(format_str % (datetime.datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))

            if step % 100 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
Beispiel #6
0
def train(dataset_no):
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        images, labels = cifar.distorted_inputs(dataset_no)

        logits = cifar.inference(images)

        loss = cifar.loss(logits, labels)

        train_op = cifar.train(loss, global_step)

        saver = tf.train.Saver(tf.all_variables())

        summary_op = tf.merge_all_summaries()

        init = tf.initialize_all_variables()

        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

        for step in range(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f )'
                    'sec/batch')
                print(format_str % (datetime.datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))

            if step % 100 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
def train():
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            images, labels = cifar.distorted_inputs(
                is_cifar10=FLAGS.dataset == 'cifar-10')

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar.inference(images)

        # Calculate loss.
        loss = cifar.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        config = tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)
        util.tf.gpu_config(config=config, allow_growth=True)

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                save_checkpoint_secs=30,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=config) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)