示例#1
0
def run(args):
    distributed, device_ids = main_util.init_distributed_mode(
        args.world_size, args.dist_url)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    if device.type == 'cuda':
        cudnn.benchmark = True

    print(args)
    config = yaml_util.load_yaml_file(args.config)
    input_shape = config['input_shape']
    ckpt_file_path = config['autoencoder']['ckpt']
    train_loader, valid_loader, test_loader = main_util.get_data_loaders(
        config, distributed)
    if not args.test_only:
        train(train_loader, valid_loader, input_shape, config, device,
              distributed, device_ids)

    autoencoder, _ = ae_util.get_autoencoder(config, device)
    resume_from_ckpt(ckpt_file_path, autoencoder)
    extended_model, model = ae_util.get_extended_model(autoencoder, config,
                                                       input_shape, device)
    if not args.extended_only:
        if device.type == 'cuda':
            model = DistributedDataParallel(model, device_ids=device_ids) if distributed \
                else DataParallel(model)
        evaluate(model, test_loader, device, title='[Original model]')

    if device.type == 'cuda':
        extended_model = DistributedDataParallel(extended_model, device_ids=device_ids) if distributed \
            else DataParallel(extended_model)

    evaluate(extended_model, test_loader, device, title='[Mimic model]')
def run(rank, model, train_dataloader, num_epochs, precision, accelerator,
        tmpdir):
    os.environ["LOCAL_RANK"] = str(rank)
    if torch.distributed.is_available(
    ) and not torch.distributed.is_initialized():
        torch.distributed.init_process_group("gloo", rank=rank, world_size=2)

    to_device = partial(move_data_to_device, device=torch.device("cuda", rank))
    model = DistributedDataParallel(
        to_device(model),
        device_ids=[rank],
    )
    train_dataloader = DataLoader(
        train_dataloader.dataset,
        sampler=DistributedSampler(train_dataloader.dataset,
                                   rank=rank,
                                   num_replicas=2,
                                   seed=42,
                                   drop_last=False),
    )
    with precision_context(precision, accelerator):
        main(to_device, model, train_dataloader, num_epochs=num_epochs)

    if rank == 0:
        atomic_save(model.state_dict(), os.path.join(tmpdir, "model_spawn.pt"))
示例#3
0
def main(args):
    config = yaml_util.load_yaml_file(args.config)
    if args.json is not None:
        main_util.overwrite_config(config, args.json)

    distributed, device_ids = main_util.init_distributed_mode(
        args.world_size, args.dist_url)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    teacher_model = get_model(config['teacher_model'], device)
    module_util.freeze_module_params(teacher_model)
    student_model_config = config['student_model']
    student_model = get_model(student_model_config, device)
    freeze_modules(student_model, student_model_config)
    print('Updatable parameters: {}'.format(
        module_util.get_updatable_param_names(student_model)))
    distill_backbone_only = student_model_config['distill_backbone_only']
    train_config = config['train']
    train_sampler, train_data_loader, val_data_loader, test_data_loader = \
        data_util.get_coco_data_loaders(config['dataset'], train_config['batch_size'], distributed)
    if distributed:
        teacher_model = DataParallel(teacher_model, device_ids=device_ids)
        student_model = DistributedDataParallel(student_model,
                                                device_ids=device_ids)

    if args.distill:
        distill(teacher_model, student_model, train_sampler, train_data_loader,
                val_data_loader, device, distributed, distill_backbone_only,
                config, args)
        load_ckpt(
            config['student_model']['ckpt'],
            model=student_model.module if isinstance(
                student_model, DistributedDataParallel) else student_model)
    evaluate(teacher_model, student_model, test_data_loader, device,
             args.skip_teacher_eval, args.transform_bottleneck)
示例#4
0
 def configure_ddp(self):
     self.pre_configure_ddp()
     self._model = DistributedDataParallel(
         LightningDistributedModule(self.model),
         **self._ddp_kwargs,
     )
     self._register_ddp_hooks()
示例#5
0
 def configure_ddp(self):
     self.pre_configure_ddp()
     self._model = DistributedDataParallel(
         LightningDistributedModule(self.model),
         device_ids=self.determine_ddp_device_ids(),
         **self._ddp_kwargs,
     )
示例#6
0
def load_vaes(H, logprint):
    vae = VAE(H)
    if H.restore_path:
        logprint(f'Restoring vae from {H.restore_path}')
        restore_params(vae, H.restore_path, map_cpu=True, local_rank=H.local_rank, mpi_size=H.mpi_size)

    ema_vae = VAE(H)
    if H.restore_ema_path:
        logprint(f'Restoring ema vae from {H.restore_ema_path}')
        restore_params(ema_vae, H.restore_ema_path, map_cpu=True, local_rank=H.local_rank, mpi_size=H.mpi_size)
    else:
        ema_vae.load_state_dict(vae.state_dict())
    ema_vae.requires_grad_(False)

    vae = vae.cuda(H.local_rank)
    ema_vae = ema_vae.cuda(H.local_rank)

    vae = DistributedDataParallel(vae, device_ids=[H.local_rank], output_device=H.local_rank)

    if len(list(vae.named_parameters())) != len(list(vae.parameters())):
        raise ValueError('Some params are not named. Please name all params.')
    total_params = 0
    for name, p in vae.named_parameters():
        total_params += np.prod(p.shape)
    logprint(total_params=total_params, readable=f'{total_params:,}')
    return vae, ema_vae
示例#7
0
def train(train_loader, valid_loader, input_shape, config, device, distributed,
          device_ids):
    ae_without_ddp, ae_type = ae_util.get_autoencoder(config, device)
    head_model = ae_util.get_head_model(config, input_shape, device)
    module_util.freeze_module_params(head_model)
    ckpt_file_path = config['autoencoder']['ckpt']
    start_epoch, best_valid_acc = resume_from_ckpt(ckpt_file_path,
                                                   ae_without_ddp)
    if best_valid_acc is None:
        best_valid_acc = 0.0

    train_config = config['train']
    criterion_config = train_config['criterion']
    criterion = func_util.get_loss(criterion_config['type'],
                                   criterion_config['params'])
    optim_config = train_config['optimizer']
    optimizer = func_util.get_optimizer(ae_without_ddp, optim_config['type'],
                                        optim_config['params'])
    scheduler_config = train_config['scheduler']
    scheduler = func_util.get_scheduler(optimizer, scheduler_config['type'],
                                        scheduler_config['params'])
    interval = train_config['interval']
    if interval <= 0:
        num_batches = len(train_loader)
        interval = num_batches // 20 if num_batches >= 20 else 1

    autoencoder = ae_without_ddp
    if distributed:
        autoencoder = DistributedDataParallel(ae_without_ddp,
                                              device_ids=device_ids)
        head_model = DataParallel(head_model, device_ids=device_ids)
    elif device.type == 'cuda':
        autoencoder = DataParallel(ae_without_ddp)
        head_model = DataParallel(head_model)

    end_epoch = start_epoch + train_config['epoch']
    start_time = time.time()
    for epoch in range(start_epoch, end_epoch):
        if distributed:
            train_loader.sampler.set_epoch(epoch)

        train_epoch(autoencoder, head_model, train_loader, optimizer,
                    criterion, epoch, device, interval)
        valid_acc = validate(ae_without_ddp, valid_loader, config, device,
                             distributed, device_ids)
        if valid_acc > best_valid_acc and main_util.is_main_process():
            print(
                'Updating ckpt (Best top1 accuracy: {:.4f} -> {:.4f})'.format(
                    best_valid_acc, valid_acc))
            best_valid_acc = valid_acc
            save_ckpt(ae_without_ddp, epoch, best_valid_acc, ckpt_file_path,
                      ae_type)
        scheduler.step()

    dist.barrier()
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
    del head_model
示例#8
0
def wrap(model):
    if world_size == 1:
        return model
    else:
        return DistributedDataParallel(
            model,
            # find_unused_parameters=True,
            device_ids=[torch.cuda.current_device()],
            output_device=torch.cuda.current_device())
示例#9
0
文件: dqn.py 项目: Fengdan92/Horizon
 def __init__(self, fc_dqn):
     super().__init__()
     self.state_dim = fc_dqn.state_dim
     self.action_dim = fc_dqn.action_dim
     current_device = torch.cuda.current_device()
     self.data_parallel = DistributedDataParallel(
         fc_dqn.fc, device_ids=[current_device], output_device=current_device
     )
     self.fc_dqn = fc_dqn
示例#10
0
 def _setup_model(self, model: Module) -> DistributedDataParallel:
     """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
     device_ids = self.determine_ddp_device_ids()
     log.detail(
         f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}"
     )
     return DistributedDataParallel(module=model,
                                    device_ids=device_ids,
                                    **self._ddp_kwargs)
示例#11
0
    def __init__(self, seq2slate_net: Seq2SlateNet):
        super().__init__()

        current_device = torch.cuda.current_device()
        self.data_parallel = DistributedDataParallel(
            seq2slate_net.seq2slate,
            device_ids=[current_device],
            output_device=current_device,
        )
        self.seq2slate_net = seq2slate_net
def main(rank: int,
         epochs: int,
         model: nn.Module,
         train_loader: DataLoader,
         test_loader: DataLoader) -> nn.Module:
    device = torch.device(f'cuda:{rank}')
    model = model.to(device)
    model = DistributedDataParallel(model, device_ids=[rank], output_device=rank)

    # initialize optimizer and loss function
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    loss = nn.CrossEntropyLoss()

    # train the model
    for i in range(epochs):
        model.train()
        train_loader.sampler.set_epoch(i)

        epoch_loss = 0
        # train the model for one epoch
        pbar = tqdm(train_loader)
        for x, y in pbar:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            x = x.view(x.shape[0], -1)
            optimizer.zero_grad()
            y_hat = model(x)
            batch_loss = loss(y_hat, y)
            batch_loss.backward()
            optimizer.step()
            batch_loss_scalar = batch_loss.item()
            epoch_loss += batch_loss_scalar / x.shape[0]
            pbar.set_description(f'training batch_loss={batch_loss_scalar:.4f}')

        # calculate validation loss
        with torch.no_grad():
            model.eval()
            val_loss = 0
            pbar = tqdm(test_loader)
            for x, y in pbar:
                x = x.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)
                x = x.view(x.shape[0], -1)
                y_hat = model(x)
                batch_loss = loss(y_hat, y)
                batch_loss_scalar = batch_loss.item()

                val_loss += batch_loss_scalar / x.shape[0]
                pbar.set_description(f'validation batch_loss={batch_loss_scalar:.4f}')

        print(f"Epoch={i}, train_loss={epoch_loss:.4f}, val_loss={val_loss:.4f}")

    return model.module
示例#13
0
    def load_model(self):
        super(EnergyTrainer, self).load_model()

        self.model = OCPDataParallel(
            self.model,
            output_device=self.device,
            num_gpus=self.config["optim"].get("num_gpus", 1),
        )
        if distutils.initialized():
            self.model = DistributedDataParallel(self.model,
                                                 device_ids=[self.device])
示例#14
0
文件: distrib.py 项目: yunzqq/svoice
def wrap(model):
    """wrap.
    Wrap a model with DDP if distributed training is enabled.
    """
    if world_size == 1:
        return model
    else:
        return DistributedDataParallel(
            model,
            device_ids=[torch.cuda.current_device()],
            output_device=torch.cuda.current_device())
示例#15
0
    def load_model(self):
        super(DistributedEnergyTrainer, self).load_model()

        self.model = OCPDataParallel(
            self.model,
            output_device=self.device,
            num_gpus=self.config["optim"].get("num_gpus", 1),
        )
        self.model = DistributedDataParallel(self.model,
                                             device_ids=[self.device],
                                             find_unused_parameters=True)
示例#16
0
    def load_model(self):
        super(DistributedForcesTrainer, self).load_model()

        self.model = OCPDataParallel(
            self.model,
            output_device=self.device,
            num_gpus=1,
        )
        self.model = DistributedDataParallel(self.model,
                                             device_ids=[self.device],
                                             find_unused_parameters=True)
示例#17
0
def validate(ae_without_ddp, data_loader, config, device, distributed,
             device_ids):
    input_shape = config['input_shape']
    extended_model, _ = ae_util.get_extended_model(ae_without_ddp, config,
                                                   input_shape, device, True)
    if distributed:
        extended_model = DistributedDataParallel(extended_model,
                                                 device_ids=device_ids)
    return evaluate(extended_model,
                    data_loader,
                    device,
                    split_name='Validation')
示例#18
0
    def __init__(self, cem_planner: CEMPlanner):
        super().__init__()
        self.plan_horizon_length = cem_planner.plan_horizon_length
        self.state_dim = cem_planner.state_dim
        self.action_dim = cem_planner.action_dim
        self.discrete_action = cem_planner.discrete_action

        current_device = torch.cuda.current_device()  # type: ignore
        self.data_parallel = DistributedDataParallel(
            cem_planner.cem_planner_network,
            device_ids=[current_device],
            output_device=current_device,
        )
        self.cem_planner = cem_planner
示例#19
0
    def __init__(self, mem_net):
        super().__init__()
        self.num_hiddens = mem_net.num_hiddens
        self.num_hidden_layers = mem_net.num_hidden_layers
        self.state_dim = mem_net.state_dim
        self.action_dim = mem_net.action_dim
        self.num_gaussians = mem_net.num_gaussians

        current_device = torch.cuda.current_device()
        self.data_parallel = DistributedDataParallel(
            mem_net.mdnrnn,
            device_ids=[current_device],
            output_device=current_device)
        self.mem_net = mem_net
示例#20
0
def train(epochs: int,
          train_data_loader: DataLoader,
          valid_data_loader: DataLoader = None,
          rank=None):
    device = torch.device(f'cuda:{rank}')
    model = create_model(model_type).to(device)
    model = DistributedDataParallel(model,
                                    device_ids=[rank],
                                    output_device=rank)
    optimizer = AdamW(model.parameters(), lr=lr)
    tokenizer = BertTokenizer.from_pretrained(model_type)

    def update_weights(bi, di, num_batches, batch_loss):
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if bi % 100 == 0:
            logger.info(
                f'training: device={di}; batch={bi+1}/{num_batches}; batch_error={batch_loss.item()};'
            )

    def valid_loss_progress_log(bi, di, num_batches, batch_loss):
        if bi % 100 == 0:
            logger.info(
                f'validation: device={di}; batch={bi+1}/{num_batches}; val_batch_error={batch_loss.item()};'
            )

    for i in range(epochs):
        model.train()
        train_data_loader.sampler.set_epoch(i)
        valid_data_loader.sampler.set_epoch(i)

        train_loss = run(model, train_data_loader, tokenizer, device,
                         update_weights)

        if valid_data_loader is not None:
            with torch.no_grad():
                model.eval()
                val_loss = run(model, valid_data_loader, tokenizer, device,
                               valid_loss_progress_log)
        else:
            val_loss = 'N/A'

        logger.info(
            f'epoch={i}; device={rank}; train_error={train_loss};  valid_error={val_loss};'
        )

    return model.module
示例#21
0
    def load_model(self):
        # Build model
        if distutils.is_master():
            logging.info(f"Loading model: {self.config['model']}")

        # TODO(abhshkdz): Eventually move towards computing features on-the-fly
        # and remove dependence from `.edge_attr`.
        bond_feat_dim = None
        if self.config["task"]["dataset"] in [
            "trajectory_lmdb",
            "single_point_lmdb",
        ]:
            bond_feat_dim = self.config["model_attributes"].get(
                "num_gaussians", 50
            )
        else:
            raise NotImplementedError

        loader = self.train_loader or self.val_loader or self.test_loader
        self.model = registry.get_model_class(self.config["model"])(
            loader.dataset[0].x.shape[-1]
            if loader
            and hasattr(loader.dataset[0], "x")
            and loader.dataset[0].x is not None
            else None,
            bond_feat_dim,
            self.num_targets,
            **self.config["model_attributes"],
        ).to(self.device)

        if distutils.is_master():
            logging.info(
                f"Loaded {self.model.__class__.__name__} with "
                f"{self.model.num_params} parameters."
            )

        if self.logger is not None:
            self.logger.watch(self.model)

        self.model = OCPDataParallel(
            self.model,
            output_device=self.device,
            num_gpus=1 if not self.cpu else 0,
        )
        if distutils.initialized():
            self.model = DistributedDataParallel(
                self.model, device_ids=[self.device]
            )
示例#22
0
def train(model, train_loader, valid_loader, best_valid_acc, criterion, device,
          distributed, device_ids, train_config, num_epochs, start_epoch,
          init_lr, ckpt_file_path, model_type):
    model_without_ddp = model
    if distributed:
        model = DistributedDataParallel(model_without_ddp,
                                        device_ids=device_ids)
    elif device.type == 'cuda':
        model = DataParallel(model_without_ddp)

    optim_config = train_config['optimizer']
    if init_lr is not None:
        optim_config['params']['lr'] = init_lr

    optimizer = func_util.get_optimizer(model, optim_config['type'],
                                        optim_config['params'])
    scheduler_config = train_config['scheduler']
    scheduler = func_util.get_scheduler(optimizer, scheduler_config['type'],
                                        scheduler_config['params'])
    interval = train_config['interval']
    if interval <= 0:
        num_batches = len(train_loader)
        interval = num_batches // 20 if num_batches >= 20 else 1

    end_epoch = start_epoch + train_config[
        'epoch'] if num_epochs is None else start_epoch + num_epochs
    start_time = time.time()
    for epoch in range(start_epoch, end_epoch):
        if distributed:
            train_loader.sampler.set_epoch(epoch)

        train_epoch(model, train_loader, optimizer, criterion, epoch, device,
                    interval)
        valid_acc = validate(model, valid_loader, device)
        if valid_acc > best_valid_acc and main_util.is_main_process():
            print(
                'Updating ckpt (Best top1 accuracy: {:.4f} -> {:.4f})'.format(
                    best_valid_acc, valid_acc))
            best_valid_acc = valid_acc
            save_ckpt(model_without_ddp, best_valid_acc, epoch, ckpt_file_path,
                      model_type)
        scheduler.step()

    dist.barrier()
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
示例#23
0
def validate(student_model_without_ddp, data_loader, config, device,
             distributed, device_ids):
    teacher_model_config = config['teacher_model']
    org_model, teacher_model_type = mimic_util.get_org_model(
        teacher_model_config, device)
    mimic_model = mimic_util.get_mimic_model(
        config,
        org_model,
        teacher_model_type,
        teacher_model_config,
        device,
        head_model=student_model_without_ddp)
    mimic_model_without_dp = mimic_model.module if isinstance(
        mimic_model, DataParallel) else mimic_model
    if distributed:
        mimic_model = DistributedDataParallel(mimic_model_without_dp,
                                              device_ids=device_ids)
    return evaluate(mimic_model, data_loader, device, split_name='Validation')
示例#24
0
    def __init__(self, seq2slate_transformer_net: Seq2SlateTransformerNet):
        super().__init__()
        self.state_dim = seq2slate_transformer_net.state_dim
        self.candidate_dim = seq2slate_transformer_net.candidate_dim
        self.num_stacked_layers = seq2slate_transformer_net.num_stacked_layers
        self.num_heads = seq2slate_transformer_net.num_heads
        self.dim_model = seq2slate_transformer_net.dim_model
        self.dim_feedforward = seq2slate_transformer_net.dim_feedforward
        self.max_src_seq_len = seq2slate_transformer_net.max_src_seq_len
        self.max_tgt_seq_len = seq2slate_transformer_net.max_tgt_seq_len

        current_device = torch.cuda.current_device()  # type: ignore
        self.data_parallel = DistributedDataParallel(
            seq2slate_transformer_net.seq2slate_transformer,
            device_ids=[current_device],
            output_device=current_device,
        )
        self.seq2slate_transformer_net = seq2slate_transformer_net
示例#25
0
def run(args):
    distributed, device_ids = main_util.init_distributed_mode(
        args.world_size, args.dist_url)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        cudnn.benchmark = True

    print(args)
    config = yaml_util.load_yaml_file(args.config)
    dataset_config = config['dataset']
    input_shape = config['input_shape']
    train_config = config['train']
    test_config = config['test']
    train_loader, valid_loader, test_loader =\
        dataset_util.get_data_loaders(dataset_config, batch_size=train_config['batch_size'],
                                      rough_size=train_config['rough_size'], reshape_size=input_shape[1:3],
                                      test_batch_size=test_config['batch_size'], jpeg_quality=-1,
                                      distributed=distributed)
    teacher_model_config = config['teacher_model']
    if not args.test_only:
        distill(train_loader, valid_loader, input_shape, args.aux, config,
                device, distributed, device_ids)

    org_model, teacher_model_type = mimic_util.get_org_model(
        teacher_model_config, device)
    if not args.student_only:
        if distributed:
            org_model = DataParallel(org_model, device_ids=device_ids)
        evaluate(org_model, test_loader, device, title='[Original model]')

    mimic_model = mimic_util.get_mimic_model(config, org_model,
                                             teacher_model_type,
                                             teacher_model_config, device)
    mimic_model_without_dp = mimic_model.module if isinstance(
        mimic_model, DataParallel) else mimic_model
    file_util.save_pickle(mimic_model_without_dp,
                          config['mimic_model']['ckpt'])
    if distributed:
        mimic_model = DistributedDataParallel(mimic_model_without_dp,
                                              device_ids=device_ids)
    evaluate(mimic_model, test_loader, device, title='[Mimic model]')
示例#26
0
    def __init__(self, config: BaseConfig):
        self._config = config
        self._model = DeepLab(num_classes=9, output_stride=8,
                              sync_bn=False).to(self._config.device)
        self._border_loss = TotalLoss(self._config)
        self._direction_loss = CrossEntropyLoss()
        self._loaders = get_data_loaders(config)
        self._writer = SummaryWriter()
        self._optimizer = torch.optim.SGD(self._model.parameters(),
                                          lr=self._config.lr,
                                          weight_decay=1e-4,
                                          nesterov=True,
                                          momentum=0.9)
        self._scheduler = torch.optim.lr_scheduler.ExponentialLR(
            self._optimizer, gamma=0.97)

        if self._config.parallel:
            self._model = DistributedDataParallel(self._model,
                                                  device_ids=[
                                                      self._config.device,
                                                  ])
示例#27
0
    def __init__(self,
                 model,
                 training_set,
                 batch_size,
                 learning_rate,
                 validation_set=None,
                 test_set=None,
                 checkpoint_dir=None):
        init_process_group('mpi')
        model = DistributedDataParallel(model)
        rank = get_rank()
        world_size = get_world_size()

        super().__init__(rank=rank,
                         world_size=world_size,
                         model=model,
                         training_set=training_set,
                         validation_set=validation_set,
                         test_set=test_set,
                         batch_size=batch_size,
                         learning_rate=learning_rate,
                         checkpoint_dir=checkpoint_dir)
示例#28
0
    def configure_ddp(self, model: LightningModule,
                      device_ids: List[int]) -> DistributedDataParallel:
        """
        Pass through all customizations from constructor to :class:`~torch.nn.parallel.DistributedDataParallel`.
        Override to define a custom DDP implementation.

        .. note:: This requires that your DDP implementation subclasses
            :class:`~torch.nn.parallel.DistributedDataParallel` and that
            the original LightningModule gets wrapped by
            :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedModule`.

        The default implementation is::

            def configure_ddp(self, model, device_ids):
                model = DistributedDataParallel(
                    LightningDistributedModule(model),
                    device_ids=device_ids,
                    **self._ddp_kwargs,
                )
                return model

        Args:
            model: the LightningModule
            device_ids: the list of devices available

        Returns:
            the model wrapped in :class:`~torch.nn.parallel.DistributedDataParallel`

        """
        # if unset, default `find_unused_parameters` `True`
        self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
            "find_unused_parameters", True)
        model = DistributedDataParallel(
            module=LightningDistributedModule(model),
            device_ids=device_ids,
            **self._ddp_kwargs,
        )
        return model
示例#29
0
def main():
    parser = get_parser()
    args = parser.parse_args()
    name = get_name(parser, args)
    print(f"Experiment {name}")

    if args.musdb is None and args.rank == 0:
        print(
            "You must provide the path to the MusDB dataset with the --musdb flag. "
            "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.",
            file=sys.stderr)
        sys.exit(1)

    eval_folder = args.evals / name
    eval_folder.mkdir(exist_ok=True, parents=True)
    args.logs.mkdir(exist_ok=True)
    metrics_path = args.logs / f"{name}.json"
    eval_folder.mkdir(exist_ok=True, parents=True)
    args.checkpoints.mkdir(exist_ok=True, parents=True)
    args.models.mkdir(exist_ok=True, parents=True)

    if args.device is None:
        device = "cpu"
        if th.cuda.is_available():
            device = "cuda"
    else:
        device = args.device

    th.manual_seed(args.seed)
    # Prevents too many threads to be started when running `museval` as it can be quite
    # inefficient on NUMA architectures.
    os.environ["OMP_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"

    if args.world_size > 1:
        if device != "cuda" and args.rank == 0:
            print(
                "Error: distributed training is only available with cuda device",
                file=sys.stderr)
            sys.exit(1)
        th.cuda.set_device(args.rank % th.cuda.device_count())
        distributed.init_process_group(backend="nccl",
                                       init_method="tcp://" + args.master,
                                       rank=args.rank,
                                       world_size=args.world_size)

    checkpoint = args.checkpoints / f"{name}.th"
    checkpoint_tmp = args.checkpoints / f"{name}.th.tmp"
    if args.restart and checkpoint.exists() and args.rank == 0:
        checkpoint.unlink()

    if args.test or args.test_pretrained:
        args.epochs = 1
        args.repeat = 0
        if args.test:
            model = load_model(args.models / args.test)
        else:
            model = load_pretrained(args.test_pretrained)
    elif args.tasnet:
        model = ConvTasNet(audio_channels=args.audio_channels,
                           samplerate=args.samplerate,
                           X=args.X,
                           segment_length=4 * args.samples)
    else:
        model = Demucs(
            audio_channels=args.audio_channels,
            channels=args.channels,
            context=args.context,
            depth=args.depth,
            glu=args.glu,
            growth=args.growth,
            kernel_size=args.kernel_size,
            lstm_layers=args.lstm_layers,
            rescale=args.rescale,
            rewrite=args.rewrite,
            stride=args.conv_stride,
            resample=args.resample,
            samplerate=args.samplerate,
            segment_length=4 * args.samples,
        )
    model.to(device)
    if args.init:
        model.load_state_dict(load_pretrained(args.init).state_dict())

    if args.show:
        print(model)
        size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
        print(f"Model size {size}")
        return

    try:
        saved = th.load(checkpoint, map_location='cpu')
    except IOError:
        saved = SavedState()

    optimizer = th.optim.Adam(model.parameters(), lr=args.lr)

    quantizer = None
    quantizer = get_quantizer(model, args, optimizer)

    if saved.last_state is not None:
        model.load_state_dict(saved.last_state, strict=False)
    if saved.optimizer is not None:
        optimizer.load_state_dict(saved.optimizer)

    model_name = f"{name}.th"
    if args.save_model:
        if args.rank == 0:
            model.to("cpu")
            model.load_state_dict(saved.best_state)
            save_model(model, quantizer, args, args.models / model_name)
        return
    elif args.save_state:
        model_name = f"{args.save_state}.th"
        if args.rank == 0:
            model.to("cpu")
            model.load_state_dict(saved.best_state)
            state = get_state(model, quantizer)
            save_state(state, args.models / model_name)
        return

    if args.rank == 0:
        done = args.logs / f"{name}.done"
        if done.exists():
            done.unlink()

    augment = [Shift(args.data_stride)]
    if args.augment:
        augment += [
            FlipSign(),
            FlipChannels(),
            Scale(),
            Remix(group_size=args.remix_group_size)
        ]
    augment = nn.Sequential(*augment).to(device)
    print("Agumentation pipeline:", augment)

    if args.mse:
        criterion = nn.MSELoss()
    else:
        criterion = nn.L1Loss()

    # Setting number of samples so that all convolution windows are full.
    # Prevents hard to debug mistake with the prediction being shifted compared
    # to the input mixture.
    samples = model.valid_length(args.samples)
    print(f"Number of training samples adjusted to {samples}")
    samples = samples + args.data_stride
    if args.repitch:
        # We need a bit more audio samples, to account for potential
        # tempo change.
        samples = math.ceil(samples / (1 - 0.01 * args.max_tempo))

    sources = 4
    if args.raw:
        train_set = Rawset(args.raw / "train",
                           samples=samples,
                           channels=args.audio_channels,
                           streams=range(1, sources + 1),
                           stride=args.data_stride)

        valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
    else:
        if not args.metadata.is_file() and args.rank == 0:
            build_musdb_metadata(args.metadata, args.musdb, args.workers)
        if args.world_size > 1:
            distributed.barrier()
        metadata = json.load(open(args.metadata))
        duration = Fraction(samples, args.samplerate)
        stride = Fraction(args.data_stride, args.samplerate)
        train_set = StemsSet(get_musdb_tracks(args.musdb,
                                              subsets=["train"],
                                              split="train"),
                             metadata,
                             duration=duration,
                             stride=stride,
                             streams=range(1, sources + 1),
                             samplerate=args.samplerate,
                             channels=args.audio_channels)
        valid_set = StemsSet(get_musdb_tracks(args.musdb,
                                              subsets=["train"],
                                              split="valid"),
                             metadata,
                             samplerate=args.samplerate,
                             channels=args.audio_channels)
    if args.repitch:
        train_set = RepitchedWrapper(train_set,
                                     proba=args.repitch,
                                     max_tempo=args.max_tempo)

    best_loss = float("inf")
    for epoch, metrics in enumerate(saved.metrics):
        print(f"Epoch {epoch:03d}: "
              f"train={metrics['train']:.8f} "
              f"valid={metrics['valid']:.8f} "
              f"best={metrics['best']:.4f} "
              f"ms={metrics.get('true_model_size', 0):.2f}MB "
              f"cms={metrics.get('compressed_model_size', 0):.2f}MB "
              f"duration={human_seconds(metrics['duration'])}")
        best_loss = metrics['best']

    if args.world_size > 1:
        dmodel = DistributedDataParallel(
            model,
            device_ids=[th.cuda.current_device()],
            output_device=th.cuda.current_device())
    else:
        dmodel = model

    for epoch in range(len(saved.metrics), args.epochs):
        begin = time.time()
        model.train()
        train_loss, model_size = train_model(epoch,
                                             train_set,
                                             dmodel,
                                             criterion,
                                             optimizer,
                                             augment,
                                             quantizer=quantizer,
                                             batch_size=args.batch_size,
                                             device=device,
                                             repeat=args.repeat,
                                             seed=args.seed,
                                             diffq=args.diffq,
                                             workers=args.workers,
                                             world_size=args.world_size)
        model.eval()
        valid_loss = validate_model(epoch,
                                    valid_set,
                                    model,
                                    criterion,
                                    device=device,
                                    rank=args.rank,
                                    split=args.split_valid,
                                    overlap=args.overlap,
                                    world_size=args.world_size)

        ms = 0
        cms = 0
        if quantizer and args.rank == 0:
            ms = quantizer.true_model_size()
            cms = quantizer.compressed_model_size(
                num_workers=min(40, args.world_size * 10))

        duration = time.time() - begin
        if valid_loss < best_loss and ms <= args.ms_target:
            best_loss = valid_loss
            saved.best_state = {
                key: value.to("cpu").clone()
                for key, value in model.state_dict().items()
            }

        saved.metrics.append({
            "train": train_loss,
            "valid": valid_loss,
            "best": best_loss,
            "duration": duration,
            "model_size": model_size,
            "true_model_size": ms,
            "compressed_model_size": cms,
        })
        if args.rank == 0:
            json.dump(saved.metrics, open(metrics_path, "w"))

        saved.last_state = model.state_dict()
        saved.optimizer = optimizer.state_dict()
        if args.rank == 0 and not args.test:
            th.save(saved, checkpoint_tmp)
            checkpoint_tmp.rename(checkpoint)

        print(
            f"Epoch {epoch:03d}: "
            f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} ms={ms:.2f}MB "
            f"cms={cms:.2f}MB "
            f"duration={human_seconds(duration)}")

    del dmodel
    model.load_state_dict(saved.best_state)
    if args.eval_cpu:
        device = "cpu"
        model.to(device)
    model.eval()
    evaluate(model,
             args.musdb,
             eval_folder,
             rank=args.rank,
             world_size=args.world_size,
             device=device,
             save=args.save,
             split=args.split_valid,
             shifts=args.shifts,
             overlap=args.overlap,
             workers=args.eval_workers)
    model.to("cpu")
    if args.rank == 0:
        if not (args.test or args.test_pretrained):
            save_model(model, quantizer, args, args.models / model_name)
        print("done")
        done.write_text("done")
示例#30
0
def train(cfg, args):
    logger = logging.getLogger('model training')
    train_data = build_dataloader(cfg.train_data_cfg, args.distributed)
    logger.info("train data: {}".format(len(train_data)))
    val_data = build_dataloader(cfg.val_data_cfg, args.distributed)
    logger.info("val data: {}".format(len(val_data)))

    model = MultiInstanceRecognition(cfg.model_cfg).cuda()
    if cfg.resume_from is not None:
        logger.info('loading pretrained models from {opt.continue_model}')
        model.load_state_dict(torch.load(cfg.resume_from))
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
    voc, char2id, id2char = get_vocabulary("ALLCASES_SYMBOLS")

    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    logger.info('Trainable params num : ', sum(params_num))
    optimizer = optim.Adam(filtered_parameters, lr=cfg.lr, betas=(0.9, 0.999))
    lrScheduler = lr_scheduler.MultiStepLR(optimizer, [1, 2, 3], gamma=0.1)

    max_iters = cfg.max_iters
    start_iter = 0
    if cfg.resume_from is not None:
        start_iter = int(cfg.resume_from.split('_')[-1].split('.')[0])
        logger.info('continue to train, start_iter: {start_iter}')

    train_data_iter = iter(train_data)
    val_data_iter = iter(val_data)
    start_time = time.time()
    for i in range(start_iter, max_iters):
        model.train()
        try:
            batch_data = next(train_data_iter)
        except StopIteration:
            train_data_iter = iter(train_data)
            batch_data = next(train_data_iter)
        data_time_s = time.time()
        batch_imgs, batch_imgs_path, batch_rectangles, \
        batch_text_labels, batch_text_labels_mask, batch_words = \
            batch_data
        while batch_imgs is None:
            batch_data = next(train_data_iter)
            batch_imgs, batch_imgs_path, batch_rectangles, \
            batch_text_labels, batch_text_labels_mask, batch_words = \
                batch_data

        batch_imgs = batch_imgs.cuda(non_blocking=True)
        batch_rectangles = batch_rectangles.cuda(non_blocking=True)
        batch_text_labels = batch_text_labels.cuda(non_blocking=True)
        data_time = time.time() - data_time_s
        # print(time.time() -s)
        # s = time.time()
        loss, decoder_logits = model(batch_imgs, batch_text_labels,
                                     batch_rectangles, batch_text_labels_mask)
        del batch_data
        # print(time.time() - s)
        # print('------')
        # s = time.time()

        loss = loss.mean()
        print(loss)
        # del loss
        # print(time.time() - s)
        # print('------')

        if i % cfg.train_verbose == 0:
            this_time = time.time() - start_time
            if args.distributed:
                loss = dist.reduce(loss, 0)
            log_info = "train iter :{}, time: {:.2f}, data_time: {:.2f}, Loss: {:.3f}".format(
                i, this_time, data_time, loss.data)
            logger.info(log_info)
            torch.cuda.empty_cache()
            # break

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        del loss
        if i % cfg.val_iter == 0:
            print("--------Val iteration---------")
            model.eval()

            try:
                val_batch = next(val_data_iter)
            except StopIteration:
                val_data_iter = iter(val_data)
                val_batch = next(val_data_iter)

            batch_imgs, batch_imgs_path, batch_rectangles, \
            batch_text_labels, batch_text_labels_mask, batch_words = \
                val_batch
            while batch_imgs is None:
                val_batch = next(val_data_iter)
                batch_imgs, batch_imgs_path, batch_rectangles, \
                batch_text_labels, batch_text_labels_mask, batch_words = \
                    val_batch
            del val_batch
            batch_imgs = batch_imgs.cuda(non_blocking=True)
            batch_rectangles = batch_rectangles.cuda(non_blocking=True)
            batch_text_labels = batch_text_labels.cuda(non_blocking=True)
            with torch.no_grad():
                val_loss, val_pred_logits = model(batch_imgs,
                                                  batch_text_labels,
                                                  batch_rectangles,
                                                  batch_text_labels_mask)
            pred_labels = val_pred_logits.argmax(dim=2).cpu().numpy()
            pred_value_str = idx2label(pred_labels, id2char, char2id)
            # gt_str = batch_words
            gt_str = []
            for words in batch_words:
                gt_str = gt_str + words
            val_dec_metrics_result = calc_metrics(pred_value_str,
                                                  gt_str,
                                                  metrics_type="accuracy")
            this_time = time.time() - start_time
            if args.distributed:
                loss = dist.reduce(val_loss, 0)
            log_info = "val iter :{}, time: {:.2f} Loss: {:.3f}, acc: {:.2f}".format(
                i, this_time,
                loss.mean().data, val_dec_metrics_result)
            logger.info(log_info)
            del val_loss
        if (i + 1) % cfg.save_iter == 0:
            torch.save(model.state_dict(),
                       cfg.save_name + '_{}.pth'.format(i + 1))
        if i > 0 and i % cfg.lr_step == 0:  # 调整学习速率
            lrScheduler.step()
            logger.info("lr step")
        # torch.cuda.empty_cache()
    print('end the training')