Пример #1
0
def load_data(is_training = True):
    global IMAGE_NUM
    train_idx_path = "data/trainNdxs.txt"
    test_idx_path = "data/testNdxs.txt"
    input_rgb_images_dir = 'data/nyu_datasets_changed/input/'
    target_depth_images_dir = 'data/nyu_datasets_changed/target_depths/'
    target_labels_images_dir = 'data/nyu_datasets_changed/labels_38/'

    data_path = "data/nyu_depth_v2_labeled.mat"
    train_idx = np.loadtxt(train_idx_path, dtype = 'int')
    test_idx = np.loadtxt(test_idx_path, dtype = 'int')

    input_transform = flow_transforms.Compose([flow_transforms.Scale(120)])
    target_depth_transform = flow_transforms.Compose([flow_transforms.Scale_Single(60)])
    target_labels_transform = flow_transforms.Compose([])

    co_transform=flow_transforms.Compose([
            flow_transforms.RandomRotate(4),
            flow_transforms.RandomCrop((480,640)),
            flow_transforms.RandomVerticalFlip()
        ])

    data = []
    if is_training:
        data = ListDataset(data_path,train_idx,input_transform,target_depth_transform, target_labels_transform, co_transform)
    else:
        data = ListDataset(data_path, test_idx, input_transform, target_depth_transform, target_labels_transform)
    IMAGE_NUM = len(data)
    return data
Пример #2
0
    def __init__(self,
                 args,
                 root='datasets/FlyingChairs_release/data',
                 mode='train'):

        # Normalize images to [-1,1]
        input_transform = transforms.Compose([
            flow_transforms.ArrayToTensor(),
            transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        target_transform = transforms.Compose(
            [flow_transforms.ArrayToTensor()])

        # Simple Aug
        if mode == 'train':
            co_transform = flow_transforms.Compose([
                flow_transforms.RandomTranslate(cfg.RANDOM_TRANS),
                flow_transforms.RandomCrop(
                    (cfg.CROP_SIZE[0], cfg.CROP_SIZE[1])),
                flow_transforms.RandomVerticalFlip(),
                flow_transforms.RandomHorizontalFlip()
            ])
        else:
            co_transform = None

        self.root = root
        self.transform = input_transform
        self.target_transform = target_transform
        self.co_transform = co_transform

        images = []
        for flow_map in sorted(glob.glob(os.path.join(root, '*_flow.flo'))):
            flow_map = os.path.basename(flow_map)
            root_filename = flow_map[:-9]
            img1 = root_filename + '_img1.ppm'
            img2 = root_filename + '_img2.ppm'
            if not (os.path.isfile(os.path.join(dir, img1))
                    and os.path.isfile(os.path.join(dir, img2))):
                continue
            images.append([[img1, img2], flow_map])

        train_list, test_list = split2list(root, 'FlyingChairs_train_val.txt')
        if mode == 'train':
            self.path_list = train_list
        else:
            self.path_list = test_list
Пример #3
0
def main():
    global args, best_EPE
    args = parser.parse_args()
    
    if not args.data:
        f = open('train_src/data_loc.json', 'r')
        content = f.read()
        f.close()
        data_loc = json.loads(content)
        args.data = data_loc[args.dataset]
    
    if not args.savpath:
        save_path = '{},{},{}epochs{},b{},lr{}'.format(
            args.arch,
            args.solver,
            args.epochs,
            ',epochSize'+str(args.epoch_size) if args.epoch_size > 0 else '',
            args.batch_size,
            args.lr)
        if not args.no_date:
            timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
            save_path = os.path.join(timestamp,save_path)
    else:
        save_path = args.savpath
    save_path = os.path.join(args.dataset,save_path)
    print('=> will save everything to {}'.format(save_path))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # save training args
    save_training_args(save_path, args)

    train_writer = SummaryWriter(os.path.join(save_path,'train'))
    test_writer = SummaryWriter(os.path.join(save_path,'test'))
    output_writers = []
    for i in range(3):
        output_writers.append(SummaryWriter(os.path.join(save_path,'test',str(i))))

    # Data loading code
    if args.grayscale:
        input_transform = transforms.Compose([
            flow_transforms.ArrayToTensor(),
            transforms.Grayscale(num_output_channels=3),
            transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
            transforms.Normalize(mean=[0.431,0.431,0.431], std=[1,1,1]) # 0.431=(0.45+0.432+0.411)/3
            # transforms.Normalize(mean=[0.5,0.5,0.5], std=[1,1,1])
        ])
    else:
        input_transform = transforms.Compose([
            flow_transforms.ArrayToTensor(),
            transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
            transforms.Normalize(mean=[0.45,0.432,0.411], std=[1,1,1])
        ])

    target_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0,0],std=[args.div_flow,args.div_flow])
    ])

    if 'KITTI' in args.dataset:
        args.sparse = True
    if args.sparse:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomCrop((320,448)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])
    else:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomTranslate(10),
            flow_transforms.RandomRotate(10,5),
            flow_transforms.RandomCrop((320,448)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])

    print("=> fetching img pairs in '{}'".format(args.data))
    train_set, test_set = datasets.__dict__[args.dataset](
        args.data,
        transform=input_transform,
        target_transform=target_transform,
        co_transform=co_transform,
        split=args.split_file if args.split_file else args.split_value
    )
    print('{} samp-les found, {} train samples and {} test samples '.format(len(test_set)+len(train_set),
                                                                           len(train_set),
                                                                           len(test_set)))
    if not args.evaluate:
        train_loader = torch.utils.data.DataLoader(
            train_set, batch_size=args.batch_size,
            num_workers=args.workers, pin_memory=True, shuffle=True)
    
    val_loader = torch.utils.data.DataLoader(
        test_set, batch_size=args.batch_size,
        num_workers=args.workers, pin_memory=True, shuffle=False)

    # create model
    if args.pretrained:
        network_data = torch.load(args.pretrained)
        # args.arch = network_data['arch']
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        network_data = None
        print("=> creating model '{}'".format(args.arch))

    # if (args.qw and args.qa and args.cut_ratio) is not None:
    #     model = models.__dict__[args.arch](data=network_data, bitW=args.qw, bitA=args.qa, cut_ratio=args.cut_ratio).cuda()
    # elif (args.qw and args.qa) is not None:
    #     model = models.__dict__[args.arch](data=network_data, bitW=args.qw, bitA=args.qa).cuda()
    # else:
    #     model = models.__dict__[args.arch](data=network_data).cuda()

    model = models.__dict__[args.arch](data=network_data, args=args).to(device)

    # model = torch.nn.DataParallel(model).cuda()
    # cudnn.benchmark = True

    assert(args.solver in ['adam', 'sgd', 'adamw'])
    print('=> setting {} solver'.format(args.solver))
    param_groups = [{'params': model.bias_parameters(), 'weight_decay': args.bias_decay},
                    {'params': model.weight_parameters(), 'weight_decay': args.weight_decay}]

    if device.type == "cuda":
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True

    if args.solver == 'adam':
        optimizer = torch.optim.Adam(param_groups, args.lr,
                                     betas=(args.momentum, args.beta))
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(param_groups, args.lr,
                                    momentum=args.momentum)
    elif args.solver == 'adamw':
        optimizer = torch.optim.AdamW(param_groups, args.lr,
                                    betas=(args.momentum, args.beta))
    
    if args.print_model:
        exportpars(model, save_path, args)
        exportsummary(model, save_path, args)
        if args.savpath == 'test':
            return

    if args.evaluate:
        best_EPE = validate(val_loader, model, 0, output_writers)
        return

    if args.demo:
        demo(val_loader, model, 0, output_writers)
        return
    if args.demovideo:
        demovideo(val_loader, model, 0, output_writers)
        return

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)

    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch
        train_loss, train_EPE = train(train_loader, model, optimizer, epoch, train_writer)
        train_writer.add_scalar('mean EPE', train_EPE, epoch)
        scheduler.step()

        # evaluate on validation set
        with torch.no_grad():
            EPE = validate(val_loader, model, epoch, output_writers)
        test_writer.add_scalar('mean EPE', EPE, epoch)

        if best_EPE < 0:
            best_EPE = EPE

        is_best = EPE < best_EPE
        best_EPE = min(EPE, best_EPE)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.module.state_dict(),
            'best_EPE': best_EPE,
            'div_flow': args.div_flow
        }, is_best, save_path)
Пример #4
0
def main():
    global args, best_EPE, save_path
    args = parser.parse_args()
    save_path = '{},{},{}epochs{},b{},lr{}'.format(
        args.arch, args.solver, args.epochs,
        ',epochSize' + str(args.epoch_size) if args.epoch_size > 0 else '',
        args.batch_size, args.lr)
    if not args.no_date:
        timestamp = datetime.datetime.now().strftime("%a-%b-%d-%H:%M")
        save_path = os.path.join(timestamp, save_path)
    save_path = os.path.join(args.dataset, save_path)
    print('=> will save everything to {}'.format(save_path))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    input_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]), normalize
    ])
    target_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0], std=[args.div_flow, args.div_flow])
    ])

    if 'KITTI' in args.dataset:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomCrop((320, 448)),
            #random flips are not supported yet for tensor conversion, but will be
            #flow_transforms.RandomVerticalFlip(),
            #flow_transforms.RandomHorizontalFlip()
        ])
    else:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomTranslate(10),
            flow_transforms.RandomRotate(10, 5),
            flow_transforms.RandomCrop((320, 448)),
            #random flips are not supported yet for tensor conversion, but will be
            #flow_transforms.RandomVerticalFlip(),
            #flow_transforms.RandomHorizontalFlip()
        ])

    print("=> fetching img pairs in '{}'".format(args.data))
    train_set, test_set = datasets.__dict__[args.dataset](
        args.data,
        transform=input_transform,
        target_transform=target_transform,
        co_transform=co_transform,
        split=args.split)
    print('{} samples found, {} train samples and {} test samples '.format(
        len(test_set) + len(train_set), len(train_set), len(test_set)))
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        sampler=balancedsampler.RandomBalancedSampler(train_set,
                                                      args.epoch_size),
        num_workers=args.workers,
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(test_set,
                                             batch_size=args.batch_size,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        print("=> creating model '{}'".format(args.arch))

    model = models.__dict__[args.arch](args.pretrained).cuda()

    model = torch.nn.DataParallel(model).cuda()
    criterion = multiscaleloss(sparse='KITTI' in args.dataset,
                               loss=args.loss).cuda()
    high_res_EPE = multiscaleloss(scales=1,
                                  downscale=4,
                                  weights=(1),
                                  loss='L1',
                                  sparse='KITTI' in args.dataset).cuda()
    cudnn.benchmark = True

    assert (args.solver in ['adam', 'sgd'])
    print('=> setting {} solver'.format(args.solver))
    if args.solver == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     betas=(args.momentum, args.beta),
                                     weight_decay=args.weight_decay)
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    if args.evaluate:
        best_EPE = validate(val_loader, model, criterion, high_res_EPE)
        return

    with open(os.path.join(save_path, args.log_summary), 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'train_EPE', 'EPE'])

    with open(os.path.join(save_path, args.log_full), 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'train_EPE'])

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train_loss, train_EPE = train(train_loader, model, criterion,
                                      high_res_EPE, optimizer, epoch)

        # evaluate o validation set

        EPE = validate(val_loader, model, criterion, high_res_EPE)
        if best_EPE < 0:
            best_EPE = EPE

        # remember best prec@1 and save checkpoint
        is_best = EPE < best_EPE
        best_EPE = min(EPE, best_EPE)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.module.state_dict(),
                'best_EPE': best_EPE,
            }, is_best)

        with open(os.path.join(save_path, args.log_summary), 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, train_EPE, EPE])
Пример #5
0
def main():
    global args, best_EPE
    args = parser.parse_args()
    save_path = "{},{},{}epochs{},b{},lr{}".format(
        args.arch,
        args.solver,
        args.epochs,
        ",epochSize" + str(args.epoch_size) if args.epoch_size > 0 else "",
        args.batch_size,
        args.lr,
    )
    if not args.no_date:
        timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
        save_path = os.path.join(timestamp, save_path)
    save_path = os.path.join(args.dataset, save_path)
    print("=> will save everything to {}".format(save_path))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    train_writer = SummaryWriter(os.path.join(save_path, "train"))
    test_writer = SummaryWriter(os.path.join(save_path, "test"))
    output_writers = []
    for i in range(3):
        output_writers.append(
            SummaryWriter(os.path.join(save_path, "test", str(i))))

    # Data loading code
    if args.data_loader == "torch":
        print("Using default data loader \n")
        input_transform = transforms.Compose([
            flow_transforms.ArrayToTensor(),
            transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
            transforms.Normalize(mean=[0.45, 0.432, 0.411], std=[1, 1, 1]),
        ])
        target_transform = transforms.Compose([
            flow_transforms.ArrayToTensor(),
            transforms.Normalize(mean=[0, 0],
                                 std=[args.div_flow, args.div_flow]),
        ])
        test_transform = transforms.Compose([
            transforms.Resize((122, 162)),
            flow_transforms.ArrayToTensor(),
            transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
            transforms.Normalize(mean=[0.45, 0.432, 0.411], std=[1, 1, 1]),
        ])

        if "KITTI" in args.dataset:
            args.sparse = True
        if args.sparse:
            co_transform = flow_transforms.Compose([
                flow_transforms.RandomCrop((122, 162)),
                flow_transforms.RandomVerticalFlip(),
                flow_transforms.RandomHorizontalFlip(),
            ])
        else:
            co_transform = flow_transforms.Compose([
                flow_transforms.RandomTranslate(10),
                flow_transforms.RandomRotate(10, 5),
                flow_transforms.RandomCrop((122, 162)),
                flow_transforms.RandomVerticalFlip(),
                flow_transforms.RandomHorizontalFlip(),
            ])

        print("=> fetching img pairs in '{}'".format(args.data))
        train_set, test_set = datasets.__dict__[args.dataset](
            args.data,
            transform=input_transform,
            test_transform=test_transform,
            target_transform=target_transform,
            co_transform=co_transform,
            split=args.split_file if args.split_file else args.split_value,
        )
        print("{} samples found, {} train samples and {} test samples ".format(
            len(test_set) + len(train_set), len(train_set), len(test_set)))
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            num_workers=args.workers,
            pin_memory=True,
            shuffle=True,
        )
        val_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=args.batch_size,
            num_workers=args.workers,
            pin_memory=True,
            shuffle=False,
        )

    if args.data_loader == "dali":
        print("Using NVIDIA DALI \n")
        (
            (image0_train_names, image0_val_names),
            (image1_train_names, image1_val_names),
            (flow_train_names, flow_val_names),
        ) = make_dali_dataset(
            args.data,
            split=args.split_file if args.split_file else args.split_value)
        print("{} samples found, {} train samples and {} test samples ".format(
            len(image0_val_names) + len(image0_train_names),
            len(image0_train_names),
            len(image0_val_names),
        ))
        global train_length
        global val_length
        train_length = len(image0_train_names)
        val_length = len(image0_val_names)

        def create_image_pipeline(
            batch_size,
            num_threads,
            device_id,
            image0_list,
            image1_list,
            flow_list,
            valBool,
        ):
            pipeline = Pipeline(batch_size, num_threads, device_id, seed=2)
            with pipeline:
                if valBool:
                    shuffleBool = False
                else:
                    shuffleBool = True
                """ READ FILES """
                image0, _ = fn.readers.file(
                    file_root=args.data,
                    files=image0_list,
                    random_shuffle=shuffleBool,
                    name="Reader",
                    seed=1,
                )
                image1, _ = fn.readers.file(
                    file_root=args.data,
                    files=image1_list,
                    random_shuffle=shuffleBool,
                    seed=1,
                )
                flo = fn.readers.numpy(
                    file_root=args.data,
                    files=flow_list,
                    random_shuffle=shuffleBool,
                    seed=1,
                )
                """ DECODE AND RESHAPE """
                image0 = fn.decoders.image(image0, device="cpu")
                image0 = fn.reshape(image0, layout="HWC")
                image1 = fn.decoders.image(image1, device="cpu")
                image1 = fn.reshape(image1, layout="HWC")
                images = fn.cat(image0, image1, axis=2)
                flo = fn.reshape(flo, layout="HWC")

                if valBool:
                    images = fn.resize(images, resize_x=162, resize_y=122)
                else:
                    """ CO-TRANSFORM """
                    # random translate
                    # angle_rng = fn.random.uniform(range=(-90, 90))
                    # images = fn.rotate(images, angle=angle_rng, fill_value=0)
                    # flo = fn.rotate(flo, angle=angle_rng, fill_value=0)

                    images = fn.random_resized_crop(
                        images,
                        size=[122, 162],  # 122, 162
                        random_aspect_ratio=[1.3, 1.4],
                        random_area=[0.8, 0.9],
                        seed=1,
                    )
                    flo = fn.random_resized_crop(
                        flo,
                        size=[122, 162],
                        random_aspect_ratio=[1.3, 1.4],
                        random_area=[0.8, 0.9],
                        seed=1,
                    )

                    # coin1 = fn.random.coin_flip(dtype=types.DALIDataType.BOOL, seed=10)
                    # coin1_n = coin1 ^ True
                    # coin2 = fn.random.coin_flip(dtype=types.DALIDataType.BOOL, seed=20)
                    # coin2_n = coin2 ^ True

                    # images = (
                    #     fn.flip(images, horizontal=1, vertical=1) * coin1 * coin2
                    #     + fn.flip(images, horizontal=1) * coin1 * coin2_n
                    #     + fn.flip(images, vertical=1) * coin1_n * coin2
                    #     + images * coin1_n * coin2_n
                    # )
                    # flo = (
                    #     fn.flip(flo, horizontal=1, vertical=1) * coin1 * coin2
                    #     + fn.flip(flo, horizontal=1) * coin1 * coin2_n
                    #     + fn.flip(flo, vertical=1) * coin1_n * coin2
                    #     + flo * coin1_n * coin2_n
                    # )
                    # _flo = flo
                    # flo_0 = fn.slice(_flo, axis_names="C", start=0, shape=1)
                    # flo_1 = fn.slice(_flo, axis_names="C", start=1, shape=1)
                    # flo_0 = flo_0 * coin1 * -1 + flo_0 * coin1_n
                    # flo_1 = flo_1 * coin2 * -1 + flo_1 * coin2_n
                    # # flo  = noflip + vertical flip + horizontal flip + both_flip

                    # # A horizontal flip is around the vertical axis (switch left and right)
                    # # So for a vertical flip coin1 is activated and needs to give +1, coin2 is activated needs to give -1
                    # # for a horizontal flip coin1 is activated and needs to be -1, coin2_n needs +1
                    # # no flip coin coin1_n +1, coin2_n +1

                    # flo = fn.cat(flo_0, flo_1, axis_name="C")
                """ NORMALIZE """
                images = fn.crop_mirror_normalize(
                    images,
                    mean=[0, 0, 0, 0, 0, 0],
                    std=[255, 255, 255, 255, 255, 255])
                images = fn.crop_mirror_normalize(
                    images,
                    mean=[0.45, 0.432, 0.411, 0.45, 0.432, 0.411],
                    std=[1, 1, 1, 1, 1, 1],
                )
                flo = fn.crop_mirror_normalize(
                    flo, mean=[0, 0], std=[args.div_flow, args.div_flow])

                pipeline.set_outputs(images, flo)
            return pipeline

        class DALILoader:
            def __init__(
                self,
                batch_size,
                image0_names,
                image1_names,
                flow_names,
                valBool,
                num_threads,
                device_id,
            ):
                self.pipeline = create_image_pipeline(
                    batch_size,
                    num_threads,
                    device_id,
                    image0_names,
                    image1_names,
                    flow_names,
                    valBool,
                )
                self.pipeline.build()
                self.epoch_size = self.pipeline.epoch_size(
                    "Reader") / batch_size

                output_names = ["images", "flow"]
                if valBool:
                    self.dali_iterator = pytorch.DALIGenericIterator(
                        self.pipeline,
                        output_names,
                        reader_name="Reader",
                        last_batch_policy=pytorch.LastBatchPolicy.PARTIAL,
                        auto_reset=True,
                    )
                else:
                    self.dali_iterator = pytorch.DALIGenericIterator(
                        self.pipeline,
                        output_names,
                        reader_name="Reader",
                        last_batch_policy=pytorch.LastBatchPolicy.PARTIAL,
                        auto_reset=True,
                    )

            def __len__(self):
                return int(self.epoch_size)

            def __iter__(self):
                return self.dali_iterator.__iter__()

            def reset(self):
                return self.dali_iterator.reset()

        train_loader = DALILoader(
            batch_size=args.batch_size,
            num_threads=args.workers,
            device_id=0,
            image0_names=image0_train_names,
            image1_names=image1_train_names,
            flow_names=flow_train_names,
            valBool=False,
        )

        val_loader = DALILoader(
            batch_size=args.batch_size,
            num_threads=args.workers,
            device_id=0,
            image0_names=image0_val_names,
            image1_names=image1_val_names,
            flow_names=flow_val_names,
            valBool=True,
        )

    # create model
    if args.pretrained:
        network_data = torch.load(args.pretrained)
        args.arch = network_data["arch"]
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        network_data = None
        print("=> creating model '{}'".format(args.arch))

    model = models.__dict__[args.arch](network_data).to(device)

    assert args.solver in ["adam", "sgd"]
    print("=> setting {} solver".format(args.solver))
    param_groups = [
        {
            "params": model.bias_parameters(),
            "weight_decay": args.bias_decay
        },
        {
            "params": model.weight_parameters(),
            "weight_decay": args.weight_decay
        },
    ]

    if device.type == "cuda":
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True

    if args.solver == "adam":
        optimizer = torch.optim.Adam(param_groups,
                                     args.lr,
                                     betas=(args.momentum, args.beta))
    elif args.solver == "sgd":
        optimizer = torch.optim.SGD(param_groups,
                                    args.lr,
                                    momentum=args.momentum)

    if args.evaluate:
        best_EPE = validate(val_loader, model, 0, output_writers)
        return

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones, gamma=0.5)

    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch

        # # --- quant
        # model.train()
        # model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')  # torch.quantization.default_qconfig
        # # model = torch.quantization.fuse_modules(model, [['Conv2d', 'bn', 'relu']])
        # torch.backends.quantized.engine = 'qnnpack'
        # model = torch.quantization.prepare_qat(model)
        # # --- quant

        # my_sample = next(itertools.islice(train_loader, 10, None))
        # print(my_sample[1][0])
        # print("Maximum value is ", torch.max(my_sample[0][0]))
        # print("Minimum value is ", torch.min(my_sample[0][0]))

        train_loss, train_EPE = train(train_loader, model, optimizer, epoch,
                                      train_writer)
        train_writer.add_scalar("mean EPE", train_EPE, epoch)

        scheduler.step()

        # evaluate on validation set

        with torch.no_grad():
            EPE = validate(val_loader, model, epoch, output_writers)
        test_writer.add_scalar("mean EPE", EPE, epoch)

        if best_EPE < 0:
            best_EPE = EPE

        is_best = EPE < best_EPE
        best_EPE = min(EPE, best_EPE)
        # if is_best:
        #     kernels = model.module.conv3_1[0].weight.data
        #     kernels = kernels.cpu()
        #     kernels = kernels - kernels.min()
        #     kernels = kernels / kernels.max()
        #     img = make_grid(kernels)
        #     plt.imshow(img.permute(1, 2, 0))
        #     plt.show()
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "arch": args.arch,
                "state_dict": model.module.state_dict(),
                "best_EPE": best_EPE,
                "div_flow": args.div_flow,
            },
            is_best,
            save_path,
            model,
            dummy_input,
        )
Пример #6
0
data_dir = (input_rgb_images_dir, target_depth_images_dir,
            target_labels_images_dir)

input_transform = transforms.Compose(
    [flow_transforms.Scale(228),
     flow_transforms.ArrayToTensor()])
target_depth_transform = transforms.Compose(
    [flow_transforms.Scale_Single(228),
     flow_transforms.ArrayToTensor()])
target_labels_transform = transforms.Compose([flow_transforms.ArrayToTensor()])

##Apply this transform on input, ground truth depth images and labeled images

co_transform = flow_transforms.Compose([
    flow_transforms.RandomCrop((480, 640)),
    flow_transforms.RandomHorizontalFlip()
])

##Splitting in train, val and test sets [No data augmentation on val and test, only on train]

train_dataset = ListDataset(data_dir,train_listing,input_transform,target_depth_transform,\
                            target_labels_transform,co_transform)

val_dataset = ListDataset(data_dir,val_listing,input_transform,target_depth_transform,\
                            target_labels_transform)

test_dataset = ListDataset(data_dir,test_listing,input_transform,target_depth_transform,\
                            target_labels_transform)

print("Loading data...")
Пример #7
0
def main():
    global args, best_EPE
    args = parser.parse_args()

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        print("=> creating model '{}'".format(args.arch))

    model = models.__dict__[args.arch](args.pretrained)

    model = torch.nn.DataParallel(model).cuda()
    criterion = multiscaleloss().cuda()
    high_res_EPE = multiscaleloss(scales=1,
                                  downscale=4,
                                  weights=(1),
                                  loss='L1').cuda()
    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    dataset = datasets.FlyingChairs(
        args.data,
        transform=transforms.Compose([transforms.ToTensor(), normalize]),
        target_transform=None,
        co_transform=flow_transforms.Compose([
            flow_transforms.RandomTranslate(10),
            flow_transforms.RandomCropRotate(10, 360, 5),
            flow_transforms.RandomCrop((320, 448))
        ]),
        split=args.split)
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    dataset.eval()
    val_loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        best_EPE = validate(val_loader, model, criterion, high_res_EPE)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, high_res_EPE, optimizer, epoch)

        # evaluate o validation set

        EPE = validate(val_loader, model, criterion, high_res_EPE)
        if best_EPE < 0:
            best_EPE = EPE

        # remember best prec@1 and save checkpoint
        is_best = EPE < best_EPE
        best_EPE = min(EPE, best_EPE)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_EPE': best_EPE,
            }, is_best)
Пример #8
0
def main():
    global args, best_EPE
    args = parser.parse_args()

    # Data loading code
    input_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
        transforms.Normalize(mean=[0.45, 0.432, 0.411], std=[1, 1, 1])
    ])
    target_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0], std=[args.div_flow, args.div_flow])
    ])

    if 'KITTI' in args.dataset:
        args.sparse = True
    if args.sparse:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomCrop((320, 448)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])
    else:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomTranslate(10),
            flow_transforms.RandomRotate(10, 5),
            flow_transforms.RandomCrop((320, 448)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])

    print("=> fetching img pairs in '{}'".format(args.data))
    train_set, test_set = datasets.__dict__[args.dataset](
        args.data,
        transform=input_transform,
        target_transform=target_transform,
        co_transform=co_transform,
        split=args.split_file if args.split_file else args.split_value)
    print('{} samp-les found, {} train samples and {} test samples '.format(
        len(test_set) + len(train_set), len(train_set), len(test_set)))
    # train_loader = torch.utils.data.DataLoader(
    #     train_set, batch_size=args.batch_size,
    #     num_workers=args.workers, pin_memory=True, shuffle=True)
    val_loader = torch.utils.data.DataLoader(test_set,
                                             batch_size=args.batch_size,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             shuffle=False)

    # create model
    if args.pretrained:
        if torch.cuda.is_available():
            network_data = torch.load(args.pretrained)
        else:
            network_data = torch.load(args.pretrained,
                                      map_location=torch.device('cpu'))
        # args.arch = network_data['arch']
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        network_data = None
        print("=> creating model '{}'".format(args.arch))

    if args.qw and args.qa is not None:
        if torch.cuda.is_available():
            model = models.__dict__[args.arch](data=network_data,
                                               bitW=args.qw,
                                               bitA=args.qa).cuda()
        else:
            model = models.__dict__[args.arch](data=network_data,
                                               bitW=args.qw,
                                               bitA=args.qa)
    else:
        if torch.cuda.is_available():
            model = models.__dict__[args.arch](data=network_data).cuda()
        else:
            model = models.__dict__[args.arch](data=network_data)
    # if torch.cuda.is_available():
    #     model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True

    assert (args.solver in ['adam', 'sgd'])
    print('=> setting {} solver'.format(args.solver))
    param_groups = [{
        'params': model.bias_parameters(),
        'weight_decay': args.bias_decay
    }, {
        'params': model.weight_parameters(),
        'weight_decay': args.weight_decay
    }]
    if args.solver == 'adam':
        optimizer = torch.optim.Adam(param_groups,
                                     args.lr,
                                     betas=(args.momentum, args.beta))
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(param_groups,
                                    args.lr,
                                    momentum=args.momentum)

    # print (summary(model, (6, 320, 448)))
    # print (model)

    if args.evaluate:
        best_EPE = validate(val_loader, model, 0)
        # validate_chairs(model.module)
        # validate_sintel(model.module)
        return
Пример #9
0
def main():
    global args, best_EPE, save_path
    args = parser.parse_args()
    save_path = '{},{},{}epochs{},b{},lr{}'.format(
        args.arch, args.solver, args.epochs,
        ',epochSize' + str(args.epoch_size) if args.epoch_size > 0 else '',
        args.batch_size, args.lr)
    if not args.no_date:
        timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
        save_path = os.path.join(timestamp, save_path)
    save_path = os.path.join(args.dataset, save_path)
    print('=> will save everything to {}'.format(save_path))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    train_writer = SummaryWriter(os.path.join(save_path, 'train'))
    test_writer = SummaryWriter(os.path.join(save_path, 'test'))
    output_writers = []
    for i in range(3):
        output_writers.append(
            SummaryWriter(os.path.join(save_path, 'test', str(i))))

    # Data loading code
    input_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
        transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1])
    ])
    target_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0], std=[args.div_flow, args.div_flow])
    ])

    if 'KITTI' in args.dataset:
        args.sparse = True
    if args.sparse:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomCrop((320, 448)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])
    else:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomTranslate(10),
            flow_transforms.RandomRotate(10, 5),
            flow_transforms.RandomCrop((320, 448)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])

    print("=> fetching img pairs in '{}'".format(args.data))
    train_set, test_set = datasets.__dict__[args.dataset](
        args.data,
        transform=input_transform,
        target_transform=target_transform,
        co_transform=co_transform,
        split=args.split_file if args.split_file else args.split_value)
    print('{} samples found, {} train samples and {} test samples '.format(
        len(test_set) + len(train_set), len(train_set), len(test_set)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(test_set,
                                             batch_size=args.batch_size,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             shuffle=False)

    # create model
    if args.pretrained:
        network_data = torch.load(args.pretrained)
        args.arch = network_data['arch']
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        network_data = None
        print("=> creating model '{}'".format(args.arch))

    model = models.__dict__[args.arch](network_data).cuda()
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True

    assert (args.solver in ['adam', 'sgd'])
    print('=> setting {} solver'.format(args.solver))
    param_groups = [{
        'params': model.module.bias_parameters(),
        'weight_decay': args.bias_decay
    }, {
        'params': model.module.weight_parameters(),
        'weight_decay': args.weight_decay
    }]
    if args.solver == 'adam':
        optimizer = torch.optim.Adam(param_groups,
                                     args.lr,
                                     betas=(args.momentum, args.beta))
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(param_groups,
                                    args.lr,
                                    momentum=args.momentum)

    if args.evaluate:
        best_EPE = validate(val_loader, model, 0, output_writers)
        return

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones, gamma=0.5)

    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()

        # train for one epoch
        train_loss, train_EPE = train(train_loader, model, optimizer, epoch,
                                      train_writer)
        train_writer.add_scalar('mean EPE', train_EPE, epoch)

        # evaluate on validation set

        EPE = validate(val_loader, model, epoch, output_writers)
        test_writer.add_scalar('mean EPE', EPE, epoch)

        if best_EPE < 0:
            best_EPE = EPE

        is_best = EPE < best_EPE
        best_EPE = min(EPE, best_EPE)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.module.state_dict(),
                'best_EPE': best_EPE,
                'div_flow': args.div_flow
            }, is_best)
Пример #10
0
    def __getitem__(self, index):
        left = self.left[index]
        right = self.right[index]
        left_img = self.loader(left)
        right_img = self.loader(right)
        disp_L = self.disp_L[index]
        dataL = self.dploader(disp_L)
        dataL[dataL == np.inf] = 0

        if not (self.disp_R is None):
            disp_R = self.disp_R[index]
            dataR = self.dploader(disp_R)
            dataR[dataR == np.inf] = 0

        max_h = 2048 // 4
        max_w = 3072 // 4

        # photometric unsymmetric-augmentation
        random_brightness = np.random.uniform(0.5, 2., 2)
        random_gamma = np.random.uniform(0.8, 1.2, 2)
        random_contrast = np.random.uniform(0.8, 1.2, 2)
        left_img = torchvision.transforms.functional.adjust_brightness(
            left_img, random_brightness[0])
        left_img = torchvision.transforms.functional.adjust_gamma(
            left_img, random_gamma[0])
        left_img = torchvision.transforms.functional.adjust_contrast(
            left_img, random_contrast[0])
        right_img = torchvision.transforms.functional.adjust_brightness(
            right_img, random_brightness[1])
        right_img = torchvision.transforms.functional.adjust_gamma(
            right_img, random_gamma[1])
        right_img = torchvision.transforms.functional.adjust_contrast(
            right_img, random_contrast[1])
        right_img = np.asarray(right_img)
        left_img = np.asarray(left_img)

        # horizontal flip
        if not (self.disp_R is None):
            if np.random.binomial(1, 0.5):
                tmp = right_img
                right_img = left_img[:, ::-1]
                left_img = tmp[:, ::-1]
                tmp = dataR
                dataR = dataL[:, ::-1]
                dataL = tmp[:, ::-1]

        # geometric unsymmetric-augmentation
        angle = 0
        px = 0
        if np.random.binomial(1, 0.5):
            angle = 0.1
            px = 2
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomVdisp(angle, px),
            flow_transforms.Scale(np.random.uniform(self.rand_scale[0],
                                                    self.rand_scale[1]),
                                  order=self.order),
            flow_transforms.RandomCrop((max_h, max_w)),
        ])
        augmented, dataL = co_transform([left_img, right_img], dataL)
        left_img = augmented[0]
        right_img = augmented[1]

        # randomly occlude a region
        if np.random.binomial(1, 0.5):
            sx = int(np.random.uniform(50, 150))
            sy = int(np.random.uniform(50, 150))
            cx = int(np.random.uniform(sx, right_img.shape[0] - sx))
            cy = int(np.random.uniform(sy, right_img.shape[1] - sy))
            right_img[cx - sx:cx + sx,
                      cy - sy:cy + sy] = np.mean(np.mean(right_img, 0),
                                                 0)[np.newaxis, np.newaxis]

        h, w, _ = left_img.shape
        top_pad = max_h - h
        left_pad = max_w - w
        left_img = np.lib.pad(left_img, ((top_pad, 0), (0, left_pad), (0, 0)),
                              mode='constant',
                              constant_values=0)
        right_img = np.lib.pad(right_img,
                               ((top_pad, 0), (0, left_pad), (0, 0)),
                               mode='constant',
                               constant_values=0)

        dataL = np.expand_dims(np.expand_dims(dataL, 0), 0)
        dataL = np.lib.pad(dataL,
                           ((0, 0), (0, 0), (top_pad, 0), (0, left_pad)),
                           mode='constant',
                           constant_values=0)[0, 0]
        dataL = np.ascontiguousarray(dataL, dtype=np.float32)

        processed = preprocess.get_transform()
        left_img = processed(left_img)
        right_img = processed(right_img)
        return (left_img, right_img, dataL)
Пример #11
0
def main():
    global args, best_EPE, save_path, intrinsic

    # ============= savor setting ===================
    save_path = '{}_{}_{}epochs{}_b{}_lr{}_posW{}'.format(
        args.arch,
        args.solver,
        args.epochs,
        '_epochSize'+str(args.epoch_size) if args.epoch_size > 0 else '',
        args.batch_size,
        args.lr,
        args.pos_weight,
    )
    if not args.no_date:
        timestamp = datetime.datetime.now().strftime("%y_%m_%d_%H_%M")
    else:
        timestamp = ''
    save_path = os.path.abspath(args.savepath) + '/' + os.path.join(args.dataset, save_path  +  '_' + timestamp )

    # ==========  Data loading code ==============
    input_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
        transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1])
    ])

    val_input_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
        transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1])
    ])

    target_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
    ])

    co_transform = flow_transforms.Compose([
            flow_transforms.RandomCrop((args.train_img_height ,args.train_img_width)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])

    print("=> loading img pairs from '{}'".format(args.data))
    train_set, val_set = datasets.__dict__[args.dataset](
        args.data,
        transform=input_transform,
        val_transform = val_input_transform,
        target_transform=target_transform,
        co_transform=co_transform
    )
    print('{} samples found, {} train samples and {} val samples '.format(len(val_set)+len(train_set),
                                                                           len(train_set),
                                                                           len(val_set)))
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size,
        num_workers=args.workers, pin_memory=True, shuffle=True, drop_last=True)

    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size,
        num_workers=args.workers, pin_memory=True, shuffle=False, drop_last=True)

    # ============== create model ====================
    if args.pretrained:
        network_data = torch.load(args.pretrained)
        args.arch = network_data['arch']
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        network_data = None
        print("=> creating model '{}'".format(args.arch))

    model = models.__dict__[args.arch]( data = network_data).cuda()
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True

    #=========== creat optimizer, we use adam by default ==================
    assert(args.solver in ['adam', 'sgd'])
    print('=> setting {} solver'.format(args.solver))
    param_groups = [{'params': model.module.bias_parameters(), 'weight_decay': args.bias_decay},
                    {'params': model.module.weight_parameters(), 'weight_decay': args.weight_decay}]
    if args.solver == 'adam':
        optimizer = torch.optim.Adam(param_groups, args.lr,
                                     betas=(args.momentum, args.beta))
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(param_groups, args.lr,
                                    momentum=args.momentum)

    # for continues training
    if args.pretrained and ('dataset' in network_data):
        if args.pretrained and args.dataset == network_data['dataset'] :
            optimizer.load_state_dict(network_data['optimizer'])
            best_EPE = network_data['best_EPE']
            args.start_epoch = network_data['epoch']
            save_path = os.path.dirname(args.pretrained)

    print('=> will save everything to {}'.format(save_path))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    train_writer = SummaryWriter(os.path.join(save_path, 'train'))
    val_writer = SummaryWriter(os.path.join(save_path, 'val'))

    # spixelID: superpixel ID for visualization,
    # XY_feat: the coordinate feature for position loss term
    spixelID, XY_feat_stack = init_spixel_grid(args)
    val_spixelID,  val_XY_feat_stack = init_spixel_grid(args, b_train=False)


    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train_avg_slic, train_avg_sem, iteration = train(train_loader, model, optimizer, epoch,
                                                         train_writer, spixelID, XY_feat_stack )
        if epoch % args.record_freq == 0:
            train_writer.add_scalar('Mean avg_slic', train_avg_slic, epoch)

        # evaluate on validation set and save the module( and choose the best)
        with torch.no_grad():
            avg_slic, avg_sem  = validate(val_loader, model, epoch, val_writer, val_spixelID, val_XY_feat_stack)
            if epoch % args.record_freq == 0:
                val_writer.add_scalar('Mean avg_slic', avg_slic, epoch)

        rec_dict = {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.module.state_dict(),
                'best_EPE': best_EPE,
                'optimizer': optimizer.state_dict(),
                'dataset': args.dataset
            }

        if (iteration) >= (args.milestones[-1] + args.additional_step):
            save_checkpoint(rec_dict, is_best =False, filename='%d_step.tar' % iteration)
            print("Train finished!")
            break

        if best_EPE < 0:
            best_EPE = avg_sem
        is_best = avg_sem < best_EPE
        best_EPE = min(avg_sem, best_EPE)
        save_checkpoint(rec_dict, is_best)
Пример #12
0
def main():
    global args, best_EPE
    args = parser.parse_args()
    #checkpoints and model_args
    save_path = '{},{},{}epochs{},b{},lr{}'.format(
        args.arch, args.solver, args.epochs,
        ',epochSize' + str(args.epoch_size) if args.epoch_size > 0 else '',
        args.batch_size, args.lr)
    if not args.no_date:
        timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
        save_path = os.path.join(timestamp, save_path)
    save_path = os.path.join(args.dataset, save_path)
    print('=> will save everything to {}'.format(save_path))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

#tensorboardX
    train_writer = SummaryWriter(
        os.path.join(save_path, 'train')
    )  #'KITTI_occ/05-29-11:36/flownets,adam,300epochs,epochSize1000,b8,lr0.0001'
    test_writer = SummaryWriter(os.path.join(save_path, 'test'))  #
    output_writers = []
    for i in range(3):
        output_writers.append(
            SummaryWriter(os.path.join(save_path, 'test', str(i))))


# Data loading code
    input_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
        transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1])
    ])
    target_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0], std=[args.div_flow, args.div_flow])
    ])

    if 'KITTI' in args.dataset:
        args.sparse = True
    if args.sparse:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomCrop((320, 448)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])
    else:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomTranslate(10),
            flow_transforms.RandomRotate(10, 5),
            flow_transforms.RandomCrop((320, 448)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])

    print("=> fetching img pairs in '{}'".format(args.data))
    train_set, test_set = datasets.__dict__[args.dataset](
        args.data,
        transform=input_transform,
        target_transform=target_transform,
        co_transform=co_transform,
        split=args.split_file if args.split_file else args.split_value)
    print('{} samples found, {} train samples and {} test samples '.format(
        len(test_set) + len(train_set), len(train_set), len(test_set)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(test_set,
                                             batch_size=args.batch_size,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             shuffle=False)

    # create model 这里的套路可以借鉴下
    model = FlowNetS()
    if args.pretrained:
        network_data = torch.load(args.pretrained)
        model.load_state_dict(network_data).to(device)

        args.arch = network_data['arch']  #flownets_bn
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        #network_data = None
        #model.load_state_dict(network_data).to(device)
        model.init_weights()
        print("=> creating model '{}'".format(args.arch))

    #这里直接把网络结构都载入了。。。

    #model = models.__dict__[args.arch](network_data).to(device)
    model = torch.nn.DataParallel(model).to(device)  # for multi-GPU
    cudnn.benchmark = True

    #model settings
    #train settings
    assert (args.solver in ['adam', 'sgd'])
    print('=> setting {} solver'.format(args.solver))
    param_groups = [{
        'params': model.module.bias_parameters(),
        'weight_decay': args.bias_decay
    }, {
        'params': model.module.weight_parameters(),
        'weight_decay': args.weight_decay
    }]
    if args.solver == 'adam':
        optimizer = torch.optim.Adam(param_groups,
                                     args.lr,
                                     betas=(args.momentum, args.beta))
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(param_groups,
                                    args.lr,
                                    momentum=args.momentum)
    #evaluate settings
    if args.evaluate:
        best_EPE = validate(val_loader, model, 0, output_writers)
        return

    #Decays the learning rate of each parameter group by gamma once the number of epoch reaches
    # one of the milestones. Notice that such decay can happen simultaneously with other changes
    # to the learning rate from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones, gamma=0.5)

    # main cycle for train and test
    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()

        #1. train for one epoch
        train_loss, train_EPE = train(train_loader, model, optimizer, epoch,
                                      train_writer)
        train_writer.add_scalar('mean EPE', train_EPE, epoch)

        #2. evaluate on validation(test)set
        with torch.no_grad():  #len(val_loader) == (total *(1- 0.8))/batch_size
            EPE = validate(val_loader, model, epoch, output_writers)
        test_writer.add_scalar('mean EPE', EPE, epoch)
        #3. record the best EPE
        if best_EPE < 0:
            best_EPE = EPE

        is_best = EPE < best_EPE
        best_EPE = min(EPE, best_EPE)

        #4. save checkpoint
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.module.state_dict(),
                'best_EPE': best_EPE,
                'div_flow': args.div_flow
            }, is_best, save_path)