def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler,
                train_data_loader, summary_writer, conf, local_rank,
                only_valid):
    losses = AverageMeter()
    fake_losses = AverageMeter()
    real_losses = AverageMeter()
    max_iters = conf["batches_per_epoch"]
    print("training epoch {}".format(current_epoch))
    model.train()
    pbar = tqdm(enumerate(train_data_loader),
                total=max_iters,
                desc="Epoch {}".format(current_epoch),
                ncols=0)
    if conf["optimizer"]["schedule"]["mode"] == "epoch":
        scheduler.step(current_epoch)
    for i, sample in pbar:
        imgs = sample["image"].cuda()
        labels = sample["labels"].cuda().float()
        out_labels = model(imgs)
        if only_valid:
            valid_idx = sample["valid"].cuda().float() > 0
            out_labels = out_labels[valid_idx]
            labels = labels[valid_idx]
            if labels.size(0) == 0:
                continue

        fake_loss = 0
        real_loss = 0
        fake_idx = labels > 0.5
        real_idx = labels <= 0.5

        ohem = conf.get("ohem_samples", None)
        if torch.sum(fake_idx * 1) > 0:
            fake_loss = loss_functions["classifier_loss"](out_labels[fake_idx],
                                                          labels[fake_idx])
        if torch.sum(real_idx * 1) > 0:
            real_loss = loss_functions["classifier_loss"](out_labels[real_idx],
                                                          labels[real_idx])
        if ohem:
            fake_loss = topk(fake_loss,
                             k=min(ohem, fake_loss.size(0)),
                             sorted=False)[0].mean()
            real_loss = topk(real_loss,
                             k=min(ohem, real_loss.size(0)),
                             sorted=False)[0].mean()

        loss = (fake_loss + real_loss) / 2
        losses.update(loss.item(), imgs.size(0))
        fake_losses.update(0 if fake_loss == 0 else fake_loss.item(),
                           imgs.size(0))
        real_losses.update(0 if real_loss == 0 else real_loss.item(),
                           imgs.size(0))

        optimizer.zero_grad()
        pbar.set_postfix({
            "lr": float(scheduler.get_lr()[-1]),
            "epoch": current_epoch,
            "loss": losses.avg,
            "fake_loss": fake_losses.avg,
            "real_loss": real_losses.avg
        })

        if conf['fp16']:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
        optimizer.step()
        torch.cuda.synchronize()
        if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
            scheduler.step(i + current_epoch * max_iters)
        if i == max_iters - 1:
            break
    pbar.close()
    if local_rank == 0:
        for idx, param_group in enumerate(optimizer.param_groups):
            lr = param_group['lr']
            summary_writer.add_scalar('group{}/lr'.format(idx),
                                      float(lr),
                                      global_step=current_epoch)
        summary_writer.add_scalar('train/loss',
                                  float(losses.avg),
                                  global_step=current_epoch)
Example #2
0
def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler,
                train_data_loader, summary_writer, conf, local_rank,
                only_valid):
    losses = AverageMeter()
    fake_losses = AverageMeter()
    real_losses = AverageMeter()
    max_iters = conf["batches_per_epoch"]
    print("training epoch {}".format(current_epoch))
    # tells the model that you are training the model.
    # So effectively layers like dropout, batchnorm etc. which behave different on the train and
    # test procedures know what is going on and hence can behave accordingly.
    model.train()
    pbar = tqdm(enumerate(train_data_loader),
                total=max_iters,
                desc="Epoch {}".format(current_epoch),
                ncols=0)
    if conf["optimizer"]["schedule"]["mode"] == "epoch":
        scheduler.step(current_epoch)
    # itterats over the training data sample
    for i, sample in pbar:
        imgs = sample["image"].cuda()
        labels = sample["labels"].cuda().float()
        out_labels = model(imgs)
        if only_valid:
            valid_idx = sample["valid"].cuda().float() > 0
            out_labels = out_labels[valid_idx]
            labels = labels[valid_idx]
            if labels.size(0) == 0:
                continue

        fake_loss = 0
        real_loss = 0
        fake_idx = labels > 0.5
        real_idx = labels <= 0.5

        ohem = conf.get("ohem_samples", None)
        # torch.sum returns the sum of all elements in the input tensor
        # this part of the function computes the loss for the real and the fake videos
        if torch.sum(fake_idx * 1) > 0:
            fake_loss = loss_functions["classifier_loss"](out_labels[fake_idx],
                                                          labels[fake_idx])
        if torch.sum(real_idx * 1) > 0:
            real_loss = loss_functions["classifier_loss"](out_labels[real_idx],
                                                          labels[real_idx])
        if ohem:
            fake_loss = topk(fake_loss,
                             k=min(ohem, fake_loss.size(0)),
                             sorted=False)[0].mean()
            real_loss = topk(real_loss,
                             k=min(ohem, real_loss.size(0)),
                             sorted=False)[0].mean()

        loss = (fake_loss + real_loss) / 2
        losses.update(loss.item(), imgs.size(0))
        fake_losses.update(0 if fake_loss == 0 else fake_loss.item(),
                           imgs.size(0))
        real_losses.update(0 if real_loss == 0 else real_loss.item(),
                           imgs.size(0))
        # sets the gradients to zero before starting to do backpropragation
        optimizer.zero_grad()
        # this specifies additional stats to display at the end of the bar
        pbar.set_postfix({
            "lr": float(scheduler.get_lr()[-1]),
            "epoch": current_epoch,
            "loss": losses.avg,
            "fake_loss": fake_losses.avg,
            "real_loss": real_losses.avg
        })

        # # starts the backpropagation for both types of models
        if conf['fp16']:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        # torch.nn.utils.clip_grad_norm_ : Clips gradient norm of an iterable of parameters.
        # The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.
        # amp.master_params: generator expression that iterates over the params owned by optimizer
        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
        # optimizer.step performs a parameter update based on the current gradient (stored in .grad attribute of a parameter) and the update rule.
        optimizer.step()
        # wait for all kernels in all streams on a CUDA device to complete.
        torch.cuda.synchronize()
        if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
            scheduler.step(i + current_epoch * max_iters)
        if i == max_iters - 1:
            break
    pbar.close()
    if local_rank == 0:
        for idx, param_group in enumerate(optimizer.param_groups):
            lr = param_group['lr']
            summary_writer.add_scalar('group{}/lr'.format(idx),
                                      float(lr),
                                      global_step=current_epoch)
        summary_writer.add_scalar('train/loss',
                                  float(losses.avg),
                                  global_step=current_epoch)
def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler,
                train_data_loader, summary_writer, conf, local_rank):
    losses = AverageMeter()
    c_losses = AverageMeter()
    d_losses = AverageMeter()
    dices = AverageMeter()
    iterator = tqdm(train_data_loader)
    model.train()
    if conf["optimizer"]["schedule"]["mode"] == "epoch":
        scheduler.step(current_epoch)
    for i, sample in enumerate(iterator):
        imgs = sample["image"].cuda()
        masks = sample["mask"].cuda().float()
        # if torch.sum(masks) < 100:
        #     continue
        centers = sample["center"].cuda().float()

        seg_mask, center_mask = model(imgs)
        with torch.no_grad():
            pred = torch.sigmoid(seg_mask)
            d = dice_round(pred[:, 0:1, ...].cpu(),
                           masks[:, 0:1, ...].cpu(),
                           t=0.5).item()
        dices.update(d, imgs.size(0))

        mask_loss = loss_functions["mask_loss"](seg_mask, masks)
        # if torch.isnan(mask_loss):
        #     print("nan loss, skipping!!!")
        #     optimizer.zero_grad()
        #     continue
        center_loss = loss_functions["center_loss"](center_mask, centers)
        center_loss *= 50
        loss = mask_loss + center_loss

        loss /= 2
        if current_epoch == 0:
            loss /= 10
        losses.update(loss.item(), imgs.size(0))
        d_losses.update(mask_loss.item(), imgs.size(0))

        c_losses.update(center_loss.item(), imgs.size(0))
        iterator.set_postfix({
            "lr": float(scheduler.get_lr()[-1]),
            "epoch": current_epoch,
            "loss": losses.avg,
            "dice": dices.avg,
            "d_loss": d_losses.avg,
            "c_loss": c_losses.avg,
        })
        optimizer.zero_grad()
        if conf['fp16']:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
        optimizer.step()
        torch.cuda.synchronize()

        if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
            scheduler.step(i + current_epoch * len(train_data_loader))

    if local_rank == 0:
        for idx, param_group in enumerate(optimizer.param_groups):
            lr = param_group['lr']
            summary_writer.add_scalar('group{}/lr'.format(idx),
                                      float(lr),
                                      global_step=current_epoch)
        summary_writer.add_scalar('train/loss',
                                  float(losses.avg),
                                  global_step=current_epoch)
        if conf['fp16'] and args.device != 'cpu':
            with autocast():
                out = model(imgs)
                loss = criterion(out, labels)  # 0.6710
            scaler.scale(loss).backward()
            if (i % args.accum) == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
        else:
            out = model(imgs)
            loss = criterion(out, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        losses.update(loss.item(), imgs.size(0))
        pbar.set_postfix({
            "lr": float(scheduler.get_lr()[-1]),
            "epoch": current_epoch,
            "loss": losses.avg,
            'seen_prev': seenratio
        })

        if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
            scheduler.step(i + current_epoch * max_iters)
        if i == max_iters - 1:
            break
    pbar.close()
    if epoch > 0:
        seen = set(epoch_img_names[epoch]).intersection(
            set(itertools.chain(*[epoch_img_names[i] for i in range(epoch)])))
def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
                local_rank):
    losses = AverageMeter()
    c_losses = AverageMeter()
    o_losses = AverageMeter()
    e_losses = AverageMeter()
    d_losses = AverageMeter()
    s_dices = AverageMeter()
    cl_losses = AverageMeter()
    dices = AverageMeter()
    iterator = tqdm(train_data_loader)
    model.train()
    if conf["optimizer"]["schedule"]["mode"] == "epoch":
        scheduler.step(current_epoch)
    for i, sample in enumerate(iterator):
        imgs = sample["image"].cuda()
        masks = sample["mask"].cuda().float()
        centers = sample["center"].cuda().float()
        offsets = sample["offset"].cuda().float()
        change_labels = sample["change_labels"].cuda().float()
        inner_labels = sample["inner_labels"].cuda().float()

        seg_mask, center_mask, offset_mask = model(imgs)
        with torch.no_grad():
            pred = torch.sigmoid(seg_mask)
            d = dice_round(pred[:, 0:1, ...], masks[:, 0:1, ...], t=0.5).item()
            s_d = dice_round(pred[:, 2:3, ...], masks[:, 2:3, ...], t=0.5).item()
        dices.update(d, imgs.size(0))
        s_dices.update(s_d, imgs.size(0))

        mask_loss = loss_functions["mask_loss"](seg_mask, masks)
        center_loss = loss_functions["center_loss"](center_mask, centers)
        offset_loss = loss_functions["offset_loss"](offset_mask, offsets)
        offset_e_loss = loss_functions["offset_e_loss"](offset_mask, offsets, masks[:, 0:1, ...].contiguous())

        #print("GT ({})_({}) PRED ({})_({})".format(torch.max(offsets).item(), torch.min(offsets).item(), torch.max(offset_mask).item(), torch.min(offset_mask).item()))
        #inner_sep_loss = soft_dice_loss(1 - torch.sigmoid(seg_mask[:, 2, ...]), 1 - inner_labels)
        mask_loss *= 2
        center_loss *= 70
        offset_loss *= 0.1
        offset_e_loss *= 0.1
        loss = mask_loss + center_loss + offset_loss + offset_e_loss# + inner_sep_loss

        if torch.sum(change_labels) > 20:
            seg_mask_changed = seg_mask[change_labels > 0]
            mask_changed = masks[change_labels > 0]
            ch_lbl_loss = soft_dice_loss(torch.sigmoid(seg_mask_changed.view(1, -1)), mask_changed.view(1, -1))
            loss += ch_lbl_loss
            cl_losses.update(ch_lbl_loss.item(), imgs.size(0))
        loss /= 5
        losses.update(loss.item(), imgs.size(0))
        d_losses.update(mask_loss.item(), imgs.size(0))
        o_losses.update(offset_loss.item(), imgs.size(0))
        c_losses.update(center_loss.item(), imgs.size(0))
        e_losses.update(offset_e_loss.item(), imgs.size(0))
        iterator.set_postfix({"lr": float(scheduler.get_lr()[-1]),
                              "epoch": current_epoch,
                              "loss": losses.avg,
                              "dice": dices.avg,
                              "s_dice": s_dices.avg,
                              "d_loss": d_losses.avg,
                              "c_loss": c_losses.avg,
                              "o_loss": o_losses.avg,
                              "e_loss": e_losses.avg,
                              "cl_loss": cl_losses.avg,
                              })
        optimizer.zero_grad()
        if conf['fp16']:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
        optimizer.step()
        torch.cuda.synchronize()

        if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
            scheduler.step(i + current_epoch * len(train_data_loader))

    if local_rank == 0:
        for idx, param_group in enumerate(optimizer.param_groups):
            lr = param_group['lr']
            summary_writer.add_scalar('group{}/lr'.format(idx), float(lr), global_step=current_epoch)
        summary_writer.add_scalar('train/loss', float(losses.avg), global_step=current_epoch)