Exemplo n.º 1
0
def train(args):

    model = RAFT(args)
    model = nn.DataParallel(model)
    print("Parameter Count: %d" % count_parameters(model))

    if args.restore_ckpt is not None:
        model.load_state_dict(torch.load(args.restore_ckpt))

    model.cuda()
    model.train()

    if 'chairs' not in args.dataset:
        model.module.freeze_bn()

    train_loader = fetch_dataloader(args)
    optimizer, scheduler = fetch_optimizer(args, model)

    total_steps = 0
    logger = Logger(model, scheduler)

    should_keep_training = True
    while should_keep_training:

        for i_batch, data_blob in enumerate(train_loader):
            image1, image2, flow, valid = [x.cuda() for x in data_blob]

            optimizer.zero_grad()
            flow_predictions = model(image1, image2, iters=args.iters)

            loss, metrics = sequence_loss(flow_predictions, flow, valid)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
            scheduler.step()
            total_steps += 1

            logger.push(metrics)

            if total_steps % VAL_FREQ == VAL_FREQ - 1:
                PATH = 'checkpoints/%d_%s.pth' % (total_steps + 1, args.name)
                torch.save(model.state_dict(), PATH)

            if total_steps == args.num_steps:
                should_keep_training = False
                break

    PATH = 'checkpoints/%s.pth' % args.name
    torch.save(model.state_dict(), PATH)

    return PATH
Exemplo n.º 2
0
def train(gpu, ngpus_per_node, args):
    print("Using GPU %d for training" % gpu)
    args.gpu = gpu

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=ngpus_per_node,
                                rank=args.gpu)

    model = RAFT(args)
    if args.distributed:
        torch.cuda.set_device(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model)
        model = model.to(f'cuda:{args.gpu}')
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.gpu],
            find_unused_parameters=True,
            output_device=args.gpu)
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()
    logroot = os.path.join(args.logroot, args.name)
    print("Parameter Count: %d, saving location: %s" %
          (count_parameters(model), logroot))

    if args.restore_ckpt is not None:
        print("=> loading checkpoint '{}'".format(args.restore_ckpt))
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.restore_ckpt, map_location=loc)
        model.load_state_dict(checkpoint, strict=False)

    model.train()

    if args.stage != 'chairs':
        model.module.freeze_bn()

    train_entries, evaluation_entries = read_splits()
    aug_params = {
        'crop_size': args.image_size,
        'min_scale': -0.2,
        'max_scale': 0.4,
        'do_flip': False
    }
    train_dataset = VirtualKITTI2(aug_params,
                                  split='training',
                                  root=args.dataset_root,
                                  entries=train_entries)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset) if args.distributed else None
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   pin_memory=False,
                                   shuffle=(train_sampler is None),
                                   num_workers=args.num_workers,
                                   drop_last=True,
                                   sampler=train_sampler)

    eval_dataset = VirtualKITTI2(split='evaluation',
                                 root=args.dataset_root,
                                 entries=evaluation_entries)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(
        eval_dataset) if args.distributed else None
    eval_loader = data.DataLoader(eval_dataset,
                                  batch_size=args.batch_size,
                                  pin_memory=False,
                                  shuffle=(eval_sampler is None),
                                  num_workers=args.num_workers,
                                  drop_last=True,
                                  sampler=eval_sampler)

    if args.distributed:
        group = dist.new_group([i for i in range(ngpus_per_node)])

    optimizer, scheduler = fetch_optimizer(args, model)

    total_steps = 0
    scaler = GradScaler(enabled=args.mixed_precision)

    if args.gpu == 0:
        logger = Logger(model, scheduler, logroot)
        logger_evaluation = Logger(
            model, scheduler,
            os.path.join(args.logroot, 'evaluation_VRKitti', args.name))

    VAL_FREQ = 500
    add_noise = True
    epoch = 0

    should_keep_training = True
    while should_keep_training:

        for i_batch, data_blob in enumerate(train_loader):
            optimizer.zero_grad()
            image1, image2, flow, valid = data_blob

            image1 = Variable(image1, requires_grad=True)
            image1 = image1.cuda(gpu, non_blocking=True)

            image2 = Variable(image2, requires_grad=True)
            image2 = image2.cuda(gpu, non_blocking=True)

            flow = Variable(flow, requires_grad=True)
            flow = flow.cuda(gpu, non_blocking=True)

            valid = Variable(valid, requires_grad=True)
            valid = valid.cuda(gpu, non_blocking=True)

            if add_noise:
                stdv = np.random.uniform(0.0, 5.0)
                image1 = (image1 + stdv * torch.randn(*image1.shape).cuda(
                    gpu, non_blocking=True)).clamp(0.0, 255.0)
                image2 = (image2 + stdv * torch.randn(*image2.shape).cuda(
                    gpu, non_blocking=True)).clamp(0.0, 255.0)

            flow_predictions = model(image1, image2, iters=args.iters)

            loss, metrics = sequence_loss(flow_predictions, flow, valid,
                                          args.gamma)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            scaler.step(optimizer)
            scheduler.step()
            scaler.update()

            if args.gpu == 0:
                logger.push(metrics, image1, image2, flow, flow_predictions,
                            valid)

            if total_steps % VAL_FREQ == VAL_FREQ - 1:

                results = validate_VRKitti2(model.module, args, eval_loader,
                                            group)

                model.train()
                if args.stage != 'chairs':
                    model.module.freeze_bn()

                if args.gpu == 0:
                    logger_evaluation.write_dict(results, total_steps)
                    PATH = os.path.join(
                        logroot, '%s.pth' % (str(total_steps + 1).zfill(3)))
                    torch.save(model.state_dict(), PATH)

            total_steps += 1

            if total_steps > args.num_steps:
                should_keep_training = False
                break
        epoch = epoch + 1

    if args.gpu == 0:
        logger.close()
        PATH = os.path.join(logroot, 'final.pth')
        torch.save(model.state_dict(), PATH)

    return PATH
Exemplo n.º 3
0
def train(gpu, ngpus_per_node, args):
    print("Using GPU %d for training" % gpu)
    args.gpu = gpu

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu)

    model = RAFT(args=args)
    if args.distributed:
        torch.cuda.set_device(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model)
        model = model.to(f'cuda:{args.gpu}')
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu)
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()

    logroot = os.path.join(args.logroot, args.name)
    print("Parameter Count: %d, saving location: %s" % (count_parameters(model), logroot))

    if args.restore_ckpt is not None:
        print("=> loading checkpoint '{}'".format(args.restore_ckpt))
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.restore_ckpt, map_location=loc)
        model.load_state_dict(checkpoint, strict=False)

    model.train()

    train_entries, evaluation_entries = read_splits()

    train_dataset = KITTI_eigen(root=args.dataset_root, inheight=args.inheight, inwidth=args.inwidth, entries=train_entries, maxinsnum=args.maxinsnum,
                                depth_root=args.depth_root, depthvls_root=args.depthvlsgt_root, prediction_root=args.prediction_root, ins_root=args.ins_root,
                                istrain=True, muteaug=False, banremovedup=False, isgarg=True)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
    train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, num_workers=int(args.num_workers / ngpus_per_node), drop_last=True, sampler=train_sampler)

    eval_dataset = KITTI_eigen(root=args.dataset_root, inheight=args.evalheight, inwidth=args.evalwidth, entries=evaluation_entries, maxinsnum=args.maxinsnum,
                               depth_root=args.depth_root, depthvls_root=args.depthvlsgt_root, prediction_root=args.prediction_root, ins_root=args.ins_root, istrain=False, isgarg=True)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset) if args.distributed else None
    eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True, num_workers=3, drop_last=True, sampler=eval_sampler)

    print("Training splits contain %d images while test splits contain %d images" % (train_dataset.__len__(), eval_dataset.__len__()))

    if args.distributed:
        group = dist.new_group([i for i in range(ngpus_per_node)])

    optimizer, scheduler = fetch_optimizer(args, model, int(train_dataset.__len__() / 2))

    total_steps = 0

    if args.gpu == 0:
        logger = Logger(logroot)
        logger_evaluation = Logger(os.path.join(args.logroot, 'evaluation_eigen_background', args.name))
        logger.create_summarywriter()
        logger_evaluation.create_summarywriter()

    VAL_FREQ = 5000
    epoch = 0
    minout = 1

    st = time.time()
    should_keep_training = True
    while should_keep_training:
        train_sampler.set_epoch(epoch)
        for i_batch, data_blob in enumerate(train_loader):
            optimizer.zero_grad()

            image1 = data_blob['img1'].cuda(gpu) / 255.0
            image2 = data_blob['img2'].cuda(gpu) / 255.0
            flowmap = data_blob['flowmap'].cuda(gpu)

            outputs = model(image1, image2)

            selector = (flowmap[:, 0, :, :] != 0)
            flow_loss = sequence_loss(outputs, flowmap, selector, gamma=args.gamma, max_flow=MAX_FLOW)

            metrics = dict()
            metrics['flow_loss'] = flow_loss

            loss = flow_loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            optimizer.step()
            scheduler.step()

            if args.gpu == 0:
                logger.write_dict(metrics, step=total_steps)
                if total_steps % SUM_FREQ == 0:
                    dr = time.time() - st
                    resths = (args.num_steps - total_steps) * dr / (total_steps + 1) / 60 / 60
                    print("Step: %d, rest hour: %f, flowloss: %f" % (total_steps, resths, flow_loss.item()))
                    logger.write_vls(data_blob, outputs, selector.unsqueeze(1), total_steps)

            if total_steps % VAL_FREQ == 1:
                if args.gpu == 0:
                    results = validate_kitti(model.module, args, eval_loader, logger, group, total_steps)
                else:
                    results = validate_kitti(model.module, args, eval_loader, None, group, None)

                if args.gpu == 0:
                    logger_evaluation.write_dict(results, total_steps)
                    if minout > results['out']:
                        minout = results['out']
                        PATH = os.path.join(logroot, 'minout.pth')
                        torch.save(model.state_dict(), PATH)
                        print("model saved to %s" % PATH)

                model.train()

            total_steps += 1

            if total_steps > args.num_steps:
                should_keep_training = False
                break
        epoch = epoch + 1

    if args.gpu == 0:
        logger.close()
        PATH = os.path.join(logroot, 'final.pth')
        torch.save(model.state_dict(), PATH)

    return
Exemplo n.º 4
0
def train(gpu, ngpus_per_node, args):
    print("Using GPU %d for training" % gpu)
    args.gpu = gpu

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu)

    model = RAFT(args)
    if args.distributed:
        torch.cuda.set_device(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model)
        model = model.to(f'cuda:{args.gpu}')
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu)

        eppCbck = eppConstrainer_background(height=args.image_size[0], width=args.image_size[1], bz=args.batch_size)
        eppCbck.to(f'cuda:{args.gpu}')

        eppconcluer = eppConcluer()
        eppconcluer.to(f'cuda:{args.gpu}')
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()

    logroot = os.path.join(args.logroot, args.name)
    print("Parameter Count: %d, saving location: %s" % (count_parameters(model), logroot))

    if args.restore_ckpt is not None:
        print("=> loading checkpoint '{}'".format(args.restore_ckpt))
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.restore_ckpt, map_location=loc)
        model.load_state_dict(checkpoint, strict=False)

    model.train()

    if args.stage != 'chairs':
        model.module.freeze_bn()

    train_entries, evaluation_entries = read_splits()

    aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
    train_dataset = KITTI_eigen(aug_params, split='training', root=args.dataset_root, entries=train_entries, semantics_root=args.semantics_root, depth_root=args.depth_root)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
    train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True,
                                   shuffle=(train_sampler is None), num_workers=int(args.num_workers / ngpus_per_node), drop_last=True,
                                   sampler=train_sampler)

    eval_dataset = KITTI_eigen(split='evaluation', root=args.dataset_root, entries=evaluation_entries, semantics_root=args.semantics_root, depth_root=args.depth_root)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset) if args.distributed else None
    eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True,
                                   shuffle=(eval_sampler is None), num_workers=3, drop_last=True,
                                   sampler=eval_sampler)

    if args.distributed:
        group = dist.new_group([i for i in range(ngpus_per_node)])

    optimizer, scheduler = fetch_optimizer(args, model)

    total_steps = 0

    if args.gpu == 0:
        logger = Logger(model, scheduler, logroot, args.num_steps)
        logger_evaluation = Logger(model, scheduler, os.path.join(args.logroot, 'evaluation_eigen_background', args.name), args.num_steps)

    VAL_FREQ = 5000
    add_noise = False
    epoch = 0

    should_keep_training = True
    print(validate_kitti(model.module, args, eval_loader, eppCbck, eppconcluer, group))
    while should_keep_training:

        train_sampler.set_epoch(epoch)
        for i_batch, data_blob in enumerate(train_loader):
            optimizer.zero_grad()

            image1 = data_blob['img1']
            image1 = Variable(image1, requires_grad=True)
            image1 = image1.cuda(gpu, non_blocking=True)

            image2 = data_blob['img2']
            image2 = Variable(image2, requires_grad=True)
            image2 = image2.cuda(gpu, non_blocking=True)

            flow = data_blob['flow']
            flow = Variable(flow, requires_grad=True)
            flow = flow.cuda(gpu, non_blocking=True)

            valid = data_blob['valid']
            valid = Variable(valid, requires_grad=True)
            valid = valid.cuda(gpu, non_blocking=True)

            E = data_blob['E']
            E = Variable(E, requires_grad=True)
            E = E.cuda(gpu, non_blocking=True)

            semantic_selector = data_blob['semantic_selector']
            semantic_selector = Variable(semantic_selector, requires_grad=True)
            semantic_selector = semantic_selector.cuda(gpu, non_blocking=True)

            if add_noise:
                stdv = np.random.uniform(0.0, 5.0)
                image1 = (image1 + stdv * torch.randn(*image1.shape).cuda(gpu, non_blocking=True)).clamp(0.0, 255.0)
                image2 = (image2 + stdv * torch.randn(*image2.shape).cuda(gpu, non_blocking=True)).clamp(0.0, 255.0)

            flow_predictions = model(image1, image2, iters=args.iters)

            metrics = dict()
            loss_flow, metrics_flow = sequence_flowloss(flow_predictions, flow, valid, args.gamma)
            loss_eppc, metrics_eppc = sequence_eppcloss(eppCbck, flow_predictions, semantic_selector, E, args.gamma)

            metrics.update(metrics_flow)
            metrics.update(metrics_eppc)

            loss = loss_flow + loss_eppc * args.eppcw
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            optimizer.step()
            scheduler.step()

            if args.gpu == 0:
                logger.push(metrics, image1, image2, flow, flow_predictions, valid, data_blob['depth'])

            if total_steps % VAL_FREQ == VAL_FREQ - 1:

                results = validate_kitti(model.module, args, eval_loader, eppCbck, eppconcluer, group)

                model.train()
                if args.stage != 'chairs':
                    model.module.freeze_bn()

                if args.gpu == 0:
                    logger_evaluation.write_dict(results, total_steps)
                    PATH = os.path.join(logroot, '%s.pth' % (str(total_steps + 1).zfill(3)))
                    torch.save(model.state_dict(), PATH)

            total_steps += 1

            if total_steps > args.num_steps:
                should_keep_training = False
                break
        epoch = epoch + 1

    if args.gpu == 0:
        logger.close()
        PATH = os.path.join(logroot, 'final.pth')
        torch.save(model.state_dict(), PATH)

    return PATH
Exemplo n.º 5
0
def train(gpu, ngpus_per_node, args):
    print("Using GPU %d for training" % gpu)
    args.gpu = gpu

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=ngpus_per_node,
                                rank=args.gpu)

    model = RAFT(args)
    if args.distributed:
        torch.cuda.set_device(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model)
        model = model.to(f'cuda:{args.gpu}')
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.gpu],
            find_unused_parameters=True,
            output_device=args.gpu)
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()
    logroot = os.path.join(args.logroot, args.name)
    print("Parameter Count: %d, saving location: %s" %
          (count_parameters(model), logroot))

    if args.restore_ckpt is not None:
        print("=> loading checkpoint '{}'".format(args.restore_ckpt))
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.restore_ckpt, map_location=loc)
        model.load_state_dict(checkpoint, strict=False)

    model.train()

    train_entries, evaluation_entries = read_splits()
    train_dataset = VirtualKITTI2(args=args,
                                  root=args.dataset_root,
                                  inheight=args.inheight,
                                  inwidth=args.inwidth,
                                  entries=train_entries,
                                  istrain=True)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset) if args.distributed else None
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   pin_memory=False,
                                   shuffle=(train_sampler is None),
                                   num_workers=args.num_workers,
                                   drop_last=True,
                                   sampler=train_sampler)

    eval_dataset = VirtualKITTI2(args=args,
                                 root=args.dataset_root,
                                 inheight=args.evalheight,
                                 inwidth=args.evalwidth,
                                 entries=evaluation_entries,
                                 istrain=False)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(
        eval_dataset) if args.distributed else None
    eval_loader = data.DataLoader(eval_dataset,
                                  batch_size=args.batch_size,
                                  pin_memory=False,
                                  shuffle=(eval_sampler is None),
                                  num_workers=2,
                                  drop_last=True,
                                  sampler=eval_sampler)

    print(
        "Training split contains %d images, validation split contained %d images"
        % (len(train_entries), len(evaluation_entries)))

    if args.distributed:
        group = dist.new_group([i for i in range(ngpus_per_node)])

    optimizer, scheduler = fetch_optimizer(args, model)

    total_steps = 0
    VAL_ITERINC = 4

    if args.gpu == 0:
        logger = Logger(model, scheduler, logroot)
        logger.create_summarywriter()

        logger_evaluations = dict()
        for num_iters in range(args.iters, args.iters * 2 + 1, VAL_ITERINC):
            logger_evaluation = Logger(
                model, scheduler,
                os.path.join(
                    args.logroot, 'evaluation_VRKitti',
                    "{}_iternum{}".format(args.name,
                                          str(num_iters).zfill(2))))
            logger_evaluation.create_summarywriter()
            logger_evaluations[num_iters] = logger_evaluation

    VAL_FREQ = 2000
    maxout = 1
    epoch = 0

    st = time.time()
    should_keep_training = True
    while should_keep_training:

        for i_batch, data_blob in enumerate(train_loader):
            optimizer.zero_grad()

            image1 = data_blob['img1'].cuda(gpu, non_blocking=True)
            image2 = data_blob['img2'].cuda(gpu, non_blocking=True)
            flow = data_blob['flowmap'].cuda(gpu, non_blocking=True)

            # exlude invalid pixels and extremely large diplacements
            mag = torch.sum(flow**2, dim=1).sqrt()
            valid = ((flow[:, 0] != 0) * (flow[:, 1] != 0) *
                     (mag < MAX_FLOW)).unsqueeze(1)

            flow_predictions = model(image1, image2, iters=args.iters)
            loss, metrics = sequence_loss(flow_predictions, flow, valid,
                                          args.gamma)

            metrics = dict()
            metrics['loss_flow'] = loss.float().item()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
            scheduler.step()

            if args.gpu == 0:
                logger.write_dict(metrics, total_steps)
                if total_steps % SUM_FREQ == 0:
                    dr = time.time() - st
                    resths = (args.num_steps -
                              total_steps) * dr / (total_steps + 1) / 60 / 60
                    print("Step: %d, rest hour: %f, flow loss: %f" %
                          (total_steps, resths, loss.item()))
                    logger.write_vls(data_blob, flow_predictions, valid,
                                     total_steps)

            if total_steps % VAL_FREQ == 1:
                for num_iters in range(args.iters, args.iters * 2 + 1,
                                       VAL_ITERINC):

                    if args.gpu == 0 and num_iters == 24:
                        results = validate_VRKitti2(model.module, args,
                                                    eval_loader, num_iters,
                                                    group, logger, total_steps)
                    else:
                        results = validate_VRKitti2(model.module, args,
                                                    eval_loader, num_iters,
                                                    group, None, None)

                    if args.gpu == 0:
                        logger_evaluations[num_iters].write_dict(
                            results, total_steps)

                        if num_iters == 24:
                            if results['out'] < maxout:
                                maxout = results['out']
                                PATH = os.path.join(logroot, 'minout.pth')
                                torch.save(model.state_dict(), PATH)
                                print("model saved to %s" % PATH)

                model.train()
            total_steps += 1

            if total_steps > args.num_steps:
                should_keep_training = False
                break
        epoch = epoch + 1

    if args.gpu == 0:
        logger.close()
        PATH = os.path.join(logroot, 'final.pth')
        torch.save(model.state_dict(), PATH)

    return PATH