示例#1
0
def create_hooks(args, model, optimizer, losses, logger, serializer):
    device = torch.device(args.device)
    loader = get_dataloader(get_valset_params(args))
    hooks = {
        'serialization':
        lambda steps, samples: serializer.checkpoint_model(
            model, optimizer, global_step=steps, samples_passed=samples),
        'validation':
        lambda step, samples: validate(model,
                                       device,
                                       loader,
                                       samples,
                                       logger,
                                       losses,
                                       weights=args.loss_weights,
                                       is_raw=args.is_raw)
    }
    periods = {
        'serialization': args.checkpointing_interval,
        'validation': args.vp
    }
    periodic_hooks = {
        k: make_hook_periodic(hooks[k], periods[k])
        for k in periods
    }
    return periodic_hooks, hooks
def test_validation():
    args = SimpleNamespace(wdw=0.01,
                           training_steps=1,
                           rs=0,
                           optimizer='ADAM',
                           lr=0.01,
                           half_life=1,
                           device=torch.device('cpu'),
                           num_warmup_steps=0)
    data_path = test_path / 'data/seq'
    shape = [256, 256]
    dataset = DatasetImpl(path=data_path,
                          shape=shape,
                          augmentation=False,
                          collapse_length=1,
                          is_raw=True,
                          max_seq_length=1)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              collate_fn=collate_wrapper,
                                              batch_size=2,
                                              pin_memory=True,
                                              shuffle=False)
    model = init_model(SimpleNamespace(flownet_path=test_path.parent /
                                       'EV_FlowNet',
                                       mish=False,
                                       sp=None,
                                       prefix_length=0,
                                       suffix_length=0,
                                       max_sequence_length=1,
                                       dynamic_sample_length=False,
                                       event_representation_depth=9),
                       device=args.device)
    optimizer, scheduler = construct_train_tools(args, model)
    evaluator = Losses([
        tuple(map(lambda x: x // 2**i, shape)) for i in range(4)
    ][::-1], 2, args.device)
    with tempfile.TemporaryDirectory() as td:
        logger = torch.utils.tensorboard.SummaryWriter(log_dir=td)
        validate(model=model,
                 device=args.device,
                 loader=data_loader,
                 samples_passed=0,
                 logger=logger,
                 evaluator=evaluator)
        del logger
    time.sleep(1)
示例#3
0
def main():
    args = parse_arguments()

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.use_cuda:
        torch.cuda.manual_seed_all(args.seed)
    cudnn.benchmark = True

    model_path = get_model_path(args.dataset, args.arch, args.seed)

    # Init logger
    log_file_name = os.path.join(model_path, 'log.txt')
    print("Log file: {}".format(log_file_name))
    log = open(log_file_name, 'w')
    print_log('model path : {}'.format(model_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    for key, value in state.items():
        print_log("{} : {}".format(key, value), log)
    print_log("Random Seed: {}".format(args.seed), log)
    print_log("Python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("Torch  version : {}".format(torch.__version__), log)
    print_log("Cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # Data specifications for the webistes dataset
    mean = [0., 0., 0.]
    std = [1., 1., 1.]
    input_size = 224
    num_classes = 4

    # Dataset
    traindir = os.path.join(WEBSITES_DATASET_PATH, 'train')
    valdir = os.path.join(WEBSITES_DATASET_PATH, 'val')

    train_transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    test_transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    data_train = dset.ImageFolder(root=traindir, transform=train_transform)
    data_test = dset.ImageFolder(root=valdir, transform=test_transform)

    # Dataloader
    data_train_loader = torch.utils.data.DataLoader(data_train,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=args.workers,
                                                    pin_memory=True)
    data_test_loader = torch.utils.data.DataLoader(data_test,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

    # Network
    if args.arch == "vgg16":
        net = models.vgg16(pretrained=True)
    elif args.arch == "vgg19":
        net = models.vgg19(pretrained=True)
    elif args.arch == "resnet18":
        net = models.resnet18(pretrained=True)
    elif args.arch == "resnet50":
        net = models.resnet50(pretrained=True)
    elif args.arch == "resnet101":
        net = models.resnet101(pretrained=True)
    elif args.arch == "resnet152":
        net = models.resnet152(pretrained=True)
    else:
        raise ValueError("Network {} not supported".format(args.arch))

    if num_classes != 1000:
        net = manipulate_net_architecture(model_arch=args.arch,
                                          net=net,
                                          num_classes=num_classes)

    # Loss function
    if args.loss_function == "ce":
        criterion = torch.nn.CrossEntropyLoss()
    else:
        raise ValueError

    # Cuda
    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    # Optimizer
    momentum = 0.9
    decay = 5e-4
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.learning_rate,
                                momentum=momentum,
                                weight_decay=decay,
                                nesterov=True)

    recorder = RecorderMeter(args.epochs)
    start_time = time.time()
    epoch_time = AverageMeter()

    # Main loop
    for epoch in range(args.epochs):
        current_learning_rate = adjust_learning_rate(args.learning_rate,
                                                     momentum, optimizer,
                                                     epoch, args.gammas,
                                                     args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train_model(data_loader=data_train_loader,
                                           model=net,
                                           criterion=criterion,
                                           optimizer=optimizer,
                                           epoch=epoch,
                                           log=log,
                                           print_freq=200,
                                           use_cuda=True)

        # evaluate on test set
        print_log("Validation on test dataset:", log)
        val_acc, val_loss = validate(data_test_loader,
                                     net,
                                     criterion,
                                     log=log,
                                     use_cuda=args.use_cuda)
        recorder.update(epoch, train_los, train_acc, val_loss, val_acc)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'args': copy.deepcopy(args),
            }, model_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(model_path, 'curve.png'))

    log.close()
示例#4
0
def main():
    args = get_args()

    print('----- Params for debug: ----------------')
    print(args)

    print('data = {}'.format(args.data))
    print('road = {}'.format(args.road))

    print('Train model ...')

    # Imagenet normalization in case of pre-trained network
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # Resize data before using
    transform = transforms.Compose([
        transforms.Resize(260),
        transforms.CenterCrop(250),
        transforms.ToTensor(), normalize
    ])

    train_record = None  # 'Record001'
    train_dataset = Apolloscape(root=args.data,
                                road=args.road,
                                transform=transform,
                                record=train_record,
                                normalize_poses=True,
                                pose_format='quat',
                                train=True,
                                cache_transform=not args.no_cache_transform,
                                stereo=args.stereo)

    val_record = None  # 'Record011'
    val_dataset = Apolloscape(root=args.data,
                              road=args.road,
                              transform=transform,
                              record=val_record,
                              normalize_poses=True,
                              pose_format='quat',
                              train=False,
                              cache_transform=not args.no_cache_transform,
                              stereo=args.stereo)

    # Show datasets
    print(train_dataset)
    print(val_dataset)

    shuffle_data = True

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=shuffle_data)  # batch_size = 75
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=shuffle_data)  # batch_size = 75

    # Get mean and std from dataset
    poses_mean = val_dataset.poses_mean
    poses_std = val_dataset.poses_std

    # Select active device
    if torch.cuda.is_available() and args.device == 'cuda':
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    print('device = {}'.format(device))

    # Used as prefix for filenames
    time_str = datetime.now().strftime('%Y%m%d_%H%M%S')

    # Create pretrained feature extractor
    if args.feature_net == 'resnet18':
        feature_extractor = models.resnet18(pretrained=args.pretrained)
    elif args.feature_net == 'resnet34':
        feature_extractor = models.resnet34(pretrained=args.pretrained)
    elif args.feature_net == 'resnet50':
        feature_extractor = models.resnet50(pretrained=args.pretrained)

    # Num features for the last layer before pose regressor
    num_features = args.feature_net_features  # 2048

    experiment_name = get_experiment_name(args)

    # Create model
    model = PoseNet(feature_extractor, num_features=num_features)
    model = model.to(device)

    # Criterion
    criterion = PoseNetCriterion(stereo=args.stereo,
                                 beta=args.beta,
                                 learn_beta=args.learn_beta)
    criterion.to(device)

    # Add all params for optimization
    param_list = [{'params': model.parameters()}]
    if criterion.learn_beta:
        param_list.append({'params': criterion.parameters()})

    # Create optimizer
    optimizer = optim.Adam(params=param_list, lr=args.lr, weight_decay=0.0005)

    start_epoch = 0

    # Restore from checkpoint is present
    if args.checkpoint is not None:
        checkpoint_file = args.checkpoint

        if os.path.isfile(checkpoint_file):
            print('\nLoading from checkpoint: {}'.format(checkpoint_file))
            checkpoint = torch.load(checkpoint_file)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optim_state_dict'])
            start_epoch = checkpoint['epoch']
            if 'criterion_state_dict' in checkpoint:
                criterion.load_state_dict(checkpoint['criterion_state_dict'])
                print('Loaded criterion params too.')

    n_epochs = start_epoch + args.epochs

    print('\nTraining ...')
    val_freq = args.val_freq
    for e in range(start_epoch, n_epochs):

        # Train for one epoch
        train(train_dataloader,
              model,
              criterion,
              optimizer,
              e,
              n_epochs,
              log_freq=args.log_freq,
              poses_mean=train_dataset.poses_mean,
              poses_std=train_dataset.poses_std,
              device=device,
              stereo=args.stereo)

        # Run validation loop
        if e > 0 and e % val_freq == 0:
            end = time.time()
            validate(val_dataloader,
                     model,
                     criterion,
                     e,
                     log_freq=args.log_freq,
                     device=device,
                     stereo=args.stereo)

        # Make figure
        if e > 0 and args.fig_save > 0 and e % args.fig_save == 0:
            exp_name = '{}_{}'.format(time_str, experiment_name)
            make_figure(model,
                        train_dataloader,
                        poses_mean=poses_mean,
                        poses_std=poses_std,
                        epoch=e,
                        experiment_name=exp_name,
                        device=device,
                        stereo=args.stereo)

        # Make checkpoint
        if e > 0 and e % args.checkpoint_save == 0:
            make_checkpoint(model,
                            optimizer,
                            criterion,
                            epoch=e,
                            time_str=time_str,
                            args=args)

    print('\nn_epochs = {}'.format(n_epochs))

    print('\n=== Test Training Dataset ======')
    pred_poses, gt_poses = model_results_pred_gt(model,
                                                 train_dataloader,
                                                 poses_mean,
                                                 poses_std,
                                                 device=device,
                                                 stereo=args.stereo)

    print('gt_poses = {}'.format(gt_poses.shape))
    print('pred_poses = {}'.format(pred_poses.shape))
    t_loss = np.asarray([
        np.linalg.norm(p - t)
        for p, t in zip(pred_poses[:, :3], gt_poses[:, :3])
    ])
    q_loss = np.asarray([
        quaternion_angular_error(p, t)
        for p, t in zip(pred_poses[:, 3:], gt_poses[:, 3:])
    ])

    print('poses_std = {:.3f}'.format(np.linalg.norm(poses_std)))
    print('T: median = {:.3f}, mean = {:.3f}'.format(np.median(t_loss),
                                                     np.mean(t_loss)))
    print('R: median = {:.3f}, mean = {:.3f}'.format(np.median(q_loss),
                                                     np.mean(q_loss)))

    # Save for later visualization
    pred_poses_train = pred_poses
    gt_poses_train = gt_poses

    print('\n=== Test Validation Dataset ======')
    pred_poses, gt_poses = model_results_pred_gt(model,
                                                 val_dataloader,
                                                 poses_mean,
                                                 poses_std,
                                                 device=device,
                                                 stereo=args.stereo)

    print('gt_poses = {}'.format(gt_poses.shape))
    print('pred_poses = {}'.format(pred_poses.shape))
    t_loss = np.asarray([
        np.linalg.norm(p - t)
        for p, t in zip(pred_poses[:, :3], gt_poses[:, :3])
    ])
    q_loss = np.asarray([
        quaternion_angular_error(p, t)
        for p, t in zip(pred_poses[:, 3:], gt_poses[:, 3:])
    ])

    print('poses_std = {:.3f}'.format(np.linalg.norm(poses_std)))
    print('T: median = {:.3f}, mean = {:.3f}'.format(np.median(t_loss),
                                                     np.mean(t_loss)))
    print('R: median = {:.3f}, mean = {:.3f}'.format(np.median(q_loss),
                                                     np.mean(q_loss)))

    # Save for later visualization
    pred_poses_val = pred_poses
    gt_poses_val = gt_poses

    # Save checkpoint
    print('\nSaving model params ....')
    make_checkpoint(model,
                    optimizer,
                    criterion,
                    epoch=n_epochs,
                    time_str=time_str,
                    args=args)
示例#5
0
def main():
    args = parse_arguments()

    random.seed(args.pretrained_seed)
    torch.manual_seed(args.pretrained_seed)
    if args.use_cuda:
        torch.cuda.manual_seed_all(args.pretrained_seed)
    cudnn.benchmark = True

    # get a path for saving the model to be trained
    model_path = get_model_path(dataset_name=args.pretrained_dataset,
                                network_arch=args.pretrained_arch,
                                random_seed=args.pretrained_seed)

    # Init logger
    log_file_name = os.path.join(model_path, 'log_seed_{}.txt'.format(args.pretrained_seed))
    print("Log file: {}".format(log_file_name))
    log = open(log_file_name, 'w')
    print_log('save path : {}'.format(model_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    for key, value in state.items():
        print_log("{} : {}".format(key, value), log)
    print_log("Random Seed: {}".format(args.pretrained_seed), log)
    print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("Torch  version : {}".format(torch.__version__), log)
    print_log("Cudnn  version : {}".format(torch.backends.cudnn.version()), log)
    # Get data specs
    num_classes, (mean, std), input_size, num_channels = get_data_specs(args.pretrained_dataset, args.pretrained_arch)
    pretrained_data_train, pretrained_data_test = get_data(args.pretrained_dataset,
                                                            mean=mean,
                                                            std=std,
                                                            input_size=input_size,
                                                            train_target_model=True)

    pretrained_data_train_loader = torch.utils.data.DataLoader(pretrained_data_train,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=args.workers,
                                                    pin_memory=True)

    pretrained_data_test_loader = torch.utils.data.DataLoader(pretrained_data_test,
                                                    batch_size=args.batch_size,
                                                    shuffle=False,
                                                    num_workers=args.workers,
                                                    pin_memory=True)


    print_log("=> Creating model '{}'".format(args.pretrained_arch), log)
    # Init model, criterion, and optimizer
    net = get_network(args.pretrained_arch, input_size=input_size, num_classes=num_classes, finetune=args.finetune)
    print_log("=> Network :\n {}".format(net), log)
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    non_trainale_params = get_num_non_trainable_parameters(net)
    trainale_params = get_num_trainable_parameters(net)
    total_params = get_num_parameters(net)
    print_log("Trainable parameters: {}".format(trainale_params), log)
    print_log("Non Trainable parameters: {}".format(non_trainale_params), log)
    print_log("Total # parameters: {}".format(total_params), log)

    # define loss function (criterion) and optimizer
    criterion_xent = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion_xent.cuda()

    recorder = RecorderMeter(args.epochs)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(args.epochs):
        current_learning_rate = adjust_learning_rate(args.learning_rate, args.momentum, optimizer, epoch, args.gammas, args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train_target_model(pretrained_data_train_loader, net, criterion_xent, optimizer, epoch, log,
                                    print_freq=args.print_freq,
                                    use_cuda=args.use_cuda)

        # evaluate on validation set
        print_log("Validation on pretrained test dataset:", log)
        val_acc = validate(pretrained_data_test_loader, net, criterion_xent, log, use_cuda=args.use_cuda)
        is_best = recorder.update(epoch, train_los, train_acc, 0., val_acc)

        save_checkpoint({
          'epoch'       : epoch + 1,
          'arch'        : args.pretrained_arch,
          'state_dict'  : net.state_dict(),
          'recorder'    : recorder,
          'optimizer'   : optimizer.state_dict(),
          'args'        : copy.deepcopy(args),
        }, model_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(model_path, 'curve.png') )

    log.close()