Example #1
0
def eval_setup(global_weights):
    args = gv.args

    if 'MNIST' in args.dataset:
        K.set_learning_phase(0)

    # global_weights_np = np.load(gv.dir_name + 'global_weights_t%s.npy' % t)
    global_weights_np = global_weights

    if 'MNIST' in args.dataset:
        global_model = model_mnist(type=args.model_num)
    elif args.dataset == 'CIFAR-10':
        global_model = cifar_10_model()
    elif args.dataset == 'census':
        global_model = census_model_1()

    if args.dataset == 'census':
        x = tf.placeholder(shape=(None,
                                gv.DATA_DIM), dtype=tf.float32)
    else:
        x = tf.placeholder(shape=(None,
                                  gv.IMAGE_ROWS,
                                  gv.IMAGE_COLS,
                                  gv.NUM_CHANNELS), dtype=tf.float32)
    y = tf.placeholder(dtype=tf.int64)

    logits = global_model(x)
    prediction = tf.nn.softmax(logits)
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=y, logits=logits))

    if args.k > 1:
        config = tf.ConfigProto(gpu_options=gv.gpu_options)
        # config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
    elif args.k == 1:
        sess = tf.Session()
    
    K.set_session(sess)
    sess.run(tf.global_variables_initializer())

    global_model.set_weights(global_weights_np)

    return x, y, sess, prediction, loss
Example #2
0
def master():
    K.set_learning_phase(1)

    args = gv.args
    print('Initializing master model')
    config = tf.ConfigProto(gpu_options=gv.gpu_options)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    K.set_session(sess)
    sess.run(tf.global_variables_initializer())

    if 'MNIST' in args.dataset:
        global_model = model_mnist(type=args.model_num)
    elif args.dataset == 'census':
        global_model = census_model_1()
    global_model.summary()
    global_weights_np = global_model.get_weights()
    np.save(gv.dir_name + 'global_weights_t0.npy', global_weights_np)

    return
Example #3
0
def master():
    tf.keras.backend.set_learning_phase(1)

    args = gv.init()
    print('Initializing master model')
    config = tf.ConfigProto(gpu_options=gv.gpu_options)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    tf.keras.backend.set_session(sess)
    sess.run(tf.global_variables_initializer())

    if 'MNIST' in args.dataset:
        global_model = model_mnist(type=args.model_num)
    elif args.dataset == 'census':
        global_model = census_model_1()
    elif args.dataset == 'CIFAR-10':
        global_model = cifar10_model()

    global_weights_np = global_model.get_weights()
    np.save(gv.dir_name + 'global_weights_t0.npy', global_weights_np)
    print("[server] save global weights t0")
    return
Example #4
0
def mal_agent(X_shard, Y_shard, mal_data_X, mal_data_Y, t, gpu_id, return_dict,
              mal_visible, X_test, Y_test):

    args = gv.args

    shared_weights = np.load(gv.dir_name + 'global_weights_t%s.npy' % t)

    holdoff_flag = 0
    if 'holdoff' in args.mal_strat:
        print('Checking holdoff')
        if 'single' in args.mal_obj:
            target, target_conf, actual, actual_conf = mal_eval_single(
                mal_data_X, mal_data_Y, shared_weights)
            if target_conf > 0.8:
                print('Holding off')
                holdoff_flag = 1

    # tf.reset_default_graph()

    K.set_learning_phase(1)

    print('Malicious Agent on GPU %s' % gpu_id)
    # set enviornment
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    if args.dataset == 'census':
        x = tf.placeholder(shape=(None, gv.DATA_DIM), dtype=tf.float32)
        y = tf.placeholder(dtype=tf.int64)
    else:
        x = tf.placeholder(shape=(None, gv.IMAGE_ROWS, gv.IMAGE_COLS,
                                  gv.NUM_CHANNELS),
                           dtype=tf.float32)
        y = tf.placeholder(dtype=tf.int64)

    if 'MNIST' in args.dataset:
        agent_model = model_mnist(type=args.model_num)
    elif args.dataset == 'CIFAR-10':
        agent_model = cifar_10_model()
    elif args.dataset == 'census':
        agent_model = census_model_1()

    logits = agent_model(x)
    prediction = tf.nn.softmax(logits)
    eval_loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,
                                                       logits=logits))

    config = tf.ConfigProto(gpu_options=gv.gpu_options)
    # config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    K.set_session(sess)

    if t >= args.mal_delay and holdoff_flag == 0:
        if args.mal_obj == 'all':
            final_delta = mal_all_algs(x, y, logits, agent_model,
                                       shared_weights, sess, mal_data_X,
                                       mal_data_Y, t)
        elif args.mal_obj == 'single' or 'multiple' in args.mal_obj:
            final_delta, penul_delta = mal_single_algs(
                x, y, logits, agent_model, shared_weights, sess, mal_data_X,
                mal_data_Y, t, mal_visible, X_shard, Y_shard)
    elif t < args.mal_delay or holdoff_flag == 1:
        print('Delay/Hold-off')
        final_delta, _ = benign_train(x, y, agent_model, logits, X_shard,
                                      Y_shard, sess, shared_weights)

    final_weights = shared_weights + final_delta
    agent_model.set_weights(final_weights)

    print('---Eval at mal agent---')
    if 'single' in args.mal_obj:
        target, target_conf, actual, actual_conf = mal_eval_single(
            mal_data_X, mal_data_Y, final_weights)
        print(
            'Target:%s with conf. %s, Curr_pred on malicious model for iter %s:%s with conf. %s'
            % (target, target_conf, t, actual, actual_conf))
    elif 'multiple' in args.mal_obj:
        suc_count_local = mal_eval_multiple(mal_data_X, mal_data_Y,
                                            final_weights)
        print('%s of %s targets achieved' % (suc_count_local, args.mal_num))

    eval_success, eval_loss = eval_minimal(X_test, Y_test, final_weights)
    return_dict['mal_success'] = eval_success
    print('Malicious Agent: success {}, loss {}'.format(
        eval_success, eval_loss))
    write_dict = {}
    # just to maintain ordering
    write_dict['t'] = t + 1
    write_dict['eval_success'] = eval_success
    write_dict['eval_loss'] = eval_loss
    file_write(write_dict, purpose='mal_eval_loss')

    return_dict[str(gv.mal_agent_index)] = np.array(final_delta)
    np.save(gv.dir_name + 'mal_delta_t%s.npy' % t, final_delta)

    if 'auto' in args.mal_strat or 'multiple' in args.mal_obj:
        penul_weights = shared_weights + penul_delta
        if 'single' in args.mal_obj:
            target, target_conf, actual, actual_conf = mal_eval_single(
                mal_data_X, mal_data_Y, penul_weights)
            print(
                'Penul weights ---- Target:%s with conf. %s, Curr_pred on malicious model for iter %s:%s with conf. %s'
                % (target, target_conf, t, actual, actual_conf))
        elif 'multiple' in args.mal_obj:
            suc_count_local = mal_eval_multiple(mal_data_X, mal_data_Y,
                                                penul_weights)
            print('%s of %s targets achieved' %
                  (suc_count_local, args.mal_num))

        eval_success, eval_loss = eval_minimal(X_test, Y_test, penul_weights)
        print('Penul weights ---- Malicious Agent: success {}, loss {}'.format(
            eval_success, eval_loss))

    return
def train_fn(X_train_shards,
             Y_train_shards,
             X_test,
             Y_test,
             return_dict,
             mal_data_X=None,
             mal_data_Y=None):
    # Start the training process
    num_agents_per_time = int(args.C * args.k)
    simul_agents = gv.num_gpus * gv.max_agents_per_gpu
    simul_num = min(num_agents_per_time, simul_agents)
    alpha_i = 1.0 / args.k
    agent_indices = np.arange(args.k)
    if args.mal:
        mal_agent_index = gv.mal_agent_index
    guidance = {
        'gradient': (np.ones(args.k) * alpha_i, np.ones(args.k)),
        'data': (np.ones(args.k) * alpha_i, np.zeros(args.k))
    }
    unupated_frac = (args.k - num_agents_per_time) / float(args.k)
    t = 0
    mal_visible = []
    eval_loss_list = []
    loss_track_list = []
    lr = args.eta
    loss_count = 0
    if args.gar == 'krum':
        krum_select_indices = []

    while return_dict['eval_success'] < gv.max_acc and t < args.T:
        print('Time step %s' % t)

        process_list = []
        mal_active = 0
        # curr_agents = np.random.choice(agent_indices, num_agents_per_time,
        #                                replace=False)
        # if t == 0:
        #     curr_agents = np.random.choice(agent_indices, num_agents_per_time, replace=False)
        # else:
        curr_agents = np.random.choice(
            agent_indices,
            num_agents_per_time,
            p=np.exp(guidance['gradient'][0] / guidance['gradient'][1]) /
            np.sum(np.exp(guidance['gradient'][0] / guidance['gradient'][1]),
                   axis=0),
            replace=False)
        print('Set of agents chosen: %s' % curr_agents)

        if t == 0:
            if 'MNIST' in args.dataset:
                global_model = model_mnist(type=args.model_num)
            elif args.dataset == 'CIFAR-10':
                global_model = cifar_10_model()
            elif args.dataset == 'census':
                global_model = census_model_1()
            global_weights = global_model.get_weights()
            np.save(gv.dir_name + 'global_weights_t%s.npy' % t, global_weights)
        else:
            global_weights = np.load(gv.dir_name +
                                     'global_weights_t%s.npy' % t,
                                     allow_pickle=True)

        k = 0
        agents_left = 1e4
        while k < num_agents_per_time:
            # true_simul = min(simul_num,agents_left)
            true_simul = 1
            print('training %s agents' % true_simul)
            for l in range(true_simul):
                gpu_index = int(l / gv.max_agents_per_gpu)
                gpu_id = gv.gpu_ids[gpu_index]
                i = curr_agents[k]
                if args.mal is False or i != mal_agent_index:
                    # p = Process(target=agent, args=(i, X_train_shards[i],
                    #                             Y_train_shards[i], t, gpu_id, return_dict, X_test, Y_test,lr))
                    agent(i, X_train_shards[i], Y_train_shards[i], t, gpu_id,
                          return_dict, X_test, Y_test, lr)
                elif args.mal is True and i == mal_agent_index:
                    # p = Process(target=mal_agent, args=(X_train_shards[mal_agent_index],
                    #                                 Y_train_shards[mal_agent_index], mal_data_X, mal_data_Y, t, gpu_id, return_dict, mal_visible, X_test, Y_test))
                    mal_agent(X_train_shards[mal_agent_index],
                              Y_train_shards[mal_agent_index], mal_data_X,
                              mal_data_Y, t, gpu_id, return_dict, mal_visible,
                              X_test, Y_test)
                    mal_active = 1
                # print (return_dict[str(i)])
                guidance['gradient'][0][i] += np.sqrt(
                    np.array([(x**2).mean()
                              for x in return_dict[str(i)]]).mean())
                guidance['gradient'][1][i] += 1
                # p.start()
                # process_list.append(p)
                k += 1
            # for item in process_list:
            #     item.join()
            agents_left = num_agents_per_time - k
            print('Agents left:%s' % agents_left)

        if mal_active == 1:
            mal_visible.append(t)

        print('Joined all processes for time step %s' % t)

        if 'avg' in args.gar:
            if args.mal:
                count = 0
                for k in range(num_agents_per_time):
                    if curr_agents[k] != mal_agent_index:
                        if count == 0:
                            ben_delta = alpha_i * return_dict[str(
                                curr_agents[k])]
                            np.save(gv.dir_name + 'ben_delta_sample%s.npy' % t,
                                    return_dict[str(curr_agents[k])])
                            count += 1
                        else:
                            ben_delta += alpha_i * return_dict[str(
                                curr_agents[k])]

                np.save(gv.dir_name + 'ben_delta_t%s.npy' % t, ben_delta)
                global_weights += alpha_i * return_dict[str(mal_agent_index)]
                global_weights += ben_delta
            else:
                for k in range(num_agents_per_time):
                    global_weights += alpha_i * return_dict[str(
                        curr_agents[k])]

        elif 'krum' in args.gar:
            collated_weights = []
            collated_bias = []
            agg_num = int(num_agents_per_time - 1 - 2)
            for k in range(num_agents_per_time):
                # weights_curr, bias_curr = collate_weights(return_dict[str(curr_agents[k])])
                weights_curr, bias_curr = collate_weights(return_dict[str(k)])
                collated_weights.append(weights_curr)
                collated_bias.append(collated_bias)
            score_array = np.zeros(num_agents_per_time)
            for k in range(num_agents_per_time):
                dists = []
                for i in range(num_agents_per_time):
                    if i == k:
                        continue
                    else:
                        dists.append(
                            np.linalg.norm(collated_weights[k] -
                                           collated_weights[i]))
                dists = np.sort(np.array(dists))
                dists_subset = dists[:agg_num]
                score_array[k] = np.sum(dists_subset)
            print(score_array)
            krum_index = np.argmin(score_array)
            print(krum_index)
            global_weights += return_dict[str(krum_index)]
            if krum_index == mal_agent_index:
                krum_select_indices.append(t)
        elif 'coomed' in args.gar:
            # Fix for mean aggregation first!
            weight_tuple_0 = return_dict[str(curr_agents[0])]
            weights_0, bias_0 = collate_weights(weight_tuple_0)
            weights_array = np.zeros((num_agents_per_time, len(weights_0)))
            bias_array = np.zeros((num_agents_per_time, len(bias_0)))
            # collated_weights = []
            # collated_bias = []
            for k in range(num_agents_per_time):
                weight_tuple = return_dict[str(curr_agents[k])]
                weights_curr, bias_curr = collate_weights(weight_tuple)
                weights_array[k, :] = weights_curr
                bias_array[k, :] = bias_curr
            shape_size = model_shape_size(weight_tuple)
            # weights_array = np.reshape(np.array(collated_weights),(len(weights_curr),num_agents_per_time))
            # bias_array = np.reshape(np.array(collated_bias),(len(bias_curr),num_agents_per_time))
            med_weights = np.median(weights_array, axis=0)
            med_bias = np.median(bias_array, axis=0)
            num_layers = len(shape_size[0])
            update_list = []
            w_count = 0
            b_count = 0
            for i in range(num_layers):
                weights_length = shape_size[2][i]
                update_list.append(med_weights[w_count:w_count +
                                               weights_length].reshape(
                                                   shape_size[0][i]))
                w_count += weights_length
                bias_length = shape_size[3][i]
                update_list.append(med_bias[b_count:b_count +
                                            bias_length].reshape(
                                                shape_size[1][i]))
                b_count += bias_length
            assert model_shape_size(update_list) == shape_size
            global_weights += update_list

        # Saving for the next update
        np.save(gv.dir_name + 'global_weights_t%s.npy' % (t + 1),
                global_weights)

        # Evaluate global weight
        if args.mal:
            # p_eval = Process(target=eval_func, args=(
            #     X_test, Y_test, t + 1, return_dict, mal_data_X, mal_data_Y), kwargs={'global_weights': global_weights})
            eval_func(X_test,
                      Y_test,
                      t + 1,
                      return_dict,
                      mal_data_X,
                      mal_data_Y,
                      global_weights=global_weights)
        else:
            # p_eval = Process(target=eval_func, args=(
            #     X_test, Y_test, t + 1, return_dict), kwargs={'global_weights': global_weights})
            eval_func(X_test,
                      Y_test,
                      t + 1,
                      return_dict,
                      global_weights=global_weights)
        # p_eval.start()
        # p_eval.join()

        eval_loss_list.append(return_dict['eval_loss'])

        t += 1

    return t
Example #6
0
def agent(i, X_shard, Y_shard, t, gpu_id, return_dict, X_test, Y_test, lr=None):
    tf.keras.backend.set_learning_phase(1)

    args = gv.init()
    if lr is None:
        lr = args.eta
    print('Agent %s on GPU %s' % (i,gpu_id))
    # set environment
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    shared_weights = np.load(gv.dir_name + 'global_weights_t%s.npy' % t, allow_pickle=True)
    shard_size = len(X_shard)

    if 'theta{}'.format(gv.mal_agent_index) in return_dict.keys():
        pre_theta = return_dict['theta{}'.format(gv.mal_agent_index)]
    else:
        pre_theta = None

    # if i == 0:
    #     # eval_success, eval_loss = eval_minimal(X_test,Y_test,x, y, sess, prediction, loss)
    #     eval_success, eval_loss = eval_minimal(X_test,Y_test,shared_weights)
    #     print('Global success at time {}: {}, loss {}'.format(t,eval_success,eval_loss))

    if args.steps is not None:
        num_steps = args.steps
    else:
        num_steps = int(args.E * shard_size / args.B)

    # with tf.device('/gpu:'+str(gpu_id)):
    if args.dataset == 'census':
        x = tf.placeholder(shape=(None,
                                            gv.DATA_DIM), dtype=tf.float32)
        # y = tf.placeholder(dtype=tf.float32)
        y = tf.placeholder(dtype=tf.int64)
    else:
        x = tf.placeholder(shape=(None,
                                            gv.IMAGE_ROWS,
                                            gv.IMAGE_COLS,
                                            gv.NUM_CHANNELS), dtype=tf.float32)
        y = tf.placeholder(dtype=tf.int64)

    if 'MNIST' in args.dataset:
        agent_model = model_mnist(type=args.model_num)
    elif args.dataset == 'census':
        agent_model = census_model_1()
    elif args.dataset == 'CIFAR-10':
        agent_model = cifar10_model()
    else:
        return

    logits = agent_model(x)

    if args.dataset == 'census':
        # loss = tf.nn.sigmoid_cross_entropy_with_logits(
        #     labels=y, logits=logits)
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=y, logits=logits))
    else:
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=y, logits=logits))
    prediction = tf.nn.softmax(logits)

    if args.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(
            learning_rate=lr).minimize(loss)
    elif args.optimizer == 'sgd':
        optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=lr).minimize(loss)

    if args.k > 1:
        config = tf.ConfigProto(gpu_options=gv.gpu_options)
        # config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
    elif args.k == 1:
        sess = tf.Session()
    else:
        return
    tf.compat.v1.keras.backend.set_session(sess)
    sess.run(tf.global_variables_initializer())

    if pre_theta is not None:
        theta = pre_theta - gv.moving_rate * (pre_theta - shared_weights)
    else:
        theta = shared_weights
    agent_model.set_weights(theta)
    # print('loaded shared weights')

    start_offset = 0
    if args.steps is not None:
        start_offset = (t * args.B * args.steps) % (shard_size - args.B)

    for step in range(num_steps):
        offset = (start_offset + step * args.B) % (shard_size - args.B)
        X_batch = X_shard[offset: (offset + args.B)]
        Y_batch = Y_shard[offset: (offset + args.B)]
        Y_batch_uncat = np.argmax(Y_batch, axis=1)
        _, loss_val = sess.run([optimizer, loss], feed_dict={x: X_batch, y: Y_batch_uncat})
        if step % 1000 == 0:
            print('Agent %s, Step %s, Loss %s, offset %s' % (i, step, loss_val, offset))
            # local_weights = agent_model.get_weights()
            # eval_success, eval_loss = eval_minimal(X_test,Y_test,x, y, sess, prediction, loss)
            # print('Agent {}, Step {}: success {}, loss {}'.format(i,step,eval_success,eval_loss))

    local_weights = agent_model.get_weights()
    local_delta = local_weights - shared_weights

    # eval_success, eval_loss = eval_minimal(X_test,Y_test,x, y, sess, prediction, loss)
    eval_success, eval_loss = eval_minimal(X_test, Y_test, local_weights)

    print('Agent {}: success {}, loss {}'.format(i, eval_success, eval_loss))

    return_dict[str(i)] = np.array(local_delta)
    return_dict["theta{}".format(i)] = np.array(local_weights)

    np.save(gv.dir_name + 'ben_delta_%s_t%s.npy' % (i, t), local_delta)

    return
weights_np = np.load(gv.dir_name + 'global_weights_t%s.npy' % 8)

X_train, Y_train, X_test, Y_test, Y_test_uncat = data_setup()

mal_analyse = True

if mal_analyse:
    mal_data_X, mal_data_Y, true_labels = mal_data_setup(X_test,
                                                         Y_test,
                                                         Y_test_uncat,
                                                         gen_flag=False)

label_to_class_name = [str(i) for i in range(gv.NUM_CLASSES)]

if 'MNIST' in args.dataset:
    model = model_mnist(type=args.model_num)
elif args.dataset == 'CIFAR-10':
    model = cifar_10_model()

x = tf.compat.v1.placeholder(shape=(None, gv.IMAGE_ROWS, gv.IMAGE_COLS,
                                    gv.NUM_CHANNELS),
                             dtype=tf.float32)
y = tf.compat.v1.placeholder(dtype=tf.int64)

logits = model(x)
prediction = tf.nn.softmax(logits)

sess = tf.compat.v1.Session()

K.set_session(sess)
sess.run(tf.compat.v1.global_variables_initializer())
Example #8
0
def agent(i,
          X_shard,
          Y_shard,
          t,
          gpu_id,
          return_dict,
          X_test,
          Y_test,
          lr=None):
    K.set_learning_phase(1)

    args = gv.args
    if lr is None:
        lr = args.eta
    print('Agent %s on GPU %s' % (i, gpu_id))
    ## Set environment
    #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    #os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    os.environ['CUDA_VISIBLE_DEVICES'] = "0"

    ## Load global weights
    shared_weights = np.load(gv.dir_name + 'global_weights_t%s.npy' % t,
                             allow_pickle=True)
    shard_size = len(X_shard)

    # if i == 0:
    #     # eval_success, eval_loss = eval_minimal(X_test,Y_test,x, y, sess, prediction, loss)
    #     eval_success, eval_loss = eval_minimal(X_test,Y_test,shared_weights)
    #     print('Global success at time {}: {}, loss {}'.format(t,eval_success,eval_loss))

    if args.steps is not None:
        num_steps = args.steps
    else:
        # num_steps = num_epochs * num_batches_per_shard
        num_steps = int(args.E) * shard_size / args.B

    # with tf.device('/gpu:'+str(gpu_id)):
    ## Load data
    if args.dataset == 'census':
        x = tf.placeholder(shape=(None, gv.DATA_DIM), dtype=tf.float32)
        # y = tf.placeholder(dtype=tf.float32)
        y = tf.placeholder(dtype=tf.int64)
    else:
        x = tf.placeholder(shape=(None, gv.IMAGE_ROWS, gv.IMAGE_COLS,
                                  gv.NUM_CHANNELS),
                           dtype=tf.float32)
        y = tf.placeholder(dtype=tf.int64)

    ## Load model
    if 'MNIST' in args.dataset:
        agent_model = model_mnist(type=args.model_num)
    elif args.dataset == 'census':
        agent_model = census_model_1()

    logits = agent_model(x)

    if args.dataset == 'census':
        # loss = tf.nn.sigmoid_cross_entropy_with_logits(
        #     labels=y, logits=logits)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,
                                                           logits=logits))
    else:
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,
                                                           logits=logits))
    #prediction = tf.nn.softmax(logits)

    if args.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)
    elif args.optimizer == 'sgd':
        optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=lr).minimize(loss)

    if args.k > 1:
        config = tf.ConfigProto(gpu_options=gv.gpu_options)
        # config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
    elif args.k == 1:
        sess = tf.Session()
    K.set_session(sess)
    sess.run(tf.global_variables_initializer())

    agent_model.set_weights(shared_weights)

    start_offset = 0
    # Start training data from where we left off last time
    if args.steps is not None:
        start_offset = (t * args.B * args.steps) % (shard_size - args.B)

    for step in range(num_steps):
        offset = (start_offset + step * args.B) % (shard_size - args.B)
        X_batch = X_shard[offset:(offset + args.B)]
        Y_batch = Y_shard[offset:(offset + args.B)]
        Y_batch_uncat = np.argmax(Y_batch, axis=1)
        _, loss_val = sess.run([optimizer, loss],
                               feed_dict={
                                   x: X_batch,
                                   y: Y_batch_uncat
                               })
        if step % 100 == 0:
            print('Agent %s, Step %s, Loss %s, offset %s' %
                  (i, step, loss_val, offset))
            # local_weights = agent_model.get_weights()
            # eval_success, eval_loss = eval_minimal(X_test,Y_test,x, y, sess, prediction, loss)
            # print('Agent {}, Step {}: success {}, loss {}'.format(i,step,eval_success,eval_loss))

    local_weights = agent_model.get_weights()
    local_delta = local_weights - shared_weights

    # eval_success, eval_loss = eval_minimal(X_test,Y_test,x, y, sess, prediction, loss)
    # Compute TPR and loss for the local model
    eval_success, eval_loss = eval_minimal(X_test, Y_test, local_weights)

    print('Evaluation on test data - Agent {}: success {}, loss {}'.format(
        i, eval_success, eval_loss))

    return_dict[str(i)] = np.array(local_delta)

    np.save(gv.dir_name + 'ben_delta_%s_t%s.npy' % (i, t), local_delta)

    return