Exemplo n.º 1
0
def main():

    # enable mixed-precision computation if desired
    if args.amp:
        mixed_precision.enable_mixed_precision()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # get the dataset
    dataset = get_dataset(args.dataset)

    _, test_loader, _ = build_dataset(
        dataset=dataset, batch_size=args.batch_size, input_dir=args.input_dir
    )

    torch_device = torch.device("cuda")
    checkpointer = Checkpointer()

    model = checkpointer.restore_model_from_checkpoint(args.checkpoint_path)
    model = model.to(torch_device)
    model, _ = mixed_precision.initialize(model, None)

    test_stats = AverageMeterSet()
    test(model, test_loader, torch_device, test_stats)
    stat_str = test_stats.pretty_string(ignore=model.tasks)
    print(stat_str)
Exemplo n.º 2
0
def main():
    # create target output dir if it doesn't exist yet
    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)

    # enable mixed-precision computation if desired
    amp = ""
    if args.amp:
        amp = "torch"
        if args.apex:
            print("Error: Cannot use both --amp and --apex.")
            exit()

    if args.apex:
        amp = "apex"
        mixed_precision.enable_mixed_precision()

    # set the RNG seeds (probably more hidden elsewhere...)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # get the dataset
    dataset = get_dataset(args.dataset)
    encoder_size = get_encoder_size(dataset)

    # get a helper object for tensorboard logging
    log_dir = os.path.join(args.output_dir, args.run_name)
    stat_tracker = StatTracker(log_dir=log_dir)

    # get dataloaders for training and testing
    train_loader, test_loader, num_classes = \
        build_dataset(dataset=dataset,
                      batch_size=args.batch_size,
                      input_dir=args.input_dir,
                      labeled_only=args.classifiers)

    torch_device = torch.device('cuda')
    checkpointer = Checkpointer(args.output_dir)
    if args.cpt_load_path:
        model = checkpointer.restore_model_from_checkpoint(
            args.cpt_load_path, training_classifier=args.classifiers)
    else:
        # create new model with random parameters
        model = Model(ndf=args.ndf,
                      n_classes=num_classes,
                      n_rkhs=args.n_rkhs,
                      tclip=args.tclip,
                      n_depth=args.n_depth,
                      encoder_size=encoder_size,
                      use_bn=(args.use_bn == 1))
        model.init_weights(init_scale=1.0)
        checkpointer.track_new_model(model)

    model = model.to(torch_device)

    # select which type of training to do
    task = train_classifiers if args.classifiers else train_self_supervised
    task(model, args.learning_rate, dataset, train_loader, test_loader,
         stat_tracker, checkpointer, args.output_dir, torch_device, amp)
Exemplo n.º 3
0
def main():
    # create target output dir if it doesn't exist yet
    if not os.path.isdir(args['output_dir']):
        os.mkdir(args['output_dir'])

    # enable mixed-precision computation if desired
    if args['amp']:
        mixed_precision.enable_mixed_precision()

    # set the RNG seeds (probably more hidden elsewhere...)
    torch.manual_seed(args['seed'])
    torch.cuda.manual_seed(args['seed'])

    # get the dataset
    dataset = get_dataset(args['dataset'])
    encoder_size = get_encoder_size(dataset)

    # get a helper object for tensorboard logging
    log_dir = os.path.join(args['output_dir'], args['run_name'])
    stat_tracker = StatTracker(log_dir=log_dir)

    # get dataloaders for training and testing
    train_loader, test_loader, num_classes = \
        build_dataset(dataset=dataset,
                      batch_size=args['batch_size'],
                      input_dir=args['input_dir'],
                      labeled_only=args['classifiers'])

    torch_device = torch.device('cuda')
    checkpointer = Checkpointer(args['output_dir'])
    if args['cpt_load_path']:
        model = checkpointer.restore_model_from_checkpoint(
            args['cpt_load_path'], training_classifier=args['classifiers'])
    else:
        # create new model with random parameters
        model = Model(ndf=args['ndf'],
                      n_classes=num_classes,
                      n_rkhs=args['n_rkhs'],
                      tclip=args['tclip'],
                      n_depth=args['n_depth'],
                      encoder_size=encoder_size,
                      use_bn=(args['use_bn'] == 1))
        model.init_weights(init_scale=1.0)
        checkpointer.track_new_model(model)

    model = model.to(torch_device)

    # select which type of training to do
    task = train_classifiers if args['classifiers'] else train_self_supervised
    if args['classifiers']:
        task = train_classifiers
    elif args['decoder']:
        task = train_decoder
    else:
        task = train_self_supervised

    task(model, args['learning_rate'], dataset, train_loader, test_loader,
         stat_tracker, checkpointer, args['output_dir'], torch_device)
def obtain_model(model_type):
    if model_type != 'robust':
        checkpoint_path = 'runs/amdim_cpt.pth'
        checkpointer = Checkpointer()
        print('Loading model')
        model = checkpointer.restore_model_from_checkpoint(checkpoint_path)
        torch_device = torch.device('cuda')
        model = model.to(torch_device)
    else:
        dataset = robustness.datasets.CIFAR()
        model_kwargs = {
            'arch':
            'resnet50',
            'dataset':
            dataset,
            'resume_path':
            f'../robust_classif/robustness_applications/models/CIFAR.pt'
        }
        model, _ = model_utils.make_and_restore_model(**model_kwargs)
    model.eval()
    model = CommonModel(model, model_type)
    return model
Exemplo n.º 5
0
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        img = self._load_image(item['image'])
        # convert 3D to 2D tensor to store in kaldi-format
        uttid = item["uttid"]
        return {"uttid": uttid, "image": img}


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()
    # Restore model from the checkpoint
    ckpt = Checkpointer()
    ckpt.restore_model_from_checkpoint(
        cpt_path="amdim_ndf256_rkhs2048_rd10.pth")
    ckpt.model.to('cuda')
    img_tmp_ark = os.path.splitext(args.img_as_feats_scp)[0] + '.tmp.ark'
    ds = ImageDataset(args.places_json)
    with kio.open_or_fd(img_tmp_ark, 'wb') as f:
        for i in tqdm(range(len(ds))):
            item = ds[i]
            feats = item["image"]
            batch = torch.zeros(2, 3, 128, 128)
            batch[0] = feats
            batch = batch.to('cuda')
            res_dict = ckpt.model(x1=batch, x2=batch, class_only=True)
            global_feats = res_dict["rkhs_glb"][:1]
            k = item["uttid"]
            kio.write_mat(f, global_feats.cpu().detach().numpy(), key=k)
Exemplo n.º 6
0
    #update run id in config file
    model_conf['run_id'] = run_id
    json.dump(model_conf, open('config.json', 'w'))

    #Construct DataLoader and checkpointer
    train_loader = Image_Loader(img_path,
                                batch_size=batch_size,
                                shuffle=True,
                                drop_last=True,
                                num_workers=num_workers,
                                input_shape=input_shape,
                                stage='train')
    checkpointer = Checkpointer(run=run)

    # Load checkpoint if given, otherwise construct a new model
    encoder, mi_estimator = checkpointer.restore_model_from_checkpoint()

    # Compute on multiple GPUs, if there are more than one given
    if torch.cuda.device_count() > 1:
        print("Let's use %d GPUs" % torch.cuda.device_count())
        encoder = torch.nn.DataParallel(encoder).module
        mi_estimator = torch.nn.DataParallel(mi_estimator).module
    encoder.to(device)
    mi_estimator.to(device)

    enc_optim = torch.optim.Adam(encoder.parameters(), lr=lr)
    mi_optim = torch.optim.Adam(mi_estimator.parameters(), lr=lr)
    try:
        encoder.train()
        mi_estimator.train()
        torch.autograd.set_detect_anomaly(True)