Example #1
0
def train_epoch(epoch,
                data_loader,
                model,
                criterion,
                optimizer,
                device,
                current_lr,
                epoch_logger,
                batch_logger,
                tb_writer=None,
                distributed=False):
    print('train at epoch {}'.format(epoch))
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()

    end_time = time.time()
    for i, (inputs, targets) in enumerate(data_loader):
        data_time.update(time.time() - end_time)

        targets = targets.to(device, non_blocking=True)
        outputs, features = model(inputs)

        loss = criterion(outputs, targets)
        acc = calculate_accuracy(outputs, targets)

        losses.update(loss.item(), inputs.size(0))
        accuracies.update(acc, inputs.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end_time)
        end_time = time.time()

        write_to_batch_logger(batch_logger, epoch, i, data_loader, losses.val,
                              accuracies.val, current_lr)

        print('Epoch: [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
              'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(epoch,
                                                         i + 1,
                                                         len(data_loader),
                                                         batch_time=batch_time,
                                                         data_time=data_time,
                                                         loss=losses,
                                                         acc=accuracies),
              flush=True)

    if distributed:
        loss_sum = torch.tensor([losses.sum],
                                dtype=torch.float32,
                                device=device)
        loss_count = torch.tensor([losses.count],
                                  dtype=torch.float32,
                                  device=device)
        acc_sum = torch.tensor([accuracies.sum],
                               dtype=torch.float32,
                               device=device)
        acc_count = torch.tensor([accuracies.count],
                                 dtype=torch.float32,
                                 device=device)

        dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM)
        dist.all_reduce(loss_count, op=dist.ReduceOp.SUM)
        dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM)
        dist.all_reduce(acc_count, op=dist.ReduceOp.SUM)

        losses.avg = loss_sum.item() / loss_count.item()
        accuracies.avg = acc_sum.item() / acc_count.item()

    write_to_epoch_logger(epoch_logger, epoch, losses.val, accuracies.val,
                          current_lr)

    if tb_writer is not None:
        tb_writer.add_scalar('train/loss', losses.avg, epoch)
        tb_writer.add_scalar('train/acc', accuracies.avg, epoch)
Example #2
0
def train_a_epoch(epoch,
                  data_loader,
                  model,
                  joint_prediction_aud,
                  criterion,
                  criterion_jsd,
                  criterion_ct_av,
                  optimizer,
                  optimizer_av,
                  device,
                  current_lr,
                  epoch_logger,
                  batch_logger,
                  tb_writer=None,
                  distributed=False):
    print('train at epoch {}'.format(epoch))
    model.train()
    joint_prediction_aud.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    # classification loss
    losses_cls = AverageMeter()
    accuracies = AverageMeter()
    # contrastive loss
    losses_ct_av = AverageMeter()
    # jsd loss
    losses_jsd_a = AverageMeter()

    end_time = time.time()
    for i, (inputs, targets, audios) in enumerate(data_loader):
        data_time.update(time.time() - end_time)

        outputs, features = model(inputs)
        targets = targets.to(device, non_blocking=True)
        audios = audios.to(device, non_blocking=True)

        loss_cls_v = criterion(outputs, targets)  # video classification loss
        acc = calculate_accuracy(outputs, targets)
        #####################################################################################
        # use audio features as features & filter out the zero-ones (not available) audio features
        features_aud = audios[audios.sum(dim=1) != 0]
        features_vid = features[audios.sum(dim=1) != 0]
        targets_new = targets[audios.sum(dim=1) != 0]
        outputs_new = outputs[audios.sum(dim=1) != 0]

        # here compose images and videos
        outputs_av, features_av = joint_prediction_aud(features_aud,
                                                       features_vid)
        loss_cls_a = criterion(outputs_av,
                               targets_new)  # video classification loss

        # contrastive learning (symmetric loss)
        # align video features to multimodal (audio-video) features
        loss_vm = criterion_ct_av(features_vid, features_av,
                                  targets_new) + criterion_ct_av(
                                      features_av, features_vid, targets_new)
        # align video features to audio features
        loss_va = criterion_ct_av(features_vid, features_aud,
                                  targets_new) + criterion_ct_av(
                                      features_aud, features_vid, targets_new)
        # align multimodal (audio-video) features to audio features
        loss_ma = criterion_ct_av(features_av, features_aud,
                                  targets_new) + criterion_ct_av(
                                      features_aud, features_av, targets_new)
        # contrastive loss
        loss_ct_av = loss_vm + loss_va
        loss_ct_a = loss_vm + loss_ma
        # jsd loss
        loss_jsd_a = criterion_jsd(outputs_new, outputs_av)
        #####################################################################################
        total_loss_v = sum([loss_cls_v, loss_ct_av, loss_jsd_a])
        total_loss_a = sum([loss_cls_a, loss_ct_a, loss_jsd_a])

        losses_cls.update(loss_cls_v.item(), inputs.size(0))
        losses_ct_av.update(loss_ct_av.item(), inputs.size(0))
        losses_jsd_a.update(loss_jsd_a.item(), inputs.size(0))

        accuracies.update(acc, inputs.size(0))

        optimizer_av.zero_grad()
        total_loss_a.backward(retain_graph=True)

        optimizer.zero_grad()
        total_loss_v.backward()

        optimizer_av.step()
        optimizer.step()
        #####################################################################################
        batch_time.update(time.time() - end_time)
        end_time = time.time()

        write_to_batch_logger(batch_logger, epoch, i, data_loader,
                              losses_cls.val, accuracies.val, current_lr)

        print('Epoch: [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
              'Loss_cls {loss_cls.val:.3f} ({loss_cls.avg:.3f})\t'
              'Loss_ct_a {loss_ct_av.val:.3f} ({loss_ct_av.avg:.3f})\t'
              'Loss_jsd_a {loss_jsd_a.val:.3f} ({loss_jsd_a.avg:.3f})\t'
              'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(
                  epoch,
                  i + 1,
                  len(data_loader),
                  batch_time=batch_time,
                  data_time=data_time,
                  loss_cls=losses_cls,
                  loss_ct_av=losses_ct_av,
                  loss_jsd_a=losses_jsd_a,
                  acc=accuracies),
              flush=True)

        if distributed:
            loss_cls_sum = torch.tensor([losses_cls.sum],
                                        dtype=torch.float32,
                                        device=device)
            loss_ct_av_sum = torch.tensor([losses_ct_av.sum],
                                          dtype=torch.float32,
                                          device=device)
            loss_jsd_a_sum = torch.tensor([losses_jsd_a.sum],
                                          dtype=torch.float32,
                                          device=device)
            acc_sum = torch.tensor([accuracies.sum],
                                   dtype=torch.float32,
                                   device=device)
            loss_count = torch.tensor([losses_cls.count],
                                      dtype=torch.float32,
                                      device=device)
            acc_count = torch.tensor([accuracies.count],
                                     dtype=torch.float32,
                                     device=device)

            dist.all_reduce(loss_cls_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(loss_ct_av_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(loss_jsd_a_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(loss_count, op=dist.ReduceOp.SUM)
            dist.all_reduce(acc_count, op=dist.ReduceOp.SUM)

            losses_cls.avg = loss_cls_sum.item() / loss_count.item()
            losses_ct_av.avg = loss_ct_av_sum.item() / loss_count.item()
            losses_jsd_a.avg = loss_jsd_a_sum.item() / loss_count.item()
            accuracies.avg = acc_sum.item() / acc_count.item()

        write_to_epoch_logger(epoch_logger, epoch, losses_cls.val,
                              accuracies.val, current_lr)

        if tb_writer is not None:
            tb_writer.add_scalar('train/loss_cls', losses_cls.avg, epoch)
            tb_writer.add_scalar('train/loss_ct_av', losses_ct_av.avg, epoch)
            tb_writer.add_scalar('train/loss_jsd_a', losses_jsd_a.avg, epoch)
            tb_writer.add_scalar('train/acc', accuracies.avg, epoch)
Example #3
0
def train_i_epoch(epoch,
                  data_loader,
                  model,
                  image_model,
                  joint_prediction_img,
                  criterion,
                  criterion_jsd,
                  criterion_ct_iv,
                  optimizer,
                  optimizer_iv,
                  device,
                  current_lr,
                  epoch_logger,
                  batch_logger,
                  tb_writer=None,
                  distributed=False,
                  image_size=None):
    print('train at epoch {}'.format(epoch))
    model.train()
    image_model.eval()
    joint_prediction_img.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    # classification loss
    losses_cls = AverageMeter()
    accuracies = AverageMeter()
    # contrastive loss
    losses_ct_iv = AverageMeter()
    # jsd loss
    losses_jsd_i = AverageMeter()

    end_time = time.time()
    for i, (inputs, targets) in enumerate(data_loader):
        data_time.update(time.time() - end_time)

        outputs, features = model(inputs)
        targets = targets.to(device, non_blocking=True)

        loss_cls_v = criterion(outputs, targets)  # video classification loss
        acc = calculate_accuracy(outputs, targets)
        #####################################################################################
        if image_size is not None:
            # here resize the image to a larger size than the video input size (better fit the resnet)
            inputs = F.interpolate(
                inputs, size=[inputs.shape[0], image_size, image_size])
        ### randomly select an image
        rand_img = random.randint(0, inputs.shape[2] - 1)
        images = inputs[:, :, rand_img, :, :]
        features_img = image_model(images)
        features_img = features_img.squeeze()

        # here compose images and videos
        outputs_iv, features_iv = joint_prediction_img(features_img, features)
        loss_cls_i = criterion(outputs_iv,
                               targets)  # video classification loss

        # contrastive learning (symmetric loss)
        # align video features to multimodal (image-video) features
        loss_vm = criterion_ct_iv(features, features_iv,
                                  targets) + criterion_ct_iv(
                                      features_iv, features, targets)
        # align video features to image features
        loss_vi = criterion_ct_iv(features, features_img,
                                  targets) + criterion_ct_iv(
                                      features_img, features, targets)
        # align multimodal features to image features
        loss_mi = criterion_ct_iv(features_iv, features_img,
                                  targets) + criterion_ct_iv(
                                      features_img, features_iv, targets)
        # contrastive loss
        loss_ct_iv = loss_vm + loss_vi
        loss_ct_i = loss_vm + loss_mi
        # jsd loss
        loss_jsd_i = criterion_jsd(outputs, outputs_iv)
        #####################################################################################
        total_loss_v = sum([loss_cls_v, loss_ct_iv, loss_jsd_i])
        total_loss_i = sum([loss_cls_i, loss_ct_i, loss_jsd_i])

        losses_cls.update(loss_cls_v.item(), inputs.size(0))
        losses_ct_iv.update(loss_ct_iv.item(), inputs.size(0))
        losses_jsd_i.update(loss_jsd_i.item(), inputs.size(0))
        accuracies.update(acc, inputs.size(0))

        optimizer_iv.zero_grad()
        total_loss_i.backward(retain_graph=True)

        optimizer.zero_grad()
        total_loss_v.backward()

        optimizer_iv.step()
        optimizer.step()
        #####################################################################################
        batch_time.update(time.time() - end_time)
        end_time = time.time()

        write_to_batch_logger(batch_logger, epoch, i, data_loader,
                              losses_cls.val, accuracies.val, current_lr)

        print('Epoch: [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
              'Loss_cls {loss_cls.val:.3f} ({loss_cls.avg:.3f})\t'
              'Loss_ct_i {loss_ct_iv.val:.3f} ({loss_ct_iv.avg:.3f})\t'
              'Loss_jsd_i {loss_jsd_i.val:.3f} ({loss_jsd_i.avg:.3f})\t'
              'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(
                  epoch,
                  i + 1,
                  len(data_loader),
                  batch_time=batch_time,
                  data_time=data_time,
                  loss_cls=losses_cls,
                  loss_ct_iv=losses_ct_iv,
                  loss_jsd_i=losses_jsd_i,
                  acc=accuracies),
              flush=True)

        if distributed:
            loss_cls_sum = torch.tensor([losses_cls.sum],
                                        dtype=torch.float32,
                                        device=device)
            loss_ct_iv_sum = torch.tensor([losses_ct_iv.sum],
                                          dtype=torch.float32,
                                          device=device)
            loss_jsd_i_sum = torch.tensor([losses_jsd_i.sum],
                                          dtype=torch.float32,
                                          device=device)
            acc_sum = torch.tensor([accuracies.sum],
                                   dtype=torch.float32,
                                   device=device)
            loss_count = torch.tensor([losses_cls.count],
                                      dtype=torch.float32,
                                      device=device)
            acc_count = torch.tensor([accuracies.count],
                                     dtype=torch.float32,
                                     device=device)

            dist.all_reduce(loss_cls_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(loss_ct_iv_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(loss_jsd_i_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(loss_count, op=dist.ReduceOp.SUM)
            dist.all_reduce(acc_count, op=dist.ReduceOp.SUM)

            losses_cls.avg = loss_cls_sum.item() / loss_count.item()
            losses_ct_iv.avg = loss_ct_iv_sum.item() / loss_count.item()
            losses_jsd_i.avg = loss_jsd_i_sum.item() / loss_count.item()
            accuracies.avg = acc_sum.item() / acc_count.item()

        write_to_epoch_logger(epoch_logger, epoch, losses_cls.val,
                              accuracies.val, current_lr)

        if tb_writer is not None:
            tb_writer.add_scalar('train/loss_cls', losses_cls.avg, epoch)
            tb_writer.add_scalar('train/loss_ct_iv', losses_ct_iv.avg, epoch)
            tb_writer.add_scalar('train/loss_jsd_i', losses_jsd_i.avg, epoch)
            tb_writer.add_scalar('train/acc', accuracies.avg, epoch)