Esempio n. 1
0
def aggregate_gradients(model, world_size):
    """Average gradients of models across all processes."""
    # all_reduce the gradients.
    for ind, param in enumerate(model.parameters()):
        # all reduce.
        dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
        param.grad.data /= world_size
Esempio n. 2
0
    def run_epoch(self) -> float:
        logger.info("Computing reward")
        rewards = self.individual_pool.compute_all_local_rewards()
        logger.info("Pushing reward")

        # Sum the rewards across all machines
        distributed.all_reduce(rewards, self.process_group)

        # Divide the rewards by the number of machines.  We do this because
        # there is no "average" all_reduce operator.
        rewards /= self.num_nodes

        self.iteration += 1
        self.individual_pool.apply_global_reward(rewards, self.iteration)
        most_recent_avg_rewards = float(torch.mean(rewards))
        new_parent_reward = self.individual_pool.compute_local_reward(
            self.individual_pool.parent_tensors
        )
        logger.info(
            "ITERATION: {0} MEAN REWARD: {1}, NEW PARENT REWARD: {2}".format(
                self.iteration, most_recent_avg_rewards, new_parent_reward
            )
        )

        return new_parent_reward
Esempio n. 3
0
        def _process_batch():
            dev_grad_batch, dev_events, job_event = queue.get()
            dev_coalesced = []
            # Coalesce the tensors on all devices and start a local reduction
            for dev_id, grad_batch, event, stream in zip(device_ids, dev_grad_batch, dev_events, reduction_streams):
                with torch.cuda.device(dev_id), torch.cuda.stream(stream):
                    stream.wait_event(event)
                    coalesced = _flatten_tensors(grad_batch)
                    dev_coalesced.append(coalesced)
            # Wait for all copies to complete before starting the NCCL kernel
            for stream in reduction_streams:
                stream.synchronize()
            nccl.reduce(dev_coalesced, root=device_ids[0], streams=nccl_streams)

            # From now on we're only going to work on the first device (from device_ids)
            grad_batch = dev_grad_batch[0]
            coalesced = dev_coalesced[0]
            reduce_stream = reduction_streams[0]
            with torch.cuda.stream(reduce_stream):
                reduce_stream.wait_stream(nccl_streams[0])
                coalesced /= dist.get_world_size()
                dist.all_reduce(coalesced, group=group_id)
                for grad, reduced in zip(grad_batch, _unflatten_tensors(coalesced, grad_batch)):
                    grad.copy_(reduced)
            job_event.set()
Esempio n. 4
0
 def on_batch_end(self, last_output, last_target, **kwargs):
     "Update metric computation with `last_output` and `last_target`."
     if not is_listy(last_target): last_target=[last_target]
     self.count += last_target[0].size(0)
     val = self.func(last_output, *last_target)
     if self.world:
         val = val.clone()
         dist.all_reduce(val, op=dist.ReduceOp.SUM)
         val /= self.world
     self.val += last_target[0].size(0) * val.detach().cpu()
Esempio n. 5
0
def test_mpi():
    dist.init_process_group('mpi')
    world_size = dist.get_world_size()
    rank = dist.get_rank()

    vector = [0] * world_size
    vector[rank] = 1
    vector = torch.DoubleTensor(vector)

    dist.all_reduce(vector, op=dist.reduce_op.SUM)
    print("Host {} : Rank {} : {}".format(get_hostname(), rank, vector))
Esempio n. 6
0
        def allreduce_params():
            if self.needs_reduction:
                self.needs_reduction = False
                buckets = defaultdict(list)
                for param in self.module.parameters():
                    if param.requires_grad and param.grad is not None:
                        tp = type(param.data)
                        buckets[tp].append(param)

                for bucket in buckets.values():
                    grads = [param.grad.data for param in bucket]
                    coalesced = _flatten_dense_tensors(grads)
                    dist.all_reduce(coalesced)
                    coalesced /= dist.get_world_size()
                    for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
                        buf.copy_(synced)
Esempio n. 7
0
    def _test_all_reduce_helper(self, group, group_id, rank, op, master_value,
                                worker_value, expected_value, cuda=False):
        for src in group:
            if rank == src:
                tensor = _build_tensor(src + 1).fill_(master_value)
                if cuda:
                    tensor = tensor.cuda()
                dist.all_reduce(tensor, op, group_id)
                self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
            else:
                tensor = _build_tensor(src + 1).fill_(worker_value)
                if cuda:
                    tensor = tensor.cuda()
                dist.all_reduce(tensor, op, group_id)
                self.assertEqual(tensor, _build_tensor(src + 1, expected_value))

        self._barrier()
Esempio n. 8
0
def _ranks_on_same_node(rank, world_size):
    hostname = socket.gethostname()
    hostname_length = torch.IntTensor([len(hostname)])
    dist.all_reduce(hostname_length, op=dist.reduce_op.MAX)
    max_hostname_length = hostname_length.item()

    encoding = [ord(c) for c in hostname]
    encoding += [-1 for c in range(max_hostname_length - len(hostname))]
    encoding = torch.IntTensor(encoding)

    all_encodings = [torch.IntTensor([0] * max_hostname_length) for _ in range(world_size)]
    dist.all_gather(all_encodings, encoding)

    all_encodings = [ec.numpy().tolist() for ec in all_encodings]
    counter = 0
    for i in range(rank):
        if all_encodings[rank] == all_encodings[i]:
            counter += 1
    return counter
        def allreduce_params():
            if (self.needs_reduction):
                self.needs_reduction = False
                buckets = {}
                for param in self.module.parameters():
                    if param.requires_grad and param.grad is not None:
                        tp = type(param.data)
                        if tp not in buckets:
                            buckets[tp] = []
                        buckets[tp].append(param)
                if self.warn_on_half:
                    if torch.cuda.HalfTensor in buckets:
                        print("WARNING: gloo dist backend for half parameters may be extremely slow." +
                              " It is recommended to use the NCCL backend in this case.")
                        self.warn_on_half = False

                for tp in buckets:
                    bucket = buckets[tp]
                    grads = [param.grad.data for param in bucket]
                    coalesced = _flatten_dense_tensors(grads)
                    dist.all_reduce(coalesced)
                    coalesced /= dist.get_world_size()
                    for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
                        buf.copy_(synced)
Esempio n. 10
0
def avg_param(model):
    for param in model.parameters():
        dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
        param.data /= float(world_size)
Esempio n. 11
0
def reduce_loss(total_loss, n_samples):
    reduction = torch.FloatTensor([total_loss, n_samples])
    dist.all_reduce(reduction, op=dist.ReduceOp.SUM)
    if rank == 0: print('n_samples : ', int(reduction[1].item()))
    return float(reduction[0].item() / reduction[1].item())
def train_one_epoch(
        dataloader: torch.utils.data.DataLoader,
        valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel,
        ali_model: Optional[AcousticModel], device: torch.device,
        graph_compiler: MmiTrainingGraphCompiler, use_pruned_intersect: bool,
        optimizer: torch.optim.Optimizer, accum_grad: int, den_scale: float,
        att_rate: float, current_epoch: int, tb_writer: SummaryWriter,
        num_epochs: int, global_batch_idx_train: int, world_size: int,
        scaler: GradScaler):
    """One epoch training and validation.

    Args:
        dataloader: Training dataloader
        valid_dataloader: Validation dataloader
        model: Acoustic model to be trained
        device: Training device, torch.device("cpu") or torch.device("cuda", device_id)
        graph_compiler: MMI training graph compiler
        optimizer: Training optimizer
        accum_grad: Number of gradient accumulation
        den_scale: Denominator scale in mmi loss
        att_rate: Attention loss rate, final loss is att_rate * att_loss + (1-att_rate) * other_loss
        current_epoch: current training epoch, for logging only
        tb_writer: tensorboard SummaryWriter
        num_epochs: total number of training epochs, for logging only
        global_batch_idx_train: global training batch index before this epoch, for logging only

    Returns:
        A tuple of 3 scalar:  (total_objf / total_frames, valid_average_objf, global_batch_idx_train)
        - `total_objf / total_frames` is the average training loss
        - `valid_average_objf` is the average validation loss
        - `global_batch_idx_train` is the global training batch index after this epoch
    """
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    valid_average_objf = float('inf')
    time_waiting_for_batch = 0
    forward_count = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        #if batch_idx >= 620:
        forward_count += 1
        if forward_count == accum_grad:
            is_update = True
            forward_count = 0
        else:
            is_update = False

        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
            batch=batch,
            model=model,
            ali_model=ali_model,
            device=device,
            graph_compiler=graph_compiler,
            use_pruned_intersect=use_pruned_intersect,
            is_training=True,
            is_update=is_update,
            accum_grad=accum_grad,
            den_scale=den_scale,
            att_rate=att_rate,
            tb_writer=tb_writer,
            global_batch_idx_train=global_batch_idx_train,
            optimizer=optimizer,
            scaler=scaler)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))

            if tb_writer is not None:
                tb_writer.add_scalar('train/global_average_objf',
                                     total_objf / total_frames,
                                     global_batch_idx_train)

                tb_writer.add_scalar(
                    'train/current_batch_average_objf',
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 200 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                ali_model=ali_model,
                device=device,
                graph_compiler=graph_compiler,
                use_pruned_intersect=use_pruned_intersect,
                scaler=scaler)
            if world_size > 1:
                s = torch.tensor([
                    total_valid_objf, total_valid_frames,
                    total_valid_all_frames
                ]).to(device)

                dist.all_reduce(s, op=dist.ReduceOp.SUM)
                total_valid_objf, total_valid_frames, total_valid_all_frames = s.cpu(
                ).tolist()

            valid_average_objf = total_valid_objf / total_valid_frames
            model.train()
            logging.info(
                'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                .format(valid_average_objf, total_valid_frames,
                        100.0 * total_valid_frames / total_valid_all_frames))

            if tb_writer is not None:
                tb_writer.add_scalar('train/global_valid_average_objf',
                                     valid_average_objf,
                                     global_batch_idx_train)
                if hasattr(model, 'module'):
                    model.module.write_tensorboard_diagnostics(
                        tb_writer, global_step=global_batch_idx_train)
                else:
                    model.write_tensorboard_diagnostics(
                        tb_writer, global_step=global_batch_idx_train)
        prev_timestamp = datetime.now()
    return total_objf / total_frames, valid_average_objf, global_batch_idx_train
Esempio n. 13
0
def elementwise_min(tensor):
    dist.all_reduce(tensor, op=dist.reduce_op.MIN)
    return tensor
Esempio n. 14
0
        x = self.dropout6(x)

        x = x.view(-1, 9 * 9 * 128)

        x = self.dropout7(F.relu(self.fc1(x)))
        x = self.dropout8(F.relu(self.fc2(x)))
        x = self.fc3(x)
        return x


model = Net2()

# Make sure that all nodes have the same model
for param in model.parameters():
    tensor0 = param.data
    dist.all_reduce(tensor0, op=dist.reduce_op.SUM)
    param.data = tensor0 / np.sqrt(np.float(num_nodes))

model.cuda()

Path_Save = '/projects/sciteam/bahp/RNN/TinyImageNetModel'
# torch.save(model.state_dict(), Path_Save)
# model.load_state_dict(torch.load(Path_Save))

LR = 0.001
batch_size = 100
Num_Epochs = 1000

criterion = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(model.parameters(), lr=LR)
Esempio n. 15
0
def reduce_tensor(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= dist.get_world_size()
    return rt
Esempio n. 16
0
def average_gradients(model):
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
        param.grad.data /= size
Esempio n. 17
0
def test_dist_allreduce():
    x = torch.ones(1, 3).cuda() * (dist.get_rank() + 1)
    sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2
    result = torch.ones(1, 3).cuda() * sum_of_ranks
    dist.all_reduce(x)
    assert torch.all(x == result)
def reduce_mean(tensor, nprocs):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= nprocs
    return rt
Esempio n. 19
0
 def _schedule_shadow_all_reduce_for_fwd_pass(self):
     all_active_procs = torch.zeros(1, device=self.device)
     dist.all_reduce(all_active_procs, group=self.process_group)
     return all_active_procs.item()
Esempio n. 20
0
def train(train_loader, model, criterion, optimizer, scheduler, epoch):
    global logger, conf, tb
    batch_time = utils.AverageMeter()
    data_time = utils.AverageMeter()
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    if conf["optimizer"]["schedule"]["mode"] == "epoch":
        scheduler.step(epoch)

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        if conf["optimizer"]["schedule"]["mode"] == "step":
            scheduler.step(i + epoch * len(train_loader))

        # measure data loading time
        data_time.update(time.time() - end)

        target = target.cuda(non_blocking=True)

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        if conf["optimizer"]["clip"] != 0.0:
            nn.utils.clip_grad_norm(model.parameters(),
                                    conf["optimizer"]["clip"])
        optimizer.step()

        # measure accuracy and record loss
        with torch.no_grad():
            output = output.detach()
            loss = loss.detach() * target.shape[0]
            prec1, prec5 = utils.accuracy_sum(output, target, topk=(1, 5))
            count = target.new_tensor([target.shape[0]], dtype=torch.long)
            if dist.is_initialized():
                dist.all_reduce(count, dist.ReduceOp.SUM)
            for meter, val in (losses, loss), (top1, prec1), (top5, prec5):
                if dist.is_initialized():
                    dist.all_reduce(val, dist.ReduceOp.SUM)
                val /= count.item()
                meter.update(val.item(), count.item())

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            logger.info("Epoch: [{0}][{1}/{2}]\t"
                        "Time {batch_time.val:.3f} ({batch_time.avg:.3f}) \t"
                        "Data {data_time.val:.3f} ({data_time.avg:.3f}) \t"
                        "Loss {loss.val:.4f} ({loss.avg:.4f}) \t"
                        "Prec@1 {top1.val:.3f} ({top1.avg:.3f}) \t"
                        "Prec@5 {top5.val:.3f} ({top5.avg:.3f})".format(
                            epoch,
                            i,
                            len(train_loader),
                            batch_time=batch_time,
                            data_time=data_time,
                            loss=losses,
                            top1=top1,
                            top5=top5,
                        ))

        if not dist.is_initialized() or dist.get_rank() == 0:
            tb.add_scalar("train/loss", losses.val,
                          i + epoch * len(train_loader))
            tb.add_scalar("train/lr",
                          scheduler.get_lr()[0], i + epoch * len(train_loader))
            tb.add_scalar("train/top1", top1.val,
                          i + epoch * len(train_loader))
            tb.add_scalar("train/top5", top5.val,
                          i + epoch * len(train_loader))
            if args.log_hist and i % 10 == 0:
                for name, param in model.named_parameters():
                    if name.find("fc") != -1 or name.find("bn_out") != -1:
                        tb.add_histogram(
                            name,
                            param.clone().cpu().data.numpy(),
                            i + epoch * len(train_loader),
                        )
Esempio n. 21
0
def run(rank, size):
    """ Simple point-to-point communication. """
    group = dist.new_group([0, 1])
    tensor = torch.ones(1)
    dist.all_reduce(tensor, op=dist.reduce_op.SUM, group=group)
    print('Rank ', rank, ' has data ', tensor[0])
Esempio n. 22
0
def train(train_loader, model, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    derain_loss_meter = AverageMeter()
    seg_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    psnr_meter = AverageMeter()
    ssim_meter = AverageMeter()

    list_multiply = lambda x, y: x * y
    assert len(args.seg_loss_step_weight) == args.num_steps

    model.train()
    end = time.time()
    max_iter = args.epochs * len(train_loader)
    for i, (clear_label, rain_input) in enumerate(train_loader):
        data_time.update(time.time() - end)

        clear_label = clear_label.cuda(non_blocking=True)
        rain_input = rain_input.cuda(non_blocking=True)
        derain_output, derain_losses = model(rain_input, clear_label)
        derain_losses = map(list_multiply, derain_losses,
                            args.derain_loss_step_weight)
        derain_sum_loss = sum(derain_losses)
        if not args.multiprocessing_distributed:
            derain_sum_loss = torch.mean(derain_sum_loss)
        loss = args.derain_loss_weight * derain_sum_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        n = rain_input.size(0)
        if args.multiprocessing_distributed:
            derain_sum_loss, loss = derain_sum_loss.detach() * n, \
                                    loss * n  # not considering ignore pixels
            count = clear_label.new_tensor([n], dtype=torch.long)
            dist.all_reduce(derain_sum_loss), dist.all_reduce(
                loss), dist.all_reduce(count)
            n = count.item()
            derain_sum_loss, loss = derain_sum_loss / n, loss / n

        # intersection, union, target = intersectionAndUnionCPU(seg_output, seg_label, args.classes, args.ignore_label)
        psnr, ssim = batchPSNRandSSIMGPU(derain_output, clear_label)
        # if args.multiprocessing_distributed:
        #     dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(target)
        # intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
        # intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)
        psnr_meter.update(psnr), ssim_meter.update(ssim)

        # accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        accuracy = 0
        psnr_val = psnr_meter.val
        ssim_val = ssim_meter.val
        derain_loss_meter.update(derain_sum_loss.item(), n)
        loss_meter.update(loss.item(), n)
        batch_time.update(time.time() - end)
        end = time.time()

        current_iter = epoch * len(train_loader) + i + 1
        current_lr = poly_learning_rate(args.base_lr,
                                        current_iter,
                                        max_iter,
                                        power=args.power)
        for index in range(0, args.index_split_1):
            optimizer.param_groups[index]['lr'] = current_lr * 0
        for index in range(args.index_split_1, args.index_split_2):
            optimizer.param_groups[index]['lr'] = current_lr * 10
        for index in range(args.index_split_2, len(optimizer.param_groups)):
            optimizer.param_groups[index]['lr'] = current_lr * 0
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        if (i + 1) % args.print_freq == 0 and main_process():
            logger.info('Epoch: [{}/{}][{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Remain {remain_time} '
                        'DerainLoss {derain_loss_meter:.4f} '
                        'SegLoss {seg_loss_meter:.4f} '
                        'Loss {loss_meter:.4f} '
                        'Accuracy {accuracy:.4f}.'
                        'PSNR {psnr_val:.2f}.'
                        'SSIM {ssim_val:.4f}.'.format(
                            epoch + 1,
                            args.epochs,
                            i + 1,
                            len(train_loader),
                            batch_time=batch_time,
                            data_time=data_time,
                            remain_time=remain_time,
                            derain_loss_meter=derain_loss_meter.val,
                            seg_loss_meter=seg_loss_meter.val,
                            loss_meter=loss_meter.val,
                            accuracy=accuracy,
                            psnr_val=psnr_val,
                            ssim_val=ssim_val))

        if main_process():
            writer.add_scalar('derain_loss_train_batch', derain_loss_meter.val,
                              current_iter)
            writer.add_scalar('seg_loss_train_batch', seg_loss_meter.val,
                              current_iter)
            writer.add_scalar('loss_train_batch', loss_meter.val, current_iter)
            # writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter)
            # writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter)
            writer.add_scalar('allAcc_train_batch', accuracy, current_iter)
            writer.add_scalar('psnr_train_batch', psnr_val, current_iter)
            writer.add_scalar('ssim_train_batch', ssim_val, current_iter)

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    # mIoU = np.mean(iou_class)
    # mAcc = np.mean(accuracy_class)
    # allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
    mIoU = 0
    mAcc = 0
    allAcc = 0
    if main_process():
        logger.info(
            'Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'
            .format(epoch + 1, args.epochs, mIoU, mAcc, allAcc))
        logger.info(
            'Train result at epoch [{}/{}]: PSNR/SSIM {:.4f}/{:.4f}.'.format(
                epoch + 1, args.epochs, psnr_meter.avg, ssim_meter.avg))
    return loss_meter.avg, mIoU, mAcc, allAcc, psnr_meter.avg, ssim_meter.avg
Esempio n. 23
0
def reduce_sum(tensor):
    import torch.distributed as dist
    tensor = tensor.clone()
    dist.all_reduce(tensor, op=dist.reduce_op.SUM)
    return tensor
Esempio n. 24
0
def validate_kitti(model, args, eval_loader, group):
    """ Peform validation using the KITTI-2015 (train) split """
    """ Peform validation using the KITTI-2015 (train) split """
    model.eval()
    gpu = args.gpu
    eval_measures_depth = torch.zeros(10).cuda(device=gpu)
    eval_epe = torch.zeros(2).cuda(device=gpu)
    eval_out = torch.zeros(2).cuda(device=gpu)

    for val_id, data_blob in enumerate(tqdm(eval_loader)):
        image1 = data_blob['img1'].cuda(gpu) / 255.0
        image2 = data_blob['img2'].cuda(gpu) / 255.0
        intrinsic = data_blob['intrinsic'].cuda(gpu)
        insmap = data_blob['insmap'].cuda(gpu)
        posepred = data_blob['posepred'].cuda(gpu)
        depthgt = data_blob['depthmap'].cuda(gpu)
        flowmap = data_blob['flowmap'].cuda(gpu)

        selfpose_gt = data_blob['rel_pose'].cuda(gpu)
        fixed_posepred = (
            selfpose_gt @ torch.inverse(posepred[:, 0])).unsqueeze(1).expand(
                [-1, args.maxinsnum, -1, -1]) @ posepred

        outputs = model(image1,
                        image2,
                        intrinsic,
                        fixed_posepred,
                        insmap,
                        iters=args.iters)

        depth_predictions = outputs['depth_predictions']

        depth_prediction = depth_predictions[-1]

        selector = ((depth_prediction > 0) * (depthgt > args.min_depth_eval) *
                    (depthgt < args.max_depth_eval)).float()
        depth_prediction = torch.clamp(depth_prediction,
                                       min=args.min_depth_eval,
                                       max=args.max_depth_eval)
        depth_gt_flatten = depthgt[selector == 1].cpu().numpy()
        pred_depth_flatten = depth_prediction[selector == 1].cpu().numpy()

        eval_measures_depth_np = compute_errors(gt=depth_gt_flatten,
                                                pred=pred_depth_flatten)

        eval_measures_depth[:9] += torch.tensor(eval_measures_depth_np).cuda(
            device=gpu)
        eval_measures_depth[9] += 1

        flowpred = outputs['flowpred']
        epe = torch.sum((flowmap - flowpred)**2, dim=1).sqrt()
        mag = torch.sum(flowmap**2, dim=1).sqrt()

        epe = epe.view(-1)
        mag = mag.view(-1)
        val = (mag.view(-1) >= 0.5) * (mag.view(-1) < MAX_FLOW)

        out = ((epe > 3.0) & ((epe / mag) > 0.05)).float()
        eval_epe[0] += epe[val].mean()
        eval_epe[1] += 1
        eval_out[0] += out[val].mean()
        eval_out[1] += 1

    if args.distributed:
        dist.all_reduce(tensor=eval_measures_depth,
                        op=dist.ReduceOp.SUM,
                        group=group)
        dist.all_reduce(tensor=eval_epe, op=dist.ReduceOp.SUM, group=group)
        dist.all_reduce(tensor=eval_out, op=dist.ReduceOp.SUM, group=group)

    if args.gpu == 0:
        eval_measures_depth[
            0:9] = eval_measures_depth[0:9] / eval_measures_depth[9]
        eval_measures_depth = eval_measures_depth.cpu().numpy()
        eval_epe[0] = eval_epe[0] / eval_epe[1]
        eval_epe = eval_epe.cpu().numpy()
        eval_out[0] = eval_out[0] / eval_out[1]
        eval_out = eval_out.cpu().numpy()
        print('Computing Depth errors for %f eval samples' %
              (eval_measures_depth[9].item()))
        print(
            "{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}"
            .format('silog', 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms',
                    'd1', 'd2', 'd3', 'out'))
        for i in range(9):
            print('{:7.3f}, '.format(eval_measures_depth[i]), end='')
        print('{:7.3f}'.format(eval_out[0]))

        return {
            'silog': float(eval_measures_depth[0]),
            'abs_rel': float(eval_measures_depth[1]),
            'log10': float(eval_measures_depth[2]),
            'rms': float(eval_measures_depth[3]),
            'sq_rel': float(eval_measures_depth[4]),
            'log_rms': float(eval_measures_depth[5]),
            'd1': float(eval_measures_depth[6]),
            'd2': float(eval_measures_depth[7]),
            'd3': float(eval_measures_depth[8]),
            'out': float(eval_out[0]),
            'epe': float(eval_epe[0]),
        }
    else:
        return None
Esempio n. 25
0
 def backward(ctx, *grads):
     all_gradients = torch.stack(grads)
     dist.all_reduce(all_gradients)
     return all_gradients[dist.get_rank()]
Esempio n. 26
0
def all_reduce(tensor, group=None):
    if group is None:
        group = get_default_group()
    return dist.all_reduce(tensor, group=group)
Esempio n. 27
0
 def backward(ctx, grad_output):
     dist.all_reduce(grad_output, async_op=False)
     return grad_output
Esempio n. 28
0
    def train(self, envs):

        self.training_step = 0
        best_reward = torch.zeros((1,), device=self.device)
        eplen = torch.zeros((1,), device=self.device, dtype=torch.int32)
        visited_rooms = set()

        rollout_idx = 0
        state = np.transpose(envs.reset(), (0, 3, 1, 2))

        # rollout
        while rollout_idx < self.num_rollouts:
            # sync model
            distributed_util.sync_model(self.actor_critic)

            states = np.zeros(
                (self.num_steps, self.num_envs, 1, 84, 84), np.float32)
            actions = np.zeros((self.num_steps, self.num_envs), np.int32)
            action_log_probs = np.zeros(
                (self.num_steps, self.num_envs), np.float32)
            rewards = np.zeros((self.num_steps, self.num_envs), np.float32)
            next_states = np.zeros(
                (self.num_steps, self.num_envs, 1, 84, 84), np.float32)
            dones = np.zeros((self.num_steps, self.num_envs), np.int32)

            current_best_reward = torch.zeros((1,), device=self.device)
            hidden = None

            for t in range(self.num_steps):
                action, action_log_prob, hidden = self.select_action(
                    state, hidden)
                next_state, reward, done, info = envs.step(action)
                # TensorFlow format to PyTorch
                next_state = np.transpose(next_state, (0, 3, 1, 2))

                # transitions
                states[t, ...] = state
                actions[t, ...] = action
                action_log_probs[t, ...] = action_log_prob
                rewards[t, ...] = reward
                next_states[t, ...] = next_state
                dones[t, ...] = done

                if self.render:
                    envs.render(0)
                state = next_state

                # done
                for i, dne in enumerate(done):
                    if dne:
                        epinfo = info[i]['episode']
                        if 'visited_rooms' in epinfo:
                            visited_rooms |= epinfo['visited_rooms']

                        best_reward[0] = max(epinfo['r'], best_reward[0])
                        current_best_reward[0] = max(
                            epinfo['r'], current_best_reward[0])
                        eplen[0] += epinfo['l']

            # logger
            dist.all_reduce(best_reward, op=dist.ReduceOp.MAX)
            dist.all_reduce(current_best_reward, op=dist.ReduceOp.MAX)
            # TODO: sync visited_rooms

            if self.rank == 0:
                logger.info('GAME STATUS')
                logger.record_tabular('rollout_idx', rollout_idx)
                logger.record_tabular('visited_rooms',
                                      str(len(visited_rooms)) + ', ' + str(visited_rooms))
                logger.record_tabular('best_reward', best_reward.item())
                logger.record_tabular(
                    'current_best_reward', current_best_reward.item())
                logger.record_tabular(
                    'eplen', eplen.item() * dist.get_world_size())
                logger.dump_tabular()

            # train neural networks
            self.update_parameters(states, actions, action_log_probs,
                                   rewards, next_states, dones)
            rollout_idx += 1
Esempio n. 29
0
def reduce_tensor(tensor, world_size, reduce_op_max=False):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.reduce_op.MAX if reduce_op_max is True else dist.reduce_op.SUM)  # Default to sum
    if not reduce_op_max:
        rt /= world_size
    return rt
Esempio n. 30
0
 def helper(array):
     array = torch.FloatTensor(array)
     dist.all_reduce(array, op=dist.reduce_op.SUM)
     return array[0] / array[1]
Esempio n. 31
0
                for param_group in optimizer.param_groups:
                    param_group['lr'] = cfg.lr * 0.1**cfg.lr_steps.index(step)

            if cfg.cuda:
                images = images.cuda().detach()
                targets = [ann.cuda().detach() for ann in targets]
                masks = [mask.cuda().detach() for mask in masks]

            with timer.counter('for+loss'):
                loss_c, loss_b, loss_m, loss_s = net(images, targets, masks)

                if cfg.cuda:
                    # use .all_reduce() to get the summed loss from all GPUs
                    all_loss = torch.stack([loss_c, loss_b, loss_m, loss_s],
                                           dim=0)
                    dist.all_reduce(all_loss)

            with timer.counter('backward'):
                loss_total = loss_c + loss_b + loss_m + loss_s
                optimizer.zero_grad()
                loss_total.backward()

            with timer.counter('update'):
                optimizer.step()

            time_this = time.time()
            if step > start_step:
                batch_time = time_this - time_last
                timer.add_batch_time(batch_time)
            time_last = time_this
Esempio n. 32
0
            engine.update_iteration(epoch, idx)

            minibatch = dataloader.next()
            imgs = minibatch['data']
            gts = minibatch['label']
            cgts = minibatch['aux_label']

            imgs = imgs.cuda(non_blocking=True)
            gts = gts.cuda(non_blocking=True)
            cgts = cgts.cuda(non_blocking=True)

            loss = model(imgs, gts, cgts)

            # reduce the whole loss over multi-gpu
            if engine.distributed:
                dist.all_reduce(loss, dist.ReduceOp.SUM)
                loss = loss / engine.world_size
            else:
                loss = Reduce.apply(*loss) / len(loss)

            current_idx = epoch * config.niters_per_epoch + idx
            lr = lr_policy.get_lr(current_idx)

            optimizer.param_groups[0]['lr'] = lr
            optimizer.param_groups[1]['lr'] = lr
            for i in range(2, len(optimizer.param_groups)):
                optimizer.param_groups[i]['lr'] = lr * 10

            loss.backward()
            optimizer.step()
            print_str = 'Epoch{}/{}'.format(epoch, config.nepochs) \
Esempio n. 33
0
def _average_gradients(model):
    # Gradient averaging.
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=0)
        param.grad.data /= size
Esempio n. 34
0
def test_synchronize_sgd():
    torch.manual_seed(42)
    dist.init_process_group('mpi')
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    device = torch.device('cpu')
    # device = torch.device('cuda') # Uncomment this to run on GPU

    # N is batch size; D_in is input dimension;
    # H is hidden dimension; D_out is output dimension.
    N, D_in, H, D_out = 64, 1000, 100, 10

    # Create random Tensors to hold input and outputs
    x = torch.randn(N, D_in, device=device)
    y = torch.randn(N, D_out, device=device)

    x = x[rank::world_size]
    y = y[rank::world_size]

    # Create random Tensors for weights; setting requires_grad=True means that we
    # want to compute gradients for these Tensors during the backward pass.
    w1 = torch.randn(D_in, H, device=device, requires_grad=True)
    w2 = torch.randn(H, D_out, device=device, requires_grad=True)

    learning_rate = 1e-6
    for t in range(500):
        # Forward pass: compute predicted y using operations on Tensors. Since w1 and
        # w2 have requires_grad=True, operations involving these Tensors will cause
        # PyTorch to build a computational graph, allowing automatic computation of
        # gradients. Since we are no longer implementing the backward pass by hand we
        # don't need to keep references to intermediate values.
        y_pred = x.mm(w1).clamp(min=0).mm(w2)

        # Compute and print loss. Loss is a Tensor of shape (), and loss.item()
        # is a Python number giving its value.
        loss = (y_pred - y).pow(2).sum()

        if rank == 0:
            print("Iter {} : {:10.3e}".format(t, loss.item()))

        # Use autograd to compute the backward pass. This call will compute the
        # gradient of loss with respect to all Tensors with requires_grad=True.
        # After this call w1.grad and w2.grad will be Tensors holding the gradient
        # of the loss with respect to w1 and w2 respectively.
        loss.backward()

        # Update weights using gradient descent. For this step we just want to mutate
        # the values of w1 and w2 in-place; we don't want to build up a computational
        # graph for the update steps, so we use the torch.no_grad() context manager
        # to prevent PyTorch from building a computational graph for the updates
        with torch.no_grad():
            w1 -= learning_rate * w1.grad
            w2 -= learning_rate * w2.grad

            # Manually zero the gradients after running the backward pass
            w1.grad.zero_()
            w2.grad.zero_()

            # Synchronize weights
            dist.all_reduce(w1, op=dist.reduce_op.SUM)
            dist.all_reduce(w2, op=dist.reduce_op.SUM)
            w1 /= world_size
            w2 /= world_size
Esempio n. 35
0
def reduce_tensor(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.reduce_op.SUM)
    rt /= args.world_size
    return rt
Esempio n. 36
0
def reduce_dict(info_dict):
    for it in info_dict:
        p = info_dict[it].clone()
        dist.all_reduce(p, op=dist.reduce_op.SUM)
        info_dict[it] = p / dist.get_world_size()
Esempio n. 37
0
def online_eval(model, dataloader_eval, gpu, ngpus):
    if gpu == -1:
        device = torch.device('cpu')
    else:
        device = torch.device(f'cuda:{gpu}')

    eval_measures = torch.zeros(10).to(device)
    for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
        with torch.no_grad():
            image = torch.autograd.Variable(
                eval_sample_batched['image']).to(device)
            focal = torch.autograd.Variable(
                eval_sample_batched['focal']).to(device)
            gt_depth = eval_sample_batched['depth']
            has_valid_depth = eval_sample_batched['has_valid_depth']
            if not has_valid_depth:
                # print('Invalid depth. continue.')
                continue

            _, _, _, _, pred_depth = model(image, focal)

            pred_depth = pred_depth.cpu().numpy().squeeze()
            gt_depth = gt_depth.cpu().numpy().squeeze()

        if args.do_kb_crop:
            height, width = gt_depth.shape
            top_margin = int(height - 352)
            left_margin = int((width - 1216) / 2)
            pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
            pred_depth_uncropped[top_margin:top_margin +
                                 352, left_margin:left_margin +
                                 1216] = pred_depth
            pred_depth = pred_depth_uncropped

        pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
        pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
        pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
        pred_depth[np.isnan(pred_depth)] = args.min_depth_eval

        valid_mask = np.logical_and(gt_depth > args.min_depth_eval,
                                    gt_depth < args.max_depth_eval)

        if args.garg_crop or args.eigen_crop:
            gt_height, gt_width = gt_depth.shape
            eval_mask = np.zeros(valid_mask.shape)

            if args.garg_crop:
                eval_mask[int(0.40810811 * gt_height):int(0.99189189 *
                                                          gt_height),
                          int(0.03594771 * gt_width):int(0.96405229 *
                                                         gt_width)] = 1

            elif args.eigen_crop:
                if args.dataset == 'kitti':
                    eval_mask[int(0.3324324 * gt_height):int(0.91351351 *
                                                             gt_height),
                              int(0.0359477 * gt_width):int(0.96405229 *
                                                            gt_width)] = 1
                else:
                    eval_mask[45:471, 41:601] = 1

            valid_mask = np.logical_and(valid_mask, eval_mask)

        measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])

        eval_measures[:9] += torch.tensor(measures).to(device)
        eval_measures[9] += 1

    if args.multiprocessing_distributed:
        group = dist.new_group([i for i in range(ngpus)])
        dist.all_reduce(tensor=eval_measures, op=dist.ReduceOp.SUM, group=group)

    if not args.multiprocessing_distributed or gpu == 0:
        eval_measures_cpu = eval_measures.cpu()
        cnt = eval_measures_cpu[9].item()
        eval_measures_cpu /= cnt
        print('Computing errors for {} eval samples'.format(int(cnt)))
        print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}"
              .format('silog', 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms',
                      'd1', 'd2', 'd3'))
        for i in range(8):
            print('{:7.3f}, '.format(eval_measures_cpu[i]), end='')
        print('{:7.3f}'.format(eval_measures_cpu[8]))
        return eval_measures_cpu

    return None
Esempio n. 38
0
else:
    for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
        tensor = torch.ByteTensor(bytes).fill_(42)
        for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
            for i in range(0, num_tensors):
                dist.reduce(tensor, 0)
dist.barrier()

if rank == 0:
    print_header("all reduce")
    for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
        tensor = torch.ByteTensor(bytes).fill_(42)
        for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
            start = timer()
            for i in range(0, num_tensors):
                dist.all_reduce(tensor)
            end = timer()
            print_stats(bytes, num_tensors, end - start)
    print()
else:
    for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
        tensor = torch.ByteTensor(bytes).fill_(42)
        for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
            for i in range(0, num_tensors):
                dist.all_reduce(tensor)
dist.barrier()

if rank == 0:
    print_header("scatter")
    for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
        tensor = torch.ByteTensor(bytes).fill_(42)
Esempio n. 39
0
def reduce_tensor(tensor, n_gpus):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.reduce_op.SUM)
    rt /= n_gpus
    return rt
Esempio n. 40
0
def reduce_tensor(tensor, n):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= n
    return rt
Esempio n. 41
0
def reduce_tensor(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.reduce_op.SUM)
    rt /= args.world_size
    return rt
def average_gradients(model):
    """ average gradients """
    for param in model.parameters():
        if param.requires_grad and param.grad is not None:
            dist.all_reduce(param.grad.data)
Esempio n. 43
0
def determine_max_batch_size(cfg, distributed, dataset_len_per_gpu):
    def get_fake_input(cfg, orig_img_shape=(128, 128, 3), device='cuda'):
        test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
        test_pipeline = Compose(test_pipeline)
        data = dict(img=np.zeros(orig_img_shape, dtype=np.uint8))
        data = test_pipeline(data)
        data = scatter(collate([data], samples_per_gpu=1), [device])[0]
        return data

    model = build_detector(cfg.model,
                           train_cfg=cfg.train_cfg,
                           test_cfg=cfg.test_cfg).cuda()

    if 'pipeline' in cfg.data.train:
        img_shape = [
            t for t in cfg.data.train.pipeline if t['type'] == 'Resize'
        ][0]['img_scale']
    else:
        img_shape = [
            t for t in cfg.data.train.dataset.pipeline if t['type'] == 'Resize'
        ][0]['img_scale']

    channels = 3

    fake_input = get_fake_input(cfg,
                                orig_img_shape=list(img_shape) + [channels])
    img_shape = fake_input['img_metas'][0][0]['pad_shape']

    width, height = img_shape[0], img_shape[1]

    percentage = 0.9

    min_bs = 2
    max_bs = min(512, int(dataset_len_per_gpu / percentage) + 1)
    step = 1

    batch_size = min_bs
    for bs in range(min_bs, max_bs, step):
        try:
            gt_boxes = [
                torch.tensor([[0., 0., width, height]]).cuda()
                for _ in range(bs)
            ]
            gt_labels = [
                torch.tensor([0], dtype=torch.long).cuda() for _ in range(bs)
            ]
            img_metas = [fake_input['img_metas'][0][0] for _ in range(bs)]

            gt_masks = None

            if isinstance(model,
                          TwoStageDetector) and model.roi_head.with_mask:
                rles = maskUtils.frPyObjects(
                    [[0.0, 0.0, width, 0.0, width, height, 0.0, height]],
                    height, width)
                rle = maskUtils.merge(rles)
                mask = maskUtils.decode(rle)
                gt_masks = [
                    BitmapMasks([mask], height, width) for _ in range(bs)
                ]

            if gt_masks is None:
                model(torch.rand(bs, channels, height, width).cuda(),
                      img_metas=img_metas,
                      gt_bboxes=gt_boxes,
                      gt_labels=gt_labels)
            else:
                model(torch.rand(bs, channels, height, width).cuda(),
                      img_metas=img_metas,
                      gt_bboxes=gt_boxes,
                      gt_labels=gt_labels,
                      gt_masks=gt_masks)

            batch_size = bs
        except RuntimeError as e:
            if str(e).startswith('CUDA out of memory'):
                break

    resulting_batch_size = int(batch_size * percentage)

    del model
    torch.cuda.empty_cache()

    if distributed:
        rank, world_size = get_dist_info()

        resulting_batch_size = torch.tensor(resulting_batch_size).cuda()
        dist.all_reduce(resulting_batch_size, torch.distributed.ReduceOp.MIN)
        print('rank', rank, 'resulting_batch_size', resulting_batch_size)

        resulting_batch_size = int(resulting_batch_size.cpu())
    else:
        print('resulting_batch_size', resulting_batch_size)

    return resulting_batch_size
def average_params(model):
    """ broadcast model parameters """
    worldsize = dist.get_world_size()
    for p in model.state_dict().values():
        dist.all_reduce(p)
        p /= worldsize
Esempio n. 45
0
    def run(self):
        local_step_count = global_step_count = self.initial_step_count
        ep_rewards = torch.zeros(self.nb_env)

        obs = dtensor_to_dev(self.env_mgr.reset(), self.device)
        internals = listd_to_dlist([
            self.network.new_internals(self.device) for _ in
            range(self.nb_env)
        ])
        start_time = time()
        while global_step_count < self.nb_step:
            actions, internals = self.agent.act(self.network, obs, internals)
            next_obs, rewards, terminals, infos = self.env_mgr.step(actions)
            next_obs = dtensor_to_dev(next_obs, self.device)

            self.agent.observe(
                obs,
                rewards.to(self.device).float(),
                terminals.to(self.device).float(),
                infos
            )
            for i, terminal in enumerate(terminals):
                if terminal:
                    for k, v in self.network.new_internals(self.device).items():
                        internals[k][i] = v

            # Perform state updates
            local_step_count += self.nb_env
            global_step_count += self.nb_env * self.world_size
            ep_rewards += rewards.float()
            obs = next_obs

            term_rewards = []
            for i, terminal in enumerate(terminals):
                if terminal:
                    for k, v in self.network.new_internals(self.device).items():
                        internals[k][i] = v
                    term_rewards.append(ep_rewards[i].item())
                    ep_rewards[i].zero_()

            if term_rewards:
                term_reward = np.mean(term_rewards)
                delta_t = time() - start_time
                self.logger.info(
                    'RANK: {} '
                    'GLOBAL STEP: {} '
                    'REWARD: {} '
                    'GLOBAL STEP/S: {} '
                    'LOCAL STEP/S: {}'.format(
                        self.global_rank,
                        global_step_count,
                        term_reward,
                        (global_step_count - self.initial_step_count) / delta_t,
                        (local_step_count - self.initial_step_count) / delta_t
                    )
                )

            # Learn
            if self.agent.is_ready():
                loss_dict, metric_dict = self.agent.compute_loss(
                    self.network, next_obs, internals
                )
                total_loss = torch.sum(
                    torch.stack(tuple(loss for loss in loss_dict.values()))
                )

                self.optimizer.zero_grad()
                total_loss.backward()
                dist.barrier()
                handles = []
                for param in self.network.parameters():
                    handles.append(
                        dist.all_reduce(param.grad, async_op=True))
                for handle in handles:
                    handle.wait()
                # for param in self.network.parameters():
                #     param.grad.mul_(1. / self.world_size)
                self.optimizer.step()

                self.agent.clear()
                for k, vs in internals.items():
                    internals[k] = [v.detach() for v in vs]