Exemple #1
0
def train(
        net: torch.nn.Module,
        dataloader: DataLoader,
        criterion,  # torch.nn._Loss
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler._LRScheduler,
        epochs=10,
) -> Iterable[torch.nn.Module]:
    net.to(TORCH_DEVICE, dtype=torch.float)

    i = 0
    for epoch in range(epochs):  # loop over the dataset multiple times
        for j, (inputs, labels) in enumerate(dataloader):
            # get the inputs
            inputs, labels = (
                inputs.to(TORCH_DEVICE, dtype=torch.float),
                labels.to(TORCH_DEVICE, dtype=torch.long)
            )
            i += len(inputs)
            # zero the parameter gradients
            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            if np.isnan(loss.item()):
                raise Exception('gradients blew up')
            writer.add_scalar('Train/Loss', loss, i)
            writer.add_scalar('Train/LR', optimizer.param_groups[0]['lr'], i)
        scheduler.step()
        yield net, loss.item()
Exemple #2
0
def load_checkpoint(
    checkpoint_file: Union[str, Path],
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer = None,
    scheduler: torch.optim.lr_scheduler._LRScheduler = None,
    map_location: Union[torch.device, str, Mapping[str, str], Callable] = None,
) -> None:
    """Shortcut for loading checkpoint state.

    Args:
        checkpoint_file (Union[str, Path]): path to checkpoint
        model (torch.nn.Module): model to initialize with checkpoint weights
        optimizer (torch.optim.Optimizer, optional): optimizer to initialize with checkpoint weights.
            If `None` then will be ignored.
            Default is None.
        scheduler (torch.optim.lr_scheduler._LRScheduler, optional): scheduler to initialize with checkpoint weights.
            If `None` then will be ignored.
            Default is None.
        map_location (Union[torch.device, str, Mapping[int, str], Callable], optional):
            location to use for loading checkpoint content.
            More about possible locations - https://pytorch.org/docs/master/generated/torch.load.html
            Default is None.
    """
    checkpoint = torch.load(str(checkpoint_file), map_location=map_location)
    loaded_items = []

    if "model_state_dict" in checkpoint and model is not None:
        state_dict = checkpoint["model_state_dict"]
        if isinstance(model, torch.nn.DataParallel):
            model.module.load_state_dict(state_dict)
        else:
            model.load_state_dict(state_dict)
        loaded_items.append("model")

    if "optimizer_state_dict" in checkpoint and optimizer is not None:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        loaded_items.append("optimizer")

    if "scheduler_state_dict" in checkpoint and scheduler is not None:
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        loaded_items.append("scheduler")

    if loaded_items:
        print("<= Loaded {} from '{}'".format(", ".join(loaded_items),
                                              checkpoint_file))

        if "stage" in checkpoint:
            print("Stage -", checkpoint["stage"])

        if "epoch" in checkpoint:
            print("Epoch -", checkpoint["epoch"])

        if "metrics" in checkpoint:
            print("Metrics:")
            pprint(checkpoint["metrics"])
Exemple #3
0
def train(
    model: torch.nn.Module,
    train_data: torch.utils.data.Dataset,
    val_data: torch.utils.data.Dataset,
    test_data: torch.utils.data.Dataset,
    optimiser: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler._LRScheduler,
    criterion: Optional[Callable[..., torch.Tensor]],
    loggers: Iterable[Logger],
    config: argparse.Namespace,
):
    """Train a model."""
    dataloader = DataLoader(train_data,
                            batch_size=config.batch_size,
                            shuffle=True,
                            num_workers=0)
    class_count = len(set(dataloader.dataset.tensors[1].tolist()))
    for epoch in range(config.epochs):
        model.train()
        for batch, (feats, labels) in enumerate(dataloader):
            # Move data to GPU
            feats = feats.to(config.device)
            # Convert labels to one-hots if using BCEWithLogitsLoss
            if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
                labels = F.one_hot(labels,
                                   num_classes=class_count).type(torch.float32)
            labels = labels.to(config.device)
            optimiser.zero_grad()
            clean_posteriors, clean_activations = model(feats)
            # TODO Make sure we apply log for NONE case (in new loss class)
            loss = criterion(clean_posteriors, labels)
            loss.backward()

            optimiser.step()

            preds = torch.argmax(clean_posteriors, dim=-1).cpu().numpy()
            if batch % config.log_step == config.log_step - 1 or batch == len(
                    dataloader) - 1:
                metrics = {
                    "train/epoch":
                    epoch +
                    batch * dataloader.batch_size / len(dataloader.dataset),
                    "train/lr":
                    scheduler.get_last_lr()[0],
                    "train/loss":
                    loss.item(),
                    "train/accuracy":
                    accuracy_score(labels.cpu().numpy(), preds),
                }
                for logger in loggers:
                    logger(metrics)
        evaluate(model, epoch + 1, val_data, criterion, loggers, config)
        test(model, test_data, loggers, config)
        scheduler.step()
Exemple #4
0
def train(train_loader: DataLoader, val_loader: DataLoader, model: nn.Module,
          criterion: nn.Module, optimizer: Optimizer,
          scheduler: torch.optim.lr_scheduler._LRScheduler, args):
    """ 训练模型
    :param train_loader: 训练集
    :param val_loader: 验证集
    :param model: 模型
    :param criterion: 损失函数
    :param optimizer: 优化器
    :param args: 训练超参
    """
    writer = SummaryWriter(args.logdir)

    best_val_acc1 = 0
    learning_rate = 0
    for epoch in range(args.epochs):
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
        learning_rate = scheduler.get_last_lr()
        if isinstance(learning_rate, list):
            learning_rate = learning_rate[0]
        # 训练一个epoch,并在验证集上评估
        train_loss, train_acc1 = train_epoch(train_loader, model, criterion,
                                             optimizer, epoch, args)
        val_acc1, val_loss, _ = test(val_loader, model, criterion, args)
        scheduler.step()
        # 保存当前及最好的acc@1的checkpoint
        is_best = val_acc1 > best_val_acc1
        best_val_acc1 = max(val_acc1, best_val_acc1)
        save_checkpoint({
            # 'epoch': epoch + 1,
            # 'arch': args.arch,
            'state_dict': model.module.state_dict(),
            # 'best_acc1': best_val_acc1,
            # 'optimizer': optimizer.state_dict(),
        }, is_best, args)
        writer.add_scalar('learning rate', learning_rate, epoch)
        writer.add_scalar('Train/Loss', train_loss, epoch)
        writer.add_scalar('Train/Accuracy', train_acc1, epoch)
        writer.add_scalar('Val/Loss', val_loss, epoch)
        writer.add_scalar('Val/Accuracy', val_acc1, epoch)
        writer.flush()
    writer.close()
    logging.info(f'Training Over with lr={learning_rate}~~')
Exemple #5
0
    def fit(  # noqa: C901
        self,
        train_generator: DataLoader,
        val_generator: DataLoader,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler._LRScheduler,
        epochs: int,
        criterion_dict: dict[str, nn.Module],
        normalizer_dict: dict[str, Normalizer],
        model_name: str,
        run_id: int,
        checkpoint: bool = True,
        writer: SummaryWriter = None,
        verbose: bool = True,
        patience: int = None,
    ) -> None:
        """Convenience class to carry out training loop.

        Args:
            train_generator (DataLoader): Dataloader containing training data.
            val_generator (DataLoader): Dataloader containing validation data.
            optimizer (torch.optim.Optimizer): Optimizer used to carry out parameter updates.
            scheduler (torch.optim.lr_scheduler._LRScheduler): Scheduler used to adjust
                Optimizer during training.
            epochs (int): Number of epochs to train for.
            criterion_dict (dict[str, nn.Module]): Dictionary of losses to apply for each task.
            normalizer_dict (dict[str, Normalizer]): Dictionary of Normalizers to apply
                to each task.
            model_name (str): String describing the model.
            run_id (int): Unique identifier of the model run.
            checkpoint (bool, optional): Whether to save model checkpoints. Defaults to True.
            writer (SummaryWriter, optional): TensorBoard writer to save logs in. Defaults to None.
            verbose (bool, optional): Whether to print out intermediate results. Defaults to True.
            patience (int, optional): Patience for early stopping. Defaults to None.
        """
        start_epoch = self.epoch

        try:
            for epoch in range(start_epoch, start_epoch + epochs):
                self.epoch += 1
                # Training
                t_metrics = self.evaluate(
                    generator=train_generator,
                    criterion_dict=criterion_dict,
                    optimizer=optimizer,
                    normalizer_dict=normalizer_dict,
                    action="train",
                    verbose=verbose,
                )

                if writer is not None:
                    for task, metrics in t_metrics.items():
                        for metric, val in metrics.items():
                            writer.add_scalar(f"{task}/train/{metric}", val,
                                              epoch)

                if verbose:
                    print(f"Epoch: [{epoch}/{start_epoch + epochs - 1}]")
                    for task, metrics in t_metrics.items():
                        metrics_str = "".join([
                            f"{key} {val:.3f}\t"
                            for key, val in metrics.items()
                        ])
                        print(f"Train \t\t: {task} - {metrics_str}")

                # Validation
                if val_generator is not None:
                    with torch.no_grad():
                        # evaluate on validation set
                        v_metrics = self.evaluate(
                            generator=val_generator,
                            criterion_dict=criterion_dict,
                            optimizer=None,
                            normalizer_dict=normalizer_dict,
                            action="val",
                        )

                    if writer is not None:
                        for task, metrics in v_metrics.items():
                            for metric, val in metrics.items():
                                writer.add_scalar(
                                    f"{task}/validation/{metric}", val, epoch)

                    if verbose:
                        for task, metrics in v_metrics.items():
                            metrics_str = "".join([
                                f"{key} {val:.3f}\t"
                                for key, val in metrics.items()
                            ])
                            print(f"Validation \t: {task} - {metrics_str}")

                    # TODO test all tasks to see if they are best,
                    # save a best model if any is best.
                    # TODO what are the costs of this approach.
                    # It could involve saving a lot of models?

                    is_best: list[bool] = []

                    for name in self.best_val_scores:
                        if self.task_dict[name] == "regression":
                            if v_metrics[name]["MAE"] < self.best_val_scores[
                                    name]:
                                self.best_val_scores[name] = v_metrics[name][
                                    "MAE"]
                                is_best.append(True)
                            is_best.append(False)
                        elif self.task_dict[name] == "classification":
                            if v_metrics[name]["Acc"] > self.best_val_scores[
                                    name]:
                                self.best_val_scores[name] = v_metrics[name][
                                    "Acc"]
                                is_best.append(True)
                            is_best.append(False)

                    if any(is_best):
                        self.es_patience = 0
                    else:
                        self.es_patience += 1
                        if patience and self.es_patience > patience:
                            print(
                                "Stopping early due to lack of improvement on Validation set"
                            )
                            break

                if checkpoint:
                    checkpoint_dict = {
                        "model_params": self.model_params,
                        "state_dict": self.state_dict(),
                        "epoch": self.epoch,
                        "best_val_score": self.best_val_scores,
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict(),
                        "normalizer_dict": {
                            task: n.state_dict()
                            if isinstance(n, Normalizer) else None
                            for task, n in normalizer_dict.items()
                        },
                    }

                    # TODO saving a model at each epoch may be slow?
                    save_checkpoint(checkpoint_dict, False, model_name, run_id)

                    # TODO when to save best models? should this be done task-wise in
                    # the multi-task case?
                    # save_checkpoint(checkpoint_dict, is_best, model_name, run_id)

                scheduler.step()

                # catch memory leak
                gc.collect()

        except KeyboardInterrupt:
            pass

        if writer is not None:
            writer.close()
def train_model(
    model: torch.nn.Module,
    train_dataset: Dataset,
    val_dataset: Dataset,
    val_transforms: Callable,
    loss_fn: Callable,
    metrics: Dict[str, Callable],
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler._LRScheduler = None,
    loader_workers: int = 4,
    batch_size: int = 32,
    n_epochs: int = 20,
    device: str = 'cuda:0',
    logger: ExperimentLogger = None,
    snapshot_dir: str = './snapshots',
    cycle: int = None,
):
    # TODO: adaptive learning rate, early stopping conditions
    train_loader = DataLoader(train_dataset,
                              batch_size,
                              shuffle=True,
                              num_workers=loader_workers)
    val_loader = DataLoader(val_dataset,
                            batch_size,
                            shuffle=False,
                            num_workers=loader_workers)
    if cycle is not None:
        snapshot_dir = os.path.join(snapshot_dir, f'cycle_{cycle}')
        os.makedirs(snapshot_dir)
    best_mertics_dict = {metric_name: 0 for metric_name in metrics}
    best_val_loss = 1e10
    for epoch in range(n_epochs):
        metrics_dict = defaultdict(list)

        train_loss = []
        val_loss = []
        model.train(True)
        for batch in train_loader:
            image, mask = batch['image'], batch['mask']
            image = image.to(device)
            y_pred = model(image)
            loss = loss_fn(y_pred, mask.to(device))
            for metric_name, metric in metrics.items():
                metrics_dict[f'train_{metric_name}'] += metric(
                    y_pred, mask.to(device))

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            train_loss.append(loss.item())

        model.eval()
        orig_masks = []
        pred_masks = []
        e_loss = []
        for batch in val_loader:
            image, mask_raw = batch['image'], batch['mask']
            image = image.to(device)
            with torch.no_grad():
                y_pred_raw = model(image)
                y_pred = torch.sigmoid(y_pred_raw).cpu().squeeze().numpy()
                mask = mask_raw.squeeze().numpy()
                e_loss.append(loss_fn(y_pred_raw, mask_raw.to(device)))
                for i in range(mask.shape[0]):
                    predicted_mask = y_pred
                    original_mask = mask
                    pred_masks.append(val_transforms(predicted_mask.squeeze()))
                    orig_masks.append(val_transforms(original_mask.squeeze()))

        pred_masks_numpy = np.array(pred_masks)
        orig_masks = torch.Tensor(orig_masks).unsqueeze(1)
        pred_masks = torch.Tensor(pred_masks_numpy).unsqueeze(1)
        val_loss.append(np.mean(e_loss))

        for metric_name, metric in metrics.items():
            metrics_dict[f'val_{metric_name}'] += metric(pred_masks,
                                                         orig_masks,
                                                         th=0.5)

        if scheduler:
            scheduler.step(epoch=epoch)

        logger.log(
            epoch=epoch,
            train_loss=np.mean(train_loss),
            train_iou=np.mean(metrics_dict['train_iou']),
            train_iout=np.mean(metrics_dict['train_iout']),
            val_loss=np.mean(val_loss),
            val_iou=np.mean(metrics_dict['val_iou']),
            val_iout=np.mean(metrics_dict['val_iout']),
            lr=optimizer.param_groups[0]['lr'],
        )
        logger.log_val_metrics(','.join(
            [f'{x:.4f}' for x in metrics_dict['val_iou']]))

        if np.mean(val_loss) < best_val_loss:
            best_val_loss = np.mean(val_loss)
            snapshot_name = f'epoch_{epoch}.loss_{np.mean(val_loss):.4f}.iout_{np.mean(metrics_dict["val_iout"]):.3f}'
            np.save(arr=pred_masks_numpy,
                    file=os.path.join(snapshot_dir,
                                      'val_masks_' + snapshot_name))
            torch.save(model.state_dict(),
                       os.path.join(snapshot_dir, snapshot_name))

        for metric_name in metrics:
            if best_mertics_dict[metric_name] < np.mean(
                    metrics_dict[f'val_{metric_name}']):
                best_mertics_dict[metric_name] = np.mean(
                    metrics_dict[f'val_{metric_name}'])
                snapshot_name = f'epoch_{epoch}.loss_{np.mean(val_loss):.4f}.iout_{np.mean(metrics_dict["val_iout"]):.3f}'
                np.save(arr=pred_masks_numpy,
                        file=os.path.join(snapshot_dir,
                                          'val_masks_' + snapshot_name))
                torch.save(model.state_dict(),
                           os.path.join(snapshot_dir, snapshot_name))
Exemple #7
0
    def attack(self, epochs: int, optimizer: torch.optim.Optimizer,
               lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
               validate_interval: int = 1, save: bool = False,
               verbose: bool = True, **kwargs):
        if verbose:
            print('train mask generator')
        self.mark_generator.requires_grad_(False)
        self.mask_generator.requires_grad_()
        self.model.requires_grad_(False)
        self.train_mask_generator(verbose=verbose)
        if verbose:
            print()
            print('train mark generator and model')

        self.mark_generator.requires_grad_()
        self.mask_generator.requires_grad_(False)
        if not self.natural:
            params: list[nn.Parameter] = []
            for param_group in optimizer.param_groups:
                params.extend(param_group['params'])
            self.model.activate_params(params)

        mark_optimizer = torch.optim.Adam(self.mark_generator.parameters(), lr=1e-2, betas=(0.5, 0.9))
        mark_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            mark_optimizer, T_max=epochs)
        loader = self.dataset.loader['train']
        dataset = loader.dataset
        logger = MetricLogger()
        logger.create_meters(loss=None, div=None, ce=None)

        if validate_interval != 0:
            best_validate_result = self.validate_fn(verbose=verbose)
            best_asr = best_validate_result[0]
        for _epoch in range(epochs):
            _epoch += 1
            idx = torch.randperm(len(dataset))
            pos = 0
            logger.reset()
            if not self.natural:
                self.model.train()
            self.mark_generator.train()
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                'Epoch', output_iter(_epoch, epochs), **ansi)
            header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header))
            for data in logger.log_every(loader, header=header) if verbose else loader:
                if not self.natural:
                    optimizer.zero_grad()
                mark_optimizer.zero_grad()
                _input, _label = self.model.get_data(data)
                batch_size = len(_input)
                data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size])
                _input2, _label2 = self.model.get_data(data2)
                pos += batch_size
                final_input, final_label = _input.clone(), _label.clone()

                # generate trigger input
                trigger_dec, trigger_int = math.modf(len(_label) * self.poison_percent)
                trigger_int = int(trigger_int)
                if random.uniform(0, 1) < trigger_dec:
                    trigger_int += 1
                x = _input[:trigger_int]
                trigger_mark, trigger_mask = self.get_mark(x), self.get_mask(x)
                trigger_input = x + trigger_mask * (trigger_mark - x)
                final_input[:trigger_int] = trigger_input
                final_label[:trigger_int] = self.target_class

                # generate cross input
                cross_dec, cross_int = math.modf(len(_label) * self.cross_percent)
                cross_int = int(cross_int)
                if random.uniform(0, 1) < cross_dec:
                    cross_int += 1
                x = _input[trigger_int:trigger_int + cross_int]
                x2 = _input2[trigger_int:trigger_int + cross_int]
                cross_mark, cross_mask = self.get_mark(x2), self.get_mask(x2)
                cross_input = x + cross_mask * (cross_mark - x)
                final_input[trigger_int:trigger_int + cross_int] = cross_input

                # div loss
                if len(trigger_input) <= len(cross_input):
                    length = len(trigger_input)
                    cross_input = cross_input[:length]
                    cross_mark = cross_mark[:length]
                    cross_mask = cross_mask[:length]
                else:
                    length = len(cross_input)
                    trigger_input = trigger_input[:length]
                    trigger_mark = trigger_mark[:length]
                    trigger_mask = trigger_mask[:length]
                input_dist: torch.Tensor = (trigger_input - cross_input).flatten(1).norm(p=2, dim=1)
                mark_dist: torch.Tensor = (trigger_mark - cross_mark).flatten(1).norm(p=2, dim=1) + 1e-5

                loss_ce = self.model.loss(final_input, final_label)
                loss_div = input_dist.div(mark_dist).mean()

                loss = loss_ce + self.lambda_div * loss_div
                loss.backward()
                if not self.natural:
                    optimizer.step()
                mark_optimizer.step()
                logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), ce=loss_ce.item())
            if not self.natural and lr_scheduler:
                lr_scheduler.step()
            mark_scheduler.step()
            if not self.natural:
                self.model.eval()
            self.mark_generator.eval()
            if validate_interval != 0 and (_epoch % validate_interval == 0 or _epoch == epochs):
                validate_result = self.validate_fn(verbose=verbose)
                cur_asr = validate_result[0]
                if cur_asr >= best_asr:
                    best_validate_result = validate_result
                    best_asr = cur_asr
                    if save:
                        self.save()
        if not self.natural:
            optimizer.zero_grad()
        mark_optimizer.zero_grad()
        self.mark_generator.requires_grad_(False)
        self.mask_generator.requires_grad_(False)
        self.model.requires_grad_(False)
        return best_validate_result
Exemple #8
0
def train_epoch(train_loader: torch.utils.data.DataLoader,
                base_model: torch.nn.Module,
                classification_layer: torch.nn.Module,
                forg_layer: torch.nn.Module,
                epoch: int,
                optimizer: torch.optim.Optimizer,
                lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
                callback: Optional[VisdomLogger],
                device: torch.device,
                args: Any):
    """ Trains the network for one epoch

        Parameters
        ----------
        train_loader: torch.utils.data.DataLoader
            Iterable that loads the training set (x, y) tuples
        base_model: torch.nn.Module
            The model architecture that "extract features" from signatures
        classification_layer: torch.nn.Module
            The classification layer (from features to predictions of which user
            wrote the signature)
        forg_layer: torch.nn.Module
            The forgery prediction layer (from features to predictions of whether
            the signature is a forgery). Only used in args.forg = True
        epoch: int
            The current epoch (used for reporting)
        optimizer: torch.optim.Optimizer
            The optimizer (already initialized)
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler
            The learning rate scheduler
        callback: VisdomLogger (optional)
            A callback to report the training progress
        device: torch.device
            The device (CPU or GPU) to use for training
        args: Namespace
            Extra arguments used for training:
            args.forg: bool
                Whether forgeries are being used for training
            args.lamb: float
                The weight used for the forgery loss (training with forgeries only)

        Returns
        -------
        None
        """

    step = 0
    n_steps = len(train_loader)
    for batch in train_loader:
        x, y = batch[0], batch[1]
        x = torch.tensor(x, dtype=torch.float).to(device)
        y = torch.tensor(y, dtype=torch.long).to(device)
        yforg = torch.tensor(batch[2], dtype=torch.float).to(device)

        # Forward propagation
        features = base_model(x)

        if args.forg:
            if args.loss_type == 'L1':
                # Eq (3) in https://arxiv.org/abs/1705.05787
                logits = classification_layer(features)
                class_loss = F.cross_entropy(logits, y)

                forg_logits = forg_layer(features).squeeze()
                forg_loss = F.binary_cross_entropy_with_logits(forg_logits, yforg)

                loss = (1 - args.lamb) * class_loss
                loss += args.lamb * forg_loss
            else: 
                # Eq (4) in https://arxiv.org/abs/1705.05787
                logits = classification_layer(features[yforg == 0])
                class_loss = F.cross_entropy(logits, y[yforg == 0])

                forg_logits = forg_layer(features).squeeze()
                forg_loss = F.binary_cross_entropy_with_logits(forg_logits, yforg)

                loss = (1 - args.lamb) * class_loss
                loss += args.lamb * forg_loss
        else:
            # Eq (1) in https://arxiv.org/abs/1705.05787
            logits = classification_layer(features)
            loss = class_loss = F.cross_entropy(logits, y)

        # Back propagation
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(optimizer.param_groups[0]['params'], 10)

        # Update weights
        optimizer.step()

        # Logging
        if callback and step % 100 == 0:
            iteration = epoch + (step / n_steps)
            callback.scalar('class_loss', iteration, class_loss.detach())

            pred = logits.argmax(1)
            acc = y[yforg == 0].eq(pred).float().mean()
            callback.scalar('train_acc', epoch + (step / n_steps), acc.detach())
            if args.forg:
                forg_pred = forg_logits > 0
                forg_acc = yforg.long().eq(forg_pred.long()).float().mean()
                callback.scalar('forg_loss', iteration, forg_loss.detach())
                callback.scalar('forg_acc', iteration, forg_acc.detach())

        step += 1
    lr_scheduler.step()
Exemple #9
0
def train(train_loader: DataLoader, val_loader: DataLoader, model: nn.Module,
          criterion: nn.Module, optimizer: Optimizer,
          scheduler: torch.optim.lr_scheduler._LRScheduler, args):
    """
    训练模型
    :param train_loader: 训练集
    :param val_loader: 验证集
    :param model: 模型
    :param criterion: 损失函数
    :param optimizer: 优化器
    :param args: 训练超参
    """
    writer = SummaryWriter(args.logdir)
    # writer.add_graph(model, (torch.rand(1, 3, args.image_size[0], args.image_size[1]),))
    global mix_up, multi_scale, bn_gammas, net_weights
    if mix_up is None:
        mix_up = MixUp(args)
    if args.multi_scale and multi_scale is None:
        multi_scale = MultiScale(args.image_size)
    if bn_gammas is None:
        bn_gammas = [
            m.weight for m in model.modules() if isinstance(m, nn.BatchNorm2d)
            or isinstance(m, nn.SyncBatchNorm)
            or isinstance(m, apex.parallel.SyncBatchNorm)
        ]

    if net_weights is None:
        net_weights = [
            param for name, param in model.named_parameters()
            if name[-4:] != 'bias'
        ]

    best_val_acc1 = 0
    learning_rate = 0
    for epoch in range(args.epochs):
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
        learning_rate = scheduler.get_last_lr()
        if isinstance(learning_rate, list):
            learning_rate = learning_rate[0]
        # 训练一个epoch,并在验证集上评估
        train_loss, train_acc1 = train_epoch(train_loader, model, criterion,
                                             optimizer, epoch, args)
        val_acc1, val_loss, _ = test(val_loader,
                                     model,
                                     criterion,
                                     args,
                                     is_confuse_matrix=False)
        scheduler.step()
        # 保存当前及最好的acc@1的checkpoint
        is_best = val_acc1 >= best_val_acc1
        best_val_acc1 = max(val_acc1, best_val_acc1)
        save_checkpoint(
            {
                # 'epoch': epoch + 1,
                # 'arch': args.arch,
                'state_dict': model.module.state_dict(),
                # 'best_acc1': best_val_acc1,
                # 'optimizer': optimizer.state_dict(),
            },
            is_best,
            args)

        all_bn_weight = []
        for gamma in bn_gammas:
            all_bn_weight.append(gamma.cpu().data.numpy())
        writer.add_histogram('BN gamma', np.concatenate(all_bn_weight, axis=0),
                             epoch)
        # writer.add_scalars('Loss', {'Train': train_loss, 'Val': val_loss}, epoch)
        # writer.add_scalars('Accuracy', {'Train': train_acc1, 'Val': val_acc1}, epoch)
        writer.add_scalar('Train/Loss', train_loss, epoch)
        writer.add_scalar('Train/Accuracy', train_acc1, epoch)
        writer.add_scalar('Val/Loss', val_loss, epoch)
        writer.add_scalar('Val/Accuracy', val_acc1, epoch)
        writer.add_scalar('learning rate', learning_rate, epoch)
        writer.flush()
    writer.close()
    logging.info(f'Training Over with lr={learning_rate}~~')
def actor_critic_train_per_episode(
        policy: SimplePolicyContinuous,
        critic: SimpleCritic,
        env: gym.Env,
        optimizer: Optimizer,
        run_params: RunParams,
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None):
    """ Trains the actor critic on the given environment. Training is done at the end of each episode, instead
    of at the end of each step of an episode. This means the agent trains much more frequently.
    Both discrete and continuous actions spaces are supported. Several features can be optionally enabled:
    1) Scaling / normalizing the states / observations
    2) Logging training statistics on Tensorboard
    3) Render the environment periodically (pick render_frequency in the RunParams)
    4) Using a learning rate scheduler
    """
    training_info = TrainingInfo(GAMMA=run_params.gamma)
    print(
        f"The goal is a running reward of at least {env.spec.reward_threshold}."
    )

    # https://medium.com/@asteinbach/actor-critic-using-deep-rl-continuous-mountain-car-in-tensorflow-4c1fb2110f7c
    # says it's crucial to scale the state
    if run_params.should_scale_states:
        scaler = setup_observation_scaler(env)

    writer = run_params.get_tensorboard_writer(
        env) if run_params.use_tensorboard else None

    for episode_number in itertools.count(
    ):  # itertools.count() is basically range(+infinity)
        state = env.reset()

        # Do a whole episode (upto 10000 steps, don't want infinite steps)
        for t in range(env.spec.max_episode_steps):
            if run_params.should_scale_states:
                state = scale_state(scaler, state)

            if run_params.continuous_actions:
                action = select_action_continuous(state, policy, training_info,
                                                  env)
            else:
                action = select_action_discrete(state, policy, training_info)

            state_value = get_state_value(state, critic)

            new_state, reward, done, _ = env.step(action)

            if run_params.should_render(episode_number):
                env.render()

            training_info.record_step(
                state, action, reward,
                state_value)  # Store reward and updates the running reward
            state = new_state
            if done:
                break

        training_info.update_running_reward()

        # Add some logging
        log_on_console(env, episode_number, reward, run_params, t,
                       training_info)
        log_on_tensorboard(env, episode_number, reward, run_params, t,
                           training_info, writer)

        # Check if we have solved the environment reliably
        if env.spec.reward_threshold is not None and training_info.running_reward > env.spec.reward_threshold:
            print(
                f"Solved! The running reward is {training_info.running_reward:.2f}, which is above the threshold of "
                f"{env.spec.reward_threshold}. The last episode ran for {t} steps."
            )
            break

        train_policy_on_episode(optimizer, training_info, episode_number)

        if lr_scheduler:
            lr_scheduler.step(episode_number)

    close_tensorboard(run_params, writer)