コード例 #1
0
def wrap(model):
    if world_size == 1:
        return model
    else:
        return DistributedDataParallel(
            model,
            # find_unused_parameters=True,
            device_ids=[torch.cuda.current_device()],
            output_device=torch.cuda.current_device())
コード例 #2
0
    def block_backward_sync(self, model: DistributedDataParallel):
        """
        Blocks ddp sync gradients behaviour on backwards pass.
        This is useful for skipping sync when accumulating gradients, reducing communication overhead

        Returns:
            context manager with sync behaviour off
        """
        yield model.no_sync()
コード例 #3
0
ファイル: ddp.py プロジェクト: ashleve/pytorch-lightning
 def _setup_model(self, model: Module) -> DistributedDataParallel:
     """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
     device_ids = self.determine_ddp_device_ids()
     log.detail(
         f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}"
     )
     return DistributedDataParallel(module=model,
                                    device_ids=device_ids,
                                    **self._ddp_kwargs)
コード例 #4
0
ファイル: dqn.py プロジェクト: Fengdan92/Horizon
 def __init__(self, fc_dqn):
     super().__init__()
     self.state_dim = fc_dqn.state_dim
     self.action_dim = fc_dqn.action_dim
     current_device = torch.cuda.current_device()
     self.data_parallel = DistributedDataParallel(
         fc_dqn.fc, device_ids=[current_device], output_device=current_device
     )
     self.fc_dqn = fc_dqn
コード例 #5
0
ファイル: seq2slate.py プロジェクト: t-triobox/ReAgent
    def __init__(self, seq2slate_net: Seq2SlateNet):
        super().__init__()

        current_device = torch.cuda.current_device()
        self.data_parallel = DistributedDataParallel(
            seq2slate_net.seq2slate,
            device_ids=[current_device],
            output_device=current_device,
        )
        self.seq2slate_net = seq2slate_net
コード例 #6
0
def run(rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir):
    os.environ["LOCAL_RANK"] = str(rank)
    if torch.distributed.is_available() and not torch.distributed.is_initialized():
        torch.distributed.init_process_group("gloo", rank=rank, world_size=2)

    to_device = partial(move_data_to_device, device=torch.device("cuda", rank))
    model = DistributedDataParallel(
        to_device(model),
        device_ids=[rank],
    )
    train_dataloader = DataLoader(
        train_dataloader.dataset,
        sampler=DistributedSampler(train_dataloader.dataset, rank=rank, num_replicas=2, seed=42, drop_last=False),
    )
    with precision_context(precision, accelerator):
        main(to_device, model, train_dataloader, num_epochs=num_epochs)

    if rank == 0:
        atomic_save(model.state_dict(), os.path.join(tmpdir, "model_spawn.pt"))
コード例 #7
0
def load_vaes_256(H, logprint):
    vae = VAE_256(H)
    if H.restore_path:
        logprint(f'Restoring vae from {H.restore_path}')
        restore_params(vae,
                       H.restore_path,
                       map_cpu=True,
                       local_rank=H.local_rank,
                       mpi_size=H.mpi_size)

    ema_vae = VAE_256(H)
    if H.restore_ema_path:
        logprint(f'Restoring ema vae from {H.restore_ema_path}')
        restore_params(ema_vae,
                       H.restore_ema_path,
                       map_cpu=True,
                       local_rank=H.local_rank,
                       mpi_size=H.mpi_size)
    else:
        ema_vae.load_state_dict(vae.state_dict())
    ema_vae.requires_grad_(False)

    vae = vae.cuda(H.local_rank)
    ema_vae = ema_vae.cuda(H.local_rank)

    if H.image_size == 64:
        vae.decoder.requires_grad_(False)

    if H.image_size == 256:
        vae.encoder.requires_grad_(False)

    vae = DistributedDataParallel(vae,
                                  device_ids=[H.local_rank],
                                  output_device=H.local_rank,
                                  find_unused_parameters=True)

    if len(list(vae.named_parameters())) != len(list(vae.parameters())):
        raise ValueError('Some params are not named. Please name all params.')
    total_params = 0
    for name, p in vae.named_parameters():
        total_params += np.prod(p.shape)
    logprint(total_params=total_params, readable=f'{total_params:,}')
    return vae, ema_vae
コード例 #8
0
def main(rank: int,
         epochs: int,
         model: nn.Module,
         train_loader: DataLoader,
         test_loader: DataLoader) -> nn.Module:
    device = torch.device(f'cuda:{rank}')
    model = model.to(device)
    model = DistributedDataParallel(model, device_ids=[rank], output_device=rank)

    # initialize optimizer and loss function
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    loss = nn.CrossEntropyLoss()

    # train the model
    for i in range(epochs):
        model.train()
        train_loader.sampler.set_epoch(i)

        epoch_loss = 0
        # train the model for one epoch
        pbar = tqdm(train_loader)
        for x, y in pbar:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            x = x.view(x.shape[0], -1)
            optimizer.zero_grad()
            y_hat = model(x)
            batch_loss = loss(y_hat, y)
            batch_loss.backward()
            optimizer.step()
            batch_loss_scalar = batch_loss.item()
            epoch_loss += batch_loss_scalar / x.shape[0]
            pbar.set_description(f'training batch_loss={batch_loss_scalar:.4f}')

        # calculate validation loss
        with torch.no_grad():
            model.eval()
            val_loss = 0
            pbar = tqdm(test_loader)
            for x, y in pbar:
                x = x.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)
                x = x.view(x.shape[0], -1)
                y_hat = model(x)
                batch_loss = loss(y_hat, y)
                batch_loss_scalar = batch_loss.item()

                val_loss += batch_loss_scalar / x.shape[0]
                pbar.set_description(f'validation batch_loss={batch_loss_scalar:.4f}')

        print(f"Epoch={i}, train_loss={epoch_loss:.4f}, val_loss={val_loss:.4f}")

    return model.module
コード例 #9
0
ファイル: distrib.py プロジェクト: yunzqq/svoice
def wrap(model):
    """wrap.
    Wrap a model with DDP if distributed training is enabled.
    """
    if world_size == 1:
        return model
    else:
        return DistributedDataParallel(
            model,
            device_ids=[torch.cuda.current_device()],
            output_device=torch.cuda.current_device())
コード例 #10
0
    def load_model(self):
        # Build model
        if distutils.is_master():
            print("### Loading model: {}".format(self.config["model"]))

        # TODO(abhshkdz): Eventually move towards computing features on-the-fly
        # and remove dependence from `.edge_attr`.
        bond_feat_dim = None
        if self.config["task"]["dataset"] in [
                "trajectory_lmdb",
                "single_point_lmdb",
        ]:
            bond_feat_dim = self.config["model_attributes"].get(
                "num_gaussians", 50)
        else:
            raise NotImplementedError

        self.model = registry.get_model_class(self.config["model"])(
            self.train_loader.dataset[0].x.shape[-1]
            if hasattr(self.train_loader.dataset[0], "x")
            and self.train_loader.dataset[0].x is not None else None,
            bond_feat_dim,
            self.num_targets,
            **self.config["model_attributes"],
        ).to(self.device)

        if distutils.is_master():
            print("### Loaded {} with {} parameters.".format(
                self.model.__class__.__name__, self.model.num_params))

        if self.logger is not None:
            self.logger.watch(self.model)

        self.model = OCPDataParallel(
            self.model,
            output_device=self.device,
            num_gpus=1,
        )
        if distutils.initialized():
            self.model = DistributedDataParallel(self.model,
                                                 device_ids=[self.device])
コード例 #11
0
def validate(ae_without_ddp, data_loader, config, device, distributed,
             device_ids):
    input_shape = config['input_shape']
    extended_model, _ = ae_util.get_extended_model(ae_without_ddp, config,
                                                   input_shape, device, True)
    if distributed:
        extended_model = DistributedDataParallel(extended_model,
                                                 device_ids=device_ids)
    return evaluate(extended_model,
                    data_loader,
                    device,
                    split_name='Validation')
コード例 #12
0
    def __init__(self, config: BaseConfig):
        self._config = config
        self._model = DeepLab(num_classes=9, output_stride=8,
                              sync_bn=False).to(self._config.device)
        self._border_loss = TotalLoss(self._config)
        self._direction_loss = CrossEntropyLoss()
        self._loaders = get_data_loaders(config)
        self._writer = SummaryWriter()
        self._optimizer = torch.optim.SGD(self._model.parameters(),
                                          lr=self._config.lr,
                                          weight_decay=1e-4,
                                          nesterov=True,
                                          momentum=0.9)
        self._scheduler = torch.optim.lr_scheduler.ExponentialLR(
            self._optimizer, gamma=0.97)

        if self._config.parallel:
            self._model = DistributedDataParallel(self._model,
                                                  device_ids=[
                                                      self._config.device,
                                                  ])
コード例 #13
0
    def __init__(self, cem_planner: CEMPlanner):
        super().__init__()
        self.plan_horizon_length = cem_planner.plan_horizon_length
        self.state_dim = cem_planner.state_dim
        self.action_dim = cem_planner.action_dim
        self.discrete_action = cem_planner.discrete_action

        current_device = torch.cuda.current_device()  # type: ignore
        self.data_parallel = DistributedDataParallel(
            cem_planner.cem_planner_network,
            device_ids=[current_device],
            output_device=current_device,
        )
        self.cem_planner = cem_planner
コード例 #14
0
ファイル: world_model.py プロジェクト: sethips/Horizon
    def __init__(self, mem_net):
        super().__init__()
        self.num_hiddens = mem_net.num_hiddens
        self.num_hidden_layers = mem_net.num_hidden_layers
        self.state_dim = mem_net.state_dim
        self.action_dim = mem_net.action_dim
        self.num_gaussians = mem_net.num_gaussians

        current_device = torch.cuda.current_device()
        self.data_parallel = DistributedDataParallel(
            mem_net.mdnrnn,
            device_ids=[current_device],
            output_device=current_device)
        self.mem_net = mem_net
コード例 #15
0
def main(args):
    config = yaml_util.load_yaml_file(args.config)
    if args.json is not None:
        main_util.overwrite_config(config, args.json)

    distributed, device_ids = main_util.init_distributed_mode(
        args.world_size, args.dist_url)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    teacher_model = get_model(config['teacher_model'], device)
    module_util.freeze_module_params(teacher_model)
    student_model_config = config['student_model']
    student_model = get_model(student_model_config,
                              device,
                              require_weights=args.require_weights)

    reset_unfrozen = False
    if 'reset_unfrozen' in student_model_config:
        reset_unfrozen = student_model_config['reset_unfrozen']
    freeze_modules(student_model,
                   student_model_config,
                   reset_unfrozen=reset_unfrozen)
    set_bottleneck_transformer(student_model, True)
    post_bn = False
    if 'post_batch_norm' in config['train']:
        post_bn = config['train']['post_batch_norm']

    # print('Updatable parameters: {}'.format(module_util.get_updatable_param_names(student_model)))
    distill_backbone_only = student_model_config['distill_backbone_only']
    train_config = config['train']

    train_sampler, train_data_loader, val_data_loader, test_data_loader = \
        data_util.get_coco_data_loaders(config['dataset'], train_config['batch_size'], distributed)
    if distributed:
        teacher_model = DataParallel(teacher_model, device_ids=device_ids)
        student_model = DistributedDataParallel(student_model,
                                                device_ids=device_ids)

    if args.distill:
        distill(teacher_model, student_model, train_sampler, train_data_loader,
                val_data_loader, device, distributed, distill_backbone_only,
                config, args)
        load_ckpt(
            config['student_model']['ckpt'],
            model=student_model.module if isinstance(
                student_model, DistributedDataParallel) else student_model)
    evaluate(teacher_model, student_model, test_data_loader, train_data_loader,
             device, args.skip_teacher_eval, args.transform_bottleneck,
             student_model_config, post_bn)
コード例 #16
0
def train(model, train_loader, valid_loader, best_valid_acc, criterion, device,
          distributed, device_ids, train_config, num_epochs, start_epoch,
          init_lr, ckpt_file_path, model_type):
    model_without_ddp = model
    if distributed:
        model = DistributedDataParallel(model_without_ddp,
                                        device_ids=device_ids)
    elif device.type == 'cuda':
        model = DataParallel(model_without_ddp)

    optim_config = train_config['optimizer']
    if init_lr is not None:
        optim_config['params']['lr'] = init_lr

    optimizer = func_util.get_optimizer(model, optim_config['type'],
                                        optim_config['params'])
    scheduler_config = train_config['scheduler']
    scheduler = func_util.get_scheduler(optimizer, scheduler_config['type'],
                                        scheduler_config['params'])
    interval = train_config['interval']
    if interval <= 0:
        num_batches = len(train_loader)
        interval = num_batches // 20 if num_batches >= 20 else 1

    end_epoch = start_epoch + train_config[
        'epoch'] if num_epochs is None else start_epoch + num_epochs
    start_time = time.time()
    for epoch in range(start_epoch, end_epoch):
        if distributed:
            train_loader.sampler.set_epoch(epoch)

        train_epoch(model, train_loader, optimizer, criterion, epoch, device,
                    interval)
        valid_acc = validate(model, valid_loader, device)
        if valid_acc > best_valid_acc and main_util.is_main_process():
            print(
                'Updating ckpt (Best top1 accuracy: {:.4f} -> {:.4f})'.format(
                    best_valid_acc, valid_acc))
            best_valid_acc = valid_acc
            save_ckpt(model_without_ddp, best_valid_acc, epoch, ckpt_file_path,
                      model_type)
        scheduler.step()

    dist.barrier()
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
コード例 #17
0
def train(epochs: int,
          train_data_loader: DataLoader,
          valid_data_loader: DataLoader = None,
          rank=None):
    device = torch.device(f'cuda:{rank}')
    model = create_model(model_type).to(device)
    model = DistributedDataParallel(model,
                                    device_ids=[rank],
                                    output_device=rank)
    optimizer = AdamW(model.parameters(), lr=lr)
    tokenizer = BertTokenizer.from_pretrained(model_type)

    def update_weights(bi, di, num_batches, batch_loss):
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if bi % 100 == 0:
            logger.info(
                f'training: device={di}; batch={bi+1}/{num_batches}; batch_error={batch_loss.item()};'
            )

    def valid_loss_progress_log(bi, di, num_batches, batch_loss):
        if bi % 100 == 0:
            logger.info(
                f'validation: device={di}; batch={bi+1}/{num_batches}; val_batch_error={batch_loss.item()};'
            )

    for i in range(epochs):
        model.train()
        train_data_loader.sampler.set_epoch(i)
        valid_data_loader.sampler.set_epoch(i)

        train_loss = run(model, train_data_loader, tokenizer, device,
                         update_weights)

        if valid_data_loader is not None:
            with torch.no_grad():
                model.eval()
                val_loss = run(model, valid_data_loader, tokenizer, device,
                               valid_loss_progress_log)
        else:
            val_loss = 'N/A'

        logger.info(
            f'epoch={i}; device={rank}; train_error={train_loss};  valid_error={val_loss};'
        )

    return model.module
コード例 #18
0
def validate(student_model_without_ddp, data_loader, config, device,
             distributed, device_ids):
    teacher_model_config = config['teacher_model']
    org_model, teacher_model_type = mimic_util.get_org_model(
        teacher_model_config, device)
    mimic_model = mimic_util.get_mimic_model(
        config,
        org_model,
        teacher_model_type,
        teacher_model_config,
        device,
        head_model=student_model_without_ddp)
    mimic_model_without_dp = mimic_model.module if isinstance(
        mimic_model, DataParallel) else mimic_model
    if distributed:
        mimic_model = DistributedDataParallel(mimic_model_without_dp,
                                              device_ids=device_ids)
    return evaluate(mimic_model, data_loader, device, split_name='Validation')
コード例 #19
0
ファイル: seq2slate.py プロジェクト: zhenyu-captain/ReAgent
    def __init__(self, seq2slate_transformer_net: Seq2SlateTransformerNet):
        super().__init__()
        self.state_dim = seq2slate_transformer_net.state_dim
        self.candidate_dim = seq2slate_transformer_net.candidate_dim
        self.num_stacked_layers = seq2slate_transformer_net.num_stacked_layers
        self.num_heads = seq2slate_transformer_net.num_heads
        self.dim_model = seq2slate_transformer_net.dim_model
        self.dim_feedforward = seq2slate_transformer_net.dim_feedforward
        self.max_src_seq_len = seq2slate_transformer_net.max_src_seq_len
        self.max_tgt_seq_len = seq2slate_transformer_net.max_tgt_seq_len

        current_device = torch.cuda.current_device()  # type: ignore
        self.data_parallel = DistributedDataParallel(
            seq2slate_transformer_net.seq2slate_transformer,
            device_ids=[current_device],
            output_device=current_device,
        )
        self.seq2slate_transformer_net = seq2slate_transformer_net
コード例 #20
0
def run(args):
    distributed, device_ids = main_util.init_distributed_mode(
        args.world_size, args.dist_url)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        cudnn.benchmark = True

    print(args)
    config = yaml_util.load_yaml_file(args.config)
    dataset_config = config['dataset']
    input_shape = config['input_shape']
    train_config = config['train']
    test_config = config['test']
    train_loader, valid_loader, test_loader =\
        dataset_util.get_data_loaders(dataset_config, batch_size=train_config['batch_size'],
                                      rough_size=train_config['rough_size'], reshape_size=input_shape[1:3],
                                      test_batch_size=test_config['batch_size'], jpeg_quality=-1,
                                      distributed=distributed)
    teacher_model_config = config['teacher_model']
    if not args.test_only:
        distill(train_loader, valid_loader, input_shape, args.aux, config,
                device, distributed, device_ids)

    org_model, teacher_model_type = mimic_util.get_org_model(
        teacher_model_config, device)
    if not args.student_only:
        if distributed:
            org_model = DataParallel(org_model, device_ids=device_ids)
        evaluate(org_model, test_loader, device, title='[Original model]')

    mimic_model = mimic_util.get_mimic_model(config, org_model,
                                             teacher_model_type,
                                             teacher_model_config, device)
    mimic_model_without_dp = mimic_model.module if isinstance(
        mimic_model, DataParallel) else mimic_model
    file_util.save_pickle(mimic_model_without_dp,
                          config['mimic_model']['ckpt'])
    if distributed:
        mimic_model = DistributedDataParallel(mimic_model_without_dp,
                                              device_ids=device_ids)
    evaluate(mimic_model, test_loader, device, title='[Mimic model]')
コード例 #21
0
ファイル: ddp.py プロジェクト: wenig/pytorch-distributed-rnn
    def __init__(self,
                 model,
                 training_set,
                 batch_size,
                 learning_rate,
                 validation_set=None,
                 test_set=None,
                 checkpoint_dir=None):
        init_process_group('mpi')
        model = DistributedDataParallel(model)
        rank = get_rank()
        world_size = get_world_size()

        super().__init__(rank=rank,
                         world_size=world_size,
                         model=model,
                         training_set=training_set,
                         validation_set=validation_set,
                         test_set=test_set,
                         batch_size=batch_size,
                         learning_rate=learning_rate,
                         checkpoint_dir=checkpoint_dir)
コード例 #22
0
    def configure_ddp(self, model: LightningModule,
                      device_ids: List[int]) -> DistributedDataParallel:
        """
        Pass through all customizations from constructor to :class:`~torch.nn.parallel.DistributedDataParallel`.
        Override to define a custom DDP implementation.

        .. note:: This requires that your DDP implementation subclasses
            :class:`~torch.nn.parallel.DistributedDataParallel` and that
            the original LightningModule gets wrapped by
            :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedModule`.

        The default implementation is::

            def configure_ddp(self, model, device_ids):
                model = DistributedDataParallel(
                    LightningDistributedModule(model),
                    device_ids=device_ids,
                    **self._ddp_kwargs,
                )
                return model

        Args:
            model: the LightningModule
            device_ids: the list of devices available

        Returns:
            the model wrapped in :class:`~torch.nn.parallel.DistributedDataParallel`

        """
        # if unset, default `find_unused_parameters` `True`
        self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
            "find_unused_parameters", True)
        model = DistributedDataParallel(
            module=LightningDistributedModule(model),
            device_ids=device_ids,
            **self._ddp_kwargs,
        )
        return model
コード例 #23
0
def main():
    parser = get_parser()
    args = parser.parse_args()
    name = get_name(parser, args)
    print(f"Experiment {name}")

    if args.musdb is None and args.rank == 0:
        print(
            "You must provide the path to the MusDB dataset with the --musdb flag. "
            "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.",
            file=sys.stderr)
        sys.exit(1)

    eval_folder = args.evals / name
    eval_folder.mkdir(exist_ok=True, parents=True)
    args.logs.mkdir(exist_ok=True)
    metrics_path = args.logs / f"{name}.json"
    eval_folder.mkdir(exist_ok=True, parents=True)
    args.checkpoints.mkdir(exist_ok=True, parents=True)
    args.models.mkdir(exist_ok=True, parents=True)

    if args.device is None:
        device = "cpu"
        if th.cuda.is_available():
            device = "cuda"
    else:
        device = args.device

    th.manual_seed(args.seed)
    # Prevents too many threads to be started when running `museval` as it can be quite
    # inefficient on NUMA architectures.
    os.environ["OMP_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"

    if args.world_size > 1:
        if device != "cuda" and args.rank == 0:
            print(
                "Error: distributed training is only available with cuda device",
                file=sys.stderr)
            sys.exit(1)
        th.cuda.set_device(args.rank % th.cuda.device_count())
        distributed.init_process_group(backend="nccl",
                                       init_method="tcp://" + args.master,
                                       rank=args.rank,
                                       world_size=args.world_size)

    checkpoint = args.checkpoints / f"{name}.th"
    checkpoint_tmp = args.checkpoints / f"{name}.th.tmp"
    if args.restart and checkpoint.exists() and args.rank == 0:
        checkpoint.unlink()

    if args.test or args.test_pretrained:
        args.epochs = 1
        args.repeat = 0
        if args.test:
            model = load_model(args.models / args.test)
        else:
            model = load_pretrained(args.test_pretrained)
    elif args.tasnet:
        model = ConvTasNet(audio_channels=args.audio_channels,
                           samplerate=args.samplerate,
                           X=args.X,
                           segment_length=4 * args.samples)
    else:
        model = Demucs(
            audio_channels=args.audio_channels,
            channels=args.channels,
            context=args.context,
            depth=args.depth,
            glu=args.glu,
            growth=args.growth,
            kernel_size=args.kernel_size,
            lstm_layers=args.lstm_layers,
            rescale=args.rescale,
            rewrite=args.rewrite,
            stride=args.conv_stride,
            resample=args.resample,
            samplerate=args.samplerate,
            segment_length=4 * args.samples,
        )
    model.to(device)
    if args.init:
        model.load_state_dict(load_pretrained(args.init).state_dict())

    if args.show:
        print(model)
        size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
        print(f"Model size {size}")
        return

    try:
        saved = th.load(checkpoint, map_location='cpu')
    except IOError:
        saved = SavedState()

    optimizer = th.optim.Adam(model.parameters(), lr=args.lr)

    quantizer = None
    quantizer = get_quantizer(model, args, optimizer)

    if saved.last_state is not None:
        model.load_state_dict(saved.last_state, strict=False)
    if saved.optimizer is not None:
        optimizer.load_state_dict(saved.optimizer)

    model_name = f"{name}.th"
    if args.save_model:
        if args.rank == 0:
            model.to("cpu")
            model.load_state_dict(saved.best_state)
            save_model(model, quantizer, args, args.models / model_name)
        return
    elif args.save_state:
        model_name = f"{args.save_state}.th"
        if args.rank == 0:
            model.to("cpu")
            model.load_state_dict(saved.best_state)
            state = get_state(model, quantizer)
            save_state(state, args.models / model_name)
        return

    if args.rank == 0:
        done = args.logs / f"{name}.done"
        if done.exists():
            done.unlink()

    augment = [Shift(args.data_stride)]
    if args.augment:
        augment += [
            FlipSign(),
            FlipChannels(),
            Scale(),
            Remix(group_size=args.remix_group_size)
        ]
    augment = nn.Sequential(*augment).to(device)
    print("Agumentation pipeline:", augment)

    if args.mse:
        criterion = nn.MSELoss()
    else:
        criterion = nn.L1Loss()

    # Setting number of samples so that all convolution windows are full.
    # Prevents hard to debug mistake with the prediction being shifted compared
    # to the input mixture.
    samples = model.valid_length(args.samples)
    print(f"Number of training samples adjusted to {samples}")
    samples = samples + args.data_stride
    if args.repitch:
        # We need a bit more audio samples, to account for potential
        # tempo change.
        samples = math.ceil(samples / (1 - 0.01 * args.max_tempo))

    sources = 4
    if args.raw:
        train_set = Rawset(args.raw / "train",
                           samples=samples,
                           channels=args.audio_channels,
                           streams=range(1, sources + 1),
                           stride=args.data_stride)

        valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
    else:
        if not args.metadata.is_file() and args.rank == 0:
            build_musdb_metadata(args.metadata, args.musdb, args.workers)
        if args.world_size > 1:
            distributed.barrier()
        metadata = json.load(open(args.metadata))
        duration = Fraction(samples, args.samplerate)
        stride = Fraction(args.data_stride, args.samplerate)
        train_set = StemsSet(get_musdb_tracks(args.musdb,
                                              subsets=["train"],
                                              split="train"),
                             metadata,
                             duration=duration,
                             stride=stride,
                             streams=range(1, sources + 1),
                             samplerate=args.samplerate,
                             channels=args.audio_channels)
        valid_set = StemsSet(get_musdb_tracks(args.musdb,
                                              subsets=["train"],
                                              split="valid"),
                             metadata,
                             samplerate=args.samplerate,
                             channels=args.audio_channels)
    if args.repitch:
        train_set = RepitchedWrapper(train_set,
                                     proba=args.repitch,
                                     max_tempo=args.max_tempo)

    best_loss = float("inf")
    for epoch, metrics in enumerate(saved.metrics):
        print(f"Epoch {epoch:03d}: "
              f"train={metrics['train']:.8f} "
              f"valid={metrics['valid']:.8f} "
              f"best={metrics['best']:.4f} "
              f"ms={metrics.get('true_model_size', 0):.2f}MB "
              f"cms={metrics.get('compressed_model_size', 0):.2f}MB "
              f"duration={human_seconds(metrics['duration'])}")
        best_loss = metrics['best']

    if args.world_size > 1:
        dmodel = DistributedDataParallel(
            model,
            device_ids=[th.cuda.current_device()],
            output_device=th.cuda.current_device())
    else:
        dmodel = model

    for epoch in range(len(saved.metrics), args.epochs):
        begin = time.time()
        model.train()
        train_loss, model_size = train_model(epoch,
                                             train_set,
                                             dmodel,
                                             criterion,
                                             optimizer,
                                             augment,
                                             quantizer=quantizer,
                                             batch_size=args.batch_size,
                                             device=device,
                                             repeat=args.repeat,
                                             seed=args.seed,
                                             diffq=args.diffq,
                                             workers=args.workers,
                                             world_size=args.world_size)
        model.eval()
        valid_loss = validate_model(epoch,
                                    valid_set,
                                    model,
                                    criterion,
                                    device=device,
                                    rank=args.rank,
                                    split=args.split_valid,
                                    overlap=args.overlap,
                                    world_size=args.world_size)

        ms = 0
        cms = 0
        if quantizer and args.rank == 0:
            ms = quantizer.true_model_size()
            cms = quantizer.compressed_model_size(
                num_workers=min(40, args.world_size * 10))

        duration = time.time() - begin
        if valid_loss < best_loss and ms <= args.ms_target:
            best_loss = valid_loss
            saved.best_state = {
                key: value.to("cpu").clone()
                for key, value in model.state_dict().items()
            }

        saved.metrics.append({
            "train": train_loss,
            "valid": valid_loss,
            "best": best_loss,
            "duration": duration,
            "model_size": model_size,
            "true_model_size": ms,
            "compressed_model_size": cms,
        })
        if args.rank == 0:
            json.dump(saved.metrics, open(metrics_path, "w"))

        saved.last_state = model.state_dict()
        saved.optimizer = optimizer.state_dict()
        if args.rank == 0 and not args.test:
            th.save(saved, checkpoint_tmp)
            checkpoint_tmp.rename(checkpoint)

        print(
            f"Epoch {epoch:03d}: "
            f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} ms={ms:.2f}MB "
            f"cms={cms:.2f}MB "
            f"duration={human_seconds(duration)}")

    del dmodel
    model.load_state_dict(saved.best_state)
    if args.eval_cpu:
        device = "cpu"
        model.to(device)
    model.eval()
    evaluate(model,
             args.musdb,
             eval_folder,
             rank=args.rank,
             world_size=args.world_size,
             device=device,
             save=args.save,
             split=args.split_valid,
             shifts=args.shifts,
             overlap=args.overlap,
             workers=args.eval_workers)
    model.to("cpu")
    if args.rank == 0:
        if not (args.test or args.test_pretrained):
            save_model(model, quantizer, args, args.models / model_name)
        print("done")
        done.write_text("done")
コード例 #24
0
def register_ddp_comm_hook(
    model: DistributedDataParallel,
    ddp_comm_state: Optional[object] = None,
    ddp_comm_hook: Optional[Callable] = None,
    ddp_comm_wrapper: Optional[Callable] = None,
) -> None:
    """Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html.

    Args:
        model:
            DDP model
        ddp_comm_state:
            state is passed to the hook and can be used to maintain
            and update any state information that users would like to
            maintain as part of the training process. Examples: error
            feedback in gradient compression, peers to communicate with
            next in GossipGrad etc.
        ddp_comm_hook:
            hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future

            This callable function is called once the bucket is ready. The
            hook can perform whatever processing is needed and return
            a Future indicating completion of any async work (ex: allreduce).
            If the hook doesn't perform any communication, it can also
            just return a completed Future. The Future should hold the
            new value of grad bucket's tensors. Once a bucket is ready,
            c10d reducer would call this hook and use the tensors returned
            by the Future and copy grads to individual parameters.
        ddp_comm_wrapper:
            communication hook wrapper to support a communication hook such
            as FP16 compression as wrapper, which could be combined with
            ddp_comm_hook

    .. warning ::
        DDP communication hook needs pytorch version at least 1.8.0

    .. warning ::
        DDP communication wrapper needs pytorch version at least 1.9.0
        Post-localSGD hook needs pytorch version at least 1.9.0

    Examples:

        >>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP
        ...     default_hooks as default,
        ...     powerSGD_hook as powerSGD,
        ...     post_localSGD_hook as post_localSGD,
        ... )
        >>>
        >>> # fp16_compress_hook for compress gradients
        >>> ddp_model = ...
        >>> register_ddp_comm_hook( # doctest: +SKIP
        ...     model=ddp_model,
        ...     ddp_comm_hook=default.fp16_compress_hook,
        ... )
        >>>
        >>> # powerSGD_hook
        >>> ddp_model = ...
        >>> register_ddp_comm_hook( # doctest: +SKIP
        ...     model=ddp_model,
        ...     ddp_comm_state=powerSGD.PowerSGDState(
        ...         process_group=None,
        ...         matrix_approximation_rank=1,
        ...         start_powerSGD_iter=5000,
        ...     ),
        ...     ddp_comm_hook=powerSGD.powerSGD_hook,
        ... )
        >>>
        >>> # post_localSGD_hook
        >>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP
        >>> ddp_model = ...
        >>> register_ddp_comm_hook( # doctest: +SKIP
        ...     model=ddp_model,
        ...     state=post_localSGD.PostLocalSGDState(
        ...         process_group=None,
        ...         subgroup=subgroup,
        ...         start_localSGD_iter=1_000,
        ...     ),
        ...     ddp_comm_hook=post_localSGD.post_localSGD_hook,
        ... )
        >>>
        >>> # fp16_compress_wrapper combined with other communication hook
        >>> ddp_model = ...
        >>> register_ddp_comm_hook( # doctest: +SKIP
        ...     model=ddp_model,
        ...     ddp_comm_state=powerSGD.PowerSGDState(
        ...         process_group=None,
        ...         matrix_approximation_rank=1,
        ...         start_powerSGD_iter=5000,
        ...     ),
        ...     ddp_comm_hook=powerSGD.powerSGD_hook,
        ...     ddp_comm_wrapper=default.fp16_compress_wrapper,
        ... )
    """
    from pytorch_lightning.utilities import rank_zero_warn

    if not _TORCH_GREATER_EQUAL_1_8:
        rank_zero_warn(
            "Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0."
        )
        return
    if ddp_comm_hook is None:
        return
    # inform mypy that ddp_comm_hook is callable
    ddp_comm_hook: Callable = ddp_comm_hook

    if ddp_comm_wrapper is not None:
        if not _TORCH_GREATER_EQUAL_1_9:
            rank_zero_warn(
                "Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0."
            )
        else:
            new_rank_zero_info(
                f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
            )
            ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)

    rank_zero_debug(
        f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
    model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)
コード例 #25
0
def train(cfg, args):
    logger = logging.getLogger('model training')
    train_data = build_dataloader(cfg.train_data_cfg, args.distributed)
    logger.info("train data: {}".format(len(train_data)))
    val_data = build_dataloader(cfg.val_data_cfg, args.distributed)
    logger.info("val data: {}".format(len(val_data)))

    model = MultiInstanceRecognition(cfg.model_cfg).cuda()
    if cfg.resume_from is not None:
        logger.info('loading pretrained models from {opt.continue_model}')
        model.load_state_dict(torch.load(cfg.resume_from))
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
    voc, char2id, id2char = get_vocabulary("ALLCASES_SYMBOLS")

    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    logger.info('Trainable params num : ', sum(params_num))
    optimizer = optim.Adam(filtered_parameters, lr=cfg.lr, betas=(0.9, 0.999))
    lrScheduler = lr_scheduler.MultiStepLR(optimizer, [1, 2, 3], gamma=0.1)

    max_iters = cfg.max_iters
    start_iter = 0
    if cfg.resume_from is not None:
        start_iter = int(cfg.resume_from.split('_')[-1].split('.')[0])
        logger.info('continue to train, start_iter: {start_iter}')

    train_data_iter = iter(train_data)
    val_data_iter = iter(val_data)
    start_time = time.time()
    for i in range(start_iter, max_iters):
        model.train()
        try:
            batch_data = next(train_data_iter)
        except StopIteration:
            train_data_iter = iter(train_data)
            batch_data = next(train_data_iter)
        data_time_s = time.time()
        batch_imgs, batch_imgs_path, batch_rectangles, \
        batch_text_labels, batch_text_labels_mask, batch_words = \
            batch_data
        while batch_imgs is None:
            batch_data = next(train_data_iter)
            batch_imgs, batch_imgs_path, batch_rectangles, \
            batch_text_labels, batch_text_labels_mask, batch_words = \
                batch_data

        batch_imgs = batch_imgs.cuda(non_blocking=True)
        batch_rectangles = batch_rectangles.cuda(non_blocking=True)
        batch_text_labels = batch_text_labels.cuda(non_blocking=True)
        data_time = time.time() - data_time_s
        # print(time.time() -s)
        # s = time.time()
        loss, decoder_logits = model(batch_imgs, batch_text_labels,
                                     batch_rectangles, batch_text_labels_mask)
        del batch_data
        # print(time.time() - s)
        # print('------')
        # s = time.time()

        loss = loss.mean()
        print(loss)
        # del loss
        # print(time.time() - s)
        # print('------')

        if i % cfg.train_verbose == 0:
            this_time = time.time() - start_time
            if args.distributed:
                loss = dist.reduce(loss, 0)
            log_info = "train iter :{}, time: {:.2f}, data_time: {:.2f}, Loss: {:.3f}".format(
                i, this_time, data_time, loss.data)
            logger.info(log_info)
            torch.cuda.empty_cache()
            # break

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        del loss
        if i % cfg.val_iter == 0:
            print("--------Val iteration---------")
            model.eval()

            try:
                val_batch = next(val_data_iter)
            except StopIteration:
                val_data_iter = iter(val_data)
                val_batch = next(val_data_iter)

            batch_imgs, batch_imgs_path, batch_rectangles, \
            batch_text_labels, batch_text_labels_mask, batch_words = \
                val_batch
            while batch_imgs is None:
                val_batch = next(val_data_iter)
                batch_imgs, batch_imgs_path, batch_rectangles, \
                batch_text_labels, batch_text_labels_mask, batch_words = \
                    val_batch
            del val_batch
            batch_imgs = batch_imgs.cuda(non_blocking=True)
            batch_rectangles = batch_rectangles.cuda(non_blocking=True)
            batch_text_labels = batch_text_labels.cuda(non_blocking=True)
            with torch.no_grad():
                val_loss, val_pred_logits = model(batch_imgs,
                                                  batch_text_labels,
                                                  batch_rectangles,
                                                  batch_text_labels_mask)
            pred_labels = val_pred_logits.argmax(dim=2).cpu().numpy()
            pred_value_str = idx2label(pred_labels, id2char, char2id)
            # gt_str = batch_words
            gt_str = []
            for words in batch_words:
                gt_str = gt_str + words
            val_dec_metrics_result = calc_metrics(pred_value_str,
                                                  gt_str,
                                                  metrics_type="accuracy")
            this_time = time.time() - start_time
            if args.distributed:
                loss = dist.reduce(val_loss, 0)
            log_info = "val iter :{}, time: {:.2f} Loss: {:.3f}, acc: {:.2f}".format(
                i, this_time,
                loss.mean().data, val_dec_metrics_result)
            logger.info(log_info)
            del val_loss
        if (i + 1) % cfg.save_iter == 0:
            torch.save(model.state_dict(),
                       cfg.save_name + '_{}.pth'.format(i + 1))
        if i > 0 and i % cfg.lr_step == 0:  # 调整学习速率
            lrScheduler.step()
            logger.info("lr step")
        # torch.cuda.empty_cache()
    print('end the training')
コード例 #26
0
    def __init__(
        self,
        *,
        model,
        diffusion,
        data,
        batch_size,
        microbatch,
        lr,
        ema_rate,
        log_interval,
        save_interval,
        resume_checkpoint,
        use_fp16=False,
        fp16_scale_growth=1e-3,
        schedule_sampler=None,
        weight_decay=0.0,
        lr_anneal_steps=0,
    ):
        self.model = model
        self.diffusion = diffusion
        self.data = data
        self.batch_size = batch_size
        self.microbatch = microbatch if microbatch > 0 else batch_size
        self.lr = lr
        self.ema_rate = (
            [ema_rate]
            if isinstance(ema_rate, float)
            else [float(x) for x in ema_rate.split(",")]
        )
        self.log_interval = log_interval
        self.save_interval = save_interval
        self.resume_checkpoint = resume_checkpoint
        self.use_fp16 = use_fp16
        self.fp16_scale_growth = fp16_scale_growth
        self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
        self.weight_decay = weight_decay
        self.lr_anneal_steps = lr_anneal_steps

        self.step = 0
        self.resume_step = 0
        self.global_batch = self.batch_size * dist.get_world_size()

        self.model_params = list(self.model.parameters())
        self.master_params = self.model_params
        self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
        self.sync_cuda = th.cuda.is_available()

        self._load_and_sync_parameters()
        if self.use_fp16:
            self._setup_fp16()

        self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
        if self.resume_step:
            self._load_optimizer_state()
            # Model was resumed, either due to a restart or a checkpoint
            # being specified at the command line.
            self.ema_params = [
                self._load_ema_parameters(rate) for rate in self.ema_rate
            ]
        else:
            self.ema_params = [
                copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
            ]

        if th.cuda.is_available():
            self.use_ddp = True
            self.ddp_model = DDP(
                self.model,
                device_ids=[dist_util.dev()],
                output_device=dist_util.dev(),
                broadcast_buffers=False,
                bucket_cap_mb=128,
                find_unused_parameters=False,
            )
        else:
            if dist.get_world_size() > 1:
                logger.warn(
                    "Distributed training requires CUDA. "
                    "Gradients will not be synchronized properly!"
                )
            self.use_ddp = False
            self.ddp_model = self.model
コード例 #27
0
class TrainLoop:
    def __init__(
        self,
        *,
        model,
        diffusion,
        data,
        batch_size,
        microbatch,
        lr,
        ema_rate,
        log_interval,
        save_interval,
        resume_checkpoint,
        use_fp16=False,
        fp16_scale_growth=1e-3,
        schedule_sampler=None,
        weight_decay=0.0,
        lr_anneal_steps=0,
    ):
        self.model = model
        self.diffusion = diffusion
        self.data = data
        self.batch_size = batch_size
        self.microbatch = microbatch if microbatch > 0 else batch_size
        self.lr = lr
        self.ema_rate = (
            [ema_rate]
            if isinstance(ema_rate, float)
            else [float(x) for x in ema_rate.split(",")]
        )
        self.log_interval = log_interval
        self.save_interval = save_interval
        self.resume_checkpoint = resume_checkpoint
        self.use_fp16 = use_fp16
        self.fp16_scale_growth = fp16_scale_growth
        self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
        self.weight_decay = weight_decay
        self.lr_anneal_steps = lr_anneal_steps

        self.step = 0
        self.resume_step = 0
        self.global_batch = self.batch_size * dist.get_world_size()

        self.model_params = list(self.model.parameters())
        self.master_params = self.model_params
        self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
        self.sync_cuda = th.cuda.is_available()

        self._load_and_sync_parameters()
        if self.use_fp16:
            self._setup_fp16()

        self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
        if self.resume_step:
            self._load_optimizer_state()
            # Model was resumed, either due to a restart or a checkpoint
            # being specified at the command line.
            self.ema_params = [
                self._load_ema_parameters(rate) for rate in self.ema_rate
            ]
        else:
            self.ema_params = [
                copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
            ]

        if th.cuda.is_available():
            self.use_ddp = True
            self.ddp_model = DDP(
                self.model,
                device_ids=[dist_util.dev()],
                output_device=dist_util.dev(),
                broadcast_buffers=False,
                bucket_cap_mb=128,
                find_unused_parameters=False,
            )
        else:
            if dist.get_world_size() > 1:
                logger.warn(
                    "Distributed training requires CUDA. "
                    "Gradients will not be synchronized properly!"
                )
            self.use_ddp = False
            self.ddp_model = self.model

    def _load_and_sync_parameters(self):
        resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint

        if resume_checkpoint:
            self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
            if dist.get_rank() == 0:
                logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
                self.model.load_state_dict(
                    dist_util.load_state_dict(
                        resume_checkpoint, map_location=dist_util.dev()
                    )
                )

        dist_util.sync_params(self.model.parameters())

    def _load_ema_parameters(self, rate):
        ema_params = copy.deepcopy(self.master_params)

        main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
        ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
        if ema_checkpoint:
            if dist.get_rank() == 0:
                logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
                state_dict = dist_util.load_state_dict(
                    ema_checkpoint, map_location=dist_util.dev()
                )
                ema_params = self._state_dict_to_master_params(state_dict)

        dist_util.sync_params(ema_params)
        return ema_params

    def _load_optimizer_state(self):
        main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
        opt_checkpoint = bf.join(
            bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
        )
        if bf.exists(opt_checkpoint):
            logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
            state_dict = dist_util.load_state_dict(
                opt_checkpoint, map_location=dist_util.dev()
            )
            self.opt.load_state_dict(state_dict)

    def _setup_fp16(self):
        self.master_params = make_master_params(self.model_params)
        self.model.convert_to_fp16()

    def run_loop(self):
        while (
            not self.lr_anneal_steps
            or self.step + self.resume_step < self.lr_anneal_steps
        ):
            batch, cond = next(self.data)
            self.run_step(batch, cond)
            if self.step % self.log_interval == 0:
                logger.dumpkvs()
            if self.step % self.save_interval == 0:
                self.save()
                # Run for a finite amount of time in integration tests.
                if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
                    return
            self.step += 1
        # Save the last checkpoint if it wasn't already saved.
        if (self.step - 1) % self.save_interval != 0:
            self.save()

    def run_step(self, batch, cond):
        self.forward_backward(batch, cond)
        if self.use_fp16:
            self.optimize_fp16()
        else:
            self.optimize_normal()
        self.log_step()

    def forward_backward(self, batch, cond):
        zero_grad(self.model_params)
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i : i + self.microbatch].to(dist_util.dev())
            micro_cond = {
                k: v[i : i + self.microbatch].to(dist_util.dev())
                for k, v in cond.items()
            }
            last_batch = (i + self.microbatch) >= batch.shape[0]
            t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())

            compute_losses = functools.partial(
                self.diffusion.training_losses,
                self.ddp_model,
                micro,
                t,
                model_kwargs=micro_cond,
            )

            if last_batch or not self.use_ddp:
                losses = compute_losses()
            else:
                with self.ddp_model.no_sync():
                    losses = compute_losses()

            if isinstance(self.schedule_sampler, LossAwareSampler):
                self.schedule_sampler.update_with_local_losses(
                    t, losses["loss"].detach()
                )

            loss = (losses["loss"] * weights).mean()
            log_loss_dict(
                self.diffusion, t, {k: v * weights for k, v in losses.items()}
            )
            if self.use_fp16:
                loss_scale = 2 ** self.lg_loss_scale
                (loss * loss_scale).backward()
            else:
                loss.backward()

    def optimize_fp16(self):
        if any(not th.isfinite(p.grad).all() for p in self.model_params):
            self.lg_loss_scale -= 1
            logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
            return

        model_grads_to_master_grads(self.model_params, self.master_params)
        self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
        self._log_grad_norm()
        self._anneal_lr()
        self.opt.step()
        for rate, params in zip(self.ema_rate, self.ema_params):
            update_ema(params, self.master_params, rate=rate)
        master_params_to_model_params(self.model_params, self.master_params)
        self.lg_loss_scale += self.fp16_scale_growth

    def optimize_normal(self):
        self._log_grad_norm()
        self._anneal_lr()
        self.opt.step()
        for rate, params in zip(self.ema_rate, self.ema_params):
            update_ema(params, self.master_params, rate=rate)

    def _log_grad_norm(self):
        sqsum = 0.0
        for p in self.master_params:
            sqsum += (p.grad ** 2).sum().item()
        logger.logkv_mean("grad_norm", np.sqrt(sqsum))

    def _anneal_lr(self):
        if not self.lr_anneal_steps:
            return
        frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
        lr = self.lr * (1 - frac_done)
        for param_group in self.opt.param_groups:
            param_group["lr"] = lr

    def log_step(self):
        logger.logkv("step", self.step + self.resume_step)
        logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
        if self.use_fp16:
            logger.logkv("lg_loss_scale", self.lg_loss_scale)

    def save(self):
        def save_checkpoint(rate, params):
            state_dict = self._master_params_to_state_dict(params)
            if dist.get_rank() == 0:
                logger.log(f"saving model {rate}...")
                if not rate:
                    filename = f"model{(self.step+self.resume_step):06d}.pt"
                else:
                    filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
                with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
                    th.save(state_dict, f)

        save_checkpoint(0, self.master_params)
        for rate, params in zip(self.ema_rate, self.ema_params):
            save_checkpoint(rate, params)

        if dist.get_rank() == 0:
            with bf.BlobFile(
                bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
                "wb",
            ) as f:
                th.save(self.opt.state_dict(), f)

        dist.barrier()

    def _master_params_to_state_dict(self, master_params):
        if self.use_fp16:
            master_params = unflatten_master_params(
                self.model.parameters(), master_params
            )
        state_dict = self.model.state_dict()
        for i, (name, _value) in enumerate(self.model.named_parameters()):
            assert name in state_dict
            state_dict[name] = master_params[i]
        return state_dict

    def _state_dict_to_master_params(self, state_dict):
        params = [state_dict[name] for name, _ in self.model.named_parameters()]
        if self.use_fp16:
            return make_master_params(params)
        else:
            return params
コード例 #28
0
def distill(train_loader, valid_loader, input_shape, aux_weight, config,
            device, distributed, device_ids):
    teacher_model_config = config['teacher_model']
    teacher_model, teacher_model_type = mimic_util.get_teacher_model(
        teacher_model_config, input_shape, device)
    module_util.freeze_module_params(teacher_model)
    student_model_config = config['student_model']
    student_model = mimic_util.get_student_model(teacher_model_type,
                                                 student_model_config,
                                                 config['dataset']['name'])
    student_model = student_model.to(device)
    start_epoch, best_valid_acc = mimic_util.resume_from_ckpt(
        student_model_config['ckpt'], student_model, is_student=True)
    if best_valid_acc is None:
        best_valid_acc = 0.0

    train_config = config['train']
    criterion_config = train_config['criterion']
    criterion = func_util.get_loss(criterion_config['type'],
                                   criterion_config['params'])
    optim_config = train_config['optimizer']
    optimizer = func_util.get_optimizer(student_model, optim_config['type'],
                                        optim_config['params'])
    scheduler_config = train_config['scheduler']
    scheduler = func_util.get_scheduler(optimizer, scheduler_config['type'],
                                        scheduler_config['params'])
    interval = train_config['interval']
    if interval <= 0:
        num_batches = len(train_loader)
        interval = num_batches // 20 if num_batches >= 20 else 1

    student_model_without_ddp = student_model
    if distributed:
        teacher_model = DataParallel(teacher_model, device_ids=device_ids)
        student_model = DistributedDataParallel(student_model,
                                                device_ids=device_ids)
        student_model_without_ddp = student_model.module

    ckpt_file_path = student_model_config['ckpt']
    end_epoch = start_epoch + train_config['epoch']
    start_time = time.time()
    for epoch in range(start_epoch, end_epoch):
        if distributed:
            train_loader.sampler.set_epoch(epoch)

        distill_one_epoch(student_model, teacher_model, train_loader,
                          optimizer, criterion, epoch, device, interval,
                          aux_weight)
        valid_acc = validate(student_model, valid_loader, config, device,
                             distributed, device_ids)
        if valid_acc > best_valid_acc and main_util.is_main_process():
            print(
                'Updating ckpt (Best top1 accuracy: {:.4f} -> {:.4f})'.format(
                    best_valid_acc, valid_acc))
            best_valid_acc = valid_acc
            save_ckpt(student_model_without_ddp, epoch, best_valid_acc,
                      ckpt_file_path, teacher_model_type)
        scheduler.step()

    dist.barrier()
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
    del teacher_model
    del student_model
コード例 #29
0
        # # 数值归一化
        # cam -= np.min(cam)
        # cam /= np.max(cam)
        # # resize to 224*224
        # cam = cv2.resize(cam, (224, 224))
        # return cam
        print("gradient: ", gradient.shape)
        print("feature: ", feature.shape)


os.environ['CUDA_VISIBLE_DEVICES'] = "1,3"
torch.distributed.init_process_group(backend="nccl",
                                     init_method='tcp://localhost:12556',
                                     rank=0,
                                     world_size=1)

model = ThinResNet()
model = model.cuda()
model = DistributedDataParallel(model)
gc = GradCAM(model, 'layer4')

x = torch.randn((20, 1, 224, 224)).cuda()
l = torch.range(0, 19).long().unsqueeze(1).cuda()

y = model(x)

#
cam = gc(x, l)

print(cam.shape)
コード例 #30
0
class BaseTrainer:
    def __init__(
        self,
        task,
        model,
        dataset,
        optimizer,
        identifier,
        run_dir=None,
        is_debug=False,
        is_vis=False,
        print_every=100,
        seed=None,
        logger="tensorboard",
        local_rank=0,
        amp=False,
        name="base_trainer",
    ):
        self.name = name
        if torch.cuda.is_available():
            self.device = local_rank
        else:
            self.device = "cpu"

        if run_dir is None:
            run_dir = os.getcwd()
        run_dir = Path(run_dir)

        timestamp = torch.tensor(datetime.datetime.now().timestamp()).to(
            self.device)
        # create directories from master rank only
        distutils.broadcast(timestamp, 0)
        timestamp = datetime.datetime.fromtimestamp(timestamp).strftime(
            "%Y-%m-%d-%H-%M-%S")
        if identifier:
            timestamp += "-{}".format(identifier)

        self.config = {
            "task": task,
            "model": model.pop("name"),
            "model_attributes": model,
            "optim": optimizer,
            "logger": logger,
            "amp": amp,
            "cmd": {
                "identifier": identifier,
                "print_every": print_every,
                "seed": seed,
                "timestamp": timestamp,
                "checkpoint_dir": str(run_dir / "checkpoints" / timestamp),
                "results_dir": str(run_dir / "results" / timestamp),
                "logs_dir": str(run_dir / "logs" / logger / timestamp),
            },
        }
        # AMP Scaler
        self.scaler = torch.cuda.amp.GradScaler() if amp else None

        if isinstance(dataset, list):
            self.config["dataset"] = dataset[0]
            if len(dataset) > 1:
                self.config["val_dataset"] = dataset[1]
            if len(dataset) > 2:
                self.config["test_dataset"] = dataset[2]
        else:
            self.config["dataset"] = dataset

        if not is_debug and distutils.is_master():
            os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True)
            os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True)
            os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True)

        self.is_debug = is_debug
        self.is_vis = is_vis

        if distutils.is_master():
            print(yaml.dump(self.config, default_flow_style=False))
        self.load()

        self.evaluator = Evaluator(task=name)

    def load(self):
        self.load_seed_from_config()
        self.load_logger()
        self.load_task()
        self.load_model()
        self.load_criterion()
        self.load_optimizer()
        self.load_extras()

    # Note: this function is now deprecated. We build config outside of trainer.
    # See build_config in ocpmodels.common.utils.py.
    def load_config_from_yaml_and_cmd(self, args):
        self.config = build_config(args)

        # AMP Scaler
        self.scaler = (torch.cuda.amp.GradScaler()
                       if self.config["amp"] else None)

        # device
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        # Are we just running sanity checks?
        self.is_debug = args.debug
        self.is_vis = args.vis

        # timestamps and directories
        args.timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        if args.identifier:
            args.timestamp += "-{}".format(args.identifier)

        args.checkpoint_dir = os.path.join("checkpoints", args.timestamp)
        args.results_dir = os.path.join("results", args.timestamp)
        args.logs_dir = os.path.join("logs", self.config["logger"],
                                     args.timestamp)

        print(yaml.dump(self.config, default_flow_style=False))
        for arg in vars(args):
            print("{:<20}: {}".format(arg, getattr(args, arg)))

        # TODO(abhshkdz): Handle these parameters better. Maybe move to yaml.
        self.config["cmd"] = args.__dict__
        del args

        if not self.is_debug:
            os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True)
            os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True)
            os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True)

            # Dump config parameters
            json.dump(
                self.config,
                open(
                    os.path.join(self.config["cmd"]["checkpoint_dir"],
                                 "config.json"),
                    "w",
                ),
            )

    def load_seed_from_config(self):
        # https://pytorch.org/docs/stable/notes/randomness.html
        seed = self.config["cmd"]["seed"]
        if seed is None:
            return

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def load_logger(self):
        self.logger = None
        if not self.is_debug and distutils.is_master():
            assert (self.config["logger"]
                    is not None), "Specify logger in config"
            self.logger = registry.get_logger_class(self.config["logger"])(
                self.config)

    def load_task(self):
        print("### Loading dataset: {}".format(self.config["task"]["dataset"]))
        dataset = registry.get_dataset_class(self.config["task"]["dataset"])(
            self.config["dataset"])

        if self.config["task"]["dataset"] in ["qm9", "dogss"]:
            num_targets = dataset.data.y.shape[-1]
            if ("label_index" in self.config["task"]
                    and self.config["task"]["label_index"] is not False):
                dataset.data.y = dataset.data.y[:,
                                                int(self.config["task"]
                                                    ["label_index"])]
                num_targets = 1
        else:
            num_targets = 1

        self.num_targets = num_targets
        (
            self.train_loader,
            self.val_loader,
            self.test_loader,
        ) = dataset.get_dataloaders(
            batch_size=int(self.config["optim"]["batch_size"]))

        # Normalizer for the dataset.
        # Compute mean, std of training set labels.
        self.normalizers = {}
        if self.config["dataset"].get("normalize_labels", True):
            self.normalizers["target"] = Normalizer(
                self.train_loader.dataset.data.y[
                    self.train_loader.dataset.__indices__],
                self.device,
            )

        # If we're computing gradients wrt input, set mean of normalizer to 0 --
        # since it is lost when compute dy / dx -- and std to forward target std
        if "grad_input" in self.config["task"]:
            if self.config["dataset"].get("normalize_labels", True):
                self.normalizers["grad_target"] = Normalizer(
                    self.train_loader.dataset.data.y[
                        self.train_loader.dataset.__indices__],
                    self.device,
                )
                self.normalizers["grad_target"].mean.fill_(0)

        if self.is_vis and self.config["task"]["dataset"] != "qm9":
            # Plot label distribution.
            plots = [
                plot_histogram(
                    self.train_loader.dataset.data.y.tolist(),
                    xlabel="{}/raw".format(self.config["task"]["labels"][0]),
                    ylabel="# Examples",
                    title="Split: train",
                ),
                plot_histogram(
                    self.val_loader.dataset.data.y.tolist(),
                    xlabel="{}/raw".format(self.config["task"]["labels"][0]),
                    ylabel="# Examples",
                    title="Split: val",
                ),
                plot_histogram(
                    self.test_loader.dataset.data.y.tolist(),
                    xlabel="{}/raw".format(self.config["task"]["labels"][0]),
                    ylabel="# Examples",
                    title="Split: test",
                ),
            ]
            self.logger.log_plots(plots)

    def load_model(self):
        # Build model
        if distutils.is_master():
            print("### Loading model: {}".format(self.config["model"]))

        # TODO(abhshkdz): Eventually move towards computing features on-the-fly
        # and remove dependence from `.edge_attr`.
        bond_feat_dim = None
        if self.config["task"]["dataset"] in [
                "trajectory_lmdb",
                "single_point_lmdb",
        ]:
            bond_feat_dim = self.config["model_attributes"].get(
                "num_gaussians", 50)
        else:
            raise NotImplementedError

        self.model = registry.get_model_class(self.config["model"])(
            self.train_loader.dataset[0].x.shape[-1]
            if hasattr(self.train_loader.dataset[0], "x")
            and self.train_loader.dataset[0].x is not None else None,
            bond_feat_dim,
            self.num_targets,
            **self.config["model_attributes"],
        ).to(self.device)

        if distutils.is_master():
            print("### Loaded {} with {} parameters.".format(
                self.model.__class__.__name__, self.model.num_params))

        if self.logger is not None:
            self.logger.watch(self.model)

        self.model = OCPDataParallel(
            self.model,
            output_device=self.device,
            num_gpus=1,
        )
        if distutils.initialized():
            self.model = DistributedDataParallel(self.model,
                                                 device_ids=[self.device])

    def load_pretrained(self, checkpoint_path=None, ddp_to_dp=False):
        if checkpoint_path is None or os.path.isfile(checkpoint_path) is False:
            print(f"Checkpoint: {checkpoint_path} not found!")
            return False

        print("### Loading checkpoint from: {}".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path)

        # Load model, optimizer, normalizer state dict.
        # if trained with ddp and want to load in non-ddp, modify keys from
        # module.module.. -> module..
        if ddp_to_dp:
            new_dict = OrderedDict()
            for k, v in checkpoint["state_dict"].items():
                name = k[7:]
                new_dict[name] = v
            self.model.load_state_dict(new_dict)
        else:
            self.model.load_state_dict(checkpoint["state_dict"])

        self.optimizer.load_state_dict(checkpoint["optimizer"])

        for key in checkpoint["normalizers"]:
            if key in self.normalizers:
                self.normalizers[key].load_state_dict(
                    checkpoint["normalizers"][key])
            if self.scaler and checkpoint["amp"]:
                self.scaler.load_state_dict(checkpoint["amp"])
        return True

    # TODO(abhshkdz): Rename function to something nicer.
    # TODO(abhshkdz): Support multiple loss functions.
    def load_criterion(self):
        self.criterion = self.config["optim"].get("criterion", nn.L1Loss())

    def load_optimizer(self):
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            self.config["optim"]["lr_initial"],  # weight_decay=3.0
        )

    def load_extras(self):
        # learning rate scheduler.
        scheduler_lambda_fn = lambda x: warmup_lr_lambda(
            x, self.config["optim"])
        self.scheduler = optim.lr_scheduler.LambdaLR(
            self.optimizer, lr_lambda=scheduler_lambda_fn)

        # metrics.
        self.meter = Meter(split="train")

    def save(self, epoch, metrics):
        if not self.is_debug and distutils.is_master():
            save_checkpoint(
                {
                    "epoch": epoch,
                    "state_dict": self.model.state_dict(),
                    "optimizer": self.optimizer.state_dict(),
                    "normalizers": {
                        key: value.state_dict()
                        for key, value in self.normalizers.items()
                    },
                    "config": self.config,
                    "val_metrics": metrics,
                    "amp": self.scaler.state_dict() if self.scaler else None,
                },
                self.config["cmd"]["checkpoint_dir"],
            )

    def train(self, max_epochs=None, return_metrics=False):
        # TODO(abhshkdz): Timers for dataloading and forward pass.
        num_epochs = (max_epochs if max_epochs is not None else
                      self.config["optim"]["max_epochs"])
        for epoch in range(num_epochs):
            self.model.train()

            for i, batch in enumerate(self.train_loader):
                batch = batch.to(self.device)

                # Forward, loss, backward.
                out, metrics = self._forward(batch)
                loss = self._compute_loss(out, batch)
                self._backward(loss)

                # Update meter.
                meter_update_dict = {
                    "epoch": epoch + (i + 1) / len(self.train_loader),
                    "loss": loss.item(),
                }
                meter_update_dict.update(metrics)
                self.meter.update(meter_update_dict)

                # Make plots.
                if self.logger is not None:
                    self.logger.log(
                        meter_update_dict,
                        step=epoch * len(self.train_loader) + i + 1,
                        split="train",
                    )

                # Print metrics.
                if i % self.config["cmd"]["print_every"] == 0:
                    print(self.meter)

            self.scheduler.step()

            with torch.no_grad():
                if self.val_loader is not None:
                    v_loss, v_mae = self.validate(split="val", epoch=epoch)

                if self.test_loader is not None:
                    test_loss, test_mae = self.validate(split="test",
                                                        epoch=epoch)

            if not self.is_debug:
                save_checkpoint(
                    {
                        "epoch": epoch + 1,
                        "state_dict": self.model.state_dict(),
                        "optimizer": self.optimizer.state_dict(),
                        "normalizers": {
                            key: value.state_dict()
                            for key, value in self.normalizers.items()
                        },
                        "config": self.config,
                        "amp":
                        self.scaler.state_dict() if self.scaler else None,
                    },
                    self.config["cmd"]["checkpoint_dir"],
                )
        if return_metrics:
            return {
                "training_loss":
                float(self.meter.loss.global_avg),
                "training_mae":
                float(self.meter.meters[
                    self.config["task"]["labels"][0] + "/" +
                    self.config["task"]["metric"]].global_avg),
                "validation_loss":
                v_loss,
                "validation_mae":
                v_mae,
                "test_loss":
                test_loss,
                "test_mae":
                test_mae,
            }

    def validate(self, split="val", epoch=None):
        if distutils.is_master():
            print("### Evaluating on {}.".format(split))

        self.model.eval()
        evaluator, metrics = Evaluator(task=self.name), {}
        rank = distutils.get_rank()

        loader = self.val_loader if split == "val" else self.test_loader

        for i, batch in tqdm(
                enumerate(loader),
                total=len(loader),
                position=rank,
                desc="device {}".format(rank),
        ):
            # Forward.
            with torch.cuda.amp.autocast(enabled=self.scaler is not None):
                out = self._forward(batch)
            loss = self._compute_loss(out, batch)

            # Compute metrics.
            metrics = self._compute_metrics(out, batch, evaluator, metrics)
            metrics = evaluator.update("loss", loss.item(), metrics)

        aggregated_metrics = {}
        for k in metrics:
            aggregated_metrics[k] = {
                "total":
                distutils.all_reduce(metrics[k]["total"],
                                     average=False,
                                     device=self.device),
                "numel":
                distutils.all_reduce(metrics[k]["numel"],
                                     average=False,
                                     device=self.device),
            }
            aggregated_metrics[k]["metric"] = (aggregated_metrics[k]["total"] /
                                               aggregated_metrics[k]["numel"])
        metrics = aggregated_metrics

        log_dict = {k: metrics[k]["metric"] for k in metrics}
        log_dict.update({"epoch": epoch + 1})
        if distutils.is_master():
            log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()]
            print(", ".join(log_str))

        # Make plots.
        if self.logger is not None and epoch is not None:
            self.logger.log(
                log_dict,
                step=(epoch + 1) * len(self.train_loader),
                split=split,
            )

        return metrics

    def _forward(self, batch, compute_metrics=True):
        out = {}

        # enable gradient wrt input.
        if "grad_input" in self.config["task"]:
            inp_for_grad = batch.pos
            batch.pos = batch.pos.requires_grad_(True)

        # forward pass.
        if self.config["model_attributes"].get("regress_forces", False):
            output, output_forces = self.model(batch)
        else:
            output = self.model(batch)

        if output.shape[-1] == 1:
            output = output.view(-1)

        out["output"] = output

        force_output = None
        if self.config["model_attributes"].get("regress_forces", False):
            out["force_output"] = output_forces
            force_output = output_forces

        if ("grad_input" in self.config["task"]
                and self.config["model_attributes"].get(
                    "regress_forces", False) is False):
            force_output = -1 * torch.autograd.grad(
                output,
                inp_for_grad,
                grad_outputs=torch.ones_like(output),
                create_graph=True,
                retain_graph=True,
            )[0]
            out["force_output"] = force_output

        if not compute_metrics:
            return out, None

        metrics = {}

        if self.config["dataset"].get("normalize_labels", True):
            errors = eval(self.config["task"]["metric"])(
                self.normalizers["target"].denorm(output).cpu(),
                batch.y.cpu()).view(-1)
        else:
            errors = eval(self.config["task"]["metric"])(
                output.cpu(), batch.y.cpu()).view(-1)

        if ("label_index" in self.config["task"]
                and self.config["task"]["label_index"] is not False):
            # TODO(abhshkdz): Get rid of this edge case for QM9.
            # This is only because QM9 has multiple targets and we can either
            # jointly predict all of them or one particular target.
            metrics["{}/{}".format(
                self.config["task"]["labels"][self.config["task"]
                                              ["label_index"]],
                self.config["task"]["metric"],
            )] = errors[0]
        else:
            for i, label in enumerate(self.config["task"]["labels"]):
                metrics["{}/{}".format(
                    label, self.config["task"]["metric"])] = errors[i]

        if "grad_input" in self.config["task"]:
            force_pred = force_output
            force_target = batch.force

            if self.config["task"].get("eval_on_free_atoms", True):
                mask = batch.fixed == 0
                force_pred = force_pred[mask]
                force_target = force_target[mask]

            if self.config["dataset"].get("normalize_labels", True):
                grad_input_errors = eval(self.config["task"]["metric"])(
                    self.normalizers["grad_target"].denorm(force_pred).cpu(),
                    force_target.cpu(),
                )
            else:
                grad_input_errors = eval(self.config["task"]["metric"])(
                    force_pred.cpu(), force_target.cpu())
            metrics["force_x/{}".format(
                self.config["task"]["metric"])] = grad_input_errors[0]
            metrics["force_y/{}".format(
                self.config["task"]["metric"])] = grad_input_errors[1]
            metrics["force_z/{}".format(
                self.config["task"]["metric"])] = grad_input_errors[2]

        return out, metrics

    def _compute_loss(self, out, batch):
        loss = []

        if self.config["dataset"].get("normalize_labels", True):
            target_normed = self.normalizers["target"].norm(batch.y)
        else:
            target_normed = batch.y

        loss.append(self.criterion(out["output"], target_normed))

        # TODO(abhshkdz): Test support for gradients wrt input.
        # TODO(abhshkdz): Make this general; remove dependence on `.forces`.
        if "grad_input" in self.config["task"]:
            if self.config["dataset"].get("normalize_labels", True):
                grad_target_normed = self.normalizers["grad_target"].norm(
                    batch.force)
            else:
                grad_target_normed = batch.force

            # Force coefficient = 30 has been working well for us.
            force_mult = self.config["optim"].get("force_coefficient", 30)
            if self.config["task"].get("train_on_free_atoms", False):
                mask = batch.fixed == 0
                loss.append(force_mult * self.criterion(
                    out["force_output"][mask], grad_target_normed[mask]))
            else:
                loss.append(
                    force_mult *
                    self.criterion(out["force_output"], grad_target_normed))

        # Sanity check to make sure the compute graph is correct.
        for lc in loss:
            assert hasattr(lc, "grad_fn")

        loss = sum(loss)
        return loss

    def _backward(self, loss):
        self.optimizer.zero_grad()
        loss.backward()
        # TODO(abhshkdz): Add support for gradient clipping.
        if self.scaler:
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            self.optimizer.step()

    def save_results(self, predictions, results_file, keys):
        if results_file is None:
            return

        results_file_path = os.path.join(
            self.config["cmd"]["results_dir"],
            f"{self.name}_{results_file}_{distutils.get_rank()}.npz",
        )
        np.savez_compressed(
            results_file_path,
            ids=predictions["id"],
            **{key: predictions[key]
               for key in keys},
        )

        distutils.synchronize()
        if distutils.is_master():
            gather_results = defaultdict(list)
            full_path = os.path.join(
                self.config["cmd"]["results_dir"],
                f"{self.name}_{results_file}.npz",
            )

            for i in range(distutils.get_world_size()):
                rank_path = os.path.join(
                    self.config["cmd"]["results_dir"],
                    f"{self.name}_{results_file}_{i}.npz",
                )
                rank_results = np.load(rank_path, allow_pickle=True)
                gather_results["ids"].extend(rank_results["ids"])
                for key in keys:
                    gather_results[key].extend(rank_results[key])
                os.remove(rank_path)

            # Because of how distributed sampler works, some system ids
            # might be repeated to make no. of samples even across GPUs.
            _, idx = np.unique(gather_results["ids"], return_index=True)
            gather_results["ids"] = np.array(gather_results["ids"])[idx]
            for k in keys:
                if k == "forces":
                    gather_results[k] = np.array(gather_results[k],
                                                 dtype=object)[idx]
                else:
                    gather_results[k] = np.array(gather_results[k])[idx]

            print(f"Writing results to {full_path}")
            np.savez_compressed(full_path, **gather_results)