Beispiel #1
0
    def test_milestones(self):
        self.assertLrEquals(0.1)

        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning)

            fields = [
                'optimizer', 'lr_schedule', 'learning_rate', 'momentum',
                'weight_decay', 'lr_gamma', 'lr_milestone_steps',
                'lr_warmup_steps'
            ]
            params = ['SGD', '', 0.1, 0.5, 0.0, 0.1, '2ep,4ep,7ep,8ep', '']

            Config().trainer = namedtuple('trainer', fields)(*params)
            self.assertLrEquals(0.1)

            lrs = optimizers.get_lr_schedule(self.optimizer, 10)

            self.assertLrEquals(0.1)
            for _ in range(19):
                lrs.step()
            self.assertLrEquals(1e-1)

            for _ in range(1):
                lrs.step()
            self.assertLrEquals(1e-2)
            for _ in range(19):
                lrs.step()
            self.assertLrEquals(1e-2)

            for _ in range(1):
                lrs.step()
            self.assertLrEquals(1e-3)
            for _ in range(29):
                lrs.step()
            self.assertLrEquals(1e-3)

            for _ in range(1):
                lrs.step()
            self.assertLrEquals(1e-4)
            for _ in range(9):
                lrs.step()
            self.assertLrEquals(1e-4)

            for _ in range(1):
                lrs.step()
            self.assertLrEquals(1e-5)
            for _ in range(100):
                lrs.step()
            self.assertLrEquals(1e-5)
Beispiel #2
0
    def test_warmup(self):
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning)

            fields = [
                'optimizer', 'lr_schedule', 'learning_rate', 'momentum',
                'weight_decay', 'lr_gamma', 'lr_milestone_steps',
                'lr_warmup_steps'
            ]
            params = ['SGD', '', 0.1, 0.5, 0.0, 0.0, '', '20it']

            Config().trainer = namedtuple('trainer', fields)(*params)

            lrs = optimizers.get_lr_schedule(self.optimizer, 10)

            for i in range(20):
                self.assertLrEquals(i / 20 * 0.1)
                lrs.step()
            self.assertLrEquals(0.1)
            for i in range(100):
                lrs.step()
                self.assertLrEquals(0.1)
def main():
    """Main function"""

    # Initialization
    args = parse_args()
    dist = init_workers(args.distributed)
    config = load_config(args)
    os.makedirs(config['output_dir'], exist_ok=True)
    config_logging(verbose=args.verbose)
    logging.info('Initialized rank %i size %i local_rank %i local_size %i',
                 dist.rank, dist.size, dist.local_rank, dist.local_size)
    if dist.rank == 0:
        logging.info('Configuration: %s', config)

    # Setup MLPerf logging
    if args.mlperf:
        mllogger = configure_mllogger(config['output_dir'])
    if dist.rank == 0 and args.mlperf:
        mllogger.event(key=mllog.constants.CACHE_CLEAR)
        mllogger.start(key=mllog.constants.INIT_START)

    # Initialize Weights & Biases logging
    if args.wandb and dist.rank == 0:
        import wandb
        wandb.init(project='cosmoflow',
                   name=args.run_tag,
                   id=args.run_tag,
                   config=config,
                   resume=args.run_tag)

    # Device and session configuration
    gpu = dist.local_rank if args.rank_gpu else None
    if gpu is not None:
        logging.info('Taking gpu %i', gpu)
    configure_session(gpu=gpu,
                      intra_threads=args.intra_threads,
                      inter_threads=args.inter_threads,
                      kmp_blocktime=args.kmp_blocktime,
                      kmp_affinity=args.kmp_affinity,
                      omp_num_threads=args.omp_num_threads)

    # Mixed precision
    if args.amp:
        logging.info('Enabling mixed float16 precision')

        # Suggested bug workaround from https://github.com/tensorflow/tensorflow/issues/38516
        if tf.__version__.startswith('2.2.'):
            from tensorflow.python.keras.mixed_precision.experimental import device_compatibility_check
            device_compatibility_check.log_device_compatibility_check = lambda policy_name, skip_local: None
        tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
        # TF 2.3
        #tf.keras.mixed_precision.set_global_policy('mixed_float16')

    # Start MLPerf logging
    if dist.rank == 0 and args.mlperf:
        log_submission_info(**config.get('mlperf', {}))
        mllogger.end(key=mllog.constants.INIT_STOP)
        mllogger.start(key=mllog.constants.RUN_START)

    # Load the data
    data_config = config['data']
    if dist.rank == 0:
        logging.info('Loading data')
    datasets = get_datasets(dist=dist, **data_config)
    logging.debug('Datasets: %s', datasets)

    # Construct or reload the model
    if dist.rank == 0:
        logging.info('Building the model')
    train_config = config['train']
    initial_epoch = 0
    checkpoint_format = os.path.join(config['output_dir'],
                                     'checkpoint-{epoch:03d}.h5')
    if args.resume and os.path.exists(checkpoint_format.format(epoch=1)):
        # Reload model from last checkpoint
        initial_epoch, model = reload_last_checkpoint(
            checkpoint_format,
            data_config['n_epochs'],
            distributed=args.distributed)
    else:
        # Build a new model
        model = get_model(**config['model'])
        # Configure the optimizer
        opt = get_optimizer(distributed=args.distributed,
                            **config['optimizer'])
        # Compile the model
        model.compile(optimizer=opt,
                      loss=train_config['loss'],
                      metrics=train_config['metrics'])

    if dist.rank == 0:
        model.summary()

    # Save configuration to output directory
    if dist.rank == 0:
        config['n_ranks'] = dist.size
        save_config(config)

    # Prepare the callbacks
    if dist.rank == 0:
        logging.info('Preparing callbacks')
    callbacks = []
    if args.distributed:

        # Broadcast initial variable states from rank 0 to all processes.
        callbacks.append(hvd.callbacks.BroadcastGlobalVariablesCallback(0))

        # Average metrics across workers
        callbacks.append(hvd.callbacks.MetricAverageCallback())

    # Learning rate decay schedule
    if 'lr_schedule' in config:
        global_batch_size = data_config['batch_size'] * dist.size
        callbacks.append(
            tf.keras.callbacks.LearningRateScheduler(
                get_lr_schedule(global_batch_size=global_batch_size,
                                **config['lr_schedule'])))

    # Timing
    timing_callback = TimingCallback()
    callbacks.append(timing_callback)

    # Checkpointing and logging from rank 0 only
    if dist.rank == 0:
        callbacks.append(tf.keras.callbacks.ModelCheckpoint(checkpoint_format))
        callbacks.append(
            tf.keras.callbacks.CSVLogger(os.path.join(config['output_dir'],
                                                      'history.csv'),
                                         append=args.resume))
        if args.tensorboard:
            callbacks.append(
                tf.keras.callbacks.TensorBoard(
                    os.path.join(config['output_dir'], 'tensorboard')))
        if args.mlperf:
            callbacks.append(MLPerfLoggingCallback())
        if args.wandb:
            callbacks.append(wandb.keras.WandbCallback())

    # Early stopping
    patience = train_config.get('early_stopping_patience', None)
    if patience is not None:
        callbacks.append(
            tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                             min_delta=1e-5,
                                             patience=patience,
                                             verbose=1))

    # Stopping at specified target
    target_mae = train_config.get('target_mae', None)
    callbacks.append(StopAtTargetCallback(target_max=target_mae))

    if dist.rank == 0:
        logging.debug('Callbacks: %s', callbacks)

    # Train the model
    if dist.rank == 0:
        logging.info('Beginning training')
    fit_verbose = 1 if (args.verbose and dist.rank == 0) else 2
    model.fit(datasets['train_dataset'],
              steps_per_epoch=datasets['n_train_steps'],
              epochs=data_config['n_epochs'],
              validation_data=datasets['valid_dataset'],
              validation_steps=datasets['n_valid_steps'],
              callbacks=callbacks,
              initial_epoch=initial_epoch,
              verbose=fit_verbose)

    # Stop MLPerf timer
    if dist.rank == 0 and args.mlperf:
        mllogger.end(key=mllog.constants.RUN_STOP,
                     metadata={'status': 'success'})

    # Print training summary
    if dist.rank == 0:
        print_training_summary(config['output_dir'], args.print_fom)

    # Print GPU memory - not supported in TF 2.2?
    #if gpu is not None:
    #    device = tf.config.list_physical_devices('GPU')[gpu]
    #    #print(tf.config.experimental.get_memory_usage(device))
    #    #print(tf.config.experimental.get_memory_info(device))

    # Finalize
    if dist.rank == 0:
        logging.info('All done!')
Beispiel #4
0
    def train_process(rank, self, config, trainset, sampler, cut_layer=None):  # pylint: disable=unused-argument
        """The main training loop in a federated learning workload, run in
          a separate process with a new CUDA context, so that CUDA memory
          can be released after the training completes.

        Arguments:
        rank: Required by torch.multiprocessing to spawn processes. Unused.
        config: a dictionary of configuration parameters.
        trainset: The training dataset.
        sampler: the sampler that extracts a partition for this client.
        cut_layer (optional): The layer which training should start from.
        """
        run = wandb.init(project="plato",
                         group=str(config['run_id']),
                         reinit=True)

        custom_train = getattr(self, "train_model", None)

        if callable(custom_train):
            self.train_model(config, trainset, sampler.get(), cut_layer)
        else:
            log_interval = 10
            batch_size = config['batch_size']

            logging.info("[Client #%s] Loading the dataset.", self.client_id)
            _train_loader = getattr(self, "train_loader", None)

            if callable(_train_loader):
                train_loader = _train_loader(batch_size, trainset,
                                             sampler.get(), cut_layer)
            else:
                train_loader = torch.utils.data.DataLoader(
                    dataset=trainset,
                    shuffle=False,
                    batch_size=batch_size,
                    sampler=sampler.get())

            iterations_per_epoch = np.ceil(len(trainset) /
                                           batch_size).astype(int)
            epochs = config['epochs']

            # Sending the model to the device used for training
            self.model.to(self.device)
            self.model.train()

            # Initializing the loss criterion
            _loss_criterion = getattr(self, "loss_criterion", None)
            if callable(_loss_criterion):
                loss_criterion = _loss_criterion(self.model)
            else:
                loss_criterion = nn.CrossEntropyLoss()

            # Initializing the optimizer
            get_optimizer = getattr(self, "get_optimizer",
                                    optimizers.get_optimizer)
            optimizer = get_optimizer(self.model)

            # Initializing the learning rate schedule, if necessary
            if hasattr(Config().trainer, 'lr_schedule'):
                lr_schedule = optimizers.get_lr_schedule(
                    optimizer, iterations_per_epoch, train_loader)
            else:
                lr_schedule = None

            for epoch in range(1, epochs + 1):
                for batch_id, (examples, labels) in enumerate(train_loader):
                    examples, labels = examples.to(self.device), labels.to(
                        self.device)
                    optimizer.zero_grad()

                    if cut_layer is None:
                        outputs = self.model(examples)
                    else:
                        outputs = self.model.forward_from(examples, cut_layer)

                    loss = loss_criterion(outputs, labels)

                    loss.backward()

                    optimizer.step()

                    if lr_schedule is not None:
                        lr_schedule.step()

                    if batch_id % log_interval == 0:
                        if self.client_id == 0:
                            logging.info(
                                "[Server #{}] Epoch: [{}/{}][{}/{}]\tLoss: {:.6f}"
                                .format(os.getpid(), epoch, epochs, batch_id,
                                        len(train_loader), loss.data.item()))
                        else:
                            wandb.log({"batch loss": loss.data.item()})

                            logging.info(
                                "[Client #{}] Epoch: [{}/{}][{}/{}]\tLoss: {:.6f}"
                                .format(self.client_id, epoch, epochs,
                                        batch_id, len(train_loader),
                                        loss.data.item()))
                if hasattr(optimizer, "params_state_update"):
                    optimizer.params_state_update()

        self.model.cpu()

        model_type = Config().trainer.model_name
        filename = f"{model_type}_{self.client_id}_{config['run_id']}.pth"
        self.save_model(filename)

        run.finish()
Beispiel #5
0
def main():
    """Main function"""

    # Initialization
    args = parse_args()
    dist = init_workers(args.distributed)
    config = load_config(args)
    os.makedirs(config['output_dir'], exist_ok=True)
    config_logging(verbose=args.verbose)
    logging.info('Initialized rank %i size %i local_rank %i local_size %i',
                 dist.rank, dist.size, dist.local_rank, dist.local_size)
    if dist.rank == 0:
        logging.info('Configuration: %s', config)

    # Device and session configuration
    gpu = dist.local_rank if args.rank_gpu else None
    if gpu is not None:
        logging.info('Taking gpu %i', gpu)
    configure_session(gpu=gpu,
                      intra_threads=args.intra_threads,
                      inter_threads=args.inter_threads,
                      kmp_blocktime=args.kmp_blocktime,
                      kmp_affinity=args.kmp_affinity,
                      omp_num_threads=args.omp_num_threads)

    # Load the data
    data_config = config['data']
    if dist.rank == 0:
        logging.info('Loading data')
    datasets = get_datasets(dist=dist, **data_config)
    logging.debug('Datasets: %s', datasets)

    # Construct or reload the model
    if dist.rank == 0:
        logging.info('Building the model')
    train_config = config['train']
    initial_epoch = 0
    checkpoint_format = os.path.join(config['output_dir'],
                                     'checkpoint-{epoch:03d}.h5')
    if args.resume and os.path.exists(checkpoint_format.format(epoch=1)):
        # Reload model from last checkpoint
        initial_epoch, model = reload_last_checkpoint(
            checkpoint_format,
            data_config['n_epochs'],
            distributed=args.distributed)
    else:
        # Build a new model
        model = get_model(**config['model'])
        # Configure the optimizer
        opt = get_optimizer(distributed=args.distributed,
                            **config['optimizer'])
        # Compile the model
        model.compile(optimizer=opt,
                      loss=train_config['loss'],
                      metrics=train_config['metrics'])

    if dist.rank == 0:
        model.summary()

    # Save configuration to output directory
    if dist.rank == 0:
        config['n_ranks'] = dist.size
        save_config(config)

    # Prepare the callbacks
    if dist.rank == 0:
        logging.info('Preparing callbacks')
    callbacks = []
    if args.distributed:

        # Broadcast initial variable states from rank 0 to all processes.
        callbacks.append(hvd.callbacks.BroadcastGlobalVariablesCallback(0))

        # Average metrics across workers
        callbacks.append(hvd.callbacks.MetricAverageCallback())

    # Learning rate decay schedule
    if 'lr_schedule' in config:
        global_batch_size = data_config['batch_size'] * dist.size
        callbacks.append(
            tf.keras.callbacks.LearningRateScheduler(
                get_lr_schedule(global_batch_size=global_batch_size,
                                **config['lr_schedule'])))

    # Timing
    timing_callback = TimingCallback()
    callbacks.append(timing_callback)

    # Checkpointing and logging from rank 0 only
    if dist.rank == 0:
        callbacks.append(tf.keras.callbacks.ModelCheckpoint(checkpoint_format))
        callbacks.append(
            tf.keras.callbacks.CSVLogger(os.path.join(config['output_dir'],
                                                      'history.csv'),
                                         append=args.resume))
        if args.tensorboard:
            callbacks.append(
                tf.keras.callbacks.TensorBoard(
                    os.path.join(config['output_dir'], 'tensorboard')))

    # Early stopping
    patience = config.get('early_stopping_patience', None)
    if patience is not None:
        callbacks.append(
            tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                             min_delta=1e-5,
                                             patience=patience,
                                             verbose=1))

    if dist.rank == 0:
        logging.debug('Callbacks: %s', callbacks)

    # Train the model
    if dist.rank == 0:
        logging.info('Beginning training')
    fit_verbose = 1 if (args.verbose and dist.rank == 0) else 2
    model.fit(datasets['train_dataset'],
              steps_per_epoch=datasets['n_train_steps'],
              epochs=data_config['n_epochs'],
              validation_data=datasets['valid_dataset'],
              validation_steps=datasets['n_valid_steps'],
              callbacks=callbacks,
              initial_epoch=initial_epoch,
              verbose=fit_verbose)

    # Print training summary
    if dist.rank == 0:
        print_training_summary(config['output_dir'], args.print_fom)

    # Finalize
    if dist.rank == 0:
        logging.info('All done!')