Beispiel #1
0
    def _train_epoch(
        self,
        data: DataLoader,
        model: nn.Module,
        optimizer: optim.Optimizer,
        criterion: Callable,
        scheduler: optim.lr_scheduler._LRScheduler = None,
        clip: float = 1.0
    ):
        model.train()

        losses = []
        for i, inputs in enumerate(data):
            inputs = self.to_device(inputs)
            x, y = inputs['features'], inputs['targets']

            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            losses.append(loss.item())
            loss.backward()
            if clip is not None:
                nn.utils.clip_grad_norm_(model.parameters(), clip)

            optimizer.step()

        if scheduler is not None:
            scheduler.step()

        return losses
Beispiel #2
0
    def joint_train(self,
                    epoch: int = 0,
                    optimizer: optim.Optimizer = None,
                    lr_scheduler: optim.lr_scheduler._LRScheduler = None,
                    poison_loader=None,
                    discrim_loader=None,
                    save=False,
                    **kwargs):
        in_dim = self.model._model.classifier[0].in_features
        D = nn.Sequential(
            OrderedDict([('fc1', nn.Linear(in_dim,
                                           256)), ('bn1', nn.BatchNorm1d(256)),
                         ('relu1', nn.LeakyReLU()),
                         ('fc2', nn.Linear(256, 128)),
                         ('bn2', nn.BatchNorm1d(128)), ('relu2', nn.ReLU()),
                         ('fc3', nn.Linear(128, 2))]))
        if env['num_gpus']:
            D.cuda()
        optim_params: list[nn.Parameter] = []
        for param_group in optimizer.param_groups:
            optim_params.extend(param_group['params'])
        optimizer.zero_grad()

        best_acc = 0.0
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')

        for _epoch in range(epoch):
            self.discrim_train(epoch=100, D=D, discrim_loader=discrim_loader)

            self.model.train()
            self.model.activate_params(optim_params)
            for data in poison_loader:
                optimizer.zero_grad()
                _input, _label_f, _label_d = self.bypass_get_data(data)
                out_f = self.model(_input)
                loss_f = self.model.criterion(out_f, _label_f)
                out_d = D(self.model.get_final_fm(_input))
                loss_d = self.model.criterion(out_d, _label_d)

                loss = loss_f - self.lambd * loss_d
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
            if lr_scheduler:
                lr_scheduler.step()
            self.model.activate_params([])
            self.model.eval()
            _, cur_acc = self.validate_fn(get_data_fn=self.bypass_get_data)
            if cur_acc >= best_acc:
                prints('best result update!', indent=0)
                prints(
                    f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}',
                    indent=0)
                best_acc = cur_acc
                if save:
                    self.save()
            print('-' * 50)
Beispiel #3
0
 def attack(self, epoch: int, lr_scheduler: optim.lr_scheduler._LRScheduler = None,
            save: bool = False, **kwargs):
     print('Sample Data')
     poison_loader, discrim_loader = self.sample_data()  # with poisoned images
     print('Joint Training')
     super().attack(epoch=10, lr_scheduler=lr_scheduler, **kwargs)
     if isinstance(lr_scheduler, optim.lr_scheduler._LRScheduler):
         lr_scheduler.step(0)
     self.joint_train(epoch=epoch, poison_loader=poison_loader, discrim_loader=discrim_loader,
                      save=save, lr_scheduler=lr_scheduler, **kwargs)
Beispiel #4
0
    def save(
        self,
        model: nn.Module,
        optimizer: optim.Optimizer,
        scheduler: optim.lr_scheduler._LRScheduler,
        epoch: int,
        metric: float,
    ):
        if self.best_metric < metric:
            self.best_metric = metric
            self.best_epoch = epoch
            is_best = True
        else:
            is_best = False

        os.makedirs(self.root_dir, exist_ok=True)
        torch.save(
            {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "epoch": epoch,
                "best_epoch": self.best_epoch,
                "best_metric": self.best_metric,
            },
            osp.join(self.root_dir, f"{epoch:02d}.pth"),
        )

        if is_best:
            shutil.copy(
                osp.join(self.root_dir, f"{epoch:02d}.pth"),
                osp.join(self.root_dir, "best.pth"),
            )
Beispiel #5
0
def log_checkpoints(
    checkpoint_dir: Path,
    model: Union[nn.Module, nn.DataParallel],
    optimizer: Optimizer,
    scheduler: optim.lr_scheduler._LRScheduler,
    epoch: int,
) -> None:
    """
    Serialize a PyTorch model in the `checkpoint_dir`.

    Args:
        checkpoint_dir: the directory to store checkpoints
        model: the model to serialize
        optimizer: the optimizer to be saved
        scheduler: the LR scheduler to be saved
        epoch: the epoch number
    """
    checkpoint_file = 'checkpoint_{}.pt'.format(epoch)
    checkpoint_dir.mkdir(exist_ok=True, parents=True)
    file_path = checkpoint_dir / checkpoint_file

    if isinstance(model, nn.DataParallel):
        model_state_dict = model.module.state_dict()
    else:
        model_state_dict = model.state_dict()

    torch.save(  # type: ignore
        {
            'epoch': epoch,
            'model_state_dict': model_state_dict,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
        },
        file_path,
    )
Beispiel #6
0
def _restore(
    mdl: nn.Module, optimizer: optim.Optimizer,
    scheduler: optim.lr_scheduler._LRScheduler, ckpt_loc: str
) -> t.Tuple[nn.Module, optim.Optimizer, optim.lr_scheduler._LRScheduler, int,
             float]:
    """Restore model training state

    Args:
        mdl (nn.Module):
            The randomly initialized model
        optimizer (optim.Optimizer):
            The optimizer
        scheduler (optim.lr_scheduler._LRScheduler):
            The scheduler for learning rate
        ckpt_loc (str):
            Location to store model checkpoints

    Returns:
        t.Tuple[nn.Module,
                optim.Optimizer,
                optim.lr_scheduler._LRScheduler,
                int, float]:
            The restored status
    """
    # Restore model checkpoint
    mdl.load_state_dict(torch.load(os.path.join(ckpt_loc, 'mdl.ckpt')))
    optimizer.load_state_dict(
        torch.load(os.path.join(ckpt_loc, 'optimizer.ckpt')))
    scheduler.load_state_dict(
        torch.load(os.path.join(ckpt_loc, 'scheduler.ckpt')))

    # Restore timer and step counter
    with open(os.path.join(ckpt_loc, 'log.out')) as f:
        records = f.readlines()
        if records[-1] != 'Training finished\n':
            final_record = records[-1]
        else:
            final_record = records[-2]
    global_counter, t_final = final_record.split('\t')[:2]
    global_counter = int(global_counter)
    t_final = float(t_final)
    t0 = time.time() - t_final * 60

    return mdl, optimizer, scheduler, global_counter, t0
Beispiel #7
0
def train_run(model: nn.Module,
              train_dl: torch.utils.data.dataloader.DataLoader,
              criterion: nn.Module, optimizer: optim.Optimizer,
              scheduler: optim.lr_scheduler._LRScheduler, num_it: int,
              on_batch_end: Callable[[int, float, float],
                                     None], device: torch.device):
    'TODO: docstring'
    iterator = iter(train_dl)
    bar = tqdm(range(num_it))
    for i in bar:
        try:
            xs, ys = next(iterator)
        except StopIteration:
            iterator = iter(train_dl)
            xs, ys = next(iterator)
        xs = xs.to(device)
        ys = ys.to(device)
        loss = train_batch(xs, ys, model, criterion, optimizer)
        on_batch_end(bar, i, loss, scheduler.get_lr()[0])
        scheduler.step()
Beispiel #8
0
def _train_step(mdl: nn.Module, optimizer: optim.Optimizer,
                scheduler: optim.lr_scheduler._LRScheduler, min_lr: float,
                clip_grad: float, device: torch.device,
                iter_train: t.Iterator):
    """Helper function to perform one step of training

    Args:
         mdl (nn.Module):
            The randomly initialized model
        optimizer (optim.Optimizer):
            The optimizer
        scheduler (optim.lr_scheduler._LRScheduler):
            The scheduler for learning rate
        min_lr (float):
            The minimum learning rate
        clip_grad (float):
            Gradient clipping
        device (torch.device):
            The device where tensors should be intialized
        iter_train (t.Iterator):
            The iterator for trainer
    """
    # Prepare for training
    optimizer.zero_grad()  # Clear gradient
    if all([
            params_group['lr'] > min_lr
            for params_group in optimizer.param_groups
    ]):
        # Update learning rate if it is still larger than min_lr
        scheduler.step()

    # Get data
    mol_array, log_p = next(iter_train)
    loss = _loss(mol_array, log_p, mdl, device)
    loss.backward()

    # Clip gradient
    torch.nn.utils.clip_grad_value_(mdl.parameters(), clip_grad)

    optimizer.step()
    return loss
Beispiel #9
0
    def _train_step(
            self, rank: int, dataset: Dataset, model: nn.Module,
            optimizer: optim.Optimizer,
            scheduler: optim.lr_scheduler._LRScheduler) -> Dict[str, float]:
        model.train()
        optimizer.zero_grad()

        data = self._fetch_from(dataset, rank, self.config.batch_train)
        metrics = self.spec.train_objective(data, model)
        loss = metrics['loss']

        if self.config.use_amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        optimizer.step()
        scheduler.step()

        return {k: self._to_value(v) for k, v in metrics.items()}
Beispiel #10
0
 def _run_epoch(self, optimizer: optim.Optimizer,
                scheduler: optim.lr_scheduler._LRScheduler,
                criterion: nn.Module):
     lr_step = self.min_lr
     with tqdm(self.loader,
               postfix=["Current state: ",
                        dict(loss=0, lr=lr_step)]) as t:
         for data, target in t:
             data = data.to(self.device)
             target = target.to(self.device)
             optimizer.zero_grad()
             output = self.model(data)
             loss = criterion(output, target)
             loss.backward()
             optimizer.step()
             scheduler.step()
             lr_step = optimizer.state_dict()['param_groups'][0]['lr']
             self.lr_s += [lr_step]
             self.losses += [loss.item()]
             t.postfix[1]['loss'] = loss.item()
             t.postfix[1]['lr'] = lr_step
             t.update()
Beispiel #11
0
def _save(mdl: nn.Module, optimizer: optim.Optimizer,
          scheduler: optim.lr_scheduler._LRScheduler, global_counter: int,
          t0: float, loss: float, current_lr: float, ckpt_loc: str) -> str:
    """Saving checkpoint to file

    Args:
        mdl (nn.Module):
            The randomly initialized model
        optimizer (optim.Optimizer):
            The optimizer
        scheduler (optim.lr_scheduler._LRScheduler):
            The scheduler for learning rate
        global_counter (int):
            The global counter for training
        t0 (float):
            The time training was started
        loss (float):
            The loss of the model
        current_lr (float):
            The current learning rate
        ckpt_loc (str):
            Location to store model checkpoints

    Return:
        str:
            The message string
    """
    # Save status
    torch.save(mdl.state_dict(), os.path.join(ckpt_loc, 'mdl.ckpt'))
    torch.save(optimizer.state_dict(), os.path.join(ckpt_loc,
                                                    'optimizer.ckpt'))
    torch.save(scheduler.state_dict(), os.path.join(ckpt_loc,
                                                    'scheduler.ckpt'))

    message_str = (f'{global_counter}\t'
                   f'{float(time.time() - t0) / 60}\t'
                   f'{loss}\t'
                   f'{current_lr}\n')
    return message_str
Beispiel #12
0
    def adv_train(self, epoch: int, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler = None,
                  validate_interval=10, save=False, verbose=True, indent=0, epoch_fn: Callable = None,
                  **kwargs):
        loader_train = self.dataset.loader['train']
        file_path = self.folder_path + self.get_filename() + '.pth'

        _, best_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs)

        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        params = [param_group['params'] for param_group in optimizer.param_groups]
        for _epoch in range(epoch):
            if callable(epoch_fn):
                self.model.activate_params([])
                epoch_fn()
                self.model.activate_params(params)
            losses.reset()
            top1.reset()
            top5.reset()
            epoch_start = time.perf_counter()
            if verbose and env['tqdm']:
                loader_train = tqdm(loader_train)
            optimizer.zero_grad()
            for data in loader_train:
                _input, _label = self.model.get_data(data)
                noise = torch.zeros_like(_input)

                poison_input, poison_label = self.get_poison_data(data)

                def loss_fn(X: torch.FloatTensor):
                    return -self.model.loss(X, _label)
                adv_x = _input
                self.model.train()
                loss = self.model.loss(adv_x, _label)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                for m in range(self.pgd.iteration):
                    self.model.eval()
                    adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=loss_fn, iteration=1)

                    optimizer.zero_grad()
                    self.model.train()

                    x = torch.cat((adv_x, poison_input))
                    y = torch.cat((_label, poison_label))
                    loss = self.model.loss(x, y)
                    loss.backward()
                    optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    _output = self.model.get_logits(_input)
                acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                losses.update(loss.item(), batch_size)
                top1.update(acc1, batch_size)
                top5.update(acc5, batch_size)
            epoch_time = str(datetime.timedelta(seconds=int(
                time.perf_counter() - epoch_start)))
            self.model.eval()
            self.model.activate_params([])
            if verbose:
                pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                    output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35)
                _str = ' '.join([
                    f'Loss: {losses.avg:.4f},'.ljust(20),
                    f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30),
                    f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30),
                    f'Time: {epoch_time},'.ljust(20),
                ])
                prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '',
                       indent=indent)
            if lr_scheduler:
                lr_scheduler.step()

            if validate_interval != 0:
                if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1:
                    _, cur_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs)
                    if cur_acc < best_acc:
                        prints('best result update!', indent=indent)
                        prints(f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}', indent=indent)
                        best_acc = cur_acc
                    if save:
                        self.save()
                    if verbose:
                        print('-' * 50)
        self.model.zero_grad()
Beispiel #13
0
    def adv_train(self,
                  epochs: int,
                  optimizer: optim.Optimizer,
                  lr_scheduler: optim.lr_scheduler._LRScheduler = None,
                  validate_interval=10,
                  save=False,
                  verbose=True,
                  indent=0,
                  **kwargs):
        loader_train = self.dataset.loader['train']
        file_path = os.path.join(self.folder_path,
                                 self.get_filename() + '.pth')

        best_acc, _ = self.validate_fn(verbose=verbose,
                                       indent=indent,
                                       **kwargs)

        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        params: list[nn.Parameter] = []
        for param_group in optimizer.param_groups:
            params.extend(param_group['params'])
        for _epoch in range(epochs):
            losses.reset()
            top1.reset()
            top5.reset()
            epoch_start = time.perf_counter()
            if verbose and env['tqdm']:
                loader_train = tqdm(loader_train)
            self.model.activate_params(params)
            optimizer.zero_grad()
            for data in loader_train:
                _input, _label = self.model.get_data(data)
                noise = torch.zeros_like(_input)
                adv_x = _input
                self.model.train()
                loss = self.model.loss(adv_x, _label)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                for m in range(self.pgd.iteration):
                    self.model.eval()
                    adv_x, _ = self.pgd.optimize(_input=_input,
                                                 noise=noise,
                                                 target=_label,
                                                 iteration=1)
                    optimizer.zero_grad()
                    self.model.train()
                    loss = self.model.loss(adv_x, _label)
                    loss.backward()
                    optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    _output = self.model(_input)
                acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                losses.update(loss.item(), batch_size)
                top1.update(acc1, batch_size)
                top5.update(acc5, batch_size)
            epoch_time = str(
                datetime.timedelta(seconds=int(time.perf_counter() -
                                               epoch_start)))
            self.model.eval()
            self.model.activate_params([])
            if verbose:
                pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                    output_iter(_epoch + 1, epochs),
                    **ansi).ljust(64 if env['color'] else 35)
                _str = ' '.join([
                    f'Loss: {losses.avg:.4f},'.ljust(20),
                    f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30),
                    f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30),
                    f'Time: {epoch_time},'.ljust(20),
                ])
                prints(pre_str,
                       _str,
                       prefix='{upline}{clear_line}'.format(
                           **ansi) if env['tqdm'] else '',
                       indent=indent)
            if lr_scheduler:
                lr_scheduler.step()

            if validate_interval != 0:
                if (_epoch +
                        1) % validate_interval == 0 or _epoch == epochs - 1:
                    adv_acc, _ = self.validate_fn(verbose=verbose,
                                                  indent=indent,
                                                  **kwargs)
                    if adv_acc < best_acc:
                        prints('{purple}best result update!{reset}'.format(
                            **ansi),
                               indent=indent)
                        prints(
                            f'Current Acc: {adv_acc:.3f}    Previous Best Acc: {best_acc:.3f}',
                            indent=indent)
                        best_acc = adv_acc
                    if save:
                        self.model.save(file_path=file_path, verbose=verbose)
                    if verbose:
                        print('-' * 50)
        self.model.zero_grad()
Beispiel #14
0
def _iter_impl(epoch: int, phase: str, data_loader: DataLoader, device: str,
               model: nn.Module, criterion: nn.Module,
               optimizer: optim.Optimizer,
               scheduler: optim.lr_scheduler._LRScheduler,
               larger_holder: LargerHolder, baseline_flops: float,
               flops_tester: FLOPs, logger: logging.Logger,
               output_directory: str, writer: SummaryWriter,
               log_frequency: int):
    start = datetime.now()

    clear_statistics(model)
    model.train(phase == "train")
    loss_metric = AverageMetric()
    accuracy_metric = AccuracyMetric(topk=(1, 5))

    for iter_, (datas, targets) in enumerate(data_loader, start=1):
        datas, targets = datas.to(device=device), targets.to(device=device)
        with torch.set_grad_enabled(phase == "train"):
            outputs = model(datas)
        loss = criterion(outputs, targets)

        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        loss_metric.update(loss)
        accuracy_metric.update(targets, outputs)

        if iter_ % log_frequency == 0:
            logger.info(
                f"{phase.upper()}, epoch={epoch:03d}, iter={iter_}/{len(data_loader)}, "
                f"loss={loss_metric.last:.4f}({loss_metric.value:.4f}), "
                f"accuracy@1={accuracy_metric.last_accuracy(1).rate*100:.2f}%"
                f"({accuracy_metric.accuracy(1).rate*100:.2f}%), "
                f"accuracy@5={accuracy_metric.last_accuracy(5).rate*100:.2f}%"
                f"({accuracy_metric.accuracy(5).rate*100:.2f}%), ")

    if phase != "train":
        acc = accuracy_metric.accuracy(1).rate
        if larger_holder.update(new_value=acc, metadata=dict(epoch=epoch)):
            if output_directory is not None:
                torch.save(model.state_dict(),
                           os.path.join(output_directory, "best_model.pth"))

    if scheduler is not None:
        scheduler.step()

    flops = flops_tester.compute()
    logger.info(
        f"{phase.upper()} Complete, epoch={epoch:03d}, "
        f"loss={loss_metric.value:.4f}, "
        f"accuracy@1={accuracy_metric.accuracy(1).rate*100:.2f}%, "
        f"accuracy@5={accuracy_metric.accuracy(5).rate*100:.2f}%, "
        f"flops={flops/1e6:.2f}M({flops/baseline_flops*100:.2f}%), "
        f"best_accuracy={larger_holder.value*100:.2f}%(epoch={larger_holder.metadata['epoch']:03d}), "
        f"propotions={network_proportion(model)}, "
        f"eplased time={datetime.now()-start}.")

    writer.add_scalar(f"{phase}/loss", loss_metric.value, epoch)
    writer.add_scalar(f"{phase}/accuracy@1",
                      accuracy_metric.accuracy(1).rate, epoch)
    writer.add_scalar(f"{phase}/accuracy@5",
                      accuracy_metric.accuracy(5).rate, epoch)
Beispiel #15
0
def train_model(
    train_ds: tf.data.Dataset,
    dev_ds: tf.data.Dataset,
    model: nn.Module,
    optimizer: optim.Optimizer,
    lr_scheduler: optim.lr_scheduler._LRScheduler,
    args: argparse.Namespace,
) -> nn.Module:

    device = model_utils.get_device()
    loss_fn = model_utils.depth_proportional_loss
    val_loss_fn = model_utils.l1_norm_loss
    best_val_loss = torch.tensor(float('inf'))
    saved_checkpoints = []
    writer = SummaryWriter(log_dir=f'{args.log_dir}/{args.experiment}')

    cos = nn.CosineSimilarity(dim=1, eps=0)
    get_gradient: nn.Module = sobel.Sobel().to(device)

    for e in range(1, args.train_epochs + 1):
        print(f'Training epoch {e}...')

        if args.use_scheduler:
            lr_scheduler.step()

        # Training portion
        torch.cuda.empty_cache()
        torch.set_grad_enabled(True)
        with tqdm(total=args.train_batch_size * len(train_ds)) as progress_bar:
            model.train()
            for i, (x_batch_orig,
                    y_batch) in enumerate(train_ds.as_numpy_iterator()):
                x_batch, y_batch = model_utils.preprocess_training_example(
                    x_batch_orig, y_batch)
                y_blurred = model_utils.blur_depth_map(y_batch)

                ones = torch.ones(y_batch.shape,
                                  dtype=torch.float32,
                                  device=device)

                # Forward pass on model
                optimizer.zero_grad()
                y_pred = model(x_batch)

                depth_grad = get_gradient(y_blurred)
                output_grad = get_gradient(y_pred)
                depth_grad_dx = depth_grad[:, 0, :, :].contiguous().view_as(
                    y_blurred)
                depth_grad_dy = depth_grad[:, 1, :, :].contiguous().view_as(
                    y_batch)
                output_grad_dx = output_grad[:, 0, :, :].contiguous().view_as(
                    y_blurred)
                output_grad_dy = output_grad[:, 1, :, :].contiguous().view_as(
                    y_batch)

                depth_normal = torch.cat(
                    (-depth_grad_dx, -depth_grad_dy, ones), 1)
                output_normal = torch.cat(
                    (-output_grad_dx, -output_grad_dy, ones), 1)

                loss_depth = torch.log(torch.abs(y_pred - y_batch) +
                                       0.5).mean()
                loss_dx = torch.log(
                    torch.abs(output_grad_dx - depth_grad_dx) + 0.5).mean()
                loss_dy = torch.log(
                    torch.abs(output_grad_dy - depth_grad_dy) + 0.5).mean()
                loss_normal = torch.abs(
                    1 - cos(output_normal, depth_normal)).mean()

                loss = loss_depth + loss_normal + (loss_dx + loss_dy)

                # Backward pass and optimization
                loss.backward()
                optimizer.step()

                progress_bar.update(len(x_batch))
                progress_bar.set_postfix(loss=loss.item())
                writer.add_scalar("train/Loss", loss,
                                  ((e - 1) * len(train_ds) + i) *
                                  args.train_batch_size)

                # Periodically save a diagram
                if (i + 1) % args.picture_frequency == 0:
                    model_utils.make_diagram(
                        np.transpose(x_batch_orig, (0, 3, 1, 2)),
                        x_batch.cpu().numpy(),
                        y_batch.cpu().numpy(),
                        y_pred.cpu().detach().numpy(),
                        f'{args.save_path}/{args.experiment}/diagram_{e}_{i+1}.png',
                    )

                del x_batch
                del y_batch
                del y_blurred
                del y_pred
                del loss

        # Validation portion
        torch.cuda.empty_cache()
        torch.set_grad_enabled(False)

        with tqdm(total=args.dev_batch_size * len(dev_ds)) as progress_bar:
            model.eval()
            val_loss = 0.0
            num_batches_processed = 0
            total_pixels = 0
            total_examples = 0
            squared_error = 0
            rel_error = 0
            log_error = 0
            threshold1 = 0  # 1.25
            threshold2 = 0  # 1.25^2
            threshold3 = 0  # corresponds to 1.25^3

            for i, (x_batch, y_batch) in enumerate(dev_ds.as_numpy_iterator()):
                x_batch, y_batch = model_utils.preprocess_test_example(
                    x_batch, y_batch)
                # Forward pass on model in validation environment
                y_pred = model(x_batch)

                # TODO: Process y_pred in whatever way inference requires.
                loss = val_loss_fn(y_pred, y_batch)
                val_loss += loss.item()
                num_batches_processed += 1

                nanmask = getNanMask(y_batch)
                total_pixels = torch.sum(~nanmask)
                total_examples += x_batch.shape[0]

                # RMS, REL, LOG10, threshold calculation
                squared_error += (
                    torch.sum(torch.pow(y_pred - y_batch, 2)).item() /
                    total_pixels)**0.5
                rel_error += torch.sum(
                    removeNans(torch.abs(y_pred - y_batch) /
                               y_batch)).item() / total_pixels
                log_error += torch.sum(
                    torch.abs(
                        removeNans(torch.log10(y_pred)) - removeNans(
                            torch.log10(y_batch)))).item() / total_pixels
                threshold1 += torch.sum(
                    torch.max(y_pred / y_batch, y_batch /
                              y_pred) < 1.25).item() / total_pixels
                threshold2 += torch.sum(
                    torch.max(y_pred / y_batch, y_batch /
                              y_pred) < 1.25**2).item() / total_pixels
                threshold3 += torch.sum(
                    torch.max(y_pred / y_batch, y_batch /
                              y_pred) < 1.25**3).item() / total_pixels

                progress_bar.update(len(x_batch))
                progress_bar.set_postfix(val_loss=val_loss /
                                         num_batches_processed)
                writer.add_scalar("Val/Loss", loss,
                                  ((e - 1) * len(dev_ds) + i) *
                                  args.dev_batch_size)

                del x_batch
                del y_batch
                del y_pred
                del loss

            writer.add_scalar("Val/RMS", squared_error / total_examples, e)
            writer.add_scalar("Val/REL", rel_error / total_examples, e)
            writer.add_scalar("Val/LOG10", log_error / total_examples, e)
            writer.add_scalar("Val/delta1", threshold1 / total_examples, e)
            writer.add_scalar("Val/delta2", threshold2 / total_examples, e)
            writer.add_scalar("Val/delta3", threshold3 / total_examples, e)

            # Save model if it's the best one yet.
            if val_loss / num_batches_processed < best_val_loss:
                best_val_loss = val_loss / num_batches_processed
                filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_best_val.checkpoint'
                model_utils.save_model(model, filename)
                print(f'Model saved!')
                print(f'Best validation loss yet: {best_val_loss}')
            # Save model on checkpoints.
            if e % args.checkpoint_freq == 0:
                filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_epoch_{e}.checkpoint'
                model_utils.save_model(model, filename)
                print(f'Model checkpoint reached!')
                saved_checkpoints.append(filename)
                # Delete checkpoints if there are too many
                while len(saved_checkpoints) > args.num_checkpoints:
                    os.remove(saved_checkpoints.pop(0))

    return model
Beispiel #16
0
def train(
    logger: lavd.Logger,
    model: nn.Module,
    optimiser: optim.Optimizer,  # type: ignore
    train_data_loader: DataLoader,
    validation_data_loaders: DataLoader,
    lr_scheduler: optim.lr_scheduler._LRScheduler,
    device: torch.device,
    checkpoint: Dict,
    num_epochs: int = num_epochs,
    model_kind: str = default_model,
    amp_scaler: Optional[amp.GradScaler] = None,
    masked_lm: bool = True,
):
    start_epoch = checkpoint["epoch"]
    train_stats = checkpoint["train"]
    validation_cp = checkpoint["validation"]
    outdated_validations = checkpoint["outdated_validation"]

    validation_results_dict: Dict[str, Dict] = OrderedDict()
    for val_data_loader in validation_data_loaders:
        val_name = val_data_loader.dataset.name
        val_result = (validation_cp[val_name] if val_name in validation_cp else
                      OrderedDict(start=start_epoch,
                                  stats=OrderedDict(loss=[], perplexity=[])))
        validation_results_dict[val_name] = val_result

    # All validations that are no longer used, will be stored in outdated_validation
    # just to have them available.
    outdated_validations.append(
        OrderedDict({
            k: v
            for k, v in validation_cp.items()
            if k not in validation_results_dict
        }))

    tokeniser = train_data_loader.dataset.tokeniser  # type: ignore
    for epoch in range(num_epochs):
        actual_epoch = start_epoch + epoch + 1
        epoch_text = "[{current:>{pad}}/{end}] Epoch {epoch}".format(
            current=epoch + 1,
            end=num_epochs,
            epoch=actual_epoch,
            pad=len(str(num_epochs)),
        )
        logger.set_prefix(epoch_text)
        logger.start(epoch_text, prefix=False)
        start_time = time.time()

        logger.start("Train")
        train_result = run_epoch(
            train_data_loader,
            model,
            optimiser,
            device=device,
            epoch=epoch,
            train=True,
            name="Train",
            logger=logger,
            amp_scaler=amp_scaler,
            masked_lm=masked_lm,
        )
        train_stats["stats"]["loss"].append(train_result["loss"])
        train_stats["stats"]["perplexity"].append(train_result["perplexity"])
        epoch_lr = lr_scheduler.get_last_lr()[0]  # type: ignore
        train_stats["lr"].append(epoch_lr)
        lr_scheduler.step()
        logger.end("Train")

        validation_results = []
        for val_data_loader in validation_data_loaders:
            val_name = val_data_loader.dataset.name
            val_text = "Validation: {}".format(val_name)
            logger.start(val_text)
            validation_result = run_epoch(
                val_data_loader,
                model,
                optimiser,
                device=device,
                epoch=epoch,
                train=False,
                name=val_text,
                logger=logger,
                amp_scaler=amp_scaler,
                masked_lm=masked_lm,
            )
            validation_results.append(
                OrderedDict(name=val_name, stats=validation_result))
            validation_results_dict[val_name]["stats"]["loss"].append(
                validation_result["loss"])
            validation_results_dict[val_name]["stats"]["perplexity"].append(
                validation_result["perplexity"])
            logger.end(val_text)

        with logger.spinner("Checkpoint", placement="right"):
            # Multi-gpu models wrap the original model. To make the checkpoint
            # compatible with the original model, the state dict of .module is saved.
            model_unwrapped = (model.module if isinstance(
                model, DistributedDataParallel) else model)
            save_checkpoint(
                logger,
                model_unwrapped,
                tokeniser,
                stats=OrderedDict(
                    epoch=actual_epoch,
                    train=train_stats,
                    validation=validation_results_dict,
                    outdated_validation=outdated_validations,
                    model=OrderedDict(kind=model_kind),
                ),
                step=actual_epoch,
            )

        with logger.spinner("Logging Data", placement="right"):
            log_results(
                logger,
                actual_epoch,
                OrderedDict(lr=epoch_lr, stats=train_result),
                validation_results,
                model_unwrapped,
            )

        with logger.spinner("Best Checkpoints", placement="right"):
            val_stats = OrderedDict({
                val_name: {
                    "name": val_name,
                    "start": val_result["start"],
                    "stats": val_result["stats"],
                }
                for val_name, val_result in validation_results_dict.items()
            })
            log_top_checkpoints(logger, val_stats, metrics)

        time_difference = time.time() - start_time
        epoch_results = [OrderedDict(name="Train", stats=train_result)
                         ] + validation_results
        log_epoch_stats(logger,
                        epoch_results,
                        metrics,
                        lr=epoch_lr,
                        time_elapsed=time_difference)
        logger.end(epoch_text, prefix=False)
Beispiel #17
0
def train_model(
    train_graph: pyg.torch_geometric.data.Data,
    valid_graph: pyg.torch_geometric.data.Data,
    train_dl: data.DataLoader,
    dev_dl: data.DataLoader,
    evaluator: Evaluator,
    model: nn.Module,
    optimizer: optim.Optimizer,
    lr_scheduler: optim.lr_scheduler._LRScheduler,
    args: argparse.Namespace,
) -> nn.Module:

    device = model_utils.get_device()
    loss_fn = nn.functional.binary_cross_entropy
    val_loss_fn = nn.functional.binary_cross_entropy
    best_val_loss = torch.tensor(float('inf'))
    best_val_hits = torch.tensor(0.0)
    saved_checkpoints = []
    writer = SummaryWriter(log_dir=f'{args.log_dir}/{args.experiment}')

    for e in range(1, args.train_epochs + 1):
        print(f'Training epoch {e}...')

        # Training portion
        torch.cuda.empty_cache()
        torch.set_grad_enabled(True)
        with tqdm(total=args.train_batch_size * len(train_dl)) as progress_bar:
            model.train()

            # Load graph into GPU
            adj_t = train_graph.adj_t.to(device)
            edge_index = train_graph.edge_index.to(device)
            x = train_graph.x.to(device)

            pos_pred = []
            neg_pred = []

            for i, (y_pos_edges,) in enumerate(train_dl):
                y_pos_edges = y_pos_edges.to(device).T
                y_neg_edges = negative_sampling(
                    edge_index,
                    num_nodes=train_graph.num_nodes,
                    num_neg_samples=y_pos_edges.shape[1]
                ).to(device)
                y_batch = torch.cat([torch.ones(y_pos_edges.shape[1]), torch.zeros(
                    y_neg_edges.shape[1])], dim=0).to(device)  # Ground truth edge labels (1 or 0)

                # Forward pass on model
                optimizer.zero_grad()
                y_pred = model(adj_t, torch.cat(
                    [y_pos_edges, y_neg_edges], dim=1))
                loss = loss_fn(y_pred, y_batch)

                # Backward pass and optimization
                loss.backward()
                optimizer.step()
                if args.use_scheduler:
                    lr_scheduler.step(loss)

                batch_acc = torch.mean(
                    1 - torch.abs(y_batch.detach() - torch.round(y_pred.detach()))).item()

                pos_pred += [y_pred[y_batch == 1].detach()]
                neg_pred += [y_pred[y_batch == 0].detach()]

                progress_bar.update(y_pos_edges.shape[1])
                progress_bar.set_postfix(loss=loss.item(), acc=batch_acc)
                writer.add_scalar(
                    "train/Loss", loss, ((e - 1) * len(train_dl) + i) * args.train_batch_size)
                writer.add_scalar("train/Accuracy", batch_acc,
                                  ((e - 1) * len(train_dl) + i) * args.train_batch_size)

                del y_pos_edges
                del y_neg_edges
                del y_pred
                del loss

            del adj_t
            del edge_index
            del x

            # Training set evaluation Hits@K Metrics
            pos_pred = torch.cat(pos_pred, dim=0)
            neg_pred = torch.cat(neg_pred, dim=0)
            results = {}
            for K in [10, 20, 30]:
                evaluator.K = K
                hits = evaluator.eval({
                    'y_pred_pos': pos_pred,
                    'y_pred_neg': neg_pred,
                })[f'hits@{K}']
                results[f'Hits@{K}'] = hits
            print()
            print(f'Train Statistics')
            print('*' * 30)
            for k, v in results.items():
                print(f'{k}: {v}')
                writer.add_scalar(
                    f"train/{k}", v, (pos_pred.shape[0] + neg_pred.shape[0]) * e)
            print('*' * 30)

            del pos_pred
            del neg_pred

        # Validation portion
        torch.cuda.empty_cache()
        torch.set_grad_enabled(False)
        with tqdm(total=args.val_batch_size * len(dev_dl)) as progress_bar:
            model.eval()

            adj_t = valid_graph.adj_t.to(device)
            edge_index = valid_graph.edge_index.to(device)
            x = valid_graph.x.to(device)

            val_loss = 0.0
            accuracy = 0
            num_samples_processed = 0
            pos_pred = []
            neg_pred = []
            for i, (edges_batch, y_batch) in enumerate(dev_dl):
                edges_batch = edges_batch.T.to(device)
                y_batch = y_batch.to(device)

                # Forward pass on model in validation environment
                y_pred = model(adj_t, edges_batch)
                loss = val_loss_fn(y_pred, y_batch)

                num_samples_processed += edges_batch.shape[1]
                batch_acc = torch.mean(
                    1 - torch.abs(y_batch - torch.round(y_pred))).item()
                accuracy += batch_acc * edges_batch.shape[1]
                val_loss += loss.item() * edges_batch.shape[1]

                pos_pred += [y_pred[y_batch == 1].detach()]
                neg_pred += [y_pred[y_batch == 0].detach()]

                progress_bar.update(edges_batch.shape[1])
                progress_bar.set_postfix(
                    val_loss=val_loss / num_samples_processed,
                    acc=accuracy/num_samples_processed)
                writer.add_scalar(
                    "Val/Loss", loss, ((e - 1) * len(dev_dl) + i) * args.val_batch_size)
                writer.add_scalar(
                    "Val/Accuracy", batch_acc, ((e - 1) * len(dev_dl) + i) * args.val_batch_size)

                del edges_batch
                del y_batch
                del y_pred
                del loss

            del adj_t
            del edge_index
            del x

            # Validation evaluation Hits@K Metrics
            pos_pred = torch.cat(pos_pred, dim=0)
            neg_pred = torch.cat(neg_pred, dim=0)
            results = {}
            for K in [10, 20, 30]:
                evaluator.K = K
                hits = evaluator.eval({
                    'y_pred_pos': pos_pred,
                    'y_pred_neg': neg_pred,
                })[f'hits@{K}']
                results[f'Hits@{K}'] = hits
            print()
            print(f'Validation Statistics')
            print('*' * 30)
            for k, v in results.items():
                print(f'{k}: {v}')
                writer.add_scalar(
                    f"Val/{k}", v, (pos_pred.shape[0] + neg_pred.shape[0]) * e)
            print('*' * 30)

            del pos_pred
            del neg_pred

            # Save model if it's the best one yet.
            if results['Hits@20'] > best_val_hits:
                best_val_hits = results['Hits@20']
                filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_best_val.checkpoint'
                model_utils.save_model(model, filename)
                print(f'Model saved!')
                print(f'Best validation Hits@20 yet: {best_val_hits}')
            # Save model on checkpoints.
            if e % args.checkpoint_freq == 0:
                filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_epoch_{e}.checkpoint'
                model_utils.save_model(model, filename)
                print(f'Model checkpoint reached!')
                saved_checkpoints.append(filename)
                # Delete checkpoints if there are too many
                while len(saved_checkpoints) > args.num_checkpoints:
                    os.remove(saved_checkpoints.pop(0))

    return model
Beispiel #18
0
def train(epoch: int, model: nn.Module, loader: data.DataLoader,
          criterion: nn.modules.loss._Loss, optimizer: optim.Optimizer,
          scheduler: optim.lr_scheduler._LRScheduler, only_epoch_sche: bool,
          use_amp: bool, accmulated_steps: int, device: str,
          log_interval: int):
    model.train()

    scaler = GradScaler() if use_amp else None

    gradident_accumulator = GradientAccumulator(accmulated_steps)

    loss_metric = AverageMetric("loss")
    accuracy_metric = AccuracyMetric(topk=(1, 5))
    ETA = EstimatedTimeArrival(len(loader))
    speed_tester = SpeedTester()

    lr = optimizer.param_groups[0]['lr']
    _logger.info(f"Train start, epoch={epoch:04d}, lr={lr:.6f}")

    for time_cost, iter_, (inputs, targets) in time_enumerate(loader, start=1):
        inputs, targets = inputs.to(device=device), targets.to(device=device)

        optimizer.zero_grad()

        with autocast(enabled=use_amp):
            outputs = model(inputs)
            loss: torch.Tensor = criterion(outputs, targets)

        gradident_accumulator.backward_step(model, loss, optimizer, scaler)

        if scheduler is not None:
            if only_epoch_sche:
                if iter_ == 1:
                    scheduler.step()
            else:
                scheduler.step()

        loss_metric.update(loss)
        accuracy_metric.update(outputs, targets)
        ETA.step()
        speed_tester.update(inputs)

        if iter_ % log_interval == 0 or iter_ == len(loader):
            _logger.info(", ".join([
                "TRAIN",
                f"epoch={epoch:04d}",
                f"iter={iter_:05d}/{len(loader):05d}",
                f"fetch data time cost={time_cost*1000:.2f}ms",
                f"fps={speed_tester.compute()*world_size():.0f} images/s",
                f"{loss_metric}",
                f"{accuracy_metric}",
                f"{ETA}",
            ]))
            speed_tester.reset()

    return {
        "lr": lr,
        "train/loss": loss_metric.compute(),
        "train/top1_acc": accuracy_metric.at(1).rate,
        "train/top5_acc": accuracy_metric.at(5).rate,
    }
def train_model(
    train_dl: data.DataLoader,
    dev_dl: data.DataLoader,
    model: nn.Module,
    optimizer: optim.Optimizer,
    lr_scheduler: optim.lr_scheduler._LRScheduler,
    args: argparse.Namespace,
) -> nn.Module:

    device = model_utils.get_device()
    # loss_fn = nn.functional.binary_cross_entropy
    loss_fn = model_utils.l1_norm_loss
    val_loss_fn = model_utils.l1_norm_loss
    best_val_loss = torch.tensor(float('inf'))
    saved_checkpoints = []
    writer = SummaryWriter(log_dir=f'{args.log_dir}/{args.experiment}')
    scalar_rand = torch.distributions.uniform.Uniform(0.5, 1.5)

    for e in range(1, args.train_epochs + 1):
        print(f'Training epoch {e}...')

        # Training portion
        torch.cuda.empty_cache()
        with tqdm(total=args.train_batch_size * len(train_dl)) as progress_bar:
            model.train()
            for i, (x_batch, y_batch_biden, y_batch_trump,
                    _) in enumerate(train_dl):
                # trump_scale = scalar_rand.sample()
                # biden_scale = scalar_rand.sample()
                # y_batch_biden = y_batch_biden * biden_scale
                # y_batch_trump = y_batch_trump * trump_scale
                # x_batch = (y_batch_trump + y_batch_biden).abs().to(device)
                x_batch = x_batch.abs().to(device)
                y_batch_biden = y_batch_biden.abs().to(device)
                y_batch_trump = y_batch_trump.abs().to(device)

                # Forward pass on model
                optimizer.zero_grad()
                y_pred_b, y_pred_t = model(x_batch)
                if args.train_trump:
                    # loss = loss_fn(y_pred_t * x_batch, y_batch_trump)
                    loss = loss_fn(y_pred_t, y_batch_trump)
                else:
                    # loss = loss_fn(y_pred_b * x_batch, y_batch_biden)
                    loss = loss_fn(y_pred_b, y_batch_biden)

                # Backward pass and optimization
                loss.backward()
                optimizer.step()
                if args.use_scheduler:
                    lr_scheduler.step(loss)

                progress_bar.update(len(x_batch))
                progress_bar.set_postfix(loss=loss.item())
                writer.add_scalar("train/Loss", loss,
                                  ((e - 1) * len(train_dl) + i) *
                                  args.train_batch_size)

                del x_batch
                del y_batch_biden
                del y_batch_trump
                del y_pred_b
                del y_pred_t
                del loss

        # Validation portion
        torch.cuda.empty_cache()
        with tqdm(total=args.val_batch_size * len(dev_dl)) as progress_bar:
            model.eval()
            val_loss = 0.0
            num_batches_processed = 0
            for i, (x_batch, y_batch_biden, y_batch_trump,
                    _) in enumerate(dev_dl):
                x_batch = x_batch.abs().to(device)
                y_batch_biden = y_batch_biden.abs().to(device)
                y_batch_trump = y_batch_trump.abs().to(device)

                # Forward pass on model
                y_pred_b, y_pred_t = model(x_batch)
                # y_pred_b_mask = torch.ones_like(y_pred_b) * (y_pred_b > args.alpha)
                # y_pred_t_mask = torch.ones_like(y_pred_t) * (y_pred_t > args.alpha)
                y_pred_b_mask = torch.clamp(y_pred_b / x_batch, 0, 1)
                y_pred_t_mask = torch.clamp(y_pred_t / x_batch, 0, 1)

                loss_trump = val_loss_fn(y_pred_t_mask * x_batch,
                                         y_batch_trump)
                loss_biden = val_loss_fn(y_pred_b_mask * x_batch,
                                         y_batch_biden)

                if args.train_trump:
                    val_loss += loss_trump.item()
                else:
                    val_loss += loss_biden.item()
                num_batches_processed += 1

                progress_bar.update(len(x_batch))
                progress_bar.set_postfix(val_loss=val_loss /
                                         num_batches_processed)
                writer.add_scalar("Val/Biden Loss", loss_biden,
                                  ((e - 1) * len(dev_dl) + i) *
                                  args.val_batch_size)
                writer.add_scalar("Val/Trump Loss", loss_trump,
                                  ((e - 1) * len(dev_dl) + i) *
                                  args.val_batch_size)

                del x_batch
                del y_batch_biden
                del y_batch_trump
                del y_pred_b
                del y_pred_t
                del loss_trump
                del loss_biden

            # Save model if it's the best one yet.
            if val_loss / num_batches_processed < best_val_loss:
                best_val_loss = val_loss / num_batches_processed
                filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_best_val.checkpoint'
                model_utils.save_model(model, filename)
                print(f'Model saved!')
                print(f'Best validation loss yet: {best_val_loss}')
            # Save model on checkpoints.
            if e % args.checkpoint_freq == 0:
                filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_epoch_{e}.checkpoint'
                model_utils.save_model(model, filename)
                print(f'Model checkpoint reached!')
                saved_checkpoints.append(filename)
                # Delete checkpoints if there are too many
                while len(saved_checkpoints) > args.num_checkpoints:
                    os.remove(saved_checkpoints.pop(0))

    return model