Esempio n. 1
0
def evaluate(epoch, args):
    args.model.eval()
    loss_avg = AverageMeter()
    acc_avg = AverageMeter()

    with torch.no_grad():
        for data, target in tqdm(args.test_loader):
            batch_step(args, data, target, loss_avg, acc_avg)

    print("Epoch:", epoch, "Test Loss:", loss_avg.avg(), "Test Accuracy:",
          acc_avg.avg(), "\n")
Esempio n. 2
0
def train(epoch, args):
    args.model.train()
    loss_avg = AverageMeter()
    acc_avg = AverageMeter()
    norm_avg = AverageMeter()

    progress_bar = tqdm(args.train_loader)
    for data, target in progress_bar:

        # learning rate warmup
        if args.warmup_lr:
            args.optimizer.param_groups[0]['lr'] = args.warmup_lr.pop(0)

        # a worker finished computing gradients according to the gamma distribution
        worker_rank = next(args.worker_order)
        load_worker(args, worker_rank)

        # worker computes gradient on its set of weights
        args.optimizer.zero_grad()
        loss = batch_step(args, data, target, loss_avg, acc_avg)
        loss.backward()

        # the master receives the gradients from the worker and updates its weights
        delay_compensation(args)
        load_master(args)
        args.optimizer.step()
        update_master(args)

        # the worker receives the master's new weights
        update_worker(args, worker_rank)

        # compute the gradient norm
        norm_avg.update(
            sum(p.grad.data.norm()**2 for p in args.model.parameters())**0.5,
            target.shape[0])
        progress_bar.set_description(
            "Epoch: %d, Loss: %0.8f Norm: %0.4f LR: %0.4f" %
            (epoch, loss_avg.avg(), norm_avg.avg(),
             args.optimizer.param_groups[0]['lr']))
    progress_bar.close()
Esempio n. 3
0
def test(model, testloader, criterion, use_cuda):

    model.eval()

    losses = AverageMeter()
    top1 = AverageMeter()

    for batch_idx, (inputs, targets) in enumerate(testloader):

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(
            inputs, volatile=True), torch.autograd.Variable(targets)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        prec1 = accuracy(outputs.data, targets.data, topk=(1, ))
        losses.update(loss.data[0], inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

    return (losses.avg(), top1.avg())
Esempio n. 4
0
def train(model, trainloader, criterion, optimizer, epoch, use_cuda):
    model.train()

    losses = AverageMeter()
    top1 = AverageMeter()

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda(async=True)
        inputs, targets = torch.autograd.Variable(
            inputs), torch.autograd.Variable(targets)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        prec1 = accuracy(outputs.data, targets.data, topk=(1, ))
        losses.update(loss.data[0], inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return (losses.avg(), top1.avg())
Esempio n. 5
0
def val_epoch(epoch,
              data_loader,
              model,
              criterion,
              device,
              logger,
              tb_writer=None,
              distributed=False):
    print('validation at epoch {}'.format(epoch))

    model.eval()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()

    end_time = time.time()

    with torch.no_grad():
        for i, (inputs, targets) in enumerate(data_loader):
            data_time.update(time.time() - end_time)

            targets = targets.to(device, non_blocking=True)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            acc = calculate_accuracy(outputs, targets)

            losses.update(loss.item(), inputs.size(0))
            accuracies.update(acc, inputs.size(0))

            batch_time.update(time.time() - end_time)
            end_time = time.time()

            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(
                      epoch,
                      i + 1,
                      len(data_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      acc=accuracies).expandtabs(tabsize=4))

    if distributed:
        loss_sum = torch.tensor([losses.sum],
                                dtype=torch.float32,
                                device=device)
        loss_count = torch.tensor([losses.count],
                                  dtype=torch.float32,
                                  device=device)
        acc_sum = torch.tensor([accuracies.sum],
                               dtype=torch.float32,
                               device=device)
        acc_count = torch.tensor([accuracies.count],
                                 dtype=torch.float32,
                                 device=device)

        dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM)
        dist.all_reduce(loss_count, op=dist.ReduceOp.SUM)
        dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM)
        dist.all_reduce(acc_count, op=dist.ReduceOp.SUM)

        losses.avg = loss_sum.item() / loss_count.item()
        accuracies.avg = acc_sum.item() / acc_count.item()

    if logger is not None:
        logger.log({'epoch': epoch, 'loss': losses.avg, 'acc': accuracies.avg})

    if tb_writer is not None:
        tb_writer.add_scalar('val/loss', losses.avg, epoch)
        tb_writer.add_scalar('val/acc', accuracies.avg, epoch)

    return losses.avg
Esempio n. 6
0
def train_epoch(epoch,
                data_loader,
                model,
                criterion,
                optimizer,
                device,
                current_lr,
                epoch_logger,
                batch_logger,
                tb_writer=None,
                distributed=False):
    print('train at epoch {}'.format(epoch))

    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()

    end_time = time.time()
    for i, (inputs, targets) in enumerate(data_loader):
        data_time.update(time.time() - end_time)

        targets = targets.to(device, non_blocking=True)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        acc = calculate_accuracy(outputs, targets)

        losses.update(loss.item(), inputs.size(0))
        accuracies.update(acc, inputs.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end_time)
        end_time = time.time()

        if batch_logger is not None:
            batch_logger.log({
                'epoch': epoch,
                'batch': i + 1,
                'iter': (epoch - 1) * len(data_loader) + (i + 1),
                'loss': losses.val,
                'acc': accuracies.val,
                'lr': current_lr
            })

        print('Epoch: [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
              'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(epoch,
                                                         i + 1,
                                                         len(data_loader),
                                                         batch_time=batch_time,
                                                         data_time=data_time,
                                                         loss=losses,
                                                         acc=accuracies))

    if distributed:
        loss_sum = torch.tensor([losses.sum],
                                dtype=torch.float32,
                                device=device)
        loss_count = torch.tensor([losses.count],
                                  dtype=torch.float32,
                                  device=device)
        acc_sum = torch.tensor([accuracies.sum],
                               dtype=torch.float32,
                               device=device)
        acc_count = torch.tensor([accuracies.count],
                                 dtype=torch.float32,
                                 device=device)

        dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM)
        dist.all_reduce(loss_count, op=dist.ReduceOp.SUM)
        dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM)
        dist.all_reduce(acc_count, op=dist.ReduceOp.SUM)

        losses.avg = loss_sum.item() / loss_count.item()
        accuracies.avg = acc_sum.item() / acc_count.item()

    if epoch_logger is not None:
        epoch_logger.log({
            'epoch': epoch,
            'loss': losses.avg,
            'acc': accuracies.avg,
            'lr': current_lr
        })

    if tb_writer is not None:
        tb_writer.add_scalar('train/loss', losses.avg, epoch)
        tb_writer.add_scalar('train/acc', accuracies.avg, epoch)
        tb_writer.add_scalar('train/lr', current_lr, epoch)
Esempio n. 7
0
def val_epoch(epoch,
              data_loader,
              model,
              criterion,
              device,
              logger,
              tb_writer=None,
              distributed=False,
              rpn=None,
              det_interval=2,
              nrois=10):
    print('validation at epoch {}'.format(epoch))

    model.eval()
    if rpn is not None:
        rpn.eval()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()

    end_time = time.time()

    with torch.no_grad():
        for i, (inputs, targets) in enumerate(data_loader):
            data_time.update(time.time() - end_time)
            targets = targets.to(device, non_blocking=True)
            if rpn is not None:
                '''
                    There was an unexpected CUDNN_ERROR when len(rpn_inputs) is
                    decrased.
                '''
                T = inputs.shape[2]
                N, C, T, H, W = inputs.size()
                if i == 0:
                    max_N = N
                # sample frames for RPN
                sample = torch.arange(0, T, det_interval)
                rpn_inputs = inputs[:, :, sample].transpose(1, 2).contiguous()
                rpn_inputs = rpn_inputs.view(-1, C, H, W)
                if len(inputs) < max_N:
                    print("Modified from {} to {}".format(len(inputs), max_N))
                    rpn_inputs = torch.cat(
                        (rpn_inputs, rpn_inputs[:(max_N - len(inputs)) *
                                                (T // det_interval)]))
                with torch.no_grad():
                    proposals = rpn(rpn_inputs)
                proposals = proposals.view(-1, T // det_interval, nrois, 4)
                if len(inputs) < max_N:
                    proposals = proposals[:len(inputs)]
                outputs = model(inputs, proposals.detach())
                # update to the largest batch_size
                max_N = max(N, max_N)
            else:
                outputs = model(inputs)

            loss = criterion(outputs, targets)
            acc = calculate_accuracy(outputs, targets)

            losses.update(loss.item(), inputs.size(0))
            accuracies.update(acc, inputs.size(0))

            batch_time.update(time.time() - end_time)
            end_time = time.time()

            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(
                      epoch,
                      i + 1,
                      len(data_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      acc=accuracies))

    if distributed:
        loss_sum = torch.tensor([losses.sum],
                                dtype=torch.float32,
                                device=device)
        loss_count = torch.tensor([losses.count],
                                  dtype=torch.float32,
                                  device=device)
        acc_sum = torch.tensor([accuracies.sum],
                               dtype=torch.float32,
                               device=device)
        acc_count = torch.tensor([accuracies.count],
                                 dtype=torch.float32,
                                 device=device)

        dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM)
        dist.all_reduce(loss_count, op=dist.ReduceOp.SUM)
        dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM)
        dist.all_reduce(acc_count, op=dist.ReduceOp.SUM)

        losses.avg = loss_sum.item() / loss_count.item()
        accuracies.avg = acc_sum.item() / acc_count.item()

    if logger is not None:
        logger.log({'epoch': epoch, 'loss': losses.avg, 'acc': accuracies.avg})

    if tb_writer is not None:
        tb_writer.add_scalar('val/loss', losses.avg, epoch)
        tb_writer.add_scalar('val/acc', accuracies.avg, epoch)

    return losses.avg
Esempio n. 8
0
def train_epoch(epoch,
                data_loader1,
                data_loader2,
                model,
                criterion,
                optimizer,
                device,
                current_lr,
                epoch_logger,
                batch_logger,
                is_master_node,
                tb_writer=None,
                distributed=False):

    print('train at epoch {}'.format(epoch))

    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()

    losses = AverageMeter()

    #

    accuracies = AverageMeter()

    end_time = time.time()
    # data_loader = data_loader1+data_loader2
    # for i,(inputs,targets) in enumerate(data_loader):
    # for i,(data1,data2) in enumerate(zip(data_loader1,data_loader2)):
    dataloader_iterator = iter(data_loader2)
    for i, data1 in enumerate(data_loader1):

        try:
            data2 = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(data_loader2)
            data2 = next(dataloader_iterator)

        data_time.update(time.time() - end_time)

        inputs1, targets1 = data1
        inputs2, targets2 = data2
        inputs = torch.cat((inputs1, inputs2), 0)

        targets = torch.cat((targets1, targets2), 0)

        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = criterion(outputs, targets)
        acc = calculate_accuracy(outputs, targets)

        losses.update(loss.item(), inputs.size(0))
        accuracies.update(acc, inputs.size(0))

        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end_time)
        end_time = time.time()
        itera = (epoch - 1) * int(len(data_loader1)) + (i + 1)
        batch_lr = get_lr(optimizer)
        if is_master_node:
            if tb_writer is not None:

                tb_writer.add_scalar('train_iter/loss_iter', losses.val, itera)
                tb_writer.add_scalar('train_iter/acc_iter', accuracies.val,
                                     itera)
                tb_writer.add_scalar('train_iter/lr_iter', batch_lr, itera)

        if batch_logger is not None:
            batch_logger.log({
                'epoch': epoch,
                'batch': i + 1,
                'iter': itera,
                'loss': losses.val,
                'acc': accuracies.val,
                'lr': current_lr
            })

        local_rank = 0
        if is_master_node:
            print('Train Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc {acc.val:.3f} ({acc.avg:.3f})\t'
                  'RANK {rank}'.format(epoch,
                                       i + 1,
                                       len(data_loader1),
                                       batch_time=batch_time,
                                       data_time=data_time,
                                       loss=losses,
                                       acc=accuracies,
                                       rank=local_rank))

        if distributed:
            loss_sum = torch.tensor([losses.sum],
                                    dtype=torch.float32,
                                    device=device)
            loss_count = torch.tensor([losses.count],
                                      dtype=torch.float32,
                                      device=device)
            acc_sum = torch.tensor([accuracies.sum],
                                   dtype=torch.float32,
                                   device=device)
            acc_count = torch.tensor([accuracies.count],
                                     dtype=torch.float32,
                                     device=device)

            dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(loss_count, op=dist.ReduceOp.SUM)
            dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(acc_count, op=dist.ReduceOp.SUM)

            losses.avg = loss_sum.item() / loss_count.item()
            accuracies.avg = acc_sum.item() / acc_count.item()

    if epoch_logger is not None:
        epoch_logger.log({
            'epoch': epoch,
            'loss': losses.avg,
            'acc': accuracies.avg,
            'lr': current_lr,
            'rank': local_rank
        })
    if is_master_node:
        if tb_writer is not None:
            tb_writer.add_scalar('train/loss', losses.avg, epoch)
            tb_writer.add_scalar('train/acc', accuracies.avg, epoch)
            tb_writer.add_scalar('train/lr', current_lr, epoch)
Esempio n. 9
0
def train_a_epoch(epoch,
                  data_loader,
                  model,
                  joint_prediction_aud,
                  criterion,
                  criterion_jsd,
                  criterion_ct_av,
                  optimizer,
                  optimizer_av,
                  device,
                  current_lr,
                  epoch_logger,
                  batch_logger,
                  tb_writer=None,
                  distributed=False):
    print('train at epoch {}'.format(epoch))
    model.train()
    joint_prediction_aud.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    # classification loss
    losses_cls = AverageMeter()
    accuracies = AverageMeter()
    # contrastive loss
    losses_ct_av = AverageMeter()
    # jsd loss
    losses_jsd_a = AverageMeter()

    end_time = time.time()
    for i, (inputs, targets, audios) in enumerate(data_loader):
        data_time.update(time.time() - end_time)

        outputs, features = model(inputs)
        targets = targets.to(device, non_blocking=True)
        audios = audios.to(device, non_blocking=True)

        loss_cls_v = criterion(outputs, targets)  # video classification loss
        acc = calculate_accuracy(outputs, targets)
        #####################################################################################
        # use audio features as features & filter out the zero-ones (not available) audio features
        features_aud = audios[audios.sum(dim=1) != 0]
        features_vid = features[audios.sum(dim=1) != 0]
        targets_new = targets[audios.sum(dim=1) != 0]
        outputs_new = outputs[audios.sum(dim=1) != 0]

        # here compose images and videos
        outputs_av, features_av = joint_prediction_aud(features_aud,
                                                       features_vid)
        loss_cls_a = criterion(outputs_av,
                               targets_new)  # video classification loss

        # contrastive learning (symmetric loss)
        # align video features to multimodal (audio-video) features
        loss_vm = criterion_ct_av(features_vid, features_av,
                                  targets_new) + criterion_ct_av(
                                      features_av, features_vid, targets_new)
        # align video features to audio features
        loss_va = criterion_ct_av(features_vid, features_aud,
                                  targets_new) + criterion_ct_av(
                                      features_aud, features_vid, targets_new)
        # align multimodal (audio-video) features to audio features
        loss_ma = criterion_ct_av(features_av, features_aud,
                                  targets_new) + criterion_ct_av(
                                      features_aud, features_av, targets_new)
        # contrastive loss
        loss_ct_av = loss_vm + loss_va
        loss_ct_a = loss_vm + loss_ma
        # jsd loss
        loss_jsd_a = criterion_jsd(outputs_new, outputs_av)
        #####################################################################################
        total_loss_v = sum([loss_cls_v, loss_ct_av, loss_jsd_a])
        total_loss_a = sum([loss_cls_a, loss_ct_a, loss_jsd_a])

        losses_cls.update(loss_cls_v.item(), inputs.size(0))
        losses_ct_av.update(loss_ct_av.item(), inputs.size(0))
        losses_jsd_a.update(loss_jsd_a.item(), inputs.size(0))

        accuracies.update(acc, inputs.size(0))

        optimizer_av.zero_grad()
        total_loss_a.backward(retain_graph=True)

        optimizer.zero_grad()
        total_loss_v.backward()

        optimizer_av.step()
        optimizer.step()
        #####################################################################################
        batch_time.update(time.time() - end_time)
        end_time = time.time()

        write_to_batch_logger(batch_logger, epoch, i, data_loader,
                              losses_cls.val, accuracies.val, current_lr)

        print('Epoch: [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
              'Loss_cls {loss_cls.val:.3f} ({loss_cls.avg:.3f})\t'
              'Loss_ct_a {loss_ct_av.val:.3f} ({loss_ct_av.avg:.3f})\t'
              'Loss_jsd_a {loss_jsd_a.val:.3f} ({loss_jsd_a.avg:.3f})\t'
              'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(
                  epoch,
                  i + 1,
                  len(data_loader),
                  batch_time=batch_time,
                  data_time=data_time,
                  loss_cls=losses_cls,
                  loss_ct_av=losses_ct_av,
                  loss_jsd_a=losses_jsd_a,
                  acc=accuracies),
              flush=True)

        if distributed:
            loss_cls_sum = torch.tensor([losses_cls.sum],
                                        dtype=torch.float32,
                                        device=device)
            loss_ct_av_sum = torch.tensor([losses_ct_av.sum],
                                          dtype=torch.float32,
                                          device=device)
            loss_jsd_a_sum = torch.tensor([losses_jsd_a.sum],
                                          dtype=torch.float32,
                                          device=device)
            acc_sum = torch.tensor([accuracies.sum],
                                   dtype=torch.float32,
                                   device=device)
            loss_count = torch.tensor([losses_cls.count],
                                      dtype=torch.float32,
                                      device=device)
            acc_count = torch.tensor([accuracies.count],
                                     dtype=torch.float32,
                                     device=device)

            dist.all_reduce(loss_cls_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(loss_ct_av_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(loss_jsd_a_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(loss_count, op=dist.ReduceOp.SUM)
            dist.all_reduce(acc_count, op=dist.ReduceOp.SUM)

            losses_cls.avg = loss_cls_sum.item() / loss_count.item()
            losses_ct_av.avg = loss_ct_av_sum.item() / loss_count.item()
            losses_jsd_a.avg = loss_jsd_a_sum.item() / loss_count.item()
            accuracies.avg = acc_sum.item() / acc_count.item()

        write_to_epoch_logger(epoch_logger, epoch, losses_cls.val,
                              accuracies.val, current_lr)

        if tb_writer is not None:
            tb_writer.add_scalar('train/loss_cls', losses_cls.avg, epoch)
            tb_writer.add_scalar('train/loss_ct_av', losses_ct_av.avg, epoch)
            tb_writer.add_scalar('train/loss_jsd_a', losses_jsd_a.avg, epoch)
            tb_writer.add_scalar('train/acc', accuracies.avg, epoch)
Esempio n. 10
0
def train_i_epoch(epoch,
                  data_loader,
                  model,
                  image_model,
                  joint_prediction_img,
                  criterion,
                  criterion_jsd,
                  criterion_ct_iv,
                  optimizer,
                  optimizer_iv,
                  device,
                  current_lr,
                  epoch_logger,
                  batch_logger,
                  tb_writer=None,
                  distributed=False,
                  image_size=None):
    print('train at epoch {}'.format(epoch))
    model.train()
    image_model.eval()
    joint_prediction_img.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    # classification loss
    losses_cls = AverageMeter()
    accuracies = AverageMeter()
    # contrastive loss
    losses_ct_iv = AverageMeter()
    # jsd loss
    losses_jsd_i = AverageMeter()

    end_time = time.time()
    for i, (inputs, targets) in enumerate(data_loader):
        data_time.update(time.time() - end_time)

        outputs, features = model(inputs)
        targets = targets.to(device, non_blocking=True)

        loss_cls_v = criterion(outputs, targets)  # video classification loss
        acc = calculate_accuracy(outputs, targets)
        #####################################################################################
        if image_size is not None:
            # here resize the image to a larger size than the video input size (better fit the resnet)
            inputs = F.interpolate(
                inputs, size=[inputs.shape[0], image_size, image_size])
        ### randomly select an image
        rand_img = random.randint(0, inputs.shape[2] - 1)
        images = inputs[:, :, rand_img, :, :]
        features_img = image_model(images)
        features_img = features_img.squeeze()

        # here compose images and videos
        outputs_iv, features_iv = joint_prediction_img(features_img, features)
        loss_cls_i = criterion(outputs_iv,
                               targets)  # video classification loss

        # contrastive learning (symmetric loss)
        # align video features to multimodal (image-video) features
        loss_vm = criterion_ct_iv(features, features_iv,
                                  targets) + criterion_ct_iv(
                                      features_iv, features, targets)
        # align video features to image features
        loss_vi = criterion_ct_iv(features, features_img,
                                  targets) + criterion_ct_iv(
                                      features_img, features, targets)
        # align multimodal features to image features
        loss_mi = criterion_ct_iv(features_iv, features_img,
                                  targets) + criterion_ct_iv(
                                      features_img, features_iv, targets)
        # contrastive loss
        loss_ct_iv = loss_vm + loss_vi
        loss_ct_i = loss_vm + loss_mi
        # jsd loss
        loss_jsd_i = criterion_jsd(outputs, outputs_iv)
        #####################################################################################
        total_loss_v = sum([loss_cls_v, loss_ct_iv, loss_jsd_i])
        total_loss_i = sum([loss_cls_i, loss_ct_i, loss_jsd_i])

        losses_cls.update(loss_cls_v.item(), inputs.size(0))
        losses_ct_iv.update(loss_ct_iv.item(), inputs.size(0))
        losses_jsd_i.update(loss_jsd_i.item(), inputs.size(0))
        accuracies.update(acc, inputs.size(0))

        optimizer_iv.zero_grad()
        total_loss_i.backward(retain_graph=True)

        optimizer.zero_grad()
        total_loss_v.backward()

        optimizer_iv.step()
        optimizer.step()
        #####################################################################################
        batch_time.update(time.time() - end_time)
        end_time = time.time()

        write_to_batch_logger(batch_logger, epoch, i, data_loader,
                              losses_cls.val, accuracies.val, current_lr)

        print('Epoch: [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
              'Loss_cls {loss_cls.val:.3f} ({loss_cls.avg:.3f})\t'
              'Loss_ct_i {loss_ct_iv.val:.3f} ({loss_ct_iv.avg:.3f})\t'
              'Loss_jsd_i {loss_jsd_i.val:.3f} ({loss_jsd_i.avg:.3f})\t'
              'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(
                  epoch,
                  i + 1,
                  len(data_loader),
                  batch_time=batch_time,
                  data_time=data_time,
                  loss_cls=losses_cls,
                  loss_ct_iv=losses_ct_iv,
                  loss_jsd_i=losses_jsd_i,
                  acc=accuracies),
              flush=True)

        if distributed:
            loss_cls_sum = torch.tensor([losses_cls.sum],
                                        dtype=torch.float32,
                                        device=device)
            loss_ct_iv_sum = torch.tensor([losses_ct_iv.sum],
                                          dtype=torch.float32,
                                          device=device)
            loss_jsd_i_sum = torch.tensor([losses_jsd_i.sum],
                                          dtype=torch.float32,
                                          device=device)
            acc_sum = torch.tensor([accuracies.sum],
                                   dtype=torch.float32,
                                   device=device)
            loss_count = torch.tensor([losses_cls.count],
                                      dtype=torch.float32,
                                      device=device)
            acc_count = torch.tensor([accuracies.count],
                                     dtype=torch.float32,
                                     device=device)

            dist.all_reduce(loss_cls_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(loss_ct_iv_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(loss_jsd_i_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(loss_count, op=dist.ReduceOp.SUM)
            dist.all_reduce(acc_count, op=dist.ReduceOp.SUM)

            losses_cls.avg = loss_cls_sum.item() / loss_count.item()
            losses_ct_iv.avg = loss_ct_iv_sum.item() / loss_count.item()
            losses_jsd_i.avg = loss_jsd_i_sum.item() / loss_count.item()
            accuracies.avg = acc_sum.item() / acc_count.item()

        write_to_epoch_logger(epoch_logger, epoch, losses_cls.val,
                              accuracies.val, current_lr)

        if tb_writer is not None:
            tb_writer.add_scalar('train/loss_cls', losses_cls.avg, epoch)
            tb_writer.add_scalar('train/loss_ct_iv', losses_ct_iv.avg, epoch)
            tb_writer.add_scalar('train/loss_jsd_i', losses_jsd_i.avg, epoch)
            tb_writer.add_scalar('train/acc', accuracies.avg, epoch)