Example #1
0
    def evaluate(self, epoch, mode='val'):
        self.mode = mode
        self.epoch = epoch

        self.model.eval()
        self.model.to(dtype=self.dtype_eval)

        if mode == 'val':
            self.criterion.validate()
        elif mode == 'test':
            self.criterion.test()
        self.criterion.epoch = epoch

        self.imsaver.join_background()

        if self.is_slave:
            tq = self.loaders[self.mode]
        else:
            tq = tqdm(self.loaders[self.mode], ncols=80, smoothing=0, bar_format='{desc}|{bar}{r_bar}')

        compute_loss = True
        torch.set_grad_enabled(False)
        for idx, batch in enumerate(tq):
            input, target = data.common.to(
                batch[0], batch[1], device=self.device, dtype=self.dtype_eval)
            with amp.autocast(self.args.amp):
                output = self.model(input)

            if mode == 'demo':  # remove padded part
                pad_width = batch[2]
                output[0], _ = data.common.pad(output[0], pad_width=pad_width, negative=True)

            if isinstance(batch[1], torch.BoolTensor):
                compute_loss = False

            if compute_loss:
                self.criterion(output, target)
                if isinstance(tq, tqdm):
                    tq.set_description(self.criterion.get_loss_desc())

            if self.args.save_results != 'none':
                if isinstance(output, (list, tuple)):
                    result = output[0]  # select last output in a pyramid
                elif isinstance(output, torch.Tensor):
                    result = output

                names = batch[-1]

                if self.args.save_results == 'part' and compute_loss: # save all when GT not available
                    indices = batch[-2]
                    save_ids = [save_id for save_id, idx in enumerate(indices) if idx % 10 == 0]

                    result = result[save_ids]
                    names = [names[save_id] for save_id in save_ids]

                self.imsaver.save_image(result, names)

        if compute_loss:
            self.criterion.normalize()
            if isinstance(tq, tqdm):
                tq.set_description(self.criterion.get_loss_desc())
                tq.display(pos=-1)  # overwrite with synchronized loss

            self.criterion.step()
            if self.args.rank == 0:
                self.save()

        self.imsaver.end_background()
Example #2
0
def train_sanity_fit(
    model: nn.Module,
    train_loader,
    criterion,
    device: str,
    num_batches: int = None,
    log_interval: int = 100,
    fp16: bool = False,
):
    """
    Performs Sanity fit over train loader.
    Use this to dummy check your train_step function. It does not calculate metrics, timing, or does checkpointing.
    It iterates over both train_loader for given batches.
    Note: - It does not to loss.backward().
    Args:
         model : A PyTorch Detr Model.
        train_loader : Train loader.
        device : "cuda" or "cpu"
        criterion : Loss function to be optimized.
        num_batches : (optional) Integer To limit sanity fit over certain batches.
                                 Useful is data is too big even for sanity check.
        log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch.
        fp16: : (optional) If True uses PyTorch native mixed precision Training.
    """

    model = model.to(device)
    criterion = criterion.to(device)
    train_sanity_start = time.time()
    model.train()

    last_idx = len(train_loader) - 1
    criterion.train()
    cnt = 0
    if fp16 is True:
        scaler = amp.GradScaler()

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        last_batch = batch_idx == last_idx
        images = list(image.to(device) for image in inputs)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        if fp16 is True:
            with amp.autocast():
                outputs = model(images)
                loss_dict = criterion(outputs, targets)
                weight_dict = criterion.weight_dict
                loss = sum(loss_dict[k] * weight_dict[k]
                           for k in loss_dict.keys() if k in weight_dict)

        else:
            outputs = model(images)
            loss_dict = criterion(outputs, targets)
            weight_dict = criterion.weight_dict
            loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()
                       if k in weight_dict)

        cnt += 1
        if last_batch or batch_idx % log_interval == 0:
            print(
                f"Train sanity check passed for batch till {batch_idx} batches"
            )

        if num_batches is not None:
            if cnt >= num_batches:
                print(f"Done till {num_batches} train batches")
                print("All specified batches done")
                train_sanity_end = time.time()
                print(
                    f"Train sanity fit check passed in time {train_sanity_end-train_sanity_start}"
                )
                return True

    train_sanity_end = time.time()

    print("All specified batches done")
    print(
        f"Train sanity fit check passed in time {train_sanity_end-train_sanity_start}"
    )

    return True
Example #3
0
    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[float], Optional[torch.Tensor],
               Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
        """

        if not self.args.predict_with_generate or prediction_loss_only:
            return super().prediction_step(
                model,
                inputs,
                prediction_loss_only=prediction_loss_only,
                ignore_keys=ignore_keys)

        has_labels = "labels" in inputs
        inputs = self._prepare_inputs(inputs)

        gen_kwargs = {
            "max_length":
            self._max_length
            if self._max_length is not None else self.model.config.max_length,
            "num_beams":
            self._num_beams
            if self._num_beams is not None else self.model.config.num_beams,
            "min_length":
            self._min_length
            if self._min_length is not None else self.model.config.min_length
        }

        generated_tokens = self.model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            **gen_kwargs,
        )
        # in case the batch is shorter than max length, the output should be padded
        if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
            generated_tokens = self._pad_tensors_to_max_len(
                generated_tokens, gen_kwargs["max_length"])

        with torch.no_grad():
            if self.use_amp:
                with autocast():
                    outputs = model(**inputs)
            else:
                outputs = model(**inputs)
            if has_labels:
                if self.label_smoother is not None:
                    loss = self.label_smoother(
                        outputs, inputs["labels"]).mean().detach()
                else:
                    loss = (outputs["loss"] if isinstance(outputs, dict) else
                            outputs[0]).mean().detach()
            else:
                loss = None

        if self.args.prediction_loss_only:
            return (loss, None, None)

        labels = inputs["labels"]
        if labels.shape[-1] < gen_kwargs["max_length"]:
            labels = self._pad_tensors_to_max_len(labels,
                                                  gen_kwargs["max_length"])

        return (loss, generated_tokens, labels)
Example #4
0
def train(hyp, opt, device, tb_writer=None, wandb=None):
    logger.info(
        colorstr('hyperparameters: ') + ', '.join(f'{k}={v}'
                                                  for k, v in hyp.items()))
    save_dir, epochs, batch_size, total_batch_size, weights, rank = \
        Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank

    # Directories
    wdir = save_dir / 'weights'
    wdir.mkdir(parents=True, exist_ok=True)  # make dir
    last = wdir / 'last.pt'
    best = wdir / 'best.pt'
    results_file = save_dir / 'results.txt'

    # Save run settings
    with open(save_dir / 'hyp.yaml', 'w') as f:
        yaml.dump(hyp, f, sort_keys=False)
    with open(save_dir / 'opt.yaml', 'w') as f:
        yaml.dump(vars(opt), f, sort_keys=False)

    # Configure
    plots = not opt.evolve  # create plots
    cuda = device.type != 'cpu'
    init_seeds(2 + rank)
    with open(opt.data) as f:
        data_dict = yaml.load(f, Loader=yaml.FullLoader)  # data dict
    with torch_distributed_zero_first(rank):  # 这个是干什么用的?
        check_dataset(data_dict)  # check
    train_path = data_dict['train']
    test_path = data_dict['val']
    nc = 1 if opt.single_cls else int(data_dict['nc'])  # number of classes
    names = ['item'] if opt.single_cls and len(
        data_dict['names']) != 1 else data_dict['names']  # class names
    assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (
        len(names), nc, opt.data)  # check

    # Model
    pretrained = weights.endswith('.pt')
    pretrained = 0
    if pretrained:
        with torch_distributed_zero_first(rank):
            attempt_download(weights)  # download if not found locally
        ckpt = torch.load(weights, map_location=device)  # load checkpoint
        if hyp.get('anchors'):
            ckpt['model'].yaml['anchors'] = round(
                hyp['anchors'])  # force autoanchor
        model = Model(opt.cfg or ckpt['model'].yaml, ch=3,
                      nc=nc).to(device)  # create
        exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [
        ]  # exclude keys
        state_dict = ckpt['model'].float().state_dict()  # to FP32
        state_dict = intersect_dicts(state_dict,
                                     model.state_dict(),
                                     exclude=exclude)  # intersect
        model.load_state_dict(state_dict, strict=False)  # load
        logger.info(
            'Transferred %g/%g items from %s' %
            (len(state_dict), len(model.state_dict()), weights))  # report
    else:
        model = Model(opt.cfg, ch=3, nc=nc).to(device)  # create

    # Freeze
    freeze = []  # parameter names to freeze (full or partial)
    for k, v in model.named_parameters():
        v.requires_grad = True  # train all layers
        if any(x in k for x in freeze):
            print('freezing %s' % k)
            v.requires_grad = False

    # Optimizer
    nbs = 64  # nominal batch size
    accumulate = max(round(nbs / total_batch_size),
                     1)  # accumulate loss before optimizing
    hyp['weight_decay'] *= total_batch_size * accumulate / nbs  # scale weight_decay
    logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")

    pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
    for k, v in model.named_modules():
        if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
            pg2.append(v.bias)  # biases
        if isinstance(v, nn.BatchNorm2d):
            pg0.append(v.weight)  # no decay
        elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
            pg1.append(v.weight)  # apply decay

    if opt.adam:
        optimizer = optim.Adam(pg0,
                               lr=hyp['lr0'],
                               betas=(hyp['momentum'],
                                      0.999))  # adjust beta1 to momentum
    else:
        optimizer = optim.SGD(pg0,
                              lr=hyp['lr0'],
                              momentum=hyp['momentum'],
                              nesterov=True)

    optimizer.add_param_group({
        'params': pg1,
        'weight_decay': hyp['weight_decay']
    })  # add pg1 with weight_decay
    optimizer.add_param_group({'params': pg2})  # add pg2 (biases)
    logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' %
                (len(pg2), len(pg1), len(pg0)))
    del pg0, pg1, pg2

    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
    lf = one_cycle(1, hyp['lrf'], epochs)  # cosine 1->hyp['lrf']
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    # plot_lr_scheduler(optimizer, scheduler, epochs)

    # Logging
    if rank in [-1, 0] and wandb and wandb.run is None:
        opt.hyp = hyp  # add hyperparameters
        wandb_run = wandb.init(
            config=opt,
            resume="allow",
            project='YOLOv5'
            if opt.project == 'runs/train' else Path(opt.project).stem,
            name=save_dir.stem,
            id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
    loggers = {'wandb': wandb}  # loggers dict

    # Resume
    start_epoch, best_fitness = 0, 0.0
    if pretrained:
        # Optimizer
        if ckpt['optimizer'] is not None:
            optimizer.load_state_dict(ckpt['optimizer'])
            best_fitness = ckpt['best_fitness']

        # Results
        if ckpt.get('training_results') is not None:
            with open(results_file, 'w') as file:
                file.write(ckpt['training_results'])  # write results.txt

        # Epochs
        start_epoch = ckpt['epoch'] + 1
        if opt.resume:
            assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (
                weights, epochs)
        if epochs < start_epoch:
            logger.info(
                '%s has been trained for %g epochs. Fine-tuning for %g additional epochs.'
                % (weights, ckpt['epoch'], epochs))
            epochs += ckpt['epoch']  # finetune additional epochs

        del ckpt, state_dict

    # Image sizes
    gs = int(model.stride.max())  # grid size (max stride)
    nl = model.model[
        -1].nl  # number of detection layers (used for scaling hyp['obj'])
    imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size
                         ]  # verify imgsz are gs-multiples

    # DP mode
    if cuda and rank == -1 and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    # SyncBatchNorm
    if opt.sync_bn and cuda and rank != -1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
        logger.info('Using SyncBatchNorm()')

    # EMA
    ema = ModelEMA(model) if rank in [-1, 0] else None

    # DDP mode
    if cuda and rank != -1:
        model = DDP(model,
                    device_ids=[opt.local_rank],
                    output_device=opt.local_rank)

    # Trainloader
    dataloader, dataset = create_dataloader(train_path,
                                            imgsz,
                                            batch_size,
                                            gs,
                                            opt,
                                            hyp=hyp,
                                            augment=True,
                                            cache=opt.cache_images,
                                            rect=opt.rect,
                                            rank=rank,
                                            world_size=opt.world_size,
                                            workers=opt.workers,
                                            image_weights=opt.image_weights,
                                            quad=opt.quad,
                                            prefix=colorstr('train: '))
    mlc = np.concatenate(dataset.labels, 0)[:, 0].max()  # max label class
    nb = len(dataloader)  # number of batches
    assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (
        mlc, nc, opt.data, nc - 1)

    # Process 0
    if rank in [-1, 0]:
        ema.updates = start_epoch * nb // accumulate  # set EMA updates
        testloader = create_dataloader(
            test_path,
            imgsz_test,
            total_batch_size,
            gs,
            opt,  # testloader
            hyp=hyp,
            cache=opt.cache_images and not opt.notest,
            rect=True,
            rank=-1,
            world_size=opt.world_size,
            workers=opt.workers,
            pad=0.5,
            prefix=colorstr('val: '))[0]

        if not opt.resume:
            labels = np.concatenate(dataset.labels, 0)
            c = torch.tensor(labels[:, 0])  # classes
            # cf = torch.bincount(c.long(), minlength=nc) + 1.  # frequency
            # model._initialize_biases(cf.to(device))
            if plots:
                plot_labels(labels, save_dir, loggers)
                if tb_writer:
                    tb_writer.add_histogram('classes', c, 0)

            # Anchors
            if not opt.noautoanchor:
                check_anchors(dataset,
                              model=model,
                              thr=hyp['anchor_t'],
                              imgsz=imgsz)

    # Model parameters
    hyp['box'] *= 3. / nl  # scale to layers
    hyp['cls'] *= nc / 80. * 3. / nl  # scale to classes and layers
    hyp['obj'] *= (imgsz / 640)**2 * 3. / nl  # scale to image size and layers
    model.nc = nc  # attach number of classes to model
    model.hyp = hyp  # attach hyperparameters to model
    model.gr = 1.0  # iou loss ratio (obj_loss = 1.0 or iou)
    model.class_weights = labels_to_class_weights(
        dataset.labels, nc).to(device) * nc  # attach class weights
    model.names = names

    # Start training
    t0 = time.time()
    nw = max(round(hyp['warmup_epochs'] * nb),
             1000)  # number of warmup iterations, max(3 epochs, 1k iterations)
    # nw = min(nw, (epochs - start_epoch) / 2 * nb)  # limit warmup to < 1/2 of training
    maps = np.zeros(nc)  # mAP per class
    results = (0, 0, 0, 0, 0, 0, 0
               )  # P, R, [email protected], [email protected], val_loss(box, obj, cls)
    scheduler.last_epoch = start_epoch - 1  # do not move
    scaler = amp.GradScaler(enabled=cuda)
    logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
                f'Using {dataloader.num_workers} dataloader workers\n'
                f'Logging results to {save_dir}\n'
                f'Starting training for {epochs} epochs...')
    for epoch in range(
            start_epoch, epochs
    ):  # epoch ------------------------------------------------------------------
        model.train()

        # Update image weights (optional)
        if opt.image_weights:
            # Generate indices
            if rank in [-1, 0]:
                cw = model.class_weights.cpu().numpy() * (
                    1 - maps)**2 / nc  # class weights
                iw = labels_to_image_weights(dataset.labels,
                                             nc=nc,
                                             class_weights=cw)  # image weights
                dataset.indices = random.choices(
                    range(dataset.n), weights=iw,
                    k=dataset.n)  # rand weighted idx
            # Broadcast if DDP
            if rank != -1:
                indices = (torch.tensor(dataset.indices)
                           if rank == 0 else torch.zeros(dataset.n)).int()
                dist.broadcast(indices, 0)
                if rank != 0:
                    dataset.indices = indices.cpu().numpy()

        # Update mosaic border
        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders

        mloss = torch.zeros(4, device=device)  # mean losses
        if rank != -1:
            dataloader.sampler.set_epoch(epoch)
        pbar = enumerate(dataloader)
        logger.info(
            ('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls',
                                   'total', 'targets', 'img_size'))
        if rank in [-1, 0]:
            pbar = tqdm(pbar, total=nb)  # progress bar
        optimizer.zero_grad()
        for i, (
                imgs, targets, paths, _
        ) in pbar:  # batch -------------------------------------------------------------
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device, non_blocking=True).float(
            ) / 255.0  # uint8 to float32, 0-255 to 0.0-1.0

            # Warmup
            if ni <= nw:
                xi = [0, nw]  # x interp
                # model.gr = np.interp(ni, xi, [0.0, 1.0])  # iou loss ratio (obj_loss = 1.0 or iou)
                accumulate = max(
                    1,
                    np.interp(ni, xi, [1, nbs / total_batch_size]).round())
                for j, x in enumerate(optimizer.param_groups):
                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                    x['lr'] = np.interp(ni, xi, [
                        hyp['warmup_bias_lr'] if j == 2 else 0.0,
                        x['initial_lr'] * lf(epoch)
                    ])
                    if 'momentum' in x:
                        x['momentum'] = np.interp(
                            ni, xi, [hyp['warmup_momentum'], hyp['momentum']])

            # Multi-scale
            if opt.multi_scale:
                sz = random.randrange(imgsz * 0.5,
                                      imgsz * 1.5 + gs) // gs * gs  # size
                sf = sz / max(imgs.shape[2:])  # scale factor
                if sf != 1:
                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]
                          ]  # new shape (stretched to gs-multiple)
                    imgs = F.interpolate(imgs,
                                         size=ns,
                                         mode='bilinear',
                                         align_corners=False)

            # Forward
            with amp.autocast(enabled=cuda):
                pred = model(imgs)  # forward
                loss, loss_items = compute_loss(
                    pred, targets.to(device),
                    model)  # loss scaled by batch_size
                if rank != -1:
                    loss *= opt.world_size  # gradient averaged between devices in DDP mode
                if opt.quad:
                    loss *= 4.

            # Backward
            scaler.scale(loss).backward()

            # Optimize
            if ni % accumulate == 0:
                scaler.step(optimizer)  # optimizer.step
                scaler.update()
                optimizer.zero_grad()
                if ema:
                    ema.update(model)

            # Print
            if rank in [-1, 0]:
                mloss = (mloss * i + loss_items) / (i + 1
                                                    )  # update mean losses
                mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9
                                 if torch.cuda.is_available() else 0)  # (GB)
                s = ('%10s' * 2 +
                     '%10.4g' * 6) % ('%g/%g' % (epoch, epochs - 1), mem,
                                      *mloss, targets.shape[0], imgs.shape[-1])
                pbar.set_description(s)

                # Plot
                if plots and ni < 3:
                    f = save_dir / f'train_batch{ni}.jpg'  # filename
                    Thread(target=plot_images,
                           args=(imgs, targets, paths, f),
                           daemon=True).start()
                    # if tb_writer:
                    #     tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
                    #     tb_writer.add_graph(model, imgs)  # add model to tensorboard
                elif plots and ni == 3 and wandb:
                    wandb.log({
                        "Mosaics": [
                            wandb.Image(str(x), caption=x.name)
                            for x in save_dir.glob('train*.jpg')
                        ]
                    })

            # end batch ------------------------------------------------------------------------------------------------
        # end epoch ----------------------------------------------------------------------------------------------------

        # Scheduler
        lr = [x['lr'] for x in optimizer.param_groups]  # for tensorboard
        scheduler.step()

        # DDP process 0 or single-GPU
        if rank in [-1, 0]:
            # mAP
            if ema:
                ema.update_attr(model,
                                include=[
                                    'yaml', 'nc', 'hyp', 'gr', 'names',
                                    'stride', 'class_weights'
                                ])
            final_epoch = epoch + 1 == epochs
            if not opt.notest or final_epoch:  # Calculate mAP
                results, maps, times = test.test(
                    opt.data,
                    batch_size=total_batch_size,
                    imgsz=imgsz_test,
                    model=ema.ema,
                    single_cls=opt.single_cls,
                    dataloader=testloader,
                    save_dir=save_dir,
                    plots=plots and final_epoch,
                    log_imgs=opt.log_imgs if wandb else 0)

            # Write
            with open(results_file, 'a') as f:
                f.write(
                    s + '%10.4g' * 7 % results +
                    '\n')  # P, R, [email protected], [email protected], val_loss(box, obj, cls)
            if len(opt.name) and opt.bucket:
                os.system('gsutil cp %s gs://%s/results/results%s.txt' %
                          (results_file, opt.bucket, opt.name))

            # Log
            tags = [
                'train/box_loss',
                'train/obj_loss',
                'train/cls_loss',  # train loss
                'metrics/precision',
                'metrics/recall',
                'metrics/mAP_0.5',
                'metrics/mAP_0.5:0.95',
                'val/box_loss',
                'val/obj_loss',
                'val/cls_loss',  # val loss
                'x/lr0',
                'x/lr1',
                'x/lr2'
            ]  # params
            for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
                if tb_writer:
                    tb_writer.add_scalar(tag, x, epoch)  # tensorboard
                if wandb:
                    wandb.log({tag: x})  # W&B

            # Update best mAP
            fi = fitness(np.array(results).reshape(
                1, -1))  # weighted combination of [P, R, [email protected], [email protected]]
            if fi > best_fitness:
                best_fitness = fi

            # Save model
            save = (not opt.nosave) or (final_epoch and not opt.evolve)
            if save:
                with open(results_file, 'r') as f:  # create checkpoint
                    ckpt = {
                        'epoch':
                        epoch,
                        'best_fitness':
                        best_fitness,
                        'training_results':
                        f.read(),
                        'model':
                        ema.ema,
                        'optimizer':
                        None if final_epoch else optimizer.state_dict(),
                        'wandb_id':
                        wandb_run.id if wandb else None
                    }

                # Save last, best and delete
                torch.save(ckpt, last)
                if best_fitness == fi:
                    torch.save(ckpt, best)
                del ckpt
        # end epoch ----------------------------------------------------------------------------------------------------
    # end training

    if rank in [-1, 0]:
        # Strip optimizers
        final = best if best.exists() else last  # final model
        for f in [last, best]:
            if f.exists():
                strip_optimizer(f)  # strip optimizers
        if opt.bucket:
            os.system(f'gsutil cp {final} gs://{opt.bucket}/weights')  # upload

        # Plots
        if plots:
            plot_results(save_dir=save_dir)  # save as results.png
            if wandb:
                files = [
                    'results.png', 'precision_recall_curve.png',
                    'confusion_matrix.png'
                ]
                wandb.log({
                    "Results": [
                        wandb.Image(str(save_dir / f), caption=f)
                        for f in files if (save_dir / f).exists()
                    ]
                })
                if opt.log_artifacts:
                    wandb.log_artifact(artifact_or_path=str(final),
                                       type='model',
                                       name=save_dir.stem)

        # Test best.pt
        logger.info('%g epochs completed in %.3f hours.\n' %
                    (epoch - start_epoch + 1, (time.time() - t0) / 3600))
        if opt.data.endswith('coco.yaml') and nc == 80:  # if COCO
            for conf, iou, save_json in ([0.25, 0.45,
                                          False], [0.001, 0.65,
                                                   True]):  # speed, mAP tests
                results, _, _ = test.test(opt.data,
                                          batch_size=total_batch_size,
                                          imgsz=imgsz_test,
                                          conf_thres=conf,
                                          iou_thres=iou,
                                          model=attempt_load(final,
                                                             device).half(),
                                          single_cls=opt.single_cls,
                                          dataloader=testloader,
                                          save_dir=save_dir,
                                          save_json=save_json,
                                          plots=False)

    else:
        dist.destroy_process_group()

    wandb.run.finish() if wandb and wandb.run else None
    torch.cuda.empty_cache()
    return results
Example #5
0
def train(hyp, opt, device, tb_writer=None):
    print(f'Hyperparameters {hyp}')
    log_dir = Path(tb_writer.log_dir) if tb_writer else Path(
        opt.logdir) / 'evolve'  # logging directory
    wdir = str(log_dir / 'weights') + os.sep  # weights directory
    os.makedirs(wdir, exist_ok=True)
    last = wdir + 'last.pt'
    best = wdir + 'best.pt'
    results_file = str(log_dir / 'results.txt')
    epochs, batch_size, total_batch_size, weights, rank = \
        opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank

    # TODO: Use DDP logging. Only the first process is allowed to log.
    # Save run settings
    with open(log_dir / 'hyp.yaml', 'w') as f:
        yaml.dump(hyp, f, sort_keys=False)
    with open(log_dir / 'opt.yaml', 'w') as f:
        yaml.dump(vars(opt), f, sort_keys=False)

    # Configure
    cuda = device.type != 'cpu'
    init_seeds(2 + rank)
    with open(opt.data) as f:
        data_dict = yaml.load(f, Loader=yaml.FullLoader)  # model dict
    train_path = data_dict['train']
    test_path = data_dict['val']
    nc, names = (1, ['item']) if opt.single_cls else (int(
        data_dict['nc']), data_dict['names'])  # number classes, names
    assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (
        len(names), nc, opt.data)  # check

    # Model
    pretrained = weights.endswith('.pt')
    if pretrained:
        with torch_distributed_zero_first(rank):
            attempt_download(weights)  # download if not found locally
        ckpt = torch.load(weights, map_location=device)  # load checkpoint
        model = Darknet(opt.cfg).to(device)  # create
        state_dict = {
            k: v
            for k, v in ckpt['model'].items()
            if model.state_dict()[k].numel() == v.numel()
        }
        model.load_state_dict(state_dict, strict=False)
        print('Transferred %g/%g items from %s' %
              (len(state_dict), len(model.state_dict()), weights))  # report
    else:
        model = Darknet(opt.cfg).to(device)  # create

    # Optimizer
    nbs = 64  # nominal batch size
    accumulate = max(round(nbs / total_batch_size),
                     1)  # accumulate loss before optimizing
    hyp['weight_decay'] *= total_batch_size * accumulate / nbs  # scale weight_decay

    pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
    for k, v in dict(model.named_parameters()).items():
        if '.bias' in k:
            pg2.append(v)  # biases
        elif 'Conv2d.weight' in k:
            pg1.append(v)  # apply weight_decay
        else:
            pg0.append(v)  # all else

    if opt.adam:
        optimizer = optim.Adam(pg0,
                               lr=hyp['lr0'],
                               betas=(hyp['momentum'],
                                      0.999))  # adjust beta1 to momentum
    else:
        optimizer = optim.SGD(pg0,
                              lr=hyp['lr0'],
                              momentum=hyp['momentum'],
                              nesterov=True)

    optimizer.add_param_group({
        'params': pg1,
        'weight_decay': hyp['weight_decay']
    })  # add pg1 with weight_decay
    optimizer.add_param_group({'params': pg2})  # add pg2 (biases)
    print('Optimizer groups: %g .bias, %g conv.weight, %g other' %
          (len(pg2), len(pg1), len(pg0)))
    del pg0, pg1, pg2

    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
    lf = lambda x: ((
        (1 + math.cos(x * math.pi / epochs)) / 2)**1.0) * 0.8 + 0.2  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    # plot_lr_scheduler(optimizer, scheduler, epochs)

    # Resume
    start_epoch, best_fitness = 0, 0.0
    if pretrained:
        # Optimizer
        if ckpt['optimizer'] is not None:
            optimizer.load_state_dict(ckpt['optimizer'])
            best_fitness = ckpt['best_fitness']

        # Results
        if ckpt.get('training_results') is not None:
            with open(results_file, 'w') as file:
                file.write(ckpt['training_results'])  # write results.txt

        # Epochs
        start_epoch = ckpt['epoch'] + 1
        if epochs < start_epoch:
            print(
                '%s has been trained for %g epochs. Fine-tuning for %g additional epochs.'
                % (weights, ckpt['epoch'], epochs))
            epochs += ckpt['epoch']  # finetune additional epochs

        del ckpt, state_dict

    # Image sizes
    gs = 32  # grid size (max stride)
    imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size
                         ]  # verify imgsz are gs-multiples

    # DP mode
    if cuda and rank == -1 and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    # SyncBatchNorm
    if opt.sync_bn and cuda and rank != -1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
        print('Using SyncBatchNorm()')

    # Exponential moving average
    ema = ModelEMA(model) if rank in [-1, 0] else None

    # DDP mode
    if cuda and rank != -1:
        model = DDP(model,
                    device_ids=[opt.local_rank],
                    output_device=(opt.local_rank))

    # Trainloader
    dataloader, dataset = create_dataloader(train_path,
                                            imgsz,
                                            batch_size,
                                            gs,
                                            opt,
                                            hyp=hyp,
                                            augment=True,
                                            cache=opt.cache_images,
                                            rect=opt.rect,
                                            local_rank=rank,
                                            world_size=opt.world_size)
    mlc = np.concatenate(dataset.labels, 0)[:, 0].max()  # max label class
    nb = len(dataloader)  # number of batches
    assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (
        mlc, nc, opt.data, nc - 1)

    # Testloader
    if rank in [-1, 0]:
        ema.updates = start_epoch * nb // accumulate  # set EMA updates ***
        # local_rank is set to -1. Because only the first process is expected to do evaluation.
        testloader = create_dataloader(test_path,
                                       imgsz_test,
                                       batch_size,
                                       gs,
                                       opt,
                                       hyp=hyp,
                                       augment=False,
                                       cache=opt.cache_images,
                                       rect=True,
                                       local_rank=-1,
                                       world_size=opt.world_size)[0]

    # Model parameters
    hyp['cls'] *= nc / 80.  # scale coco-tuned hyp['cls'] to current dataset
    model.nc = nc  # attach number of classes to model
    model.hyp = hyp  # attach hyperparameters to model
    model.gr = 1.0  # giou loss ratio (obj_loss = 1.0 or giou)
    model.class_weights = labels_to_class_weights(dataset.labels, nc).to(
        device)  # attach class weights
    model.names = names

    # Class frequency
    if rank in [-1, 0]:
        labels = np.concatenate(dataset.labels, 0)
        c = torch.tensor(labels[:, 0])  # classes
        # cf = torch.bincount(c.long(), minlength=nc) + 1.
        # model._initialize_biases(cf.to(device))
        plot_labels(labels, save_dir=log_dir)
        if tb_writer:
            tb_writer.add_histogram('classes', c, 0)

        # Check anchors
        #if not opt.noautoanchor:
        #    check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)

    # Start training
    t0 = time.time()
    nw = max(3 * nb,
             1e3)  # number of warmup iterations, max(3 epochs, 1k iterations)
    # nw = min(nw, (epochs - start_epoch) / 2 * nb)  # limit warmup to < 1/2 of training
    maps = np.zeros(nc)  # mAP per class
    results = (
        0, 0, 0, 0, 0, 0, 0
    )  # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
    scheduler.last_epoch = start_epoch - 1  # do not move
    scaler = amp.GradScaler(enabled=cuda)
    if rank in [0, -1]:
        print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
        print('Using %g dataloader workers' % dataloader.num_workers)
        print('Starting training for %g epochs...' % epochs)
    # torch.autograd.set_detect_anomaly(True)
    for epoch in range(
            start_epoch, epochs
    ):  # epoch ------------------------------------------------------------------
        model.train()

        # Update image weights (optional)
        if dataset.image_weights:
            # Generate indices
            if rank in [-1, 0]:
                w = model.class_weights.cpu().numpy() * (
                    1 - maps)**2  # class weights
                image_weights = labels_to_image_weights(dataset.labels,
                                                        nc=nc,
                                                        class_weights=w)
                dataset.indices = random.choices(
                    range(dataset.n), weights=image_weights,
                    k=dataset.n)  # rand weighted idx
            # Broadcast if DDP
            if rank != -1:
                indices = torch.zeros([dataset.n], dtype=torch.int)
                if rank == 0:
                    indices[:] = torch.from_tensor(dataset.indices,
                                                   dtype=torch.int)
                dist.broadcast(indices, 0)
                if rank != 0:
                    dataset.indices = indices.cpu().numpy()

        # Update mosaic border
        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders

        mloss = torch.zeros(4, device=device)  # mean losses
        if rank != -1:
            dataloader.sampler.set_epoch(epoch)
        pbar = enumerate(dataloader)
        if rank in [-1, 0]:
            print(
                ('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj',
                                       'cls', 'total', 'targets', 'img_size'))
            pbar = tqdm(pbar, total=nb)  # progress bar
        optimizer.zero_grad()
        for i, (
                imgs, targets, paths, _
        ) in pbar:  # batch -------------------------------------------------------------
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device, non_blocking=True).float(
            ) / 255.0  # uint8 to float32, 0-255 to 0.0-1.0

            # Warmup
            if ni <= nw:
                xi = [0, nw]  # x interp
                # model.gr = np.interp(ni, xi, [0.0, 1.0])  # giou loss ratio (obj_loss = 1.0 or giou)
                accumulate = max(
                    1,
                    np.interp(ni, xi, [1, nbs / total_batch_size]).round())
                for j, x in enumerate(optimizer.param_groups):
                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                    x['lr'] = np.interp(
                        ni, xi,
                        [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
                    if 'momentum' in x:
                        x['momentum'] = np.interp(ni, xi,
                                                  [0.9, hyp['momentum']])

            # Multi-scale
            if opt.multi_scale:
                sz = random.randrange(imgsz * 0.5,
                                      imgsz * 1.5 + gs) // gs * gs  # size
                sf = sz / max(imgs.shape[2:])  # scale factor
                if sf != 1:
                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]
                          ]  # new shape (stretched to gs-multiple)
                    imgs = F.interpolate(imgs,
                                         size=ns,
                                         mode='bilinear',
                                         align_corners=False)

            # Autocast
            with amp.autocast(enabled=cuda):
                # Forward
                pred = model(imgs)

                # Loss
                loss, loss_items = compute_loss(pred, targets.to(device),
                                                model)  # scaled by batch_size
                if rank != -1:
                    loss *= opt.world_size  # gradient averaged between devices in DDP mode
                # if not torch.isfinite(loss):
                #     print('WARNING: non-finite loss, ending training ', loss_items)
                #     return results

            # Backward
            scaler.scale(loss).backward()

            # Optimize
            if ni % accumulate == 0:
                scaler.step(optimizer)  # optimizer.step
                scaler.update()
                optimizer.zero_grad()
                if ema is not None:
                    ema.update(model)

            # Print
            if rank in [-1, 0]:
                mloss = (mloss * i + loss_items) / (i + 1
                                                    )  # update mean losses
                mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9
                                 if torch.cuda.is_available() else 0)  # (GB)
                s = ('%10s' * 2 +
                     '%10.4g' * 6) % ('%g/%g' % (epoch, epochs - 1), mem,
                                      *mloss, targets.shape[0], imgs.shape[-1])
                pbar.set_description(s)

                # Plot
                if ni < 3:
                    f = str(log_dir / ('train_batch%g.jpg' % ni))  # filename
                    result = plot_images(images=imgs,
                                         targets=targets,
                                         paths=paths,
                                         fname=f)
                    if tb_writer and result is not None:
                        tb_writer.add_image(f,
                                            result,
                                            dataformats='HWC',
                                            global_step=epoch)
                        # tb_writer.add_graph(model, imgs)  # add model to tensorboard

            # end batch ------------------------------------------------------------------------------------------------

        # Scheduler
        scheduler.step()

        # DDP process 0 or single-GPU
        if rank in [-1, 0]:
            # mAP
            if ema is not None:
                ema.update_attr(model)
            final_epoch = epoch + 1 == epochs
            if not opt.notest or final_epoch:  # Calculate mAP
                results, maps, times = test.test(
                    opt.data,
                    batch_size=batch_size,
                    imgsz=imgsz_test,
                    save_json=final_epoch
                    and opt.data.endswith(os.sep + 'coco.yaml'),
                    model=ema.ema.module
                    if hasattr(ema.ema, 'module') else ema.ema,
                    single_cls=opt.single_cls,
                    dataloader=testloader,
                    save_dir=log_dir)

            # Write
            with open(results_file, 'a') as f:
                f.write(s + '%10.4g' * 7 % results +
                        '\n')  # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
            if len(opt.name) and opt.bucket:
                os.system('gsutil cp %s gs://%s/results/results%s.txt' %
                          (results_file, opt.bucket, opt.name))

            # Tensorboard
            if tb_writer:
                tags = [
                    'train/giou_loss', 'train/obj_loss', 'train/cls_loss',
                    'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5',
                    'metrics/mAP_0.5:0.95', 'val/giou_loss', 'val/obj_loss',
                    'val/cls_loss'
                ]
                for x, tag in zip(list(mloss[:-1]) + list(results), tags):
                    tb_writer.add_scalar(tag, x, epoch)

            # Update best mAP
            fi = fitness(np.array(results).reshape(
                1, -1))  # fitness_i = weighted combination of [P, R, mAP, F1]
            if fi > best_fitness:
                best_fitness = fi

            # Save model
            save = (not opt.nosave) or (final_epoch and not opt.evolve)
            if save:
                with open(results_file, 'r') as f:  # create checkpoint
                    ckpt = {
                        'epoch':
                        epoch,
                        'best_fitness':
                        best_fitness,
                        'training_results':
                        f.read(),
                        'model':
                        ema.ema.module.state_dict()
                        if hasattr(ema, 'module') else ema.ema.state_dict(),
                        'optimizer':
                        None if final_epoch else optimizer.state_dict()
                    }

                # Save last, best and delete
                torch.save(ckpt, last)
                if epoch >= (epochs - 5):
                    torch.save(ckpt,
                               last.replace('.pt', '_{:03d}.pt'.format(epoch)))
                if (best_fitness == fi) and not final_epoch:
                    torch.save(ckpt, best)
                del ckpt
        # end epoch ----------------------------------------------------------------------------------------------------
    # end training

    if rank in [-1, 0]:
        # Strip optimizers
        n = ('_'
             if len(opt.name) and not opt.name.isnumeric() else '') + opt.name
        fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
        for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'],
                          [flast, fbest, fresults]):
            if os.path.exists(f1):
                os.rename(f1, f2)  # rename
                ispt = f2.endswith('.pt')  # is *.pt
                strip_optimizer(f2) if ispt else None  # strip optimizer
                os.system('gsutil cp %s gs://%s/weights' % (
                    f2, opt.bucket)) if opt.bucket and ispt else None  # upload
        # Finish
        if not opt.evolve:
            plot_results(save_dir=log_dir)  # save as results.png
        print('%g epochs completed in %.3f hours.\n' %
              (epoch - start_epoch + 1, (time.time() - t0) / 3600))

    dist.destroy_process_group() if rank not in [-1, 0] else None
    torch.cuda.empty_cache()
    return results
Example #6
0
def training(model, optimizer, criterion, scaler, logger, process_group,
             run_dir):
    global phase, step

    step_offset = 0

    for current_phase, phase_params in hp.training.scheme.items():
        if current_phase < phase:
            step_offset += phase_params['steps']
            continue

        # Update learning rate
        for param_group in optimizer['generator'].param_groups:
            param_group['lr'] = phase_params['lr_generator']
        if phase_params['lr_discriminator'] is not None:
            for param_group in optimizer['discriminator'].param_groups:
                param_group['lr'] = phase_params['lr_discriminator']

        # Initialize data loaders
        train_data = ConvDataset(
            sp_files=utils.core.parse_data_structure(hp.files.train_speaker,
                                                     'speaker'),
            ir_files=utils.core.parse_data_structure(hp.files.train_ir, 'ir'),
            noise_files=utils.core.parse_data_structure(
                hp.files.train_noise, 'noise'),
            validation=False,
            augmentation=phase_params['augmentation'])
        valid_data = ConvDataset(
            sp_files=utils.core.parse_data_structure(hp.files.valid_speaker,
                                                     'speaker',
                                                     validation=True),
            ir_files=utils.core.parse_data_structure(hp.files.valid_ir,
                                                     'ir',
                                                     validation=True),
            noise_files=utils.core.parse_data_structure(hp.files.valid_noise,
                                                        'noise',
                                                        validation=True),
            validation=True,
            augmentation=phase_params['augmentation'])
        train_loader = DataLoader(dataset=train_data,
                                  collate_fn=collate,
                                  batch_size=phase_params['batch_size'],
                                  num_workers=hp.training.num_workers,
                                  pin_memory=True)
        valid_loader = DataLoader(dataset=valid_data,
                                  collate_fn=collate,
                                  batch_size=phase_params['batch_size'],
                                  num_workers=hp.training.num_workers,
                                  pin_memory=False)

        with tqdm(desc=f'Train {phase_params["modules"]}',
                  total=phase_params['steps']) as pbar:
            pbar.update(step)
            for inputs, ground_truth, conditioning in train_loader:

                model.train()

                if pbar.n >= phase_params['steps']:
                    break

                inputs = inputs.to(args.local_rank, non_blocking=True)
                ground_truth = ground_truth.to(args.local_rank,
                                               non_blocking=True)
                conditioning = conditioning.to(args.local_rank,
                                               non_blocking=True)

                training_loss = dict()

                with autocast(enabled=hp.training.mixed_precision):
                    if current_phase == 0:

                        prediction = utils.core.ddp(model).generator.wavenet(
                            inputs, conditioning)
                        loss = criterion.sample_loss(
                            ground_truth,
                            prediction) + criterion.spectrogram_loss(
                                ground_truth, prediction)

                        training_loss['wavenet'] = loss.item()

                    elif current_phase == 1:

                        prediction, prediction_postnet = utils.core.ddp(
                            model).generator(inputs, conditioning)
                        wavenet_loss = criterion.sample_loss(
                            ground_truth,
                            prediction) + criterion.spectrogram_loss(
                                ground_truth, prediction)
                        wavenet_postnet_loss = criterion.sample_loss(ground_truth, prediction_postnet) \
                                               + criterion.spectrogram_loss(ground_truth, prediction_postnet)
                        loss = wavenet_loss + wavenet_postnet_loss

                        training_loss['wavenet'] = wavenet_loss.item()
                        training_loss[
                            'wavenet-postnet'] = wavenet_postnet_loss.item()

                    else:

                        prediction, prediction_postnet, prediction_scores, \
                        discriminator_scores, L_FM_G = model(inputs, ground_truth, conditioning)

                        loss, wavenet_loss, wavenet_postnet_loss, \
                        G_loss, D_losses = criterion(pbar.n, ground_truth, prediction, prediction_postnet,
                                                     prediction_scores,
                                                     discriminator_scores, L_FM_G)

                        if G_loss is not None:
                            training_loss['wavenet'] = wavenet_loss.item()
                            training_loss[
                                'wavenet-postnet'] = wavenet_postnet_loss.item(
                                )
                            training_loss['G'] = G_loss.item()
                        training_loss['D_16kHz'] = D_losses[0].item()
                        training_loss['D_8kHz'] = D_losses[1].item()
                        training_loss['D_4kHz'] = D_losses[2].item()
                        training_loss['D_mel'] = D_losses[3].item()

                loss = utils.core.all_reduce(loss, group=process_group)

                if hp.training.mixed_precision:
                    scaler['generator'].scale(loss).backward(
                        retain_graph=current_phase == 2)
                    if current_phase == 2:
                        scaler['discriminator'].scale(loss).backward()
                else:
                    loss.backward()

                if hp.training.mixed_precision:
                    scaler['generator'].step(optimizer['generator'])
                    scaler['generator'].update()
                    if current_phase == 2:
                        scaler['discriminator'].step(
                            optimizer['discriminator'])
                        scaler['discriminator'].update()
                else:
                    optimizer['generator'].step()
                    if current_phase == 2:
                        optimizer['discriminator'].step()

                optimizer['generator'].zero_grad()
                optimizer['discriminator'].zero_grad()

                pbar.set_postfix(loss=training_loss)
                pbar.update()
                step = pbar.n

                if args.local_rank == 0:
                    logger.log_training(pbar.n + step_offset, {
                        'training.loss': training_loss,
                    })

                if pbar.n % hp.training.validation_every_n_steps == 0:
                    validation_loss, audio_data = validation(
                        model, criterion, valid_loader, process_group,
                        current_phase)

                    if args.local_rank == 0:
                        logger.log_validation(
                            model=utils.core.ddp(model),
                            step=pbar.n + step_offset,
                            scalars={'validation.loss': validation_loss},
                            audio_data=audio_data)
                        utils.core.save_checkpoint(run_dir,
                                                   utils.core.ddp(model),
                                                   optimizer, scaler,
                                                   current_phase, pbar.n)

        if current_phase < 2:
            phase = current_phase + 1
            step = 0
Example #7
0
def trainer_augment(loaders, model_params, model, criterion, val_criterion,
                    optimizer, lr_scheduler, optimizer_params, training_params,
                    save_path):

    start_epoch = training_params['start_epoch']
    total_epochs = training_params['num_epoch']
    device = training_params['device']
    device_ids = training_params['device_ids']
    augment_prob = model_params['special_augment_prob']

    model = nn.DataParallel(model, device_ids=device_ids)
    cuda = device.type != 'cpu'
    scaler = amp.GradScaler(enabled=cuda)
    ema = model_params['ema_model']

    print("Epochs: {}\n".format(total_epochs))
    best_epoch = 1
    best_acc = 0.0
    history = {
        "train": {
            "loss": [],
            "acc": []
        },
        "eval": {
            "loss": [],
            "acc": []
        },
        "lr": []
    }
    # num_layer = 213
    # num_layer = 340
    num_layer = 418
    for epoch in range(start_epoch, total_epochs + 1):
        if epoch <= training_params['warm_up'] and epoch == 1:
            training_params['TTA_time'] = 1
            model = freeze_model(model, num_layer)
            #print('-------------------------', ct)

        elif epoch == training_params['warm_up'] + 1:
            training_params['TTA_time'] = 5
            model = unfreeze_model(model)

        epoch_save_path = save_path + '_epoch-{}.pt'.format(epoch)
        head = "epoch {:2}/{:2}".format(epoch, total_epochs)
        print(head + "\n" + "-" * (len(head)))

        model.train()
        running_labels = 0
        running_scores = []
        running_loss = 0.0
        optimizer.zero_grad()
        for images, labels in tqdm.tqdm(loaders["train"]):
            images, labels = images.to(device), labels.to(device)
            with amp.autocast(enabled=cuda):

                snapmix_check = False
                if np.random.rand(1) >= 0:
                    snapmix_check = True
                    SNAPMIX_ALPHA = 5.0
                    mixed_images, labels_1, labels_2, lam_a, lam_b = snapmix(
                        images, labels, SNAPMIX_ALPHA, model)
                    mixed_images, labels_1, labels_2 = torch.autograd.Variable(
                        mixed_images), torch.autograd.Variable(
                            labels_1), torch.autograd.Variable(labels_2)

                    outputs, _ = model(mixed_images, train_state=cuda)
                    loss_a = criterion(outputs, labels_1)
                    loss_b = criterion(outputs, labels_2)
                    loss = torch.mean(loss_a * lam_a + loss_b * lam_b)
                    running_labels += labels_1.shape[0]

                else:

                    mixed_images, labels_1, labels_2, lam = cutmix(
                        images, labels)
                    mixed_images, labels_1, labels_2 = torch.autograd.Variable(
                        mixed_images), torch.autograd.Variable(
                            labels_1), torch.autograd.Variable(labels_2)

                    outputs, _ = model(mixed_images, train_state=cuda)
                    loss = lam * criterion(outputs, labels_1.unsqueeze(1)) + (
                        1 - lam) * criterion(outputs, labels_2.unsqueeze(1))
                    running_labels += labels_1.shape[0]

            #first step
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            if ema:
                ema.update(model)

            running_loss += loss.item() * images.size(0)

        epoch_loss = running_loss / running_labels
        history["train"]["loss"].append(epoch_loss)
        print("{} - loss: {:.4f}".format("train", epoch_loss))

        if ema:
            ema.update_attr(model, include=[])
            model_eval = ema.ema

        with torch.no_grad():
            if ema:
                model_eval = nn.DataParallel(model_eval, device_ids=device_ids)
                model_eval.eval()
            else:
                model.eval()
            final_scores = []
            final_loss = 0.0

            for TTA in range(training_params['TTA_time']):
                running_labels = []
                running_scores = []
                running_outputs_softmax = np.empty((0, 5))
                running_loss = 0.0

                for images, labels in tqdm.tqdm(loaders["eval"]):
                    images, labels = images.to(device), labels.to(device)

                    if ema:
                        outputs, _ = model_eval(images, train_state=False)
                    else:
                        outputs, _ = model(images)

                    outputs_softmax = F.log_softmax(outputs, dim=-1)
                    scores = torch.argmax(outputs_softmax, 1)
                    loss = val_criterion(outputs, labels)
                    running_labels += list(
                        labels.unsqueeze(1).data.cpu().numpy())
                    running_scores += list(scores.cpu().detach().numpy())
                    running_outputs_softmax = np.append(
                        running_outputs_softmax,
                        outputs_softmax.cpu().detach().numpy(),
                        axis=0)
                    running_loss += loss.item() * images.size(0)
                print("{} - TTA loss: {:.4f} acc: {:.4f}".format(
                    "eval", running_loss / len(running_labels),
                    metrics.accuracy_score(running_labels,
                                           np.round(running_scores))))

                if TTA == 0:
                    final_scores = running_outputs_softmax
                else:
                    final_scores += running_outputs_softmax
                final_loss += running_loss

        final_scores_softmax_torch = torch.tensor(final_scores /
                                                  training_params['TTA_time'],
                                                  dtype=torch.float32)
        running_labels_torch = torch.tensor(running_labels,
                                            dtype=torch.float32)
        epoch_loss = val_criterion(
            final_scores_softmax_torch.to(device=training_params['device']),
            running_labels_torch.squeeze().to(
                device=training_params['device']))
        final_scores = np.argmax(final_scores, axis=1)
        epoch_accuracy_score = metrics.accuracy_score(running_labels,
                                                      np.round(final_scores))
        history["eval"]["loss"].append(epoch_loss.cpu().detach().numpy())
        history["eval"]["acc"].append(epoch_accuracy_score)
        print("{} loss: {:.4f} acc: {:.4f} lr: {:.9f}".format(
            "eval - epoch", epoch_loss, epoch_accuracy_score,
            optimizer.param_groups[0]["lr"]))
        history["lr"].append(optimizer.param_groups[0]["lr"])
        lr_scheduler.step()

        if epoch_accuracy_score > best_acc:
            best_epoch = epoch
            best_acc = epoch_accuracy_score
        save_checkpoint(model_eval, optimizer, lr_scheduler, epoch,
                        epoch_save_path)

    print("\nFinish: - Best Epoch: {} - Best accuracy: {}\n".format(
        best_epoch, best_acc))
Example #8
0
def train_fn(model, optimizer, scheduler, loss_fn, dataloader, device, scaler,
             epoch, do_cutmix, do_fmix, do_mixup):
    model.train()
    final_loss = 0
    s = time.time()
    pbar = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (images, labels) in pbar:
        ### mix up
        p = np.random.uniform(0, 1)
        if do_cutmix and do_fmix:
            if p < 0.25:  # cutmix
                images, labels = cutmix(images, labels, alpha=1.)
            elif p < 0.5:  # fmix
                img_size = (images.size(2), images.size(3))
                images, labels = fmix(images,
                                      labels,
                                      alpha=1.,
                                      decay_power=3.,
                                      shape=img_size)
            else:
                eyes = torch.eye(5)
                labels = eyes[labels]

        elif do_cutmix and not do_fmix and p < 0.5:  # cutmix
            images, labels = cutmix(images, labels, alpha=1.)

        elif do_fmix and not do_cutmix and p < 0.5:  # fmix
            img_size = (images.size(2), images.size(3))
            images, labels = fmix(images,
                                  labels,
                                  alpha=1.,
                                  decay_power=3.,
                                  shape=img_size)

        elif do_mixup and p < 0.5:
            images, labels = mixup(images, labels, alpha=1.0)

        else:
            eyes = torch.eye(5)
            labels = eyes[labels]
        ########

        images = images.to(device).float()
        labels = labels.to(device).float()

        with autocast():
            outputs = model(
                images)  # これをautocastから外すと実験結果が固定されるけど、メモリが増える、512でやるときは必須
            loss = loss_fn(outputs, labels)
            scaler.scale(loss).backward()
            final_loss += loss.item()
            del loss
            torch.cuda.empty_cache()
        if (i + 1) % 2 == 0 or ((i + 1) == len(dataloader)):
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        if i % 10 == 0 or (i + 1) == len(dataloader):
            description = f"[train] epoch {epoch} | iteration {i} | time {time.time() - s:.4f} | avg loss {final_loss / (i+1):.6f}"
            pbar.set_description(description)
        torch.cuda.empty_cache()

    final_loss /= len(dataloader)

    return final_loss
Example #9
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model",
        type=str,
        default="QuartzNet15x5Base-En",
        required=False,
        help="Pass: '******'",
    )
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument(
        "--dont_normalize_text",
        default=False,
        action='store_true',
        help="Turn off trasnscript normalization. Recommended for non-English.",
    )
    parser.add_argument("--out_dir",
                        type=str,
                        required=True,
                        help="Destination dir for output files")
    parser.add_argument("--sctk_dir",
                        type=str,
                        required=False,
                        default="",
                        help="Path to sctk root dir")
    parser.add_argument("--glm",
                        type=str,
                        required=False,
                        default="",
                        help="Path to glm file")
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)

    use_sctk = os.path.exists(args.sctk_dir)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model)
    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)
    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': not args.dont_normalize_text,
        })
    if can_gpu:
        asr_model = asr_model.cuda()
    asr_model.eval()
    labels_map = dict([(i, asr_model.decoder.vocabulary[i])
                       for i in range(len(asr_model.decoder.vocabulary))])

    wer = WER(vocabulary=asr_model.decoder.vocabulary)
    hypotheses = []
    references = []
    all_log_probs = []
    for test_batch in asr_model.test_dataloader():
        if can_gpu:
            test_batch = [x.cuda() for x in test_batch]
        with autocast():
            log_probs, encoded_len, greedy_predictions = asr_model(
                input_signal=test_batch[0], input_signal_length=test_batch[1])
        for r in log_probs.cpu().numpy():
            all_log_probs.append(r)
        hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions)
        for batch_ind in range(greedy_predictions.shape[0]):
            reference = ''.join([
                labels_map[c]
                for c in test_batch[2][batch_ind].cpu().detach().numpy()
            ])
            references.append(reference)
        del test_batch

    info_list = get_utt_info(args.dataset)
    hypfile = os.path.join(args.out_dir, "hyp.trn")
    reffile = os.path.join(args.out_dir, "ref.trn")
    with open(hypfile, "w") as hyp_f, open(reffile, "w") as ref_f:
        for i in range(len(hypotheses)):
            utt_id = os.path.splitext(
                os.path.basename(info_list[i]['audio_filepath']))[0]
            # rfilter in sctk likes each transcript to have a space at the beginning
            hyp_f.write(" " + hypotheses[i] + " (" + utt_id + ")" + "\n")
            ref_f.write(" " + references[i] + " (" + utt_id + ")" + "\n")

    if use_sctk:
        score_with_sctk(args.sctk_dir,
                        reffile,
                        hypfile,
                        args.out_dir,
                        glm=args.glm)
Example #10
0
    def train(self, G, D, dataset):

        mb_size = self.args.local_bs
        # local 参与方的优化器

        # optimizer_D = torch.optim.SGD(params=D.parameters(), lr=self.args.fed_d_lr)
        # optimizer_G = torch.optim.SGD(params=G.parameters(), lr=self.args.fed_g_lr)

        optimizer_D = torch.optim.Adam(params=D.parameters(), lr=self.args.fed_d_lr)
        optimizer_G = torch.optim.Adam(params=G.parameters(), lr=self.args.fed_g_lr)
        if self.args.use_amp:
            # 使用amp混合精度
            scaler_D = GradScaler()
            scaler_G = GradScaler()

        # optimizer_D = torch.optim.RMSprop(params=D.parameters(), lr=self.args.fed_d_lr)
        # optimizer_G = torch.optim.RMSprop(params=G.parameters(), lr=self.args.fed_g_lr)

        # 用于记录每个epoch的loss信息
        # G_epoch_loss = []
        # D_epoch_loss = []
        # 记录存储每个batch的loss信息
        G_batch_loss = []
        D_batch_loss = []
        # 用于记录G的生成效果
        G_MSE_batch_trian_loss = []
        # G_MSE_epoch_train_loss = []
        G_MSE_batch_test_loss = []
        # G_MSE_epoch_tets_loss = []

        self.Dim = dataset['d']
        self.trainX = dataset['train_x']
        self.testX = dataset['test_x']
        self.trainM = dataset['train_m']
        self.testM = dataset['test_m']
        self.Train_No = dataset['train_no']
        self.Test_No = dataset['test_no']

        for j in range(self.args.local_ep):
            mb_idx = self.sample_idx(self.Train_No, mb_size)
            X_mb = self.trainX[mb_idx, :]

            Z_mb = sample_Z(mb_size, self.Dim)
            M_mb = self.trainM[mb_idx, :]

            if self.p_hint == 1.0:
                H_mb = M_mb
            else:
                H_mb1 = self.sample_M(mb_size, self.Dim, 1 - self.p_hint)
                H_mb = M_mb * H_mb1

            New_X_mb = M_mb * X_mb + (1 - M_mb) * Z_mb  # Missing Data Introduce

            X_mb = torch.tensor(X_mb, device="cuda", dtype=torch.float32)
            M_mb = torch.tensor(M_mb, device="cuda", dtype=torch.float32)
            H_mb = torch.tensor(H_mb, device="cuda", dtype=torch.float32)
            New_X_mb = torch.tensor(New_X_mb, device="cuda", dtype=torch.float32)

            if self.args.use_amp:
                optimizer_D.zero_grad()
                with autocast():
                    D_loss_curr = self.compute_D_loss(G, D, M=M_mb, New_X=New_X_mb, H=H_mb)
                scaler_D.scale(D_loss_curr).backward()
                scaler_D.step(optimizer_D)
                scaler_D.update()

                if j % self.args.fed_n_critic == 0:
                    optimizer_G.zero_grad()
                    with autocast():
                        G_loss_curr, MSE_train_loss_curr, MSE_test_loss_curr = self.compute_G_loss(G, D, X=X_mb, M=M_mb,
                                                                                                   New_X=New_X_mb,
                                                                                                   H=H_mb)
                    scaler_G.scale(G_loss_curr).backward()
                    scaler_G.step(optimizer_G)
                    scaler_G.update()

            else:
                optimizer_D.zero_grad()
                D_loss_curr = self.compute_D_loss(G, D, M=M_mb, New_X=New_X_mb, H=H_mb)
                # 计算WGAN-D的损失函数
                # D_loss_curr = self.compute_WD_loss(G, D, M=M_mb, New_X=New_X_mb, H=H_mb)
                D_loss_curr.backward()
                optimizer_D.step()

                if j % self.args.fed_n_critic == 0:
                    optimizer_G.zero_grad()
                    G_loss_curr, MSE_train_loss_curr, MSE_test_loss_curr = self.compute_G_loss(G, D, X=X_mb, M=M_mb,
                                                                                               New_X=New_X_mb, H=H_mb)

                    G_loss_curr.backward()
                    optimizer_G.step()

            if self.args.verbose and j % 10 == 0:
                print('Update Epoch: {} [{}/{} ({:.0f}%)]\t G_Loss: {:.6f} D_Loss: {:.6f}'.format(
                    self.args.local_stations[self.idx], j, self.args.local_ep,
                    100. * j / self.args.local_ep,
                    G_loss_curr,
                    D_loss_curr
                ))
            G_batch_loss.append(G_loss_curr)
            D_batch_loss.append(D_loss_curr)
            G_MSE_batch_trian_loss.append(MSE_train_loss_curr)
            G_MSE_batch_test_loss.append(MSE_test_loss_curr)

        G_epoch_loss_mean = sum(G_batch_loss) / len(G_batch_loss)
        D_epoch_loss_mean = sum(D_batch_loss) / len(D_batch_loss)
        G_MSE_epoch_train_loss_mean = sum(G_MSE_batch_trian_loss) / len(G_MSE_batch_trian_loss)
        G_MSE_epoch_test_loss_mean = sum(G_MSE_batch_test_loss) / len(G_MSE_batch_test_loss)
        return G.state_dict(), D.state_dict(), G_epoch_loss_mean, D_epoch_loss_mean, \
               G_MSE_epoch_train_loss_mean, G_MSE_epoch_test_loss_mean, self.Train_No
Example #11
0
        )
        args.epochs = start_epoch + 1

    logging.info('Training Epochs: from {0} to {1}'.format(
        start_epoch, args.epochs - 1))
    for epoch in range(start_epoch, args.epochs):
        #Train
        model.train()
        runningLoss = []
        train_loss = []
        print('Epoch ' + str(epoch) + ': Train')
        for i, (images, gt) in enumerate(tqdm(train_loader)):
            images = images.to(device)
            gt = gt.to(device)

            with autocast(enabled=args.amp):
                if type(model) is SRCNN3D:
                    output1, output2 = model(images)
                    loss1 = loss_func(output1, gt)
                    loss2 = loss_func(output2, gt)
                    loss = loss2 + loss1
                elif type(model) is UNetVSeg:
                    if IsDeepSup:
                        sys.exit("Not Implimented yet")
                    else:
                        out, _, _ = model(images)
                        loss = loss_func(out, gt)
                elif type(model) is ThisNewNet:
                    out, loss = model(images, gt=gt)
                else:
                    out = model(images)
Example #12
0
    def independent_training(self, G, D, dataset, station_name='', save_pth_pre=''):
        # 写入训练过程数据
        fw_name = save_pth_pre + 'indpt_' + station_name + '_log.txt'
        fw_fed_main = open(fw_name, 'w+')
        fw_fed_main.write('iter\t G_loss\t D_loss\t G_train_MSE_loss\t G_test_MSE_loss\t \n')

        mb_size = self.args.local_bs
        # local 参与方的优化器
        # optimizer_D = torch.optim.SGD(params=D.parameters(), lr=self.args.d_lr, momentum=0.9)
        # optimizer_G = torch.optim.SGD(params=G.parameters(), lr=self.args.g_lr, momentum=0.9)

        optimizer_D = torch.optim.Adam(params=D.parameters(), lr=self.args.d_lr)
        optimizer_G = torch.optim.Adam(params=G.parameters(), lr=self.args.g_lr)

        if self.args.use_amp:
            # 使用amp混合精度
            scaler_D = GradScaler()
            scaler_G = GradScaler()

        # optimizer_D = torch.optim.RMSprop(params=D.parameters(), lr=self.args.d_lr)
        # optimizer_G = torch.optim.RMSprop(params=G.parameters(), lr=self.args.g_lr)

        # 用于调整学习率
        if self.args.lr_decay:
            D_StepLR = torch.optim.lr_scheduler.StepLR(optimizer_D,
                                                       step_size=self.args.d_lr_decay_step,
                                                       gamma=self.args.d_lr_decay)
            G_StepLR = torch.optim.lr_scheduler.StepLR(optimizer_G,
                                                       step_size=self.args.g_lr_decay_step,
                                                       gamma=self.args.g_lr_decay)
            # D_StepLR = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_D, mode='min', factor=0.1, patience=10,
            #                                                       verbose=True, threshold=0.0001,
            #                                                       threshold_mode='rel',
            #                                                       min_lr=0)
            # G_StepLR = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_G, mode='min', factor=0.1, patience=10,
            #                                                       verbose=True, threshold=0.0001,
            #                                                       threshold_mode='rel',
            #                                                       min_lr=0)

        # 用于绘图
        fig, axs = plt.subplots(nrows=2, ncols=2, constrained_layout=True)

        # 用于记录每个epoch的loss信息
        G_epoch_loss = []
        D_epoch_loss = []
        # 记录存储每个batch的loss信息
        # G_batch_loss = []
        # D_batch_loss = []
        # 用于记录G的生成效果
        # G_MSE_batch_trian_loss = []
        G_MSE_epoch_train_loss = []
        # G_MSE_batch_test_loss = []
        G_MSE_epoch_tets_loss = []

        self.Dim = dataset['d']
        self.trainX = dataset['train_x']
        self.testX = dataset['test_x']
        self.trainM = dataset['train_m']
        self.testM = dataset['test_m']
        self.Train_No = dataset['train_no']
        self.Test_No = dataset['test_no']

        print('Station ' + station_name + ' is under Training...')
        # 这次不加上这个 self.args.local_ep
        with tqdm(range(self.args.epochs)) as tq:
            for j in tq:  # 暂时取消self.args.local_ep *
                self.cur_epoch = j
                tq.set_description('Local Updating')
                mb_idx = self.sample_idx(self.Train_No, mb_size)
                X_mb = self.trainX[mb_idx, :]

                Z_mb = sample_Z(mb_size, self.args.input_dim) * 0.1
                M_mb = self.trainM[mb_idx, :]

                # 当p_hint=1时,即为原始的condition gan;否则,即为GAIN算法
                if self.p_hint == 1.0:
                    H_mb = M_mb
                else:
                    H_mb1 = self.sample_M(mb_size, self.Dim, 1 - self.p_hint)
                    H_mb = M_mb * H_mb1

                New_X_mb = M_mb * X_mb + (1 - M_mb) * Z_mb  # Missing Data Introduce

                X_mb = torch.tensor(X_mb, device=self.args.device, dtype=torch.float32)
                M_mb = torch.tensor(M_mb, device=self.args.device, dtype=torch.float32)
                H_mb = torch.tensor(H_mb, device=self.args.device, dtype=torch.float32)
                New_X_mb = torch.tensor(New_X_mb, device=self.args.device, dtype=torch.float32)

                if self.args.use_amp:
                    optimizer_D.zero_grad()
                    with autocast():
                        D_loss_curr = self.compute_D_loss(G, D, M=M_mb, New_X=New_X_mb, H=H_mb)
                    scaler_D.scale(D_loss_curr).backward()
                    scaler_D.step(optimizer_D)
                    scaler_D.update()

                    if j % self.args.n_critic == 0:
                        optimizer_G.zero_grad()
                        with autocast():
                            G_loss_curr, MSE_train_loss_curr, RMSE_test_loss_curr = self.compute_G_loss(G, D,
                                                                                                        X=X_mb,
                                                                                                        M=M_mb,
                                                                                                        New_X=New_X_mb,
                                                                                                        H=H_mb)
                        scaler_G.scale(G_loss_curr).backward()
                        scaler_G.step(optimizer_D)
                        scaler_G.update()
                else:
                    optimizer_D.zero_grad()
                    D_loss_curr = self.compute_D_loss(G, D, M=M_mb, New_X=New_X_mb, H=H_mb)

                    D_loss_curr.backward()
                    optimizer_D.step()

                    if j % self.args.n_critic == 0:
                        optimizer_G.zero_grad()
                        G_loss_curr, MSE_train_loss_curr, RMSE_test_loss_curr = self.compute_G_loss(G, D, X=X_mb,
                                                                                                    M=M_mb,
                                                                                                    New_X=New_X_mb,
                                                                                                    H=H_mb)

                        G_loss_curr.backward()
                        optimizer_G.step()
                if self.args.lr_decay:
                    D_StepLR.step()
                    G_StepLR.step()

                # 保存模型的文件夹名称
                # file_name = save_pth_pre.split('/')[2]
                file_name = save_pth_pre.split('/')
                file_name = file_name[2] + '/' + file_name[3]
                # %% Intermediate Losses
                tq.set_postfix(Train_MSE_loss=np.sqrt(MSE_train_loss_curr.item()),
                               Train_RMSE=np.sqrt(RMSE_test_loss_curr.item()))
                # 保存模型
                if j % 100 == 0:
                    self.save_model(G, D, file_name, station_name)

                # 调整学习率 步骤
                # G_StepLR.step()
                # D_StepLR.step()

                if j % 1 == 0:
                    # 保存训练过程数据
                    G_epoch_loss.append(G_loss_curr)
                    D_epoch_loss.append(D_loss_curr)

                    G_MSE_epoch_train_loss.append(MSE_train_loss_curr)
                    G_MSE_epoch_tets_loss.append(RMSE_test_loss_curr)

                    fw_fed_main.write('{}\t {:.5f}\t {:.5f}\t {:.5f}\t {:.5f}\t \n'.format(j, G_loss_curr, D_loss_curr,
                                                                                           MSE_train_loss_curr,
                                                                                           RMSE_test_loss_curr))

        # self.plot_progess_info(axs[0, 0], G_epoch_loss, 'G loss')
        # self.plot_progess_info(axs[0, 1], D_epoch_loss, 'D loss')
        # self.plot_progess_info(axs[1, 0], G_MSE_epoch_train_loss, 'G MSE training loss')
        # self.plot_progess_info(axs[1, 1], G_MSE_epoch_tets_loss, 'RMSE on training dataset')
        # Print the loss info in training progress
        loss_plot(axs[0, 0], G_epoch_loss, 'G loss')
        loss_plot(axs[0, 1], D_epoch_loss, 'D loss')
        loss_plot(axs[1, 0], G_MSE_epoch_train_loss, 'G MSE training loss')
        loss_plot(axs[1, 1], G_MSE_epoch_tets_loss, 'RMSE on training dataset')

        plt.savefig(save_pth_pre + 'indpt_{}_info.eps'.format(station_name))
        plt.savefig(save_pth_pre + 'indpt_{}_info.svg'.format(station_name))

        plt.close()

        return G
def train(args, train_dataloader):
    model = MyModel(args)
    model.train()
    if args.amp:
        scaler = GradScaler()
    device = args.local_rank if args.local_rank != -1 \
        else (torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu'))
    if args.local_rank != -1:
        torch.cuda.set_device(args.local_rank)
    model.to(device)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[
                                                          args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {"params": [p for n, p in model.named_parameters() if not any(
            nd in n for nd in no_decay)], "weight_decay":args.weight_decay},
        {"params": [p for n, p in model.named_parameters() if any(
            nd in n for nd in no_decay)], "weight_decay":0.0}
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.lr)
    if args.warmup_ratio > 0:
        num_training_steps = len(train_dataloader)*args.max_epochs
        warmup_steps = args.warmup_ratio*num_training_steps
        scheduler = get_linear_schedule_with_warmup(
            optimizer, warmup_steps, num_training_steps)
    if args.local_rank < 1:
        mid = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime(time.time()))
    for epoch in range(args.max_epochs):
        if args.local_rank != -1:
            train_dataloader.sampler.set_epoch(epoch)
        tqdm_train_dataloader = tqdm(
            train_dataloader, desc="epoch:%d" % epoch, ncols=150)
        for i, batch in enumerate(tqdm_train_dataloader):
            torch.cuda.empty_cache()
            optimizer.zero_grad()
            txt_ids, attention_mask, token_type_ids, context_mask, turn_mask, tags = batch['txt_ids'], batch['attention_mask'], batch['token_type_ids'],\
                batch['context_mask'], batch['turn_mask'], batch['tags']
            txt_ids, attention_mask, token_type_ids, context_mask, turn_mask, tags = txt_ids.to(device), attention_mask.to(device), token_type_ids.to(device),\
                context_mask.to(device), turn_mask.to(device), tags.to(device)
            if args.amp:
                with autocast():
                    loss, (loss_t1, loss_t2) = model(
                        txt_ids, attention_mask, token_type_ids, context_mask, turn_mask, tags)
                scaler.scale(loss).backward()
                if args.max_grad_norm > 0:
                    scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.max_grad_norm)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss, (loss_t1, loss_t2) = model(txt_ids, attention_mask,
                                                 token_type_ids, context_mask, turn_mask, tags)
                loss.backward()
                if args.max_grad_norm > 0:
                    clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
            lr = optimizer.param_groups[0]['lr']
            named_parameters = [
                (n, p) for n, p in model.named_parameters() if not p.grad is None]
            grad_norm = torch.norm(torch.stack(
                [torch.norm(p.grad) for n, p in named_parameters])).item()
            if args.warmup_ratio > 0:
                scheduler.step()
            postfix_str = "norm:{:.2f},lr:{:.1e},loss:{:.2e},t1:{:.2e},t2:{:.2e}".format(
                grad_norm, lr, loss.item(), loss_t1, loss_t2)
            tqdm_train_dataloader.set_postfix_str(postfix_str)
        if args.local_rank in [-1, 0] and not args.not_save:
            if hasattr(model, 'module'):
                model_state_dict = model.module.state_dict()
            else:
                model_state_dict = model.state_dict()
            checkpoint = {"model_state_dict": model_state_dict}
            save_dir = './checkpoints/%s/%s/' % (args.dataset_tag, mid)
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
                pickle.dump(args, open(save_dir+'args', 'wb'))
            save_path = save_dir+"checkpoint_%d.cpt" % epoch
            torch.save(checkpoint, save_path)
            print("model saved at:", save_path)
        if args.test_eval and args.local_rank in [-1, 0]:
            test_dataloader = load_t1_data(args.dataset_tag, args.test_path, args.pretrained_model_path,
                                           args.window_size, args.overlap, args.test_batch, args.max_len)  # test_dataloader是第一轮问答的dataloder
            (p1, r1, f1), (p2, r2, f2) = test_evaluation(
                model, test_dataloader, args.threshold, args.amp)
            print(
                "Turn 1: precision:{:.4f} recall:{:.4f} f1:{:.4f}".format(p1, r1, f1))
            print(
                "Turn 2: precision:{:.4f} recall:{:.4f} f1:{:.4f}".format(p2, r2, f2))
            model.train()
        if args.local_rank != -1:
            torch.distributed.barrier()
Example #14
0
 def autocast(self):
     """AMP context"""
     return amp.autocast()
Example #15
0
def train_one_epoch(model,
                    optimizer,
                    data_loader,
                    device,
                    epoch,
                    print_freq,
                    accumulate,
                    img_size,
                    grid_min,
                    grid_max,
                    gs,
                    multi_scale=False,
                    warmup=False):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    lr_scheduler = None
    if epoch == 0 and warmup is True:  # 当训练第一轮(epoch=0)时,启用warmup训练方式,可理解为热身训练
        warmup_factor = 1.0 / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters,
                                                 warmup_factor)
        accumulate = 1

    enable_amp = True if "cuda" in device.type else False
    scaler = amp.GradScaler(enabled=enable_amp)

    mloss = torch.zeros(4).to(device)  # mean losses
    now_lr = 0.
    nb = len(data_loader)  # number of batches
    # imgs: [batch_size, 3, img_size, img_size]
    # targets: [num_obj, 6] , that number 6 means -> (img_index, obj_index, x, y, w, h)
    # paths: list of img path
    for i, (imgs, targets, paths, _, _) in enumerate(
            metric_logger.log_every(data_loader, print_freq, header)):
        # ni 统计从epoch0开始的所有batch数
        ni = i + nb * epoch  # number integrated batches (since train start)
        imgs = imgs.to(
            device).float() / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
        targets = targets.to(device)

        # Multi-Scale
        if multi_scale:
            # 每训练64张图片,就随机修改一次输入图片大小,
            # 由于label已转为相对坐标,故缩放图片不影响label的值
            if ni % accumulate == 0:  # adjust img_size (67% - 150%) every 1 batch
                # 在给定最大最小输入尺寸范围内随机选取一个size(size为32的整数倍)
                img_size = random.randrange(grid_min, grid_max + 1) * gs
            sf = img_size / max(imgs.shape[2:])  # scale factor

            # 如果图片最大边长不等于img_size, 则缩放图片,并将长和宽调整到32的整数倍
            if sf != 1:
                # gs: (pixels) grid size
                ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]
                      ]  # new shape (stretched to 32-multiple)
                imgs = F.interpolate(imgs,
                                     size=ns,
                                     mode='bilinear',
                                     align_corners=False)

        # 混合精度训练上下文管理器,如果在CPU环境中不起任何作用
        with amp.autocast(enabled=enable_amp):
            pred = model(imgs)

            # loss
            loss_dict = compute_loss(pred, targets, model)

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purpose
            loss_dict_reduced = utils.reduce_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            loss_items = torch.cat(
                (loss_dict_reduced["box_loss"], loss_dict_reduced["obj_loss"],
                 loss_dict_reduced["class_loss"], losses_reduced)).detach()
            mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses

            if not torch.isfinite(losses_reduced):
                print('WARNING: non-finite loss, ending training ',
                      loss_dict_reduced)
                print("training image path: {}".format(",".join(paths)))
                sys.exit(1)

            losses *= 1. / accumulate  # scale loss

        # backward
        scaler.scale(losses).backward()
        # optimize
        # 每训练64张图片更新一次权重
        if ni % accumulate == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        now_lr = optimizer.param_groups[0]["lr"]
        metric_logger.update(lr=now_lr)

        if ni % accumulate == 0 and lr_scheduler is not None:  # 第一轮使用warmup训练方式
            lr_scheduler.step()

    return mloss, now_lr
Example #16
0
def train_epoch(model, config, train_loader, val_loader, epoch_i):
    optimizer = config['optimizer']
    scheduler = config['scheduler']
    writer = config['writer']
    scaler = config['scaler']
    opt = config['opt']

    if opt.criterion == 'MSELoss':
        criterion = torch.nn.MSELoss(reduction='sum').to(opt.device)
    else:
        criterion = torch.nn.CrossEntropyLoss().to(opt.device)

    # train one epoch
    model.train()
    total_loss = 0.
    final_val_loss = 0.
    total_examples = 0
    st_time = time.time()
    optimizer.zero_grad()
    for local_step, (x,y) in tqdm(enumerate(train_loader), total=len(train_loader)):
        global_step = (len(train_loader) * epoch_i) + local_step
        x = to_device(x, opt.device)
        y = to_device(y, opt.device)
        with autocast(enabled=opt.use_amp):
            if opt.use_profiler:
                with profiler.profile(profile_memory=True, record_shapes=True) as prof:
                    output = model(x)
                print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
            else:
                output = model(x)
            loss = criterion(output, y)
            if opt.gradient_accumulation_steps > 1:
                loss = loss / opt.gradient_accumulation_steps
        # back-propagation - begin
        scaler.scale(loss).backward()
        if (local_step + 1) % opt.gradient_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            if opt.use_transformers_optimizer: scheduler.step()
        # back-propagation - end
        cur_examples = y.size(0)
        total_examples += cur_examples
        total_loss += (loss.item() * cur_examples)
        if writer: writer.add_scalar('Loss/train', loss.item(), global_step)
    cur_loss = total_loss / total_examples

    # evaluate
    eval_loss, eval_acc = evaluate(model, config, val_loader)
    curr_time = time.time()
    elapsed_time = (curr_time - st_time) / 60
    st_time = curr_time
    curr_lr = scheduler.get_last_lr()[0] if scheduler else optimizer.param_groups[0]['lr']
    logger.info('{:3d} epoch | {:5d}/{:5d} | train loss : {:6.3f}, valid loss {:6.3f}, valid acc {:.4f}| lr :{:7.6f} | {:5.2f} min elapsed'.\
            format(epoch_i, local_step+1, len(train_loader), cur_loss, eval_loss, eval_acc, curr_lr, elapsed_time)) 
    if writer:
        writer.add_scalar('Loss/valid', eval_loss, global_step)
        writer.add_scalar('Acc/valid', eval_acc, global_step)
        writer.add_scalar('LearningRate/train', curr_lr, global_step)
    return eval_loss, eval_acc
def main():
    '''
    :return: None

    * :ref:`API in English <lif_fc_mnist.main-en>`

    .. _lif_fc_mnist.main-cn:

    使用全连接-LIF的网络结构,进行MNIST识别。\n
    这个函数会初始化网络进行训练,并显示训练过程中在测试集的正确率。

    * :ref:`中文API <lif_fc_mnist.main-cn>`

    .. _lif_fc_mnist.main-en:

    The network with FC-LIF structure for classifying MNIST.\n
    This function initials the network, starts trainingand shows accuracy on test dataset.
    '''
    parser = argparse.ArgumentParser(description='LIF MNIST Training')
    parser.add_argument('-T',
                        default=100,
                        type=int,
                        help='simulating time-steps')
    parser.add_argument('-device', default='cuda:0', help='device')
    parser.add_argument('-b', default=64, type=int, help='batch size')
    parser.add_argument('-epochs',
                        default=100,
                        type=int,
                        metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-j',
                        default=4,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('-data-dir',
                        type=str,
                        help='root dir of MNIST dataset')
    parser.add_argument('-out-dir',
                        type=str,
                        default='./logs',
                        help='root dir for saving logs and checkpoint')
    parser.add_argument('-resume',
                        type=str,
                        help='resume from the checkpoint path')
    parser.add_argument('-amp',
                        action='store_true',
                        help='automatic mixed precision training')
    parser.add_argument('-opt',
                        type=str,
                        choices=['sgd', 'adam'],
                        default='adam',
                        help='use which optimizer. SGD or Adam')
    parser.add_argument('-momentum',
                        default=0.9,
                        type=float,
                        help='momentum for SGD')
    parser.add_argument('-lr', default=1e-3, type=float, help='learning rate')
    parser.add_argument('-tau',
                        default=2.0,
                        type=float,
                        help='parameter tau of LIF neuron')

    args = parser.parse_args()
    print(args)

    net = SNN(tau=args.tau)

    print(net)

    net.to(args.device)

    # 初始化数据加载器
    train_dataset = torchvision.datasets.MNIST(
        root=args.data_dir,
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True)
    test_dataset = torchvision.datasets.MNIST(
        root=args.data_dir,
        train=False,
        transform=torchvision.transforms.ToTensor(),
        download=True)

    train_data_loader = data.DataLoader(dataset=train_dataset,
                                        batch_size=args.b,
                                        shuffle=True,
                                        drop_last=True,
                                        num_workers=args.j,
                                        pin_memory=True)
    test_data_loader = data.DataLoader(dataset=test_dataset,
                                       batch_size=args.b,
                                       shuffle=False,
                                       drop_last=False,
                                       num_workers=args.j,
                                       pin_memory=True)

    scaler = None
    if args.amp:
        scaler = amp.GradScaler()

    start_epoch = 0
    max_test_acc = -1

    optimizer = None
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
    else:
        raise NotImplementedError(args.opt)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        net.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch'] + 1
        max_test_acc = checkpoint['max_test_acc']

    out_dir = os.path.join(args.out_dir,
                           f'T{args.T}_b{args.b}_{args.opt}_lr{args.lr}')

    if args.amp:
        out_dir += '_amp'

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print(f'Mkdir {out_dir}.')

    with open(os.path.join(out_dir, 'args.txt'), 'w',
              encoding='utf-8') as args_txt:
        args_txt.write(str(args))

    writer = SummaryWriter(out_dir, purge_step=start_epoch)
    with open(os.path.join(out_dir, 'args.txt'), 'w',
              encoding='utf-8') as args_txt:
        args_txt.write(str(args))
        args_txt.write('\n')
        args_txt.write(' '.join(sys.argv))

    encoder = encoding.PoissonEncoder()

    for epoch in range(start_epoch, args.epochs):
        start_time = time.time()
        net.train()
        train_loss = 0
        train_acc = 0
        train_samples = 0
        for img, label in train_data_loader:
            optimizer.zero_grad()
            img = img.to(args.device)
            label = label.to(args.device)
            label_onehot = F.one_hot(label, 10).float()

            if scaler is not None:
                with amp.autocast():
                    out_fr = 0.
                    for t in range(args.T):
                        encoded_img = encoder(img)
                        out_fr += net(encoded_img)
                    out_fr = out_fr / args.T
                    loss = F.mse_loss(out_fr, label_onehot)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                out_fr = 0.
                for t in range(args.T):
                    encoded_img = encoder(img)
                    out_fr += net(encoded_img)
                out_fr = out_fr / args.T
                loss = F.mse_loss(out_fr, label_onehot)
                loss.backward()
                optimizer.step()

            train_samples += label.numel()
            train_loss += loss.item() * label.numel()
            train_acc += (out_fr.argmax(1) == label).float().sum().item()

            functional.reset_net(net)

        train_time = time.time()
        train_speed = train_samples / (train_time - start_time)
        train_loss /= train_samples
        train_acc /= train_samples

        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('train_acc', train_acc, epoch)

        net.eval()
        test_loss = 0
        test_acc = 0
        test_samples = 0
        with torch.no_grad():
            for img, label in test_data_loader:
                img = img.to(args.device)
                label = label.to(args.device)
                label_onehot = F.one_hot(label, 10).float()
                out_fr = 0.
                for t in range(args.T):
                    encoded_img = encoder(img)
                    out_fr += net(encoded_img)
                out_fr = out_fr / args.T
                loss = F.mse_loss(out_fr, label_onehot)

                test_samples += label.numel()
                test_loss += loss.item() * label.numel()
                test_acc += (out_fr.argmax(1) == label).float().sum().item()
                functional.reset_net(net)
        test_time = time.time()
        test_speed = test_samples / (test_time - train_time)
        test_loss /= test_samples
        test_acc /= test_samples
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)

        save_max = False
        if test_acc > max_test_acc:
            max_test_acc = test_acc
            save_max = True

        checkpoint = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
            'max_test_acc': max_test_acc
        }

        if save_max:
            torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth'))

        torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth'))

        print(args)
        print(out_dir)
        print(
            f'epoch ={epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}, max_test_acc ={max_test_acc: .4f}'
        )
        print(
            f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s'
        )
        print(
            f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n'
        )

    # 保存绘图用数据
    net.eval()
    # 注册钩子
    output_layer = net.layer[-1]  # 输出层
    output_layer.v_seq = []
    output_layer.s_seq = []

    def save_hook(m, x, y):
        m.v_seq.append(m.v.unsqueeze(0))
        m.s_seq.append(y.unsqueeze(0))

    output_layer.register_forward_hook(save_hook)

    with torch.no_grad():
        img, label = test_dataset[0]
        img = img.to(args.device)
        out_fr = 0.
        for t in range(args.T):
            encoded_img = encoder(img)
            out_fr += net(encoded_img)
        out_spikes_counter_frequency = (out_fr / args.T).cpu().numpy()
        print(f'Firing rate: {out_spikes_counter_frequency}')

        output_layer.v_seq = torch.cat(output_layer.v_seq)
        output_layer.s_seq = torch.cat(output_layer.s_seq)
        v_t_array = output_layer.v_seq.cpu().numpy().squeeze(
        )  # v_t_array[i][j]表示神经元i在j时刻的电压值
        np.save("v_t_array.npy", v_t_array)
        s_t_array = output_layer.s_seq.cpu().numpy().squeeze(
        )  # s_t_array[i][j]表示神经元i在j时刻释放的脉冲,为0或1
        np.save("s_t_array.npy", s_t_array)
Example #18
0
    def _train_epoch(self, epoch: int) -> Dict[str, float]:
        """
        Trains one epoch and returns metrics.
        """
        logger.info("Epoch %d/%d", epoch, self._num_epochs - 1)
        cpu_memory_usage = []
        for worker, memory in common_util.peak_memory_mb().items():
            cpu_memory_usage.append((worker, memory))
            logger.info(f"Worker {worker} memory usage MB: {memory}")
        gpu_memory_usage = []
        for gpu, memory in common_util.gpu_memory_mb().items():
            gpu_memory_usage.append((gpu, memory))
            logger.info(f"GPU {gpu} memory usage MB: {memory}")

        regularization_penalty = self.model.get_regularization_penalty()

        train_loss = 0.0
        batch_loss = 0.0

        if regularization_penalty is not None:
            train_reg_loss = 0.0
            batch_reg_loss = 0.0
        else:
            train_reg_loss = None
            batch_reg_loss = None
        # Set the model to "train" mode.
        self._pytorch_model.train()

        # Get tqdm for the training batches
        batch_generator = iter(self.data_loader)
        batch_group_generator = common_util.lazy_groups_of(
            batch_generator, self._num_gradient_accumulation_steps
        )

        logger.info("Training")

        num_training_batches: Union[int, float]
        try:
            len_data_loader = len(self.data_loader)
            num_training_batches = math.ceil(
                len_data_loader / self._num_gradient_accumulation_steps
            )
        except TypeError:
            num_training_batches = float("inf")

        # Having multiple tqdm bars in case of distributed training will be a mess. Hence only the master's
        # progress is shown
        if self._master:
            batch_group_generator_tqdm = Tqdm.tqdm(
                batch_group_generator, total=num_training_batches
            )
        else:
            batch_group_generator_tqdm = batch_group_generator

        self._last_log = time.time()

        batches_this_epoch = 0
        if self._batch_num_total is None:
            self._batch_num_total = 0

        done_early = False
        for batch_group in batch_group_generator_tqdm:
            if self._distributed:
                # Check whether the other workers have stopped already (due to differing amounts of
                # data in each). If so, we can't proceed because we would hang when we hit the
                # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor
                # here because NCCL process groups apparently don't support BoolTensor.
                done = torch.tensor(0, device=self.cuda_device)
                torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
                if done.item() > 0:
                    done_early = True
                    logger.warning(
                        f"Worker {torch.distributed.get_rank()} finishing training early! "
                        "This implies that there is an imbalance in your training "
                        "data across the workers and that some amount of it will be "
                        "ignored. A small amount of this is fine, but a major imbalance "
                        "should be avoided. Note: This warning will appear unless your "
                        "data is perfectly balanced."
                    )
                    break

            batches_this_epoch += 1
            self._batch_num_total += 1
            batch_num_total = self._batch_num_total

            self.optimizer.zero_grad()

            batch_group_outputs = []
            for batch in batch_group:
                with amp.autocast(self._use_amp):
                    batch_outputs = self.batch_outputs(batch, for_training=True)
                    batch_group_outputs.append(batch_outputs)
                    loss = batch_outputs.get("loss")
                    reg_loss = batch_outputs.get("reg_loss")
                    if torch.isnan(loss):
                        raise ValueError("nan loss encountered")
                    loss = loss / len(batch_group)

                    batch_loss = loss.item()
                    train_loss += batch_loss
                    if reg_loss is not None:
                        reg_loss = reg_loss / len(batch_group)
                        batch_reg_loss = reg_loss.item()
                        train_reg_loss += batch_reg_loss

                if self._scaler is not None:
                    self._scaler.scale(loss).backward()
                else:
                    loss.backward()

            batch_grad_norm = self.rescale_gradients()

            # This does nothing if batch_num_total is None or you are using a
            # scheduler which doesn't update per batch.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step_batch(batch_num_total)
            if self._momentum_scheduler:
                self._momentum_scheduler.step_batch(batch_num_total)

            param_updates = None
            if self._tensorboard.should_log_histograms_this_batch() and self._master:
                # Get the magnitude of parameter updates for logging.  We need to do some
                # computation before and after the optimizer step, and it's expensive because of
                # GPU/CPU copies (necessary for large models, and for shipping to tensorboard), so
                # we don't do this every batch, only when it's requested.
                param_updates = {
                    name: param.detach().cpu().clone()
                    for name, param in self.model.named_parameters()
                }

                if self._scaler is not None:
                    self._scaler.step(self.optimizer)
                    self._scaler.update()
                else:
                    self.optimizer.step()

                for name, param in self.model.named_parameters():
                    param_updates[name].sub_(param.detach().cpu())
            else:
                if self._scaler is not None:
                    self._scaler.step(self.optimizer)
                    self._scaler.update()
                else:
                    self.optimizer.step()

            # Update moving averages
            if self._moving_average is not None:
                self._moving_average.apply(batch_num_total)

            # Update the description with the latest metrics
            metrics = training_util.get_metrics(
                self.model,
                train_loss,
                train_reg_loss,
                batch_loss,
                batch_reg_loss,
                batches_this_epoch,
                world_size=self._world_size,
                cuda_device=self.cuda_device,
            )

            if self._master:
                # Updating tqdm only for the master as the trainers wouldn't have one
                description = training_util.description_from_metrics(metrics)
                batch_group_generator_tqdm.set_description(description, refresh=False)
                self._tensorboard.log_batch(
                    self.model,
                    self.optimizer,
                    batch_grad_norm,
                    metrics,
                    batch_group,
                    param_updates,
                )

                self._checkpointer.maybe_save_checkpoint(self, epoch, batches_this_epoch)
            for callback in self._batch_callbacks:
                callback(
                    self,
                    batch_group,
                    batch_group_outputs,
                    epoch,
                    batches_this_epoch,
                    is_training=True,
                    is_master=self._master,
                )

        if self._distributed and not done_early:
            logger.warning(
                f"Worker {torch.distributed.get_rank()} completed its entire epoch (training)."
            )
            # Indicate that we're done so that any workers that have remaining data stop the epoch early.
            done = torch.tensor(1, device=self.cuda_device)
            torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
            assert done.item()

        # Let all workers finish their epoch before computing
        # the final statistics for the epoch.
        if self._distributed:
            dist.barrier()

        metrics = training_util.get_metrics(
            self.model,
            train_loss,
            train_reg_loss,
            batch_loss=None,
            batch_reg_loss=None,
            num_batches=batches_this_epoch,
            reset=True,
            world_size=self._world_size,
            cuda_device=self.cuda_device,
        )

        for (worker, memory) in cpu_memory_usage:
            metrics["worker_" + str(worker) + "_memory_MB"] = memory
        for (gpu_num, memory) in gpu_memory_usage:
            metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory
        return metrics
Example #19
0
def validation(model, criterion, valid_loader, process_group, current_phase):
    validation_loss = dict()
    audio_data = dict()

    model.eval()
    with torch.no_grad():
        with tqdm(desc='Valid', total=hp.training.n_validation_steps) as pbar:
            for inputs, ground_truth, conditioning in valid_loader:

                if pbar.n >= hp.training.n_validation_steps:
                    break

                inputs = inputs.to(args.local_rank, non_blocking=True)
                ground_truth = ground_truth.to(args.local_rank,
                                               non_blocking=True)
                conditioning = conditioning.to(args.local_rank,
                                               non_blocking=True)

                with autocast(enabled=hp.training.mixed_precision):
                    if current_phase == 0:

                        prediction = utils.core.ddp(model).generator.wavenet(
                            inputs, conditioning)
                        wavenet_loss = criterion.sample_loss(
                            ground_truth,
                            prediction) + criterion.spectrogram_loss(
                                ground_truth, prediction)

                        wavenet_loss = utils.core.all_reduce(
                            wavenet_loss, group=process_group)

                        utils.core.list_into_dict(validation_loss, 'wavenet',
                                                  wavenet_loss.item())

                        if pbar.n == hp.training.n_validation_steps - 1:
                            audio_data['input'] = inputs.detach().cpu().numpy()
                            audio_data['ground_truth'] = ground_truth.detach(
                            ).cpu().numpy()
                            audio_data['wavenet'] = prediction.detach().cpu(
                            ).numpy()

                    elif current_phase == 1:

                        prediction, prediction_postnet = utils.core.ddp(
                            model).generator(inputs, conditioning)
                        wavenet_loss = criterion.sample_loss(
                            ground_truth,
                            prediction) + criterion.spectrogram_loss(
                                ground_truth, prediction)
                        wavenet_postnet_loss = criterion.sample_loss(ground_truth, prediction_postnet) \
                                               + criterion.spectrogram_loss(ground_truth, prediction_postnet)

                        wavenet_loss = utils.core.all_reduce(
                            wavenet_loss, group=process_group)
                        wavenet_postnet_loss = utils.core.all_reduce(
                            wavenet_postnet_loss, group=process_group)

                        utils.core.list_into_dict(validation_loss, 'wavenet',
                                                  wavenet_loss.item())
                        utils.core.list_into_dict(validation_loss,
                                                  'wavenet-postnet',
                                                  wavenet_postnet_loss.item())

                        if pbar.n == hp.training.n_validation_steps - 1:
                            audio_data['input'] = inputs.detach().cpu().numpy()
                            audio_data['ground_truth'] = ground_truth.detach(
                            ).cpu().numpy()
                            audio_data['wavenet'] = prediction.detach().cpu(
                            ).numpy()
                            audio_data[
                                'wavenet-postnet'] = prediction_postnet.detach(
                                ).cpu().numpy()

                    else:

                        prediction, prediction_postnet, prediction_scores, \
                        discriminator_scores, L_FM_G = model(inputs, ground_truth, conditioning)

                        _, wavenet_loss, wavenet_postnet_loss, \
                        G_loss, D_losses = criterion(pbar.n, ground_truth, prediction, prediction_postnet,
                                                     prediction_scores,
                                                     discriminator_scores, L_FM_G)

                        wavenet_loss = utils.core.all_reduce(
                            wavenet_loss, group=process_group)
                        wavenet_postnet_loss = utils.core.all_reduce(
                            wavenet_postnet_loss, group=process_group)
                        G_loss = utils.core.all_reduce(G_loss,
                                                       group=process_group)
                        D_losses = [
                            utils.core.all_reduce(D_loss, group=process_group)
                            for D_loss in D_losses
                        ]

                        if G_loss is not None:
                            utils.core.list_into_dict(validation_loss,
                                                      'wavenet',
                                                      wavenet_loss.item())
                            utils.core.list_into_dict(
                                validation_loss, 'wavenet-postnet',
                                wavenet_postnet_loss.item())
                            utils.core.list_into_dict(validation_loss, 'G',
                                                      G_loss.item())
                        utils.core.list_into_dict(validation_loss, 'D_16kHz',
                                                  D_losses[0].item())
                        utils.core.list_into_dict(validation_loss, 'D_8kHz',
                                                  D_losses[1].item())
                        utils.core.list_into_dict(validation_loss, 'D_4kHz',
                                                  D_losses[2].item())
                        utils.core.list_into_dict(validation_loss, 'D_mel',
                                                  D_losses[3].item())

                        if pbar.n == hp.training.n_validation_steps - 1:
                            audio_data['input'] = inputs.detach().cpu().numpy()
                            audio_data['ground_truth'] = ground_truth.detach(
                            ).cpu().numpy()
                            audio_data['wavenet'] = prediction.detach().cpu(
                            ).numpy()
                            audio_data[
                                'wavenet-postnet'] = prediction_postnet.detach(
                                ).cpu().numpy()

                    pbar.set_postfix(losses={
                        key: value[-1]
                        for key, value in validation_loss.items()
                    })
                    pbar.update()

        for key, value in validation_loss.items():
            validation_loss[key] = np.mean(value)

        return validation_loss, audio_data
Example #20
0
    def _validation_loss(self, epoch: int) -> Tuple[float, float, int]:
        """
        Computes the validation loss. Returns it and the number of batches.
        """
        logger.info("Validating")

        self._pytorch_model.eval()

        # Replace parameter values with the shadow values from the moving averages.
        if self._moving_average is not None:
            self._moving_average.assign_average_value()

        if self._validation_data_loader is not None:
            validation_data_loader = self._validation_data_loader
        else:
            raise ConfigurationError(
                "Validation results cannot be calculated without a validation_data_loader"
            )

        regularization_penalty = self.model.get_regularization_penalty()

        # Having multiple tqdm bars in case of distributed training will be a mess. Hence only the master's
        # progress is shown
        if self._master:
            val_generator_tqdm = Tqdm.tqdm(validation_data_loader)
        else:
            val_generator_tqdm = validation_data_loader

        batches_this_epoch = 0
        val_loss = 0
        val_batch_loss = 0
        if regularization_penalty is not None:
            val_reg_loss = 0
            val_batch_reg_loss = 0
        else:
            val_reg_loss = None
            val_batch_reg_loss = None
        done_early = False
        for batch in val_generator_tqdm:
            if self._distributed:
                # Check whether the other workers have stopped already (due to differing amounts of
                # data in each). If so, we can't proceed because we would hang when we hit the
                # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor
                # here because NCCL process groups apparently don't support BoolTensor.
                done = torch.tensor(0, device=self.cuda_device)
                torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
                if done.item() > 0:
                    done_early = True
                    logger.warning(
                        f"Worker {torch.distributed.get_rank()} finishing validation early! "
                        "This implies that there is an imbalance in your validation "
                        "data across the workers and that some amount of it will be "
                        "ignored. A small amount of this is fine, but a major imbalance "
                        "should be avoided. Note: This warning will appear unless your "
                        "data is perfectly balanced."
                    )
                    break

            with amp.autocast(self._use_amp):
                batch_outputs = self.batch_outputs(batch, for_training=False)
                loss = batch_outputs.get("loss")
                reg_loss = batch_outputs.get("reg_loss")
                if loss is not None:
                    # You shouldn't necessarily have to compute a loss for validation, so we allow for
                    # `loss` to be None.  We need to be careful, though - `batches_this_epoch` is
                    # currently only used as the divisor for the loss function, so we can safely only
                    # count those batches for which we actually have a loss.  If this variable ever
                    # gets used for something else, we might need to change things around a bit.
                    batches_this_epoch += 1
                    val_batch_loss = loss.detach().cpu().numpy()
                    val_loss += val_batch_loss
                    if reg_loss is not None:
                        val_batch_reg_loss = reg_loss.detach().cpu().numpy()
                        val_reg_loss += val_batch_reg_loss

            # Update the description with the latest metrics
            val_metrics = training_util.get_metrics(
                self.model,
                val_loss,
                val_reg_loss,
                val_batch_loss,
                val_batch_reg_loss,
                batches_this_epoch,
                world_size=self._world_size,
                cuda_device=self.cuda_device,
            )

            description = training_util.description_from_metrics(val_metrics)
            if self._master:
                val_generator_tqdm.set_description(description, refresh=False)

            for callback in self._batch_callbacks:
                callback(
                    self,
                    [batch],
                    [batch_outputs],
                    epoch,
                    batches_this_epoch,
                    is_training=False,
                    is_master=self._master,
                )

        if self._distributed and not done_early:
            logger.warning(
                f"Worker {torch.distributed.get_rank()} completed its entire epoch (validation)."
            )
            # Indicate that we're done so that any workers that have remaining data stop validation early.
            done = torch.tensor(1, device=self.cuda_device)
            torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
            assert done.item()

        # Now restore the original parameter values.
        if self._moving_average is not None:
            self._moving_average.restore()

        return val_loss, val_reg_loss, batches_this_epoch
Example #21
0
    def prepare_model(
        self,
        model: torch.nn.Module,
        move_to_device: bool = True,
        wrap_ddp: bool = True,
        ddp_kwargs: Optional[Dict[str, Any]] = None,
    ) -> torch.nn.Module:
        """Prepares the model for distributed execution.

        This allows you to use the same exact code regardless of number of
        workers or the device type being used (CPU, GPU).

        Args:
            model (torch.nn.Module): A torch model to prepare.
            move_to_device: Whether to move the model to the correct
                device. If set to False, the model needs to manually be moved
                to the correct device.
            wrap_ddp: Whether to wrap models in
                ``DistributedDataParallel``.
            ddp_kwargs (Dict[str, Any]): Args to pass into
                ``DistributedDataParallel`` initialization if ``wrap_ddp`` is
                set to True.
        """
        ddp_kwargs = ddp_kwargs or {}

        rank = train.local_rank()

        device = self.get_device()

        if torch.cuda.is_available():
            torch.cuda.set_device(device)

        if move_to_device:
            logger.info(f"Moving model to device: {device}")
            model = model.to(device)

        def model_get_state(self):
            # `__getstate__` is an special method that informs pickle which attributes
            # to serialize. This custom implementation ensures that the wrapped forward
            # method and custom `__getstate__` method aren't serialized.
            if hasattr(self, "_original_get_state"):
                state = self._original_get_state()
                state["__getstate__"] = state["_original_get_state"]
                del state["_original_get_state"]
            else:
                # If model does not have a `__getstate__` already defined, use default
                # implementation.
                state = self.__dict__.copy()
                del state["__getstate__"]
            state["forward"] = state["_unwrapped_forward"]
            del state["_unwrapped_forward"]

            return state

        if self.amp_is_enabled:
            # Pickle cannot serialize the wrapped forward method. As a workaround,
            # define a custom `__getstate__` method that unwraps the forward method.
            model._unwrapped_forward = model.forward
            model.forward = autocast()(model.forward)

            # TODO(amogkam): Replace below logic with a generic "unpack model" method.
            # Replacing the `model.forward` method makes the model no longer
            # serializable. When serializing the model, we have to override the
            # `__getstate__` method to set back the original forward method.
            if hasattr(model, "__getstate__"):
                model._original_get_state = model.__getstate__
            # `__getstate__` must be a bound method rather than an callable attribute.
            # See https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance.  # noqa: E501
            model.__getstate__ = types.MethodType(model_get_state, model)

        if wrap_ddp and train.world_size() > 1:
            logger.info("Wrapping provided model in DDP.")
            if torch.cuda.is_available():
                model = DistributedDataParallel(model,
                                                device_ids=[rank],
                                                output_device=rank,
                                                **ddp_kwargs)
            else:
                model = DistributedDataParallel(model, **ddp_kwargs)

        return model
    def fit(
        self,
        train_objectives: Iterable[Tuple[DataLoader, nn.Module]],
        evaluator: SentenceEvaluator = None,
        epochs: int = 1,
        steps_per_epoch=None,
        scheduler: str = 'WarmupLinear',
        warmup_steps: int = 10000,
        optimizer_class: Type[Optimizer] = transformers.AdamW,
        optimizer_params: Dict[str, object] = {
            'lr': 2e-5,
            'correct_bias': True
        },
        weight_decay: float = 0.01,
        evaluation_steps: int = 0,
        output_path: str = None,
        save_best_model: bool = True,
        max_grad_norm: float = 1,
        use_amp: bool = False,
        callback: Callable[[float, int, int], None] = None,
        show_progress_bar: bool = True,
        baseline: float = 0.01,
        patience: int = 5,
    ):
        """
        Train the model with the given training objective
        Each training objective is sampled in turn for one batch.
        We sample only as many batches from each objective as there are in the smallest one
        to make sure of equal training with each dataset.
        :param train_objectives: Tuples of (DataLoader, LossFunction). Pass more than one for multi-task learning
        :param evaluator: An evaluator (sentence_transformers.evaluation) evaluates the model performance during training on held-out dev data. It is used to determine the best model that is saved to disc.
        :param epochs: Number of epochs for training
        :param steps_per_epoch: Number of training steps per epoch. If set to None (default), one epoch is equal the DataLoader size from train_objectives.
        :param scheduler: Learning rate scheduler. Available schedulers: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
        :param warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is increased from o up to the maximal learning rate. After these many training steps, the learning rate is decreased linearly back to zero.
        :param optimizer_class: Optimizer
        :param optimizer_params: Optimizer parameters
        :param weight_decay: Weight decay for model parameters
        :param evaluation_steps: If > 0, evaluate the model using evaluator after each number of training steps
        :param output_path: Storage path for the model and evaluation files
        :param save_best_model: If true, the best model (according to evaluator) is stored at output_path
        :param max_grad_norm: Used for gradient normalization.
        :param use_amp: Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0
        :param callback: Callback function that is invoked after each evaluation.
                It must accept the following three parameters in this order:
                `score`, `epoch`, `steps`
        :param show_progress_bar: If True, output a tqdm progress bar
        :param baseline: minimum improvement in the accuracy for a new model to be saved and best_score to be updated
        :param patience: maximum number of epochs to go without an improvement in the accuracy
        """
        self.acc_list = [1e-6]  # stores the accuracy while training
        training_acc_list = []

        t_evaluator = LabelAccuracyEvaluator(
            dataloader=train_objectives[0][0],
            softmax_model=train_objectives[0][1],
            name='lae-training')

        self.baseline = baseline
        self.patience = patience

        if use_amp:
            from torch.cuda.amp import autocast
            scaler = torch.cuda.amp.GradScaler()

        self.to(self._target_device)

        if output_path is not None:
            os.makedirs(output_path, exist_ok=True)

        dataloaders = [dataloader for dataloader, _ in train_objectives]

        # Use smart batching
        for dataloader in dataloaders:
            dataloader.collate_fn = self.smart_batching_collate

        loss_models = [loss for _, loss in train_objectives]
        for loss_model in loss_models:
            loss_model.to(self._target_device)

        self.best_score = -9999999

        if steps_per_epoch is None or steps_per_epoch == 0:
            steps_per_epoch = min(
                [len(dataloader) for dataloader in dataloaders])

        num_train_steps = int(steps_per_epoch * epochs)

        # Prepare optimizers
        optimizers = []
        schedulers = []
        for loss_model in loss_models:
            param_optimizer = list(loss_model.named_parameters())

            no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                weight_decay
            }, {
                'params': [
                    p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            }]

            optimizer = optimizer_class(optimizer_grouped_parameters,
                                        **optimizer_params)
            scheduler_obj = self._get_scheduler(optimizer,
                                                scheduler=scheduler,
                                                warmup_steps=warmup_steps,
                                                t_total=num_train_steps)

            optimizers.append(optimizer)
            schedulers.append(scheduler_obj)

        global_step = 0
        data_iterators = [iter(dataloader) for dataloader in dataloaders]

        num_train_objectives = len(train_objectives)

        skip_scheduler = False
        for epoch in trange(epochs,
                            desc="Epoch",
                            disable=not show_progress_bar):
            training_steps = 0

            for loss_model in loss_models:
                loss_model.zero_grad()
                loss_model.train()

            for _ in trange(steps_per_epoch,
                            desc="Iteration",
                            smoothing=0.05,
                            disable=not show_progress_bar):
                for train_idx in range(num_train_objectives):
                    loss_model = loss_models[train_idx]
                    optimizer = optimizers[train_idx]
                    scheduler = schedulers[train_idx]
                    data_iterator = data_iterators[train_idx]

                    try:
                        data = next(data_iterator)
                    except StopIteration:
                        data_iterator = iter(dataloaders[train_idx])
                        data_iterators[train_idx] = data_iterator
                        data = next(data_iterator)

                    features, labels = data

                    if use_amp:
                        with autocast():
                            loss_value = loss_model(features, labels)

                        scale_before_step = scaler.get_scale()
                        scaler.scale(loss_value).backward()
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(loss_model.parameters(),
                                                       max_grad_norm)
                        scaler.step(optimizer)
                        scaler.update()

                        skip_scheduler = scaler.get_scale(
                        ) != scale_before_step
                    else:
                        loss_value = loss_model(features, labels)
                        loss_value.backward()
                        torch.nn.utils.clip_grad_norm_(loss_model.parameters(),
                                                       max_grad_norm)
                        optimizer.step()

                    optimizer.zero_grad()

                    if not skip_scheduler:
                        scheduler.step()

                training_steps += 1
                global_step += 1

                if evaluation_steps > 0 and training_steps % evaluation_steps == 0:
                    for loss_model in loss_models:
                        loss_model.zero_grad()
                        loss_model.train()

            # training evaluation
            training_acc_evaluated = t_evaluator(self,
                                                 output_path=output_path,
                                                 epoch=epoch,
                                                 steps=-1)
            training_acc_list.append(training_acc_evaluated)

            wandb.log({"train_acc": training_acc_evaluated, "epoch": epoch})

            # validation evaluation
            flag = self._eval_during_training(evaluator, output_path, epoch,
                                              -1)

            if flag is True:
                print(f'Epoch: {epoch}')
                print(f"Best score: {self.best_score}")
                print('=' * 60)
            else:
                print('TRAINING EXITED. Best model has been found.')
                print(f'Epoch: {epoch}')
                print(f"Best score: {self.best_score}")
                print('=' * 60)
                return

            # removing the unnecessary first element in ACC_LIST that needed to be there for epoch 1
            if epoch == 0:
                del self.acc_list[0]
Example #23
0
def train(max_iters, dataset, test_batch, G, D, optimizer_G, optimizer_D,
          recon_lambda, style_lambda, vgg_lambda, content_lambda,
          color_augment, amp, device, save):

    status = Status(max_iters)
    loss = LSGANLoss()
    vgg = VGGLoss(device, p=1)
    scaler = GradScaler() if amp else None

    while status.batches_done < max_iters:
        for rgb, gray in dataset:
            optimizer_D.zero_grad()
            optimizer_G.zero_grad()

            rgb = rgb.to(device)
            gray = gray.to(device)
            '''Discriminator'''
            with autocast(amp):
                real = color_augment(rgb)
                # D(real)
                real_prob, _ = D(torch.cat([gray, real], dim=1))
                # D(G(gray, rgb))
                fake = G(gray, real)
                fake_prob, _ = D(torch.cat([gray, fake.detach()], dim=1))

                # loss
                D_loss = loss.d_loss(real_prob, fake_prob)

            if scaler is not None:
                scaler.scale(D_loss).backward()
                scaler.step(optimizer_D)
            else:
                D_loss.backward()
                optimizer_D.step()
            '''Generator'''
            with autocast(amp):
                # D(G(gray, rgb))
                fake_prob, _ = D(torch.cat([gray, fake], dim=1))

                # loss
                G_loss = loss.g_loss(fake_prob)
                if recon_lambda > 0:
                    G_loss += l1_loss(fake, real) * recon_lambda
                if style_lambda > 0:
                    G_loss += vgg.style_loss(real, fake) * style_lambda
                if vgg_lambda > 0:
                    G_loss += vgg.vgg_loss(real, fake) * vgg_lambda
                if content_lambda > 0:
                    G_loss += vgg.content_loss(sketch.repeat(1, 3, 1, 1),
                                               fake) * vgg_lambda

            if scaler is not None:
                scaler.scale(G_loss).backward()
                scaler.step(optimizer_G)
            else:
                G_loss.backward()
                optimizer_G.step()

            # save
            if status.batches_done % save == 0:
                with torch.no_grad():
                    G.eval()
                    image = G(test_batch[1], test_batch[0])
                    G.train()
                image_grid = _image_grid(test_batch[1], test_batch[0], image)
                save_image(
                    image_grid,
                    f'implementations/original/EDCNN/result/{status.batches_done}.jpg',
                    nrow=3 * 3,
                    normalize=True,
                    value_range=(-1, 1))
                torch.save(
                    G.state_dict(),
                    f'implementations/original/EDCNN/result/G_{status.batches_done}.pt'
                )
            save_image(fake,
                       'running.jpg',
                       nrow=4,
                       normalize=True,
                       value_range=(-1, 1))

            # updates
            loss_dict = dict(
                G=G_loss.item() if not torch.any(torch.isnan(G_loss)) else 0,
                D=D_loss.item() if not torch.any(torch.isnan(D_loss)) else 0)
            status.update(loss_dict)
            if scaler is not None:
                scaler.update()

            if status.batches_done == max_iters:
                break

    status.plot()
Example #24
0
    scaler = GradScaler()

    for epoch in range(start_epoch, num_epochs):
        # Train
        model.train()
        runningLoss = 0.0
        runningLossCounter = 0.0
        train_loss = 0.0
        print('Epoch ' + str(epoch) + ': Train')
        with tqdm(total=len(train_loader)) as pbar:
            for i, data in enumerate(train_loader):
                img = data['img']['data'].squeeze(-1)  # * 2 - 1 #For 2D cases
                images = Variable(img).to(device)
                optimizer.zero_grad()
                with autocast():
                    # VAE Part
                    loss_vae = 0
                    if ce_factor < 1:
                        x_r, _, z_dist = model(images)

                        kl_loss = 0
                        if beta > 0:
                            if IsVAE:
                                kl_loss = kl_loss_fn(
                                    z_dist, sumdim=(1,)) * beta
                            else:
                                sys.exit("KLD Not gonna work")
                                kl_loss = kl_loss_fn(
                                    z_dist, sumdim=(1, 2)) * beta
                        if model.d == 3:
Example #25
0
def train_step(
    model: nn.Module,
    train_loader,
    criterion,
    device: str,
    optimizer,
    scheduler=None,
    num_batches: int = None,
    log_interval: int = 100,
    scaler=None,
):
    """
    Performs one step of training. Calculates loss, forward pass, computes gradient and returns metrics.
    Args:
        model : PyTorch Detr Model.
        train_loader : Train loader.
        device : "cuda" or "cpu"
        criterion : Detr Loss function to be optimized.
        optimizer : Torch optimizer to train.
        scheduler : Learning rate scheduler.
        num_batches : (optional) Integer To limit training to certain number of batches.
        log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch.
        scaler: (optional)  Pass torch.cuda.amp.GradScaler() for fp16 precision Training.
    """

    model = model.to(device)
    criterion = criterion.to(device)
    start_train_step = time.time()
    model.train()
    last_idx = len(train_loader) - 1
    batch_time_m = utils.AverageMeter()
    criterion.train()
    cnt = 0
    batch_start = time.time()
    metrics = OrderedDict()

    total_loss = utils.AverageMeter()
    bbox_loss = utils.AverageMeter()
    giou_loss = utils.AverageMeter()
    labels_loss = utils.AverageMeter()

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        last_batch = batch_idx == last_idx
        images = list(image.to(device) for image in inputs)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()

        if scaler is not None:
            with amp.autocast():
                outputs = model(images)
                loss_dict = criterion(outputs, targets)
                weight_dict = criterion.weight_dict
                loss = sum(loss_dict[k] * weight_dict[k]
                           for k in loss_dict.keys() if k in weight_dict)
                scaler.scale(loss).backward()
                # Step using scaler.step()
                scaler.step(optimizer)
                # Update for next iteration
                scaler.update()

        else:
            outputs = model(images)
            loss_dict = criterion(outputs, targets)
            weight_dict = criterion.weight_dict
            loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()
                       if k in weight_dict)
            loss.backward()
            optimizer.step()

        if scheduler is not None:
            scheduler.step()

        cnt += 1
        total_loss.update(loss.item())
        bbox_loss.update(loss_dict["loss_bbox"].item())
        giou_loss.update(loss_dict["loss_giou"].item())
        labels_loss.update(loss_dict["loss_ce"].item())

        batch_time_m.update(time.time() - batch_start)
        batch_start = time.time()

        if last_batch or batch_idx % log_interval == 0:  # If we reach the log intervel
            print(
                "Batch Train Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  "
                .format(batch_time=batch_time_m, ))

        if num_batches is not None:
            if cnt >= num_batches:
                end_train_step = time.time()
                metrics["total_loss"] = total_loss.avg
                metrics["bbox_loss"] = bbox_loss.avg
                metrics["giou_loss"] = giou_loss.avg
                metrics["labels_loss"] = labels_loss.avg

                print(f"Done till {num_batches} train batches")
                print(
                    f"Time taken for Training step = {end_train_step - start_train_step} sec"
                )
                return metrics

    end_train_step = time.time()
    metrics["total_loss"] = total_loss.avg
    metrics["bbox_loss"] = bbox_loss.avg
    metrics["giou_loss"] = giou_loss.avg
    metrics["labels_loss"] = labels_loss.avg
    print(
        f"Time taken for Training step = {end_train_step - start_train_step} sec"
    )
    return metrics
Example #26
0
    def fit_batch(self, batch):
        """Trains one batch"""
        # Unpacking batch list
        mixture = batch.mix_sig
        targets = [batch.s1_sig, batch.s2_sig, batch.s3_sig, batch.s4_sig]

        if self.hparams.auto_mix_prec:
            with autocast():
                predictions, targets = self.compute_forward(
                    mixture, targets, sb.Stage.TRAIN
                )
                loss = self.compute_objectives(predictions, targets)

                # hard threshold the easy dataitems
                if self.hparams.threshold_byloss:
                    th = self.hparams.threshold
                    loss_to_keep = loss[loss > th]
                    if loss_to_keep.nelement() > 0:
                        loss = loss_to_keep.mean()
                else:
                    loss = loss.mean()

            if (
                loss < self.hparams.loss_upper_lim and loss.nelement() > 0
            ):  # the fix for computational problems
                self.scaler.scale(loss).backward()
                if self.hparams.clip_grad_norm >= 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.modules.parameters(), self.hparams.clip_grad_norm,
                    )
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                self.nonfinite_count += 1
                logger.info(
                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
                        self.nonfinite_count
                    )
                )
                loss.data = torch.tensor(0).to(self.device)
        else:
            predictions, targets = self.compute_forward(
                mixture, targets, sb.Stage.TRAIN
            )
            loss = self.compute_objectives(predictions, targets)

            if self.hparams.threshold_byloss:
                th = self.hparams.threshold
                loss_to_keep = loss[loss > th]
                if loss_to_keep.nelement() > 0:
                    loss = loss_to_keep.mean()
            else:
                loss = loss.mean()

            if (
                loss < self.hparams.loss_upper_lim and loss.nelement() > 0
            ):  # the fix for computational problems
                loss.backward()
                if self.hparams.clip_grad_norm >= 0:
                    torch.nn.utils.clip_grad_norm_(
                        self.modules.parameters(), self.hparams.clip_grad_norm
                    )
                self.optimizer.step()
            else:
                self.nonfinite_count += 1
                logger.info(
                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
                        self.nonfinite_count
                    )
                )
                loss.data = torch.tensor(0).to(self.device)
        self.optimizer.zero_grad()

        return loss.detach().cpu()
    def train(engine, mini_batch):
        for language_model, model, optimizer in zip(engine.language_models,
                                                    engine.models,
                                                    engine.optimizers):
            language_model.eval()
            model.train()
            if engine.state.iteration % engine.config.iteration_per_update == 1 or \
                engine.config.iteration_per_update == 1:
                if engine.state.iteration > 1:
                    optimizer.zero_grad()

        device = next(engine.models[0].parameters()).device
        mini_batch.src = (mini_batch.src[0].to(device),
                          mini_batch.src[1].to(device))
        mini_batch.tgt = (mini_batch.tgt[0].to(device),
                          mini_batch.tgt[1].to(device))

        with autocast(not engine.config.off_autocast):
            # X2Y
            x, y = (mini_batch.src[0][:, 1:-1],
                    mini_batch.src[1] - 2), mini_batch.tgt[0][:, :-1]
            x_hat_lm, y_hat_lm = None, None
            # |x| = (batch_size, n)
            # |y| = (batch_size, m)
            y_hat = engine.models[X2Y](x, y)
            # |y_hat| = (batch_size, m, y_vocab_size)

            if engine.state.epoch > engine.config.dsl_n_warmup_epochs:
                with torch.no_grad():
                    y_hat_lm = engine.language_models[X2Y](y)
                    # |y_hat_lm| = |y_hat|

            #Y2X
            # Since encoder in seq2seq takes packed_sequence instance,
            # we need to re-sort if we use reversed src and tgt.
            x, y, restore_indice = DualSupervisedTrainingEngine._reorder(
                mini_batch.src[0][:, :-1],
                mini_batch.tgt[0][:, 1:-1],
                mini_batch.tgt[1] - 2,
            )
            # |x| = (batch_size, n)
            # |y| = (batch_size, m)
            x_hat = DualSupervisedTrainingEngine._restore_order(
                engine.models[Y2X](y, x),
                restore_indice=restore_indice,
            )
            # |x_hat| = (batch_size, n, x_vocab_size)

            if engine.state.epoch > engine.config.dsl_n_warmup_epochs:
                with torch.no_grad():
                    x_hat_lm = DualSupervisedTrainingEngine._restore_order(
                        engine.language_models[Y2X](x),
                        restore_indice=restore_indice,
                    )
                    # |x_hat_lm| = |x_hat|

            x, y = mini_batch.src[0][:, 1:], mini_batch.tgt[0][:, 1:]
            loss_x2y, loss_y2x, dual_loss = DualSupervisedTrainingEngine._get_loss(
                x,
                y,
                x_hat,
                y_hat,
                engine.crits,
                x_hat_lm,
                y_hat_lm,
                # According to the paper, DSL should be warm-started.
                # Thus, we turn-off the regularization at the beginning.
                lagrange=engine.config.dsl_lambda if
                engine.state.epoch > engine.config.dsl_n_warmup_epochs else .0)

            backward_targets = [
                loss_x2y.div(y.size(0)).div(
                    engine.config.iteration_per_update),
                loss_y2x.div(x.size(0)).div(
                    engine.config.iteration_per_update),
            ]

        for scaler, backward_target in zip(engine.scalers, backward_targets):
            if engine.config.gpu_id >= 0 and not engine.config.off_autocast:
                scaler.scale(backward_target).backward()
            else:
                backward_target.backward()

        x_word_count = int(mini_batch.src[1].sum())
        y_word_count = int(mini_batch.tgt[1].sum())
        p_norm = float(
            get_parameter_norm(
                list(engine.models[X2Y].parameters()) +
                list(engine.models[Y2X].parameters())))
        g_norm = float(
            get_grad_norm(
                list(engine.models[X2Y].parameters()) +
                list(engine.models[Y2X].parameters())))

        if engine.state.iteration % engine.config.iteration_per_update == 0 and \
            engine.state.iteration > 0:
            for model, optimizer, scaler in zip(engine.models,
                                                engine.optimizers,
                                                engine.scalers):
                torch_utils.clip_grad_norm_(
                    model.parameters(),
                    engine.config.max_grad_norm,
                )
                # Take a step of gradient descent.
                if engine.config.gpu_id >= 0 and not engine.config.off_autocast:
                    # Use scaler instead of engine.optimizer.step()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()

        return {
            'x2y':
            float(loss_x2y / y_word_count),
            'y2x':
            float(loss_y2x / x_word_count),
            'reg':
            float(dual_loss / x.size(0)),
            '|param|':
            p_norm if not np.isnan(p_norm) and not np.isinf(p_norm) else 0.,
            '|g_param|':
            g_norm if not np.isnan(g_norm) and not np.isinf(g_norm) else 0.,
        }
def train_fn(train_loader, teacher_model, model, criterion, optimizer, epoch,
             scheduler, device):
    if CFG.device == 'GPU':
        scaler = GradScaler()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to train mode
    model.train()
    start = end = time.time()
    global_step = 0
    for step, (images, images_annot, labels) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        with torch.no_grad():
            teacher_features, _, _ = teacher_model(images_annot.to(device))
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        if CFG.device == 'GPU':
            with autocast():
                features, _, y_preds = model(images)
                loss = criterion(teacher_features, features, y_preds, labels)
                # record loss
                losses.update(loss.item(), batch_size)
                if CFG.gradient_accumulation_steps > 1:
                    loss = loss / CFG.gradient_accumulation_steps
                scaler.scale(loss).backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), CFG.max_grad_norm)
                if (step + 1) % CFG.gradient_accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    global_step += 1
        elif CFG.device == 'TPU':
            features, _, y_preds = model(images)
            loss = criterion(teacher_features, features, y_preds, labels)
            # record loss
            losses.update(loss.item(), batch_size)
            if CFG.gradient_accumulation_steps > 1:
                loss = loss / CFG.gradient_accumulation_steps
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       CFG.max_grad_norm)
            if (step + 1) % CFG.gradient_accumulation_steps == 0:
                xm.optimizer_step(optimizer, barrier=True)
                optimizer.zero_grad()
                global_step += 1
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if CFG.device == 'GPU':
            if step % CFG.print_freq == 0 or step == (len(train_loader) - 1):
                print('Epoch: [{0}][{1}/{2}] '
                      'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                      'Elapsed {remain:s} '
                      'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                      'Grad: {grad_norm:.4f}  '
                      #'LR: {lr:.6f}  '
                      .format(
                       epoch+1, step, len(train_loader), batch_time=batch_time,
                       data_time=data_time, loss=losses,
                       remain=timeSince(start, float(step+1)/len(train_loader)),
                       grad_norm=grad_norm,
                       #lr=scheduler.get_lr()[0],
                       ))
        elif CFG.device == 'TPU':
            if step % CFG.print_freq == 0 or step == (len(train_loader) - 1):
                xm.master_print('Epoch: [{0}][{1}/{2}] '
                                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                                'Elapsed {remain:s} '
                                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                                'Grad: {grad_norm:.4f}  '
                                #'LR: {lr:.6f}  '
                                .format(
                                epoch+1, step, len(train_loader), batch_time=batch_time,
                                data_time=data_time, loss=losses,
                                remain=timeSince(start, float(step+1)/len(train_loader)),
                                grad_norm=grad_norm,
                                #lr=scheduler.get_lr()[0],
                                ))
    return losses.avg
Example #29
0
    def train(
        self,
        train_dataset,
        output_dir,
        show_running_loss=True,
        eval_data=None,
        verbose=True,
        **kwargs,
    ):
        """
        Trains the model on train_dataset.

        Utility function to be used by the train_model() method. Not intended to be used directly.
        """

        model = self.model
        args = self.args
        device = self.device

        tb_writer = SummaryWriter(logdir=args.tensorboard_dir)
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(
            train_dataset,
            sampler=train_sampler,
            batch_size=args.train_batch_size,
            num_workers=self.args.dataloader_num_workers,
        )

        if args.max_steps > 0:
            t_total = args.max_steps
            args.num_train_epochs = args.max_steps // (
                len(train_dataloader) // args.gradient_accumulation_steps) + 1
        else:
            t_total = len(
                train_dataloader
            ) // args.gradient_accumulation_steps * args.num_train_epochs

        no_decay = ["bias", "LayerNorm.weight"]

        optimizer_grouped_parameters = []
        custom_parameter_names = set()
        for group in self.args.custom_parameter_groups:
            params = group.pop("params")
            custom_parameter_names.update(params)
            param_group = {**group}
            param_group["params"] = [
                p for n, p in model.named_parameters() if n in params
            ]
            optimizer_grouped_parameters.append(param_group)

        for group in self.args.custom_layer_parameters:
            layer_number = group.pop("layer")
            layer = f"layer.{layer_number}."
            group_d = {**group}
            group_nd = {**group}
            group_nd["weight_decay"] = 0.0
            params_d = []
            params_nd = []
            for n, p in model.named_parameters():
                if n not in custom_parameter_names and layer in n:
                    if any(nd in n for nd in no_decay):
                        params_nd.append(p)
                    else:
                        params_d.append(p)
                    custom_parameter_names.add(n)
            group_d["params"] = params_d
            group_nd["params"] = params_nd

            optimizer_grouped_parameters.append(group_d)
            optimizer_grouped_parameters.append(group_nd)

        if not self.args.train_custom_parameters_only:
            optimizer_grouped_parameters.extend([
                {
                    "params": [
                        p for n, p in model.named_parameters()
                        if n not in custom_parameter_names and not any(
                            nd in n for nd in no_decay)
                    ],
                    "weight_decay":
                    args.weight_decay,
                },
                {
                    "params": [
                        p for n, p in model.named_parameters()
                        if n not in custom_parameter_names and any(
                            nd in n for nd in no_decay)
                    ],
                    "weight_decay":
                    0.0,
                },
            ])

        warmup_steps = math.ceil(t_total * args.warmup_ratio)
        args.warmup_steps = warmup_steps if args.warmup_steps == 0 else args.warmup_steps

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=t_total)

        if (args.model_name and os.path.isfile(
                os.path.join(args.model_name, "optimizer.pt"))
                and os.path.isfile(
                    os.path.join(args.model_name, "scheduler.pt"))):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(args.model_name, "optimizer.pt")))
            scheduler.load_state_dict(
                torch.load(os.path.join(args.model_name, "scheduler.pt")))

        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        logger.info(" Training started")

        global_step = 0
        training_progress_scores = None
        tr_loss, logging_loss = 0.0, 0.0
        model.zero_grad()
        train_iterator = trange(int(args.num_train_epochs),
                                desc="Epoch",
                                disable=args.silent,
                                mininterval=0)
        epoch_number = 0
        best_eval_metric = None
        early_stopping_counter = 0
        steps_trained_in_current_epoch = 0
        epochs_trained = 0

        if args.model_name and os.path.exists(args.model_name):
            try:
                # set global_step to gobal_step of last saved checkpoint from model path
                checkpoint_suffix = args.model_name.split("/")[-1].split("-")
                if len(checkpoint_suffix) > 2:
                    checkpoint_suffix = checkpoint_suffix[1]
                else:
                    checkpoint_suffix = checkpoint_suffix[-1]
                global_step = int(checkpoint_suffix)
                epochs_trained = global_step // (
                    len(train_dataloader) // args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = global_step % (
                    len(train_dataloader) // args.gradient_accumulation_steps)

                logger.info(
                    "   Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("   Continuing training from epoch %d",
                            epochs_trained)
                logger.info("   Continuing training from global step %d",
                            global_step)
                logger.info(
                    "   Will skip the first %d steps in the current epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                logger.info("   Starting fine-tuning.")

        if args.evaluate_during_training:
            training_progress_scores = self._create_training_progress_scores(
                **kwargs)

        if args.wandb_project:
            wandb.init(project=args.wandb_project,
                       config={**asdict(args)},
                       **args.wandb_kwargs)
            wandb.watch(self.model)

        if args.fp16:
            from torch.cuda import amp

            scaler = amp.GradScaler()

        for current_epoch in train_iterator:
            model.train()
            if epochs_trained > 0:
                epochs_trained -= 1
                continue
            train_iterator.set_description(
                f"Epoch {epoch_number + 1} of {args.num_train_epochs}")
            batch_iterator = tqdm(
                train_dataloader,
                desc=f"Running Epoch {epoch_number} of {args.num_train_epochs}",
                disable=args.silent,
                mininterval=0,
            )
            for step, batch in enumerate(batch_iterator):
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue
                batch = tuple(t.to(device) for t in batch)

                inputs = self._get_inputs_dict(batch)
                if args.fp16:
                    with amp.autocast():
                        outputs = model(**inputs)
                        # model outputs are always tuple in pytorch-transformers (see doc)
                        loss = outputs[0]
                else:
                    outputs = model(**inputs)
                    # model outputs are always tuple in pytorch-transformers (see doc)
                    loss = outputs[0]

                if args.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training

                current_loss = loss.item()

                if show_running_loss:
                    batch_iterator.set_description(
                        f"Epochs {epoch_number}/{args.num_train_epochs}. Running Loss: {current_loss:9.4f}"
                    )

                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()

                tr_loss += loss.item()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                    if args.fp16:
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        # Log metrics
                        tb_writer.add_scalar("lr",
                                             scheduler.get_last_lr()[0],
                                             global_step)
                        tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                             args.logging_steps, global_step)
                        logging_loss = tr_loss
                        if args.wandb_project:
                            wandb.log({
                                "Training loss": current_loss,
                                "lr": scheduler.get_last_lr()[0],
                                "global_step": global_step,
                            })

                    if args.save_steps > 0 and global_step % args.save_steps == 0:
                        # Save model checkpoint
                        output_dir_current = os.path.join(
                            output_dir, "checkpoint-{}".format(global_step))

                        self.save_model(output_dir_current,
                                        optimizer,
                                        scheduler,
                                        model=model)

                    if args.evaluate_during_training and (
                            args.evaluate_during_training_steps > 0
                            and global_step %
                            args.evaluate_during_training_steps == 0):
                        # Only evaluate when single GPU otherwise metrics may not average well
                        results = self.eval_model(
                            eval_data,
                            verbose=verbose
                            and args.evaluate_during_training_verbose,
                            silent=args.evaluate_during_training_silent,
                            **kwargs,
                        )
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)

                        output_dir_current = os.path.join(
                            output_dir, "checkpoint-{}".format(global_step))

                        if args.save_eval_checkpoints:
                            self.save_model(output_dir_current,
                                            optimizer,
                                            scheduler,
                                            model=model,
                                            results=results)

                        training_progress_scores["global_step"].append(
                            global_step)
                        training_progress_scores["train_loss"].append(
                            current_loss)
                        for key in results:
                            training_progress_scores[key].append(results[key])
                        report = pd.DataFrame(training_progress_scores)
                        report.to_csv(
                            os.path.join(args.output_dir,
                                         "training_progress_scores.csv"),
                            index=False,
                        )

                        if args.wandb_project:
                            wandb.log(
                                self._get_last_metrics(
                                    training_progress_scores))

                        if not best_eval_metric:
                            best_eval_metric = results[
                                args.early_stopping_metric]
                            self.save_model(args.best_model_dir,
                                            optimizer,
                                            scheduler,
                                            model=model,
                                            results=results)
                        if best_eval_metric and args.early_stopping_metric_minimize:
                            if results[
                                    args.
                                    early_stopping_metric] - best_eval_metric < args.early_stopping_delta:
                                best_eval_metric = results[
                                    args.early_stopping_metric]
                                self.save_model(args.best_model_dir,
                                                optimizer,
                                                scheduler,
                                                model=model,
                                                results=results)
                                early_stopping_counter = 0
                            else:
                                if args.use_early_stopping:
                                    if early_stopping_counter < args.early_stopping_patience:
                                        early_stopping_counter += 1
                                        if verbose:
                                            logger.info(
                                                f" No improvement in {args.early_stopping_metric}"
                                            )
                                            logger.info(
                                                f" Current step: {early_stopping_counter}"
                                            )
                                            logger.info(
                                                f" Early stopping patience: {args.early_stopping_patience}"
                                            )
                                    else:
                                        if verbose:
                                            logger.info(
                                                f" Patience of {args.early_stopping_patience} steps reached"
                                            )
                                            logger.info(
                                                " Training terminated.")
                                            train_iterator.close()
                                        return (
                                            global_step,
                                            tr_loss / global_step if not self.
                                            args.evaluate_during_training else
                                            training_progress_scores,
                                        )
                        else:
                            if results[
                                    args.
                                    early_stopping_metric] - best_eval_metric > args.early_stopping_delta:
                                best_eval_metric = results[
                                    args.early_stopping_metric]
                                self.save_model(args.best_model_dir,
                                                optimizer,
                                                scheduler,
                                                model=model,
                                                results=results)
                                early_stopping_counter = 0
                            else:
                                if args.use_early_stopping:
                                    if early_stopping_counter < args.early_stopping_patience:
                                        early_stopping_counter += 1
                                        if verbose:
                                            logger.info(
                                                f" No improvement in {args.early_stopping_metric}"
                                            )
                                            logger.info(
                                                f" Current step: {early_stopping_counter}"
                                            )
                                            logger.info(
                                                f" Early stopping patience: {args.early_stopping_patience}"
                                            )
                                    else:
                                        if verbose:
                                            logger.info(
                                                f" Patience of {args.early_stopping_patience} steps reached"
                                            )
                                            logger.info(
                                                " Training terminated.")
                                            train_iterator.close()
                                        return (
                                            global_step,
                                            tr_loss / global_step if not self.
                                            args.evaluate_during_training else
                                            training_progress_scores,
                                        )

            epoch_number += 1
            output_dir_current = os.path.join(
                output_dir,
                "checkpoint-{}-epoch-{}".format(global_step, epoch_number))

            if args.save_model_every_epoch or args.evaluate_during_training:
                os.makedirs(output_dir_current, exist_ok=True)

            if args.save_model_every_epoch:
                self.save_model(output_dir_current,
                                optimizer,
                                scheduler,
                                model=model)

            if args.evaluate_during_training and args.evaluate_each_epoch:
                results = self.eval_model(
                    eval_data,
                    verbose=verbose and args.evaluate_during_training_verbose,
                    silent=args.evaluate_during_training_silent,
                    **kwargs,
                )

                self.save_model(output_dir_current,
                                optimizer,
                                scheduler,
                                results=results)

                training_progress_scores["global_step"].append(global_step)
                training_progress_scores["train_loss"].append(current_loss)
                for key in results:
                    training_progress_scores[key].append(results[key])
                report = pd.DataFrame(training_progress_scores)
                report.to_csv(os.path.join(args.output_dir,
                                           "training_progress_scores.csv"),
                              index=False)

                if args.wandb_project:
                    wandb.log(self._get_last_metrics(training_progress_scores))

                if not best_eval_metric:
                    best_eval_metric = results[args.early_stopping_metric]
                    self.save_model(args.best_model_dir,
                                    optimizer,
                                    scheduler,
                                    model=model,
                                    results=results)
                if best_eval_metric and args.early_stopping_metric_minimize:
                    if results[
                            args.
                            early_stopping_metric] - best_eval_metric < args.early_stopping_delta:
                        best_eval_metric = results[args.early_stopping_metric]
                        self.save_model(args.best_model_dir,
                                        optimizer,
                                        scheduler,
                                        model=model,
                                        results=results)
                        early_stopping_counter = 0
                    else:
                        if args.use_early_stopping and args.early_stopping_consider_epochs:
                            if early_stopping_counter < args.early_stopping_patience:
                                early_stopping_counter += 1
                                if verbose:
                                    logger.info(
                                        f" No improvement in {args.early_stopping_metric}"
                                    )
                                    logger.info(
                                        f" Current step: {early_stopping_counter}"
                                    )
                                    logger.info(
                                        f" Early stopping patience: {args.early_stopping_patience}"
                                    )
                            else:
                                if verbose:
                                    logger.info(
                                        f" Patience of {args.early_stopping_patience} steps reached"
                                    )
                                    logger.info(" Training terminated.")
                                    train_iterator.close()
                                return (
                                    global_step,
                                    tr_loss / global_step
                                    if not self.args.evaluate_during_training
                                    else training_progress_scores,
                                )
                else:
                    if results[
                            args.
                            early_stopping_metric] - best_eval_metric > args.early_stopping_delta:
                        best_eval_metric = results[args.early_stopping_metric]
                        self.save_model(args.best_model_dir,
                                        optimizer,
                                        scheduler,
                                        model=model,
                                        results=results)
                        early_stopping_counter = 0
                    else:
                        if args.use_early_stopping and args.early_stopping_consider_epochs:
                            if early_stopping_counter < args.early_stopping_patience:
                                early_stopping_counter += 1
                                if verbose:
                                    logger.info(
                                        f" No improvement in {args.early_stopping_metric}"
                                    )
                                    logger.info(
                                        f" Current step: {early_stopping_counter}"
                                    )
                                    logger.info(
                                        f" Early stopping patience: {args.early_stopping_patience}"
                                    )
                            else:
                                if verbose:
                                    logger.info(
                                        f" Patience of {args.early_stopping_patience} steps reached"
                                    )
                                    logger.info(" Training terminated.")
                                    train_iterator.close()
                                return (
                                    global_step,
                                    tr_loss / global_step
                                    if not self.args.evaluate_during_training
                                    else training_progress_scores,
                                )

        return (
            global_step,
            tr_loss / global_step if not self.args.evaluate_during_training
            else training_progress_scores,
        )
def train(hyp, opt, device, tb_writer=None, wandb=None):
    logger.info(
        colorstr("hyperparameters: ") + ", ".join(f"{k}={v}"
                                                  for k, v in hyp.items()))
    save_dir, epochs, batch_size, total_batch_size, weights, rank = (
        Path(opt.save_dir),
        opt.epochs,
        opt.batch_size,
        opt.total_batch_size,
        opt.weights,
        opt.global_rank,
    )

    # Directories
    wdir = save_dir / "weights"
    wdir.mkdir(parents=True, exist_ok=True)  # make dir
    last = wdir / "last.pt"
    best = wdir / "best.pt"
    results_file = save_dir / "results.txt"

    # Save run settings
    with open(save_dir / "hyp.yaml", "w") as f:
        yaml.dump(hyp, f, sort_keys=False)
    with open(save_dir / "opt.yaml", "w") as f:
        # yaml.dump(vars(opt), f, sort_keys=False)  # opt 実行パラメータ
        yaml.dump(str(opt), f, sort_keys=False)

    # Configure
    plots = not opt.evolve  # create plots
    cuda = device.type != "cpu"
    init_seeds(2 + rank)
    with open(opt.data) as f:
        data_dict = yaml.load(f, Loader=yaml.SafeLoader)  # data dict
    with torch_distributed_zero_first(rank):
        check_dataset(data_dict)  # check
    train_path = data_dict["train"]
    test_path = data_dict["val"]
    nc = 1 if opt.single_cls else int(data_dict["nc"])  # number of classes
    names = (["item"] if opt.single_cls and len(data_dict["names"]) != 1 else
             data_dict["names"])  # class names
    assert len(names) == nc, "%g names found for nc=%g dataset in %s" % (
        len(names),
        nc,
        opt.data,
    )  # check

    # Model
    pretrained = weights.endswith(".pt")
    if pretrained:
        with torch_distributed_zero_first(rank):
            attempt_download(weights)  # download if not found locally
        ckpt = torch.load(weights, map_location=device)  # load checkpoint
        if hyp.get("anchors"):
            ckpt["model"].yaml["anchors"] = round(
                hyp["anchors"])  # force autoanchor
        model = Model(opt.cfg or ckpt["model"].yaml, ch=3,
                      nc=nc).to(device)  # create
        exclude = ["anchor"] if opt.cfg or hyp.get("anchors") else [
        ]  # exclude keys
        state_dict = ckpt["model"].float().state_dict()  # to FP32
        state_dict = intersect_dicts(state_dict,
                                     model.state_dict(),
                                     exclude=exclude)  # intersect
        model.load_state_dict(state_dict, strict=False)  # load
        logger.info(
            "Transferred %g/%g items from %s" %
            (len(state_dict), len(model.state_dict()), weights))  # report
    else:
        model = Model(opt.cfg, ch=3, nc=nc).to(device)  # create

    # Freeze
    freeze = []  # parameter names to freeze (full or partial)
    for k, v in model.named_parameters():
        v.requires_grad = True  # train all layers
        if any(x in k for x in freeze):
            print("freezing %s" % k)
            v.requires_grad = False

    # Optimizer
    nbs = 64  # nominal batch size
    accumulate = max(round(nbs / total_batch_size),
                     1)  # accumulate loss before optimizing
    hyp["weight_decay"] *= total_batch_size * accumulate / nbs  # scale weight_decay
    logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")

    pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
    for k, v in model.named_modules():
        if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
            pg2.append(v.bias)  # biases
        if isinstance(v, nn.BatchNorm2d):
            pg0.append(v.weight)  # no decay
        elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
            pg1.append(v.weight)  # apply decay

    if opt.adam:
        optimizer = optim.Adam(pg0,
                               lr=hyp["lr0"],
                               betas=(hyp["momentum"],
                                      0.999))  # adjust beta1 to momentum
    else:
        optimizer = optim.SGD(pg0,
                              lr=hyp["lr0"],
                              momentum=hyp["momentum"],
                              nesterov=True)

    optimizer.add_param_group({
        "params": pg1,
        "weight_decay": hyp["weight_decay"]
    })  # add pg1 with weight_decay
    optimizer.add_param_group({"params": pg2})  # add pg2 (biases)
    logger.info("Optimizer groups: %g .bias, %g conv.weight, %g other" %
                (len(pg2), len(pg1), len(pg0)))
    del pg0, pg1, pg2

    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
    if opt.linear_lr:
        lf = (lambda x: (1 - x / (epochs - 1)) *
              (1.0 - hyp["lrf"]) + hyp["lrf"])  # linear
    else:
        lf = one_cycle(1, hyp["lrf"], epochs)  # cosine 1->hyp['lrf']
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    # plot_lr_scheduler(optimizer, scheduler, epochs)

    # Logging
    if rank in [-1, 0] and wandb and wandb.run is None:
        opt.hyp = hyp  # add hyperparameters
        wandb_run = wandb.init(
            config=opt,
            resume="allow",
            project="YOLOv5"
            if opt.project == "runs/train" else Path(opt.project).stem,
            name=save_dir.stem,
            id=ckpt.get("wandb_id") if "ckpt" in locals() else None,
        )
    loggers = {"wandb": wandb}  # loggers dict

    # Resume
    start_epoch, best_fitness = 0, 0.0
    if pretrained:
        # Optimizer
        if ckpt["optimizer"] is not None:
            optimizer.load_state_dict(ckpt["optimizer"])
            best_fitness = ckpt["best_fitness"]

        # Results
        if ckpt.get("training_results") is not None:
            with open(results_file, "w") as file:
                file.write(ckpt["training_results"])  # write results.txt

        # Epochs
        start_epoch = ckpt["epoch"] + 1
        if opt.resume:
            assert (
                start_epoch > 0
            ), "%s training to %g epochs is finished, nothing to resume." % (
                weights,
                epochs,
            )
        if epochs < start_epoch:
            logger.info(
                "%s has been trained for %g epochs. Fine-tuning for %g additional epochs."
                % (weights, ckpt["epoch"], epochs))
            epochs += ckpt["epoch"]  # finetune additional epochs

        del ckpt, state_dict

    # Image sizes
    gs = int(model.stride.max())  # grid size (max stride)
    nl = model.model[
        -1].nl  # number of detection layers (used for scaling hyp['obj'])
    imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size
                         ]  # verify imgsz are gs-multiples

    # DP mode
    if cuda and rank == -1 and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    # SyncBatchNorm
    if opt.sync_bn and cuda and rank != -1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
        logger.info("Using SyncBatchNorm()")

    # EMA
    ema = ModelEMA(model) if rank in [-1, 0] else None

    # DDP mode
    if cuda and rank != -1:
        model = DDP(model,
                    device_ids=[opt.local_rank],
                    output_device=opt.local_rank)

    # Trainloader
    dataloader, dataset = create_dataloader(
        train_path,
        imgsz,
        batch_size,
        gs,
        opt,
        hyp=hyp,
        augment=True,
        cache=opt.cache_images,
        rect=opt.rect,
        rank=rank,
        world_size=opt.world_size,
        workers=opt.workers,
        image_weights=opt.image_weights,
        quad=opt.quad,
        prefix=colorstr("train: "),
    )
    mlc = np.concatenate(dataset.labels, 0)[:, 0].max()  # max label class
    nb = len(dataloader)  # number of batches
    assert (
        mlc < nc
    ), "Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g" % (
        mlc,
        nc,
        opt.data,
        nc - 1,
    )

    # Process 0
    if rank in [-1, 0]:
        ema.updates = start_epoch * nb // accumulate  # set EMA updates
        testloader = create_dataloader(
            test_path,
            imgsz_test,
            batch_size * 2,
            gs,
            opt,  # testloader
            hyp=hyp,
            cache=opt.cache_images and not opt.notest,
            rect=True,
            rank=-1,
            world_size=opt.world_size,
            workers=opt.workers,
            pad=0.5,
            prefix=colorstr("val: "),
        )[0]

        if not opt.resume:
            labels = np.concatenate(dataset.labels, 0)
            c = torch.tensor(labels[:, 0])  # classes
            # cf = torch.bincount(c.long(), minlength=nc) + 1.  # frequency
            # model._initialize_biases(cf.to(device))
            if plots:
                plot_labels(labels, save_dir, loggers)
                if tb_writer:
                    tb_writer.add_histogram("classes", c, 0)

            # Anchors
            if not opt.noautoanchor:
                check_anchors(dataset,
                              model=model,
                              thr=hyp["anchor_t"],
                              imgsz=imgsz)

    # Model parameters
    hyp["box"] *= 3.0 / nl  # scale to layers
    hyp["cls"] *= nc / 80.0 * 3.0 / nl  # scale to classes and layers
    hyp["obj"] *= (imgsz / 640)**2 * 3.0 / nl  # scale to image size and layers
    model.nc = nc  # attach number of classes to model
    model.hyp = hyp  # attach hyperparameters to model
    model.gr = 1.0  # iou loss ratio (obj_loss = 1.0 or iou)
    model.class_weights = (
        labels_to_class_weights(dataset.labels, nc).to(device) * nc
    )  # attach class weights
    model.names = names

    # Start training
    t0 = time.time()
    nw = max(round(hyp["warmup_epochs"] * nb),
             1000)  # number of warmup iterations, max(3 epochs, 1k iterations)
    # nw = min(nw, (epochs - start_epoch) / 2 * nb)  # limit warmup to < 1/2 of training
    maps = np.zeros(nc)  # mAP per class
    results = (0, 0, 0, 0, 0, 0, 0
               )  # P, R, [email protected], [email protected], val_loss(box, obj, cls)
    scheduler.last_epoch = start_epoch - 1  # do not move
    scaler = amp.GradScaler(enabled=cuda)
    compute_loss = ComputeLoss(model)  # init loss class
    logger.info(f"Image sizes {imgsz} train, {imgsz_test} test\n"
                f"Using {dataloader.num_workers} dataloader workers\n"
                f"Logging results to {save_dir}\n"
                f"Starting training for {epochs} epochs...")
    for epoch in range(
            start_epoch, epochs
    ):  # epoch ------------------------------------------------------------------
        model.train()

        # Update image weights (optional)
        if opt.image_weights:
            # Generate indices
            if rank in [-1, 0]:
                cw = (model.class_weights.cpu().numpy() * (1 - maps)**2 / nc
                      )  # class weights
                iw = labels_to_image_weights(dataset.labels,
                                             nc=nc,
                                             class_weights=cw)  # image weights
                dataset.indices = random.choices(
                    range(dataset.n), weights=iw,
                    k=dataset.n)  # rand weighted idx
            # Broadcast if DDP
            if rank != -1:
                indices = (torch.tensor(dataset.indices)
                           if rank == 0 else torch.zeros(dataset.n)).int()
                dist.broadcast(indices, 0)
                if rank != 0:
                    dataset.indices = indices.cpu().numpy()

        # Update mosaic border
        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders

        mloss = torch.zeros(4, device=device)  # mean losses
        if rank != -1:
            dataloader.sampler.set_epoch(epoch)
        pbar = enumerate(dataloader)
        logger.info(
            ("\n" + "%10s" * 8) % ("Epoch", "gpu_mem", "box", "obj", "cls",
                                   "total", "targets", "img_size"))
        if rank in [-1, 0]:
            pbar = tqdm(pbar, total=nb)  # progress bar
        optimizer.zero_grad()
        for i, (
                imgs,
                targets,
                paths,
                _,
        ) in (
                pbar
        ):  # batch -------------------------------------------------------------
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = (imgs.to(device, non_blocking=True).float() / 255.0
                    )  # uint8 to float32, 0-255 to 0.0-1.0

            # Warmup
            if ni <= nw:
                xi = [0, nw]  # x interp
                # model.gr = np.interp(ni, xi, [0.0, 1.0])  # iou loss ratio (obj_loss = 1.0 or iou)
                accumulate = max(
                    1,
                    np.interp(ni, xi, [1, nbs / total_batch_size]).round())
                for j, x in enumerate(optimizer.param_groups):
                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                    x["lr"] = np.interp(
                        ni,
                        xi,
                        [
                            hyp["warmup_bias_lr"] if j == 2 else 0.0,
                            x["initial_lr"] * lf(epoch),
                        ],
                    )
                    if "momentum" in x:
                        x["momentum"] = np.interp(
                            ni, xi, [hyp["warmup_momentum"], hyp["momentum"]])

            # Multi-scale
            if opt.multi_scale:
                sz = random.randrange(imgsz * 0.5,
                                      imgsz * 1.5 + gs) // gs * gs  # size
                sf = sz / max(imgs.shape[2:])  # scale factor
                if sf != 1:
                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]
                          ]  # new shape (stretched to gs-multiple)
                    imgs = F.interpolate(imgs,
                                         size=ns,
                                         mode="bilinear",
                                         align_corners=False)

            # Forward
            with amp.autocast(enabled=cuda):
                pred = model(imgs)  # forward
                loss, loss_items = compute_loss(
                    pred, targets.to(device))  # loss scaled by batch_size
                if rank != -1:
                    loss *= (opt.world_size
                             )  # gradient averaged between devices in DDP mode
                if opt.quad:
                    loss *= 4.0

            # Backward
            scaler.scale(loss).backward()

            # Optimize
            if ni % accumulate == 0:
                scaler.step(optimizer)  # optimizer.step
                scaler.update()
                optimizer.zero_grad()
                if ema:
                    ema.update(model)

            # Print
            if rank in [-1, 0]:
                mloss = (mloss * i + loss_items) / (i + 1
                                                    )  # update mean losses
                mem = "%.3gG" % (torch.cuda.memory_reserved() / 1e9
                                 if torch.cuda.is_available() else 0)  # (GB)
                s = ("%10s" * 2 + "%10.4g" * 6) % (
                    "%g/%g" % (epoch, epochs - 1),
                    mem,
                    *mloss,
                    targets.shape[0],
                    imgs.shape[-1],
                )
                pbar.set_description(s)

                # Plot
                if plots and ni < 3:
                    f = save_dir / f"train_batch{ni}.jpg"  # filename
                    Thread(target=plot_images,
                           args=(imgs, targets, paths, f),
                           daemon=True).start()
                    # if tb_writer:
                    #     tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
                    #     tb_writer.add_graph(model, imgs)  # add model to tensorboard
                elif plots and ni == 10 and wandb:
                    wandb.log(
                        {
                            "Mosaics": [
                                wandb.Image(str(x), caption=x.name)
                                for x in save_dir.glob("train*.jpg")
                                if x.exists()
                            ]
                        },
                        commit=False,
                    )

            # end batch ------------------------------------------------------------------------------------------------
        # end epoch ----------------------------------------------------------------------------------------------------

        # Scheduler
        lr = [x["lr"] for x in optimizer.param_groups]  # for tensorboard
        scheduler.step()

        # DDP process 0 or single-GPU
        if rank in [-1, 0]:
            # mAP
            if ema:
                ema.update_attr(
                    model,
                    include=[
                        "yaml",
                        "nc",
                        "hyp",
                        "gr",
                        "names",
                        "stride",
                        "class_weights",
                    ],
                )
            final_epoch = epoch + 1 == epochs
            if not opt.notest or final_epoch:  # Calculate mAP
                results, maps, times = test.test(
                    opt.data,
                    batch_size=batch_size * 2,
                    imgsz=imgsz_test,
                    model=ema.ema,
                    single_cls=opt.single_cls,
                    dataloader=testloader,
                    save_dir=save_dir,
                    verbose=nc < 50 and final_epoch,
                    plots=plots and final_epoch,
                    log_imgs=opt.log_imgs if wandb else 0,
                    compute_loss=compute_loss,
                )

            # Write
            with open(results_file, "a") as f:
                f.write(
                    s + "%10.4g" * 7 % results +
                    "\n")  # P, R, [email protected], [email protected], val_loss(box, obj, cls)
            if len(opt.name) and opt.bucket:
                os.system("gsutil cp %s gs://%s/results/results%s.txt" %
                          (results_file, opt.bucket, opt.name))

            # Log
            tags = [
                "train/box_loss",
                "train/obj_loss",
                "train/cls_loss",  # train loss
                "metrics/precision",
                "metrics/recall",
                "metrics/mAP_0.5",
                "metrics/mAP_0.5:0.95",
                "val/box_loss",
                "val/obj_loss",
                "val/cls_loss",  # val loss
                "x/lr0",
                "x/lr1",
                "x/lr2",
            ]  # params
            for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
                if tb_writer:
                    tb_writer.add_scalar(tag, x, epoch)  # tensorboard
                if wandb:
                    wandb.log({tag: x}, step=epoch,
                              commit=tag == tags[-1])  # W&B

            # Update best mAP
            fi = fitness(np.array(results).reshape(
                1, -1))  # weighted combination of [P, R, [email protected], [email protected]]
            if fi > best_fitness:
                best_fitness = fi

            # Save model
            save = (not opt.nosave) or (final_epoch and not opt.evolve)
            if save:
                with open(results_file, "r") as f:  # create checkpoint
                    ckpt = {
                        "epoch":
                        epoch,
                        "best_fitness":
                        best_fitness,
                        "training_results":
                        f.read(),
                        "model":
                        ema.ema,
                        "optimizer":
                        None if final_epoch else optimizer.state_dict(),
                        "wandb_id":
                        wandb_run.id if wandb else None,
                    }

                # Save last, best and delete
                torch.save(ckpt, last)
                if best_fitness == fi:
                    torch.save(ckpt, best)
                del ckpt
        # end epoch ----------------------------------------------------------------------------------------------------
    # end training

    if rank in [-1, 0]:
        # Strip optimizers
        final = best if best.exists() else last  # final model
        for f in [last, best]:
            if f.exists():
                strip_optimizer(f)  # strip optimizers
        if opt.bucket:
            os.system(f"gsutil cp {final} gs://{opt.bucket}/weights")  # upload

        # Plots
        if plots:
            plot_results(save_dir=save_dir)  # save as results.png
            if wandb:
                files = [
                    "results.png",
                    "confusion_matrix.png",
                    *[f"{x}_curve.png" for x in ("F1", "PR", "P", "R")],
                ]
                wandb.log({
                    "Results": [
                        wandb.Image(str(save_dir / f), caption=f)
                        for f in files if (save_dir / f).exists()
                    ]
                })
                if opt.log_artifacts:
                    wandb.log_artifact(artifact_or_path=str(final),
                                       type="model",
                                       name=save_dir.stem)

        # Test best.pt
        logger.info("%g epochs completed in %.3f hours.\n" %
                    (epoch - start_epoch + 1, (time.time() - t0) / 3600))
        if opt.data.endswith("coco.yaml") and nc == 80:  # if COCO
            for conf, iou, save_json in (
                [0.25, 0.45, False],
                [0.001, 0.65, True],
            ):  # speed, mAP tests
                results, _, _ = test.test(
                    opt.data,
                    batch_size=batch_size * 2,
                    imgsz=imgsz_test,
                    conf_thres=conf,
                    iou_thres=iou,
                    model=attempt_load(final, device).half(),
                    single_cls=opt.single_cls,
                    dataloader=testloader,
                    save_dir=save_dir,
                    save_json=save_json,
                    plots=False,
                )

    else:
        dist.destroy_process_group()

    wandb.run.finish() if wandb and wandb.run else None
    torch.cuda.empty_cache()

    # mlflow
    with mlflow.start_run() as run:
        # Log args into mlflow
        for key, value in hyp.items():
            mlflow.log_param(key, value)

        for key, value in vars(opt).items():
            mlflow.log_param(key, value)

        # Log results into mlflow
        for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
            # xがtorch.Tensorだったらfloatに直す
            if torch.is_tensor(x):
                x = x.item()

            # tag名に特殊記号があれば削除する
            if ":" in tag:
                tag = re.sub(r":", " ", tag)

            mlflow.log_metric(tag, x)

        # Log model
        mlflow.pytorch.log_model(model, "model")

    return results