コード例 #1
0
def main():
    """
    driver for running controlled experiment from one device's perspective
    """
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.FATAL)

    # parse arguments
    parser = argparse.ArgumentParser(description='set params for controlled experiment')
    parser.add_argument('--out', dest='out_file',
                        type=str, default='logs/log.pickle', help='log history output')
    parser.add_argument('--cfg', dest='config_file',
                        type=str, default='configs/mnist_cfg.json', help='name of the config file')
    parser.add_argument('--draw_graph', dest='graph_file',
                        type=str, default=None, help='name of the output graph filename')
    parser.add_argument('--seed', dest='seed',
                    type=int, default=0, help='use pretrained weights')
    parser.add_argument('--global', dest='global',
                    type=int, default=0, help='report accuracies of global models for greedy-sim and greedy-no-sim')     

    parsed = parser.parse_args()

    if parsed.config_file == None or parsed.out_file == None:
        print('Config file and output diretory has to be specified. Run \'python driver.py -h\' for help/.')

    np.random.seed(parsed.seed)
    tf.compat.v1.set_random_seed(parsed.seed)

    # load config file
    with open(parsed.config_file, 'rb') as f:
        config_json = f.read()
    config = json.loads(config_json)

    if config['dataset'] == 'mnist':
        model_fn = custom_models.get_2nn_mnist_model
        x_train, y_train_orig, x_test, y_test_orig = get_mnist_dataset()
        
    elif config['dataset'] == 'cifar':
        model_fn = custom_models.get_big_cnn_cifar_model
        x_train, y_train_orig, x_test, y_test_orig = get_cifar_dataset()

    elif config['dataset'] == 'opportunity-uci':
        model_fn = custom_models.get_deep_conv_lstm_model
        x_train, y_train_orig, x_test, y_test_orig = get_opp_uci_dataset('../data/opportunity-uci/oppChallenge_gestures.data',
                                                                         SLIDING_WINDOW_LENGTH,
                                                                         SLIDING_WINDOW_STEP)

    train_data_provider = dp.DataProvider(x_train, y_train_orig, config['local-task'])
    test_data_provider = dp.StableTestDataProvider(x_test, y_test_orig, config['test-data-per-label'], config['goal-tasks'])

    # get local dataset for clients
    client_label_conf = {}
    for l in config['local-set']:
        client_label_conf[l] = (int) (config['number-of-data-points']/len(config['local-set']))
    x_train_client, y_train_client = train_data_provider.peek(client_label_conf)

    # get pretrained model
    with open(config['pretrained-model'], 'rb') as handle:
        pretrained_model_weight = pickle.load(handle)

    # set params for building clients
    opt_fn = keras.optimizers.SGD
    compile_config = {'loss': 'mean_squared_error', 'metrics': ['accuracy']}
    hyperparams = config['hyperparams']
    
    # initialize delegation_clients for simulation
    clients = {}
    i = 0
    for k in config['strategies'].keys():
        if config['strategies'][k]:
            client_class = get_client_class(k)
            if client_class == None:
                print("strategy name {} not defined".format(k))
                return

            train_config = {}
            c = client_class(i,
                            model_fn,
                            opt_fn,
                            copy.deepcopy(pretrained_model_weight),
                            x_train_client,
                            y_train_client,
                            train_data_provider,
                            test_data_provider,
                            config['goal-set'],
                            compile_config,
                            train_config,
                            hyperparams)
            clients[k] = c
            i += 1
    
    # initialize logs
    logs = {}
    for ck in clients.keys():
        logs[ck] = {}
        hist = clients[ck].eval()
        if config['hyperparams']['evaluation-metrics'] == 'loss-and-accuracy':
            logs[ck]['accuracy'] = []
            logs[ck]['loss'] = []
            logs[ck]['loss'].append(hist[0])
            logs[ck]['accuracy'].append(hist[1])
        elif config['hyperparams']['evaluation-metrics'] == 'f1-score-weighted':
            logs[ck]['f1-score'] = []
            logs[ck]['f1-score'].append(hist)
        elif config['hyperparams']['evaluation-metrics'] == 'split-f1-score-weighted':
            for labels in config['hyperparams']['split-test-labels']:
                logs[ck]['f1: ' + str(labels)] = []
                logs[ck]['f1: ' + str(labels)].append(hist[str(labels)])
        else:
            ValueError('invalid evaluation-metrics: {}'.format(config['hyperparams']['evaluation-metrics']))

    candidates = np.arange(0,10)
    if len(config['intervals']) != len(config['task-encounters']):
        raise ValueError('length of intervals and task-encounters should be the same: {} != {}'.format(config['intervals'], config['task-encounters']))
    
    all_labels = {}
    for label in np.unique(y_test_orig):
        all_labels[label] = config['number-of-data-points']/len(np.unique(y_test_orig))

    try:
        repeat = config['repeat']
    except:
        repeat = 1

    try:
        same_repeat = config['same-repeat']
    except:
        same_repeat = False

    if same_repeat:
        enc_clients = [] # list of 'other' clients our device is encountering
        one_cycle_length = 0
        for i in range(len(config['intervals'])):
            one_cycle_length += config['intervals'][i]
        for k in range(one_cycle_length):
            label_conf = all_labels

            # print(label_conf)
            rotated_train_data_provider = dp.DataProvider(x_train, y_train_orig, config['task-encounters'][i])
            x_other, y_other = rotated_train_data_provider.peek(label_conf)

            enc_clients.append(
                get_client_class(ck)(k,   # random id
                                    model_fn,
                                    opt_fn,
                                    copy.deepcopy(pretrained_model_weight),
                                    x_other,
                                    y_other,
                                    rotated_train_data_provider,
                                    test_data_provider,
                                    config['goal-set'],
                                    compile_config,
                                    train_config,
                                    hyperparams)
            )

    unique_ids = 10000
    for j in range(repeat):
        ii = -1
        for i in range(len(config['intervals'])):
            print('simulating range {} of {} in repetition {} of {}'.format(i+1, len(config['intervals']), j+1, repeat))
            for _ in tqdm(range(config['intervals'][i])):
                # set labels
                label_conf = all_labels

                task_num = config['task-encounters'][i][np.random.randint(len(config['task-encounters'][i]))]
                rotated_train_data_provider = dp.DataProvider(x_train, y_train_orig, task_num)
                x_other, y_other = rotated_train_data_provider.peek(label_conf)
                # import matplotlib.pyplot as plt
                # plt.imshow(np.reshape(x_other[1], (28, 28)), cmap='gray')
                # plt.show()
                # return
                # run for different approaches: local, greedy, ...
                ii += 1
                unique_ids += 1
                for ck in clients.keys():
                    if not same_repeat:
                        other = get_client_class(ck)(100+task_num,   # random id
                                                    model_fn,
                                                    opt_fn,
                                                    copy.deepcopy(pretrained_model_weight),
                                                    x_other,
                                                    y_other,
                                                    rotated_train_data_provider,
                                                    test_data_provider,
                                                    config['goal-set'],
                                                    compile_config,
                                                    train_config,
                                                    hyperparams)
                        # for bn in range(int(config['number-of-data-points']/config['hyperparams']['batch-size'])):
                        clients[ck].delegate(other, 1, 1)
                    else:
                        # for bn in range(int(config['number-of-data-points']/config['hyperparams']['batch-size'])):
                        clients[ck].delegate(enc_clients[ii], 1, 1)
                        
                    hist = clients[ck].eval()
                    if config['hyperparams']['evaluation-metrics'] == 'loss-and-accuracy':
                        logs[ck]['loss'].append(hist[0])
                        logs[ck]['accuracy'].append(hist[1])
                    elif config['hyperparams']['evaluation-metrics'] == 'f1-score-weighted':
                        logs[ck]['f1-score'].append(hist)
                    elif config['hyperparams']['evaluation-metrics'] == 'split-f1-score-weighted':
                        for labels in config['hyperparams']['split-test-labels']:
                            logs[ck]['f1: ' + str(labels)].append(hist[str(labels)])
                    else:
                        ValueError('invalid evaluation-metrics: {}'.format(config['hyperparams']['evaluation-metrics']))

                    if i == len(config['intervals'])-1 and j == repeat-1:
                        with open('weights/' + ck + '_last_weights.pickle', 'wb') as handle:
                            pickle.dump(clients[ck]._weights , handle, protocol=pickle.HIGHEST_PROTOCOL)

    with open(parsed.out_file, 'wb') as handle:
        pickle.dump(logs, handle, protocol=pickle.HIGHEST_PROTOCOL)

    # draw graph
    if config['hyperparams']['evaluation-metrics'] == 'split-f1-score-weighted':
        if parsed.graph_file != None:
            for k in logs.keys():
                filename = parsed.graph_file.split('.')[:-1] 
                filename = ''.join(filename) + '_' + k
                filename += '.pdf'
                print(filename)
                for labels in logs[k].keys():
                    plt.plot(np.arange(0, len(logs[k][labels])), np.array(logs[k][labels]), lw=1.2)
                plt.legend(list(logs[k].keys()))
                plt.ylabel('F1-score')
                plt.xlabel("Encounters")
                plt.savefig(filename)
                plt.close()
        return

    if config['hyperparams']['evaluation-metrics'] == 'loss-and-accuracy':
        key = 'accuracy'
    elif config['hyperparams']['evaluation-metrics'] == 'f1-score-weighted':
        key = 'f1-score'
    
    if parsed.graph_file != None:
        for k in logs.keys():
            plt.plot(np.arange(0, len(logs[k][key])), np.array(logs[k][key]), lw=1.2)
        plt.legend(list(logs.keys()))
        if key == 'accuracy':
            y_label = 'Accuracy'
        elif key == 'f1-score':
            y_label = 'F1-score'
        # plt.ylim(0.9, 0.940)
        # plt.title(parsed.graph_file)
        plt.ylabel(y_label)
        plt.xlabel("Encounters")
        plt.savefig(parsed.graph_file)
        plt.close()
コード例 #2
0
ファイル: swarm_driver.py プロジェクト: UT-MPC/swarm
def main():
    setup_env()

    # parse arguments
    parser = argparse.ArgumentParser(description='set params for simulation')
    parser.add_argument('--seed',
                        dest='seed',
                        type=int,
                        default=0,
                        help='use pretrained weights')

    parser.add_argument('--tag',
                        dest='tag',
                        type=str,
                        default='default_tag',
                        help='tag')
    parser.add_argument('--cfg',
                        dest='config_file',
                        type=str,
                        default='toy_realworld_mnist_cfg.json',
                        help='name of the config file')

    parsed = parser.parse_args()

    if parsed.config_file == None or parsed.tag == None:
        print(
            'Config file and the tag has to be specified. Run \'python delegation_swarm_driver.py -h\' for help/.'
        )

    LOG_FILE_PATH = Path(LOG_FOLDER, parsed.tag + '.log')
    if LOG_FILE_PATH.exists():
        ans = input(
            "Simulation under the same tag already exists. Do you want to proceed? [y/N]: "
        )
        if not (ans == 'y' or ans == 'Y'):
            print('exit simulation.')
            exit()
    try:
        with open('configs/workstation_info.json', 'rb') as f:
            wsinfo_json = f.read()
        wsinfo = json.loads(wsinfo_json)
        wsinfo['workstation-name']
    except:
        print("create file \'configs/workstation_info.json\'")

    logging.basicConfig(filename=LOG_FILE_PATH,
                        filemode='w',
                        format='%(name)s - %(levelname)s - %(message)s',
                        level=logging.INFO)

    np.random.seed(parsed.seed)
    tf.compat.v1.set_random_seed(parsed.seed)

    # load config file
    with open(parsed.config_file, 'rb') as f:
        config_json = f.read()
    config = json.loads(config_json)

    logging.info('-----------------------<config file>-----------------------')
    for k in config:
        logging.info(str(k + ':'))
        logging.info('    ' + str(config[k]))

    if config['dataset'] == 'mnist':
        num_classes = 10
        model_fn = custom_models.get_2nn_mnist_model
        x_train, y_train_orig, x_test, y_test_orig = get_mnist_dataset()

    elif config['dataset'] == 'cifar':
        num_classes = 10
        model_fn = custom_models.get_big_cnn_cifar_model
        x_train, y_train_orig, x_test, y_test_orig = get_cifar_dataset()
    elif config['dataset'] == 'opportunity-uci':
        model_fn = custom_models.get_deep_conv_lstm_model
        x_train, y_train_orig, x_test, y_test_orig = get_opp_uci_dataset(
            'data/opportunity-uci/oppChallenge_gestures.data',
            SLIDING_WINDOW_LENGTH, SLIDING_WINDOW_STEP)
    else:
        print("invalid dataset name")
        return

    CLIENT_NUM = config['client-num']

    # use existing pretrained model
    if config['pretrained-model'] != None:
        print("using existing pretrained model")
        with open(config['pretrained-model'], 'rb') as handle:
            init_weights = pickle.load(handle)
    # pretrain new model
    else:
        # pretraining setup
        x_pretrain, y_pretrain_orig = dp.filter_data_by_labels(
            x_train, y_train_orig, np.arange(num_classes),
            config['pretrain-config']['data-size'])
        y_pretrain = keras.utils.to_categorical(y_pretrain_orig, num_classes)

        pretrain_config = {'batch_size': 50, 'shuffle': True}
        compile_config = {
            'loss': 'mean_squared_error',
            'metrics': ['accuracy']
        }
        init_model = model_fn()
        init_model.compile(**compile_config)
        pretrain_config['epochs'] = config['pretrain-setup']['epochs']
        pretrain_config['x'] = x_pretrain
        pretrain_config['y'] = y_pretrain
        pretrain_config['verbose'] = 1
        init_model.fit(**pretrain_config)
        init_weights = init_model.get_weights()
        with open(
                'remote_hist/pretrained_model_2nn_local_updates_' +
                get_time({}) + '_.pickle', 'wb') as handle:
            pickle.dump(init_weights, handle, protocol=pickle.HIGHEST_PROTOCOL)

    enc_config = config['enc-exp-config']
    enc_exp_config = {}
    enc_exp_config['data_file_name'] = enc_config['encounter-data-file']
    enc_exp_config['communication_time'] = enc_config['communication-time']
    enc_exp_config['train_time_per_step'] = enc_config['train-time-per-step']
    try:
        enc_exp_config['max_rounds'] = enc_config['max-rounds']
    except:
        raise ValueError(
            'no \'max-rounds\' found in the config file (replaces max-delegations)'
        )
    # if config['mobility-model'] == 'levy-walk':
    try:
        enc_exp_config['local_data_per_quad'] = config['district-9']
    except:
        enc_exp_config['local_data_per_quad'] = None

    hyperparams = config['hyperparams']

    test_data_provider = dp.StableTestDataProvider(
        x_test, y_test_orig, config['hyperparams']['test-data-per-label'])

    test_swarms = []
    swarm_names = []

    # OPTIMIZER = keras.optimizers.SGD

    orig_swarm = Swarm(model_fn, keras.optimizers.SGD, LocalClient, CLIENT_NUM,
                       init_weights, x_train, y_train_orig, test_data_provider,
                       config['local-set-size'], config['goal-set-size'],
                       config['local-data-size'], enc_exp_config, hyperparams)

    for k in config['strategies'].keys():
        if config['strategies'][k]:
            swarm_names.append(k)
            client_class = get_client_class(k)
            test_swarms.append(
                Swarm(model_fn, keras.optimizers.SGD, client_class, CLIENT_NUM,
                      init_weights, x_train, y_train_orig, test_data_provider,
                      config['local-set-size'], config['goal-set-size'],
                      config['local-data-size'], enc_exp_config, hyperparams,
                      orig_swarm))

    # del orig_swarm

    hists = {}
    for i in range(0, len(test_swarms)):
        start = timer()
        print("{} == running {} with {}".format(
            swarm_names[i], test_swarms[i].__class__.__name__,
            test_swarms[i]._clients[0].__class__.__name__))
        print("swarm {} of {}".format(i + 1, len(test_swarms)))
        test_swarms[i].run()
        end = timer()
        print('-------------- Elasped Time --------------')
        print(end - start)
        hists[swarm_names[i]] = (test_swarms[i].hist)

        hist_file_path = PurePath(
            HIST_FOLDER, 'partial_{}_'.format(i) + parsed.tag + '.pickle')
        if i > 0:
            os.remove(
                PurePath(HIST_FOLDER,
                         'partial_{}_'.format(i - 1) + parsed.tag + '.pickle'))
        if i == len(test_swarms) - 1:
            hist_file_path = PurePath(HIST_FOLDER, parsed.tag + '.pickle')
        with open(hist_file_path, 'wb') as handle:
            pickle.dump(hists, handle, protocol=pickle.HIGHEST_PROTOCOL)

    print('drawing graph...')
    matplotlib.rcParams['pdf.fonttype'] = 42
    matplotlib.rcParams['ps.fonttype'] = 42

    processed_hists = {}
    for k in hists.keys():
        # if 'federated' in k:
        #     continue
        t, acc = get_accs_over_time(hists[k], 'clients')
        processed_hists[k] = {}
        processed_hists[k]['times'] = t
        processed_hists[k]['accs'] = acc

    for k in processed_hists.keys():
        # if 'federated' in k:
        #     continue
        plt.plot(np.array(processed_hists[k]['times']),
                 np.array(processed_hists[k]['accs']),
                 lw=1.2)
    plt.legend(list(processed_hists.keys()))
    if hyperparams['evaluation-metrics'] == 'f1-score-weighted':
        plt.ylabel("F1-score")
    else:
        plt.ylabel("Accuracy")
    plt.xlabel("Time")
    graph_file_path = PurePath(FIG_FOLDER, parsed.tag + '.pdf')
    plt.savefig(graph_file_path)
    plt.close()

    logging.info('Simulation completed successfully.')

    # upload to S3 storage
    client = boto3.client('s3')
    S3_BUCKET_NAME = 'opfl-sim-models'
    upload_log_path = PurePath(wsinfo['workstation-name'],
                               'logs/' + parsed.tag + '.log')
    client.upload_file(str(LOG_FILE_PATH), S3_BUCKET_NAME,
                       str(upload_log_path))
    upload_hist_path = PurePath(wsinfo['workstation-name'],
                                'hists/' + parsed.tag + '.pickle')
    client.upload_file(str(hist_file_path), S3_BUCKET_NAME,
                       str(upload_hist_path))
    upload_graph_path = PurePath(wsinfo['workstation-name'],
                                 'figs/' + parsed.tag + '.pdf')
    client.upload_file(str(graph_file_path), S3_BUCKET_NAME,
                       str(upload_graph_path))