def evaluate(epoch, args): args.model.eval() loss_avg = AverageMeter() acc_avg = AverageMeter() with torch.no_grad(): for data, target in tqdm(args.test_loader): batch_step(args, data, target, loss_avg, acc_avg) print("Epoch:", epoch, "Test Loss:", loss_avg.avg(), "Test Accuracy:", acc_avg.avg(), "\n")
def train(epoch, args): args.model.train() loss_avg = AverageMeter() acc_avg = AverageMeter() norm_avg = AverageMeter() progress_bar = tqdm(args.train_loader) for data, target in progress_bar: # learning rate warmup if args.warmup_lr: args.optimizer.param_groups[0]['lr'] = args.warmup_lr.pop(0) # a worker finished computing gradients according to the gamma distribution worker_rank = next(args.worker_order) load_worker(args, worker_rank) # worker computes gradient on its set of weights args.optimizer.zero_grad() loss = batch_step(args, data, target, loss_avg, acc_avg) loss.backward() # the master receives the gradients from the worker and updates its weights delay_compensation(args) load_master(args) args.optimizer.step() update_master(args) # the worker receives the master's new weights update_worker(args, worker_rank) # compute the gradient norm norm_avg.update( sum(p.grad.data.norm()**2 for p in args.model.parameters())**0.5, target.shape[0]) progress_bar.set_description( "Epoch: %d, Loss: %0.8f Norm: %0.4f LR: %0.4f" % (epoch, loss_avg.avg(), norm_avg.avg(), args.optimizer.param_groups[0]['lr'])) progress_bar.close()
def test(model, testloader, criterion, use_cuda): model.eval() losses = AverageMeter() top1 = AverageMeter() for batch_idx, (inputs, targets) in enumerate(testloader): if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() inputs, targets = torch.autograd.Variable( inputs, volatile=True), torch.autograd.Variable(targets) outputs = model(inputs) loss = criterion(outputs, targets) prec1 = accuracy(outputs.data, targets.data, topk=(1, )) losses.update(loss.data[0], inputs.size(0)) top1.update(prec1[0], inputs.size(0)) return (losses.avg(), top1.avg())
def train(model, trainloader, criterion, optimizer, epoch, use_cuda): model.train() losses = AverageMeter() top1 = AverageMeter() for batch_idx, (inputs, targets) in enumerate(trainloader): if use_cuda: inputs, targets = inputs.cuda(), targets.cuda(async=True) inputs, targets = torch.autograd.Variable( inputs), torch.autograd.Variable(targets) outputs = model(inputs) loss = criterion(outputs, targets) prec1 = accuracy(outputs.data, targets.data, topk=(1, )) losses.update(loss.data[0], inputs.size(0)) top1.update(prec1[0], inputs.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() return (losses.avg(), top1.avg())
def val_epoch(epoch, data_loader, model, criterion, device, logger, tb_writer=None, distributed=False): print('validation at epoch {}'.format(epoch)) model.eval() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() accuracies = AverageMeter() end_time = time.time() with torch.no_grad(): for i, (inputs, targets) in enumerate(data_loader): data_time.update(time.time() - end_time) targets = targets.to(device, non_blocking=True) outputs = model(inputs) loss = criterion(outputs, targets) acc = calculate_accuracy(outputs, targets) losses.update(loss.item(), inputs.size(0)) accuracies.update(acc, inputs.size(0)) batch_time.update(time.time() - end_time) end_time = time.time() print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc {acc.val:.3f} ({acc.avg:.3f})'.format( epoch, i + 1, len(data_loader), batch_time=batch_time, data_time=data_time, loss=losses, acc=accuracies).expandtabs(tabsize=4)) if distributed: loss_sum = torch.tensor([losses.sum], dtype=torch.float32, device=device) loss_count = torch.tensor([losses.count], dtype=torch.float32, device=device) acc_sum = torch.tensor([accuracies.sum], dtype=torch.float32, device=device) acc_count = torch.tensor([accuracies.count], dtype=torch.float32, device=device) dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(loss_count, op=dist.ReduceOp.SUM) dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM) dist.all_reduce(acc_count, op=dist.ReduceOp.SUM) losses.avg = loss_sum.item() / loss_count.item() accuracies.avg = acc_sum.item() / acc_count.item() if logger is not None: logger.log({'epoch': epoch, 'loss': losses.avg, 'acc': accuracies.avg}) if tb_writer is not None: tb_writer.add_scalar('val/loss', losses.avg, epoch) tb_writer.add_scalar('val/acc', accuracies.avg, epoch) return losses.avg
def train_epoch(epoch, data_loader, model, criterion, optimizer, device, current_lr, epoch_logger, batch_logger, tb_writer=None, distributed=False): print('train at epoch {}'.format(epoch)) model.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() accuracies = AverageMeter() end_time = time.time() for i, (inputs, targets) in enumerate(data_loader): data_time.update(time.time() - end_time) targets = targets.to(device, non_blocking=True) outputs = model(inputs) loss = criterion(outputs, targets) acc = calculate_accuracy(outputs, targets) losses.update(loss.item(), inputs.size(0)) accuracies.update(acc, inputs.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end_time) end_time = time.time() if batch_logger is not None: batch_logger.log({ 'epoch': epoch, 'batch': i + 1, 'iter': (epoch - 1) * len(data_loader) + (i + 1), 'loss': losses.val, 'acc': accuracies.val, 'lr': current_lr }) print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(epoch, i + 1, len(data_loader), batch_time=batch_time, data_time=data_time, loss=losses, acc=accuracies)) if distributed: loss_sum = torch.tensor([losses.sum], dtype=torch.float32, device=device) loss_count = torch.tensor([losses.count], dtype=torch.float32, device=device) acc_sum = torch.tensor([accuracies.sum], dtype=torch.float32, device=device) acc_count = torch.tensor([accuracies.count], dtype=torch.float32, device=device) dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(loss_count, op=dist.ReduceOp.SUM) dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM) dist.all_reduce(acc_count, op=dist.ReduceOp.SUM) losses.avg = loss_sum.item() / loss_count.item() accuracies.avg = acc_sum.item() / acc_count.item() if epoch_logger is not None: epoch_logger.log({ 'epoch': epoch, 'loss': losses.avg, 'acc': accuracies.avg, 'lr': current_lr }) if tb_writer is not None: tb_writer.add_scalar('train/loss', losses.avg, epoch) tb_writer.add_scalar('train/acc', accuracies.avg, epoch) tb_writer.add_scalar('train/lr', current_lr, epoch)
def val_epoch(epoch, data_loader, model, criterion, device, logger, tb_writer=None, distributed=False, rpn=None, det_interval=2, nrois=10): print('validation at epoch {}'.format(epoch)) model.eval() if rpn is not None: rpn.eval() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() accuracies = AverageMeter() end_time = time.time() with torch.no_grad(): for i, (inputs, targets) in enumerate(data_loader): data_time.update(time.time() - end_time) targets = targets.to(device, non_blocking=True) if rpn is not None: ''' There was an unexpected CUDNN_ERROR when len(rpn_inputs) is decrased. ''' T = inputs.shape[2] N, C, T, H, W = inputs.size() if i == 0: max_N = N # sample frames for RPN sample = torch.arange(0, T, det_interval) rpn_inputs = inputs[:, :, sample].transpose(1, 2).contiguous() rpn_inputs = rpn_inputs.view(-1, C, H, W) if len(inputs) < max_N: print("Modified from {} to {}".format(len(inputs), max_N)) rpn_inputs = torch.cat( (rpn_inputs, rpn_inputs[:(max_N - len(inputs)) * (T // det_interval)])) with torch.no_grad(): proposals = rpn(rpn_inputs) proposals = proposals.view(-1, T // det_interval, nrois, 4) if len(inputs) < max_N: proposals = proposals[:len(inputs)] outputs = model(inputs, proposals.detach()) # update to the largest batch_size max_N = max(N, max_N) else: outputs = model(inputs) loss = criterion(outputs, targets) acc = calculate_accuracy(outputs, targets) losses.update(loss.item(), inputs.size(0)) accuracies.update(acc, inputs.size(0)) batch_time.update(time.time() - end_time) end_time = time.time() print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc {acc.val:.3f} ({acc.avg:.3f})'.format( epoch, i + 1, len(data_loader), batch_time=batch_time, data_time=data_time, loss=losses, acc=accuracies)) if distributed: loss_sum = torch.tensor([losses.sum], dtype=torch.float32, device=device) loss_count = torch.tensor([losses.count], dtype=torch.float32, device=device) acc_sum = torch.tensor([accuracies.sum], dtype=torch.float32, device=device) acc_count = torch.tensor([accuracies.count], dtype=torch.float32, device=device) dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(loss_count, op=dist.ReduceOp.SUM) dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM) dist.all_reduce(acc_count, op=dist.ReduceOp.SUM) losses.avg = loss_sum.item() / loss_count.item() accuracies.avg = acc_sum.item() / acc_count.item() if logger is not None: logger.log({'epoch': epoch, 'loss': losses.avg, 'acc': accuracies.avg}) if tb_writer is not None: tb_writer.add_scalar('val/loss', losses.avg, epoch) tb_writer.add_scalar('val/acc', accuracies.avg, epoch) return losses.avg
def train_epoch(epoch, data_loader1, data_loader2, model, criterion, optimizer, device, current_lr, epoch_logger, batch_logger, is_master_node, tb_writer=None, distributed=False): print('train at epoch {}'.format(epoch)) model.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() # accuracies = AverageMeter() end_time = time.time() # data_loader = data_loader1+data_loader2 # for i,(inputs,targets) in enumerate(data_loader): # for i,(data1,data2) in enumerate(zip(data_loader1,data_loader2)): dataloader_iterator = iter(data_loader2) for i, data1 in enumerate(data_loader1): try: data2 = next(dataloader_iterator) except StopIteration: dataloader_iterator = iter(data_loader2) data2 = next(dataloader_iterator) data_time.update(time.time() - end_time) inputs1, targets1 = data1 inputs2, targets2 = data2 inputs = torch.cat((inputs1, inputs2), 0) targets = torch.cat((targets1, targets2), 0) inputs = inputs.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) acc = calculate_accuracy(outputs, targets) losses.update(loss.item(), inputs.size(0)) accuracies.update(acc, inputs.size(0)) loss.backward() optimizer.step() batch_time.update(time.time() - end_time) end_time = time.time() itera = (epoch - 1) * int(len(data_loader1)) + (i + 1) batch_lr = get_lr(optimizer) if is_master_node: if tb_writer is not None: tb_writer.add_scalar('train_iter/loss_iter', losses.val, itera) tb_writer.add_scalar('train_iter/acc_iter', accuracies.val, itera) tb_writer.add_scalar('train_iter/lr_iter', batch_lr, itera) if batch_logger is not None: batch_logger.log({ 'epoch': epoch, 'batch': i + 1, 'iter': itera, 'loss': losses.val, 'acc': accuracies.val, 'lr': current_lr }) local_rank = 0 if is_master_node: print('Train Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc {acc.val:.3f} ({acc.avg:.3f})\t' 'RANK {rank}'.format(epoch, i + 1, len(data_loader1), batch_time=batch_time, data_time=data_time, loss=losses, acc=accuracies, rank=local_rank)) if distributed: loss_sum = torch.tensor([losses.sum], dtype=torch.float32, device=device) loss_count = torch.tensor([losses.count], dtype=torch.float32, device=device) acc_sum = torch.tensor([accuracies.sum], dtype=torch.float32, device=device) acc_count = torch.tensor([accuracies.count], dtype=torch.float32, device=device) dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(loss_count, op=dist.ReduceOp.SUM) dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM) dist.all_reduce(acc_count, op=dist.ReduceOp.SUM) losses.avg = loss_sum.item() / loss_count.item() accuracies.avg = acc_sum.item() / acc_count.item() if epoch_logger is not None: epoch_logger.log({ 'epoch': epoch, 'loss': losses.avg, 'acc': accuracies.avg, 'lr': current_lr, 'rank': local_rank }) if is_master_node: if tb_writer is not None: tb_writer.add_scalar('train/loss', losses.avg, epoch) tb_writer.add_scalar('train/acc', accuracies.avg, epoch) tb_writer.add_scalar('train/lr', current_lr, epoch)
def train_a_epoch(epoch, data_loader, model, joint_prediction_aud, criterion, criterion_jsd, criterion_ct_av, optimizer, optimizer_av, device, current_lr, epoch_logger, batch_logger, tb_writer=None, distributed=False): print('train at epoch {}'.format(epoch)) model.train() joint_prediction_aud.train() batch_time = AverageMeter() data_time = AverageMeter() # classification loss losses_cls = AverageMeter() accuracies = AverageMeter() # contrastive loss losses_ct_av = AverageMeter() # jsd loss losses_jsd_a = AverageMeter() end_time = time.time() for i, (inputs, targets, audios) in enumerate(data_loader): data_time.update(time.time() - end_time) outputs, features = model(inputs) targets = targets.to(device, non_blocking=True) audios = audios.to(device, non_blocking=True) loss_cls_v = criterion(outputs, targets) # video classification loss acc = calculate_accuracy(outputs, targets) ##################################################################################### # use audio features as features & filter out the zero-ones (not available) audio features features_aud = audios[audios.sum(dim=1) != 0] features_vid = features[audios.sum(dim=1) != 0] targets_new = targets[audios.sum(dim=1) != 0] outputs_new = outputs[audios.sum(dim=1) != 0] # here compose images and videos outputs_av, features_av = joint_prediction_aud(features_aud, features_vid) loss_cls_a = criterion(outputs_av, targets_new) # video classification loss # contrastive learning (symmetric loss) # align video features to multimodal (audio-video) features loss_vm = criterion_ct_av(features_vid, features_av, targets_new) + criterion_ct_av( features_av, features_vid, targets_new) # align video features to audio features loss_va = criterion_ct_av(features_vid, features_aud, targets_new) + criterion_ct_av( features_aud, features_vid, targets_new) # align multimodal (audio-video) features to audio features loss_ma = criterion_ct_av(features_av, features_aud, targets_new) + criterion_ct_av( features_aud, features_av, targets_new) # contrastive loss loss_ct_av = loss_vm + loss_va loss_ct_a = loss_vm + loss_ma # jsd loss loss_jsd_a = criterion_jsd(outputs_new, outputs_av) ##################################################################################### total_loss_v = sum([loss_cls_v, loss_ct_av, loss_jsd_a]) total_loss_a = sum([loss_cls_a, loss_ct_a, loss_jsd_a]) losses_cls.update(loss_cls_v.item(), inputs.size(0)) losses_ct_av.update(loss_ct_av.item(), inputs.size(0)) losses_jsd_a.update(loss_jsd_a.item(), inputs.size(0)) accuracies.update(acc, inputs.size(0)) optimizer_av.zero_grad() total_loss_a.backward(retain_graph=True) optimizer.zero_grad() total_loss_v.backward() optimizer_av.step() optimizer.step() ##################################################################################### batch_time.update(time.time() - end_time) end_time = time.time() write_to_batch_logger(batch_logger, epoch, i, data_loader, losses_cls.val, accuracies.val, current_lr) print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss_cls {loss_cls.val:.3f} ({loss_cls.avg:.3f})\t' 'Loss_ct_a {loss_ct_av.val:.3f} ({loss_ct_av.avg:.3f})\t' 'Loss_jsd_a {loss_jsd_a.val:.3f} ({loss_jsd_a.avg:.3f})\t' 'Acc {acc.val:.3f} ({acc.avg:.3f})'.format( epoch, i + 1, len(data_loader), batch_time=batch_time, data_time=data_time, loss_cls=losses_cls, loss_ct_av=losses_ct_av, loss_jsd_a=losses_jsd_a, acc=accuracies), flush=True) if distributed: loss_cls_sum = torch.tensor([losses_cls.sum], dtype=torch.float32, device=device) loss_ct_av_sum = torch.tensor([losses_ct_av.sum], dtype=torch.float32, device=device) loss_jsd_a_sum = torch.tensor([losses_jsd_a.sum], dtype=torch.float32, device=device) acc_sum = torch.tensor([accuracies.sum], dtype=torch.float32, device=device) loss_count = torch.tensor([losses_cls.count], dtype=torch.float32, device=device) acc_count = torch.tensor([accuracies.count], dtype=torch.float32, device=device) dist.all_reduce(loss_cls_sum, op=dist.ReduceOp.SUM) dist.all_reduce(loss_ct_av_sum, op=dist.ReduceOp.SUM) dist.all_reduce(loss_jsd_a_sum, op=dist.ReduceOp.SUM) dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM) dist.all_reduce(loss_count, op=dist.ReduceOp.SUM) dist.all_reduce(acc_count, op=dist.ReduceOp.SUM) losses_cls.avg = loss_cls_sum.item() / loss_count.item() losses_ct_av.avg = loss_ct_av_sum.item() / loss_count.item() losses_jsd_a.avg = loss_jsd_a_sum.item() / loss_count.item() accuracies.avg = acc_sum.item() / acc_count.item() write_to_epoch_logger(epoch_logger, epoch, losses_cls.val, accuracies.val, current_lr) if tb_writer is not None: tb_writer.add_scalar('train/loss_cls', losses_cls.avg, epoch) tb_writer.add_scalar('train/loss_ct_av', losses_ct_av.avg, epoch) tb_writer.add_scalar('train/loss_jsd_a', losses_jsd_a.avg, epoch) tb_writer.add_scalar('train/acc', accuracies.avg, epoch)
def train_i_epoch(epoch, data_loader, model, image_model, joint_prediction_img, criterion, criterion_jsd, criterion_ct_iv, optimizer, optimizer_iv, device, current_lr, epoch_logger, batch_logger, tb_writer=None, distributed=False, image_size=None): print('train at epoch {}'.format(epoch)) model.train() image_model.eval() joint_prediction_img.train() batch_time = AverageMeter() data_time = AverageMeter() # classification loss losses_cls = AverageMeter() accuracies = AverageMeter() # contrastive loss losses_ct_iv = AverageMeter() # jsd loss losses_jsd_i = AverageMeter() end_time = time.time() for i, (inputs, targets) in enumerate(data_loader): data_time.update(time.time() - end_time) outputs, features = model(inputs) targets = targets.to(device, non_blocking=True) loss_cls_v = criterion(outputs, targets) # video classification loss acc = calculate_accuracy(outputs, targets) ##################################################################################### if image_size is not None: # here resize the image to a larger size than the video input size (better fit the resnet) inputs = F.interpolate( inputs, size=[inputs.shape[0], image_size, image_size]) ### randomly select an image rand_img = random.randint(0, inputs.shape[2] - 1) images = inputs[:, :, rand_img, :, :] features_img = image_model(images) features_img = features_img.squeeze() # here compose images and videos outputs_iv, features_iv = joint_prediction_img(features_img, features) loss_cls_i = criterion(outputs_iv, targets) # video classification loss # contrastive learning (symmetric loss) # align video features to multimodal (image-video) features loss_vm = criterion_ct_iv(features, features_iv, targets) + criterion_ct_iv( features_iv, features, targets) # align video features to image features loss_vi = criterion_ct_iv(features, features_img, targets) + criterion_ct_iv( features_img, features, targets) # align multimodal features to image features loss_mi = criterion_ct_iv(features_iv, features_img, targets) + criterion_ct_iv( features_img, features_iv, targets) # contrastive loss loss_ct_iv = loss_vm + loss_vi loss_ct_i = loss_vm + loss_mi # jsd loss loss_jsd_i = criterion_jsd(outputs, outputs_iv) ##################################################################################### total_loss_v = sum([loss_cls_v, loss_ct_iv, loss_jsd_i]) total_loss_i = sum([loss_cls_i, loss_ct_i, loss_jsd_i]) losses_cls.update(loss_cls_v.item(), inputs.size(0)) losses_ct_iv.update(loss_ct_iv.item(), inputs.size(0)) losses_jsd_i.update(loss_jsd_i.item(), inputs.size(0)) accuracies.update(acc, inputs.size(0)) optimizer_iv.zero_grad() total_loss_i.backward(retain_graph=True) optimizer.zero_grad() total_loss_v.backward() optimizer_iv.step() optimizer.step() ##################################################################################### batch_time.update(time.time() - end_time) end_time = time.time() write_to_batch_logger(batch_logger, epoch, i, data_loader, losses_cls.val, accuracies.val, current_lr) print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss_cls {loss_cls.val:.3f} ({loss_cls.avg:.3f})\t' 'Loss_ct_i {loss_ct_iv.val:.3f} ({loss_ct_iv.avg:.3f})\t' 'Loss_jsd_i {loss_jsd_i.val:.3f} ({loss_jsd_i.avg:.3f})\t' 'Acc {acc.val:.3f} ({acc.avg:.3f})'.format( epoch, i + 1, len(data_loader), batch_time=batch_time, data_time=data_time, loss_cls=losses_cls, loss_ct_iv=losses_ct_iv, loss_jsd_i=losses_jsd_i, acc=accuracies), flush=True) if distributed: loss_cls_sum = torch.tensor([losses_cls.sum], dtype=torch.float32, device=device) loss_ct_iv_sum = torch.tensor([losses_ct_iv.sum], dtype=torch.float32, device=device) loss_jsd_i_sum = torch.tensor([losses_jsd_i.sum], dtype=torch.float32, device=device) acc_sum = torch.tensor([accuracies.sum], dtype=torch.float32, device=device) loss_count = torch.tensor([losses_cls.count], dtype=torch.float32, device=device) acc_count = torch.tensor([accuracies.count], dtype=torch.float32, device=device) dist.all_reduce(loss_cls_sum, op=dist.ReduceOp.SUM) dist.all_reduce(loss_ct_iv_sum, op=dist.ReduceOp.SUM) dist.all_reduce(loss_jsd_i_sum, op=dist.ReduceOp.SUM) dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM) dist.all_reduce(loss_count, op=dist.ReduceOp.SUM) dist.all_reduce(acc_count, op=dist.ReduceOp.SUM) losses_cls.avg = loss_cls_sum.item() / loss_count.item() losses_ct_iv.avg = loss_ct_iv_sum.item() / loss_count.item() losses_jsd_i.avg = loss_jsd_i_sum.item() / loss_count.item() accuracies.avg = acc_sum.item() / acc_count.item() write_to_epoch_logger(epoch_logger, epoch, losses_cls.val, accuracies.val, current_lr) if tb_writer is not None: tb_writer.add_scalar('train/loss_cls', losses_cls.avg, epoch) tb_writer.add_scalar('train/loss_ct_iv', losses_ct_iv.avg, epoch) tb_writer.add_scalar('train/loss_jsd_i', losses_jsd_i.avg, epoch) tb_writer.add_scalar('train/acc', accuracies.avg, epoch)