Ejemplo n.º 1
0
def track(sequence: Iterable,
          description: str = "Working...",
          disable: bool = False,
          style: Literal["rich", "tqdm"] = None,
          **kwargs):
    """Progress bar with `'rich'` and `'tqdm'` styles."""
    if style is None:
        style = settings.progress_bar_style
    if style not in ["rich", "tqdm"]:
        raise ValueError("style must be one of ['rich', 'tqdm']")
    if disable:
        return sequence
    if style == "tqdm":
        # fixes repeated pbar in jupyter
        # see https://github.com/tqdm/tqdm/issues/375
        if hasattr(tqdm_base, "_instances"):
            for instance in list(tqdm_base._instances):
                tqdm_base._decr_instances(instance)
        return tqdm_base(sequence, desc=description, file=sys.stdout, **kwargs)
    else:
        in_colab = "google.colab" in sys.modules
        force_jupyter = None if not in_colab else True
        console = Console(force_jupyter=force_jupyter)
        return track_base(sequence,
                          description=description,
                          console=console,
                          **kwargs)
Ejemplo n.º 2
0
def tqdm(*args, **kwargs):

    #     get_ipython().events.register('post_execute', tqdm_clear)
    if hasattr(tqdm_base, '_instances'):
        for instance in list(tqdm_base._instances):
            tqdm_base._decr_instances(instance)
    return tqdm_base(*args, **kwargs)
Ejemplo n.º 3
0
def tqdm(*args, **kwargs):
    """Decorator of tqdm, to avoid some errors if tqdm 
    terminated unexpectedly

    Returns
    -------
    an decorated `tqdm` class
    """

    if hasattr(tqdm_base, '_instances'):
        for instance in list(tqdm_base._instances):
            tqdm_base._decr_instances(instance)
    return tqdm_base(*args, **kwargs)
Ejemplo n.º 4
0
Archivo: _track.py Proyecto: vals/scVI
def track(sequence: Iterable,
          description: str = "Working...",
          disable: bool = False,
          style: Literal["rich", "tqdm"] = None,
          **kwargs):
    """
    Progress bar with `'rich'` and `'tqdm'` styles.

    Parameters
    ----------
    sequence
        Iterable sequence.
    description
        First text shown to left of progress bar.
    disable
        Switch to turn off progress bar.
    style
        One of ["rich", "tqdm"]. "rich" is interactive
        and is not persistent after close.
    **kwargs
        Keyword args to tqdm or rich.

    Examples
    --------
    >>> from scvi.utils import track
    >>> my_list = [1, 2, 3]
    >>> for i in track(my_list): print(i)
    """
    if style is None:
        style = settings.progress_bar_style
    if style not in ["rich", "tqdm"]:
        raise ValueError("style must be one of ['rich', 'tqdm']")
    if disable:
        return sequence
    if style == "tqdm":
        # fixes repeated pbar in jupyter
        # see https://github.com/tqdm/tqdm/issues/375
        if hasattr(tqdm_base, "_instances"):
            for instance in list(tqdm_base._instances):
                tqdm_base._decr_instances(instance)
        return tqdm_base(sequence, desc=description, file=sys.stdout, **kwargs)
    else:
        in_colab = "google.colab" in sys.module
        force_jupyter = None if not in_colab else True
        console = Console(force_jupyter=force_jupyter)
        return track_base(sequence,
                          description=description,
                          console=console,
                          **kwargs)
Ejemplo n.º 5
0
def track(sequence: Iterable,
          description: str = "Working...",
          disable: bool = False,
          style: Literal["rich", "tqdm"] = None,
          **kwargs):
    """Progress bar with `'rich'` and `'tqdm'` styles."""
    if style is None:
        style = settings.progress_bar_style
    if style not in ["rich", "tqdm"]:
        raise ValueError("style must be one of ['rich', 'tqdm']")
    if disable:
        return sequence
    if style == "tqdm":
        # fixes repeated pbar in jupyter
        # see https://github.com/tqdm/tqdm/issues/375
        if hasattr(tqdm_base, "_instances"):
            for instance in list(tqdm_base._instances):
                tqdm_base._decr_instances(instance)
        return tqdm_base(sequence, desc=description, file=sys.stdout, **kwargs)
    else:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            return tqdm_rich(sequence, desc=description, **kwargs)
Ejemplo n.º 6
0
def reset_tqdm():
    for instance in list(tqdm._instances):
        tqdm._decr_instances(instance)
Ejemplo n.º 7
0
def tqdm(*args, **kwargs):
    if hasattr(tqdm_base, '_instances'):
        for instance in list(tqdm_base._instances):
            tqdm_base._decr_instances(instance)
    return tqdm_base(*args, **kwargs)
Ejemplo n.º 8
0
def main():

    loss_history = {"epoch": [], "train": [], "val": []}
    train_metrices_history = {
        "epoch": [],
        "acc": [],
        "iou": [],
        "dice": [],
        "sens": [],
        "spec": []
    }
    val_metrices_history = {
        "epoch": [],
        "acc": [],
        "iou": [],
        "dice": [],
        "sens": [],
        "spec": []
    }
    prev_loss = -100
    loss_increase_counter = 0
    early_stop = True
    early_stop_threshold = 5

    for e in range(epochs):
        model.train()
        running_loss = 0.0
        train_accuracy = 0.0
        train_sensitivity = 0.0
        train_specificity = 0.0
        train_iou = 0.0  # Jaccard Score
        train_dice = 0.0
        ts = time.time()

        for i, d in tqdm(enumerate(trainloader),
                         total=len(trainloader),
                         leave=True,
                         position=0,
                         desc='Epoch: {}'.format(e)):
            inputs_, targets_, _ = d
            inputs, targets = inputs_.to(device), targets_.to(device)

            opt.zero_grad()
            outputs = model(inputs)
            loss = loss_f(outputs, targets.long())
            loss.backward()
            opt.step()
            running_loss += loss.item()

            train_iou += IoU()(outputs, targets)
            train_dice += Fscore()(outputs, targets)
            train_accuracy += Accuracy()(outputs, targets)
            train_sensitivity += Sensitivity()(outputs, targets)
            train_specificity += Specificity()(outputs, targets)

        state = {
            'epoch': e,
            'state_dict': model.state_dict(),
            'optimizer': opt.state_dict(),
            'loss': loss.item()
        }

        tr_loss = running_loss / len(trainloader)

        vl_loss, val_acc, val_iou, val_dice, val_sens, val_spec = val()

        loss_history["epoch"].append(e)
        loss_history["train"].append(tr_loss)
        loss_history["val"].append(vl_loss)

        tr_acc = train_accuracy / len(trainloader)
        tr_iou = train_iou / len(trainloader)
        tr_dice = train_dice / len(trainloader)
        tr_sens = train_sensitivity / len(trainloader)
        tr_spec = train_specificity / len(trainloader)

        val_metrices_history["epoch"].append(e)
        val_metrices_history["acc"].append(val_acc)
        val_metrices_history["iou"].append(val_iou)
        val_metrices_history["dice"].append(val_dice)
        val_metrices_history["sens"].append(val_sens)
        val_metrices_history["spec"].append(val_spec)

        train_metrices_history["epoch"].append(e)
        train_metrices_history["acc"].append(tr_acc)
        train_metrices_history["iou"].append(tr_iou)
        train_metrices_history["dice"].append(tr_dice)
        train_metrices_history["sens"].append(tr_sens)
        train_metrices_history["spec"].append(tr_spec)

        file_name = os.path.join(ckp_path +
                                 'task_2_weights_epoch_{}'.format(e) + '.pt')

        print(
            "Finish Epoch {0},Time Elapsed {1}, train loss: {2:.6g}, val loss: {3:.6g}"
            .format(e,
                    time.time() - ts, tr_loss, vl_loss))
        print(
            "Metrices Train : Acc {0:.6g},IOU {1:.6g}, Dice {2:.6g}, Sens {3:.6g}, Spec  {4:.6g}"
            .format(tr_acc, tr_iou, tr_dice, tr_sens, tr_spec))
        print(
            "Metrices Valid : Acc {0:.6g},IOU {1:.6g}, Dice {2:.6g},  Sens {3:.6g}, Spec  {4:.6g}"
            .format(val_acc, val_iou, val_dice, val_sens, val_spec))
        print("-" * 60)
        torch.save(state, file_name)
        # Implemented early stopping
        if vl_loss > prev_loss:
            loss_increase_counter += 1
        else:
            loss_increase_counter = 0
        if early_stop and loss_increase_counter > early_stop_threshold:
            print("Early Stopping..")
            break

        prev_loss = vl_loss

        torch.cuda.empty_cache()

        list(getattr(tqdm, '_instances'))

        for instance in list(tqdm._instances):
            tqdm._decr_instances(instance)

    print('Finished Training')

    # Write the train loss to the csv
    (pd.DataFrame.from_dict(data=loss_history, orient='columns').to_csv(
        os.path.join(logs_path, 'loss.csv'),
        header=['epoch', 'train_loss', 'val_loss']))
    (pd.DataFrame.from_dict(
        data=train_metrices_history, orient='columns').to_csv(
            os.path.join(logs_path, 'train_metrices.csv'),
            header=["epoch", "acc", "iou", "dice", "sens", "spec"]))
    (pd.DataFrame.from_dict(
        data=val_metrices_history, orient='columns').to_csv(
            os.path.join(logs_path, 'val_metrices.csv'),
            header=["epoch", "acc", "iou", "dice", "sens", "spec"]))
Ejemplo n.º 9
0
def tqdm_fixed(*args, **kwargs):
    from tqdm import tqdm as tqdm_base
    if hasattr(tqdm_base, '_instances'):
        for instance in list(tqdm_base._instances):
            tqdm_base._decr_instances(instance)
    return tqdm_base(*args, **kwargs)
Ejemplo n.º 10
0
def clear_tqdm():
    if hasattr(tqdm, '_instances'):
        for instance in list(tqdm._instances):
            tqdm._decr_instances(instance)
Ejemplo n.º 11
0
            loss = criterion(output, target_tokens)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

N_EPOCHS = args.epochs
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    train_loss = train(model, train_iterator, criterion, optimizer, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), args.save+'/model.pth')

    print('Epoch: ', epoch)
    print('Train loss: ', train_loss)
    print('Valid loss: ', valid_loss)

test_loss = evaluate(model, test_iterator, criterion)

print('Test Loss: {:.2f}'.format(test_loss))

for instance in list(tqdm._instances):
    tqdm._decr_instances(instance)