Пример #1
0
def train(args, train_loader, pose_exp_net, optimizer, epoch_size, logger, tb_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)

    pose_exp_net.train()
    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, tgt_lf, ref_imgs, ref_lfs, intrinsics, intrinsics_inv, pose_gt) in enumerate(train_loader):

        data_time.update(time.time() - end)
        tgt_lf = tgt_lf.to(device)
        ref_lfs = [lf.to(device) for lf in ref_lfs]
        pose_gt = pose_gt.to(device)

        explainability_mask, pose = pose_exp_net(tgt_lf, ref_lfs)
        loss = (pose - pose_gt).abs().mean()
        losses.update(loss.item(), args.batch_size)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

        logger.train_bar.update(i+1)
        tb_writer.add_scalar('loss/train', loss, n_iter)

        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Пример #2
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        logger,
                        output_writers=[]):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = len(output_writers) > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        # compute output
        disp = disp_net(tgt_img)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        loss_1 = loss_1.item()
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).item()
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth).item()

        if log_outputs and i < len(
                output_writers):  # log first output of first batches
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    output_writers[i].add_image('val Input {}'.format(j),
                                                tensor2array(tgt_img[0]), 0)
                    output_writers[i].add_image('val Input {}'.format(j),
                                                tensor2array(ref[0]), 1)

            log_output_tensorboard(output_writers[i], 'val', '', epoch,
                                   1. / disp, disp, warped, diff,
                                   explainability_mask)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            output_writers[0].add_histogram(
                '{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch)
        output_writers[0].add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
Пример #3
0
def train(train_loader,
          alice_net,
          bob_net,
          mod_net,
          optimizer,
          epoch_size,
          logger=None,
          train_writer=None,
          mode='compete'):
    global args, n_iter
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)

    # switch to train mode
    alice_net.train()
    bob_net.train()
    mod_net.train()

    end = time.time()

    for i, (img, target) in enumerate(train_loader):
        # measure data loading time
        #mode = 'compete' if (i%2)==0 else 'collaborate'

        data_time.update(time.time() - end)
        img_var = Variable(img.cuda())
        target_var = Variable(target.cuda())

        pred_alice = alice_net(img_var)
        pred_bob = bob_net(img_var)
        pred_mod = mod_net(img_var)

        loss_alice = F.cross_entropy(pred_alice, target_var, reduce=False)
        loss_bob = F.cross_entropy(pred_bob, target_var, reduce=False)

        if mode == 'compete':
            if args.fix_bob:
                if args.DEBUG: print("Training Alice Only")
                loss = loss_alice.mean()
            elif args.fix_alice:
                loss = loss_bob.mean()
            else:
                if args.DEBUG: print("Training Both Alice and Bob")

                pred_mod_soft = Variable(F.sigmoid(pred_mod).data,
                                         requires_grad=False)
                loss = pred_mod_soft * loss_alice + (1 -
                                                     pred_mod_soft) * loss_bob

                loss = loss.mean()

        elif mode == 'collaborate':
            loss_alice2 = Variable(loss_alice.data, requires_grad=False)
            loss_bob2 = Variable(loss_bob.data, requires_grad=False)

            loss1 = F.sigmoid(pred_mod) * loss_alice2 + (
                1 - F.sigmoid(pred_mod)) * loss_bob2

            loss2 = collaboration_loss(pred_mod, loss_alice2, loss_bob2)

            loss = loss1.mean() + loss2.mean(
            ) + args.wr * mod_regularization_loss(pred_mod)

        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('loss_alice',
                                    loss_alice.mean().item(), n_iter)
            train_writer.add_scalar('loss_bob', loss_bob.mean().item(), n_iter)
            train_writer.add_scalar('mod_mean',
                                    F.sigmoid(pred_mod).mean().item(), n_iter)
            train_writer.add_scalar('mod_var',
                                    F.sigmoid(pred_mod).var().item(), n_iter)
            train_writer.add_scalar('loss_regularization',
                                    mod_regularization_loss(pred_mod).item(),
                                    n_iter)

            if mode == 'compete':
                train_writer.add_scalar('competetion_loss', loss.item(),
                                        n_iter)
            elif mode == 'collaborate':
                train_writer.add_scalar('collaboration_loss', loss.item(),
                                        n_iter)

        # record loss
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_alice.mean().item(),
                loss_bob.mean().item()
            ])
        if args.log_terminal:
            logger.train_bar.update(i + 1)
            if i % args.print_freq == 0:
                logger.train_writer.write(
                    'Train: Time {} Data {} Loss {}'.format(
                        batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Пример #4
0
def validate_with_gt(args,
                     val_loader,
                     mvdnet,
                     depth_cons,
                     epoch,
                     output_writers=[]):
    batch_time = AverageMeter()
    error_names = [
        'abs_rel', 'abs_diff', 'sq_rel', 'a1', 'a2', 'a3', 'mean_angle'
    ]
    test_error_names = [
        'abs_rel', 'abs_diff', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3',
        'mean_angle'
    ]
    test_error_names1 = [
        'abs_rel', 'abs_diff', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3',
        'mean_angle'
    ]
    errors = AverageMeter(i=len(error_names))
    test_errors = AverageMeter(i=len(test_error_names))
    test_errors1 = AverageMeter(i=len(test_error_names1))
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    if args.train_cons:
        depth_cons.eval()
    else:
        mvdnet.eval()

    end = time.time()
    with torch.no_grad():
        for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics,
                intrinsics_inv, tgt_depth) in enumerate(val_loader):
            tgt_img_var = Variable(tgt_img.cuda())
            ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
            gt_nmap_var = Variable(gt_nmap.cuda())
            ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses]
            intrinsics_var = Variable(intrinsics.cuda())
            intrinsics_inv_var = Variable(intrinsics_inv.cuda())
            tgt_depth_var = Variable(tgt_depth.cuda())

            pose = torch.cat(ref_poses_var, 1)
            if (pose != pose).any():
                continue

            outputs = mvdnet(tgt_img_var, ref_imgs_var, pose, intrinsics_var,
                             intrinsics_inv_var)
            output_depth = outputs[0]
            output_depth1 = output_depth.clone()
            nmap = outputs[1]
            nmap1 = nmap.clone()

            output_depth1 = output_depth.clone()
            if args.train_cons:
                outputs = depth_cons(output_depth, nmap.permute(0, 3, 1, 2))
                nmap = outputs[:, 1:].permute(0, 2, 3, 1)
                output_depth = outputs[:, 0].unsqueeze(1)

            mask = (tgt_depth <= args.nlabel * args.mindepth) & (
                tgt_depth >= args.mindepth) & (tgt_depth == tgt_depth)
            #mask = (tgt_depth <= 10) & (tgt_depth >= args.mindepth) & (tgt_depth == tgt_depth) #for DeMoN testing, to compare against DPSNet you might need to turn on this for fair comparison

            if not mask.any():
                continue

            output_depth1_ = torch.squeeze(output_depth1.data.cpu(), 1)
            output_depth_ = torch.squeeze(output_depth.data.cpu(), 1)

            errors_ = compute_errors_train(tgt_depth, output_depth_, mask)
            test_errors_ = list(
                compute_errors_test(tgt_depth[mask], output_depth_[mask]))
            test_errors1_ = list(
                compute_errors_test(tgt_depth[mask], output_depth1_[mask]))

            n_mask = (gt_nmap_var.permute(0, 2, 3, 1)[0, :, :] != 0)
            n_mask = n_mask[:, :, 0] | n_mask[:, :, 1] | n_mask[:, :, 2]
            total_angles_m = compute_angles(
                gt_nmap_var.permute(0, 2, 3, 1)[0], nmap[0])
            total_angles_m1 = compute_angles(
                gt_nmap_var.permute(0, 2, 3, 1)[0], nmap1[0])

            mask_angles = total_angles_m[n_mask]
            mask_angles1 = total_angles_m1[n_mask]
            total_angles_m[~n_mask] = 0
            total_angles_m1[~n_mask] = 0
            errors_.append(
                torch.mean(mask_angles).item()
            )  #/mask_angles.size(0)#[torch.sum(mask_angles).item(), (mask_angles.size(0)),  torch.sum(mask_angles < 7.5).item(), torch.sum(mask_angles < 15).item(), torch.sum(mask_angles < 30).item(), torch.sum(mask_angles < 45).item()]
            test_errors_.append(torch.mean(mask_angles).item())
            test_errors1_.append(torch.mean(mask_angles1).item())
            errors.update(errors_)
            test_errors.update(test_errors_)
            test_errors1.update(test_errors1_)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if i % args.print_freq == 0 or i == len(val_loader) - 1:
                if args.train_cons:
                    print(
                        'valid: Time {} Prev Error {:.4f}({:.4f}) Curr Error {:.4f} ({:.4f}) Prev angle Error {:.4f} ({:.4f}) Curr angle Error {:.4f} ({:.4f}) Iter {}/{}'
                        .format(batch_time, test_errors1.val[0],
                                test_errors1.avg[0], test_errors.val[0],
                                test_errors.avg[0], test_errors1.val[-1],
                                test_errors1.avg[-1], test_errors.val[-1],
                                test_errors.avg[-1], i, len(val_loader)))
                else:
                    print(
                        'valid: Time {} Rel Error {:.4f} ({:.4f}) Angle Error {:.4f} ({:.4f}) Iter {}/{}'
                        .format(batch_time, test_errors.val[0],
                                test_errors.avg[0], test_errors.val[-1],
                                test_errors.avg[-1], i, len(val_loader)))
            if args.output_print:
                output_dir = Path(args.output_dir)
                if not os.path.isdir(output_dir):
                    os.mkdir(output_dir)
                plt.imsave(output_dir / '{:04d}_map{}'.format(i, '_dps.png'),
                           output_depth_.numpy()[0],
                           cmap='rainbow')
                np.save(output_dir / '{:04d}{}'.format(i, '_dps.npy'),
                        output_depth_.numpy()[0])
                if args.train_cons:
                    plt.imsave(output_dir /
                               '{:04d}_map{}'.format(i, '_prev.png'),
                               output_depth1_.numpy()[0],
                               cmap='rainbow')
                    np.save(output_dir / '{:04d}{}'.format(i, '_prev.npy'),
                            output_depth1_.numpy()[0])
                # np.save(output_dir/'{:04d}{}'.format(i,'_gt.npy'),tgt_depth.numpy()[0])
                # imsave(output_dir/'{:04d}_aimage{}'.format(i,'.png'), np.transpose(tgt_img.numpy()[0],(1,2,0)))
                # np.save(output_dir/'{:04d}_cam{}'.format(i,'.npy'),intrinsics_var.cpu().numpy()[0])
    if args.output_print:
        np.savetxt(output_dir / args.ttype + 'errors.csv',
                   test_errors.avg,
                   fmt='%1.4f',
                   delimiter=',')
        np.savetxt(output_dir / args.ttype + 'prev_errors.csv',
                   test_errors1.avg,
                   fmt='%1.4f',
                   delimiter=',')
    return errors.avg, error_names
Пример #5
0
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          logger, train_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        disparities = disp_net(tgt_img)
        depth = [1 / disp for disp in disparities]
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1 = photometric_reconstruction_loss(tgt_img, ref_imgs, intrinsics,
                                                 depth, explainability_mask,
                                                 pose, args.rotation_mode,
                                                 args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            if w2 > 0:
                train_writer.add_scalar('explanability_loss', loss_2.item(),
                                        n_iter)
            train_writer.add_scalar('disparity_smoothness_loss', loss_3.item(),
                                    n_iter)
            train_writer.add_scalar('total_loss', loss.item(), n_iter)

        if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0:
            train_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                   n_iter)

            with torch.no_grad():
                for k, scaled_depth in enumerate(depth):
                    train_writer.add_image(
                        'train Dispnet Output Normalized {}'.format(k),
                        tensor2array(disparities[k][0],
                                     max_value=None,
                                     colormap='magma'), n_iter)
                    train_writer.add_image(
                        'train Depth Output Normalized {}'.format(k),
                        tensor2array(1 / disparities[k][0], max_value=None),
                        n_iter)
                    b, _, h, w = scaled_depth.size()
                    downscale = tgt_img.size(2) / h

                    tgt_img_scaled = F.interpolate(tgt_img, (h, w),
                                                   mode='area')
                    ref_imgs_scaled = [
                        F.interpolate(ref_img, (h, w), mode='area')
                        for ref_img in ref_imgs
                    ]

                    intrinsics_scaled = torch.cat(
                        (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]),
                        dim=1)

                    # log warped images along with explainability mask
                    for j, ref in enumerate(ref_imgs_scaled):
                        ref_warped = inverse_warp(
                            ref,
                            scaled_depth[:, 0],
                            pose[:, j],
                            intrinsics_scaled,
                            rotation_mode=args.rotation_mode,
                            padding_mode=args.padding_mode)[0]
                        train_writer.add_image(
                            'train Warped Outputs {} {}'.format(k, j),
                            tensor2array(ref_warped), n_iter)
                        train_writer.add_image(
                            'train Diff Outputs {} {}'.format(k, j),
                            tensor2array(
                                0.5 * (tgt_img_scaled[0] - ref_warped).abs()),
                            n_iter)
                        if explainability_mask[k] is not None:
                            train_writer.add_image(
                                'train Exp mask Outputs {} {}'.format(k, j),
                                tensor2array(explainability_mask[k][0, j],
                                             max_value=1,
                                             colormap='bone'), n_iter)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item()
            ])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Пример #6
0
def train(train_loader, model, optimizer, epoch, args, log):
    '''train given model and dataloader'''
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    mixing_avg = []

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        data_time.update(time.time() - end)
        optimizer.zero_grad()

        input = input.cuda()
        target = target.long().cuda()

        unary = None
        noise = None
        adv_mask1 = 0
        adv_mask2 = 0

        # train with clean images
        if args.train == 'vanilla':
            input_var, target_var = Variable(input), Variable(target)
            output, reweighted_target = model(input_var, target_var)
            loss = bce_loss(softmax(output), reweighted_target)

        # train with mixup images
        elif args.train == 'mixup':
            # process for Puzzle Mix
            if args.graph:
                # whether to add adversarial noise or not
                if args.adv_p > 0:
                    adv_mask1 = np.random.binomial(n=1, p=args.adv_p)
                    adv_mask2 = np.random.binomial(n=1, p=args.adv_p)
                else:
                    adv_mask1 = 0
                    adv_mask2 = 0

                # random start
                if (adv_mask1 == 1 or adv_mask2 == 1):
                    noise = torch.zeros_like(input).uniform_(
                        -args.adv_eps / 255., args.adv_eps / 255.)
                    input_orig = input * args.std + args.mean
                    input_noise = input_orig + noise
                    input_noise = torch.clamp(input_noise, 0, 1)
                    noise = input_noise - input_orig
                    input_noise = (input_noise - args.mean) / args.std
                    input_var = Variable(input_noise, requires_grad=True)
                else:
                    input_var = Variable(input, requires_grad=True)
                target_var = Variable(target)

                # calculate saliency (unary)
                if args.clean_lam == 0:
                    model.eval()
                    output = model(input_var)
                    loss_batch = criterion_batch(output, target_var)
                else:
                    model.train()
                    output = model(input_var)
                    loss_batch = 2 * args.clean_lam * criterion_batch(
                        output, target_var) / args.num_classes

                loss_batch_mean = torch.mean(loss_batch, dim=0)
                loss_batch_mean.backward(retain_graph=True)

                unary = torch.sqrt(torch.mean(input_var.grad**2, dim=1))

                # calculate adversarial noise
                if (adv_mask1 == 1 or adv_mask2 == 1):
                    noise += (args.adv_eps + 2) / 255. * input_var.grad.sign()
                    noise = torch.clamp(noise, -args.adv_eps / 255.,
                                        args.adv_eps / 255.)
                    adv_mix_coef = np.random.uniform(0, 1)
                    noise = adv_mix_coef * noise

                if args.clean_lam == 0:
                    model.train()
                    optimizer.zero_grad()

            input_var, target_var = Variable(input), Variable(target)
            # perform mixup and calculate loss
            output, reweighted_target = model(input_var,
                                              target_var,
                                              mixup=True,
                                              args=args,
                                              grad=unary,
                                              noise=noise,
                                              adv_mask1=adv_mask1,
                                              adv_mask2=adv_mask2)
            loss = bce_loss(softmax(output), reweighted_target)

        # for manifold mixup
        elif args.train == 'mixup_hidden':
            input_var, target_var = Variable(input), Variable(target)
            output, reweighted_target = model(input_var,
                                              target_var,
                                              mixup_hidden=True,
                                              args=args)
            loss = bce_loss(softmax(output), reweighted_target)
        else:
            raise AssertionError('wrong train type!!')

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

    print_log(
        '  **Train** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'
        .format(top1=top1, top5=top5, error1=100 - top1.avg), log)
    return top1.avg, top5.avg, losses.avg
Пример #7
0
def main():
    # set up the experiment directories
    if not args.log_off:
        exp_name = experiment_name_non_mnist()
        exp_dir = os.path.join(args.root_dir, exp_name)

        if not os.path.exists(exp_dir):
            os.makedirs(exp_dir)

        copy_script_to_folder(os.path.abspath(__file__), exp_dir)

        result_png_path = os.path.join(exp_dir, 'results.png')
        log = open(os.path.join(exp_dir, 'log.txt'.format(args.seed)), 'w')
        print_log('save path : {}'.format(exp_dir), log)
    else:
        log = None

    global best_acc

    state = {k: v for k, v in args._get_kwargs()}
    print("")
    print_log(state, log)
    print("")
    print_log("Random Seed: {}".format(args.seed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # dataloader
    train_loader, valid_loader, _, test_loader, num_classes = load_data_subset(
        args.batch_size,
        2,
        args.dataset,
        args.data_dir,
        labels_per_class=args.labels_per_class,
        valid_labels_per_class=args.valid_labels_per_class,
        mixup_alpha=args.mixup_alpha)

    if args.dataset == 'tiny-imagenet-200':
        stride = 2
        args.mean = torch.tensor([0.5] * 3,
                                 dtype=torch.float32).view(1, 3, 1, 1).cuda()
        args.std = torch.tensor([0.5] * 3,
                                dtype=torch.float32).view(1, 3, 1, 1).cuda()
        args.labels_per_class = 500
    elif args.dataset == 'cifar10':
        stride = 1
        args.mean = torch.tensor([x / 255 for x in [125.3, 123.0, 113.9]],
                                 dtype=torch.float32).view(1, 3, 1, 1).cuda()
        args.std = torch.tensor([x / 255 for x in [63.0, 62.1, 66.7]],
                                dtype=torch.float32).view(1, 3, 1, 1).cuda()
        args.labels_per_class = 5000
    elif args.dataset == 'cifar100':
        stride = 1
        args.mean = torch.tensor([x / 255 for x in [129.3, 124.1, 112.4]],
                                 dtype=torch.float32).view(1, 3, 1, 1).cuda()
        args.std = torch.tensor([x / 255 for x in [68.2, 65.4, 70.4]],
                                dtype=torch.float32).view(1, 3, 1, 1).cuda()
        args.labels_per_class = 500
    else:
        raise AssertionError('Given Dataset is not supported!')

    # create model
    print_log("=> creating model '{}'".format(args.arch), log)
    net = models.__dict__[args.arch](num_classes, args.dropout, stride).cuda()
    args.num_classes = num_classes

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    optimizer = torch.optim.SGD(net.parameters(),
                                state['learning_rate'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)

    recorder = RecorderMeter(args.epochs)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_acc = recorder.max_accuracy(False)
            print_log(
                "=> loaded checkpoint '{}' accuracy={} (epoch {})".format(
                    args.resume, best_acc, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    start_time = time.time()
    epoch_time = AverageMeter()
    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule)
        if epoch == args.schedule[0]:
            args.clean_lam == 0

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)
        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        tr_acc, tr_acc5, tr_los = train(train_loader, net, optimizer, epoch,
                                        args, log)

        # evaluate on validation set
        val_acc, val_los = validate(test_loader, net, log)
        if (epoch % 50) == 0 and args.adv_p > 0:
            _, _ = validate(test_loader,
                            net,
                            log,
                            fgsm=True,
                            eps=4,
                            mean=args.mean,
                            std=args.std)
            _, _ = validate(test_loader,
                            net,
                            log,
                            fgsm=True,
                            eps=8,
                            mean=args.mean,
                            std=args.std)

        train_loss.append(tr_los)
        train_acc.append(tr_acc)
        test_loss.append(val_los)
        test_acc.append(val_acc)

        is_best = False
        if val_acc > best_acc:
            is_best = True
            best_acc = val_acc

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

        if args.log_off:
            continue

        # save log
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, exp_dir, 'checkpoint.pth.tar')

        dummy = recorder.update(epoch, tr_los, tr_acc, val_los, val_acc)
        if (epoch + 1) % 100 == 0:
            recorder.plot_curve(result_png_path)

        train_log = OrderedDict()
        train_log['train_loss'] = train_loss
        train_log['train_acc'] = train_acc
        train_log['test_loss'] = test_loss
        train_log['test_acc'] = test_acc

        pickle.dump(train_log, open(os.path.join(exp_dir, 'log.pkl'), 'wb'))
        plotting(exp_dir)

    acc_var = np.maximum(
        np.max(test_acc[-10:]) - np.median(test_acc[-10:]),
        np.median(test_acc[-10:]) - np.min(test_acc[-10:]))
    print_log(
        "\nfinal 10 epoch acc (median) : {:.2f} (+- {:.2f})".format(
            np.median(test_acc[-10:]), acc_var), log)

    if not args.log_off:
        log.close()
Пример #8
0
def validate_with_gt(args, val_loader, mvdnet, epoch, output_writers=[]):
    batch_time = AverageMeter()
    error_names = [
        'abs_rel', 'abs_diff', 'sq_rel', 'a1', 'a2', 'a3', 'mean_angle'
    ]
    test_error_names = [
        'abs_rel', 'abs_diff', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3',
        'mean_angle'
    ]
    errors = AverageMeter(i=len(error_names))
    test_errors = AverageMeter(i=len(test_error_names))
    log_outputs = len(output_writers) > 0

    output_dir = Path(args.output_dir)
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    # switch to evaluate mode
    mvdnet.eval()

    end = time.time()
    with torch.no_grad():
        for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics,
                intrinsics_inv, tgt_depth) in enumerate(val_loader):
            tgt_img_var = Variable(tgt_img.cuda())
            ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
            gt_nmap_var = Variable(gt_nmap.cuda())
            ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses]
            intrinsics_var = Variable(intrinsics.cuda())
            intrinsics_inv_var = Variable(intrinsics_inv.cuda())
            tgt_depth_var = Variable(tgt_depth.cuda())

            pose = torch.cat(ref_poses_var, 1)

            if (pose != pose).any():
                continue

            if args.dataset == 'sceneflow':
                factor = (1.0 / args.scale) * intrinsics_var[:, 0, 0] / 1050.0
                factor = factor.view(-1, 1, 1)
            else:
                factor = torch.ones(
                    (tgt_depth_var.size(0), 1, 1)).type_as(tgt_depth_var)

            # get mask
            mask = (tgt_depth_var <= args.nlabel * args.mindepth * factor *
                    3) & (tgt_depth_var >= args.mindepth * factor) & (
                        tgt_depth_var == tgt_depth_var)

            if not mask.any():
                continue

            output_depth, nmap = mvdnet(tgt_img_var,
                                        ref_imgs_var,
                                        pose,
                                        intrinsics_var,
                                        intrinsics_inv_var,
                                        factor=factor.unsqueeze(1))
            output_disp = args.nlabel * args.mindepth / (output_depth)
            if args.dataset == 'sceneflow':
                output_disp = (args.nlabel *
                               args.mindepth) * 3 / (output_depth)
                output_depth = (args.nlabel * 3) * (args.mindepth *
                                                    factor) / output_disp

            tgt_disp_var = ((1.0 / args.scale) *
                            intrinsics_var[:, 0, 0].view(-1, 1, 1) /
                            tgt_depth_var)

            if args.dataset == 'sceneflow':
                output = torch.squeeze(output_disp.data.cpu(), 1)
                errors_ = compute_errors_train(tgt_disp_var.cpu(), output,
                                               mask)
                test_errors_ = list(
                    compute_errors_test(tgt_disp_var.cpu()[mask],
                                        output[mask]))
            else:
                output = torch.squeeze(output_depth.data.cpu(), 1)
                errors_ = compute_errors_train(tgt_depth, output, mask)
                test_errors_ = list(
                    compute_errors_test(tgt_depth[mask], output[mask]))

            n_mask = (gt_nmap_var.permute(0, 2, 3, 1)[0, :, :] != 0)
            n_mask = n_mask[:, :, 0] | n_mask[:, :, 1] | n_mask[:, :, 2]
            total_angles_m = compute_angles(
                gt_nmap_var.permute(0, 2, 3, 1)[0], nmap[0])

            mask_angles = total_angles_m[n_mask]
            total_angles_m[~n_mask] = 0
            errors_.append(
                torch.mean(mask_angles).item()
            )  #/mask_angles.size(0)#[torch.sum(mask_angles).item(), (mask_angles.size(0)),  torch.sum(mask_angles < 7.5).item(), torch.sum(mask_angles < 15).item(), torch.sum(mask_angles < 30).item(), torch.sum(mask_angles < 45).item()]
            test_errors_.append(torch.mean(mask_angles).item())
            errors.update(errors_)
            test_errors.update(test_errors_)
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if args.output_print:
                np.save(output_dir / '{:04d}{}'.format(i, '_depth.npy'),
                        output.numpy()[0])
                plt.imsave(output_dir / '{:04d}_gt{}'.format(i, '.png'),
                           tgt_depth.numpy()[0],
                           cmap='rainbow')
                imsave(output_dir / '{:04d}_aimage{}'.format(i, '.png'),
                       np.transpose(tgt_img.numpy()[0], (1, 2, 0)))
                np.save(output_dir / '{:04d}_cam{}'.format(i, '.npy'),
                        intrinsics_var.cpu().numpy()[0])
                np.save(output_dir / '{:04d}{}'.format(i, '_normal.npy'),
                        nmap.cpu().numpy()[0])

            if i % args.print_freq == 0:
                print(
                    'valid: Time {} Abs Error {:.4f} ({:.4f}) Abs angle Error {:.4f} ({:.4f}) Iter {}/{}'
                    .format(batch_time, test_errors.val[0], test_errors.avg[0],
                            test_errors.val[-1], test_errors.avg[-1], i,
                            len(val_loader)))
    if args.output_print:
        np.savetxt(output_dir / args.ttype + 'errors.csv',
                   test_errors.avg,
                   fmt='%1.4f',
                   delimiter=',')
        np.savetxt(output_dir / args.ttype + 'angle_errors.csv',
                   test_errors.avg,
                   fmt='%1.4f',
                   delimiter=',')
    return errors.avg, error_names
Пример #9
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        logger,
                        output_writers=[]):
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = len(output_writers) > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        # compute output
        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var)

        loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var,
                                                 intrinsics_var,
                                                 intrinsics_inv_var, depth,
                                                 explainability_mask, pose,
                                                 args.rotation_mode,
                                                 args.padding_mode)
        loss_1 = loss_1.data[0]
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).data[0]
        else:
            loss_2 = 0
        loss_3 = smooth_loss(disp).data[0]

        if log_outputs and i % 100 == 0 and i / 100 < len(
                output_writers):  # log first output of every 100 batch
            index = int(i // 100)
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(tgt_img[0]),
                                                    0)
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(ref[0]), 1)

            output_writers[index].add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), epoch)
            output_writers[index].add_image(
                'val Depth Output',
                tensor2array(1. / disp.data[0].cpu(), max_value=10), epoch)
            # log warped images along with explainability mask
            for j, ref in enumerate(ref_imgs_var):
                ref_warped = inverse_warp(ref[:1],
                                          depth[:1, 0],
                                          pose[:1, j],
                                          intrinsics_var[:1],
                                          intrinsics_inv_var[:1],
                                          rotation_mode=args.rotation_mode,
                                          padding_mode=args.padding_mode)[0]

                output_writers[index].add_image(
                    'val Warped Outputs {}'.format(j),
                    tensor2array(ref_warped.data.cpu()), epoch)
                output_writers[index].add_image(
                    'val Diff Outputs {}'.format(j),
                    tensor2array(
                        0.5 * (tgt_img_var[0] - ref_warped).abs().data.cpu()),
                    epoch)
                if explainability_mask is not None:
                    output_writers[index].add_image(
                        'val Exp mask Outputs {}'.format(j),
                        tensor2array(explainability_mask[0, j].data.cpu(),
                                     max_value=1,
                                     colormap='bone'), epoch)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1,
                                                                  6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.data.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            output_writers.add_histogram(
                '{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch)
        output_writers[0].add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
Пример #10
0
def validate(val_loader, model, log):
    '''evaluate trained model'''
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # Switch to evaluate mode
    model.eval()

    for i, (input, target) in enumerate(val_loader):
        if args.use_cuda:
            input = input.cuda()
            target = target.cuda()

        with torch.no_grad():
            output = model(input)
            target_reweighted = to_one_hot(target, args.num_classes)
            loss = bce_loss(softmax(output), target_reweighted)

        # Measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

    print_log(
        '**Test ** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Loss: {losses.avg:.3f} '
        .format(top1=top1, top5=top5, error1=100 - top1.avg,
                losses=losses), log)

    return top1.avg, losses.avg
Пример #11
0
def train(args, train_loader, mvdnet, optimizer, epoch_size, train_writer,
          epoch):
    global n_iter
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    d_losses = AverageMeter(precision=4)
    nmap_losses = AverageMeter(precision=4)

    # switch to training mode
    mvdnet.train()

    print("Training")
    end = time.time()
    for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics, intrinsics_inv,
            tgt_depth) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img_var = Variable(tgt_img.cuda())
        ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
        gt_nmap_var = Variable(gt_nmap.cuda())
        ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses]
        intrinsics_var = Variable(intrinsics.cuda())
        intrinsics_inv_var = Variable(intrinsics_inv.cuda())
        tgt_depth_var = Variable(tgt_depth.cuda()).cuda()

        # compute output
        pose = torch.cat(ref_poses_var, 1)

        if args.dataset == 'sceneflow':
            factor = (1.0 / args.scale) * intrinsics_var[:, 0, 0] / 1050.0
            factor = factor.view(-1, 1, 1)
        else:
            factor = torch.ones(
                (tgt_depth_var.size(0), 1, 1)).type_as(tgt_depth_var)

        # get mask
        mask = (tgt_depth_var <= args.nlabel * args.mindepth * factor * 3) & (
            tgt_depth_var >= args.mindepth * factor) & (tgt_depth_var
                                                        == tgt_depth_var)
        mask.detach_()
        if mask.any() == 0:
            continue

        targetimg = inverse_warp(ref_imgs_var[0], tgt_depth_var.unsqueeze(1),
                                 pose[:, 0], intrinsics_var,
                                 intrinsics_inv_var)  #[B,CH,D,H,W,1]

        outputs = mvdnet(tgt_img_var,
                         ref_imgs_var,
                         pose,
                         intrinsics_var,
                         intrinsics_inv_var,
                         factor=factor.unsqueeze(1))

        nmap = outputs[2].permute(0, 3, 1, 2)
        depths = outputs[0:2]

        disps = [args.mindepth * args.nlabel / (depth)
                 for depth in depths]  # correct disps
        if args.dataset == 'sceneflow':
            disps = [(args.mindepth * args.nlabel) * 3 / (depth)
                     for depth in depths]  # correct disps
            depths = [(args.mindepth * factor) * (args.nlabel * 3) / disp
                      for disp in disps]
        loss = 0.
        d_loss = 0.
        nmap_loss = 0.
        if args.dataset == 'sceneflow':
            tgt_disp_var = ((1.0 / args.scale) *
                            intrinsics_var[:, 0, 0].view(-1, 1, 1) /
                            tgt_depth_var)
            for l, disp in enumerate(disps):
                output = torch.squeeze(disp, 1)
                d_loss = d_loss + F.smooth_l1_loss(output[mask],
                                                   tgt_disp_var[mask]) * pow(
                                                       0.7,
                                                       len(disps) - l - 1)
        else:
            for l, depth in enumerate(depths):
                output = torch.squeeze(depth, 1)
                d_loss = d_loss + F.smooth_l1_loss(output[mask],
                                                   tgt_depth_var[mask]) * pow(
                                                       0.7,
                                                       len(depths) - l - 1)

        n_mask = mask.unsqueeze(1).expand(-1, 3, -1, -1)
        nmap_loss = nmap_loss + F.smooth_l1_loss(nmap[n_mask],
                                                 gt_nmap_var[n_mask])

        loss = loss + args.d_weight * d_loss + args.n_weight * nmap_loss

        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('total_loss', loss.item(), n_iter)
        # record loss and EPE
        losses.update(loss.item(), args.batch_size)
        d_losses.update(d_loss.item(), args.batch_size)
        nmap_losses.update(nmap_loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([loss.item()])
        if i % args.print_freq == 0:
            print(
                'Train: Time {} Data {} Loss {} NmapLoss {} DLoss {} Iter {}/{} Epoch {}/{}'
                .format(batch_time, data_time, losses, nmap_losses, d_losses,
                        i, len(train_loader), epoch, args.epochs))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Пример #12
0
def train(train_loader, model, optimizer, epoch, args, log, mpp=None):
    '''train given model and dataloader'''
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    mixing_avg = []

    # switch to train mode
    model.train()

    end = time.time()
    for input, target in train_loader:
        data_time.update(time.time() - end)
        optimizer.zero_grad()

        input = input.cuda()
        target = target.long().cuda()
        sc = None

        # train with clean images
        if not args.comix:
            target_reweighted = to_one_hot(target, args.num_classes)
            output = model(input)
            loss = bce_loss(softmax(output), target_reweighted)

        # train with Co-Mixup images
        else:
            input_var = Variable(input, requires_grad=True)
            target_var = Variable(target)
            A_dist = None

            # Calculate saliency (unary)
            if args.clean_lam == 0:
                model.eval()
                output = model(input_var)
                loss_batch = criterion_batch(output, target_var)
            else:
                model.train()
                output = model(input_var)
                loss_batch = 2 * args.clean_lam * criterion_batch(
                    output, target_var) / args.num_classes
            loss_batch_mean = torch.mean(loss_batch, dim=0)
            loss_batch_mean.backward(retain_graph=True)
            sc = torch.sqrt(torch.mean(input_var.grad**2, dim=1))

            # Here, we calculate distance between most salient location (Compatibility)
            # We can try various measurements
            with torch.no_grad():
                z = F.avg_pool2d(sc, kernel_size=8, stride=1)
                z_reshape = z.reshape(args.batch_size, -1)
                z_idx_1d = torch.argmax(z_reshape, dim=1)
                z_idx_2d = torch.zeros((args.batch_size, 2), device=z.device)
                z_idx_2d[:, 0] = z_idx_1d // z.shape[-1]
                z_idx_2d[:, 1] = z_idx_1d % z.shape[-1]
                A_dist = distance(z_idx_2d, dist_type='l1')

            if args.clean_lam == 0:
                model.train()
                optimizer.zero_grad()

            # Perform mixup and calculate loss
            target_reweighted = to_one_hot(target, args.num_classes)
            if args.parallel:
                device = input.device
                out, target_reweighted = mpp(input.cpu(),
                                             target_reweighted.cpu(),
                                             args=args,
                                             sc=sc.cpu(),
                                             A_dist=A_dist.cpu())
                out = out.to(device)
                target_reweighted = target_reweighted.to(device)

            else:
                out, target_reweighted = mixup_process(input,
                                                       target_reweighted,
                                                       args=args,
                                                       sc=sc,
                                                       A_dist=A_dist)

            out = model(out)
            loss = bce_loss(softmax(out), target_reweighted)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

    print_log(
        '**Train** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f}'
        .format(top1=top1, top5=top5, error1=100 - top1.avg), log)
    return top1.avg, top5.avg, losses.avg
Пример #13
0
def test(val_loader,disp_net,mask_net,pose_net, flow_net, tb_writer,global_vars_dict = None):
#data prepared
    device = global_vars_dict['device']
    n_iter_val = global_vars_dict['n_iter_val']
    args = global_vars_dict['args']


    data_time = AverageMeter()


# to eval model
    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    end = time.time()
    poses = np.zeros(((len(val_loader)-1) * 1 * (args.sequence_length-1),6))#init

    disp_list = []

    flow_list = []
    mask_list = []

#3. validation cycle
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in tqdm(enumerate(val_loader)):
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics,intrinsics_inv = intrinsics.to(device),intrinsics_inv.to(device)
    #3.1 forwardpass
        #disp
        disp = disp_net(tgt_img)
        if args.spatial_normalize:
            disp = spatial_normalize(disp)
        depth = 1 / disp

        #pose
        pose = pose_net(tgt_img, ref_imgs)
        #flow----
        #制作前后一帧的
        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3])
        elif args.flownet == 'FlowNetC6':
            flow_fwd = flow_net(tgt_img, ref_imgs[2])
            flow_bwd = flow_net(tgt_img, ref_imgs[1])
        #FLOW FWD [B,2,H,W]
        #flow cam :tensor[b,2,h,w]
        #flow_background
        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv)

        flows_cam_fwd = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv)
        flows_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics, intrinsics_inv)

        #exp_masks_target = consensus_exp_masks(flows_cam_fwd, flows_cam_bwd, flow_fwd, flow_bwd, tgt_img,
        #                                       ref_imgs[2], ref_imgs[1], wssim=args.wssim, wrig=args.wrig,
        #                                       ws=args.smooth_loss_weight)

        rigidity_mask_fwd = (flows_cam_fwd - flow_fwd).abs()#[b,2,h,w]
        rigidity_mask_bwd = (flows_cam_bwd - flow_bwd).abs()

        # mask
        # 4.explainability_mask(none)
        explainability_mask = mask_net(tgt_img, ref_imgs)  # 有效区域?4??

        # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63]
        end = time.time()


    #3.4 check log

        #查看forward pass效果
    # 2 disp
        disp_to_show =tensor2array(disp[0].cpu(), max_value=None,colormap='bone')# tensor disp_to_show :[1,h,w],0.5~3.1~10
        tb_writer.add_image('Disp/disp0', disp_to_show,i)
        disp_list.append(disp_to_show)

        if i == 0:
            disp_arr =  np.expand_dims(disp_to_show,axis=0)
        else:
            disp_to_show = np.expand_dims(disp_to_show,axis=0)
            disp_arr = np.concatenate([disp_arr,disp_to_show],0)


    #3. flow
        tb_writer.add_image('Flow/Flow Output', flow2rgb(flow_fwd[0], max_value=6),i)
        tb_writer.add_image('Flow/cam_Flow Output', flow2rgb(flow_cam[0], max_value=6),i)
        tb_writer.add_image('Flow/rigid_Flow Output', flow2rgb(rigidity_mask_fwd[0], max_value=6),i)
        tb_writer.add_image('Flow/rigidity_mask_fwd',flow2rgb(rigidity_mask_fwd[0],max_value=6),i)
        flow_list.append(flow2rgb(flow_fwd[0], max_value=6))
    #4. mask
        tb_writer.add_image('Mask /mask0',tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), i)
        #tb_writer.add_image('Mask Output/mask1 sample{}'.format(i),tensor2array(explainability_mask[1][0], max_value=None, colormap='magma'), epoch)
        #tb_writer.add_image('Mask Output/mask2 sample{}'.format(i),tensor2array(explainability_mask[2][0], max_value=None, colormap='magma'), epoch)
        #tb_writer.add_image('Mask Output/mask3 sample{}'.format(i),tensor2array(explainability_mask[3][0], max_value=None, colormap='magma'), epoch)
        mask_list.append(tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'))
    #

    return disp_list,disp_arr,flow_list,mask_list
Пример #14
0
def validate(args, val_loader, pose_exp_net, logger, tb_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)

    pose_exp_net.eval()
    end = time.time()
    logger.valid_bar.update(0)

    for i, (tgt_img, tgt_lf, ref_imgs, ref_lfs, intrinsics, intrinsics_inv, pose_gt) in enumerate(val_loader):

        data_time.update(time.time() - end)
        tgt_lf = tgt_lf.to(device)
        ref_lfs = [lf.to(device) for lf in ref_lfs]
        pose_gt = pose_gt.to(device)

        explainability_mask, pose = pose_exp_net(tgt_lf, ref_lfs)
        loss = (pose - pose_gt).abs().mean()
        losses.update(loss.item(), args.batch_size)

        batch_time.update(time.time() - end)
        logger.valid_bar.update(i+1)

        if i % args.print_freq == 0:
            logger.valid_writer.write('Validate: Time {} Data {} Loss {}'.format(batch_time, data_time, losses))


        n_iter += 1

    tb_writer.add_scalar('loss/valid', losses.avg[0], n_iter)
    return losses.avg[0]
Пример #15
0
def validate_Make3D(args, val_loader, model, epoch, logger, mode='DtoD'):
    ##global device
    batch_time = AverageMeter()
    error_names = ['abs_diff', 'abs_rel', 'ave_log10', 'rmse']
    errors = AverageMeter(i=len(error_names))
    min_errors = AverageMeter(i=len(error_names))
    min_errors_list = []

    abs_diff_tot, abs_rel_tot, ave_log10_tot, rmse_tot = [], [], [], []
    abs_diff_sum, abs_rel_sum, ave_log10_sum, rmse_sum = 0, 0, 0, 0

    # switch to evaluate mode
    #model.eval()
    print("mode: ", args.mode)
    end = time.time()
    logger.valid_bar.update(0)
    for i, (depth, img, depth_np) in enumerate(val_loader):
        img = img.cuda()
        depth = depth.cuda()
        depth_np = depth_np.cuda()
        # compute output
        if mode == 'RtoD' or mode == 'RtoD_test':
            input_img = img
        elif mode == 'DtoD' or mode == 'DtoD_test':
            input_img = depth
        with torch.no_grad():
            output_depth = model(input_img, istrain=False)
        err_result = compute_errors_Make3D(depth_np, depth, output_depth)
        errors.update(err_result)
        abs_diff_tot.append(err_result[0])
        abs_rel_tot.append(err_result[1])
        ave_log10_tot.append(err_result[2])
        rmse_tot.append(err_result[3])
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write(
                'valid: Time {} Abs Error {:.4f} ({:.4f})'.format(
                    batch_time, errors.val[0], errors.avg[0]))

    logger.valid_bar.update(len(val_loader))
    sorted_abs_diff = sorted(abs_diff_tot)
    #min_len = 72
    min_len = (len(sorted_abs_diff))
    print("scene length: ", min_len)
    print("sorted_abs_diff length: ", len(sorted_abs_diff))
    for i in range(min_len):
        sort_idx = abs_diff_tot.index(sorted_abs_diff[i])
        abs_diff_sum += sorted_abs_diff[i]
        abs_rel_sum += abs_rel_tot[sort_idx]
        ave_log10_sum += ave_log10_tot[sort_idx]
        rmse_sum += rmse_tot[sort_idx]
    min_errors_list.append(abs_diff_sum / min_len)
    min_errors_list.append(abs_rel_sum / min_len)
    min_errors_list.append(ave_log10_sum / min_len)
    min_errors_list.append(rmse_sum / min_len)
    min_errors.update(min_errors_list)

    return errors.avg, min_errors.avg, error_names
Пример #16
0
def main():
    global args
    args = parser.parse_args()
    args.pretrained_disp = Path(args.pretrained_disp)
    args.pretrained_pose = Path(args.pretrained_pose)
    args.pretrained_mask = Path(args.pretrained_mask)
    args.pretrained_flow = Path(args.pretrained_flow)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        gt_dir = args.output_dir / 'gt'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'

        image_dir.makedirs_p()
        gt_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()

        output_writer = SummaryWriter(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    if args.dataset == "kitti2015":
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                      sequence_length=5,
                                      transform=valid_flow_transform)

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda()
    mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_disp)
    posenet_weights = torch.load(args.pretrained_pose)
    masknet_weights = torch.load(args.pretrained_mask)
    flownet_weights = torch.load(args.pretrained_flow)
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    error_names = [
        'epe_total', 'epe_sp', 'epe_mv', 'Fl', 'epe_total_gt_mask',
        'epe_sp_gt_mask', 'epe_mv_gt_mask', 'Fl_gt_mask'
    ]
    errors = AverageMeter(i=len(error_names))
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt,
            obj_map_gt) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)

        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3])
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])

        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)
        flow_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics_var,
                                 intrinsics_inv_var)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5
        rigidity_mask_census_soft = (flow_cam - flow_fwd).abs()  #.normalize()
        rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH
        rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH
        rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * (
            rigidity_mask_census_v).type_as(flow_fwd)

        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = (rigidity_mask_combined > args.THRESH
                          ).type_as(flow_cam).expand_as(flow_cam) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        rigidity_mask = rigidity_mask.type_as(flow_fwd)
        _epe_errors = compute_all_epes(
            flow_gt_var, flow_cam,
            flow_fwd, rigidity_mask_combined) + compute_all_epes(
                flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded))
        errors.update(_epe_errors)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )
        gt_mask_np = obj_map_gt[0].numpy()

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(gt_dir / str(i).zfill(3), gt_mask_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)

        if (args.output_dir is not None) and i % 10 == 0:
            ind = int(i // 10)
            output_writer.add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), ind)
            output_writer.add_image('val Input',
                                    tensor2array(tgt_img[0].cpu()), i)
            output_writer.add_image(
                'val Total Flow Output',
                flow_to_image(tensor2array(total_flow.data[0].cpu())), ind)
            output_writer.add_image(
                'val Rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind)
            output_writer.add_image(
                'val Non-rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())),
                ind)
            output_writer.add_image(
                'val Rigidity Mask',
                tensor2array(rigidity_mask.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Census',
                tensor2array(rigidity_mask_census.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Combined',
                tensor2array(rigidity_mask_combined.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)

            tgt_img_viz = tensor2array(tgt_img[0].cpu())
            depth_viz = tensor2array(disp.data[0].cpu(),
                                     max_value=None,
                                     colormap='bone')
            mask_viz = tensor2array(
                rigidity_mask_census_soft.data[0].prod(dim=0).cpu(),
                max_value=1,
                colormap='bone')
            rigid_flow_viz = flow_to_image(tensor2array(
                flow_cam.data[0].cpu()))
            non_rigid_flow_viz = flow_to_image(
                tensor2array(flow_fwd_non_rigid.data[0].cpu()))
            total_flow_viz = flow_to_image(
                tensor2array(total_flow.data[0].cpu()))
            row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz))
            row2_viz = np.hstack(
                (rigid_flow_viz, non_rigid_flow_viz, total_flow_viz))

            row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8'))
            row2_viz_im = Image.fromarray((row2_viz).astype('uint8'))

            row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png')
            row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png')

    print("Results")
    print("\t {:>10}, {:>10}, {:>10}, {:>6}, {:>10}, {:>10}, {:>10}, {:>10} ".
          format(*error_names))
    print(
        "Errors \t {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*errors.avg))
Пример #17
0
def train_AE_RtoD(args, model, DtoD_model, criterion_L2, criterion_L1,
                  optimizer, dataset_loader, val_loader, batch_size, n_epochs,
                  lr, logger, train_writer):
    global n_iter, best_error
    print("Training for %d epochs..." % n_epochs)
    num = 0
    model_num = 0
    data_iter = iter(dataset_loader)
    depth_fixed, rgb_fixed, _ = next(data_iter)
    depth_fixed = depth_fixed.cuda()
    rgb_fixed = rgb_fixed.cuda()

    predicted_dirs = './' + args.dataset + '_AE_RtoD_predicted_lr000%d_color_uNet_gen2_nogradf' % (
        lr * 100000)
    result_dirs = './' + args.dataset + '_AE_RtoD_feat_result_lr000%d_color_uNet_gen2_nogradf/out' % (
        lr * 100000)
    result_gt_dirs = './' + args.dataset + '_AE_RtoD_feat_result_lr000%d_color_uNet_gen2_nogradf/gt' % (
        lr * 100000)
    save_dir = './' + args.dataset + '_AE_RtoD_trained_model_lr000%d_color_uNet_gen2_nogradf' % (
        lr * 100000)
    if ((args.local_rank + 1) % 4 == 0):
        if not os.path.exists(predicted_dirs):
            os.makedirs(predicted_dirs)
        if not os.path.exists(result_dirs):
            os.makedirs(result_dirs)
        if not os.path.exists(result_gt_dirs):
            os.makedirs(result_gt_dirs)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

    H = depth_fixed.shape[2]
    W = depth_fixed.shape[3]
    num_sample_list = [16, 64, 64, 64, 64, 64, 16]
    figsize_x_list = [14, 16, 10, 10, 10, 16, 14]
    figsize_y_list = [7, 8, 5, 5, 5, 8, 7]

    if args.dataset != 'NYU':
        d_range_list = [4, 2, 2, 4, 2, 2, 4]
        ftmap_height_list = [
            H,
            int(H / 2),
            int(H / 8),
            int(H / 16),
            int(H / 8),
            int(H / 2), H
        ]
        ftmap_width_list = [
            W,
            int(W / 2),
            int(W / 8),
            int(W / 16),
            int(W / 8),
            int(W / 2), W
        ]
    else:
        d_range_list = [4, 2, 4, 2, 4, 2, 4]
        ftmap_height_list = [
            H,
            int(H / 2),
            int(H / 4),
            int(H / 8),
            int(H / 16),
            int(H / 2), H
        ]
        ftmap_width_list = [
            W,
            int(W / 2),
            int(W / 4),
            int(W / 8),
            int(W / 16),
            int(W / 2), W
        ]

    test_loss_dir = Path(args.save_path)
    test_loss_dir_rmse = str(test_loss_dir / 'test_rmse_list.txt')
    test_loss_dir = str(test_loss_dir / 'test_loss_list.txt')
    train_loss_dir = Path(args.save_path)
    train_loss_dir_rmse = str(train_loss_dir / 'train_rmse_list.txt')
    train_loss_dir = str(train_loss_dir / 'train_loss_list.txt')

    loss_list = []
    rmse_list = []
    train_loss_list = []
    train_rmse_list = []
    num_cnt = 0
    train_loss_cnt = 0
    if args.dataset == "KITTI":
        y1, y2 = int(0.40810811 * depth_fixed.size(2)), int(
            0.99189189 * depth_fixed.size(2))
        x1, x2 = int(0.03594771 * depth_fixed.size(3)), int(
            0.96405229 * depth_fixed.size(3))  ### Crop used by Garg ECCV 2016
        '''
        y1,y2 = int(0.3324324 * depth_fixed.size(2)), int(0.91351351 * depth_fixed.size(2))
        x1,x2 = int(0.0359477 * depth_fixed.size(3)), int(0.96405229 * depth_fixed.size(3))     ### Crop used by Godard CVPR 2017   
        '''
        print(" - valid y range: %d ~ %d" % (y1, y2))
        print(" - valid x range: %d ~ %d" % (x1, x2))
    for epoch in tqdm(range(n_epochs)):
        if args.dataset == "KITTI":
            crop_mask = depth_fixed != depth_fixed
            #print('crop_mask size: ',crop_mask.size())
            crop_mask[:, :, y1:y2, x1:x2] = 1
        if logger is not None:
            logger.epoch_bar.update(epoch)
            ####################################### one epoch training #############################################
            logger.reset_train_bar()
            batch_time = AverageMeter()
            data_time = AverageMeter()
            losses = AverageMeter(precision=4)
        ################ train mode ####################
        model.train()
        ################################################
        end = time.time()
        if logger is not None:
            logger.train_bar.update(0)
        for i, (gt_data, rgb_data, gt_data_2) in enumerate(dataset_loader):
            # data loading time
            if logger is not None:
                data_time.update(time.time() - end)
            # get the inputs
            inputs = rgb_data
            depths = gt_data
            if args.dataset != "KITTI":
                gt_data_2 = None
            # If gt_data_2 is None ==> NYU dataset!
            if gt_data_2 is not None:
                sparse_depths = gt_data_2.cuda()
                sparse_depths = Variable(sparse_depths)

            origin = depths
            inputs = inputs.cuda()
            depths = depths.cuda()

            # wrap them in Variable
            inputs, depths = Variable(inputs), Variable(depths)

            ########################################
            ### Train the AutoEncoder (Generator) ###
            ########################################
            '''AutoEncoder loss'''
            outputs = model(inputs, istrain=False)

            if args.mode != 'RtoD_single':
                with torch.no_grad():
                    ft_map1_tar, ft_map2_tar, ft_map3_tar, ft_map4_tar, _, _, _, _ = DtoD_model(
                        depths, istrain=True)
            if args.mode != 'RtoD_single':
                with torch.no_grad():
                    ft_map1, ft_map2, ft_map3, ft_map4, _, _, _, _ = DtoD_model(
                        outputs, istrain=True)
            # masking valied area
            if gt_data_2 is not None:
                valid_mask = sparse_depths > -1
                valid_mask = valid_mask[:, 0, :, :].unsqueeze(1)
                if (crop_mask.size(0) != valid_mask.size(0)):
                    crop_mask = crop_mask[0:valid_mask.size(0), :, :, :]

            diff = outputs - depths
            diff_abs = torch.abs(diff)
            diff_2 = torch.pow(outputs - depths, 2)
            c = 0.2 * torch.max(diff_abs.detach())
            mask2 = torch.gt(diff_abs.detach(), c)
            diff_abs[mask2] = (diff_2[mask2] + (c * c)) / (2 * c)
            if gt_data_2 is not None:
                diff_abs[~crop_mask] = 0.1 * diff_abs[~crop_mask]
                diff_abs[crop_mask
                         & (~valid_mask)] = 0.3 * diff_abs[crop_mask
                                                           & (~valid_mask)]
            output_loss = 3 * diff_abs.mean()

            diff2_clone = diff_2.clone().detach()
            rmse_loss = torch.sqrt(diff2_clone.mean())

            ################# BerHu Loss #########################
            latent_loss = torch.tensor(0.).cuda()
            if args.mode != 'RtoD_single':
                latent1 = criterion_L2(ft_map1, ft_map1_tar.detach())
                latent2 = 2.5 * criterion_L2(ft_map2, ft_map2_tar.detach())
                latent3 = 14 * criterion_L2(ft_map3, ft_map3_tar.detach())
                latent4 = 12 * criterion_L2(ft_map4, ft_map4_tar.detach())
                #print("latent1 : ",latent1.item(),"latent2 : ",latent2.item(),"latent3 : ",latent3.item(),"latent4 : ",latent4.item())
                latent_loss = 1.5 * (
                    (latent1 + latent2 + latent3 + latent4) / 4)
            ################# Latent Loss #########################
            #gradient_loss = imgrad_loss(outputs, depths) ## for kitti
            ##gradient_loss = 3.5* imgrad_loss(outputs, depths) ## for NYU
            ################# gradient loss #######################
            '''
            grad_latent_loss = torch.tensor(0.).cuda()
            if args.mode != 'RtoD_single':
                grad_latent1 = imgrad_loss(ft_map1, ft_map1_tar.detach())
                grad_latent2 = 1.5*imgrad_loss(ft_map2, ft_map2_tar.detach())
                grad_latent3 = 4*imgrad_loss(ft_map3, ft_map3_tar.detach()) ##for kitti
                ##grad_latent3 = 2*imgrad_loss(ft_map3, ft_map3_tar.detach()) ##for NYU
                grad_latent4 = 2.3*imgrad_loss(ft_map4, ft_map4_tar.detach())
                #print("g_latent1 : ",grad_latent1.item(),"g_latent2 : ",grad_latent2.item(),"g_latent3 : ",grad_latent3.item(),"g_latent4 : ",grad_latent4.item())
                ##grad_latent_loss = ((grad_latent1 + grad_latent2 + grad_latent3 + grad_latent4)/4.0) ## for kitti
                grad_latent_loss = ((grad_latent1 + grad_latent2 + grad_latent3 + grad_latent4)/4.0) ## for NYU
            ################# gradient latent loss ################
            grad_loss = (gradient_loss + grad_latent_loss)
            '''
            ################# gradient total loss ################
            depth_smoothness_tot = 0.1 * depth_smoothness(outputs, inputs)
            depth_smoothness_loss = torch.mean(torch.abs(depth_smoothness_tot))
            ################# smoothness loss ######################
            #loss = output_loss + latent_loss + grad_loss + depth_smoothness_loss
            loss = output_loss + latent_loss + depth_smoothness_loss
            if logger is not None:
                if i > 0 and n_iter % args.print_freq == 0:
                    train_writer.add_scalar('output_loss', output_loss.item(),
                                            n_iter)
                    train_writer.add_scalar('latent_loss', latent_loss.item(),
                                            n_iter)
                    train_writer.add_scalar('total_loss', loss.item(), n_iter)
                # record loss and EPE
                losses.update(loss.item(), args.batch_size)
            # zero the parameter gradients and backward & ptimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # measure elapsed time
            if logger is not None:
                batch_time.update(time.time() - end)
                end = time.time()

                with open(args.save_path / args.log_full, 'a') as csvfile:
                    writer = csv.writer(csvfile, delimiter='\t')
                    writer.writerow(
                        [loss.item(),
                         output_loss.item(),
                         latent_loss.item()])
                logger.train_bar.update(i + 1)
                if i % args.print_freq == 0:
                    logger.train_writer.write(
                        'Train: Time {} Data {} Loss {}'.format(
                            batch_time, data_time, losses))
                n_iter += 1
            if i >= args.epoch_size - 1:
                break
            ### KITTI's learning decay ###
            if (epoch > 2):
                if ((i + 1) % 2200 == 0):
                    if (lr < 0.00002):
                        lr -= (lr / 100)
                    else:
                        lr -= (lr / 60)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Decayed learning rates, lr: {}'.format(lr))
            ### NYU's learning decay ###
            '''
            if (epoch>6):
                if ((i+1) % 1900 == 0):
                    if (lr < 0.00002):
                        lr -= (lr / 200)
                    else :
                        lr -= (lr / 40)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    print ('Decayed learning rates, lr: {}'.format(lr))
            '''
            if ((i + 1) % 100 == 0):
                if ((args.local_rank + 1) % 4 == 0):
                    print("epoch: %d,  %d/%d" %
                          (epoch + 1, i + 1, args.epoch_size))
                    if args.mode == 'RtoD':
                        print(
                            "total_loss: %5f, output_loss: %5f, smoothness_loss: %5f, latent_loss: %5f"
                            %
                            (loss.item(), output_loss.item(),
                             depth_smoothness_loss.item(), latent_loss.item()))
                        #print("grad_loss: %5f, gradient_loss: %5f, grad_latent_loss: %5f"%(grad_loss.item(), gradient_loss.item(), grad_latent_loss.item()))
                    elif args.mode == 'RtoD_single':
                        print(
                            "total_loss: %5f, output_loss: %5f, smoothness_loss: %5f"
                            % (loss.item(), output_loss.item(),
                               depth_smoothness_loss.item()))
                        print("grad_loss: %5f" % (grad_loss.item()))
                '''
                total_loss = loss.item()
                rmse_loss = rmse_loss.item()
                loss_pdf = "train_loss.pdf"
                rmse_pdf = "train_rmse.pdf"
                train_loss_cnt = train_loss_cnt + 1
                all_plot(args.save_path,total_loss, rmse_loss, train_loss_list, train_rmse_list, train_loss_dir,train_loss_dir_rmse,loss_pdf, rmse_pdf, train_loss_cnt,True)
                print("")
                '''
                if ((i + 1) % 700 == 0):
                    save_image_batch(model, rgb_fixed, depth_fixed,
                                     predicted_dirs, num)
                    num = num + 1
            if ((i + 1) % 700 == 0):
                '''
                test_loss, rmse_test_loss = validate_in_test(args, val_loader, model,DtoD_model,n_epochs, logger,args.mode, crop_mask,criterion_L2)
                loss_pdf = "test_loss.pdf"
                rmse_pdf = "test_rmse.pdf"
                num_cnt = num_cnt + 1
                if((args.local_rank + 1)%4 == 0):
                    print('%d th test_set_loss :  %.4f'%(num_cnt,test_loss))
                all_plot(args.save_path,test_loss, rmse_test_loss, loss_list, rmse_list, test_loss_dir,test_loss_dir_rmse,loss_pdf, rmse_pdf, num_cnt,False)             
                '''
                if ((args.local_rank + 1) % 4 == 0):
                    output = outputs.cpu().detach().numpy()
                    save_image_tensor(output, result_dirs,
                                      'output_depth_%d.png' % (model_num + 1))
                    save_image_tensor(origin, result_gt_dirs,
                                      'origin_depth_%d.png' % (model_num + 1))

                    torch.save(
                        model.state_dict(),
                        save_dir + '/epoch_%d_AE_depth_loss_%.4f.pkl' %
                        (model_num + 1, loss))
                model_num = model_num + 1
        if logger is not None:
            ######################################################################################################
            logger.train_writer.write(' * Avg Loss : {:.3f}'.format(
                losses.avg[0]))
            ################################ evalutating on validation set ########################################
            logger.reset_valid_bar()
            errors, error_names = validate(args, val_loader, model, epoch,
                                           logger, args.mode)

            ################# training log ############################
            error_string = ', '.join(
                '{} : {:.3f}'.format(name, error)
                for name, error in zip(error_names, errors))
            logger.valid_writer.write(' * Avg {}'.format(error_string))

            for error, name in zip(errors, error_names):
                train_writer.add_scalar(name, error, epoch)

            # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
            decisive_error = errors[1]
            if best_error < 0:
                best_error = decisive_error

            # remember lowest error and save checkpoint
            is_best = decisive_error < best_error
            best_error = min(best_error, decisive_error)
            if is_best:
                torch.save(model,
                           args.save_path / 'AE_RtoD_model_best.pth.tar')

            with open(args.save_path / args.log_summary, 'a') as csvfile:
                writer = csv.writer(csvfile, delimiter='\t')
                writer.writerow([loss, decisive_error])
            ###########################################################
        ##if ((epoch+1) % 2) and ((epoch+1) > n_epochs/2):
        #if (((epoch+1) % 2) and epoch>5):
        '''
        if epoch % 1 == 0:
            #print('\n','epoch: ',epoch+1,'  loss: ',loss.item())
            print('output_loss: ',output_loss.item(),'  latent_loss: ',latent_loss.item())
            
            output = outputs.cpu().detach().numpy()
            save_image_tensor(output,result_dirs,'output_depth_%d.png'%(model_num+1))
            save_image_tensor(origin,result_gt_dirs,'origin_depth_%d.png'%(model_num+1))
            
            torch.save(model.state_dict(), save_dir+'/epoch_%d_AE_depth_loss_%.4f.pkl' %(model_num+1,loss))
            model_num = model_num + 1
        '''
        #####################################################################################
        ################### Extracting feature_map ##########################################
        if ((epoch + 1) % 10 == 0 or epoch == 0):
            if ((args.local_rank + 1) % 4 == 0):
                with torch.no_grad():
                    rft1, rft2, rft3, rft4, rft5, rft6, rft7, rout = model(
                        inputs, istrain=True)
                    ft1_gt, ft2_gt, ft3_gt, ft4_gt, ft5_gt, ft6_gt, ft7_gt, _ = DtoD_model(
                        depths, istrain=True)
                    dft1, dft2, dft3, dft4, dft5, dft6, dft7, _ = DtoD_model(
                        rout, istrain=True)
                    rftmap_list = [rft1, rft2, rft3, rft4, rft5, rft6, rft7]
                    gt_ftmap_list = [
                        ft1_gt, ft2_gt, ft3_gt, ft4_gt, ft5_gt, ft6_gt, ft7_gt
                    ]
                    dftmap_list = [dft1, dft2, dft3, dft4, dft5, dft6, dft7]

                result_dir = result_dirs + '/epoch_%d_depth' % (epoch + 1)
                if not os.path.exists(result_dir):
                    os.makedirs(result_dir)

                for kk in range(len(rftmap_list)):
                    ftmap_extract(args, num_sample_list[kk],
                                  figsize_x_list[kk], figsize_y_list[kk],
                                  d_range_list[kk], rftmap_list[kk],
                                  ftmap_height_list[kk], ftmap_width_list[kk],
                                  result_dir + '/RtoD', epoch, kk + 1)
                    ftmap_extract(args, num_sample_list[kk],
                                  figsize_x_list[kk], figsize_y_list[kk],
                                  d_range_list[kk], gt_ftmap_list[kk],
                                  ftmap_height_list[kk], ftmap_width_list[kk],
                                  result_dir + '/DtoD_gt', epoch, kk + 1)
                    ftmap_extract(args, num_sample_list[kk],
                                  figsize_x_list[kk], figsize_y_list[kk],
                                  d_range_list[kk], dftmap_list[kk],
                                  ftmap_height_list[kk], ftmap_width_list[kk],
                                  result_dir + '/DtoD', epoch, kk + 1)

                print("featmap save is finished")

                inputs_ = inputs.cpu().detach().numpy()
                save_image_tensor(origin, result_dir, 'origin_depth.png')
                save_image_tensor(inputs_, result_dir, 'origin_input.png')

                print("origin_depth save is finished")
                print("origin_image save is finished")
            #####################################################################################
            #####################################################################################
    if logger is not None:
        logger.epoch_bar.finish()

    return loss, output_loss, latent_loss
Пример #18
0
def adjust_shifts(args, train_set, adjust_loader, pose_exp_net, epoch, logger, train_writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    new_shifts = AverageMeter(args.sequence_length-1)
    pose_exp_net.train()
    poses = np.zeros(((len(adjust_loader)-1) * args.batch_size * (args.sequence_length-1),6))

    end = time.time()

    for i, (indices, tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(adjust_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img_var = Variable(tgt_img.cuda())
        ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]

        # compute output
        explainability_mask, pose_batch = pose_exp_net(tgt_img_var, ref_imgs_var)

        if i < len(adjust_loader)-1:
            step = args.batch_size*(args.sequence_length-1)
            poses[i * step:(i+1) * step] = pose_batch.data.cpu().view(-1,6).numpy()

        for index, pose in zip(indices, pose_batch):
            displacements = pose[:,:3].norm(p=2, dim=1).data.cpu().numpy()
            train_set.reset_shifts(index, displacements)
            new_shifts.update(train_set.samples[index]['ref_imgs'])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        logger.train_bar.update(i)
        if i % args.print_freq == 0:
            logger.train_writer.write('Adjustement:'
                                      'Time {} Data {} shifts {}'.format(batch_time, data_time, new_shifts))

    prefix = 'train poses'
    coeffs_names = ['tx', 'ty', 'tz']
    if args.rotation_mode == 'euler':
        coeffs_names.extend(['rx', 'ry', 'rz'])
    elif args.rotation_mode == 'quat':
        coeffs_names.extend(['qx', 'qy', 'qz'])
    for i in range(poses.shape[1]):
        train_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]), poses[:,i], epoch)

    return new_shifts.avg
Пример #19
0
def validate(val_loader,
             model,
             log,
             fgsm=False,
             eps=4,
             rand_init=False,
             mean=None,
             std=None):
    '''evaluate trained model'''
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    for i, (input, target) in enumerate(val_loader):
        if args.use_cuda:
            input = input.cuda()
            target = target.cuda()

        # check FGSM for adversarial training
        if fgsm:
            input_var = Variable(input, requires_grad=True)
            target_var = Variable(target)

            optimizer_input = torch.optim.SGD([input_var], lr=0.1)
            output = model(input_var)
            loss = criterion(output, target_var)
            optimizer_input.zero_grad()
            loss.backward()

            sign_data_grad = input_var.grad.sign()
            input = input * std + mean + eps / 255. * sign_data_grad
            input = torch.clamp(input, 0, 1)
            input = (input - mean) / std

        with torch.no_grad():
            input_var = Variable(input)
            target_var = Variable(target)

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

    if fgsm:
        print_log(
            'Attack (eps : {}) Prec@1 {top1.avg:.2f}'.format(eps, top1=top1),
            log)
    else:
        print_log(
            '  **Test** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f} Loss: {losses.avg:.3f} '
            .format(top1=top1, top5=top5, error1=100 - top1.avg,
                    losses=losses), log)
    return top1.avg, losses.avg
Пример #20
0
def train(args, train_loader, disp_net, pose_net, optimizer, epoch_size,
          logger, train_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.geometry_consistency_weight

    # switch to train mode
    disp_net.train()
    pose_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        tgt_depth, ref_depths = compute_depth(disp_net, tgt_img, ref_imgs)

        poses, poses_inv = compute_pose_with_inv(pose_net, tgt_img, ref_imgs,
                                                 intrinsics)
        #if poses is None:
        loss_1, loss_3 = compute_photo_and_geometry_loss(
            tgt_img, ref_imgs, intrinsics, tgt_depth, ref_depths, poses,
            poses_inv, args.num_scales, args.with_ssim, args.with_mask,
            args.with_auto_mask, args.padding_mode)

        loss_2 = compute_smooth_loss(tgt_depth, tgt_img, ref_depths, ref_imgs)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if log_losses:
            train_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            train_writer.add_scalar('disparity_smoothness_loss', loss_2.item(),
                                    n_iter)
            train_writer.add_scalar('geometry_consistency_loss', loss_3.item(),
                                    n_iter)
            train_writer.add_scalar('total_loss', loss.item(), n_iter)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow(
                [loss.item(),
                 loss_1.item(),
                 loss_2.item(),
                 loss_3.item()])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Пример #21
0
def train(args, train_loader, mvdnet, depth_cons, cons_loss_, optimizer,
          epoch_size, train_writer, epoch):
    global n_iter
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    d_losses = AverageMeter(precision=4)
    nmap_losses = AverageMeter(precision=4)
    cons_losses = AverageMeter(precision=4)

    # switch to training mode
    if args.train_cons:
        depth_cons.train()
    else:
        mvdnet.train()

    print("Training")
    end = time.time()

    for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics, intrinsics_inv,
            tgt_depth) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img_var = Variable(tgt_img.cuda())
        ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
        gt_nmap_var = Variable(gt_nmap.cuda())
        ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses]
        intrinsics_var = Variable(intrinsics.cuda())
        intrinsics_inv_var = Variable(intrinsics_inv.cuda())
        tgt_depth_var = Variable(tgt_depth.cuda()).cuda()

        # compute output
        pose = torch.cat(ref_poses_var, 1)

        # get mask
        mask = (tgt_depth_var <= args.nlabel * args.mindepth) & (
            tgt_depth_var >= args.mindepth) & (tgt_depth_var == tgt_depth_var)
        mask.detach_()
        if mask.any() == 0:
            continue

        if args.train_cons:
            with torch.no_grad():
                outputs = mvdnet(tgt_img_var, ref_imgs_var, pose,
                                 intrinsics_var, intrinsics_inv_var)
                output_depth1 = outputs[0]
                nmap1 = outputs[1]
        else:
            outputs = mvdnet(tgt_img_var, ref_imgs_var, pose, intrinsics_var,
                             intrinsics_inv_var)
            output_depth1 = outputs[1]
            nmap1 = outputs[2]

        if args.train_cons:
            outputs = depth_cons(output_depth1, nmap1)
            nmap = outputs[:, 1:]
            depths = [outputs[:, 0]]
        else:
            nmap = nmap1.permute(0, 3, 1, 2)
            depths = [output_depth1.squeeze(1)]

        loss = 0.
        d_loss = 0.
        nmap_loss = 0.
        cons_loss = 0.

        for l, depth in enumerate(depths):
            output = torch.squeeze(depth, 1)
            d_loss = d_loss + F.smooth_l1_loss(output[mask],
                                               tgt_depth_var[mask])

        n_mask = mask.unsqueeze(1).expand(-1, 3, -1, -1)
        nmap_loss = nmap_loss + F.smooth_l1_loss(nmap[n_mask],
                                                 gt_nmap_var[n_mask])

        if args.train_cons:
            cons_loss = cons_loss + cons_loss_(
                depths[-1].unsqueeze(1), tgt_depth_var.unsqueeze(1),
                nmap.clone(), intrinsics_var, mask.unsqueeze(1))
            cons_losses.update(cons_loss.item(), args.batch_size)
        loss = loss + args.d_weight * d_loss + args.n_weight * nmap_loss + args.c_weight * cons_loss

        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('total_loss', loss.item(), n_iter)
        # record loss and EPE
        losses.update(loss.item(), args.batch_size)
        d_losses.update(d_loss.item(), args.batch_size)
        nmap_losses.update(nmap_loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([loss.item()])
        if i % args.print_freq == 0:
            print(
                'Train: Time {} Data {} Loss {} NmapLoss {} DLoss {} ConsLoss {}Iter {}/{} Epoch {}/{}'
                .format(batch_time, data_time, losses, nmap_losses, d_losses,
                        cons_losses, i, len(train_loader), epoch, args.epochs))

        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Пример #22
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_net,
                        epoch,
                        logger,
                        output_writers=[]):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=4, precision=4)
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    disp_net.eval()
    pose_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        # compute output
        tgt_depth = [1 / disp_net(tgt_img)]
        ref_depths = []
        for ref_img in ref_imgs:
            ref_depth = [1 / disp_net(ref_img)]
            ref_depths.append(ref_depth)

        if log_outputs and i < len(output_writers):
            if epoch == 0:
                output_writers[i].add_image('val Input',
                                            tensor2array(tgt_img[0]), 0)

            output_writers[i].add_image(
                'val Dispnet Output Normalized',
                tensor2array(1 / tgt_depth[0][0],
                             max_value=None,
                             colormap='magma'), epoch)
            output_writers[i].add_image(
                'val Depth Output', tensor2array(tgt_depth[0][0],
                                                 max_value=10), epoch)

        poses, poses_inv = compute_pose_with_inv(pose_net, tgt_img, ref_imgs,
                                                 intrinsics)

        loss_1, loss_3 = compute_photo_and_geometry_loss(
            tgt_img, ref_imgs, intrinsics, tgt_depth, ref_depths, poses,
            poses_inv, args.num_scales, args.with_ssim, args.with_mask, False,
            args.padding_mode)

        loss_2 = compute_smooth_loss(tgt_depth, tgt_img, ref_depths, ref_imgs)

        loss_1 = loss_1.item()
        loss_2 = loss_2.item()
        loss_3 = loss_3.item()

        loss = loss_1
        losses.update([loss, loss_1, loss_2, loss_3])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))

    logger.valid_bar.update(len(val_loader))
    return losses.avg, [
        'Total loss', 'Photo loss', 'Smooth loss', 'Consistency loss'
    ]
Пример #23
0
def main():
    global global_vars_dict
    args = global_vars_dict['args']
    best_error = -1  #best model choosing

    #mkdir
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")

    args.save_path = Path('checkpoints') / Path(args.data_dir).stem / timestamp

    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.alternating:
        args.alternating_flags = np.array([False, False, True])
    #mk writers
    tb_writer = SummaryWriter(args.save_path)

    # Data loading code and transpose

    if args.data_normalization == 'global':
        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
    elif args.data_normalization == 'local':
        normalize = custom_transforms.NormalizeLocally()

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data_dir))

    train_transform = custom_transforms.Compose([
        #custom_transforms.RandomRotate(),
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])

    #train set, loader only建立一个
    from datasets.sequence_mc import SequenceFolder
    train_set = SequenceFolder(  # mc data folder
        args.data_dir,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length,  # 5
        target_transform=None,
        depth_format='png')

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

#val set,loader 挨个建立
#if args.val_with_depth_gt:
    from datasets.validation_folders2 import ValidationSet

    val_set_with_depth_gt = ValidationSet(args.data_dir,
                                          transform=valid_transform,
                                          depth_format='png')

    val_loader_depth = torch.utils.data.DataLoader(val_set_with_depth_gt,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=True)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))

    #1 create model
    print("=> creating model")
    #1.1 disp_net
    disp_net = getattr(models, args.dispnet)().cuda()
    output_exp = True  #args.mask_loss_weight > 0

    if args.pretrained_disp:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_disp))
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    if args.resume:
        print("=> resuming from checkpoint")
        dispnet_weights = torch.load(args.save_path /
                                     'dispnet_checkpoint.pth.tar')
        disp_net.load_state_dict(dispnet_weights['state_dict'])

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters())

    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    if args.resume and (args.save_path /
                        'optimizer_checkpoint.pth.tar').exists():
        print("=> loading optimizer from checkpoint")
        optimizer_weights = torch.load(args.save_path /
                                       'optimizer_checkpoint.pth.tar')
        optimizer.load_state_dict(optimizer_weights['state_dict'])

    #
    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader_depth))
        logger.reset_epoch_bar()
    else:
        logger = None


#预先评估下
    criterion_train = MaskedL1Loss().to(device)  # l1LOSS 容易优化
    criterion_val = ComputeErrors().to(device)

    #depth_error_names,depth_errors = validate_depth_with_gt(val_loader_depth, disp_net,criterion=criterion_val, epoch=0, logger=logger,tb_writer=tb_writer,global_vars_dict=global_vars_dict)

    #logger.reset_epoch_bar()
    #    logger.epoch_logger_update(epoch=0,time=0,names=depth_error_names,values=depth_errors)
    epoch_time = AverageMeter()
    end = time.time()
    #3. main cycle
    for epoch in range(1, args.epochs):  #epoch 0 在第没入循环之前已经测试了.

        logger.reset_train_bar()
        logger.reset_valid_bar()

        errors = [0]
        error_names = ['no error names depth']

        #3.2 train for one epoch---------
        loss_names, losses = train_depth_gt(train_loader=train_loader,
                                            disp_net=disp_net,
                                            optimizer=optimizer,
                                            criterion=criterion_train,
                                            logger=logger,
                                            train_writer=tb_writer,
                                            global_vars_dict=global_vars_dict)

        #3.3 evaluate on validation set-----
        depth_error_names, depth_errors = validate_depth_with_gt(
            val_loader=val_loader_depth,
            disp_net=disp_net,
            criterion=criterion_val,
            epoch=epoch,
            logger=logger,
            tb_writer=tb_writer,
            global_vars_dict=global_vars_dict)

        epoch_time.update(time.time() - end)
        end = time.time()

        #3.5 log_terminal
        #if args.log_terminal:
        if args.log_terminal:
            logger.epoch_logger_update(epoch=epoch,
                                       time=epoch_time,
                                       names=depth_error_names,
                                       values=depth_errors)

    # tensorboard scaler
    #train loss
        for loss_name, loss in zip(loss_names, losses.avg):
            tb_writer.add_scalar('train/' + loss_name, loss, epoch)

        #val_with_gt loss
        for name, error in zip(depth_error_names, depth_errors.avg):
            tb_writer.add_scalar('val/' + name, error, epoch)

        #3.6 save model and remember lowest error and save checkpoint
        total_loss = losses.avg[0]
        if best_error < 0:
            best_error = total_loss

        is_best = total_loss <= best_error
        best_error = min(best_error, total_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, is_best)

    if args.log_terminal:
        logger.epoch_bar.finish()
Пример #24
0
def train(args, train_loader, bio_net, optimizer, epoch_size, logger):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)

    # switch to train mode
    bio_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (sample, value) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        data = sample.to(device)

        # compute output
        estimated_y = bio_net(data)
        estimated_y = estimated_y.view(-1)

        value = value.float()
        value = value.to(device)
        loss = value - estimated_y
        # print("value is:", value.size())
        # print("estimated_y:", estimated_y.size())

        # record loss and EPE
        # print("loss", loss.size())
        # print("loss.item:", loss)

        loss_sum = torch.sum(loss.data)
        loss_sum = Variable(loss_sum, requires_grad=True)
        # print("loss.item:", loss_sum, loss_sum.item())
        # print("args.batch_size", args.batch_size)
        losses.update(loss_sum.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss_sum.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Пример #25
0
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          logger, train_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        disparities = disp_net(tgt_img)
        depth = [1 / disp for disp in disparities]
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        loss_3 = inverse_depth_smooth_loss(depth, tgt_img)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if log_losses:
            train_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            if w2 > 0:
                train_writer.add_scalar('explanability_loss', loss_2.item(),
                                        n_iter)
            train_writer.add_scalar('disparity_smoothness_loss', loss_3.item(),
                                    n_iter)
            train_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output:
            train_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                   n_iter)
            for k, scaled_maps in enumerate(
                    zip(depth, disparities, warped, diff,
                        explainability_mask)):
                log_output_tensorboard(train_writer, "train", k, n_iter,
                                       *scaled_maps)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item()
            ])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Пример #26
0
def validate(val_loader,
             disp_net,
             pose_exp_net,
             epoch,
             logger,
             output_writers=[]):
    global args
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = len(output_writers) > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        # compute output
        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var)

        loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var,
                                                 intrinsics_var,
                                                 intrinsics_inv_var, depth,
                                                 explainability_mask, pose)
        loss_1 = loss_1.data[0]
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).data[0]
        else:
            loss_2 = 0
        loss_3 = smooth_loss(disp).data[0]

        if log_outputs and i % 100 == 0 and i / 100 < len(
                output_writers):  # log first output of every 100 batch
            index = int(i // 100)
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(tgt_img[0]),
                                                    0)
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(ref[0]), 1)

            output_writers[index].add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), epoch)
            output_writers[index].add_image(
                'val Depth Output',
                tensor2array(1. / disp.data[0].cpu(), max_value=10), epoch)
            # log warped images along with explainability mask
            for j, ref in enumerate(ref_imgs_var):
                ref_warped = inverse_warp(ref[:1], depth[:1, 0], pose[:1, j],
                                          intrinsics_var[:1],
                                          intrinsics_inv_var[:1])[0]
                output_writers[index].add_image(
                    'val Warped Outputs {}'.format(j),
                    tensor2array(ref_warped.data.cpu()), epoch)
                output_writers[index].add_image(
                    'val Diff Outputs {}'.format(j),
                    tensor2array(
                        0.5 * (tgt_img_var[0] - ref_warped).abs().data.cpu()),
                    epoch)
                if explainability_mask is not None:
                    output_writers[index].add_image(
                        'val Exp mask Outputs {}'.format(j),
                        tensor2array(explainability_mask[0, j].data.cpu(),
                                     max_value=1,
                                     colormap='bone'), epoch)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1,
                                                                  6).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        output_writers[0].add_histogram('val poses_tx', poses[:, 0], epoch)
        output_writers[0].add_histogram('val poses_ty', poses[:, 1], epoch)
        output_writers[0].add_histogram('val poses_tz', poses[:, 2], epoch)
        output_writers[0].add_histogram('val poses_rx', poses[:, 3], epoch)
        output_writers[0].add_histogram('val poses_ry', poses[:, 4], epoch)
        output_writers[0].add_histogram('val poses_rz', poses[:, 5], epoch)

    return losses.avg
Пример #27
0
def validate_with_gt(args,
                     val_loader,
                     disp_net,
                     epoch,
                     logger,
                     output_writers=[]):
    global device
    batch_time = AverageMeter()
    error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3']
    errors = AverageMeter(i=len(error_names))
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    disp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, depth) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        depth = depth.to(device)

        # compute output
        output_disp = disp_net(tgt_img)
        output_depth = 1 / output_disp[:, 0]

        if log_outputs and i < len(output_writers):
            if epoch == 0:
                output_writers[i].add_image('val Input',
                                            tensor2array(tgt_img[0]), 0)
                depth_to_show = depth[0]
                output_writers[i].add_image(
                    'val target Depth',
                    tensor2array(depth_to_show, max_value=10), epoch)
                depth_to_show[depth_to_show == 0] = 1000
                disp_to_show = (1 / depth_to_show).clamp(0, 10)
                output_writers[i].add_image(
                    'val target Disparity Normalized',
                    tensor2array(disp_to_show,
                                 max_value=None,
                                 colormap='magma'), epoch)

            output_writers[i].add_image(
                'val Dispnet Output Normalized',
                tensor2array(output_disp[0], max_value=None, colormap='magma'),
                epoch)
            output_writers[i].add_image(
                'val Depth Output', tensor2array(output_depth[0],
                                                 max_value=10), epoch)

        errors.update(compute_errors(depth, output_depth))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write(
                'valid: Time {} Abs Error {:.4f} ({:.4f})'.format(
                    batch_time, errors.val[0], errors.avg[0]))
    logger.valid_bar.update(len(val_loader))
    return errors.avg, error_names
Пример #28
0
def main():
    global args
    args = parser.parse_args()

    args.pretrained_disp = Path(args.pretrained_disp)
    args.pretrained_pose = Path(args.pretrained_pose)
    args.pretrained_mask = Path(args.pretrained_mask)
    args.pretrained_flow = Path(args.pretrained_flow)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        gt_dir = args.output_dir / 'gt'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'
        rigidity_mask_dir = args.output_dir / 'rigidity'
        rigidity_census_mask_dir = args.output_dir / 'rigidity_census'
        explainability_mask_dir = args.output_dir / 'explainability'

        image_dir.makedirs_p()
        gt_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()
        rigidity_mask_dir.makedirs_p()
        rigidity_census_mask_dir.makedirs_p()
        explainability_mask_dir.makedirs_p()

        output_writer = SummaryWriter(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    val_flow_set = ValidationMask(root=args.kitti_dir,
                                  sequence_length=5,
                                  transform=valid_flow_transform)

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda()
    mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_disp)
    posenet_weights = torch.load(args.pretrained_pose)
    masknet_weights = torch.load(args.pretrained_mask)
    flownet_weights = torch.load(args.pretrained_flow)
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    error_names = ['tp_0', 'fp_0', 'fn_0', 'tp_1', 'fp_1', 'fn_1']
    errors = AverageMeter(i=len(error_names))
    errors_census = AverageMeter(i=len(error_names))
    errors_bare = AverageMeter(i=len(error_names))

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map_gt,
            semantic_map_gt) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)
        if args.flownet in ['Back2Future']:
            flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3])
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])
        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5
        rigidity_mask_census_soft = (flow_cam - flow_fwd).pow(2).sum(
            dim=1).unsqueeze(1).sqrt()  #.normalize()
        rigidity_mask_census_soft = 1 - rigidity_mask_census_soft / rigidity_mask_census_soft.max(
        )
        rigidity_mask_census = rigidity_mask_census_soft > args.THRESH

        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        flow_fwd_non_rigid = (1 - rigidity_mask_combined).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = rigidity_mask_combined.type_as(flow_fwd).expand_as(
            flow_fwd) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )
        rigidity_mask_census_np = rigidity_mask_census.cpu().data[0].numpy()
        rigidity_mask_bare_np = rigidity_mask.cpu().data[0].numpy()

        gt_mask_np = obj_map_gt[0].numpy()
        semantic_map_np = semantic_map_gt[0].numpy()

        _errors = mask_error(gt_mask_np, semantic_map_np,
                             rigidity_mask_combined_np[0])
        _errors_census = mask_error(gt_mask_np, semantic_map_np,
                                    rigidity_mask_census_np[0])
        _errors_bare = mask_error(gt_mask_np, semantic_map_np,
                                  rigidity_mask_bare_np[0])

        errors.update(_errors)
        errors_census.update(_errors_census)
        errors_bare.update(_errors_bare)

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(gt_dir / str(i).zfill(3), gt_mask_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)
            np.save(rigidity_mask_dir / str(i).zfill(3),
                    rigidity_mask.cpu().data[0].numpy())
            np.save(rigidity_census_mask_dir / str(i).zfill(3),
                    rigidity_mask_census.cpu().data[0].numpy())
            np.save(explainability_mask_dir / str(i).zfill(3),
                    explainability_mask[:, 1].cpu().data[0].numpy())
            # rigidity_mask_dir rigidity_mask.numpy()
            # rigidity_census_mask_dir rigidity_mask_census.numpy()

        if (args.output_dir is not None) and i % 10 == 0:
            ind = int(i // 10)
            output_writer.add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), ind)
            output_writer.add_image('val Input',
                                    tensor2array(tgt_img[0].cpu()), i)
            output_writer.add_image(
                'val Total Flow Output',
                flow_to_image(tensor2array(total_flow.data[0].cpu())), ind)
            output_writer.add_image(
                'val Rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind)
            output_writer.add_image(
                'val Non-rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())),
                ind)
            output_writer.add_image(
                'val Rigidity Mask',
                tensor2array(rigidity_mask.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Census',
                tensor2array(rigidity_mask_census.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Combined',
                tensor2array(rigidity_mask_combined.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)

        if args.output_dir is not None:
            tgt_img_viz = tensor2array(tgt_img[0].cpu())
            depth_viz = tensor2array(disp.data[0].cpu(),
                                     max_value=None,
                                     colormap='magma')
            mask_viz = tensor2array(rigidity_mask_census_soft.data[0].cpu(),
                                    max_value=1,
                                    colormap='bone')
            row2_viz = flow_to_image(
                np.hstack((tensor2array(flow_cam.data[0].cpu()),
                           tensor2array(flow_fwd_non_rigid.data[0].cpu()),
                           tensor2array(total_flow.data[0].cpu()))))

            row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz))
            ####### sửa 2 cái vstack thành hstack ###############
            viz3 = np.hstack(
                (255 * tgt_img_viz, 255 * depth_viz, 255 * mask_viz,
                 flow_to_image(
                     np.hstack((tensor2array(flow_fwd_non_rigid.data[0].cpu()),
                                tensor2array(total_flow.data[0].cpu()))))))
            ########################################################
            ######## code tự thêm ####################
            row1_viz = np.transpose(row1_viz, (1, 2, 0))
            row2_viz = np.transpose(row2_viz, (1, 2, 0))
            viz3 = np.transpose(viz3, (1, 2, 0))
            ##########################################

            row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8'))
            row2_viz_im = Image.fromarray((row2_viz).astype('uint8'))
            viz3_im = Image.fromarray(viz3.astype('uint8'))

            row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png')
            row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png')
            viz3_im.save(viz_dir / str(i).zfill(3) + '03.png')

    bg_iou = errors.sum[0] / (errors.sum[0] + errors.sum[1] + errors.sum[2])
    fg_iou = errors.sum[3] / (errors.sum[3] + errors.sum[4] + errors.sum[5])
    avg_iou = (bg_iou + fg_iou) / 2

    bg_iou_census = errors_census.sum[0] / (
        errors_census.sum[0] + errors_census.sum[1] + errors_census.sum[2])
    fg_iou_census = errors_census.sum[3] / (
        errors_census.sum[3] + errors_census.sum[4] + errors_census.sum[5])
    avg_iou_census = (bg_iou_census + fg_iou_census) / 2

    bg_iou_bare = errors_bare.sum[0] / (
        errors_bare.sum[0] + errors_bare.sum[1] + errors_bare.sum[2])
    fg_iou_bare = errors_bare.sum[3] / (
        errors_bare.sum[3] + errors_bare.sum[4] + errors_bare.sum[5])
    avg_iou_bare = (bg_iou_bare + fg_iou_bare) / 2

    print("Results Full Model")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou, bg_iou, fg_iou))

    print("Results Census only")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou_census, bg_iou_census, fg_iou_census))

    print("Results Bare")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou_bare, bg_iou_bare, fg_iou_bare))
def train(args, train_loader, disp_net, pose_net, optimizer, epoch_size, logger, tb_writer, w1, w3):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)

    # Set the networks to training mode, batch norm and dropout are handled accordingly
    disp_net.train()
    pose_net.train()

    end = time.time()
    logger.train_bar.start()
    logger.train_bar.update(0)

    for i, trainingdata in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        tgt_lf = trainingdata['tgt_lf'].to(device)
        ref_lfs = [img.to(device) for img in trainingdata['ref_lfs']]

        if args.lfformat == "epi" and args.cameras_epi == "full":
            # in this case we have separate horizontal and vertical epis
            tgt_lf_formatted_h = trainingdata['tgt_lf_formatted_h'].to(device)
            tgt_lf_formatted_v = trainingdata['tgt_lf_formatted_v'].to(device)
            ref_lfs_formatted_h = [lf.to(device) for lf in trainingdata['ref_lfs_formatted_h']]
            ref_lfs_formatted_v = [lf.to(device) for lf in trainingdata['ref_lfs_formatted_v']]

            # stacked images
            tgt_stack = trainingdata['tgt_stack'].to(device)
            ref_stacks = [lf.to(device) for lf in trainingdata['ref_stacks']]

            # Encode the epi images further
            if args.without_disp_stack:
                # Stacked images should not be concatenated with the encoded EPI images
                tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted_v, None, tgt_lf_formatted_h)
            else:
                # Stacked images should be concatenated with the encoded EPI images
                tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted_v, tgt_stack, tgt_lf_formatted_h)

            tgt_lf_encoded_p, ref_lfs_encoded_p = pose_net.encode(tgt_lf_formatted_v, tgt_stack,
                                                                  ref_lfs_formatted_v, ref_stacks,
                                                                  tgt_lf_formatted_h, ref_lfs_formatted_h)
        else:
            tgt_lf_formatted = trainingdata['tgt_lf_formatted'].to(device)
            ref_lfs_formatted = [lf.to(device) for lf in trainingdata['ref_lfs_formatted']]

            # Encode the images if necessary
            if disp_net.has_encoder():
                # This will only be called for epi with horizontal or vertical only encoding
                if args.without_disp_stack:
                    # Stacked images should not be concatenated with the encoded EPI images
                    tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted, None)
                else:
                    # Stacked images should be concatenated with the encoded EPI images
                    # NOTE: Here we stack all 17 images, not 5. Here the images missing from the encoding,
                    # are covered in the stack. We are not using this case in the paper at all.
                    tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted, tgt_lf)
            else:
                # This will be called for focal stack and stack, where there is no encoding
                tgt_lf_encoded_d = tgt_lf_formatted

            if pose_net.has_encoder():
                tgt_lf_encoded_p, ref_lfs_encoded_p = pose_net.encode(tgt_lf_formatted, tgt_lf,
                                                                      ref_lfs_formatted, ref_lfs)
            else:
                tgt_lf_encoded_p = tgt_lf_formatted
                ref_lfs_encoded_p = ref_lfs_formatted

        # compute output of networks
        disparities = disp_net(tgt_lf_encoded_d)
        depth = [1/disp for disp in disparities]
        pose = pose_net(tgt_lf_encoded_p, ref_lfs_encoded_p)

        # if i==0:
        #     tb_writer.add_graph(disp_net, tgt_lf_encoded_d)
        #     tb_writer.add_graph(pose_net, (tgt_lf_encoded_p, ref_lfs_encoded_p))

        # compute photometric error
        intrinsics = trainingdata['intrinsics'].to(device)
        pose_gt_tgt_refs = trainingdata['pose_gt_tgt_refs'].to(device)
        metadata = trainingdata['metadata']
        photometric_error, warped, diff = multiwarp_photometric_loss(
            tgt_lf, ref_lfs, intrinsics, depth, pose, metadata, args.rotation_mode, args.padding_mode
        )

        # smoothness_error = smooth_loss(depth)                             # smoothness error
        smoothness_error = total_variation_loss(depth, sum_or_mean="mean")  # total variation error
        # smoothness_error = total_variation_squared_loss(depth)            # total variation error squared version
        mean_distance_error, mean_angle_error = pose_loss(pose, pose_gt_tgt_refs)

        loss = w1 + torch.exp(-1.0 * w1) * photometric_error + w3 + torch.exp(-1.0 * w3) * smoothness_error

        if log_losses:
            tb_writer.add_scalar(tag='train/photometric_error', scalar_value=photometric_error.item(), global_step=n_iter)
            tb_writer.add_scalar(tag='train/smoothness_loss', scalar_value=smoothness_error.item(), global_step=n_iter)
            tb_writer.add_scalar(tag='train/total_loss', scalar_value=loss.item(), global_step=n_iter)
            tb_writer.add_scalar(tag='train/mean_distance_error', scalar_value=mean_distance_error.item(), global_step=n_iter)
            tb_writer.add_scalar(tag='train/mean_angle_error', scalar_value=mean_angle_error.item(), global_step=n_iter)
        if log_output:
            if args.lfformat == "epi" and args.cameras_epi == "full":
                b, n, h, w = tgt_lf_formatted_v.shape
                vis_img = tgt_lf_formatted_v[0, 0, :, :].detach().cpu().numpy().reshape(1, h, w) * 0.5 + 0.5
            else:
                b, n, h, w = tgt_lf_formatted.shape
                vis_img = tgt_lf_formatted[0, 0, :, :].detach().cpu().numpy().reshape(1, h, w) * 0.5 + 0.5

            b, n, h, w = depth[0].shape
            vis_depth = tensor2array(depth[0][0, 0, :, :], colormap='magma')
            vis_disp = tensor2array(disparities[0][0, 0, :, :], colormap='magma')
            vis_enc_f = tgt_lf_encoded_d[0, 0, :, :].detach().cpu().numpy().reshape(1, h, w) * 0.5 + 0.5
            vis_enc_b = tgt_lf_encoded_d[0, -1, :, :].detach().cpu().numpy().reshape(1, h, w) * 0.5 + 0.5

            tb_writer.add_image('train/input', vis_img, n_iter)
            tb_writer.add_image('train/encoded_front', vis_enc_f, n_iter)
            tb_writer.add_image('train/encoded_back', vis_enc_b, n_iter)
            tb_writer.add_image('train/depth', vis_depth, n_iter)
            tb_writer.add_image('train/disp', vis_disp, n_iter)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path + "/" + args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([loss.item(), photometric_error.item(), smoothness_error.item(),
                             mean_distance_error.item(), mean_angle_error.item()])
        logger.train_bar.update(i+1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    logger.train_bar.finish()
    return losses.avg[0]
Пример #30
0
def validate_with_gt(args,
                     val_loader,
                     depth_net,
                     pose_net,
                     epoch,
                     logger,
                     output_writers=[],
                     **env):
    global device
    batch_time = AverageMeter()
    depth_error_names = ['abs diff', 'abs rel', 'sq rel', 'a1', 'a2', 'a3']
    stab_depth_errors = AverageMeter(i=len(depth_error_names))
    unstab_depth_errors = AverageMeter(i=len(depth_error_names))
    pose_error_names = ['Absolute Trajectory Error', 'Rotation Error']
    pose_errors = AverageMeter(i=len(pose_error_names))

    # switch to evaluate mode
    depth_net.eval()
    pose_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, sample in enumerate(val_loader):
        log_output = i < len(output_writers)

        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        batch_size, seq, c, h, w = imgs.size()

        intrinsics = sample['intrinsics'].to(device)
        intrinsics_inv = sample['intrinsics_inv'].to(device)

        if args.network_input_size is not None:
            imgs = F.interpolate(imgs, (c, *args.network_input_size),
                                 mode='area')

            downscale = h / args.network_input_size[0]
            intrinsics = torch.cat(
                (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1)
            intrinsics_inv = torch.cat(
                (intrinsics_inv[:, :, 0:2] * downscale, intrinsics_inv[:, :,
                                                                       2:]),
                dim=2)

        GT_depth = sample['depth'].to(device)
        GT_pose = sample['pose'].to(device)

        mid_index = (args.sequence_length - 1) // 2

        tgt_img = imgs[:, mid_index]

        if epoch == 1 and log_output:
            for j, img in enumerate(sample['imgs']):
                output_writers[i].add_image('val Input', tensor2array(img[0]),
                                            j)
            depth_to_show = GT_depth[0].cpu()
            # KITTI Like data routine to discard invalid data
            depth_to_show[depth_to_show == 0] = 1000
            disp_to_show = (1 / depth_to_show).clamp(0, 10)
            output_writers[i].add_image(
                'val target Disparity Normalized',
                tensor2array(disp_to_show, max_value=None, colormap='bone'),
                epoch)

        poses = pose_net(imgs)
        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]
        inverted_pose_matrices = invert_mat(pose_matrices)
        pose_errors.update(
            compute_pose_error(GT_pose[:, :-1],
                               inverted_pose_matrices.data[:, :-1]))

        tgt_poses = pose_matrices[:, mid_index]  # [B, 3, 4]
        compensated_predicted_poses = compensate_pose(pose_matrices, tgt_poses)
        compensated_GT_poses = compensate_pose(GT_pose, GT_pose[:, mid_index])

        for j in range(args.sequence_length):
            if j == mid_index:
                if log_output and epoch == 1:
                    output_writers[i].add_image(
                        'val Input Stabilized',
                        tensor2array(sample['imgs'][j][0]), j)
                continue
            '''compute displacement magnitude for each element of batch, and rescale
            depth accordingly.'''

            prior_img = imgs[:, j]
            displacement = compensated_GT_poses[:, j, :, -1]  # [B,3]
            displacement_magnitude = displacement.norm(p=2, dim=1)  # [B]
            current_GT_depth = GT_depth * args.nominal_displacement / displacement_magnitude.view(
                -1, 1, 1)

            prior_predicted_pose = compensated_predicted_poses[:,
                                                               j]  # [B, 3, 4]
            prior_GT_pose = compensated_GT_poses[:, j]

            prior_predicted_rot = prior_predicted_pose[:, :, :-1]
            prior_GT_rot = prior_GT_pose[:, :, :-1].transpose(1, 2)

            prior_compensated_from_GT = inverse_rotate(prior_img, prior_GT_rot,
                                                       intrinsics,
                                                       intrinsics_inv)
            if log_output and epoch == 1:
                depth_to_show = current_GT_depth[0]
                output_writers[i].add_image(
                    'val target Depth {}'.format(j),
                    tensor2array(depth_to_show, max_value=args.max_depth),
                    epoch)
                output_writers[i].add_image(
                    'val Input Stabilized',
                    tensor2array(prior_compensated_from_GT[0]), j)

            prior_compensated_from_prediction = inverse_rotate(
                prior_img, prior_predicted_rot, intrinsics, intrinsics_inv)
            predicted_input_pair = torch.cat(
                [prior_compensated_from_prediction, tgt_img],
                dim=1)  # [B, 6, W, H]
            GT_input_pair = torch.cat([prior_compensated_from_GT, tgt_img],
                                      dim=1)  # [B, 6, W, H]

            # This is the depth from footage stabilized with GT pose, it should be better than depth from raw footage without any GT info
            raw_depth_stab = depth_net(GT_input_pair)
            raw_depth_unstab = depth_net(predicted_input_pair)

            # Upsample depth so that it matches GT size
            scale_factor = GT_depth.size(-1) // raw_depth_stab.size(-1)
            depth_stab = F.interpolate(raw_depth_stab,
                                       scale_factor=scale_factor,
                                       mode='bilinear',
                                       align_corners=False)
            depth_unstab = F.interpolate(raw_depth_unstab,
                                         scale_factor=scale_factor,
                                         mode='bilinear',
                                         align_corners=False)

            for k, depth in enumerate([depth_stab, depth_unstab]):
                disparity = 1 / depth
                errors = stab_depth_errors if k == 0 else unstab_depth_errors
                errors.update(
                    compute_depth_errors(current_GT_depth, depth, crop=True))
                if log_output:
                    prefix = 'stabilized' if k == 0 else 'unstabilized'
                    output_writers[i].add_image(
                        'val {} Dispnet Output Normalized {}'.format(
                            prefix, j),
                        tensor2array(disparity[0],
                                     max_value=None,
                                     colormap='bone'), epoch)
                    output_writers[i].add_image(
                        'val {} Depth Output {}'.format(prefix, j),
                        tensor2array(depth[0], max_value=args.max_depth),
                        epoch)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write(
                'valid: Time {} ATE Error {:.4f} ({:.4f}), Unstab Rel Abs Error {:.4f} ({:.4f})'
                .format(batch_time, pose_errors.val[0], pose_errors.avg[0],
                        unstab_depth_errors.val[1],
                        unstab_depth_errors.avg[1]))
    logger.valid_bar.update(len(val_loader))

    errors = (*pose_errors.avg, *unstab_depth_errors.avg,
              *stab_depth_errors.avg)
    error_names = (*pose_error_names,
                   *['unstab {}'.format(e) for e in depth_error_names],
                   *['stab {}'.format(e) for e in depth_error_names])

    return OrderedDict(zip(error_names, errors))