Exemplo n.º 1
0
def init_distributed(args, cfg, num_gpus):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = args.master_port if args.master_port else '12356'

    # initialize the process group
    timeout = timedelta(0, 25)  # 25 seconds
    dist.init_process_group("nccl", rank=args.local_rank, world_size=num_gpus, timeout=timeout)

    try:
        start(args, cfg)
    except InterruptException as _:
        print("Training session was interrupted")

    dist_utils.synchronize()
    dist.destroy_process_group()
Exemplo n.º 2
0
def start(args, cfg):
    # suppress Python warnings from sub-processes to prevent duplicate warnings being printed to console
    if dist_utils.get_rank() > 0:
        warnings.filterwarnings("ignore")

    logger = create_logger(args)
    model_save_dir = os.path.join(ModelPaths.checkpoint_base_dir(), cfg.MODE,
                                  args.model_dir)

    if dist_utils.is_main_process():
        os.makedirs(model_save_dir, exist_ok=True)

    # check if a checkpoint already exists in the model save directory. If it does, and the 'no_resume' flag is not set,
    # training should resume from the last pre-existing checkpoint.
    existing_ckpts = sorted(glob(os.path.join(model_save_dir, "*.pth")))
    if existing_ckpts and not args.no_resume:
        args.restore_session = existing_ckpts[-1]
        args.initial_ckpt = None  # when jobs auto-restart on the cluster, this might be set,
        # however we want to use the latest checkpoint instead

    # backup config to model directory
    if dist_utils.is_main_process():
        with open(os.path.join(model_save_dir, 'config.yaml'),
                  'w') as writefile:
            yaml.dump(global_cfg.d(), writefile)

    trainer = Trainer(cfg, model_save_dir, args, logger)

    try:
        trainer.start(args)
    except InterruptException as _:
        if dist_utils.is_main_process():
            print("Interrupt signal received. Saving checkpoint...")
            trainer.backup_session()
            dist_utils.synchronize()
        exit(1)
    except Exception as err:
        if dist_utils.is_main_process():
            print("Exception occurred. Saving checkpoint...")
            print(err)
            trainer.backup_session()
            if dist_utils.is_distributed():
                dist.destroy_process_group()
        raise err
Exemplo n.º 3
0
    def start(self, opts):
        max_samples_per_gpu = self.cfg.MAX_SAMPLES_PER_GPU
        batch_size = self.cfg.BATCH_SIZE
        accumulate_gradients = self.cfg.ACCUMULATE_GRADIENTS

        dataset = create_training_dataset(self.total_iterations * batch_size,
                                          print_fn=self.console_logger.info)

        if accumulate_gradients:
            assert batch_size >= self.num_gpus, "Batch size ({}) must be >= number of GPUs ({})".format(
                batch_size, self.num_gpus)

            optimizer_step_interval = int(
                batch_size / (max_samples_per_gpu * self.num_gpus))
            assert batch_size % max_samples_per_gpu == 0, \
                "Batch size ({}) must be divisible by number of samples per GPU ({})".format(
                    batch_size, max_samples_per_gpu)

            if self.is_main_process:
                self.console_logger.info(
                    "Optimizer will be run every {} iterations".format(
                        optimizer_step_interval))
        else:
            if batch_size > max_samples_per_gpu:
                raise ValueError(
                    "A batch size of {} cannot be processed. Max samples per GPU = {}"
                    .format(batch_size, max_samples_per_gpu))

            max_samples_per_gpu = batch_size
            optimizer_step_interval = 1

        if self.is_main_process:
            n_trainable_params = sum(p.numel()
                                     for p in self.model.parameters()
                                     if p.requires_grad)
            self.console_logger.info(
                "Commencing/resuming training with the following settings:\n"
                "- Elapsed iterations: %d\n"
                "- Total iterations: %d\n"
                "- Batch size: %d\n"
                "- Optimizer step interval: %d\n"
                "- Model save directory: %s\n"
                "- Save interval: %d\n"
                "- Trainable parameters: %d" %
                (self.elapsed_iterations, self.total_iterations, batch_size,
                 optimizer_step_interval, self.model_save_dir,
                 opts.save_interval, n_trainable_params))

            self.logger.total_iterations = self.total_iterations
            self.logger.start_timer()

        output_manager = ModelOutputManager(optimizer_step_interval)

        data_loader = create_training_data_loader(dataset, max_samples_per_gpu,
                                                  True, collate_fn,
                                                  opts.num_cpu_workers,
                                                  self.elapsed_iterations)

        self.interrupt_detector.start()

        sub_iter_idx = 0

        for image_seqs, targets, meta_info in data_loader:
            model_output = self.model(
                image_seqs.to(device=self.local_device),
                tensor_struct_to(targets, device=self.local_device))

            dist_utils.synchronize()
            if self.interrupt_detector.is_interrupted:
                raise InterruptException()

            optim_loss = output_manager(model_output)

            if self.cfg.MIXED_PRECISION:
                with amp.scale_loss(optim_loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                optim_loss.backward()

            sub_iter_idx += 1
            if sub_iter_idx < optimizer_step_interval:
                continue

            sub_iter_idx = 0

            self.optimizer.step()
            self.lr_scheduler.step()
            self.optimizer.zero_grad()
            self.elapsed_iterations += 1

            logging_vars, _ = output_manager.reset()
            logging_vars = dist_utils.reduce_dict(logging_vars, average=True)
            logging_vars = {k: v.item() for k, v in logging_vars.items()}

            if self.is_main_process:
                add_to_summary = self.elapsed_iterations % opts.summary_interval == 0
                self.logger.add_training_point(self.elapsed_iterations,
                                               add_to_summary, **logging_vars)

                if hasattr(self.lr_scheduler,
                           "get_last_lr"):  # PyTorch versions > 1.5
                    logging_vars['lr'] = self.lr_scheduler.get_last_lr()[0]
                else:
                    logging_vars['lr'] = self.lr_scheduler.get_lr()[0]

                if self.elapsed_iterations % opts.display_interval == 0:
                    log_func = self.console_logger.info
                else:
                    log_func = self.console_logger.debug

                eta, avg_time_per_iter = self.logger.compute_eta(
                    as_string=True)
                log_func(
                    "It: {:05d} - {:s} - ETA: {:s} - sec/it: {:.3f}".format(
                        self.elapsed_iterations, var_keys_to_str(logging_vars),
                        eta, avg_time_per_iter))

            if self.elapsed_iterations % opts.save_interval == 0:
                if self.is_main_process:
                    # remove outdated checkpoints
                    checkpoints = sorted(
                        glob(os.path.join(self.model_save_dir, '%06d.pth')))
                    if len(checkpoints) > opts.ckpts_to_keep:
                        for ckpt_path in checkpoints[:-opts.ckpts_to_keep]:
                            os.remove(ckpt_path)

                    self.backup_session()

                dist_utils.synchronize()

        self.console_logger.info("Training complete\n"
                                 "Model(s) saved to: %s\n"
                                 "Log file(s) saved to: %s\n" %
                                 (self.model_save_dir, self.log_dir))