Exemplo n.º 1
0
def main(run_dir,
         epoch,
         rtol,
         atol,
         _log):
    config = read_config(run_dir)
    _log.info('Load config from {}'.format(run_dir))

    model_ing = import_source(run_dir,'model_ingredient')
    model = model_ing.make_model(**{**config['model'],'isgpu':True},_log= _log)
    path = get_model_path(run_dir,epoch)
    if isinstance(model,nn.DataParallel):
        model.module.load_state_dict(torch.load(path))
    else:
        model.load_state_dict(torch.load(path))
    _log.info('Load paras from {}'.format(path))

    model = model.eval()

    if hasattr(model,'ODEnet'):
        model.ODEnet.odeblock.rtol = rtol
        model.ODEnet.odeblock.atol = atol

    data_loader_ing = import_source(run_dir,'data_ingredient')
    dset,train,val,test = data_loader_ing.make_dataloaders(**{**config['dataset'],'isgpu':True},_log = _log)

    _log.info('Testing models...')
    loss,acc = validate(model,test)
    _log.info("Test loss = {test_loss:.6f}, Test accuracy = {test_acc:.4f}".format(test_loss = loss,test_acc = acc))
Exemplo n.º 2
0
def main(run_dir, epoch, _log):
    config = read_config(run_dir)
    _log.info('Load config from {}'.format(run_dir))

    model_ing = import_source(run_dir, 'model_ingredient')
    print(config['model'])
    # model = model_ing.make_resnet(**{**config['model'],'isgpu':True},_log= _log)
    model = model_ing.make_resnet(config['model']['layers'], True)
    path = get_model_path(run_dir, epoch)
    path = path.replace('\\', '/')
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(torch.load(path))
    else:
        model.load_state_dict(torch.load(path))
    _log.info('Load paras from {}'.format(path))

    model = model.eval()

    data_loader_ing = import_source(run_dir, 'data_ingredient')
    dset, train, val, test = data_loader_ing.make_dataloaders(**{
        **config['dataset'], 'isgpu':
        True
    },
                                                              _log=_log)

    _log.info('Testing resnet models...')
    loss, acc = validate(model, test)
    _log.info(
        "Resnet Test loss = {test_loss:.6f}, Test accuracy = {test_acc:.4f}".
        format(test_loss=loss, test_acc=acc))
Exemplo n.º 3
0
def main(run_dir, epoch, device, attack, epsilon, min_end_time, max_end_time,
         tol, batches, pgd_step_size, pgd_num_steps, pgd_random_start, _log):

    config = read_config(run_dir)
    _log.info(f"Read config from {run_dir}")

    model_ing = import_source(run_dir, "model_ingredient")
    model = model_ing.make_model(**{
        **config['model'], 'device': device
    },
                                 _log=_log)
    path = get_model_path(run_dir, epoch)
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(torch.load(path))
    else:
        model.load_state_dict(torch.load(path, map_location=device))
    model = model.eval()
    _log.info(f"Loaded state dict from {path}")

    if hasattr(model, "odeblock"):
        _log.info(f"Updated times to {[min_end_time, max_end_time]}")
        model.odeblock.min_end_time = min_end_time
        model.odeblock.max_end_time = max_end_time
        model.odeblock.atol = tol
        model.odeblock.rtol = tol

    data_ing = import_source(run_dir, "data_ingredient")
    dset, tl, vl, test_loader = data_ing.make_dataloaders(**{
        **config['dataset'], 'device':
        device
    },
                                                          _log=_log)

    if attack == 'pgd':
        attack_fn = partial(ATTACKS[attack],
                            epsilon=epsilon,
                            step_size=pgd_step_size,
                            num_steps=pgd_num_steps,
                            random_start=pgd_random_start)
    else:
        attack_fn = partial(ATTACKS[attack], epsilon=epsilon)

    adv_test_loader = adv.AdversarialLoader(model, test_loader, attack_fn)
    adv_test_loader = TruncIterator(adv_test_loader, batches)

    _log.info("Testing model...")
    test_loss, test_acc = validate(model, adv_test_loader)

    output = f"Test loss = {test_loss:.6f}, Test accuracy = {test_acc:.4f}"
    _log.info(output)

    return output
Exemplo n.º 4
0
def main(run_dir, epoch, isgpu, batches, atol, rtol, att, epsilon, step_size,
         num_step, random_start, _log):  # construct model and load parameters
    config = read_config(run_dir)
    _log.info('Read config from {run_dir}'.format(run_dir))

    mod_ing = import_source(run_dir, 'model_ingredient')
    model = mod_ing.make_model(**{
        **config['model'], 'isgpu': isgpu
    },
                               _log=_log)
    path = get_model_path(run_dir, epoch)

    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(torch.load(path))
    else:
        model.load_state_dict(torch.load(path))
    model = model.eval()
    _log.info('Load model from {path}'.format(path))

    data_ing = import_source(run_dir, 'data_ingredient')
    dset, train, val, test = data_ing.make_dataloaders(**{
        **config['dataset'], 'isgpu':
        isgpu
    },
                                                       _log=_log)

    if att == 'pgd':
        attack = partial(ATTACK[att],
                         epsilon=epsilon,
                         step_size=step_size,
                         num_step=num_step,
                         random_start=random_start)
    else:
        attack = partial(ATTACK[att], epsilon=epsilon)

    if hasattr(model, 'ODEnet'):
        model.ODEnet.odeblock.atol = atol
        model.ODEnet.odeblock.rtol = rtol

    advTestLoader = adv.advloader(test, model, attack)
    advTestLoader = truncIterator(advTestLoader, batches)

    _log.info('Testing Model...')
    test_loss, test_acc = validate(model,
                                   advTestLoader,
                                   _log=logging.getLogger('validate'))

    _log.info('Test loss : {loss}, accuracy : {}'.format(test_loss, test_acc))

    return test_loss, test_acc
Exemplo n.º 5
0
def main(run_dir, epoch, device, attack, epsilon, end_time_start, end_time_end,
         num_times, tol, batches, _log):

    config = read_config(run_dir)
    _log.info(f"Read config from {run_dir}")

    model_ing = import_source(run_dir, "model_ingredient")
    model = model_ing.make_model(**{
        **config['model'], 'device': device
    },
                                 _log=_log)
    path = get_model_path(run_dir, epoch)
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(torch.load(path))
    else:
        model.load_state_dict(torch.load(path, map_location=device))
    model = model.eval()
    _log.info(f"Loaded state dict from {path}")

    if hasattr(model, "odeblock"):
        model.odeblock.atol = tol
        model.odeblock.rtol = tol

    data_ing = import_source(run_dir, "data_ingredient")
    dset, tl, vl, test_loader = data_ing.make_dataloaders(**{
        **config['dataset'], 'device':
        device
    },
                                                          _log=_log)

    attack = partial(ATTACKS[attack], epsilon=epsilon)
    adv_test_loader = adv.AdversarialLoader(model, test_loader, attack)
    adv_test_loader = TruncIterator(adv_test_loader, batches)
    _log.info("Testing model...")

    for end_time in np.linspace(end_time_start, end_time_end, num_times):
        model.odeblock.min_end_time = end_time
        model.odeblock.max_end_time = end_time
        model.odeblock.t = torch.tensor([0, end_time]).float()
        test_loss, test_acc = validate(model, adv_test_loader)
        ex.log_scalar("test_loss", test_loss)
        ex.log_scalar("test_acc", test_acc)
        ex.log_scalar("end_time", end_time)
        _log.info(
            f"end_time={end_time:.6f}, Test loss = {test_loss:.6f}, Test accuracy = {test_acc:.4f}"
        )
Exemplo n.º 6
0
def main(run_dir,
         epoch,
         device,
         min_end_time,
         max_end_time,
         tol,
         _log):

    config = read_config(run_dir)
    _log.info(f"Read config from {run_dir}")

    model_ing = import_source(run_dir, "model_ingredient")
    model = model_ing.make_model(**{**config['model'], 'device':device}, _log=_log)
    path = get_model_path(run_dir, epoch)
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(torch.load(path))
    else:
        model.load_state_dict(torch.load(path, map_location=device))
    model = model.eval()
    _log.info(f"Loaded state dict from {path}")

    if hasattr(model, "odeblock"):
        _log.info(f"Updated times to {[min_end_time, max_end_time]}")
        model.odeblock.min_end_time = min_end_time
        model.odeblock.max_end_time = max_end_time
        model.odeblock.atol = tol
        model.odeblock.rtol = tol


    data_ing = import_source(run_dir, "data_ingredient")
    dset, tl, vl, test_loader = data_ing.make_dataloaders(**{**config['dataset'],
                                                             'device':device},
                                                          _log=_log)
    _log.info("Testing model...")
    test_loss, test_acc = validate(model, test_loader)

    _log.info(f"Test loss = {test_loss:.6f}, Test accuracy = {test_acc:.4f}")