def main():
    args = get_arguments()
    constr_activation = get_constraint(args.activation_bits, 'activation')
    net = mixnet_s(quan_first=True,
                  quan_last=True,
                  constr_activation=constr_activation,
                  preactivation=False,
                  bw_act=args.activation_bits)
    test_loader = dataloader_imagenet(args.data_root, split='test', batch_size=args.batch_size)
    add_lsqmodule(net, bit_width=args.weight_bits)

    name_weights_old = torch.load(args.model_path)
    name_weights_new = net.state_dict()
    name_weights_new.update(name_weights_old)
    load_checkpoint(net, name_weights_new)

    criterion = torch.nn.CrossEntropyLoss()

    score = get_micronet_score(net, args.weight_bits, args.activation_bits)
    
    # Calculate accuracy
    net = net.cuda()

    quan_perf_epoch = eval_performance(net, test_loader, criterion)
    accuracy = quan_perf_epoch[1]

    print("Accuracy:", accuracy)
    print("Score:", score)
Exemple #2
0
def train(cfg_file: str, ckpt=None) -> None:

    # Load the config file
    cfg = load_config(cfg_file)

    # Set the random seed
    set_seed(seed=cfg["training"].get("random_seed", 42))

    # Load the data - Trg as (batch, # of frames, joints + 1 )
    train_data, dev_data, test_data, src_vocab, trg_vocab = load_data(cfg=cfg)

    # Build the Progressive Transformer model
    model = build_model(cfg, src_vocab=src_vocab, trg_vocab=trg_vocab)

    if ckpt is not None:
        use_cuda = cfg["training"].get("use_cuda", False)
        model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda)
        # Build model and load parameters from the checkpoint
        model.load_state_dict(model_checkpoint["model_state"])

    # for training management, e.g. early stopping and model selection
    trainer = TrainManager(model=model, config=cfg)

    # Store copy of original training config in model dir
    shutil.copy2(cfg_file, trainer.model_dir + "/config.yaml")
    # Log all entries of config
    log_cfg(cfg, trainer.logger)

    # Train the model
    trainer.train_and_validate(train_data=train_data, valid_data=dev_data)

    # Test the model with the best checkpoint
    test(cfg_file)
Exemple #3
0
def main():
#   flowers/test/1/image_06743.jpg
#   checkpoints/cp_tmp.pth
    start_time = time()
    
    criterion = get_criterion()
    
    in_arg = get_args_predict()
    
    device = get_device(in_arg.gpu)
    
    model = load_checkpoint(in_arg.checkpoint_path, device)
    
    cat_to_name = load_names(in_arg.category_names)
    
    top_ps, top_class = predict(
        image_path=in_arg.image_path, 
        model=model, 
        cat_to_name=cat_to_name, 
        device=device,
        topk=in_arg.top_k
    )
    print(top_ps)
    print(top_class)
    
    tot_time = time() - start_time
    print(f"\n** Total Elapsed Runtime: {tot_time:.3f} seconds")
def do_run():

    config = get_input_config()

    #setup some globals
    global STANZA
    STANZA = config.get("name")

    http_proxy = config.get("http_proxy")
    https_proxy = config.get("https_proxy")

    proxies = {}

    if not http_proxy is None:
        proxies["http"] = http_proxy
    if not https_proxy is None:
        proxies["https"] = https_proxy

    request_timeout = int(config.get("request_timeout", 30))

    try:
        req_args = {"verify": True, "timeout": float(request_timeout)}
        if proxies:
            req_args["proxies"] = proxies

        req = requests.get(
            url=
            "https://publicdashacc.blob.core.windows.net/publicdata?restype=container&comp=list&prefix=data",
            params=req_args)
        xmldom = etree.fromstring(req.content)

        blobs = xmldom.xpath('/EnumerationResults/Blobs/Blob')
        for blob in blobs:
            blob_etag = blob.xpath('Properties/Etag')[0].text
            blob_name = blob.xpath('Name')[0].text
            logging.info("Found file=%s etag=%s" % (blob_name, blob_etag))
            blob_url = "https://publicdashacc.blob.core.windows.net/publicdata/%s" % (
                blob_name)
            if not load_checkpoint(config, blob_etag):
                print("Processing file={}".format(blob_url))
                data_req = requests.get(url=blob_url, params=req_args)
                data_json = data_req.json()

                iterate_json_data("overview", data_json, blob_name)
                iterate_json_data("countries", data_json, blob_name)
                iterate_json_data("regions", data_json, blob_name)
                iterate_json_data("utlas", data_json, blob_name)
            logging.info("Marking file={} etag={} as processed".format(
                blob_name, blob_etag))
            save_checkpoint(config, blob_etag)

    except RuntimeError, e:
        logging.error("Looks like an error: %s" % str(e))
        sys.exit(2)
Exemple #5
0
def main():
    in_args = predict_input_args()
    image = Image.open(in_args.image_dir)
    image = process_image(image)
    model = load_checkpoint(in_args.checkpoint)

    cat_to_name = category_name(in_args.category_names)

    probs, classes = predict(image, in_args, model)
    top_labels = [cat_to_name[i] for i in classes]
    label_prob = {l: p for (l, p) in list(zip(top_labels, probs))}

    for name, prob in label_prob.items():
        print(f'image name: {name} , probability: {prob}')
    print(f'Device: {in_args.gpu}')
Exemple #6
0
    def init_from_checkpoint(self, path: str) -> None:
        # Find last checkpoint
        model_checkpoint = load_checkpoint(path=path, use_cuda=self.use_cuda)

        # restore model and optimizer parameters
        self.model.load_state_dict(model_checkpoint["model_state"])
        self.optimizer.load_state_dict(model_checkpoint["optimizer_state"])

        if model_checkpoint["scheduler_state"] is not None and \
                self.scheduler is not None:
            # Load the scheduler state
            self.scheduler.load_state_dict(model_checkpoint["scheduler_state"])

        # restore counts
        self.steps = model_checkpoint["steps"]
        self.total_tokens = model_checkpoint["total_tokens"]
        self.best_ckpt_score = model_checkpoint["best_ckpt_score"]
        self.best_ckpt_iteration = model_checkpoint["best_ckpt_iteration"]

        # move parameters to cuda
        if self.use_cuda:
            self.model.cuda()
Exemple #7
0
def test(cfg_file, ckpt: str) -> None:

    # Load the config file
    cfg = load_config(cfg_file)

    # Load the model directory and checkpoint
    model_dir = cfg["training"]["model_dir"]
    # when checkpoint is not specified, take latest (best) from model dir
    if ckpt is None:
        ckpt = get_latest_checkpoint(model_dir, post_fix="_best")
        if ckpt is None:
            raise FileNotFoundError(
                "No checkpoint found in directory {}.".format(model_dir))

    batch_size = cfg["training"].get("eval_batch_size",
                                     cfg["training"]["batch_size"])
    batch_type = cfg["training"].get(
        "eval_batch_type", cfg["training"].get("batch_type", "sentence"))
    use_cuda = cfg["training"].get("use_cuda", False)
    eval_metric = cfg["training"]["eval_metric"]
    max_output_length = cfg["training"].get("max_output_length", None)

    # load the data
    train_data, dev_data, test_data, src_vocab, trg_vocab = load_data(cfg=cfg)

    # To produce testing results
    data_to_predict = {"test": test_data}
    # To produce validation results
    # data_to_predict = {"dev": dev_data}

    # Load model state from disk
    model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda)

    # Build model and load parameters into it
    model = build_model(cfg, src_vocab=src_vocab, trg_vocab=trg_vocab)
    model.load_state_dict(model_checkpoint["model_state"])
    # If cuda, set model as cuda
    if use_cuda:
        model.cuda()

    # Set up trainer to produce videos
    trainer = TrainManager(model=model, config=cfg, test=True)

    # For each of the required data, produce results
    for data_set_name, data_set in data_to_predict.items():

        # Validate for this data set
        score, loss, references, hypotheses, \
        inputs, all_dtw_scores, file_paths = \
            validate_on_data(
                model=model,
                data=data_set,
                batch_size=batch_size,
                max_output_length=max_output_length,
                eval_metric=eval_metric,
                loss_function=None,
                batch_type=batch_type,
                type="val" if not data_set_name is "train" else "train_inf"
            )

        # Set which sequences to produce video for
        display = list(range(len(hypotheses)))

        # Produce videos for the produced hypotheses
        trainer.produce_validation_video(
            output_joints=hypotheses,
            inputs=inputs,
            references=references,
            model_dir=model_dir,
            display=display,
            type="test",
            file_paths=file_paths,
        )
Exemple #8
0
def main():
    os.chdir(os.path.dirname(__file__))
    args = get_arguments()
    constr_weight = get_constraint(args.weight_bits, 'weight')
    constr_activation = get_constraint(args.activation_bits, 'activation')
    if args.dataset == 'cifar10':
        network = resnet20
        dataloader = dataloader_cifar10
    elif args.dataset == 'cifar100':
        t_net = WRN40_6()
        state = torch.load(
            "/prj/neo_lv/user/ybhalgat/LSQ-KD-0911/cifar100_pretrained/wrn40_6.pth"
        )
        t_net.load_state_dict(state)
        network = WRN40_4
        dataloader = dataloader_cifar100
    else:
        if args.network == 'resnet18':
            network = resnet18
        elif args.network == 'resnet50':
            network = resnet50
        elif args.network == 'efficientnet-b0':
            t_net = EfficientNet.from_pretrained("efficientnet-b1")
            network = efficientnet_b0
        elif args.network == "mixnet_s":
            t_net = MixNet(net_type=args.teacher)
            t_net.load_state_dict(
                torch.load("../imagenet_pretrained/" + args.teacher + ".pth"))
            network = mixnet_s
        else:
            print('Not Support Network Type: %s' % args.network)
            return
        dataloader = dataloader_imagenet
    train_loader = dataloader(args.data_root,
                              split='train',
                              batch_size=args.batch_size)
    test_loader = dataloader(args.data_root,
                             split='test',
                             batch_size=args.batch_size)
    net = network(quan_first=args.quan_first,
                  quan_last=args.quan_last,
                  constr_activation=constr_activation,
                  preactivation=args.preactivation,
                  bw_act=args.activation_bits)

    # net.load_state_dict(name_weights_new, strict=False)
    if args.cem:
        ##### CEM vector for 1.5x_W7A7_CEM prefinetuning 72%
        # cem_input = [7, 7, 7, 7, 7, 6, 7, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7,
        #              7, 7, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5,
        #              7, 7, 7, 6, 7, 7, 7, 7, 7, 5, 7, 7, 7, 6, 4, 7, 7, 6,
        #              6, 6, 7, 7, 7, 7, 5, 7, 7, 7, 6, 4, 7, 7, 5, 5, 4, 7,
        #              7, 6, 5, 5, 7, 5, 7, 5, 5, 3]

        ##### CEM vector for 1.5x_W7A7_CEM prefinetuning 70%
        cem_input = [
            7, 7, 7, 7, 7, 5, 7, 7, 7, 7, 7, 7, 7, 7, 6, 6, 7, 7, 6, 7, 6, 7,
            7, 7, 7, 6, 7, 7, 7, 7, 6, 7, 7, 7, 7, 4, 7, 6, 7, 5, 7, 7, 7, 7,
            7, 5, 7, 7, 7, 5, 5, 7, 7, 7, 5, 6, 7, 7, 7, 6, 4, 7, 7, 6, 5, 4,
            7, 6, 5, 5, 4, 7, 7, 6, 5, 4, 7, 7, 6, 5, 5, 3
        ]

        strategy_path = "/prj/neo_lv/user/ybhalgat/LSQ-implementation/lsq_quantizer/cem_strategy_relaxed.txt"
        with open(strategy_path) as fp:
            strategy = fp.readlines()
        strategy = [x.strip().split(",") for x in strategy]

        ##### CEM vector for W6A6_CEM
        # cem_input = [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        #              0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        #              0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1,
        #              1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0,
        #              1, 0, 1, 1, 0, 1, 1, 1, 1, 1]

        strat = {}
        act_strat = {}
        for idx, width in enumerate(cem_input):
            weight_layer_name = strategy[idx][1]
            act_layer_name = strategy[idx][0]
            for name, module in net.named_modules():
                if name.startswith('module'):
                    name = name[7:]  # remove `module.`
                if name == weight_layer_name:
                    strat[name] = int(cem_input[idx])
                if name == act_layer_name:
                    act_strat[name] = int(cem_input[idx])

        add_lsqmodule(net, bit_width=args.weight_bits, strategy=strat)

        for name, module in net.named_modules():
            if name in act_strat:
                if "efficientnet" in args.network:
                    if "_in_act_quant" in name or "first_act" in name or "_head_act_quant0" in name or "_head_act_quant1" in name:
                        temp_constr_act = get_constraint(
                            act_strat[name], 'weight')  #symmetric
                    else:
                        temp_constr_act = get_constraint(
                            act_strat[name], 'activation')  #asymmetric
                elif "mixnet" in args.network:
                    if "last_act" in name or "out_act_quant" in name or "first_act" in name:
                        temp_constr_act = get_constraint(
                            act_strat[name], 'weight')  #symmetric
                    else:
                        temp_constr_act = get_constraint(
                            act_strat[name], 'activation')  #asymmetric
                module.constraint = temp_constr_act

    elif args.manual:
        if args.network == "wrn40_4":
            strategy = {
                "block3.layer.0.conv2": 3,
                "block3.layer.2.conv1": 3,
                "block3.layer.3.conv1": 3,
                "block3.layer.4.conv1": 3,
                "block3.layer.2.conv2": 3,
                "block3.layer.1.conv2": 3,
                "block3.layer.3.conv2": 3,
                "block3.layer.1.conv1": 3,
                "block3.layer.5.conv1": 2,
                "block1.layer.1.conv2": 1
            }
            act_strategy = {
                "block3.layer.0.relu2": 3,
                "block3.layer.2.relu1": 3,
                "block3.layer.3.relu1": 3,
                "block3.layer.4.relu1": 3,
                "block3.layer.2.relu2": 3,
                "block3.layer.1.relu2": 3,
                "block3.layer.3.relu2": 3,
                "block3.layer.1.relu1": 3,
                "block3.layer.5.relu1": 2,
                "block1.layer.1.relu2": 1
            }

        elif args.network == 'efficientnet-b0':
            strategy = {
                "_fc": 3,
                "_conv_head": 5,
                "_blocks.15._project_conv": 5,
                "_blocks.15._expand_conv": 4,
                "_blocks.14._expand_conv": 4,
                "_blocks.13._expand_conv": 4,
                "_blocks.12._expand_conv": 4,
                "_blocks.13._project_conv": 4,
                "_blocks.14._project_conv": 4,
                "_blocks.12._project_conv": 5,
                "_blocks.9._expand_conv": 4,
                "_blocks.10._expand_conv": 4
            }
            act_strategy = {
                "_head_act_quant1": 3,
                "_head_act_quant0": 5,
                "_blocks.15._pre_proj_activation": 5,
                "_blocks.15._in_act_quant": 4,
                "_blocks.14._in_act_quant": 4,
                "_blocks.13._in_act_quant": 4,
                "_blocks.12._in_act_quant": 4,
                "_blocks.13._pre_proj_activation": 4,
                "_blocks.14._pre_proj_activation": 4,
                "_blocks.12._pre_proj_activation": 5,
                "_blocks.9._in_act_quant": 4,
                "_blocks.10._in_act_quant": 4
            }
            #strategy = {"_fc": 3,
            #            "_conv_head": 4,
            #            "_blocks.15._project_conv": 4,
            #            "_blocks.14._project_conv": 4,
            #            "_blocks.13._project_conv": 3,
            #            "_blocks.13._expand_conv": 4,
            #            "_blocks.12._project_conv": 4,
            #            "_blocks.12._expand_conv": 5,
            #            "_blocks.14._expand_conv": 4,
            #            "_blocks.15._expand_conv": 4,
            #            "_blocks.9._project_conv": 4}
            #            #"_blocks.10._project_conv": 4,
            #            #"_blocks.9._expand_conv": 4,
            #            #"_blocks.10._expand_conv": 4,
            #            #"_blocks.7._expand_conv": 4,
            #            #"_blocks.11._expand_conv": 4}
            #act_strategy = {"_head_act_quant1": 3,
            #                "_head_act_quant0": 4,
            #                "_blocks.15._pre_proj_activation": 4,
            #                "_blocks.14._pre_proj_activation": 4,
            #                "_blocks.13._pre_proj_activation": 3,
            #                "_blocks.13._in_act_quant": 4,
            #                "_blocks.12._pre_proj_activation": 4,
            #                "_blocks.12._in_act_quant": 5,
            #                "_blocks.14._in_act_quant": 4,
            #                "_blocks.15._in_act_quant": 4,
            #                "_blocks.9._pre_proj_activation": 4}
            #                #"_blocks.10._pre_proj_activation": 4,
            #                #"_blocks.9._in_act_quant": 4,
            #                #"_blocks.10._in_act_quant": 4,
            #                #"_blocks.7._in_act_quant": 4,
            #                #"_blocks.11._in_act_quant": 4}
        add_lsqmodule(net, bit_width=args.weight_bits, strategy=strategy)

        for name, module in net.named_modules():
            if name in act_strategy:
                if "_in_act_quant" in name or "first_act" in name or "_head_act_quant0" in name or "_head_act_quant1" in name:
                    temp_constr_act = get_constraint(act_strategy[name],
                                                     'weight')  #symmetric
                else:
                    temp_constr_act = get_constraint(act_strategy[name],
                                                     'activation')  #asymmetric
                module.constraint = temp_constr_act

    elif args.haq:
        if args.network == 'resnet50':
            strategy = [
                6, 6, 5, 5, 5, 5, 4, 5, 5, 4, 5, 5, 5, 5, 5, 5, 3, 5, 4, 3, 5,
                4, 3, 4, 4, 4, 2, 5, 4, 3, 3, 5, 3, 2, 5, 3, 2, 4, 3, 2, 5, 3,
                2, 5, 3, 4, 2, 5, 2, 3, 4, 2, 3, 4
            ]
        elif args.network == 'efficientnet-b0':
            strategy = [
                7, 8, 8, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
                7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 7, 6, 7, 6, 7, 6, 5, 6, 5, 6, 4,
                5, 6, 5, 6, 4, 4, 5, 4, 5, 2, 3, 4, 3, 4, 2, 3, 4, 4, 7, 5, 2,
                4, 2, 5, 5, 2, 4, 2, 5, 5, 2, 4, 2, 5, 5, 2, 4, 3, 3, 2
            ]
        add_lsqmodule(net, strategy=strategy)

    else:
        add_lsqmodule(net, bit_width=args.weight_bits)

    model_path = os.path.join(args.model_root, args.model_name + '.pth.tar')
    if not os.path.exists(model_path):
        model_path = model_path[:-4]
    name_weights_old = torch.load(model_path)
    name_weights_new = net.state_dict()
    name_weights_new.update(name_weights_old)
    load_checkpoint(net, name_weights_new, strict=False)

    print(net)
    net = net.cuda()
    net = nn.DataParallel(net)

    t_net = t_net.cuda()
    t_net = nn.DataParallel(t_net)

    if args.pruned:
        start_LSQ(net)

    quan_activation = isinstance(constr_activation, np.ndarray)
    postfix = '_w' if not quan_activation else '_a'
    new_model_name = args.prefix + args.model_name + '_lsq' + postfix
    cache_root = os.path.join('.', 'cache')
    train_loger = LogHelper(new_model_name, cache_root, quan_activation,
                            args.resume)
    optimizer, lr_scheduler, optimizer_t, lr_scheduler_t = get_optimizer(
        s_net=net,
        t_net=t_net,
        optimizer=args.optimizer,
        lr_base=args.learning_rate,
        weight_decay=args.weight_decay,
        lr_scheduler=args.lr_scheduler,
        total_epoch=args.total_epoch,
        quan_activation=quan_activation,
        act_lr_factor=args.act_lr_factor,
        weight_lr_factor=args.weight_lr_factor)
    trainer = Trainer(net=net,
                      t_net=t_net,
                      train_loader=train_loader,
                      test_loader=test_loader,
                      optimizer=optimizer,
                      optimizer_t=optimizer_t,
                      lr_scheduler=lr_scheduler,
                      lr_scheduler_t=lr_scheduler_t,
                      model_name=new_model_name,
                      train_loger=train_loger,
                      pruned=args.pruned)
    trainer(total_epoch=args.total_epoch,
            save_check_point=True,
            resume=args.resume)
def main():
    os.chdir(os.path.dirname(__file__))
    args = get_arguments()
    constr_weight = get_constraint(args.weight_bits, 'weight')
    constr_activation = get_constraint(args.activation_bits, 'activation')
    if args.dataset == 'cifar10':
        network = resnet20
        dataloader = dataloader_cifar10
    elif args.dataset == 'cifar100':
        t_net = ResNet(depth=56, num_classes=100)
        state = torch.load("/prj/neo_lv/user/ybhalgat/LSQ-KD/cifar100_pretrained/resnet56.pth.tar")
        t_net.load_state_dict(state)
        network = resnet20
        dataloader = dataloader_cifar100
    else:
        if args.network == 'resnet18':
            network = resnet18
        elif args.network == 'resnet50':
            network = resnet50
        elif args.network == 'efficientnet-b0':
            t_net = EfficientNet.from_pretrained("efficientnet-b3")
            network = efficientnet_b0
        else:
            print('Not Support Network Type: %s' % args.network)
            return
        dataloader = dataloader_imagenet
    train_loader = dataloader(args.data_root, split='train', batch_size=args.batch_size)
    test_loader = dataloader(args.data_root, split='test', batch_size=args.batch_size)
    net = network(quan_first=args.quan_first,
                  quan_last=args.quan_last,
                  constr_activation=constr_activation,
                  preactivation=args.preactivation,
                  bw_act=args.activation_bits)

    model_path = os.path.join(args.model_root, args.model_name + '.pth.tar')
    if not os.path.exists(model_path):
        model_path = model_path[:-4]
    name_weights_old = torch.load(model_path)
    name_weights_new = net.state_dict()
    name_weights_new.update(name_weights_old)
    load_checkpoint(net, name_weights_new)
    # net.load_state_dict(name_weights_new, strict=False)
    if not args.haq:
        add_lsqmodule(net, bit_width=args.weight_bits)
    else:
        if args.network == 'resnet50':
            strategy = [6, 6, 5, 5, 5, 5, 4, 5, 5, 4, 5, 5, 5, 5, 5, 5, 3, 5, 4, 3, 5, 4, 3, 4, 4, 4, 2, 5,
                        4, 3, 3, 5, 3, 2, 5, 3, 2, 4, 3, 2, 5, 3, 2, 5, 3, 4, 2, 5, 2, 3, 4, 2, 3, 4]
        elif args.network == 'efficientnet-b0':
            strategy = [7, 8, 8, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 6, 6, 6,
                        6, 6, 6, 6, 6, 7, 6, 7, 6, 7, 6, 5, 6, 5, 6, 4, 5, 6, 5, 6, 4, 4, 5, 4, 5, 2,
                        3, 4, 3, 4, 2, 3, 4, 4, 7, 5, 2, 4, 2, 5, 5, 2, 4, 2, 5, 5, 2, 4, 2, 5, 5, 2,
                        4, 3, 3, 2]
        add_lsqmodule(net, strategy=strategy)

    print(net)
    net = net.cuda()
    net = nn.DataParallel(net, device_ids=range(cuda.device_count()))

    t_net = t_net.cuda()
    t_net = nn.DataParallel(t_net, device_ids=range(cuda.device_count()))



    quan_activation = isinstance(constr_activation, np.ndarray)
    postfix = '_w' if not quan_activation else '_a'
    new_model_name = args.prefix + args.model_name + '_lsq' + postfix
    cache_root = os.path.join('.', 'cache')
    train_loger = LogHelper(new_model_name, cache_root, quan_activation, args.resume)
    optimizer, lr_scheduler, optimizer_t = get_optimizer(s_net=net,
                                            t_net=t_net,
                                            optimizer=args.optimizer,
                                            lr_base=args.learning_rate,
                                            weight_decay=args.weight_decay,
                                            lr_scheduler=args.lr_scheduler,
                                            total_epoch=args.total_epoch,
                                            quan_activation=quan_activation,
                                            act_lr_factor=args.act_lr_factor,
                                            weight_lr_factor=args.weight_lr_factor)
    trainer = Trainer(net=net,
                      t_net=t_net,
                      train_loader=train_loader,
                      test_loader=test_loader,
                      optimizer=optimizer,
                      optimizer_t=optimizer_t,
                      lr_scheduler=lr_scheduler,
                      model_name=new_model_name,
                      train_loger=train_loger)
    trainer(total_epoch=args.total_epoch,
            save_check_point=True,
            resume=args.resume)
def main():
    args = get_arguments()
    constr_activation = get_constraint(args.activation_bits, 'activation')

    net = WRN40_4(quan_first=False,
                  quan_last=False,
                  constr_activation=constr_activation,
                  preactivation=False,
                  bw_act=args.activation_bits)
    test_loader = dataloader_cifar100(args.data_root,
                                      split='test',
                                      batch_size=args.batch_size)
    add_lsqmodule(net, bit_width=args.weight_bits)

    if args.cem:
        strategy = {
            "block3.layer.0.conv2": 3,
            "block3.layer.2.conv1": 3,
            "block3.layer.3.conv1": 3,
            "block3.layer.4.conv1": 3,
            "block3.layer.2.conv2": 3,
            "block3.layer.1.conv2": 3,
            "block3.layer.3.conv2": 3,
            "block3.layer.1.conv1": 3,
            "block3.layer.5.conv1": 2,
            "block1.layer.1.conv2": 1
        }
        act_strategy = {
            "block3.layer.0.relu2": 3,
            "block3.layer.2.relu1": 3,
            "block3.layer.3.relu1": 3,
            "block3.layer.4.relu1": 3,
            "block3.layer.2.relu2": 3,
            "block3.layer.1.relu2": 3,
            "block3.layer.3.relu2": 3,
            "block3.layer.1.relu1": 3,
            "block3.layer.5.relu1": 2,
            "block1.layer.1.relu2": 1
        }

        add_lsqmodule(net, bit_width=args.weight_bits, strategy=strategy)

        for name, module in net.named_modules():
            if name in act_strategy:
                if "_in_act_quant" in name or "first_act" in name or "_head_act_quant0" in name or "_head_act_quant1" in name:
                    temp_constr_act = get_constraint(act_strategy[name],
                                                     'weight')  #symmetric
                else:
                    temp_constr_act = get_constraint(act_strategy[name],
                                                     'activation')  #asymmetric
                module.constraint = temp_constr_act

    name_weights_old = torch.load(args.model_path)
    name_weights_new = net.state_dict()
    name_weights_new.update(name_weights_old)
    load_checkpoint(net, name_weights_new)

    criterion = torch.nn.CrossEntropyLoss()

    score = get_micronet_score(net,
                               args.weight_bits,
                               args.activation_bits,
                               weight_strategy=strategy,
                               activation_strategy=act_strategy,
                               input_res=(3, 32, 32),
                               baseline_params=36500000,
                               baseline_MAC=10490000000)

    # Calculate accuracy
    net = net.cuda()

    quan_perf_epoch = eval_performance(net, test_loader, criterion)
    accuracy = quan_perf_epoch[1]

    print("Accuracy:", accuracy)
    print("Score:", score)
Exemple #11
0
def test(
    cfg_file, ckpt: str, output_path: str = None, logger: logging.Logger = None
) -> None:
    """
    Main test function. Handles loading a model from checkpoint, generating
    translations and storing them and attention plots.

    :param cfg_file: path to configuration file
    :param ckpt: path to checkpoint to load
    :param output_path: path to output
    :param logger: log output to this logger (creates new logger if not set)
    """

    if logger is None:
        logger = logging.getLogger(__name__)
        if not logger.handlers:
            FORMAT = "%(asctime)-15s - %(message)s"
            logging.basicConfig(format=FORMAT)
            logger.setLevel(level=logging.DEBUG)

    cfg = load_config(cfg_file)

    if "test" not in cfg["data"].keys():
        raise ValueError("Test data must be specified in config.")

    # when checkpoint is not specified, take latest (best) from model dir
    if ckpt is None:
        model_dir = cfg["training"]["model_dir"]
        ckpt = get_latest_checkpoint(model_dir)
        if ckpt is None:
            raise FileNotFoundError(
                "No checkpoint found in directory {}.".format(model_dir)
            )

    batch_size = cfg["training"]["batch_size"]
    batch_type = cfg["training"].get("batch_type", "sentence")
    use_cuda = cfg["training"].get("use_cuda", False)
    level = cfg["data"]["level"]
    dataset_version = cfg["data"].get("version", "phoenix_2014_trans")
    translation_max_output_length = cfg["training"].get(
        "translation_max_output_length", None
    )

    # load the data
    _, dev_data, test_data, gls_vocab, txt_vocab = load_data(data_cfg=cfg["data"])

    # load model state from disk
    model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda)

    # build model and load parameters into it
    do_recognition = cfg["training"].get("recognition_loss_weight", 1.0) > 0.0
    do_translation = cfg["training"].get("translation_loss_weight", 1.0) > 0.0
    model = build_model(
        cfg=cfg["model"],
        gls_vocab=gls_vocab,
        txt_vocab=txt_vocab,
        sgn_dim=sum(cfg["data"]["feature_size"])
        if isinstance(cfg["data"]["feature_size"], list)
        else cfg["data"]["feature_size"],
        do_recognition=do_recognition,
        do_translation=do_translation,
    )
    model.load_state_dict(model_checkpoint["model_state"])

    if use_cuda:
        model.cuda()

    # Data Augmentation Parameters
    frame_subsampling_ratio = cfg["data"].get("frame_subsampling_ratio", None)
    # Note (Cihan): we are not using 'random_frame_subsampling' and
    #   'random_frame_masking_ratio' in testing as they are just for training.

    # whether to use beam search for decoding, 0: greedy decoding
    if "testing" in cfg.keys():
        recognition_beam_sizes = cfg["testing"].get("recognition_beam_sizes", [1])
        translation_beam_sizes = cfg["testing"].get("translation_beam_sizes", [1])
        translation_beam_alphas = cfg["testing"].get("translation_beam_alphas", [-1])
    else:
        recognition_beam_sizes = [1]
        translation_beam_sizes = [1]
        translation_beam_alphas = [-1]

    if "testing" in cfg.keys():
        max_recognition_beam_size = cfg["testing"].get(
            "max_recognition_beam_size", None
        )
        if max_recognition_beam_size is not None:
            recognition_beam_sizes = list(range(1, max_recognition_beam_size + 1))

    if do_recognition:
        recognition_loss_function = torch.nn.CTCLoss(
            blank=model.gls_vocab.stoi[SIL_TOKEN], zero_infinity=True
        )
        if use_cuda:
            recognition_loss_function.cuda()
    if do_translation:
        translation_loss_function = XentLoss(
            pad_index=txt_vocab.stoi[PAD_TOKEN], smoothing=0.0
        )
        if use_cuda:
            translation_loss_function.cuda()

    # NOTE (Cihan): Currently Hardcoded to be 0 for TensorFlow decoding
    assert model.gls_vocab.stoi[SIL_TOKEN] == 0

    if do_recognition:
        # Dev Recognition CTC Beam Search Results
        dev_recognition_results = {}
        dev_best_wer_score = float("inf")
        dev_best_recognition_beam_size = 1
        for rbw in recognition_beam_sizes:
            logger.info("-" * 60)
            valid_start_time = time.time()
            logger.info("[DEV] partition [RECOGNITION] experiment [BW]: %d", rbw)
            dev_recognition_results[rbw] = validate_on_data(
                model=model,
                data=dev_data,
                batch_size=batch_size,
                use_cuda=use_cuda,
                batch_type=batch_type,
                dataset_version=dataset_version,
                sgn_dim=sum(cfg["data"]["feature_size"])
                if isinstance(cfg["data"]["feature_size"], list)
                else cfg["data"]["feature_size"],
                txt_pad_index=txt_vocab.stoi[PAD_TOKEN],
                # Recognition Parameters
                do_recognition=do_recognition,
                recognition_loss_function=recognition_loss_function,
                recognition_loss_weight=1,
                recognition_beam_size=rbw,
                # Translation Parameters
                do_translation=do_translation,
                translation_loss_function=translation_loss_function
                if do_translation
                else None,
                translation_loss_weight=1 if do_translation else None,
                translation_max_output_length=translation_max_output_length
                if do_translation
                else None,
                level=level if do_translation else None,
                translation_beam_size=1 if do_translation else None,
                translation_beam_alpha=-1 if do_translation else None,
                frame_subsampling_ratio=frame_subsampling_ratio,
            )
            logger.info("finished in %.4fs ", time.time() - valid_start_time)
            if dev_recognition_results[rbw]["valid_scores"]["wer"] < dev_best_wer_score:
                dev_best_wer_score = dev_recognition_results[rbw]["valid_scores"]["wer"]
                dev_best_recognition_beam_size = rbw
                dev_best_recognition_result = dev_recognition_results[rbw]
                logger.info("*" * 60)
                logger.info(
                    "[DEV] partition [RECOGNITION] results:\n\t"
                    "New Best CTC Decode Beam Size: %d\n\t"
                    "WER %3.2f\t(DEL: %3.2f,\tINS: %3.2f,\tSUB: %3.2f)",
                    dev_best_recognition_beam_size,
                    dev_best_recognition_result["valid_scores"]["wer"],
                    dev_best_recognition_result["valid_scores"]["wer_scores"][
                        "del_rate"
                    ],
                    dev_best_recognition_result["valid_scores"]["wer_scores"][
                        "ins_rate"
                    ],
                    dev_best_recognition_result["valid_scores"]["wer_scores"][
                        "sub_rate"
                    ],
                )
                logger.info("*" * 60)

    if do_translation:
        logger.info("=" * 60)
        dev_translation_results = {}
        dev_best_bleu_score = float("-inf")
        dev_best_translation_beam_size = 1
        dev_best_translation_alpha = 1
        for tbw in translation_beam_sizes:
            dev_translation_results[tbw] = {}
            for ta in translation_beam_alphas:
                dev_translation_results[tbw][ta] = validate_on_data(
                    model=model,
                    data=dev_data,
                    batch_size=batch_size,
                    use_cuda=use_cuda,
                    level=level,
                    sgn_dim=sum(cfg["data"]["feature_size"])
                    if isinstance(cfg["data"]["feature_size"], list)
                    else cfg["data"]["feature_size"],
                    batch_type=batch_type,
                    dataset_version=dataset_version,
                    do_recognition=do_recognition,
                    recognition_loss_function=recognition_loss_function
                    if do_recognition
                    else None,
                    recognition_loss_weight=1 if do_recognition else None,
                    recognition_beam_size=1 if do_recognition else None,
                    do_translation=do_translation,
                    translation_loss_function=translation_loss_function,
                    translation_loss_weight=1,
                    translation_max_output_length=translation_max_output_length,
                    txt_pad_index=txt_vocab.stoi[PAD_TOKEN],
                    translation_beam_size=tbw,
                    translation_beam_alpha=ta,
                    frame_subsampling_ratio=frame_subsampling_ratio,
                )

                if (
                    dev_translation_results[tbw][ta]["valid_scores"]["bleu"]
                    > dev_best_bleu_score
                ):
                    dev_best_bleu_score = dev_translation_results[tbw][ta][
                        "valid_scores"
                    ]["bleu"]
                    dev_best_translation_beam_size = tbw
                    dev_best_translation_alpha = ta
                    dev_best_translation_result = dev_translation_results[tbw][ta]
                    logger.info(
                        "[DEV] partition [Translation] results:\n\t"
                        "New Best Translation Beam Size: %d and Alpha: %d\n\t"
                        "BLEU-4 %.2f\t(BLEU-1: %.2f,\tBLEU-2: %.2f,\tBLEU-3: %.2f,\tBLEU-4: %.2f)\n\t"
                        "CHRF %.2f\t"
                        "ROUGE %.2f",
                        dev_best_translation_beam_size,
                        dev_best_translation_alpha,
                        dev_best_translation_result["valid_scores"]["bleu"],
                        dev_best_translation_result["valid_scores"]["bleu_scores"][
                            "bleu1"
                        ],
                        dev_best_translation_result["valid_scores"]["bleu_scores"][
                            "bleu2"
                        ],
                        dev_best_translation_result["valid_scores"]["bleu_scores"][
                            "bleu3"
                        ],
                        dev_best_translation_result["valid_scores"]["bleu_scores"][
                            "bleu4"
                        ],
                        dev_best_translation_result["valid_scores"]["chrf"],
                        dev_best_translation_result["valid_scores"]["rouge"],
                    )
                    logger.info("-" * 60)

    logger.info("*" * 60)
    logger.info(
        "[DEV] partition [Recognition & Translation] results:\n\t"
        "Best CTC Decode Beam Size: %d\n\t"
        "Best Translation Beam Size: %d and Alpha: %d\n\t"
        "WER %3.2f\t(DEL: %3.2f,\tINS: %3.2f,\tSUB: %3.2f)\n\t"
        "Acc %3.2f\n\t"
        "BLEU-4 %.2f\t(BLEU-1: %.2f,\tBLEU-2: %.2f,\tBLEU-3: %.2f,\tBLEU-4: %.2f)\n\t"
        "CHRF %.2f\t"
        "ROUGE %.2f",
        dev_best_recognition_beam_size if do_recognition else -1,
        dev_best_translation_beam_size if do_translation else -1,
        dev_best_translation_alpha if do_translation else -1,
        dev_best_recognition_result["valid_scores"]["wer"] if do_recognition else -1,
        dev_best_recognition_result["valid_scores"]["wer_scores"]["del_rate"]
        if do_recognition
        else -1,
        dev_best_recognition_result["valid_scores"]["wer_scores"]["ins_rate"]
        if do_recognition
        else -1,
        dev_best_recognition_result["valid_scores"]["wer_scores"]["sub_rate"]
        if do_recognition
        else -1,
        # Acc
        dev_best_recognition_result["valid_scores"]["acc"]
        if do_recognition
        else -1,
        dev_best_translation_result["valid_scores"]["bleu"] if do_translation else -1,
        dev_best_translation_result["valid_scores"]["bleu_scores"]["bleu1"]
        if do_translation
        else -1,
        dev_best_translation_result["valid_scores"]["bleu_scores"]["bleu2"]
        if do_translation
        else -1,
        dev_best_translation_result["valid_scores"]["bleu_scores"]["bleu3"]
        if do_translation
        else -1,
        dev_best_translation_result["valid_scores"]["bleu_scores"]["bleu4"]
        if do_translation
        else -1,
        dev_best_translation_result["valid_scores"]["chrf"] if do_translation else -1,
        dev_best_translation_result["valid_scores"]["rouge"] if do_translation else -1,
    )
    logger.info("*" * 60)

    test_best_result = validate_on_data(
        model=model,
        data=test_data,
        batch_size=batch_size,
        use_cuda=use_cuda,
        batch_type=batch_type,
        dataset_version=dataset_version,
        sgn_dim=sum(cfg["data"]["feature_size"])
        if isinstance(cfg["data"]["feature_size"], list)
        else cfg["data"]["feature_size"],
        txt_pad_index=txt_vocab.stoi[PAD_TOKEN],
        do_recognition=do_recognition,
        recognition_loss_function=recognition_loss_function if do_recognition else None,
        recognition_loss_weight=1 if do_recognition else None,
        recognition_beam_size=dev_best_recognition_beam_size
        if do_recognition
        else None,
        do_translation=do_translation,
        translation_loss_function=translation_loss_function if do_translation else None,
        translation_loss_weight=1 if do_translation else None,
        translation_max_output_length=translation_max_output_length
        if do_translation
        else None,
        level=level if do_translation else None,
        translation_beam_size=dev_best_translation_beam_size
        if do_translation
        else None,
        translation_beam_alpha=dev_best_translation_alpha if do_translation else None,
        frame_subsampling_ratio=frame_subsampling_ratio,
    )

    logger.info(
        "[TEST] partition [Recognition & Translation] results:\n\t"
        "Best CTC Decode Beam Size: %d\n\t"
        "Best Translation Beam Size: %d and Alpha: %d\n\t"
        "WER %3.2f\t(DEL: %3.2f,\tINS: %3.2f,\tSUB: %3.2f)\n\t"
        "Acc %3.2f\n\t"
        "BLEU-4 %.2f\t(BLEU-1: %.2f,\tBLEU-2: %.2f,\tBLEU-3: %.2f,\tBLEU-4: %.2f)\n\t"
        "CHRF %.2f\t"
        "ROUGE %.2f",
        dev_best_recognition_beam_size if do_recognition else -1,
        dev_best_translation_beam_size if do_translation else -1,
        dev_best_translation_alpha if do_translation else -1,
        test_best_result["valid_scores"]["wer"] if do_recognition else -1,
        test_best_result["valid_scores"]["wer_scores"]["del_rate"]
        if do_recognition
        else -1,
        test_best_result["valid_scores"]["wer_scores"]["ins_rate"]
        if do_recognition
        else -1,
        test_best_result["valid_scores"]["wer_scores"]["sub_rate"]
        if do_recognition
        else -1,
        # Acc
        test_best_result["valid_scores"]["acc"]
        if do_recognition
        else -1,
        test_best_result["valid_scores"]["bleu"] if do_translation else -1,
        test_best_result["valid_scores"]["bleu_scores"]["bleu1"]
        if do_translation
        else -1,
        test_best_result["valid_scores"]["bleu_scores"]["bleu2"]
        if do_translation
        else -1,
        test_best_result["valid_scores"]["bleu_scores"]["bleu3"]
        if do_translation
        else -1,
        test_best_result["valid_scores"]["bleu_scores"]["bleu4"]
        if do_translation
        else -1,
        test_best_result["valid_scores"]["chrf"] if do_translation else -1,
        test_best_result["valid_scores"]["rouge"] if do_translation else -1,
    )
    logger.info("*" * 60)

    def _write_to_file(file_path: str, sequence_ids: List[str], hypotheses: List[str]):
        with open(file_path, mode="w", encoding="utf-8") as out_file:
            for seq, hyp in zip(sequence_ids, hypotheses):
                out_file.write(seq + "|" + hyp + "\n")

    if output_path is not None:
        if do_recognition:
            dev_gls_output_path_set = "{}.BW_{:03d}.{}.gls".format(
                output_path, dev_best_recognition_beam_size, "dev"
            )
            _write_to_file(
                dev_gls_output_path_set,
                [s for s in dev_data.sequence],
                dev_best_recognition_result["gls_hyp"],
            )
            test_gls_output_path_set = "{}.BW_{:03d}.{}.gls".format(
                output_path, dev_best_recognition_beam_size, "test"
            )
            _write_to_file(
                test_gls_output_path_set,
                [s for s in test_data.sequence],
                test_best_result["gls_hyp"],
            )

        if do_translation:
            if dev_best_translation_beam_size > -1:
                dev_txt_output_path_set = "{}.BW_{:02d}.A_{:1d}.{}.txt".format(
                    output_path,
                    dev_best_translation_beam_size,
                    dev_best_translation_alpha,
                    "dev",
                )
                test_txt_output_path_set = "{}.BW_{:02d}.A_{:1d}.{}.txt".format(
                    output_path,
                    dev_best_translation_beam_size,
                    dev_best_translation_alpha,
                    "test",
                )
            else:
                dev_txt_output_path_set = "{}.BW_{:02d}.{}.txt".format(
                    output_path, dev_best_translation_beam_size, "dev"
                )
                test_txt_output_path_set = "{}.BW_{:02d}.{}.txt".format(
                    output_path, dev_best_translation_beam_size, "test"
                )

            _write_to_file(
                dev_txt_output_path_set,
                [s for s in dev_data.sequence],
                dev_best_translation_result["txt_hyp"],
            )
            _write_to_file(
                test_txt_output_path_set,
                [s for s in test_data.sequence],
                test_best_result["txt_hyp"],
            )

        with open(output_path + ".dev_results.pkl", "wb") as out:
            pickle.dump(
                {
                    "recognition_results": dev_recognition_results
                    if do_recognition
                    else None,
                    "translation_results": dev_translation_results
                    if do_translation
                    else None,
                },
                out,
            )
        with open(output_path + ".test_results.pkl", "wb") as out:
            pickle.dump(test_best_result, out)
def main():
    args = get_arguments()
    constr_activation = get_constraint(args.activation_bits, 'activation')

    net = efficientnet_b0(quan_first=True,
                          quan_last=True,
                          constr_activation=constr_activation,
                          preactivation=False,
                          bw_act=args.activation_bits)
    test_loader = dataloader_imagenet(args.data_root,
                                      split='test',
                                      batch_size=args.batch_size)
    add_lsqmodule(net, bit_width=args.weight_bits)

    if args.cem:
        ##### CEM vector for 1.5x_W7A7_CEM
        cem_input = [
            7, 7, 7, 7, 7, 5, 7, 7, 7, 7, 7, 7, 7, 7, 6, 6, 7, 7, 6, 7, 6, 7,
            7, 7, 7, 6, 7, 7, 7, 7, 6, 7, 7, 7, 7, 4, 7, 6, 7, 5, 7, 7, 7, 7,
            7, 5, 7, 7, 7, 5, 5, 7, 7, 7, 5, 6, 7, 7, 7, 6, 4, 7, 7, 6, 5, 4,
            7, 6, 5, 5, 4, 7, 7, 6, 5, 4, 7, 7, 6, 5, 5, 3
        ]

        strategy_path = "lsq_quantizer/cem_strategy_relaxed.txt"
        with open(strategy_path) as fp:
            strategy = fp.readlines()
        strategy = [x.strip().split(",") for x in strategy]

        strat = {}
        act_strat = {}
        for idx, width in enumerate(cem_input):
            weight_layer_name = strategy[idx][1]
            act_layer_name = strategy[idx][0]
            for name, module in net.named_modules():
                if name.startswith('module'):
                    name = name[7:]  # remove `module.`
                if name == weight_layer_name:
                    strat[name] = int(cem_input[idx])
                if name == act_layer_name:
                    act_strat[name] = int(cem_input[idx])

        add_lsqmodule(net, bit_width=args.weight_bits, strategy=strat)

        for name, module in net.named_modules():
            if name in act_strat:
                if "_in_act_quant" in name or "first_act" in name or "_head_act_quant0" in name or "_head_act_quant1" in name:
                    temp_constr_act = get_constraint(act_strat[name],
                                                     'weight')  #symmetric
                else:
                    temp_constr_act = get_constraint(act_strat[name],
                                                     'activation')  #asymmetric
                module.constraint = temp_constr_act

    name_weights_old = torch.load(args.model_path)
    name_weights_new = net.state_dict()
    name_weights_new.update(name_weights_old)
    load_checkpoint(net, name_weights_new)

    score = get_micronet_score(net,
                               args.weight_bits,
                               args.activation_bits,
                               weight_strategy=strat,
                               activation_strategy=act_strat)

    criterion = torch.nn.CrossEntropyLoss()

    # Calculate accuracy
    net = net.cuda()

    quan_perf_epoch = eval_performance(net, test_loader, criterion)
    accuracy = quan_perf_epoch[1]

    print("Accuracy:", accuracy)
    print("Score:", score)
Exemple #13
0
# checkpoint_path = '/home/delvinso/neuro/output/bay_cog_comp_sb_18m/axial/_best.path.tar'
# checkpont_path = '/home/delvinso/neuro/output/archive/bay_cog_comp_sb_18m/axial/_best.path.tar'
# checkpoint_path = '/home/delvinso/neuro/output/tensor_check_norm/bay_cog_comp_sb_18m-axial/_best.path.tar'
# checkpoint_path = '/home/delvinso/neuro/output/no_norm/bay_cog_comp_sb_18m-axial/_best.path.tar'
# checkpoint_path = '/home/delvinso/neuro/output/models/notensor_dropout/bay_cog_comp_sb_18m-axial/_best.path.tar'
# checkpoint_path = '/home/delvinso/neuro/output/models/notensor/bay_cog_comp_sb_18m-axial/_best.path.tar'
# checkpoint_path = '/home/delvinso/neuro/output/models/bay_cog_comp_sb_18m-axial/_best.path.tar'

checkpoint_path = '/home/delvinso/neuro/output/models/test_ss/bay_cog_comp_sb_18m-axial/_best.path.tar'
# optimizer = torch.optim.Adam(model.parameters(), lr = 1e-5, weight_decay=0.01)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.eval().to(device)

# optimizer doesn't really matter for making predictions ...
helpers.load_checkpoint(checkpoint=checkpoint_path, model=model)

# set the arguments
sets = ['train', 'valid']
view = 'axial'
outcome = 'bay_cog_comp_sb_18m'
data_dir = '/home/delvinso/neuro'
manifest_path = '/home/delvinso/neuro/output/ubc_npy_outcomes_v3_ss.csv'
# initialize the dataloaders
dls = get_dataloader(sets=sets,
                     view=view,
                     outcome=outcome,
                     data_dir=data_dir,
                     return_pid=True,
                     manifest_path=manifest_path)