示例#1
0
def train(config, epoch, num_epoch, epoch_iters, base_lr, num_iters,
          trainloader, optimizer, model, writer_dict, device):

    # Training
    model.train()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    ave_loss1 = AverageMeter()
    ave_aux_loss = AverageMeter()
    ave_error_loss = AverageMeter()
    tic = time.time()
    cur_iters = epoch * epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']
    rank = get_rank()
    world_size = get_world_size()

    for i_iter, batch in enumerate(trainloader):
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)
        losses, aux_loss, error_loss, _ = model(images, labels)
        # print('pred', pred[2].size())
        loss = losses.mean() + 0.4 * aux_loss.mean() + 1 * error_loss.mean()

        reduced_loss = reduce_tensor(loss)
        loss1 = reduce_tensor(losses)
        aux_loss = reduce_tensor(aux_loss)
        error_losses = reduce_tensor(error_loss)

        model.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(reduced_loss.item())
        ave_loss1.update(loss1.item())
        ave_aux_loss.update(aux_loss.item())
        ave_error_loss.update(error_losses.item())

        lr = adjust_learning_rate(optimizer, base_lr, num_iters,
                                  i_iter + cur_iters)

        if i_iter % config.PRINT_FREQ == 0 and rank == 0:
            print_loss = ave_loss.average() / world_size
            print_loss1 = ave_loss1.average() / world_size
            print_loss_aux = ave_aux_loss.average() / world_size
            print_error_loss = ave_error_loss.average() / world_size
            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {:.6f}, Loss: {:.6f}, Loss_1: {:.6f}, Loss_aux: {:.6f}, error_loss: {:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters,
                      batch_time.average(), lr, print_loss, print_loss1, print_loss_aux, print_error_loss)
            logging.info(msg)

            writer.add_scalar('train_loss', print_loss, global_steps)
            writer_dict['train_global_steps'] = global_steps + 1
示例#2
0
def validate(config, testloader, model, writer_dict, device):

    rank = get_rank()
    world_size = get_world_size()
    model.eval()
    ave_loss = AverageMeter()
    confusion_matrix = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))
    confusion_matrix_sum = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))

    with torch.no_grad():
        for _, batch in enumerate(testloader):
            image, label, boundary_gt, _, _ = batch
            size = label.size()
            image = image.to(device)
            boundary_gt = boundary_gt.to(device)
            label = label.long().to(device)

            losses, aux_loss, error_loss, losses_2, aux_loss_2, error_loss_2, preds = model(
                image, label, boundary_gt.float())
            pred = F.upsample(input=preds[0],
                              size=(size[-2], size[-1]),
                              mode='bilinear')

            loss = (losses + 0.4 * aux_loss + 4 * error_loss + losses_2 +
                    0.4 * aux_loss_2 + 4 * error_loss_2).mean()
            reduced_loss = reduce_tensor(loss)
            ave_loss.update(reduced_loss.item())

            confusion_matrix += get_confusion_matrix(
                label, pred, size, config.DATASET.NUM_CLASSES,
                config.TRAIN.IGNORE_LABEL)

    confusion_matrix = torch.from_numpy(confusion_matrix).to(device)
    reduced_confusion_matrix = reduce_tensor(confusion_matrix)

    confusion_matrix = reduced_confusion_matrix.cpu().numpy()
    pos = confusion_matrix.sum(1)
    res = confusion_matrix.sum(0)
    tp = np.diag(confusion_matrix)
    IoU_array = (tp / np.maximum(1.0, pos + res - tp))
    mean_IoU = IoU_array.mean()
    print_loss = ave_loss.average() / world_size

    if rank == 0:
        writer = writer_dict['writer']
        global_steps = writer_dict['valid_global_steps']
        writer.add_scalar('valid_loss', print_loss, global_steps)
        writer.add_scalar('valid_mIoU', mean_IoU, global_steps)
        writer_dict['valid_global_steps'] = global_steps + 1
        # cv2.imwrite(str(global_steps)+'_boundary.png', (preds[0][0][0].data.cpu().numpy()*255).astype(np.uint8))
        # cv2.imwrite(str(global_steps) + '_error.png', (preds[2][0][0].data.cpu().numpy() * 255).astype(np.uint8))
        cv2.imwrite(
            str(global_steps) + '_error.png',
            (preds[2][0][0].data.cpu().numpy() * 255).astype(np.uint8))
    return print_loss, mean_IoU, IoU_array
def train(config, epoch, num_epoch, epoch_iters, base_lr, num_iters,
         trainloader, optimizer, lr_scheduler, model, writer_dict, device):
    
    # Training
    model.train()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    tic = time.time()
    cur_iters = epoch*epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']
    rank = get_rank()
    world_size = get_world_size()

    for i_iter, batch in enumerate(trainloader):
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        losses, _ = model(images, labels, train_step=(lr_scheduler._step_count-1))
        loss = losses.mean()

        reduced_loss = reduce_tensor(loss)

        model.zero_grad()
        loss.backward()
        optimizer.step()
        
        if config.TRAIN.LR_SCHEDULER != 'step':
            lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(reduced_loss.item())

        lr = adjust_learning_rate(optimizer,
                                  base_lr,
                                  num_iters,
                                  i_iter+cur_iters)

        if i_iter % config.PRINT_FREQ == 0 and rank == 0:
            print_loss = ave_loss.average() / world_size
            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {:.6f}, Loss: {:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters, 
                      batch_time.average(), lr, print_loss)
            logging.info(msg)
            
            writer.add_scalar('train_loss', print_loss, global_steps)
            writer_dict['train_global_steps'] = global_steps + 1
            batch_time = AverageMeter()
示例#4
0
def validate(config, testloader, model, writer_dict, device):
    rank = get_rank()
    world_size = get_world_size()
    model.eval()
    ave_loss = AverageMeter()
    tot_inter = np.zeros(config.DATASET.NUM_CLASSES)
    tot_union = np.zeros(config.DATASET.NUM_CLASSES)
    with torch.no_grad():
        for i_iter, batch in enumerate(testloader):
            image, label, _, _ = batch
            size = label.size()
            label = label.long().to(device)
            image = image.to(device)

            loss, pred = model(image, label)
            if pred.size()[-2] != size[-2] or pred.size()[-1] != size[-1]:
                pred = F.interpolate(pred,
                                     size=(size[-2], size[-1]),
                                     mode='bilinear',
                                     align_corners=False)
            reduced_loss = reduce_tensor(loss)
            ave_loss.update(reduced_loss.item())

            batch_inter, batch_union = batch_intersection_union(
                pred, label, config.DATASET.NUM_CLASSES)
            tot_inter += batch_inter
            tot_union += batch_union

            if i_iter % config.PRINT_FREQ == 0 and rank == 0:
                msg = f'Iter: {i_iter}, Loss: {ave_loss.average() / world_size:.6f}'
                logging.info(msg)

    tot_inter = torch.from_numpy(tot_inter).to(device)
    tot_union = torch.from_numpy(tot_union).to(device)
    tot_inter = reduce_tensor(tot_inter).cpu().numpy()
    tot_union = reduce_tensor(tot_union).cpu().numpy()
    IoU = np.float64(1.0) * tot_inter / (np.spacing(1, dtype=np.float64) +
                                         tot_union)
    mean_IoU = IoU.mean()
    print_loss = ave_loss.average() / world_size

    if rank == 0:
        writer = writer_dict['writer']
        global_steps = writer_dict['valid_global_steps']
        writer.add_scalar('valid_loss', print_loss, global_steps)
        writer.add_scalar('valid_mIoU', mean_IoU, global_steps)
        writer_dict['valid_global_steps'] = global_steps + 1

    return print_loss, mean_IoU
示例#5
0
def validate(config, testloader, model, writer_dict, device):
    
    rank = get_rank()
    world_size = get_world_size()
    model.eval()
    ave_loss = AverageMeter()
    confusion_matrix = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))

    with torch.no_grad():
        for _, batch in enumerate(testloader):
            image, label, _, _ = batch
            size = label.size()
            image = image.to(device)
            label = label.long().to(device)

            losses, pred = model(image, label)
            pred = F.upsample(input=pred, size=(
                        size[-2], size[-1]), mode='bilinear')
            loss = losses.mean()
            reduced_loss = reduce_tensor(loss)
            ave_loss.update(reduced_loss.item())

            confusion_matrix += get_confusion_matrix(
                label,
                pred,
                size,
                config.DATASET.NUM_CLASSES,
                config.TRAIN.IGNORE_LABEL)

    confusion_matrix = torch.from_numpy(confusion_matrix).to(device)
    reduced_confusion_matrix = reduce_tensor(confusion_matrix)

    confusion_matrix = reduced_confusion_matrix.cpu().numpy()
    pos = confusion_matrix.sum(1)
    res = confusion_matrix.sum(0)
    tp = np.diag(confusion_matrix)
    IoU_array = (tp / np.maximum(1.0, pos + res - tp))
    mean_IoU = IoU_array.mean()
    print_loss = ave_loss.average()/world_size

    if rank == 0:
        writer = writer_dict['writer']
        global_steps = writer_dict['valid_global_steps']
        writer.add_scalar('valid_loss', print_loss, global_steps)
        writer.add_scalar('valid_mIoU', mean_IoU, global_steps)
        writer_dict['valid_global_steps'] = global_steps + 1
    return print_loss, mean_IoU, IoU_array
示例#6
0
def train(config, epoch, num_epoch, epoch_iters, trainloader, optimizer,
          lr_scheduler, model, writer_dict, device):
    # Training
    model.train()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    tic = time.time()
    rank = get_rank()
    world_size = get_world_size()

    for i_iter, batch in enumerate(trainloader, 0):
        images, labels, _, _ = batch
        labels = labels.long().to(device)
        images = images.to(device)

        loss, _ = model(images, labels)
        reduced_loss = reduce_tensor(loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(reduced_loss.item())

        lr = optimizer.param_groups[0]['lr']

        if i_iter % config.PRINT_FREQ == 0 and rank == 0:
            print_loss = ave_loss.average() / world_size
            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {:.6f}, Loss: {:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters,
                      batch_time.average(), lr, print_loss)
            logging.info(msg)

    if rank == 0:
        writer = writer_dict['writer']
        global_steps = writer_dict['train_global_steps']
        writer.add_scalar('train_loss',
                          ave_loss.average() / world_size, global_steps)
        writer_dict['train_global_steps'] = global_steps + 1
def lovasz_softmax_flat(probas, labels, classes='present', ignore=255):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
    """
    C = probas.size(1)
    losses = []
    class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes

    num_gpus = dist.get_world_size()
    rank = get_rank()
    labels_collect = []
    probas_collect = []
    for r in range(num_gpus):
        labels_collect.append(to_cuda(torch.ones(labels.size()).long()))
        probas_collect.append(to_cuda(torch.ones(probas.size())))

    labels_collect[rank] = labels.clone()
    probas_collect[rank] = probas.clone()

    for r in range(num_gpus):
        dist.broadcast(labels_collect[r], src=r)
        dist.broadcast(probas_collect[r], src=r)

    num_valids = []
    for r in range(num_gpus):
        num_valids.append(torch.sum(labels_collect[r] != 255).item())
    num_valids = np.cumsum(num_valids)

    labels_collect = torch.cat(labels_collect, dim=0).detach()
    probas_collect = torch.cat(probas_collect, dim=0).detach()

    valid_labels = (labels_collect != 255)
    assert(torch.sum(valid_labels).item() == num_valids[-1])
    labels_collect = labels_collect[valid_labels]
    probas_collect = probas_collect[valid_labels.nonzero().squeeze()]

    lg_collect_cls = {}
    start = 0 if rank == 0 else num_valids[rank-1]
    end = num_valids[rank]

    for c in class_to_sum:
        fg_collect = (labels_collect == c).float()
        if (classes == 'present' and fg_collect.sum() == 0):
            continue

        if C == 1:
            if len(classes) > 1:
                raise ValueError('Sigmoid output possible only with 1 class')
            class_pred_collect = probas_collect[:, 0]
        else:
            class_pred_collect = probas_collect[:, c]

        errors_collect = (fg_collect - class_pred_collect).abs()

        _, perm = torch.sort(errors_collect, 0, descending=True)
        perm = perm.data
        fg_collect_sorted = fg_collect[perm]
        lg_collect = lovasz_grad(fg_collect_sorted)
        assert(num_valids[-1] == lg_collect.size(0))

        lg_collect = to_cuda(torch.zeros(lg_collect.size())).scatter_(0, perm,
                              lg_collect).detach()

        #errors_collect = to_cuda(torch.zeros(errors_collect.size())).scatter_(0, perm,
        #                      errors_collect).detach()
        #
        #lg = lg_collect[start:end].data
        #errors = errors_collect[start:end]
        #losses.append(torch.dot(errors, lg))

        lg_collect_cls[c] = lg_collect

    #print(num_valids)
    valid = (labels != 255)
    labels = labels[valid]
    probas = probas[valid.nonzero().squeeze()]

    if probas.numel() == 0:
        # only void pixels, the gradients should be 0
        #return probas * 0.
        return 0.0

    for c in class_to_sum:
        fg = (labels == c).float() # foreground for class c

        if (classes == 'present' and fg.sum() == 0):
            continue

        if C == 1:
            if len(classes) > 1:
                raise ValueError('Sigmoid output possible only with 1 class')
            class_pred = probas[:, 0]
        else:
            class_pred = probas[:, c]

        errors = (fg - class_pred).abs()

        lg = lg_collect_cls[c][start:end].data
        losses.append(torch.dot(errors, lg))

    return mean(losses) * num_gpus
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
    #args.data_dir = args.glue_dir
    args.cache_dir = args.data_dir
    args.version_2_with_negative = False
    args.overwrite_cache = False#True
    args.max_seq_length = 384
    args.doc_stride = 128
    args.max_query_length = 64
    args.threads = 1
    args.train_file = None
    args.predict_file = None
    args.joint_prediction = False

    if args.local_rank not in [-1, 0] and not evaluate:
         # Make sure only the first process in distributed training process the dataset, and the others will use the cache
        torch.distributed.barrier()

    # Load data features from cache or dataset file
    input_dir = args.data_dir or args.cache_dir or "."
    cached_features_file = os.path.join(
        input_dir,
        "cached_{}_{}_{}".format(
            "2.0" if args.version_2_with_negative else "1.1",
            "dev" if evaluate else "train_aug" if args.aug_train else "train",
            str(args.max_seq_length),
        ),
    )

    # Init features and dataset from cache if it exists
    # cached_features_file = "/squad/train-v1.1.json_bert-large-uncased_384_128_64"
    if os.path.exists(cached_features_file) and not args.overwrite_cache:
        logger.info("Loading features from cached file %s", cached_features_file)
        features_and_dataset = torch.load(cached_features_file)
        if not output_examples:
            dataset = features_and_dataset["dataset"]
        else:
            features, dataset, examples = (
                features_and_dataset["features"],
                features_and_dataset["dataset"],
                features_and_dataset["examples"],
            )
        del features_and_dataset
    else:
        logger.info("Creating features from dataset file at %s", input_dir)

        if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)):
            try:
                import tensorflow_datasets as tfds
            except ImportError:
                raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")

            if args.version_2_with_negative:
                logger.warn("tensorflow_datasets does not handle version 2 of SQuAD.")

            tfds_examples = tfds.load("squad")
            examples = SquadV1Processor(aug_data=args.aug_train).get_examples_from_dataset(tfds_examples, evaluate=evaluate)
        else:
            print("process {} calling".format(get_rank()))
            processor = SquadV2Processor(aug_data=args.aug_train) if args.version_2_with_negative else SquadV1Processor(aug_data=args.aug_train)
            if evaluate:
                examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
            else:
                examples = processor.get_train_examples(args.data_dir, filename=args.train_file)

        features, dataset = squad_convert_examples_to_features(
            examples=examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=not evaluate,
            return_dataset="pt",
            threads=args.threads,
        )

        if args.local_rank in [-1, 0]:
            logger.info("Saving features into cached file %s", cached_features_file)
            if not output_examples:
                torch.save({"dataset": dataset}, cached_features_file)
            else:
                torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)

    if args.local_rank == 0 and not evaluate:
    #     # Make sure only the first process in distributed training process the dataset, and the others will use the cache
        torch.distributed.barrier()

    if output_examples:
        return dataset, examples, features
    return dataset
示例#9
0
def train(config, epoch, num_epoch, epoch_iters, base_lr, num_iters,
          trainloader, optimizer, model, writer_dict, device):

    # Training
    model.train()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    ave_loss_joints = AverageMeter()
    ave_loss_inp = AverageMeter()
    ave_acc = AverageMeter()
    tic = time.time()
    cur_iters = epoch * epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']
    rank = get_rank()
    world_size = get_world_size()

    for i_iter, batch in enumerate(trainloader):
        images, labels, target_weight, _, name, joints, joints_vis = batch
        size = labels.size()
        #cv2.imwrite('groundtruth/gt_'+str(i_iter)+'.png', labels[0].detach().numpy())
        images = images.to(device)
        labels = labels.to(device)

        losses, losses_joints, losses_inp, pred = model(
            images, labels, target_weight)  #forward
        #pred = F.upsample(input=pred, size=(size[-2], size[-1]), mode='bilinear')
        #pred = pred.to('cpu')
        #cv2.imwrite('prediction/pred_'+str(i_iter)+'.png',pred[0][0].detach().numpy())
        #print("saved")

        label_joints, _ = get_max_preds(
            labels[:, 0:15, :, :].detach().cpu().numpy())
        pred_joints, _ = get_max_preds(pred[:,
                                            0:15, :, :].detach().cpu().numpy())

        _, acc, _, _ = accuracy(pred[:, 0:15, :, :].detach().cpu().numpy(),
                                labels[:, 0:15, :, :].detach().cpu().numpy())

        save_batch_image_with_joints(
            images[:, 0:3, :, :], label_joints * 4, joints_vis,
            'results/full_RGBD/train/joint_gt/{}_gt.png'.format(i_iter))
        save_batch_image_with_joints(
            images[:, 0:3, :, :], pred_joints * 4, joints_vis,
            'results/full_RGBD/train/joint_pred/{}_pred.png'.format(i_iter))

        labels = F.upsample(input=labels, size=(256, 256), mode='bilinear')
        pred = F.upsample(input=pred, size=(256, 256), mode='bilinear')

        cv2.imwrite(
            'results/full_RGBD/train/depth_gt/{}_gt.png'.format(i_iter),
            labels[0, 15, :, :].detach().cpu().numpy())
        cv2.imwrite(
            'results/full_RGBD/train/depth_pred/{}_pred.png'.format(i_iter),
            pred[0, 15, :, :].detach().cpu().numpy())

        loss = losses.mean()
        loss_joints = losses_joints.mean()
        loss_inp = losses_inp.mean()

        reduced_loss = reduce_tensor(loss)
        reduced_loss_joints = reduce_tensor(loss_joints)
        reduced_loss_inp = reduce_tensor(loss_inp)

        model.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(reduced_loss.item())
        ave_loss_joints.update(reduced_loss_joints.item())
        ave_loss_inp.update(reduced_loss_inp.item())
        ave_acc.update(acc)

        lr = adjust_learning_rate(optimizer, base_lr, num_iters,
                                  i_iter + cur_iters)

        if i_iter % config.PRINT_FREQ == 0 and rank == 0:
            print_loss = ave_loss.average() / world_size
            print_loss_joints = ave_loss_joints.average() / world_size
            print_loss_inp = ave_loss_inp.average() / world_size
            print_acc = ave_acc.average() / world_size

            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {:.6f}, Loss: {:.6f}, {:.6f}, {:.6f}, Acc: {:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters,
                      batch_time.average(), lr, print_loss, print_loss_joints, print_loss_inp,print_acc)
            logging.info(msg)

            writer.add_scalar('train_loss', print_loss, global_steps)
            writer.add_scalar('train_loss_joint', print_loss_joints,
                              global_steps)
            writer.add_scalar('train_loss_depth', print_loss_inp, global_steps)
            writer.add_scalar('train_accuracy', print_acc, global_steps)
            writer_dict['train_global_steps'] = global_steps + 1
示例#10
0
def validate(config, testloader, model, writer_dict, device):

    rank = get_rank()  #0
    world_size = get_world_size()  #1
    model.eval()
    ave_loss = AverageMeter()
    ave_loss_joints = AverageMeter()
    ave_loss_inp = AverageMeter()
    ave_accs = AverageMeter()
    ave_acc = AverageMeter()
    confusion_matrix = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))

    with torch.no_grad():
        for i_iter, batch in enumerate(testloader):
            image, label, target_weight, _, name, joints, joints_vis = batch
            size = label.size()
            #cv2.imwrite('validation_result/groundtruth/gt_'+str(i_iter)+'.png', label[0].detach().numpy())
            image = image.to(device)
            label = label.to(device)

            losses, losses_joints, losses_inp, pred = model(
                image, label, target_weight)

            #pred = F.upsample(input=pred, size=(64, 64), mode='bilinear')

            label_joints, _ = get_max_preds(
                label[:, 0:15, :, :].detach().cpu().numpy())
            pred_joints, _ = get_max_preds(
                pred[:, 0:15, :, :].detach().cpu().numpy())

            accs, acc, _, _ = accuracy(
                pred[:, 0:15, :, :].detach().cpu().numpy(),
                label[:, 0:15, :, :].detach().cpu().numpy())

            save_batch_image_with_joints(
                image[:, 0:3, :, :], label_joints * 4, joints_vis,
                'results/full_RGBD/val/joint_gt/{}_gt.png'.format(i_iter))
            save_batch_image_with_joints(
                image[:, 0:3, :, :], pred_joints * 4, joints_vis,
                'results/full_RGBD/val/joint_pred/{}_pred.png'.format(i_iter))

            label = F.upsample(input=label, size=(256, 256), mode='bilinear')
            pred = F.upsample(input=pred, size=(256, 256), mode='bilinear')

            cv2.imwrite(
                'results/full_RGBD/val/depth_gt/{}_gt.png'.format(i_iter),
                label[0, 15, :, :].detach().cpu().numpy())
            cv2.imwrite(
                'results/full_RGBD/val/depth_pred/{}_pred.png'.format(i_iter),
                pred[0, 15, :, :].detach().cpu().numpy())

            loss = losses.mean()
            loss_joints = losses_joints.mean()
            loss_inp = losses_inp.mean()

            reduced_loss = reduce_tensor(loss)
            reduced_loss_joints = reduce_tensor(loss_joints)
            reduced_loss_inp = reduce_tensor(loss_inp)

            ave_loss.update(reduced_loss.item())
            ave_loss_joints.update(reduced_loss_joints.item())
            ave_loss_inp.update(reduced_loss_inp.item())
            ave_acc.update(acc)
            ave_accs.update(accs)

    print_loss = ave_loss.average() / world_size
    print_loss_joints = ave_loss_joints.average() / world_size
    print_loss_inp = ave_loss_inp.average() / world_size
    print_acc = ave_acc.average() / world_size
    print_accs = ave_accs.average() / world_size

    if rank == 0:
        writer = writer_dict['writer']
        global_steps = writer_dict['valid_global_steps']
        writer.add_scalar('valid_loss', print_loss, global_steps)
        writer.add_scalar('valid_loss_joint', print_loss_joints, global_steps)
        writer.add_scalar('valid_loss_depth', print_loss_inp, global_steps)
        writer.add_scalar('valid_accuracy', print_acc, global_steps)
        for i in range(15):
            writer.add_scalar('valid_each_accuracy_' + str(i), print_accs[i],
                              global_steps)
        writer_dict['valid_global_steps'] = global_steps + 1
    return print_loss, print_loss_joints, print_loss_inp, print_acc