def do_epoch(args: argparse.Namespace,
             train_loader: torch.utils.data.DataLoader, model: DDP,
             optimizer: torch.optim.Optimizer,
             scheduler: torch.optim.lr_scheduler, epoch: int,
             callback: VisdomLogger, iter_per_epoch: int,
             log_iter: int) -> Tuple[torch.tensor, torch.tensor]:
    loss_meter = AverageMeter()
    train_losses = torch.zeros(log_iter).to(dist.get_rank())
    train_mIous = torch.zeros(log_iter).to(dist.get_rank())

    iterable_train_loader = iter(train_loader)

    if main_process(args):
        bar = tqdm(range(iter_per_epoch))
    else:
        bar = range(iter_per_epoch)

    for i in bar:
        model.train()
        current_iter = epoch * len(train_loader) + i + 1

        images, gt = iterable_train_loader.next()
        images = images.to(dist.get_rank(), non_blocking=True)
        gt = gt.to(dist.get_rank(), non_blocking=True)

        loss = compute_loss(
            args=args,
            model=model,
            images=images,
            targets=gt.long(),
            num_classes=args.num_classes_tr,
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if args.scheduler == 'cosine':
            scheduler.step()

        if i % args.log_freq == 0:
            model.eval()
            logits = model(images)
            intersection, union, target = intersectionAndUnionGPU(
                logits.argmax(1), gt, args.num_classes_tr, 255)
            if args.distributed:
                dist.all_reduce(loss)
                dist.all_reduce(intersection)
                dist.all_reduce(union)
                dist.all_reduce(target)

            allAcc = (intersection.sum() / (target.sum() + 1e-10))  # scalar
            mAcc = (intersection / (target + 1e-10)).mean()
            mIoU = (intersection / (union + 1e-10)).mean()
            loss_meter.update(loss.item() / dist.get_world_size())

            if main_process(args):
                if callback is not None:
                    t = current_iter / len(train_loader)
                    callback.scalar('loss_train_batch',
                                    t,
                                    loss_meter.avg,
                                    title='Loss')
                    callback.scalars(['mIoU', 'mAcc', 'allAcc'],
                                     t, [mIoU, mAcc, allAcc],
                                     title='Training metrics')
                    for index, param_group in enumerate(
                            optimizer.param_groups):
                        lr = param_group['lr']
                        callback.scalar('lr', t, lr, title='Learning rate')
                        break

                train_losses[int(i / args.log_freq)] = loss_meter.avg
                train_mIous[int(i / args.log_freq)] = mIoU

    if args.scheduler != 'cosine':
        scheduler.step()

    return train_mIous, train_losses
Пример #2
0
    def run(self,
            train_loader,
            val_loader,
            optimizer: torch.optim.Optimizer,
            n_epochs,
            scheduler=None,
            scheduler_after_each_batch=False,
            clip_grad_value=None,
            callbacks: list = None,
            metrics: OrderedDict = None):
        callbacks = [] if callbacks is None else callbacks
        metrics = OrderedDict([("loss", self.criterion)
                               ]) if metrics is None else metrics

        self.model.to(self.device)
        optimizer.zero_grad()
        self.callbacks_on_training_begin(callbacks)

        self.state.optimizer = optimizer
        self.state.scheduler = scheduler

        self.state.end_epoch = self.state.epoch + n_epochs
        for epoch in range(self.state.epoch, self.state.end_epoch):
            self.state.epoch = epoch

            self.callbacks_on_epoch_begin(callbacks)

            train_metrics_evaluator = SupervisedMetricsEvaluator(metrics)
            losses = []
            self.model.train(True)
            start_time = time.time()
            for iteration, (inp, target) in enumerate(train_loader):
                self.state.iteration = iteration

                inp, target = inp.to(self.device), target.to(self.device)
                output = self.model(inp)
                train_metrics_evaluator.process_batch(output.detach(),
                                                      target.detach())
                loss = self.criterion(output, target)

                losses.append(loss.item())

                loss.backward()
                if clip_grad_value is not None:
                    for group in optimizer.param_groups:
                        group_params = group['params']
                        torch.nn.utils.clip_grad_norm_(group_params,
                                                       clip_grad_value,
                                                       norm_type=2)
                optimizer.step()
                optimizer.zero_grad()
                if scheduler_after_each_batch and scheduler is not None:
                    scheduler.step()
            end_time = time.time()
            self.state.elapsed = end_time - start_time
            if scheduler is not None and not scheduler_after_each_batch:
                scheduler.step()
            self.model.train(False)

            self.state.run_avg_loss = np.mean(losses)

            # print(losses)
            self.state.metrics_per_category[
                "train"] = train_metrics_evaluator.compute_result_and_reset(
                    True)
            self.state.metrics_per_category["valid"] = run_supervised_metrics(
                self.model, metrics, val_loader, self.device)
            self.callbacks_on_epoch_end(callbacks)

        self.callbacks_on_training_end(callbacks)
Пример #3
0
def train(model: torch.nn.Module,
          train_dl: DataLoader,
          optimizer: torch.optim.Optimizer,
          scheduler: LambdaLR,
          validation_evaluator: ClassificationEvaluator,
          n_epochs: int,
          device: AnyStr,
          log_interval: int = 1,
          patience: int = 10,
          neg_class_weight: float = None,
          model_dir: str = "local",
          split: str = '') -> torch.nn.Module:
    best_loss = float('inf')
    patience_counter = 0
    best_f1 = 0.0
    weights_found = False
    loss_fn = torch.nn.CrossEntropyLoss(
        weight=torch.tensor([neg_class_weight, 1.]).to(device))

    # Main loop
    for ep in range(n_epochs):
        # Training loop
        for i, batch in enumerate(tqdm(train_dl)):
            model.train()
            optimizer.zero_grad()
            batch = tuple(t.to(device) for t in batch)
            input_ids = batch[0]
            masks = batch[1]
            labels = batch[2]

            (logits, ) = model(input_ids, attention_mask=masks)
            loss = loss_fn(logits.view(-1, 2), labels.view(-1))

            loss.backward()
            optimizer.step()
            scheduler.step()

        gc.collect()

        # Inline evaluation
        (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)

        # Saving the best model and early stopping
        if F1 > best_f1:
            weights_found = True
            best_model = model.state_dict()
            # best_loss = val_loss
            best_f1 = F1
            torch.save(model.state_dict(), f'{model_dir}/model_{split}.pth')
            patience_counter = 0
        else:
            patience_counter += 1
            # Stop training once we have lost patience
            if patience_counter == patience:
                break

    if weights_found == False:
        print("No good weights found, saving weights from last epoch")
        # Save one just in case
        torch.save(model.state_dict(), f'{model_dir}/model_{split}.pth')

    gc.collect()
    return best_f1
Пример #4
0
def train_model(model: nn.Module,
                train_loader: DataLoader,
                valid_loader: DataLoader,
                optimizer: torch.optim.Optimizer,
                criterion,
                scheduler,
                gput: bool = False,
                epochs: int = 5):
    losses = []
    metrics = {
        'train_accuracy': [],
        'valid_accuracy': [],
        'train_f1': [],
        'valid_f1': []
    }

    best_valid_loss = np.inf

    for n_epoch in range(epochs):

        train_losses = []
        train_preds = []
        train_targets = []
        valid_losses = []
        valid_preds = []
        valid_targets = []

        progress_bar = tqdm(total=len(train_loader.dataset),
                            desc=f'Epoch {n_epoch + 1} of {epochs}')

        model.train()

        for batch_index, (x, y) in enumerate(train_loader):

            if gpu:
                x = x.to(device)
                y = y.to(device)

            x = x.view(x.shape[0], -1)

            optimizer.zero_grad()

            preds = model(x)

            loss = criterion(preds, y)

            loss.backward()

            optimizer.step()

            if batch_index % 100 == 0:
                scheduler.step()
#                 print(scheduler.get_last_lr()) # debug

            train_losses.append(loss.item())
            losses.append(loss.item())

            train_preds.append(torch.argmax(preds, dim=1).view(-1, 1).numpy())
            train_targets.append(y.numpy())

            progress_bar.set_postfix(train_loss=np.mean(losses[-500:]))

            progress_bar.update(x.shape[0])

        progress_bar.close()

        model.eval()

        for x, y in valid_loader:
            x = x.view(x.shape[0], -1)

            with torch.no_grad():
                preds = model(x)

            valid_preds.append(torch.argmax(preds, dim=1).view(-1, 1).numpy())
            valid_targets.append(y.numpy())

            loss = criterion(preds, y)

            valid_losses.append(loss.item())

        print(50 * '-')
        print(f'Epoch {n_epoch + 1} results')
        print(
            f'Mean train loss: {np.mean(train_losses):.3f}, Mean valid loss: {np.mean(valid_losses):.3f}'
        )

        train_accuracy = accuracy_score(np.concatenate(train_targets),
                                        np.concatenate(train_preds))
        valid_accuracy = accuracy_score(np.concatenate(valid_targets),
                                        np.concatenate(valid_preds))

        train_f1 = f1_score(np.concatenate(train_targets),
                            np.concatenate(train_preds),
                            average='weighted')

        valid_f1 = f1_score(np.concatenate(valid_targets),
                            np.concatenate(valid_preds),
                            average='weighted')

        metrics['train_accuracy'].append(train_accuracy)
        metrics['valid_accuracy'].append(valid_accuracy)
        metrics['train_f1'].append(train_f1)
        metrics['valid_f1'].append(valid_f1)

        print(
            f'Train accuracy: {train_accuracy:.3f}, valid accuracy: {valid_accuracy:.3f}'
        )
        print(
            f'Train weighted F1: {train_f1:.3f}, valid weighted F1: {valid_f1:.3f}'
        )
        print(50 * '-')

    return model, losses, metrics, np.concatenate(valid_preds), np.concatenate(
        valid_targets)
Пример #5
0
def train_one_epoch(model: torch.nn.Module,
                    criterion: torch.nn.Module,
                    data_loader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device,
                    epoch: int,
                    max_norm: float = 0):
    model.train()
    criterion.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter(
        'class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    for samples, targets in metric_logger.log_every(data_loader, print_freq,
                                                    header):
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(samples)
        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()
                     if k in weight_dict)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {
            f'{k}_unscaled': v
            for k, v in loss_dict_reduced.items()
        }
        loss_dict_reduced_scaled = {
            k: v * weight_dict[k]
            for k, v in loss_dict_reduced.items() if k in weight_dict
        }
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        if max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

        metric_logger.update(loss=loss_value,
                             **loss_dict_reduced_scaled,
                             **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
Пример #6
0
def train(
    dataset: torch.utils.data.Dataset,
    autoencoder: torch.nn.Module,
    epochs: int,
    batch_size: int,
    optimizer: torch.optim.Optimizer,
    scheduler: Any = None,
    validation: Optional[torch.utils.data.Dataset] = None,
    corruption: Optional[float] = None,
    cuda: bool = True,
    sampler: Optional[torch.utils.data.sampler.Sampler] = None,
    silent: bool = False,
    update_freq: Optional[int] = 1,
    update_callback: Optional[Callable[[float, float], None]] = None,
    num_workers: Optional[int] = None,
    epoch_callback: Optional[Callable[[int, torch.nn.Module], None]] = None
) -> None:
    """
    Function to train an autoencoder using the provided dataset. If the dataset consists of 2-tuples or lists of
    (feature, prediction), then the prediction is stripped away.

    :param dataset: training Dataset, consisting of tensors shape [batch_size, features]
    :param autoencoder: autoencoder to train
    :param epochs: number of training epochs
    :param batch_size: batch size for training
    :param optimizer: optimizer to use
    :param scheduler: scheduler to use, or None to disable, defaults to None
    :param corruption: proportion of masking corruption to apply, set to None to disable, defaults to None
    :param validation: instance of Dataset to use for validation, set to None to disable, defaults to None
    :param cuda: whether CUDA is used, defaults to True
    :param sampler: sampler to use in the DataLoader, set to None to disable, defaults to None
    :param silent: set to True to prevent printing out summary statistics, defaults to False
    :param update_freq: frequency of batches with which to update counter, set to None disables, default 1
    :param update_callback: optional function of loss and validation loss to update
    :param num_workers: optional number of workers to use for data loading
    :param epoch_callback: optional function of epoch and model
    :return: None
    """
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=False,
        sampler=sampler,
        shuffle=True if sampler is None else False,
        num_workers=num_workers if num_workers is not None else 0)
    if validation is not None:
        validation_loader = DataLoader(validation,
                                       batch_size=batch_size,
                                       pin_memory=False,
                                       sampler=None,
                                       shuffle=False)
    else:
        validation_loader = None
    loss_function = nn.MSELoss()
    autoencoder.train()
    validation_loss_value = -1
    loss_value = 0
    for epoch in range(epochs):
        if scheduler is not None:
            scheduler.step()
        data_iterator = tqdm(
            dataloader,
            leave=True,
            unit='batch',
            postfix={
                'epo': epoch,
                'lss': '%.6f' % 0.0,
                'vls': '%.6f' % -1,
            },
            disable=silent,
        )
        for index, batch in enumerate(data_iterator):
            if isinstance(
                    batch,
                    tuple) or isinstance(batch, list) and len(batch) in [1, 2]:
                batch = batch[0]
            if cuda:
                batch = batch.cuda(non_blocking=True)
            # run the batch through the autoencoder and obtain the output
            if corruption is not None:
                output = autoencoder(F.dropout(batch, corruption))
            else:
                output = autoencoder(batch)
            loss = loss_function(output, batch)
            # accuracy = pretrain_accuracy(output, batch)
            loss_value = float(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step(closure=None)
            data_iterator.set_postfix(
                epo=epoch,
                lss='%.6f' % loss_value,
                vls='%.6f' % validation_loss_value,
            )
        if update_freq is not None and epoch % update_freq == 0:
            if validation_loader is not None:
                validation_output = predict(validation,
                                            autoencoder,
                                            batch_size,
                                            cuda=cuda,
                                            silent=True,
                                            encode=False)
                validation_inputs = []
                for val_batch in validation_loader:
                    if (isinstance(val_batch, tuple) or isinstance(
                            val_batch, list)) and len(val_batch) in [1, 2]:
                        validation_inputs.append(val_batch[0])
                    else:
                        validation_inputs.append(val_batch)
                validation_actual = torch.cat(validation_inputs)
                if cuda:
                    validation_actual = validation_actual.cuda(
                        non_blocking=True)
                    validation_output = validation_output.cuda(
                        non_blocking=True)
                validation_loss = loss_function(validation_output,
                                                validation_actual)
                # validation_accuracy = pretrain_accuracy(validation_output, validation_actual)
                validation_loss_value = float(validation_loss.item())
                data_iterator.set_postfix(
                    epo=epoch,
                    lss='%.6f' % loss_value,
                    vls='%.6f' % validation_loss_value,
                )
                autoencoder.train()
            else:
                validation_loss_value = -1
                # validation_accuracy = -1
                data_iterator.set_postfix(
                    epo=epoch,
                    lss='%.6f' % loss_value,
                    vls='%.6f' % -1,
                )
            if update_callback is not None:
                update_callback(epoch, optimizer.param_groups[0]['lr'],
                                loss_value, validation_loss_value)
        if epoch_callback is not None:
            autoencoder.eval()
            epoch_callback(epoch, autoencoder)
            autoencoder.train()
Пример #7
0
def train(model: torch.nn.Module,
          train_dls: List[DataLoader],
          optimizer: torch.optim.Optimizer,
          scheduler: LambdaLR,
          validation_evaluator: MultiDatasetClassificationEvaluator,
          n_epochs: int,
          device: AnyStr,
          log_interval: int = 1,
          patience: int = 10,
          model_dir: str = "wandb_local",
          gradient_accumulation: int = 1,
          domain_name: str = ''):
    #best_loss = float('inf')
    best_acc = 0.0
    patience_counter = 0

    epoch_counter = 0
    total = sum(len(dl) for dl in train_dls)

    # Main loop
    while epoch_counter < n_epochs:
        dl_iters = [iter(dl) for dl in train_dls]
        dl_idx = list(range(len(dl_iters)))
        finished = [0] * len(dl_iters)
        i = 0
        with tqdm(total=total, desc="Training") as pbar:
            while sum(finished) < len(dl_iters):
                random.shuffle(dl_idx)
                for d in dl_idx:
                    domain_dl = dl_iters[d]
                    batches = []
                    try:
                        for j in range(gradient_accumulation):
                            batches.append(next(domain_dl))
                    except StopIteration:
                        finished[d] = 1
                        if len(batches) == 0:
                            continue
                    optimizer.zero_grad()
                    for batch in batches:
                        model.train()
                        batch = tuple(t.to(device) for t in batch)
                        input_ids = batch[0]
                        masks = batch[1]
                        labels = batch[2]
                        # Null the labels if its the test data
                        if d == len(train_dls) - 1:
                            labels = None
                        # Testing with random domains to see if any effect
                        #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
                        domains = batch[3]

                        loss, logits = model(input_ids,
                                             attention_mask=masks,
                                             domains=domains,
                                             labels=labels)
                        loss = loss.mean() / gradient_accumulation

                        if i % log_interval == 0:
                            wandb.log({"Loss": loss.item()})

                        loss.backward()
                        i += 1
                        pbar.update(1)

                    optimizer.step()
                    if scheduler is not None:
                        scheduler.step()

        gc.collect()

        # Inline evaluation
        (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
        print(f"Validation acc: {acc}")

        #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')

        # Saving the best model and early stopping
        #if val_loss < best_loss:
        if acc > best_acc:
            best_model = model.state_dict()
            #best_loss = val_loss
            best_acc = acc
            #wandb.run.summary['best_validation_loss'] = best_loss
            torch.save(
                model.state_dict(),
                f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth'
            )
            patience_counter = 0
            # Log to wandb
            wandb.log({
                'Validation accuracy': acc,
                'Validation Precision': P,
                'Validation Recall': R,
                'Validation F1': F1,
                'Validation loss': val_loss
            })
        else:
            patience_counter += 1
            # Stop training once we have lost patience
            if patience_counter == patience:
                break

        gc.collect()
        epoch_counter += 1
Пример #8
0
def train_siamese(
    model: Module,
    optimiser: torch.optim.Optimizer,
    criterion: callable,
    *,
    writer: Writer = MockWriter(),
    train_number_epochs: int,
    data_dir: Path,
    train_batch_size: int,
    model_name: str,
    save_path: Path,
    save_best: bool = False,
    img_size: Tuple[int, int],
    validation_interval: int = 1,
):
    """
    :param img_size:
    :type img_size:
    :param validation_interval:
    :type validation_interval:
    :param data_dir:
    :type data_dir:
    :param optimiser:
    :type optimiser:
    :param criterion:
    :type criterion:
    :param writer:
    :type writer:
    :param model_name:
    :type model_name:
    :param save_path:
    :type save_path:
    :param save_best:
    :type save_best:
    :param model:
    :type model:
    :param train_number_epochs:
    :type train_number_epochs:
    :param train_batch_size:
    :type train_batch_size:
    :return:
    :rtype:"""

    train_dataloader = DataLoader(
        PairDataset(
            data_path=data_dir,
            transform=transforms.Compose(
                [
                    transforms.Grayscale(),
                    transforms.Resize(img_size),
                    transforms.ToTensor(),
                ]
            ),
            split=SplitEnum.training,
        ),
        shuffle=True,
        num_workers=0,
        batch_size=train_batch_size,
    )

    valid_dataloader = DataLoader(
        PairDataset(
            data_path=data_dir,
            transform=transforms.Compose(
                [
                    transforms.Grayscale(),
                    transforms.Resize(img_size),
                    transforms.ToTensor(),
                ]
            ),
            split=SplitEnum.validation,
        ),
        shuffle=True,
        num_workers=0,
        batch_size=train_batch_size,
    )

    best = math.inf

    E = tqdm(range(0, train_number_epochs))
    batch_counter = count()

    for epoch in E:
        for tss in train_dataloader:
            batch_i = next(batch_counter)
            with TorchTrainSession(model):
                o = [t.to(global_torch_device()) for t in tss]
                optimiser.zero_grad()
                loss_contrastive = criterion(model(*o[:2]), o[2].to(dtype=torch.float))
                loss_contrastive.backward()
                optimiser.step()
                train_loss = loss_contrastive.cpu().item()
                writer.scalar("train_loss", train_loss, batch_i)
            if batch_counter.__next__() % validation_interval == 0:
                with TorchEvalSession(model):
                    for tsv in valid_dataloader:
                        ov = [t.to(global_torch_device()) for t in tsv]
                        v_o, fact = model(*ov[:2]), ov[2].to(dtype=torch.float)
                        valid_loss = criterion(v_o, fact).cpu().item()
                        valid_accuracy = (
                            accuracy(distances=v_o, is_diff=fact).cpu().item()
                        )
                        writer.scalar("valid_loss", valid_loss, batch_i)
                        if valid_loss < best:
                            best = valid_loss
                            print(f"new best {best}")
                            writer.blip("new_best", batch_i)
                            if save_best:
                                save_model_parameters(
                                    model,
                                    optimiser=optimiser,
                                    model_name=model_name,
                                    save_directory=save_path,
                                )
            E.set_description(
                f"Epoch number {epoch}, Current train loss {train_loss}, valid loss {valid_loss}, valid_accuracy "
                f"{valid_accuracy}"
            )

    return model
def train(
    model: nn.Module,
    model_name: str,
    EPOCH,
    optimizer: torch.optim.Optimizer,
    loss_func: nn.Module,
    exp_lr_scheduler,
    data_set: dict,
    data_loader: dict,
    test_result_output_func,
    save_dir,
    cuda_mode=None,
    print_inter=2,
    val_inter=50,
    scheduler_step_inter=100,
):
    """
        通用训练函数
    将训练的过程与模型以及训练所需的内容解耦合
    只需给出所需的一系列东西即可训练
    :param model: 模型
    :param model_name:
    :param EPOCH:  epoch总数
    :param optimizer: 优化器
    :param loss_func: 损失函数 nn.XX Loss
    :param exp_lr_scheduler: LR 规划器
    :param data_set: 数据集对象 继承自torch中DateSet对象,
                    Dataloader从中load数据并在训练中进行feed
    :param data_loader: torch的data loader
    :param test_result_output_func: how to print the test result after test
    :param save_dir: log及model param保存位置
    :param cuda_mode: use which GPU for train
                      give GPU ID, if is None use CPU
    :param print_inter: 输出loss的epoch间隔
    :param val_inter: 进行测试的epoch间隔
    :param scheduler_step_inter how many epoch passed let scheduler step once

    :return:
    """
    if cuda_mode is not None:
        torch.cuda.set_device(cuda_mode)
        model.cuda(cuda_mode)
    else:
        model.cpu()
    start_time_raw = time.time()
    start_time = time.strftime('%H:%M:%S', time.localtime(start_time_raw))
    print('start_at: %s' % start_time)

    # start training
    # epoch: 用所有训练数据跑一遍称为一次epoch
    accuracy_res = ""
    curr_step_inter = scheduler_step_inter
    try:
        for epoch in range(EPOCH + 1):
            loss_his = []
            if epoch % curr_step_inter == 0 and epoch != 0:
                curr_step_inter = int(curr_step_inter * 1.5)
                exp_lr_scheduler.step()
            # disable the size up the batch size
            # if (epoch % int(scheduler_step_inter*1.8) ) == 0 and \
            #     epoch != 0 and data_loader['train'].batch_size < 512:
            #
            #         data_loader['train'] = DataLoader.DataLoader(data_loader['train'].dataset,
            #                                                      shuffle=True,
            #                                                      batch_size=data_loader['train'].batch_size * 2,
            #                                                      num_workers=1)

            for batch_x, batch_y in data_loader['train']:
                if model_name.startswith("cnn"):
                    batch_x = batch_x.cpu()
                else:
                    batch_x = [each.cpu() for each in batch_x]
                batch_y = batch_y.cpu()

                if cuda_mode is not None:
                    if model_name.startswith("cnn"):
                        batch_x = batch_x.cuda()
                    else:
                        batch_x = [each.cuda() for each in batch_x]
                    batch_y = batch_y.cuda()

                # in the siamese train mode ,
                # it may come two output, two input, so need to wrap them in tuple/list
                if model_name.startswith("cnn"):
                    batch_out = model(batch_x)
                    batch_out = torch.squeeze(batch_out)
                    loss = loss_func(batch_out, batch_y)
                else:
                    batch_out = model(*batch_x)
                    if model_name.startswith('verify'):
                        batch_y = batch_y.float()
                        loss = loss_func(batch_out[0], batch_out[1], batch_y)
                    else:
                        # print(batch_out)
                        # print(batch_y)
                        loss = loss_func(batch_out, batch_y)

                loss_his.append(loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            loss_val = np.mean(np.array(loss_his))

            if epoch % print_inter == 0:
                print("epoch %d loss %s" % (epoch, loss_val))

            if epoch % val_inter == 0:

                # start testing
                model.eval()
                model.cpu()
                # 转换为求值模式
                test_result_list = []
                for test_x, target_y in data_loader['test']:
                    if model_name.startswith("cnn"):
                        test_x = test_x.cpu()
                    else:
                        test_x = [each.cpu() for each in test_x]
                    target_y = target_y.cpu()
                    if model_name.startswith("cnn"):
                        test_output = model(test_x).cpu()
                    else:
                        test_output = model(*test_x)

                    # in cnn model get max probability category label and ground-truth label
                    # in verify model get dissimilarities ground-truth result

                    # only classify mode need max index

                    if not model_name.startswith("verify"):
                        test_output = get_max_index(test_output)
                        test_output = test_output.item()

                    target_y = target_y.item(
                    )  # new style of get value in tensor
                    test_result_list.append((target_y, test_output))

                accuracy_res = test_result_output_func(test_result_list,
                                                       epoch=epoch,
                                                       loss=loss_val)
                model.train()
                if cuda_mode is not None:
                    model.cuda()

    except KeyboardInterrupt:
        print("stop train\n save model ?")
        res = input()
        if res != 'y':
            return

    end_time_raw = time.time()
    end_time = time.strftime('%H:%M:%S', time.localtime(end_time_raw))
    print('end_at: %s' % end_time)

    cost_time = end_time_raw - start_time_raw
    cost_time = time.strftime('%H:%M:%S', time.gmtime(cost_time, ))
    print('cost time: %s' % cost_time)

    end_time = time.strftime('%m-%d,%H-%M', time.localtime(end_time_raw))
    model = model.cpu()
    # save all model info
    torch.save(
        model.state_dict(),
        os.path.join(save_dir, '%s_model%s.pkl' % (model_name, end_time)))

    file = open(
        os.path.join(save_dir,
                     '%s_models_info_%s.txt' % (model_name, end_time)), 'w')
    info = 'data_set_size:%d\n' % len(data_set['train']) + \
           str(accuracy_res) + \
           'loss: %f\n' % loss_val + \
           'Epoch: %d\n' % EPOCH
    info += str(model)

    file.writelines(info)
    file.close()
Пример #10
0
def train_REINFORCE(input_dir: str, output_dir: str,
                    dataset: torch.utils.data.Dataset,
                    synthesizer: Synthesizer,
                    model: nn.Module,
                    optimizer: torch.optim.Optimizer,
                    loss: Callable[[Any], torch.Tensor],
                    evaluate: Optional[Callable[[], None]],
                    metric: str,
                    reward: Callable[[Environment, Any], float],
                    collate: Callable[[List[Any]], Any],
                    batch_size: int,
                    n_rollout: int,
                    length: Length,
                    evaluation_interval: Optional[Length] = None,
                    snapshot_interval: Optional[Length] = None,
                    maximize: bool = True,
                    threshold: Optional[float] = None,
                    use_pretrained_model: bool = False,
                    use_pretrained_optimizer: bool = False,
                    n_dataloader_worker: int = 2,
                    device: torch.device = torch.device("cpu")) \
        -> None:
    logger.info("Prepare model")
    model.to(device)
    model.train()

    group = get_world_process_group(device)

    if hasattr(dataset, "__len__"):
        iter_per_epoch = len(dataset) // batch_size
    else:
        iter_per_epoch = 1

    evaluation_interval = evaluation_interval or Epoch(1)
    snapshot_interval = snapshot_interval or Epoch(1)

    n_iter = length.n_iter(iter_per_epoch)
    evaluation_interval_iter = evaluation_interval.n_iter(iter_per_epoch)
    snapshot_interval_iter = snapshot_interval.n_iter(iter_per_epoch)

    if use_pretrained_model:
        logger.info("Load pretrained model")
        pretrained_model = os.path.join(input_dir, "model.pt")
        state_dict = torch.load(pretrained_model,
                                map_location=torch.device("cpu"))
        model.load_state_dict(state_dict)
    if use_pretrained_optimizer:
        logger.info("Load pretrained optimizer")
        pretrained_optimizer = os.path.join(input_dir, "optimizer.pt")
        state_dict = torch.load(pretrained_optimizer,
                                map_location=torch.device("cpu"))
        optimizer.load_state_dict(state_dict)

    # Initialize extensions manager
    manager = \
        create_extensions_manager(
            n_iter, evaluation_interval_iter, snapshot_interval_iter,
            iter_per_epoch,
            model, optimizer,
            evaluate, metric, maximize, threshold, output_dir,
            report_metrics=["reward"])

    train_model = setup_distributed_training(model, loss, group)

    logger.info("Start training")
    try:
        while manager.iteration < n_iter:
            loader = create_dataloader(dataset, batch_size,
                                       n_dataloader_worker, lambda x: x)

            for samples in logger.iterable_block("iteration", loader, True):
                if manager.iteration >= n_iter:
                    break
                # Rollout
                rollouts = []
                train_model.train()
                with torch.no_grad():
                    for sample in logger.iterable_block("rollout", samples):
                        sample_inputs = sample.clone_without_supervision()
                        sample_inputs.to(device)
                        for rollout in logger.iterable_block(
                                "sample",
                                synthesizer(sample_inputs,
                                            n_required_output=n_rollout)):
                            if not rollout.is_finished:
                                continue
                            for _ in range(rollout.num):
                                output = sample.clone()
                                output["ground_truth"] = rollout.output
                                output.mark_as_supervision("ground_truth")
                                output["reward"] = \
                                    torch.tensor(reward(sample.clone(), rollout.output))
                                rollouts.append(output)
                if len(rollouts) == 0:
                    logger.warning("No rollout")
                    continue
                if len(rollouts) != n_rollout:
                    logger.warning(
                        "#rollout is unexpected: "
                        f"expected={n_rollout} actual={len(rollouts)}")

                with manager.run_iteration():
                    model.train()
                    with logger.block("collate"):
                        batch2 = collate(rollouts)
                    with logger.block("to"):
                        batch2.to(device)
                    with logger.block("forward"):
                        train_model.train()
                        bloss = train_model(batch2)
                    with logger.block("backward"):
                        optimizer.zero_grad(set_to_none=True)
                        bloss.backward()
                    with logger.block("optimizer.step"):
                        optimizer.step()

                    ppe.reporting.report({"loss": bloss.item()})
                    ppe.reporting.report(
                        {"reward": batch2["reward"].float().mean().item()})
                    logger.dump_elapsed_time_log()
                    if device.type == "cuda":
                        ppe.reporting.report({
                            "gpu.max_memory_allocated":
                            torch.cuda.max_memory_allocated(device)
                        })
    except RuntimeError as e:  # noqa
        logger.critical(traceback.format_exc())

    save_results(output_dir, model, optimizer)
Пример #11
0
def train(net: Connect4Network,
          dataset: np.ndarray,
          optimizer: torch.optim.Optimizer,
          scheduler: torch.optim.lr_scheduler.MultiStepLR,
          start_epoch: int,
          iteration: np.int,
          arguments: AlphaZeroArgs,
          cpu: int = 0):
    """
    Training function, optimizing the weights of our NeuralNetwork using state, policy and value of the datasets
    @param net: Neural Network
    @param dataset: Dataset generated by Self play
    @param optimizer: Pytorch optimizer
    @param scheduler: Pytorch scheduler
    @param start_epoch: start epoch
    @param iteration: current iteration
    @param arguments: AlphaZeroArgs
    @param cpu: CPU index
    """
    torch.manual_seed(cpu)
    cuda = torch.cuda.is_available()
    # Neural Net in TrainMode
    net.train()

    # Custom Alpha loss function
    criterion = AlphaLoss()

    # Initialize Training Set
    train_set = BoardData(dataset)
    train_loader = DataLoader(train_set,
                              batch_size=arguments.batch_size,
                              shuffle=True,
                              num_workers=0,
                              pin_memory=False)
    losses_per_epoch = load_results(iteration + 1)

    print("Starting training process...")
    if len(train_loader) > 10:
        update_rate = len(train_loader) // 10
    else:
        update_rate = 1
    print("Update step rate: %d" % update_rate)
    for epoch in range(start_epoch, arguments.num_epochs):
        total_loss = 0.0
        losses_per_batch = []
        for i, data in enumerate(train_loader, 0):
            # Training using State, policy, value generated by MCTS
            state, policy, value = data
            state, policy, value = state.float(), policy.float(), value.float()
            # CUDA check
            if cuda:
                state, policy, value = state.cuda(), policy.cuda(), value.cuda(
                )
            policy_pred, value_pred = net(state)
            # AlphaLoss for calculation
            loss = criterion(value_pred[:, 0], value, policy_pred, policy)
            loss = loss / arguments.gradient_acc_steps
            loss.backward()
            clip_grad_norm_(net.parameters(), arguments.max_norm)
            if (epoch % arguments.gradient_acc_steps) == 0:
                # Using the
                optimizer.step()
                optimizer.zero_grad()

            # Add current loss to total
            total_loss += loss.item()
            if i % update_rate == (
                    update_rate - 1
            ):  # print every update_size-d mini-batches of size = batch_size
                losses_per_batch.append(arguments.gradient_acc_steps *
                                        total_loss / update_rate)
                print(
                    '[Iteration %d] Process ID: %d [Epoch: %d, %5d/ %d points] total loss per batch: %.3f'
                    % (iteration, os.getpid(), epoch + 1,
                       (i + 1) * arguments.batch_size, len(train_set),
                       losses_per_batch[-1]))
                print("Policy (actual, predicted):", policy[0].argmax().item(),
                      policy_pred[0].argmax().item())
                print("Policy data:", policy[0])
                print("Policy pred:", policy_pred[0])
                print("Value (actual, predicted):", value[0].item(),
                      value_pred[0, 0].item())
                # print("Conv grad: %.7f" % net.conv.conv1.weight.grad.mean().item())
                # print("Res18 grad %.7f:" % net.res_18.conv1.weight.grad.mean().item())
                total_loss = 0.0

        scheduler.step()
        if len(losses_per_batch) >= 1:
            losses_per_epoch.append(
                sum(losses_per_batch) / len(losses_per_batch))
        if (epoch % 2) == 0:
            # Save trained Model
            util.pickle_save(util.get_losses_file(iteration + 1),
                             losses_per_epoch)

            torch_dest = util.get_model_file_path(arguments.neural_net_name,
                                                  iteration + 1)
            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }, torch_dest)
        '''
        # Early stopping
        if len(losses_per_epoch) > 50:
            if abs(sum(losses_per_epoch[-4:-1])/3-sum(losses_per_epoch[-16:-13])/3) <= 0.00017:
                break
        '''

    print("Finished Training!")
    # Plotting Feature
    fig = plt.figure()
    ax = fig.add_subplot(222)
    ax.scatter(
        [e for e in range(start_epoch, (len(losses_per_epoch) + start_epoch))],
        losses_per_epoch)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss per batch")
    ax.set_title("Loss vs Epoch")

    util.create_model_directory()

    # Save Plot as PNG
    plt.savefig(
        os.path.join(
            util.get_model_directory(),
            f"Loss_vs_Epoch_iter_{iteration + 1}_{datetime.datetime.today().strftime('%Y-%m-%d')}.png"
        ))
    plt.show()
Пример #12
0
def train_supervised(output_dir: str,
                     dataset: torch.utils.data.Dataset,
                     model: nn.Module,
                     optimizer: torch.optim.Optimizer,
                     loss: Callable[[Any], torch.Tensor],
                     evaluate: Optional[Callable[[], None]],
                     metric: str,
                     collate: Callable[[List[Any]], Any],
                     batch_size: int,
                     length: Length,
                     evaluation_interval: Optional[Length] = None,
                     snapshot_interval: Optional[Length] = None,
                     maximize: bool = True,
                     threshold: Optional[float] = None,
                     n_dataloader_worker: int = 1,
                     device: torch.device = torch.device("cpu")) \
        -> None:
    logger.info("Prepare model")
    model.to(device)
    model.train()

    group = get_world_process_group(device)
    global_batch_size = batch_size * distributed.size(group)

    if hasattr(dataset, "__len__"):
        iter_per_epoch = len(dataset) // global_batch_size
    else:
        iter_per_epoch = 1

    evaluation_interval = evaluation_interval or Epoch(1)
    snapshot_interval = snapshot_interval or Epoch(1)

    n_iter = length.n_iter(iter_per_epoch)
    evaluation_interval_iter = evaluation_interval.n_iter(iter_per_epoch)
    snapshot_interval_iter = snapshot_interval.n_iter(iter_per_epoch)

    # Initialize extensions manager
    manager = \
        create_extensions_manager(
            n_iter, evaluation_interval_iter, snapshot_interval_iter,
            iter_per_epoch,
            model, optimizer,
            evaluate, metric, maximize, threshold, output_dir)

    train_model = setup_distributed_training(model, loss, group)

    logger.info("Start training")
    try:
        while manager.iteration < n_iter:
            loader = create_dataloader(dataset, batch_size,
                                       n_dataloader_worker, collate)

            for batch in logger.iterable_block("iteration", loader, True):
                if manager.iteration >= n_iter:
                    break
                if len(batch.to_dict()) == 0:
                    logger.warning(f"Skip {manager.iteration} th batch")
                    continue
                with manager.run_iteration():
                    train_model.train()
                    with logger.block("to"):
                        batch.to(device=device)
                    with logger.block("forward"):
                        bloss = train_model(batch)
                    with logger.block("backward"):
                        optimizer.zero_grad(set_to_none=True)
                        bloss.backward()
                    with logger.block("optimizer.step"):
                        optimizer.step()

                    ppe.reporting.report({"loss": bloss.item()})
                    logger.dump_elapsed_time_log()
                    if device.type == "cuda":
                        ppe.reporting.report({
                            "gpu.max_memory_allocated":
                            torch.cuda.max_memory_allocated(device)
                        })
    except RuntimeError as e:  # noqa
        logger.critical(traceback.format_exc())

    save_results(output_dir, model, optimizer)
Пример #13
0
    def step_optimizer(
        self,
        optimizer: torch.optim.Optimizer,
        clip_grads: Optional[Callable[[Iterator], None]] = None,
        auto_zero_grads: bool = True,
        scaler: Optional[Any] = None,
        # Should be torch.cuda.amp.GradScaler, but:
        #   * other implementations might be possible
        #   * requiring this type forces upgrades to PyTorch 1.6+
    ) -> None:
        """
        Perform a single optimization step.

        This function must be called once for each optimizer. However, the order of
        different optimizers' steps can be specified by calling this function in different
        orders. Also, gradient accumulation across iterations is performed by the Determined
        training loop by setting the experiment configuration field
        :ref:`optimizations.aggregation_frequency <config-aggregation-frequency>`.

        Here is a code example:

        .. code-block:: python

            def clip_grads(params):
                torch.nn.utils.clip_grad_norm_(params, 0.0001),

            self.context.step_optimizer(self.opt1, clip_grads)

        Arguments:
            optimizer(``torch.optim.Optimizer``): Which optimizer should be stepped.
            clip_grads(a function, optional): This function should have one argument for
                parameters in order to clip the gradients.
            auto_zero_grads(bool, optional): Automatically zero out gradients automatically after
                stepping the optimizer. If false, you need to call ``optimizer.zero_grad()``
                manually. Note that if :ref:`optimizations.aggregation_frequency
                <config-aggregation-frequency>` is greater than 1, ``auto_zero_grads`` must be true.
            scaler(``torch.cuda.amp.GradScaler``, optional): The scaler to use for stepping the
                optimizer. This should be unset if not using AMP, and is necessary if
                ``wrap_scaler()`` was called directly.
        """

        check.true(
            auto_zero_grads or self.hvd_config.aggregation_frequency == 1,
            "if optimizations.aggregation_frequency is larger than 1, "
            "you can only set auto_zero_grads to be true. ",
        )

        if not self._should_communicate_and_update():
            return

        # Communication needs to be synchronized so that is completed
        # before we apply gradient clipping and `step()`. In the case of APEX
        # this is called in backward() instead, so that it's inside the context
        # manager and before unscaling.
        if self.hvd_config.use and not self._use_apex:
            optimizer.synchronize()  # type: ignore

        parameters = ([
            p for group in optimizer.param_groups
            for p in group.get("params", [])
        ] if not self._use_apex else apex.amp.master_params(optimizer))

        if self.hvd_config.average_aggregated_gradients:
            self._average_gradients(
                parameters=parameters,
                divisor=self.hvd_config.aggregation_frequency)

        if clip_grads is not None:
            if self._scaler and self.experimental._auto_amp:
                self._scaler.unscale_(optimizer)
            clip_grads(parameters)

        # For stepping the optimizer we will operate on the scaler passed
        # in, or fall back to the wrapped scaler (if any).
        if scaler is None and self.experimental._auto_amp:
            scaler = self._scaler
        if scaler:

            def step_fn() -> None:
                scaler.step(optimizer)  # type: ignore

        else:
            step_fn = optimizer.step  # type: ignore

        if self.hvd_config.use:
            with optimizer.skip_synchronize():  # type: ignore
                step_fn()
        else:
            step_fn()

        if auto_zero_grads:
            optimizer.zero_grad()
Пример #14
0
def train_one_epoch(
    model: torch.nn.Module,
    criterion: Optional[torch.nn.Module],
    contrastive_criterion: Optional[torch.nn.Module],
    qa_criterion: Optional[torch.nn.Module],
    weight_dict: Dict[str, float],
    data_loader: Iterable,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    epoch: int,
    args,
    max_norm: float = 0,
    model_ema: Optional[torch.nn.Module] = None,
):
    model.train()
    if criterion is not None:
        criterion.train()
    if contrastive_criterion is not None:
        contrastive_criterion.train()
    if qa_criterion is not None:
        qa_criterion.train()
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr",
                            SmoothedValue(window_size=1, fmt="{value:.6f}"))
    metric_logger.add_meter("lr_backbone",
                            SmoothedValue(window_size=1, fmt="{value:.6f}"))
    metric_logger.add_meter("lr_text_encoder",
                            SmoothedValue(window_size=1, fmt="{value:.6f}"))
    header = "Epoch: [{}]".format(epoch)
    print_freq = 10

    num_training_steps = int(len(data_loader) * args.epochs)
    for i, batch_dict in enumerate(
            metric_logger.log_every(data_loader, print_freq, header)):
        curr_step = epoch * len(data_loader) + i
        samples = batch_dict["samples"].to(device)
        positive_map = batch_dict["positive_map"].to(
            device) if "positive_map" in batch_dict else None
        targets = batch_dict["targets"]
        answers = {k: v.to(device)
                   for k, v in batch_dict["answers"].items()
                   } if "answers" in batch_dict else None
        captions = [t["caption"] for t in targets]

        targets = targets_to(targets, device)

        memory_cache = None
        if args.masks:
            outputs = model(samples, captions)
        else:
            memory_cache = model(samples, captions, encode_and_save=True)
            outputs = model(samples,
                            captions,
                            encode_and_save=False,
                            memory_cache=memory_cache)

        loss_dict = {}
        if criterion is not None:
            loss_dict.update(criterion(outputs, targets, positive_map))

        if contrastive_criterion is not None:
            assert memory_cache is not None
            contrastive_loss = contrastive_criterion(
                memory_cache["text_pooled_op"], memory_cache["img_pooled_op"])
            loss_dict["contrastive_loss"] = contrastive_loss

        if qa_criterion is not None:
            answer_losses = qa_criterion(outputs, answers)
            loss_dict.update(answer_losses)

        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()
                     if k in weight_dict)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = dist.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {
            f"{k}_unscaled": v
            for k, v in loss_dict_reduced.items()
        }
        loss_dict_reduced_scaled = {
            k: v * weight_dict[k]
            for k, v in loss_dict_reduced.items() if k in weight_dict
        }
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        if max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

        adjust_learning_rate(
            optimizer,
            epoch,
            curr_step,
            num_training_steps=num_training_steps,
            args=args,
        )
        if model_ema is not None:
            update_ema(model, model_ema, args.ema_decay)

        metric_logger.update(loss=loss_value,
                             **loss_dict_reduced_scaled,
                             **loss_dict_reduced_unscaled)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(lr_backbone=optimizer.param_groups[1]["lr"])
        metric_logger.update(lr_text_encoder=optimizer.param_groups[2]["lr"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
Пример #15
0
def train(dataloader: DataLoader, model: RNN, optimizer: torch.optim.Optimizer,
          loss_function: Union[SplitCrossEntropyLoss, CrossEntropyLoss], use_apex=False, amp=None,
          lr_weights: dict = None, prior: str = 'ninf', scaling: str = None, total_steps: int = 0, steps: int = 0,
          bptt: int = 125, alpha: float = 0., beta: float = 0., log_interval: int = 200, n_samples: int = 4,
          device: Union[torch.device, str] = 'cpu', tb_writer=None, **kwargs):
    total_loss = 0
    batch = 0

    tr_kl = 0.
    logging_kl = 0.
    tr_loss = 0.
    logging_loss = 0.

    model.train()

    log.info('Starting training loop')
    start_time = time.time()

    with tqdm(dataloader, total=len(dataloader)) as pbar:
        for data, targets, seq_len, lang in pbar:

            data = data.squeeze(0).to(device)
            targets = targets.squeeze(0).to(device)
            lang = lang.to(device)

            hidden = model.init_hidden(batchsize=data.size(-1))

            lr2 = optimizer.param_groups[0]['lr']
            if lr_weights is not None:
                optimizer.param_groups[0]['lr'] = lr2 * seq_len.item() / bptt * lr_weights[lang.item()]
            else:
                optimizer.param_groups[0]['lr'] = lr2 * seq_len.item() / bptt

            hidden = detach(hidden)
            optimizer.zero_grad()

            loss = 0

            if not isinstance(prior, VIPrior):
                n_samples = 1

            for s in range(n_samples):
                output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, lang, return_h=True)

                if isinstance(loss_function, SplitCrossEntropyLoss):
                    raw_loss = loss_function(model.decoder.weight, model.decoder.bias, output, targets)
                else:
                    raw_loss = loss_function(output, targets)

                if alpha:
                    raw_loss = raw_loss + sum(alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:])
                # Temporal Activation Regularization (slowness)
                if beta:
                    raw_loss = raw_loss + sum(beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:])

                loss += raw_loss

            loss /= n_samples

            log_loss = loss

            if isinstance(prior, VIPrior):
                kl_term = prior.kl_div()

                if scaling == "uniform":
                    scale = 1. / total_steps
                elif scaling == "linear_annealing":
                    scale = ((total_steps - steps - 1) * 2. + 1.) / total_steps ** 2
                elif scaling == "logistic_annealing":
                    steepness = 0.0025
                    scale = 1. / (1 + np.exp(-steepness * (steps - total_steps / 2.)))
                else:
                    scale = 1.
                loss = loss + scale * kl_term
                tr_kl += kl_term.item()

            if use_apex:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if tb_writer is not None:
                tb_writer.add_scalar('train/loss', log_loss.item(), steps)

                if isinstance(prior, VIPrior):
                    tb_writer.add_scalar('train/kl', kl_term.item(), steps)
                    tb_writer.add_scalar('train/loss+kl', loss.item(), steps)

                    logging_kl += tr_kl

                logging_loss += tr_loss

            optimizer.step()

            total_loss += raw_loss.data
            batch += 1
            steps += 1

            # reset lr to optimiser default
            optimizer.param_groups[0]['lr'] = lr2

            if batch % log_interval == 0 and batch > 0:
                cur_loss = total_loss.item() / log_interval
                elapsed = time.time() - start_time
                log.debug(
                    '| {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
                        batch, len(dataloader), optimizer.param_groups[0]['lr'], elapsed * 1000 / log_interval,
                        cur_loss, math.exp(cur_loss), cur_loss / math.log(2)))
                total_loss = 0
                start_time = time.time()

            pbar.set_description('Training, end of batch {} | Loss {}'.format(batch, loss.data))

    return steps
Пример #16
0
    def step_optimizer(
        self,
        optimizer: torch.optim.Optimizer,  # type: ignore
        clip_grads: Optional[Callable[[Iterator], None]] = None,
        auto_zero_grads: bool = True,
    ) -> None:
        """
        Perform a single optimization step.

        This function must be called once for each optimizer. However, the order of
        different optimizers' steps can be specified by calling this function in different
        orders. Also, gradient accumulation across iterations is performed by the Determined
        training loop by setting the experiment configuration field
        :ref:`optimizations.aggregation_frequency <config-aggregation-frequency>`.

        Here is a code example:

        .. code-block:: python

            def clip_grads(params):
                torch.nn.utils.clip_grad_norm_(params, 0.0001),

            self.context.step_optimizer(self.opt1, clip_grads)

        Arguments:
            optimizer(``torch.optim.Optimizer``): Which optimizer should be stepped.
            clip_grads(a function, optional): This function should have one argument for
                parameters in order to clip the gradients.
            auto_zero_grads(bool, optional): Automatically zero out gradients automatically after
                stepping the optimizer. If false, you need to call ``optimizer.zero_grad()``
                manually. Note that if :ref:`optimizations.aggregation_frequency
                <config-aggregation-frequency>` is greater than 1, ``auto_zero_grads`` must be true.
        """

        check.true(
            auto_zero_grads or self.hvd_config.aggregation_frequency > 1,
            "if optimizations.aggregation_frequency is larger than 1, "
            "you can only set auto_zero_grads to be true. ",
        )
        if self._should_communicate_and_update():
            # Communication needs to be synchronized so that is completed
            # before we apply gradient clipping and `step()`.
            if self.hvd_config.use and not self._use_amp:
                optimizer.synchronize()

            parameters = ([
                p for group in optimizer.param_groups
                for p in group.get("params", [])
            ] if not self._use_amp else apex.amp.master_params(optimizer))

            if self.hvd_config.average_aggregated_gradients:
                self._average_gradients(
                    parameters=parameters,
                    divisor=self.hvd_config.aggregation_frequency)

            if clip_grads is not None:
                clip_grads(parameters)

            if self.hvd_config.use:
                with optimizer.skip_synchronize():
                    optimizer.step()
            else:
                optimizer.step()

            if auto_zero_grads:
                optimizer.zero_grad()
Пример #17
0
def inner_train_ssd(
    *,
    data_root: Path,
    cfg: NOD,
    model: Module,
    data_loader: DataLoader,
    optimiser: torch.optim.Optimizer,
    scheduler: WarmupMultiStepLR,
    check_pointer: callable,
    device: callable,
    arguments: callable,
    kws: NOD,
) -> Module:
    """

    :param data_root:
    :type data_root:
    :param cfg:
    :type cfg:
    :param model:
    :type model:
    :param data_loader:
    :type data_loader:
    :param optimiser:
    :type optimiser:
    :param scheduler:
    :type scheduler:
    :param check_pointer:
    :type check_pointer:
    :param device:
    :type device:
    :param arguments:
    :type arguments:
    :param kws:
    :type kws:
    :return:
    :rtype:"""
    logger = logging.getLogger("SSD.trainer")
    logger.info("Start training ...")
    meters = MetricLogger()

    with TorchTrainSession(model):
        save_to_disk = global_distribution_rank() == 0
        if kws.use_tensorboard and save_to_disk:
            import tensorboardX

            writer = tensorboardX.SummaryWriter(
                log_dir=str(PROJECT_APP_PATH.user_data / "results" /
                            "tf_logs"))
        else:
            writer = None

        max_iter = len(data_loader)
        start_iter = arguments["iteration"]
        start_training_time = time.time()
        end = time.time()
        for iteration, (images, targets,
                        _) in enumerate(data_loader, start_iter):
            arguments["iteration"] = iteration

            images = images.to(device)
            targets = targets.to(device)
            loss_instance = MultiBoxLoss(neg_pos_ratio=cfg.model.neg_pos_ratio)
            cls_logits, bbox_pred = model(images)

            reg_loss, cls_loss = loss_instance(cls_logits, bbox_pred,
                                               targets.labels, targets.boxes)
            loss_dict = dict(reg_loss=reg_loss, cls_loss=cls_loss)

            loss = sum(loss for loss in loss_dict.values())

            loss_dict_reduced = reduce_loss_dict(
                loss_dict)  # reduce losses over all GPUs for logging purposes
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            meters.update(total_loss=losses_reduced, **loss_dict_reduced)

            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
            scheduler.step()

            batch_time = time.time() - end
            end = time.time()
            meters.update(time=batch_time)
            if iteration % kws.log_step == 0:
                eta_seconds = meters.time.global_avg * (max_iter - iteration)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                logger.info(
                    meters.delimiter.join([
                        f"iter: {iteration:06d}",
                        f"lr: {optimiser.param_groups[0]['lr']:.5f}",
                        f"{str(meters)}",
                        f"eta: {eta_string}",
                        f"mem: {round(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)}M",
                    ]))
                if writer:
                    global_step = iteration
                    writer.add_scalar("losses/total_loss",
                                      losses_reduced,
                                      global_step=global_step)
                    for loss_name, loss_item in loss_dict_reduced.items():
                        writer.add_scalar(f"losses/{loss_name}",
                                          loss_item,
                                          global_step=global_step)
                    writer.add_scalar("lr",
                                      optimiser.param_groups[0]["lr"],
                                      global_step=global_step)

            if iteration % kws.save_step == 0:
                check_pointer.save(f"model_{iteration:06d}", **arguments)

            if (kws.eval_step > 0 and iteration % kws.eval_step == 0
                    and not iteration == max_iter):
                with TorchEvalSession(model):
                    eval_results = do_ssd_evaluation(
                        data_root,
                        cfg,
                        model,
                        distributed=kws.distributed,
                        iteration=iteration,
                    )
                    if global_distribution_rank() == 0 and writer:
                        for eval_result, dataset in zip(
                                eval_results, cfg.datasets.test):
                            write_metrics_recursive(
                                eval_result["metrics"],
                                "metrics/" + dataset,
                                writer,
                                iteration,
                            )

        check_pointer.save("model_final", **arguments)

        total_training_time = int(time.time() -
                                  start_training_time)  # compute training time
        logger.info(
            f"Total training time: {datetime.timedelta(seconds=total_training_time)} ("
            f"{total_training_time / max_iter:.4f} s / it)")
        return model
Пример #18
0
def train(args,
          worker_id: int,
          global_model: Union[ActorNetwork, ActorCriticNetwork],
          T: Value,
          global_reward: Value,
          optimizer: torch.optim.Optimizer = None,
          global_model_critic: CriticNetwork = None,
          optimizer_critic: torch.optim.Optimizer = None,
          lr_scheduler: torch.optim.lr_scheduler = None,
          lr_scheduler_critic: torch.optim.lr_scheduler = None):
    """
    Start worker in training mode, i.e. training the shared model with backprop
    loosely based on https://github.com/ikostrikov/pytorch-a3c/blob/master/train.py
    :param args: console arguments
    :param worker_id: id of worker to differentiatethem and init different seeds
    :param global_model: global model, which is optimized/ for split models: actor
    :param T: global counter of steps
    :param global_reward: global running reward value
    :param optimizer: optimizer for shared model/ for split models: actor model
    :param global_model_critic: optional global critic model for split networks
    :param optimizer_critic: optional critic optimizer for split networks
    :param lr_scheduler: optional learning rate scheduler instance for shared model
    / for fixed model: actor learning rate scheduler
    :param lr_scheduler_critic: optional learning rate scheduler instance for critic model
    :return: None
    """
    torch.manual_seed(args.seed + worker_id)

    if args.worker == 1:
        logging.info(f"Running A2C with {args.n_envs} environments.")
        if "RR" not in args.env_name:
            env = SubprocVecEnv([
                make_env(args.env_name, args.seed, i, args.log_dir)
                for i in range(args.n_envs)
            ])
        else:
            env = DummyVecEnv(
                [make_env(args.env_name, args.seed, worker_id, args.log_dir)])
    else:
        logging.info(f"Running A3C: training worker {worker_id} started.")
        env = DummyVecEnv(
            [make_env(args.env_name, args.seed, worker_id, args.log_dir)])
        # avoid any issues if this is not 1
        args.n_envs = 1

    normalizer = get_normalizer(args.normalizer, env)

    # init local NN instance for worker thread
    model = copy.deepcopy(global_model)
    model.train()

    model_critic = None

    if global_model_critic:
        model_critic = copy.deepcopy(global_model_critic)
        model_critic.train()

    # if no shared optimizer is provided use individual one
    if not optimizer:
        optimizer, optimizer_critic = get_optimizer(
            args.optimizer,
            global_model,
            args.lr,
            model_critic=global_model_critic,
            lr_critic=args.lr_critic)
        if args.lr_scheduler == "exponential":
            lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                                  gamma=0.99)
            if optimizer_critic:
                lr_scheduler_critic = torch.optim.lr_scheduler.ExponentialLR(
                    optimizer_critic, gamma=0.99)

    state = torch.Tensor(env.reset())

    t = np.zeros(args.n_envs)
    global_iter = 0
    episode_reward = np.zeros(args.n_envs)

    if worker_id == 0:
        writer = SummaryWriter(log_dir='experiments/runs/')

    while True:
        # Get state of the global model
        model.load_state_dict(global_model.state_dict())
        if not args.shared_model:
            model_critic.load_state_dict(global_model_critic.state_dict())

        # containers for computing loss
        values = []
        log_probs = []
        rewards = []
        entropies = []
        # container to check whether a terminal state was reached from one of the envs
        terminals = []

        # reward_sum = 0
        for step in range(args.rollout_steps):
            t += 1

            if args.shared_model:
                value, mu, std = model(normalizer(state))
            else:
                mu, std = model(normalizer(state))
                value = model_critic(normalizer(state))

            dist = torch.distributions.Normal(mu, std)

            # ------------------------------------------
            # # select action
            action = dist.sample()

            # ------------------------------------------
            # Compute statistics for loss
            entropy = dist.entropy().sum(-1).unsqueeze(-1)
            log_prob = dist.log_prob(action).sum(-1).unsqueeze(-1)

            # make selected move
            action = np.clip(action.detach().numpy(), -args.max_action,
                             args.max_action)
            state, reward, dones, _ = env.step(
                action[0]
                if not args.worker == 1 or "RR" in args.env_name else action)

            reward = shape_reward(args, reward)

            episode_reward += reward

            # probably don't set terminal state if max_episode length
            dones = np.logical_or(dones, t >= args.max_episode_length)

            values.append(value)
            log_probs.append(log_prob)
            rewards.append(torch.Tensor(reward).unsqueeze(-1))
            entropies.append(entropy)
            terminals.append(torch.Tensor(1 - dones).unsqueeze(-1))

            for i, done in enumerate(dones):
                if done:
                    # keep track of the avg overall global reward
                    with global_reward.get_lock():
                        if global_reward.value == -np.inf:
                            global_reward.value = episode_reward[i]
                        else:
                            global_reward.value = .99 * global_reward.value + .01 * episode_reward[
                                i]
                    if worker_id == 0 and T.value % args.log_frequency == 0:
                        writer.add_scalar("reward/global", global_reward.value,
                                          T.value)

                    episode_reward[i] = 0
                    t[i] = 0
                    if args.worker != 1 or "RR" in args.env_name:
                        env.reset()

            with T.get_lock():
                # this is one for a3c and n for A2C (actually the lock is not needed for A2C)
                T.value += args.n_envs

            if lr_scheduler and worker_id == 0 and T.value % args.lr_scheduler_step and global_iter != 0:
                lr_scheduler.step(T.value / args.lr_scheduler_step)

                if lr_scheduler_critic:
                    lr_scheduler_critic.step(T.value / args.lr_scheduler_step)

            state = torch.Tensor(state)

        if args.shared_model:
            v, _, _ = model(normalizer(state))
            G = v.detach()
        else:
            G = model_critic(normalizer(state)).detach()

        values.append(G)

        # compute loss and backprop
        advantages = torch.zeros((args.n_envs, 1))

        ret = torch.zeros((args.rollout_steps, args.n_envs, 1))
        adv = torch.zeros((args.rollout_steps, args.n_envs, 1))

        # iterate over all time steps from most recent to the starting one
        for i in reversed(range(args.rollout_steps)):
            # G can be seen essentially as the return over the course of the rollout
            G = rewards[i] + args.discount * terminals[i] * G
            if not args.no_gae:
                # Generalized Advantage Estimation
                td_error = rewards[i] + args.discount * terminals[i] * values[
                    i + 1] - values[i]
                # terminals here to "reset" advantages to 0, because reset ist called internally in the env
                # and new trajectory started
                advantages = advantages * args.discount * args.tau * terminals[
                    i] + td_error
            else:
                advantages = G - values[i].detach()

            adv[i] = advantages.detach()
            ret[i] = G.detach()

        policy_loss = -(torch.stack(log_probs) * adv).mean()
        # minus 1 in order to remove the last element, which is only necessary for next timestep value
        value_loss = .5 * (ret - torch.stack(values[:-1])).pow(2).mean()
        entropy_loss = torch.stack(entropies).mean()

        # zero grads to reset the gradients
        optimizer.zero_grad()

        if args.shared_model:
            # combined loss for shared architecture
            total_loss = policy_loss + args.value_loss_weight * value_loss - args.entropy_loss_weight * entropy_loss
            total_loss.backward()
        else:
            optimizer_critic.zero_grad()

            value_loss.backward()
            (policy_loss - args.entropy_loss_weight * entropy_loss).backward()

            # this is just used for plotting in tensorboard
            total_loss = policy_loss + args.value_loss_weight * value_loss - args.entropy_loss_weight * entropy_loss

        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        sync_grads(model, global_model)
        optimizer.step()

        if not args.shared_model:
            torch.nn.utils.clip_grad_norm_(model_critic.parameters(),
                                           args.max_grad_norm)
            sync_grads(model_critic, global_model_critic)
            optimizer_critic.step()

        global_iter += 1

        if worker_id == 0 and T.value % args.log_frequency == 0:
            log_to_tensorboard(writer,
                               model,
                               optimizer,
                               rewards,
                               values,
                               total_loss,
                               policy_loss,
                               value_loss,
                               entropy_loss,
                               T.value,
                               model_critic=model_critic,
                               optimizer_critic=optimizer_critic)
Пример #19
0
def train_person_segmentor(
        model: torch.nn.Module,
        train_loader: torch.utils.data.DataLoader,
        valid_loader: torch.utils.data.DataLoader,
        criterion: callable,
        optimiser: torch.optim.Optimizer,
        *,
        save_model_path: Path,
        learning_rate: Number = 6e-2,
        scheduler: torch.optim.lr_scheduler = None,
        n_epochs: int = 100,
        writer: ImageWriterMixin = MockWriter(),
):
    """

    :param model:
    :type model:
    :param train_loader:
    :type train_loader:
    :param valid_loader:
    :type valid_loader:
    :param criterion:
    :type criterion:
    :param optimiser:
    :type optimiser:
    :param scheduler:
    :type scheduler:
    :param save_model_path:
    :type save_model_path:
    :param n_epochs:
    :type n_epochs:
    :return:
    :rtype:"""
    valid_loss_min = numpy.Inf  # track change in validation loss
    assert n_epochs > 0, n_epochs
    E = tqdm(range(1, n_epochs + 1))
    for epoch_i in E:
        train_loss = 0.0
        valid_loss = 0.0

        with TorchTrainSession(model):
            for data, target in tqdm(train_loader):
                output, *_ = model(data.to(global_torch_device()))
                loss = criterion(output,
                                 target.to(global_torch_device()).float())

                optimiser.zero_grad()
                loss.backward()
                optimiser.step()

                train_loss += loss.cpu().item() * data.size(0)

        with TorchEvalSession(model):
            with torch.no_grad():
                for data, target in tqdm(valid_loader):
                    target = target.float()
                    (
                        output,
                        *_,
                    ) = model(  # forward pass: compute predicted outputs by passing inputs to the model
                        data.to(global_torch_device()))
                    validation_loss = criterion(
                        output, target.to(
                            global_torch_device()))  # calculate the batch loss
                    writer.scalar(
                        "dice_validation",
                        dice_loss(output, target.to(global_torch_device())),
                    )

                    valid_loss += validation_loss.detach().cpu().item(
                    ) * data.size(0)  # update average validation loss
                writer.image("input", data, epoch_i)  # write the last batch
                writer.image("truth", target, epoch_i)  # write the last batch
                writer.image("prediction", torch.sigmoid(output),
                             epoch_i)  # write the last batch

        # calculate average losses
        train_loss = train_loss / len(train_loader.dataset)
        valid_loss = valid_loss / len(valid_loader.dataset)

        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print(
                f"Validation loss decreased ({valid_loss_min:.6f} --> {valid_loss:.6f}).  Saving model ..."
            )
            torch.save(model.state_dict(), save_model_path)
            valid_loss_min = valid_loss

        if scheduler:
            scheduler.step()
            optimiser, scheduler = reschedule_learning_rate(
                model,
                optimiser,
                epoch_i,
                scheduler,
                starting_learning_rate=learning_rate,
            )

        # print training/validation statistics
        current_lr = next(iter(optimiser.param_groups))["lr"]
        E.set_description(f"Epoch: {epoch_i} "
                          f"Training Loss: {train_loss:.6f} "
                          f"Validation Loss: {valid_loss:.6f} "
                          f"Learning rate: {current_lr:.6f}")
        writer.scalar("training_loss", train_loss)
        writer.scalar("validation_loss", valid_loss)
        writer.scalar("learning_rate", current_lr)

    return model
Пример #20
0
def train(data_loader: DataLoader, model: nn.Module, num_iterations: int, start_iteration: int, device: torch.device,
          losses_dict: dict, accuracies_dict: dict, optimizer: torch.optim.Optimizer = None,
          debug_dumper: DebugDumper = None, log_file_name: str = None):

    loss_func = nn.CrossEntropyLoss()

    iterator = iter(data_loader)
    epoch_size = len(data_loader.dataset)

    is_train = optimizer is not None
    if is_train:
        model.train()
    else:
        model.eval()

    global_av_loss = 0
    global_av_accuracy = 0
    for i in range(num_iterations):

        try:
            batch = next(iterator)
        except StopIteration:
            iterator = iter(data_loader)
            batch = next(iterator)

        # debug dump
        if debug_dumper is not None:
            debug_dumper.dump(batch)

        # transferring data to device
        for k in batch:
            batch[k] = batch[k].to(device)

        av_loss = 0
        av_accuracy = 0
        av_n = 0
        with torch.set_grad_enabled(is_train):

            predictions = model.forward(batch)

            labels = batch['label']
            labels = labels.view(labels.size(0))
            loss = loss_func(predictions, labels)

            if is_train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            av_loss += loss.item()
            global_av_loss += loss.item()

            # calculating accuracy
            predicted_labels = torch.argmax(predictions, dim=1)
            correct_count = predicted_labels.eq(labels).sum().item()
            accuracy = correct_count / labels.size(0)
            av_accuracy += accuracy
            global_av_accuracy += accuracy

            av_n += 1

        iteration = i + start_iteration
        if i % 10 == 0:
            e = iteration / epoch_size
            av_loss /= av_n
            av_accuracy = 100 * av_accuracy / av_n
            if is_train:
                log_string = "%d: epoch=%1.2f, loss=%f, accuracy=%1.2f%%" % (iteration, e, av_loss, av_accuracy)
            else:
                log_string = "%d: loss=%f, accuracy=%1.2f%%" % (i, av_loss, av_accuracy)
            print(log_string)
            if log_file_name is not None:
                with open(log_file_name, 'a+') as lf:
                    lf.write(log_string + '\n')

        # adding loss to log
        if is_train:
            losses_dict[iteration] = loss.item()
            accuracies_dict[iteration] = accuracy * 100.0

    if not is_train:
        global_av_loss /= num_iterations
        global_av_accuracy /= num_iterations
        print("average test loss = %f, average test accuracy = %f" % (global_av_loss, global_av_accuracy))
        losses_dict[num_iterations + start_iteration] = global_av_loss
        accuracies_dict[num_iterations + start_iteration] = global_av_accuracy

    return num_iterations + start_iteration
Пример #21
0
def train(
    epoch: int,
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    preconditioner: kfac.preconditioner.KFACPreconditioner | None,
    loss_func: torch.nn.Module,
    train_sampler: torch.utils.data.distributed.DistributedSampler[SampleT],
    train_loader: torch.utils.data.DataLoader[SampleT],
    args: argparse.Namespace,
) -> None:
    """Train model."""
    model.train()
    train_sampler.set_epoch(epoch)
    train_loss = Metric('train_loss')
    train_accuracy = Metric('train_accuracy')
    scaler = args.grad_scaler if 'grad_scaler' in args else None
    mini_step = 0
    step_loss = torch.tensor(0.0).to('cuda' if args.cuda else 'cpu')
    step_accuracy = torch.tensor(0.0).to('cuda' if args.cuda else 'cpu')

    with tqdm(
        total=math.ceil(len(train_loader) / args.batches_per_allreduce),
        bar_format='{l_bar}{bar:10}{r_bar}',
        desc=f'Epoch {epoch:3d}/{args.epochs:3d}',
        disable=not args.verbose,
    ) as t:
        for batch_idx, (data, target) in enumerate(train_loader):
            mini_step += 1
            if args.cuda:
                data, target = data.cuda(), target.cuda()

            if scaler is not None:
                with torch.cuda.amp.autocast():
                    output = model(data)
                    loss = loss_func(output, target)
            else:
                output = model(data)
                loss = loss_func(output, target)

            with torch.no_grad():
                step_loss += loss
                step_accuracy += accuracy(output, target)

            loss = loss / args.batches_per_allreduce

            if (
                mini_step % args.batches_per_allreduce == 0
                or batch_idx + 1 == len(train_loader)
            ):
                if scaler is not None:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()
            else:
                with model.no_sync():  # type: ignore
                    if scaler is not None:
                        scaler.scale(loss).backward()
                    else:
                        loss.backward()

            if (
                mini_step % args.batches_per_allreduce == 0
                or batch_idx + 1 == len(train_loader)
            ):
                if preconditioner is not None:
                    if scaler is not None:
                        scaler.unscale_(optimizer)
                    preconditioner.step()
                if scaler is not None:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                optimizer.zero_grad()

                train_loss.update(step_loss / mini_step)
                train_accuracy.update(step_accuracy / mini_step)
                step_loss.zero_()
                step_accuracy.zero_()

                t.set_postfix_str(
                    'loss: {:.4f}, acc: {:.2f}%, lr: {:.4f}'.format(
                        train_loss.avg,
                        100 * train_accuracy.avg,
                        optimizer.param_groups[0]['lr'],
                    ),
                )
                t.update(1)
                mini_step = 0

    if args.log_writer is not None:
        args.log_writer.add_scalar('train/loss', train_loss.avg, epoch)
        args.log_writer.add_scalar('train/accuracy', train_accuracy.avg, epoch)
        args.log_writer.add_scalar(
            'train/lr',
            optimizer.param_groups[0]['lr'],
            epoch,
        )
Пример #22
0
def train_epoch(
        train_loader: torch.utils.data.DataLoader, base_model: torch.nn.Module,
        classification_layer: torch.nn.Module, forg_layer: torch.nn.Module,
        epoch: int, optimizer: torch.optim.Optimizer,
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
        callback: Optional[VisdomLogger], device: torch.device, args: Any):
    """ Trains the network for one epoch

        Parameters
        ----------
        train_loader: torch.utils.data.DataLoader
            Iterable that loads the training set (x, y) tuples
        base_model: torch.nn.Module
            The model architecture that "extract features" from signatures
        classification_layer: torch.nn.Module
            The classification layer (from features to predictions of which user
            wrote the signature)
        forg_layer: torch.nn.Module
            The forgery prediction layer (from features to predictions of whether
            the signature is a forgery). Only used in args.forg = True
        epoch: int
            The current epoch (used for reporting)
        optimizer: torch.optim.Optimizer
            The optimizer (already initialized)
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler
            The learning rate scheduler
        callback: VisdomLogger (optional)
            A callback to report the training progress
        device: torch.device
            The device (CPU or GPU) to use for training
        args: Namespace
            Extra arguments used for training:
            args.forg: bool
                Whether forgeries are being used for training
            args.lamb: float
                The weight used for the forgery loss (training with forgeries only)

        Returns
        -------
        None
        """

    step = 0
    n_steps = len(train_loader)
    for batch in train_loader:
        x, y = batch[0], batch[1]
        x = torch.tensor(x, dtype=torch.float).to(device)
        y = torch.tensor(y, dtype=torch.long).to(device)
        yforg = torch.tensor(batch[2], dtype=torch.float).to(device)

        # Forward propagation
        features = base_model(x)

        if args.forg:
            # Eq (4) in https://arxiv.org/abs/1705.05787
            logits = classification_layer(features[yforg == 0])
            class_loss = F.cross_entropy(logits, y[yforg == 0])

            forg_logits = forg_layer(features).squeeze()
            forg_loss = F.binary_cross_entropy_with_logits(forg_logits, yforg)

            loss = (1 - args.lamb) * class_loss
            loss += args.lamb * forg_loss
        else:
            # Eq (1) in https://arxiv.org/abs/1705.05787
            logits = classification_layer(features)
            loss = class_loss = F.cross_entropy(logits, y)

        # Back propagation
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(optimizer.param_groups[0]['params'],
                                        10)

        # Update weights
        optimizer.step()

        # Logging
        if callback and step % 100 == 0:
            iteration = epoch + (step / n_steps)
            callback.scalar('class_loss', iteration, class_loss.detach())

            pred = logits.argmax(1)
            acc = y[yforg == 0].eq(pred).float().mean()
            callback.scalar('train_acc', epoch + (step / n_steps),
                            acc.detach())
            if args.forg:
                forg_pred = forg_logits > 0
                forg_acc = yforg.long().eq(forg_pred.long()).float().mean()
                callback.scalar('forg_loss', iteration, forg_loss.detach())
                callback.scalar('forg_acc', iteration, forg_acc.detach())

        step += 1
    lr_scheduler.step()
Пример #23
0
def train_one_epoch(model: torch.nn.Module,
                    criterion: torch.nn.Module,
                    data_loader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device,
                    epoch: int,
                    loss_scaler,
                    max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None,
                    mixup_fn: Optional[Mixup] = None):
    # TODO fix this for finetuning
    model.train()
    criterion.train()
    end = time.time()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 50

    for samples, targets in metric_logger.log_every(data_loader, print_freq,
                                                    header):
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        metric_logger.update(data_time=time.time() - end)
        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)

        end = time.time()
        with torch.cuda.amp.autocast():
            outputs = model(samples)
            loss = criterion(outputs, targets)

        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        optimizer.zero_grad()

        # this attribute is added by timm on one optimizer (adahessian)
        is_second_order = hasattr(
            optimizer, 'is_second_order') and optimizer.is_second_order
        loss_scaler(loss,
                    optimizer,
                    clip_grad=max_norm,
                    parameters=model.parameters(),
                    create_graph=is_second_order)
        batch_time = time.time() - end
        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        metric_logger.update(batch_time=batch_time)
        metric_logger.update(throughput=samples.size(0) / batch_time)
        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
Пример #24
0
def train(dataloader: torch.utils.data.DataLoader, model: torch.nn.Module,
          loss_func: Callable, optimizer: torch.optim.Optimizer, epoch: int,
          args: argparse.Namespace):
    """
    Train the given model for a single epoch using the given dataloader.

    Args:
        dataloader: The dataloader containing the training data.
        model: Instance of the model that is being trained.
        loss_func: A loss function to compute the error between the
            actual and the desired output of the model.
        optimizer: An instance of an optimizer that is used to compute
            and perform the updates to the weights of the network.
        epoch: The current training epoch.
        args: Namespace object containing some global variable (e.g.,
            command line arguments, such as the batch size)
    """

    # -------------------------------------------------------------------------
    # Preliminaries
    # -------------------------------------------------------------------------

    # Activate training mode
    model.train()

    # Keep track the time to process a batch, as well as the batch losses
    batch_times = AverageMeter()
    batch_losses = AverageMeter()

    # -------------------------------------------------------------------------
    # Process the training dataset in mini-batches
    # -------------------------------------------------------------------------

    # TODO: Check order here
    for batch_idx, (target, data) in enumerate(dataloader):

        # Initialize start time of the batch
        batch_start = time.time()

        # Fetch data and move to device
        data, target = data.to(args.device), target.to(args.device)
        target = target.squeeze()

        # Clear gradients
        optimizer.zero_grad()

        # Compute forward pass through model
        if config['model']['class'] == 'DefaultAutoencoder':
            output = model.forward(target).squeeze()

        else:
            output = model.forward(data).squeeze()

        # Calculate the loss for the batch
        # TODO: This needs to be adjusted if we also want to compute a loss
        #       on the latent space (to ensure the physical interpretability
        #       of the latent dimensions)
        loss = loss_func(output, target)

        # Back-propagate the loss and update the weights
        loss.backward()
        optimizer.step(closure=None)

        # ---------------------------------------------------------------------
        # Log information about current batch to TensorBoard
        # ---------------------------------------------------------------------

        if args.tensorboard:
            # Compute how many examples we have processed already and log the
            # loss value for the current batch
            global_step = ((epoch - 1) * args.n_train_batches + batch_idx) * \
                          args.batch_size
            args.logger.add_scalar(tag='loss/train',
                                   scalar_value=loss.item(),
                                   global_step=global_step)

        # ---------------------------------------------------------------------
        # Additional logging to console
        # ---------------------------------------------------------------------

        # Store the loss and processing time for the current batch
        batch_losses.update(loss.item())
        batch_times.update(time.time() - batch_start)

        # Print information to console, if applicable
        if batch_idx % args.log_interval == 0:

            # Which fraction of batches have we already processed this epoch?
            percent = 100. * batch_idx / args.n_train_batches

            # Print some information about how the training is going
            print(f'Epoch: {epoch:>3}/{args.epochs}', end=' | ', flush=True)
            print(f'Batch: {batch_idx:>3}/{args.n_train_batches}',
                  flush=True,
                  end=' ')
            print(f'({percent:>4.1f}%)', end=' | ', flush=True)
            print(f'Loss: {loss.item():.6f}', end=' | ', flush=True)
            print(f'Time: {batch_times.value:>6.3f}s', flush=True)
Пример #25
0
def train_epoch(model: nn.Module, optimizer: torch.optim.Optimizer, loss_func: nn.Module,
                loader: DataLoader, cfg: Dict, epoch: int, use_mse: bool):
    """Train model for a single epoch.

    Parameters
    ----------
    model : nn.Module
        The PyTorch model to train
    optimizer : torch.optim.Optimizer
        Optimizer used for weight updating
    loss_func : nn.Module
        The loss function, implemented as a PyTorch Module
    loader : DataLoader
        PyTorch DataLoader containing the training data in batches.
    cfg : Dict
        Dictionary containing the run config
    epoch : int
        Current Number of epoch
    use_mse : bool
        If True, loss_func is nn.MSELoss(), else NSELoss() which expects addtional std of discharge
        vector

    """
    model.train()

    # process bar handle
    pbar = tqdm(loader, file=sys.stdout)
    pbar.set_description(f'# Epoch {epoch}')

    # Iterate in batches over training set
    for data in pbar:
        # delete old gradients
        optimizer.zero_grad()

        # forward pass through LSTM
        if not entity_aware:
            (x, _, _, _, _, x_s, _, q_stds), y = data
            x, y, q_stds = x.to(DEVICE), y.to(DEVICE), q_stds.to(DEVICE)
            predictions = model(x)[0]

        # forward pass through EALSTM
        elif entity_aware:
            (x, _, _, _, _, x_s, _, q_stds), y = data
            x_d, x_s, y = x_d.to(DEVICE), x_s.to(DEVICE), y.to(DEVICE)
            predictions = model(x_d, x_s[:, 0, :])[0]

        # MSELoss
        if use_mse:
            loss = loss_func(predictions, y)

        # NSELoss needs std of each basin for each sample
        else:
            q_stds = q_stds.to(DEVICE)
            loss = loss_func(predictions, y, q_stds)

        # calculate gradients
        loss.backward()

        if cfg["clip_norm"]:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["clip_value"])

        # perform parameter update
        optimizer.step()

        pbar.set_postfix_str(f"Loss: {loss.item():5f}")
Пример #26
0
def train_one_epoch(
    test_name: str,  # 实验名称
    yolov3_net: model.yolov3net.YoloV3Net,  # 网络模型
    yolov3_losses: model.yolov3loss.YoloV3Loss,  # 损失函数
    optimizer: torch.optim.Optimizer,  # 优化器
    epoch: int,  # 当前 epoch
    train_batch_num: int,  # 训练集的批次数,即为训练集大小除以批次大小
    validate_batch_num: int,  # 验证集的批次数,即为验证集大小除以批次大小
    total_epoch: int,  # 总批次
    train_data_loader: torch.utils.data.dataloader.DataLoader,  # 训练集
    validate_data_loader: torch.utils.data.dataloader.DataLoader,  # 验证集
    cuda: bool,
) -> None:
    """
    训练一个 epoch
    :return:
    """

    # -----------------------------------------------------------------------------------------------------------#
    # step1. 训练
    # -----------------------------------------------------------------------------------------------------------#
    total_train_loss = 0  # 当前 epoch 的训练总损失

    # 1. 打开网络训练模式
    yolov3_net = yolov3_net.train()

    # torch.save(yolov3_net.state_dict(), "logs/" + "begin" + ".pth")

    # 2. 加载 tadm 进度条,
    with tqdm.tqdm(total=train_batch_num,
                   desc=f"Epoch {epoch + 1}/{total_epoch}",
                   postfix=dict) as pbar:
        # 3. 批次遍历数据集
        for iteration, (tensord_images,
                        tensord_target_list) in enumerate(train_data_loader):
            if cuda:
                tensord_images = tensord_images.cuda()

            # print("train in cuda") if cuda else print("train not in cuda")

            # 4. 清零梯度
            optimizer.zero_grad()

            # 5. 前向传播
            predict_feature_list = yolov3_net(tensord_images)

            # 6. 计算损失
            loss = yolov3_losses(predict_feature_list, tensord_target_list)

            # 7. 反向传播
            loss.backward()

            # 8. 优化器优化参数
            optimizer.step()

            # 9. 进度条更新
            total_train_loss += loss.item()
            pbar.set_postfix(
                **{
                    "lr": optimizer.param_groups[0]["lr"],  # 优化器的当前学习率
                    "train_loss": total_train_loss /
                    (iteration + 1),  # 当前 epoch 的训练总损失 / 迭代次数
                })
            pbar.update(1)  # 进度条更新

    # -----------------------------------------------------------------------------------------------------------#
    # step2. 验证
    # -----------------------------------------------------------------------------------------------------------#
    total_validate_loss = 0  # 当前 epoch 的验证总损失

    # 1. 打开网络验证模式
    yolov3_net = yolov3_net.eval()

    # 2. 加载 tadm 进度条,
    with tqdm.tqdm(total=validate_batch_num,
                   desc=f"Epoch {epoch + 1}/{total_epoch}",
                   postfix=dict) as pbar:
        # 3. 批次遍历数据集
        for iteration, (
                tensord_images,
                tensord_target_list) in enumerate(validate_data_loader):
            if cuda:
                tensord_images = tensord_images.cuda()

            # print("eval in cuda") if cuda else print("eval not in cuda")

            # 4. 清零梯度
            optimizer.zero_grad()

            # 5. 前向传播
            predict_feature_list = yolov3_net(tensord_images)

            # 6. 计算损失
            loss = yolov3_losses(predict_feature_list, tensord_target_list)

            # 7. 进度条更新
            total_validate_loss += loss.item()
            pbar.set_postfix(
                **{
                    "validate_loss": total_validate_loss /
                    (iteration + 1),  # 当前 epoch 的验证总损失 / 迭代次数
                })
            pbar.update(1)  # 进度条更新

    # -----------------------------------------------------------------------------------------------------------#
    # step3. 结果
    # -----------------------------------------------------------------------------------------------------------#
    # 1. 计算平均损失
    train_loss = total_train_loss / train_batch_num
    validate_loss = total_validate_loss / validate_batch_num

    # 2. 显示结果
    ret = "Epoch%04d-Train_Loss%.4f-Val_Loss%.4f" % (epoch + 1, train_loss,
                                                     validate_loss)
    # print(ret)

    # 3. 保存权重
    torch.save(
        yolov3_net.state_dict(),
        os.path.join(os.path.join(os.getcwd(), "logs"),
                     test_name + "_" + ret + ".pth"))
Пример #27
0
def train(model,
          train_loader,
          val_loader,
          optimizer: torch.optim.Optimizer,
          criterion_g,
          criterion_v,
          criterion_c,
          criterion_feat_g,
          criterion_feat_v,
          criterion_feat_c,
          workspace: Workspace,
          scheduler=None,
          n_epoch=30,
          cutmix_prob=0,
          mixup_prob=0,
          freeze_bn_epochs=None,
          feat_loss_weight=1.0,
          use_apex=False,
          decrease_ohem_rate=False,
          use_grapheme_code=False,
          grapheme_classifier=None,
          criterion_grapheme=None,
          final_ft=False):
    score = evaluate(model, val_loader)
    workspace.log(f'Score={score}', epoch=0)
    workspace.plot_score('val/score', score, 0)

    freeze_bn_epochs = freeze_bn_epochs or []
    global_step = -1

    if final_ft:
        workspace.log('Freeze backbone')
        M.freeze_backbone(model)
        M.freeze_multihead(model)

    for epoch in range(1, n_epoch + 1):
        model.train()
        if grapheme_classifier is not None:
            grapheme_classifier.train()

        if epoch in freeze_bn_epochs:
            model.apply(set_batchnorm_eval)
            workspace.log(f'Freeze BN', epoch=epoch)

        if scheduler:
            if isinstance(scheduler,
                          torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(score)
            else:
                scheduler.step()
            workspace.log(f'Scheduler.step()', epoch=epoch)

        if decrease_ohem_rate:
            if isinstance(criterion_g, L.OHEMCrossEntropyLoss):
                r_before, r_after = criterion_g.adjust_rate(epoch)
                workspace.log(f'OHEM(g).rate: {r_before} -> {r_after}')
            if isinstance(criterion_v, L.OHEMCrossEntropyLoss):
                r_before, r_after = criterion_v.adjust_rate(epoch)
                workspace.log(f'OHEM(v).rate: {r_before} -> {r_after}')
            if isinstance(criterion_c, L.OHEMCrossEntropyLoss):
                r_before, r_after = criterion_c.adjust_rate(epoch)
                workspace.log(f'OHEM(c).rate: {r_before} -> {r_after}')
            if isinstance(criterion_grapheme, L.OHEMCrossEntropyLoss):
                r_before, r_after = criterion_grapheme.adjust_rate(epoch)
                workspace.log(f'OHEM(grapheme).rate: {r_before} -> {r_after}')
        else:
            if isinstance(criterion_g, L.OHEMCrossEntropyLoss):
                workspace.log(f'OHEM(g).rate: {criterion_g.rate}')
            if isinstance(criterion_v, L.OHEMCrossEntropyLoss):
                workspace.log(f'OHEM(v).rate: {criterion_v.rate}')
            if isinstance(criterion_c, L.OHEMCrossEntropyLoss):
                workspace.log(f'OHEM(c).rate: {criterion_c.rate}')
            if isinstance(criterion_grapheme, L.OHEMCrossEntropyLoss):
                workspace.log(
                    f'OHEM(grapheme).rate: {criterion_grapheme.rate}')

        for iteration, data_tuple in enumerate(train_loader):
            global_step += 1
            if use_grapheme_code:
                (x, tg, tv, tc, tgrapheme) = data_tuple
            else:
                (x, tg, tv, tc) = data_tuple

            if global_step == 0:
                workspace.log(
                    f'Check tensor size: x={x.size()}, '
                    f'tg={tg.size()}, tv={tv.size()}, tc={tc.size()}')

            r = np.random.rand(1)
            if r < cutmix_prob:
                use_cutmix = True
                use_mixup = False
            elif r < cutmix_prob + mixup_prob:
                use_cutmix = False
                use_mixup = True
            else:
                use_cutmix = False
                use_mixup = False

            x = x.cuda()
            (tg, tv, tc) = (tg.cuda(), tv.cuda(), tc.cuda())

            if use_grapheme_code:
                tgrapheme = tgrapheme.cuda()

            loss_feat_g = 0
            loss_feat_v = 0
            loss_feat_c = 0
            loss_grapheme = 0

            if use_cutmix or use_mixup:
                if use_cutmix:
                    x, rand_index, lam = cutmix(x, beta=1.0)
                    mix_criterion = cutmix_criterion
                elif use_mixup:
                    x, rand_index, lam = mixup(x, alpha=1.0)
                    mix_criterion = mixup_criterion

                tga, tgb = tg, tg[rand_index]
                tva, tvb = tv, tv[rand_index]
                tca, tcb = tc, tc[rand_index]

                if isinstance(model, (M.BengaliResNet34JPUAF, )):
                    logit_g, logit_v, logit_c = model(x, tg=tg, tv=tv, tc=tc)
                elif isinstance(
                        model,
                    (M.BengaliResNet34V3, M.BengaliResNet34V4,
                     M.BengaliResNet34AGeMV4, M.BengaliSEResNeXt50V4,
                     M.BengaliEfficientNetB0V4, M.BengaliEfficientNetB3V4)):
                    (feat, feat_g, logit_g, feat_v, logit_v, feat_c,
                     logit_c) = model(x)
                else:
                    logit_g, logit_v, logit_c = model(x)

                loss_g = mix_criterion(logit_g,
                                       tga,
                                       tgb,
                                       lam,
                                       criterion=criterion_g)
                loss_v = mix_criterion(logit_v,
                                       tva,
                                       tvb,
                                       lam,
                                       criterion=criterion_v)
                loss_c = mix_criterion(logit_c,
                                       tca,
                                       tcb,
                                       lam,
                                       criterion=criterion_c)
            else:
                if isinstance(model, (M.BengaliResNet34JPUAF, )):
                    logit_g, logit_v, logit_c = model(x, tg=tg, tv=tv, tc=tc)
                elif isinstance(
                        model,
                    (M.BengaliResNet34V3, M.BengaliResNet34V4,
                     M.BengaliResNet34AGeMV4, M.BengaliSEResNeXt50V4,
                     M.BengaliEfficientNetB0V4, M.BengaliEfficientNetB3V4)):
                    (feat, feat_g, logit_g, feat_v, logit_v, feat_c,
                     logit_c) = model(x)

                    if criterion_feat_g is None:
                        pass
                    else:
                        loss_feat_g = criterion_feat_g(feat_g, tg)

                    if criterion_feat_v is None:
                        pass
                    else:
                        loss_feat_v = criterion_feat_v(feat_v, tv)

                    if criterion_feat_c is None:
                        pass
                    else:
                        loss_feat_c = criterion_feat_c(feat_c, tc)
                else:
                    logit_g, logit_v, logit_c = model(x)

                loss_g = criterion_g(logit_g, tg)
                loss_v = criterion_v(logit_v, tv)
                loss_c = criterion_c(logit_c, tc)

                if use_grapheme_code:
                    logit_grapheme = grapheme_classifier(
                        torch.cat([logit_g, logit_v, logit_c], dim=1))
                    loss_grapheme = criterion_grapheme(logit_grapheme,
                                                       tgrapheme)

            loss_feat = loss_feat_g + loss_feat_v + loss_feat_c
            loss = loss_g + loss_v + loss_c + loss_grapheme + feat_loss_weight * loss_feat

            if global_step % 20 == 0:
                if loss_feat == 0:
                    if loss_grapheme == 0:
                        workspace.log(f'Iteration={iteration}, Loss={loss}',
                                      epoch=epoch)
                    else:
                        workspace.log(
                            f'Iteration={iteration}, Loss={loss}, LossGrapheme={loss_grapheme}',
                            epoch=epoch)
                else:
                    if loss_grapheme == 0:
                        workspace.log(
                            f'Iteration={iteration}, Loss={loss}, FeatLoss={loss_feat}',
                            epoch=epoch)
                    else:
                        workspace.log(
                            f'Iteration={iteration}, Loss={loss}, LossGrapheme={loss_grapheme}, FeatLoss={loss_feat}',
                            epoch=epoch)
                workspace.plot_score('train/loss', float(loss.item()),
                                     global_step)
                workspace.plot_score('train/lr',
                                     float(get_current_lr(optimizer)),
                                     global_step)

            optimizer.zero_grad()

            if use_apex:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()

            if isinstance(scheduler, CosineLRWithRestarts):
                scheduler.batch_step()

        score = evaluate(model, val_loader)

        workspace.log(f'Score={score}', epoch=epoch)
        workspace.plot_score('val/score', score, epoch)
        saved = workspace.save_bestmodel(model, epoch, score)

        if saved:
            checkpoint = {
                'optimizer': optimizer.state_dict(),
                'amp': None if not use_apex else amp.state_dict()
            }
            if scheduler is not None:
                checkpoint['scheduler'] = scheduler.state_dict()
            if isinstance(criterion_feat_g, nn.Module):
                checkpoint['criterion_feat_g'] = criterion_feat_g.state_dict()
            if isinstance(criterion_feat_v, nn.Module):
                checkpoint['criterion_feat_v'] = criterion_feat_v.state_dict()
            if isinstance(criterion_feat_c, nn.Module):
                checkpoint['criterion_feat_c'] = criterion_feat_c.state_dict()
            workspace.save_checkpoint(epoch, name='best', **checkpoint)
    workspace.save_model(model, n_epoch)
Пример #28
0
def training_step(cost: torch.FloatTensor, optimizer: torch.optim.Optimizer):
    optimizer.zero_grad()
    cost.backward()
    optimizer.step()
Пример #29
0
    def attach(self, optimizer: torch.optim.Optimizer):
        r"""
        Attaches the privacy engine to the optimizer.

        Attaches to the ``PrivacyEngine`` an optimizer object,and injects
        itself into the optimizer's step. To do that it,

        1. Validates that the model does not have unsupported layers.

        2. Adds a pointer to this object (the ``PrivacyEngine``) inside the optimizer.

        3. Moves optimizer's original ``step()`` function to ``original_step()``.

        4. Monkeypatches the optimizer's ``step()`` function to call ``step()`` on
        the query engine automatically whenever it would call ``step()`` for itself.

        Args:
            optimizer: The optimizer to which the privacy engine will attach
        """

        self.validator.validate(self.module)
        norm_clipper = (
            # pyre-fixme[6]: Expected `float` for 1st param but got
            #  `Union[List[float], float]`.
            clipping.ConstantFlatClipper(self.max_grad_norm)
            if not isinstance(self.max_grad_norm, list)
            # pyre-fixme[6]: Expected `List[float]` for 1st param but got
            #  `Union[List[float], float]`.
            else clipping.ConstantPerLayerClipper(self.max_grad_norm))

        if self.misc_settings.get("experimental", False):
            norm_clipper = clipping._Dynamic_Clipper_(
                # pyre-fixme[6]: Expected `List[float]` for 1st param but got
                #  `List[Union[List[float], float]]`.
                [self.max_grad_norm],
                self.misc_settings.get("clip_per_layer", False),
                self.misc_settings.get("clipping_method",
                                       clipping.ClippingMethod.STATIC),
                self.misc_settings.get("clipping_ratio", 0.0),
                self.misc_settings.get("clipping_momentum", 0.0),
            )

        self.clipper = PerSampleGradientClipper(
            self.module,
            norm_clipper,
            self.batch_first,
            self.loss_reduction,
        )

        def dp_zero_grad(self):
            self.privacy_engine.zero_grad()
            self.original_zero_grad()

        def dp_step(self, closure=None):
            self.privacy_engine.step()
            self.original_step(closure)

        # Pyre doesn't like monkeypatching. But we'll do it anyway :)
        optimizer.privacy_engine = self  # pyre-ignore
        optimizer.original_step = optimizer.step  # pyre-ignore
        optimizer.step = types.MethodType(dp_step, optimizer)  # pyre-ignore

        optimizer.original_zero_grad = optimizer.zero_grad  # pyre-ignore
        optimizer.zero_grad = types.MethodType(dp_zero_grad,
                                               optimizer)  # pyre-ignore

        def virtual_step(self):
            self.privacy_engine.virtual_step()

        # pyre-ignore
        optimizer.virtual_step = types.MethodType(virtual_step, optimizer)

        # create a cross reference for detaching
        self.optimizer = optimizer  # pyre-ignore
Пример #30
0
    def train(
        self,
        base_path: Union[Path, str],
        learning_rate: float = 0.1,
        mini_batch_size: int = 32,
        mini_batch_chunk_size: Optional[int] = None,
        max_epochs: int = 100,
        train_with_dev: bool = False,
        train_with_test: bool = False,
        monitor_train: bool = False,
        monitor_test: bool = False,
        main_evaluation_metric: Tuple[str, str] = ("micro avg", 'f1-score'),
        scheduler=AnnealOnPlateau,
        anneal_factor: float = 0.5,
        patience: int = 3,
        min_learning_rate: float = 0.0001,
        initial_extra_patience: int = 0,
        optimizer: torch.optim.Optimizer = SGD,
        cycle_momentum: bool = False,
        warmup_fraction: float = 0.1,
        embeddings_storage_mode: str = "cpu",
        checkpoint: bool = False,
        save_final_model: bool = True,
        anneal_with_restarts: bool = False,
        anneal_with_prestarts: bool = False,
        anneal_against_dev_loss: bool = False,
        batch_growth_annealing: bool = False,
        shuffle: bool = True,
        param_selection_mode: bool = False,
        write_weights: bool = False,
        num_workers: int = 6,
        sampler=None,
        use_amp: bool = False,
        amp_opt_level: str = "O1",
        eval_on_train_fraction: float = 0.0,
        eval_on_train_shuffle: bool = False,
        save_model_each_k_epochs: int = 0,
        tensorboard_comment: str = '',
        use_swa: bool = False,
        use_final_model_for_eval: bool = False,
        gold_label_dictionary_for_eval: Optional[Dictionary] = None,
        create_file_logs: bool = True,
        create_loss_file: bool = True,
        epoch: int = 0,
        use_tensorboard: bool = False,
        tensorboard_log_dir=None,
        metrics_for_tensorboard=[],
        optimizer_state_dict: Optional = None,
        scheduler_state_dict: Optional = None,
        save_optimizer_state: bool = False,
        **kwargs,
    ) -> dict:
        """
        Trains any class that implements the flair.nn.Model interface.
        :param base_path: Main path to which all output during training is logged and models are saved
        :param learning_rate: Initial learning rate (or max, if scheduler is OneCycleLR)
        :param mini_batch_size: Size of mini-batches during training
        :param mini_batch_chunk_size: If mini-batches are larger than this number, they get broken down into chunks of this size for processing purposes
        :param max_epochs: Maximum number of epochs to train. Terminates training if this number is surpassed.
        :param scheduler: The learning rate scheduler to use
        :param checkpoint: If True, a full checkpoint is saved at end of each epoch
        :param cycle_momentum: If scheduler is OneCycleLR, whether the scheduler should cycle also the momentum
        :param anneal_factor: The factor by which the learning rate is annealed
        :param patience: Patience is the number of epochs with no improvement the Trainer waits
         until annealing the learning rate
        :param min_learning_rate: If the learning rate falls below this threshold, training terminates
        :param warmup_fraction: Fraction of warmup steps if the scheduler is LinearSchedulerWithWarmup
        :param train_with_dev:  If True, the data from dev split is added to the training data
        :param train_with_test: If True, the data from test split is added to the training data
        :param monitor_train: If True, training data is evaluated at end of each epoch
        :param monitor_test: If True, test data is evaluated at end of each epoch
        :param embeddings_storage_mode: One of 'none' (all embeddings are deleted and freshly recomputed),
        'cpu' (embeddings are stored on CPU) or 'gpu' (embeddings are stored on GPU)
        :param save_final_model: If True, final model is saved
        :param anneal_with_restarts: If True, the last best model is restored when annealing the learning rate
        :param shuffle: If True, data is shuffled during training
        :param param_selection_mode: If True, testing is performed against dev data. Use this mode when doing
        parameter selection.
        :param num_workers: Number of workers in your data loader.
        :param sampler: You can pass a data sampler here for special sampling of data.
        :param eval_on_train_fraction: the fraction of train data to do the evaluation on,
        if 0. the evaluation is not performed on fraction of training data,
        if 'dev' the size is determined from dev set size
        :param eval_on_train_shuffle: if True the train data fraction is determined on the start of training
        and kept fixed during training, otherwise it's sampled at beginning of each epoch
        :param save_model_each_k_epochs: Each k epochs, a model state will be written out. If set to '5', a model will
        be saved each 5 epochs. Default is 0 which means no model saving.
        :param main_evaluation_metric: Type of metric to use for best model tracking and learning rate scheduling (if dev data is available, otherwise loss will be used), currently only applicable for text_classification_model
        :param tensorboard_comment: Comment to use for tensorboard logging
        :param create_file_logs: If True, the logs will also be stored in a file 'training.log' in the model folder
        :param create_loss_file: If True, the loss will be writen to a file 'loss.tsv' in the model folder
        :param optimizer: The optimizer to use (typically SGD or Adam)
        :param epoch: The starting epoch (normally 0 but could be higher if you continue training model)
        :param use_tensorboard: If True, writes out tensorboard information
        :param tensorboard_log_dir: Directory into which tensorboard log files will be written
        :param metrics_for_tensorboard: List of tuples that specify which metrics (in addition to the main_score) shall be plotted in tensorboard, could be [("macro avg", 'f1-score'), ("macro avg", 'precision')] for example
        :param kwargs: Other arguments for the Optimizer
        :return:
        """

        # create a model card for this model with Flair and PyTorch version
        model_card = {
            'flair_version': flair.__version__,
            'pytorch_version': torch.__version__
        }

        # also record Transformers version if library is loaded
        try:
            import transformers
            model_card['transformers_version'] = transformers.__version__
        except:
            pass

        # remember all parameters used in train() call
        local_variables = locals()
        training_parameters = {}
        for parameter in signature(self.train).parameters:
            training_parameters[parameter] = local_variables[parameter]
        model_card['training_parameters'] = training_parameters

        # add model card to model
        self.model.model_card = model_card

        if use_tensorboard:
            try:
                from torch.utils.tensorboard import SummaryWriter

                if tensorboard_log_dir is not None and not os.path.exists(
                        tensorboard_log_dir):
                    os.mkdir(tensorboard_log_dir)
                writer = SummaryWriter(log_dir=tensorboard_log_dir,
                                       comment=tensorboard_comment)
                log.info(f"tensorboard logging path is {tensorboard_log_dir}")

            except:
                log_line(log)
                log.warning(
                    "ATTENTION! PyTorch >= 1.1.0 and pillow are required for TensorBoard support!"
                )
                log_line(log)
                use_tensorboard = False
                pass

        if use_amp:
            if sys.version_info < (3, 0):
                raise RuntimeError(
                    "Apex currently only supports Python 3. Aborting.")
            if amp is None:
                raise RuntimeError(
                    "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
                    "to enable mixed-precision training.")

        if mini_batch_chunk_size is None:
            mini_batch_chunk_size = mini_batch_size
        if learning_rate < min_learning_rate:
            min_learning_rate = learning_rate / 10

        initial_learning_rate = learning_rate

        # cast string to Path
        if type(base_path) is str:
            base_path = Path(base_path)
        base_path.mkdir(exist_ok=True, parents=True)

        if create_file_logs:
            log_handler = add_file_handler(log, base_path / "training.log")
        else:
            log_handler = None

        log_line(log)
        log.info(f'Model: "{self.model}"')
        log_line(log)
        log.info(f'Corpus: "{self.corpus}"')
        log_line(log)
        log.info("Parameters:")
        log.info(f' - learning_rate: "{learning_rate}"')
        log.info(f' - mini_batch_size: "{mini_batch_size}"')
        log.info(f' - patience: "{patience}"')
        log.info(f' - anneal_factor: "{anneal_factor}"')
        log.info(f' - max_epochs: "{max_epochs}"')
        log.info(f' - shuffle: "{shuffle}"')
        log.info(f' - train_with_dev: "{train_with_dev}"')
        log.info(f' - batch_growth_annealing: "{batch_growth_annealing}"')
        log_line(log)
        log.info(f'Model training base path: "{base_path}"')
        log_line(log)
        log.info(f"Device: {flair.device}")
        log_line(log)
        log.info(f"Embeddings storage mode: {embeddings_storage_mode}")
        if isinstance(self.model, SequenceTagger
                      ) and self.model.weight_dict and self.model.use_crf:
            log_line(log)
            log.warning(
                f'WARNING: Specified class weights will not take effect when using CRF'
            )

        # check for previously saved best models in the current training folder and delete them
        self.check_for_and_delete_previous_best_models(base_path)

        # determine what splits (train, dev, test) to evaluate and log
        log_train = True if monitor_train else False
        log_test = True if (not param_selection_mode and self.corpus.test
                            and monitor_test) else False
        log_dev = False if train_with_dev or not self.corpus.dev else True
        log_train_part = True if (eval_on_train_fraction == "dev"
                                  or eval_on_train_fraction > 0.0) else False

        if log_train_part:
            train_part_size = len(self.corpus.dev) if eval_on_train_fraction == "dev" \
                else int(len(self.corpus.train) * eval_on_train_fraction)

            assert train_part_size > 0
            if not eval_on_train_shuffle:
                train_part_indices = list(range(train_part_size))
                train_part = torch.utils.data.dataset.Subset(
                    self.corpus.train, train_part_indices)

        # prepare loss logging file and set up header
        loss_txt = init_output_file(base_path,
                                    "loss.tsv") if create_loss_file else None

        weight_extractor = WeightExtractor(base_path)

        # if optimizer class is passed, instantiate:
        if inspect.isclass(optimizer):
            optimizer: torch.optim.Optimizer = optimizer(
                self.model.parameters(), lr=learning_rate, **kwargs)

        if use_swa:
            import torchcontrib
            optimizer = torchcontrib.optim.SWA(optimizer,
                                               swa_start=10,
                                               swa_freq=5,
                                               swa_lr=learning_rate)

        if use_amp:
            self.model, optimizer = amp.initialize(self.model,
                                                   optimizer,
                                                   opt_level=amp_opt_level)

        # load existing optimizer state dictionary if it exists
        if optimizer_state_dict:
            optimizer.load_state_dict(optimizer_state_dict)

        # minimize training loss if training with dev data, else maximize dev score
        anneal_mode = "min" if train_with_dev or anneal_against_dev_loss else "max"
        best_validation_score = 100000000000 if train_with_dev or anneal_against_dev_loss else 0.

        dataset_size = len(self.corpus.train)
        if train_with_dev:
            dataset_size += len(self.corpus.dev)

        # if scheduler is passed as a class, instantiate
        if inspect.isclass(scheduler):
            if scheduler == OneCycleLR:
                scheduler = OneCycleLR(
                    optimizer,
                    max_lr=learning_rate,
                    steps_per_epoch=dataset_size // mini_batch_size + 1,
                    epochs=max_epochs - epoch,
                    # if we load a checkpoint, we have already trained for epoch
                    pct_start=0.0,
                    cycle_momentum=cycle_momentum)
            elif scheduler == LinearSchedulerWithWarmup:
                steps_per_epoch = (dataset_size + mini_batch_size -
                                   1) / mini_batch_size
                num_train_steps = int(steps_per_epoch * max_epochs)
                num_warmup_steps = int(num_train_steps * warmup_fraction)

                scheduler = LinearSchedulerWithWarmup(
                    optimizer,
                    num_train_steps=num_train_steps,
                    num_warmup_steps=num_warmup_steps)
            else:
                scheduler = scheduler(
                    optimizer,
                    factor=anneal_factor,
                    patience=patience,
                    initial_extra_patience=initial_extra_patience,
                    mode=anneal_mode,
                    verbose=True,
                )

        # load existing scheduler state dictionary if it exists
        if scheduler_state_dict:
            scheduler.load_state_dict(scheduler_state_dict)

        # update optimizer and scheduler in model card
        model_card['training_parameters']['optimizer'] = optimizer
        model_card['training_parameters']['scheduler'] = scheduler

        if isinstance(scheduler, OneCycleLR) and batch_growth_annealing:
            raise ValueError(
                "Batch growth with OneCycle policy is not implemented.")

        train_data = self.corpus.train

        # if training also uses dev/train data, include in training set
        if train_with_dev or train_with_test:

            parts = [self.corpus.train]
            if train_with_dev: parts.append(self.corpus.dev)
            if train_with_test: parts.append(self.corpus.test)

            train_data = ConcatDataset(parts)

        # initialize sampler if provided
        if sampler is not None:
            # init with default values if only class is provided
            if inspect.isclass(sampler):
                sampler = sampler()
            # set dataset to sample from
            sampler.set_dataset(train_data)
            shuffle = False

        dev_score_history = []
        dev_loss_history = []
        train_loss_history = []

        micro_batch_size = mini_batch_chunk_size

        # At any point you can hit Ctrl + C to break out of training early.
        try:
            previous_learning_rate = learning_rate
            momentum = 0
            for group in optimizer.param_groups:
                if "momentum" in group:
                    momentum = group["momentum"]

            for epoch in range(epoch + 1, max_epochs + 1):
                log_line(log)

                # update epoch in model card
                self.model.model_card['training_parameters']['epoch'] = epoch

                if anneal_with_prestarts:
                    last_epoch_model_state_dict = copy.deepcopy(
                        self.model.state_dict())

                if eval_on_train_shuffle:
                    train_part_indices = list(range(self.corpus.train))
                    random.shuffle(train_part_indices)
                    train_part_indices = train_part_indices[:train_part_size]
                    train_part = torch.utils.data.dataset.Subset(
                        self.corpus.train, train_part_indices)

                # get new learning rate
                for group in optimizer.param_groups:
                    learning_rate = group["lr"]

                if learning_rate != previous_learning_rate and batch_growth_annealing:
                    mini_batch_size *= 2

                # reload last best model if annealing with restarts is enabled
                if ((anneal_with_restarts or anneal_with_prestarts)
                        and learning_rate != previous_learning_rate
                        and os.path.exists(base_path / "best-model.pt")):
                    if anneal_with_restarts:
                        log.info("resetting to best model")
                        self.model.load_state_dict(
                            self.model.load(base_path /
                                            "best-model.pt").state_dict())
                    if anneal_with_prestarts:
                        log.info("resetting to pre-best model")
                        self.model.load_state_dict(
                            self.model.load(base_path /
                                            "pre-best-model.pt").state_dict())

                previous_learning_rate = learning_rate
                if use_tensorboard:
                    writer.add_scalar("learning_rate", learning_rate, epoch)

                # stop training if learning rate becomes too small
                if ((not isinstance(scheduler,
                                    (OneCycleLR, LinearSchedulerWithWarmup))
                     and learning_rate < min_learning_rate)):
                    log_line(log)
                    log.info("learning rate too small - quitting training!")
                    log_line(log)
                    break

                batch_loader = DataLoader(
                    train_data,
                    batch_size=mini_batch_size,
                    shuffle=shuffle
                    if epoch > 1 else False,  # never shuffle the first epoch
                    num_workers=num_workers,
                    sampler=sampler,
                )

                self.model.train()

                train_loss: float = 0

                seen_batches = 0
                total_number_of_batches = len(batch_loader)

                modulo = max(1, int(total_number_of_batches / 10))

                # process mini-batches
                batch_time = 0
                average_over = 0
                for batch_no, batch in enumerate(batch_loader):

                    start_time = time.time()

                    # zero the gradients on the model and optimizer
                    self.model.zero_grad()
                    optimizer.zero_grad()

                    # if necessary, make batch_steps
                    batch_steps = [batch]
                    if len(batch) > micro_batch_size:
                        batch_steps = [
                            batch[x:x + micro_batch_size]
                            for x in range(0, len(batch), micro_batch_size)
                        ]

                    # forward and backward for batch
                    for batch_step in batch_steps:

                        # forward pass
                        loss = self.model.forward_loss(batch_step)

                        if isinstance(loss, Tuple):
                            average_over += loss[1]
                            loss = loss[0]

                        # Backward
                        if use_amp:
                            with amp.scale_loss(loss,
                                                optimizer) as scaled_loss:
                                scaled_loss.backward()
                        else:
                            loss.backward()
                        train_loss += loss.item()

                    # do the optimizer step
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   5.0)
                    optimizer.step()

                    # do the scheduler step if one-cycle or linear decay
                    if isinstance(scheduler,
                                  (OneCycleLR, LinearSchedulerWithWarmup)):
                        scheduler.step()
                        # get new learning rate
                        for group in optimizer.param_groups:
                            learning_rate = group["lr"]
                            if "momentum" in group:
                                momentum = group["momentum"]
                            if "betas" in group:
                                momentum, _ = group["betas"]

                    seen_batches += 1

                    # depending on memory mode, embeddings are moved to CPU, GPU or deleted
                    store_embeddings(batch, embeddings_storage_mode)

                    batch_time += time.time() - start_time
                    if seen_batches % modulo == 0:
                        momentum_info = f' - momentum: {momentum:.4f}' if cycle_momentum else ''
                        intermittent_loss = train_loss / average_over if average_over > 0 else train_loss / seen_batches
                        log.info(
                            f"epoch {epoch} - iter {seen_batches}/{total_number_of_batches} - loss "
                            f"{intermittent_loss:.8f} - samples/sec: {mini_batch_size * modulo / batch_time:.2f}"
                            f" - lr: {learning_rate:.6f}{momentum_info}")
                        batch_time = 0
                        iteration = epoch * total_number_of_batches + batch_no
                        if not param_selection_mode and write_weights:
                            weight_extractor.extract_weights(
                                self.model.state_dict(), iteration)

                if average_over != 0:
                    train_loss /= average_over

                self.model.eval()

                log_line(log)
                log.info(
                    f"EPOCH {epoch} done: loss {train_loss:.4f} - lr {learning_rate:.7f}"
                )

                if use_tensorboard:
                    writer.add_scalar("train_loss", train_loss, epoch)

                # evaluate on train / dev / test split depending on training settings
                result_line: str = ""

                if log_train:
                    train_eval_result = self.model.evaluate(
                        self.corpus.train,
                        gold_label_type=self.model.label_type,
                        mini_batch_size=mini_batch_chunk_size,
                        num_workers=num_workers,
                        embedding_storage_mode=embeddings_storage_mode,
                        main_evaluation_metric=main_evaluation_metric,
                        gold_label_dictionary=gold_label_dictionary_for_eval,
                    )
                    result_line += f"\t{train_eval_result.log_line}"

                    # depending on memory mode, embeddings are moved to CPU, GPU or deleted
                    store_embeddings(self.corpus.train,
                                     embeddings_storage_mode)

                if log_train_part:
                    train_part_eval_result, train_part_loss = self.model.evaluate(
                        train_part,
                        gold_label_type=self.model.label_type,
                        mini_batch_size=mini_batch_chunk_size,
                        num_workers=num_workers,
                        embedding_storage_mode=embeddings_storage_mode,
                        main_evaluation_metric=main_evaluation_metric,
                        gold_label_dictionary=gold_label_dictionary_for_eval,
                    )
                    result_line += f"\t{train_part_loss}\t{train_part_eval_result.log_line}"

                    log.info(
                        f"TRAIN_SPLIT : loss {train_part_loss} - {main_evaluation_metric[1]} ({main_evaluation_metric[0]}) {round(train_part_eval_result.main_score, 4)}"
                    )
                if use_tensorboard:
                    for (metric_class_avg_type,
                         metric_type) in metrics_for_tensorboard:
                        writer.add_scalar(
                            f"train_{metric_class_avg_type}_{metric_type}",
                            train_part_eval_result.classification_report[
                                metric_class_avg_type][metric_type], epoch)

                if log_dev:
                    dev_eval_result = self.model.evaluate(
                        self.corpus.dev,
                        gold_label_type=self.model.label_type,
                        mini_batch_size=mini_batch_chunk_size,
                        num_workers=num_workers,
                        out_path=base_path / "dev.tsv",
                        embedding_storage_mode=embeddings_storage_mode,
                        main_evaluation_metric=main_evaluation_metric,
                        gold_label_dictionary=gold_label_dictionary_for_eval,
                    )
                    result_line += f"\t{dev_eval_result.loss}\t{dev_eval_result.log_line}"
                    log.info(
                        f"DEV : loss {dev_eval_result.loss} - {main_evaluation_metric[1]} ({main_evaluation_metric[0]})  {round(dev_eval_result.main_score, 4)}"
                    )
                    # calculate scores using dev data if available
                    # append dev score to score history
                    dev_score_history.append(dev_eval_result.main_score)
                    dev_loss_history.append(dev_eval_result.loss)

                    dev_score = dev_eval_result.main_score

                    # depending on memory mode, embeddings are moved to CPU, GPU or deleted
                    store_embeddings(self.corpus.dev, embeddings_storage_mode)

                    if use_tensorboard:
                        writer.add_scalar("dev_loss", dev_eval_result.loss,
                                          epoch)
                        writer.add_scalar("dev_score",
                                          dev_eval_result.main_score, epoch)
                        for (metric_class_avg_type,
                             metric_type) in metrics_for_tensorboard:
                            writer.add_scalar(
                                f"dev_{metric_class_avg_type}_{metric_type}",
                                dev_eval_result.classification_report[
                                    metric_class_avg_type][metric_type], epoch)

                if log_test:
                    test_eval_result = self.model.evaluate(
                        self.corpus.test,
                        gold_label_type=self.model.label_type,
                        mini_batch_size=mini_batch_chunk_size,
                        num_workers=num_workers,
                        out_path=base_path / "test.tsv",
                        embedding_storage_mode=embeddings_storage_mode,
                        main_evaluation_metric=main_evaluation_metric,
                        gold_label_dictionary=gold_label_dictionary_for_eval,
                    )
                    result_line += f"\t{test_eval_result.loss}\t{test_eval_result.log_line}"
                    log.info(
                        f"TEST : loss {test_eval_result.loss} - {main_evaluation_metric[1]} ({main_evaluation_metric[0]})  {round(test_eval_result.main_score, 4)}"
                    )

                    # depending on memory mode, embeddings are moved to CPU, GPU or deleted
                    store_embeddings(self.corpus.test, embeddings_storage_mode)

                    if use_tensorboard:
                        writer.add_scalar("test_loss", test_eval_result.loss,
                                          epoch)
                        writer.add_scalar("test_score",
                                          test_eval_result.main_score, epoch)
                        for (metric_class_avg_type,
                             metric_type) in metrics_for_tensorboard:
                            writer.add_scalar(
                                f"test_{metric_class_avg_type}_{metric_type}",
                                test_eval_result.classification_report[
                                    metric_class_avg_type][metric_type], epoch)

                # determine if this is the best model or if we need to anneal
                current_epoch_has_best_model_so_far = False
                # default mode: anneal against dev score
                if not train_with_dev and not anneal_against_dev_loss:
                    if dev_score > best_validation_score:
                        current_epoch_has_best_model_so_far = True
                        best_validation_score = dev_score

                    if isinstance(scheduler, AnnealOnPlateau):
                        scheduler.step(dev_score, dev_eval_result.loss)

                # alternative: anneal against dev loss
                if not train_with_dev and anneal_against_dev_loss:
                    if dev_eval_result.loss < best_validation_score:
                        current_epoch_has_best_model_so_far = True
                        best_validation_score = dev_eval_result.loss

                    if isinstance(scheduler, AnnealOnPlateau):
                        scheduler.step(dev_eval_result.loss)

                # alternative: anneal against train loss
                if train_with_dev:
                    if train_loss < best_validation_score:
                        current_epoch_has_best_model_so_far = True
                        best_validation_score = train_loss

                    if isinstance(scheduler, AnnealOnPlateau):
                        scheduler.step(train_loss)

                train_loss_history.append(train_loss)

                # determine bad epoch number
                try:
                    bad_epochs = scheduler.num_bad_epochs
                except:
                    bad_epochs = 0
                for group in optimizer.param_groups:
                    new_learning_rate = group["lr"]
                if new_learning_rate != previous_learning_rate:
                    bad_epochs = patience + 1
                    if previous_learning_rate == initial_learning_rate:
                        bad_epochs += initial_extra_patience

                # log bad epochs
                log.info(f"BAD EPOCHS (no improvement): {bad_epochs}")

                if create_loss_file:
                    # output log file
                    with open(loss_txt, "a") as f:

                        # make headers on first epoch
                        if epoch == 1:
                            f.write(
                                f"EPOCH\tTIMESTAMP\tBAD_EPOCHS\tLEARNING_RATE\tTRAIN_LOSS"
                            )

                            if log_train:
                                f.write("\tTRAIN_" + "\tTRAIN_".join(
                                    train_eval_result.log_header.split("\t")))

                            if log_train_part:
                                f.write("\tTRAIN_PART_LOSS\tTRAIN_PART_" +
                                        "\tTRAIN_PART_".join(
                                            train_part_eval_result.log_header.
                                            split("\t")))

                            if log_dev:
                                f.write("\tDEV_LOSS\tDEV_" + "\tDEV_".join(
                                    dev_eval_result.log_header.split("\t")))

                            if log_test:
                                f.write("\tTEST_LOSS\tTEST_" + "\tTEST_".join(
                                    test_eval_result.log_header.split("\t")))

                        f.write(
                            f"\n{epoch}\t{datetime.datetime.now():%H:%M:%S}\t{bad_epochs}\t{learning_rate:.4f}\t{train_loss}"
                        )
                        f.write(result_line)

                # if checkpoint is enabled, save model at each epoch
                if checkpoint and not param_selection_mode:
                    self.model.save(base_path / "checkpoint.pt",
                                    checkpoint=True)

                # Check whether to save best model
                if ((not train_with_dev or anneal_with_restarts
                     or anneal_with_prestarts) and not param_selection_mode
                        and current_epoch_has_best_model_so_far
                        and not use_final_model_for_eval):
                    log.info("saving best model")
                    self.model.save(base_path / "best-model.pt",
                                    checkpoint=save_optimizer_state)

                    if anneal_with_prestarts:
                        current_state_dict = self.model.state_dict()
                        self.model.load_state_dict(last_epoch_model_state_dict)
                        self.model.save(base_path / "pre-best-model.pt")
                        self.model.load_state_dict(current_state_dict)

                if save_model_each_k_epochs > 0 and not epoch % save_model_each_k_epochs:
                    print("saving model of current epoch")
                    model_name = "model_epoch_" + str(epoch) + ".pt"
                    self.model.save(base_path / model_name,
                                    checkpoint=save_optimizer_state)

            if use_swa:
                optimizer.swap_swa_sgd()

            # if we do not use dev data for model selection, save final model
            if save_final_model and not param_selection_mode:
                self.model.save(base_path / "final-model.pt",
                                checkpoint=save_optimizer_state)

        except KeyboardInterrupt:
            log_line(log)
            log.info("Exiting from training early.")

            if use_tensorboard:
                writer.close()

            if not param_selection_mode:
                log.info("Saving model ...")
                self.model.save(base_path / "final-model.pt",
                                checkpoint=save_optimizer_state)
                log.info("Done.")

        # test best model if test data is present
        if self.corpus.test and not train_with_test:
            final_score = self.final_test(
                base_path=base_path,
                eval_mini_batch_size=mini_batch_chunk_size,
                num_workers=num_workers,
                main_evaluation_metric=main_evaluation_metric,
                gold_label_dictionary_for_eval=gold_label_dictionary_for_eval,
            )
        else:
            final_score = 0
            log.info("Test data not provided setting final score to 0")

        if create_file_logs:
            log_handler.close()
            log.removeHandler(log_handler)

        if use_tensorboard:
            writer.close()

        return {
            "test_score": final_score,
            "dev_score_history": dev_score_history,
            "train_loss_history": train_loss_history,
            "dev_loss_history": dev_loss_history,
        }