예제 #1
0
def main(args):

    # Read configs
    with open(args.cfg_path, "r") as fp:
        configs = json.load(fp)

    # Update the configs based on command line args
    arg_dict = vars(args)
    for key in arg_dict:
        if key in configs:
            if arg_dict[key] is not None:
                configs[key] = arg_dict[key]

    configs = utils.ConfigMapper(configs)

    configs.attack_eps = float(configs.attack_eps) / 255
    configs.attack_lr = float(configs.attack_lr) / 255

    print("configs mode: ", configs.mode)
    print("configs lr: ", configs.lr)

    configs.save_path = os.path.join(configs.save_path, configs.mode)
    experiment_name = exp_name(configs)
    configs.save_path = os.path.join(configs.save_path, experiment_name)
    pathlib.Path(configs.save_path).mkdir(parents=True, exist_ok=True)

    # settings
    experiment_name = "resnet18_Adversarial_Training_margin" + '_lr_' + str(
        configs.lr) + '_alpha_' + str(configs.alpha) + '_seed_' + str(
            configs.seed) + '_epsilon_' + str(configs.attack_eps)

    trainer = Trainer(configs)
    trainer.train()

    print("training is over!!!")
예제 #2
0
def main(args):

    # Read configs
    with open(args.cfg_path, "r") as fp:
        configs = json.load(fp)

    # Update the configs based on command line args
    arg_dict = vars(args)
    for key in arg_dict:
        if key in configs:
            if arg_dict[key] is not None:
                configs[key] = arg_dict[key]
    configs = utils.ConfigMapper(configs)

    configs.attack_eps = float(configs.attack_eps) / 255
    configs.attack_lr = float(configs.attack_lr) / 255

    configs.save_path = os.path.join(configs.save_path, configs.mode,
                                     configs.alg)
    pathlib.Path(configs.save_path).mkdir(parents=True, exist_ok=True)

    if configs.mode == 'train':
        trainer = Trainer(configs)
        trainer.train()
    elif configs.mode == 'eval':
        evaluator = Evaluator(configs)
        evaluator.eval()
    elif configs.mode == 'vis':
        visualizer = Visualizer(configs)
        visualizer.visualize()
    else:
        raise ValueError('mode should be train, eval or vis')
예제 #3
0
def main(args):

    # Read configs
    with open(args.cfg_path, "r") as fp:
        configs = json.load(fp)

    # Update the configs based on command line args
    arg_dict = vars(args)
    for key in arg_dict:
        if key in configs:
            if arg_dict[key] is not None:
                configs[key] = arg_dict[key]
    
    configs = utils.ConfigMapper(configs)

    configs.attack_eps = float(configs.attack_eps)
    configs.attack_lr = float(configs.attack_lr)

    print("configs mode: ", configs.mode)
    print("configs lr: ", configs.lr)
    print("configs size: ", configs.size)

    configs.save_path = os.path.join(configs.save_path, configs.mode)
    experiment_name = exp_name(configs)
    configs.save_path = os.path.join(configs.save_path, experiment_name)
    pathlib.Path(configs.save_path).mkdir(parents=True, exist_ok=True)

    trainer = Trainer(configs)
    trainer.train()

    print("training is over!!!")
예제 #4
0
def main(args):

    with open(args.cfg_path, "r") as fp:
        configs = json.load(fp)
    configs = utils.ConfigMapper(configs)

    utils.mkdir_p(configs.log_dir)
    utils.mkdir_p(configs.log_dir + '/models')

    lib.print_model_settings(locals().copy())
    trainer.train(configs)
예제 #5
0
파일: main.py 프로젝트: zd8692/robustOT
def main(args):
    assert os.path.exists(args.cfg_path)

    # Forming config
    with open(args.cfg_path) as json_file:
        config = json.load(json_file)

    args_to_override = [
        'method', 'model', 'ano_type', 'logdir', 'rho', 'ent_weight',
        'domain_src', 'domain_tgt'
    ]
    numeric_keys = ['ano_type', 'rho', 'ent_weight', 'vat_weight']
    arg_dict = vars(args)
    for key in args_to_override:
        if key in numeric_keys:
            if arg_dict[key] > -1:
                config[key] = arg_dict[key]
        else:
            if arg_dict[key] != '':
                config[key] = arg_dict[key]
    config = utils.ConfigMapper(config)

    # Initializing save_path
    if config.dataset == 'DomainNet':
        domain_list = '{}_{}'.format(config.domain_src, config.domain_tgt)
        config.logdir = os.path.join(
            config.logdir, config.exp, domain_list, config.method,
            config.model,
            'ent_{}_vat_{}_rho_{}'.format(config.ent_weight, config.vat_weight,
                                          config.rho))
    else:
        config.logdir = os.path.join(
            config.logdir, config.exp, config.method, config.model,
            'ent_{}_vat_{}_rho_{}'.format(config.ent_weight, config.vat_weight,
                                          config.rho))
    if args.run_id > -1:
        config.logdir += '_run_{}'.format(args.run_id)
    Path(config.logdir).mkdir(parents=True, exist_ok=True)

    # Creating trainer
    if config.method == 'sourceonly':
        trainer = SourceonlyTrainer(config)
    elif config.method == 'adversarial':
        trainer = AdversarialTrainer(config)
    elif config.method == 'robust_adversarial':
        trainer = RobustAdversarialTrainer(config)

    # Training !!
    trainer.train()
예제 #6
0
def visualize(args):

    num_vis = 100

    # Forming config
    with open(args.cfg_path) as json_file:
        config = json.load(json_file)
    config = utils.ConfigMapper(config)

    # Create dataloader
    source_loader, target_loader, nclasses = datasets.form_visda_datasets(
        config=config, ignore_anomaly=True)

    # Loading model state
    model_state = torch.load(os.path.join(args.results_path,
                                          'model_state.pth'))

    weight_vector = model_state['weight_vector']
    indices_sorted = torch.argsort(weight_vector)
    num = weight_vector.shape[0]
    sampling_interval = int(num / num_vis)
    indices_sampled = indices_sorted[0:num:sampling_interval]

    path_vector_all = target_loader.dataset.samples
    paths = []
    for ind in indices_sampled:
        paths.append(path_vector_all[ind][0])
        print(weight_vector[ind])

    imgs = read_images(paths)
    vutils.save_image(imgs,
                      '{}/weight_vis.png'.format(args.results_path),
                      nrow=10)

    weight_vector_np = weight_vector.cpu().numpy()
    plt.figure()
    plt.rcParams.update({'font.size': 19})
    plt.gcf().subplots_adjust(bottom=0.15)
    plt.hist(weight_vector_np, bins=200)
    plt.xlabel('Weight')
    plt.ylabel('Count')
    plt.yticks([0, 1000, 2000, 3000, 4000, 5000, 6000])
    plt.savefig('{}/weight_hist.png'.format(args.results_path), dpi=300)
예제 #7
0
def visualize(args):

    num_vis = 100

    # Forming config
    with open(args.cfg_path) as json_file:
        config = json.load(json_file)
    config = utils.ConfigMapper(config)

    # Create dataloader
    source_loader, target_loader, nclasses = datasets.form_visda_datasets(config=config, ignore_anomaly=True)
    nclasses = 12
    
    model_state = torch.load(os.path.join(args.results_path, 'model_state.pth'))
    weight_vector = model_state['weight_vector']
    weight_vector = weight_vector.cpu().numpy()

    source_count_list = [0] * nclasses
    target_count_list = [0] * nclasses 
    weight_count_list = [0] * nclasses
    
    source_samples = source_loader.dataset.samples
    for sample in source_samples:
        source_count_list[sample[1]] += 1

    target_samples = target_loader.dataset.samples
    for i, sample in enumerate(target_samples):
        target_count_list[sample[1]] += 1
        weight_count_list[sample[1]] += weight_vector[i]
    
    source_count_list = np.array(source_count_list)
    target_count_list = np.array(target_count_list)
    weight_count_list = np.array(weight_count_list)
    
    source_count_list = source_count_list / np.sum(source_count_list)
    ntarget = np.sum(target_count_list)
    target_count_list = target_count_list / ntarget
    weight_count_list = (weight_count_list / ntarget) * nclasses

    print(source_count_list)
    print(target_count_list)
    print(weight_count_list * target_count_list)
예제 #8
0
def main():

    eval_root = 'results/unconditional/WGAN/CelebA_attributes'
    save_root = 'results/evaluation_WGAN'
    Path(save_root).mkdir(exist_ok=True, parents=True)

    folders = os.listdir(eval_root)
    for fol in folders:
        print('Evaluating {}'.format(fol))
        load_path = os.path.join(eval_root, fol)
        save_path = os.path.join(save_root, fol)
        config = json.load(open('{}/config.json'.format(load_path), 'r'))

        # General args
        config = utils.ConfigMapper(config)
        config.imageSize = 32
        config.num_classes = 2
        config.dataset = 'celeba_attribute'
        config.dataroot = '/vulcanscratch/yogesh22/data/celebA/'
        config.G_bias = True

        evaluator_class = Evaluator(config, load_path, save_path)
        evaluator_class.eval()
        print('#########')
예제 #9
0
def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch GAN hub')
    parser.add_argument('--base_cfg_path',
                        default='configs/base_config.json',
                        type=str,
                        help='path to config file')
    parser.add_argument('--cfg_path',
                        default='',
                        type=str,
                        help='path to config file')
    parser.add_argument('--anomaly_frac',
                        default=-1,
                        type=float,
                        help='Fraction of anomalies')
    parser.add_argument('--weight_update_iters',
                        default=-1,
                        type=int,
                        help='Number of iters to update weights')
    parser.add_argument('--rho', default=-1, type=float, help='rho used')
    parser.add_argument('--weight_update',
                        default=-1,
                        type=int,
                        help='Whether to update weights or not')
    parser.add_argument('--weight_update_type',
                        default=-1,
                        type=int,
                        help='Weight update type')
    parser.add_argument('--run_id', default=-1, type=int, help='Run id')
    args = parser.parse_args()

    with open(args.base_cfg_path, "r") as fp:
        configs = json.load(fp)

    # Overriding base configs
    if args.cfg_path != '':
        with open(args.cfg_path, "r") as fp:
            exp_configs = json.load(fp)
        for k in exp_configs.keys():
            configs[k] = exp_configs[k]

    # Overriding with parser args
    logname = ''
    args_to_override = [
        'anomaly_frac', 'weight_update_iters', 'rho', 'weight_update',
        'weight_update_type', 'run_id'
    ]
    arg_dict = vars(args)
    for key in arg_dict:
        if key in args_to_override:
            if arg_dict[key] > -1:
                if key == 'weight_update':
                    if arg_dict[key] == 0:
                        configs[key] = False
                        logname += '_baseline'
                    else:
                        configs[key] = True
                        logname += '_robust'
                elif key == 'weight_update_type':
                    if arg_dict[key] == 0:
                        configs[key] = 'discrete'
                        logname += '_discrete'
                    else:
                        configs[key] = 'cont'
                        logname += '_cont'
                else:
                    configs[key] = arg_dict[key]
                    logname += '_{}_{}'.format(key, configs[key])
    configs['logdir'] = configs['logdir'] + configs['expname'] + '/'
    configs['logdir'] += logname

    Path(configs['logdir']).mkdir(parents=True, exist_ok=True)
    src_path = args.cfg_path
    dst_path = os.path.join(configs['logdir'], 'config.json')
    shutil.copy(src_path, dst_path)

    configs = utils.ConfigMapper(configs)
    return configs