Esempio n. 1
0
def train(tf_seed, np_seed, train_steps, out_steps, summary_steps,
          checkpoint_steps, step_size_schedule, weight_decay, momentum,
          train_batch_size, epsilon, replay_m, model_dir, dataset,
          poison_alpha, poison_config, **kwargs):
    tf.compat.v1.set_random_seed(tf_seed)
    np.random.seed(np_seed)

    print('poison alpha = %f' % poison_alpha)

    model_dir = model_dir + '%s_m%d_eps%.1f_b%d' % (
        dataset, replay_m, epsilon, train_batch_size
    )  # TODO Replace with not defaults

    # Setting up the data and the model

    poison_config_dict = utilities.config_to_namedtuple(
        utilities.get_config(poison_config))

    print(poison_config_dict)

    data_path = get_path_dir(dataset=dataset, **kwargs)
    if dataset == 'cifar10':
        raw_data = cifar10_input.CIFAR10Data(data_path)
    elif dataset == 'cifar10_poisoned':
        raw_data = dataset_input.CIFAR10Data(poison_config_dict, seed=np_seed)
    else:
        raw_data = cifar100_input.CIFAR100Data(data_path)
    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tpu_strategy.scope():
        model = Model(mode='train',
                      dataset=dataset,
                      train_batch_size=train_batch_size)

    # Setting up the optimizer
    boundaries = [int(sss[0]) for sss in step_size_schedule][1:]
    values = [sss[1] for sss in step_size_schedule]
    learning_rate = tf.compat.v1.train.piecewise_constant(
        tf.cast(global_step, tf.int32), boundaries, values)
    optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate, momentum)

    # Optimizing computation
    total_loss = model.mean_xent + weight_decay * model.weight_decay_loss
    grads = optimizer.compute_gradients(total_loss)

    # Compute new image
    pert_grad = [g for g, v in grads if 'perturbation' in v.name]
    sign_pert_grad = tf.sign(pert_grad[0])
    new_pert = model.pert + epsilon * sign_pert_grad
    clip_new_pert = tf.clip_by_value(new_pert, -epsilon, epsilon)
    assigned = tf.compat.v1.assign(model.pert, clip_new_pert)

    # Train
    no_pert_grad = [(tf.zeros_like(v), v) if 'perturbation' in v.name else
                    (g, v) for g, v in grads]
    with tf.control_dependencies([assigned]):
        min_step = optimizer.apply_gradients(no_pert_grad,
                                             global_step=global_step)
    tf.compat.v1.initialize_variables([model.pert])  # TODO: Removed from TF

    # Setting up the Tensorboard and checkpoint outputs
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    saver = tf.compat.v1.train.Saver(max_to_keep=1)
    tf.compat.v1.summary.scalar('accuracy', model.accuracy)
    tf.compat.v1.summary.scalar('xent', model.xent / train_batch_size)
    tf.compat.v1.summary.scalar('total loss', total_loss / train_batch_size)
    merged_summaries = tf.compat.v1.summary.merge_all()

    gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=1.0)
    with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
            gpu_options=gpu_options)) as sess:
        print(
            '\n\n********** free training for epsilon=%.1f using m_replay=%d **********\n\n'
            % (epsilon, replay_m))
        print(
            'important params >>> \n model dir: %s \n dataset: %s \n training batch size: %d \n'
            % (model_dir, dataset, train_batch_size))
        if dataset == 'cifar100':
            print(
                'the ride for CIFAR100 is bumpy -- fasten your seatbelts! \n \
          you will probably see the training and validation accuracy fluctuating a lot early in trainnig \n \
                this is natural especially for large replay_m values because we see that mini-batch so many times.'
            )
        # initialize data augmentation
        if dataset == 'cifar10':
            data = cifar10_input.AugmentedCIFAR10Data(raw_data, sess, model)
        elif dataset == 'cifar10_poisoned':
            data = raw_data
        else:
            data = cifar100_input.AugmentedCIFAR100Data(raw_data, sess, model)

        # Initialize the summary writer, global variables, and our time counter.
        summary_writer = tf.compat.v1.summary.FileWriter(
            model_dir + '/train', sess.graph)
        eval_summary_writer = tf.compat.v1.summary.FileWriter(model_dir +
                                                              '/eval')
        sess.run(tf.compat.v1.global_variables_initializer())

        # Main training loop
        for ii in range(train_steps):
            if ii % replay_m == 0:
                x_batch, y_batch = data.train_data.get_next_batch(
                    train_batch_size, multiple_passes=True)
                nat_dict = {model.x_input: x_batch, model.y_input: y_batch}

            x_eval_batch, y_eval_batch = data.eval_data.get_next_batch(
                train_batch_size, multiple_passes=True)
            eval_dict = {
                model.x_input: x_eval_batch,
                model.y_input: y_eval_batch
            }

            # Output to stdout
            if ii % summary_steps == 0:
                train_acc, summary = sess.run(
                    [model.accuracy, merged_summaries], feed_dict=nat_dict)
                summary_writer.add_summary(summary, global_step.eval(sess))
                val_acc, summary = sess.run([model.accuracy, merged_summaries],
                                            feed_dict=eval_dict)
                eval_summary_writer.add_summary(summary,
                                                global_step.eval(sess))
                print('Step {}:    ({})'.format(ii, datetime.now()))
                print(
                    '    training nat accuracy {:.4}% -- validation nat accuracy {:.4}%'
                    .format(train_acc * 100, val_acc * 100))
                sys.stdout.flush()
            # Tensorboard summaries
            elif ii % out_steps == 0:
                nat_acc = sess.run(model.accuracy, feed_dict=nat_dict)
                print('Step {}:    ({})'.format(ii, datetime.now()))
                print('    training nat accuracy {:.4}%'.format(nat_acc * 100))

            # Write a checkpoint
            if (ii + 1) % checkpoint_steps == 0:
                saver.save(sess,
                           os.path.join(model_dir, 'checkpoint'),
                           global_step=global_step)

            # Actual training step
            sess.run(min_step, feed_dict=nat_dict)
Esempio n. 2
0
def evaluate(model, sess, config, summary_writer=None):
    eval_batch_size = config.eval.batch_size

    model_dir = config.model.output_dir
    # Setting up the Tensorboard and checkpoint outputs
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    poison_method = config.data.poison_method
    clean_label = config.data.clean_label
    target_label = config.data.target_label
    position = config.data.position
    color = config.data.color
    dataset = dataset_input.CIFAR10Data(config,
                                        seed=config.training.np_random_seed)
    print(poison_method, clean_label, target_label, position, color)

    global_step = tf.contrib.framework.get_or_create_global_step()
    # Iterate over the samples batch-by-batch
    num_eval_examples = len(dataset.eval_data.xs)
    num_clean_examples = 0
    num_batches = int(math.ceil(num_eval_examples / eval_batch_size))
    total_xent_nat = 0.
    total_corr_nat = 0
    total_xent_pois = 0.
    total_corr_pois = 0

    for ibatch in trange(num_batches):
        bstart = ibatch * eval_batch_size
        bend = min(bstart + eval_batch_size, num_eval_examples)

        x_batch = dataset.eval_data.xs[bstart:bend, :]
        y_batch = dataset.eval_data.ys[bstart:bend]
        pois_x_batch = dataset.poisoned_eval_data.xs[bstart:bend, :]
        pois_y_batch = dataset.poisoned_eval_data.ys[bstart:bend]

        dict_nat = {
            model.x_input: x_batch,
            model.y_input: y_batch,
            model.is_training: False
        }

        cur_corr_nat, cur_xent_nat = sess.run([model.num_correct, model.xent],
                                              feed_dict=dict_nat)
        total_xent_nat += cur_xent_nat
        total_corr_nat += cur_corr_nat

        if clean_label > -1:
            clean_indices = np.where(y_batch == clean_label)[0]
            if len(clean_indices) == 0: continue
            pois_x_batch = pois_x_batch[clean_indices]
            pois_y_batch = np.repeat(target_label, len(clean_indices))
        else:
            pois_y_batch = np.repeat(target_label, bend - bstart)
        num_clean_examples += len(pois_x_batch)

        dict_pois = {
            model.x_input: pois_x_batch,
            model.y_input: pois_y_batch,
            model.is_training: False
        }

        cur_corr_pois, cur_xent_pois, prsof = sess.run(
            [model.num_correct, model.xent, model.pre_softmax],
            feed_dict=dict_pois)
        print(prsof[0:5])
        total_xent_pois += cur_xent_pois
        total_corr_pois += cur_corr_pois

    avg_xent_nat = total_xent_nat / num_eval_examples
    acc_nat = total_corr_nat / num_eval_examples
    avg_xent_pois = total_xent_pois / num_clean_examples
    acc_pois = total_corr_pois / num_clean_examples

    if summary_writer:
        summary = tf.Summary(value=[
            tf.Summary.Value(tag='xent_nat_eval', simple_value=avg_xent_nat),
            tf.Summary.Value(tag='xent_nat', simple_value=avg_xent_nat),
            tf.Summary.Value(tag='accuracy_nat_eval', simple_value=acc_nat),
            tf.Summary.Value(tag='accuracy_nat', simple_value=acc_nat)
        ])
        summary_writer.add_summary(summary, global_step.eval(sess))

    step = global_step.eval(sess)
    print('Eval at step: {}'.format(step))
    print('  natural: {:.2f}%'.format(100 * acc_nat))
    print('  avg nat xent: {:.4f}'.format(avg_xent_nat))
    print('  poisoned: {:.2f}%'.format(100 * acc_pois))
    print('  avg pois xent: {:.4f}'.format(avg_xent_pois))

    result = {
        'nat': '{:.2f}%'.format(100 * acc_nat),
        'pois': '{:.2f}%'.format(100 * acc_pois)
    }
    with open('job_result.json', 'w') as result_file:
        json.dump(result, result_file, sort_keys=True, indent=4)
Esempio n. 3
0
def wrapper():
    return dataset_input.CIFAR10Data(config, seed=config.training.np_random_seed)
    #return dataset_input.RestrictedImagenet(config, seed=config.training.np_random_seed)
def train(config):
    # seeding randomness
    tf.set_random_seed(config.training.tf_random_seed)
    np.random.seed(config.training.np_random_seed)

    # Setting up training parameters
    max_num_training_steps = config.training.max_num_training_steps
    step_size_schedule = config.training.step_size_schedule
    weight_decay = config.training.weight_decay
    momentum = config.training.momentum
    batch_size = config.training.batch_size
    eval_during_training = config.training.eval_during_training
    num_clean_examples = config.training.num_examples
    if eval_during_training:
        num_eval_steps = config.training.num_eval_steps

    # Setting up output parameters
    num_output_steps = config.training.num_output_steps
    num_summary_steps = config.training.num_summary_steps
    num_checkpoint_steps = config.training.num_checkpoint_steps

    # Setting up the data and the model
    dataset = dataset_input.CIFAR10Data(config,
                                        seed=config.training.np_random_seed)
    print('Num Poisoned Left: {}'.format(dataset.num_poisoned_left))
    print('Poison Position: {}'.format(config.data.position))
    print('Poison Color: {}'.format(config.data.color))
    num_training_examples = len(dataset.train_data.xs)
    global_step = tf.contrib.framework.get_or_create_global_step()
    model = resnet.Model(config.model)

    # uncomment to get a list of trainable variables
    model_vars = tf.trainable_variables()
    slim.model_analyzer.analyze_vars(model_vars, print_info=True)

    # Setting up the optimizer
    boundaries = [int(sss[0]) for sss in step_size_schedule]
    boundaries = boundaries[1:]
    values = [sss[1] for sss in step_size_schedule]
    learning_rate = tf.train.piecewise_constant(tf.cast(global_step, tf.int32),
                                                boundaries, values)
    total_loss = model.mean_xent + weight_decay * model.weight_decay_loss

    optimizer = tf.train.MomentumOptimizer(learning_rate, momentum)
    train_step = optimizer.minimize(total_loss, global_step=global_step)

    # Setting up the Tensorboard and checkpoint outputs
    model_dir = config.model.output_dir
    if eval_during_training:
        eval_dir = os.path.join(model_dir, 'eval')
        if not os.path.exists(eval_dir):
            os.makedirs(eval_dir)

    # We add accuracy and xent twice so we can easily make three types of
    # comparisons in Tensorboard:
    # - train vs eval (for a single run)
    # - train of different runs
    # - eval of different runs

    saver = tf.train.Saver(max_to_keep=3)

    tf.summary.scalar('accuracy_nat_train',
                      model.accuracy,
                      collections=['nat'])
    tf.summary.scalar('accuracy_nat', model.accuracy, collections=['nat'])
    tf.summary.scalar('xent_nat_train',
                      model.xent / batch_size,
                      collections=['nat'])
    tf.summary.scalar('xent_nat', model.xent / batch_size, collections=['nat'])
    tf.summary.image('images_nat_train', model.train_xs, collections=['nat'])
    tf.summary.scalar('learning_rate', learning_rate, collections=['nat'])
    nat_summaries = tf.summary.merge_all('nat')

    with tf.Session() as sess:
        print('Dataset Size: ', len(dataset.train_data.xs))

        # Initialize the summary writer, global variables, and our time counter.
        summary_writer = tf.summary.FileWriter(model_dir, sess.graph)
        if eval_during_training:
            eval_summary_writer = tf.summary.FileWriter(eval_dir)

        sess.run(tf.global_variables_initializer())
        training_time = 0.0

        # Main training loop
        for ii in range(max_num_training_steps + 1):
            x_batch, y_batch = dataset.train_data.get_next_batch(
                batch_size, multiple_passes=True)

            nat_dict = {
                model.x_input: x_batch,
                model.y_input: y_batch,
                model.is_training: False
            }

            # Output to stdout
            if ii % num_output_steps == 0:
                nat_acc = sess.run(model.accuracy, feed_dict=nat_dict)
                print('Step {}:    ({})'.format(ii, datetime.now()))
                print('    training nat accuracy {:.4}%'.format(nat_acc * 100))
                if ii != 0:
                    print('    {} examples per second'.format(
                        num_output_steps * batch_size / training_time))
                    training_time = 0.0

            # Tensorboard summaries
            if ii % num_summary_steps == 0:
                summary = sess.run(nat_summaries, feed_dict=nat_dict)
                summary_writer.add_summary(summary, global_step.eval(sess))

            # Write a checkpoint
            if ii % num_checkpoint_steps == 0:
                saver.save(sess,
                           os.path.join(model_dir, 'checkpoint'),
                           global_step=global_step)

            if eval_during_training and ii % num_eval_steps == 0:
                evaluate(model, sess, config, eval_summary_writer)

            # Actual training step
            start = timer()
            nat_dict[model.is_training] = True
            sess.run(train_step, feed_dict=nat_dict)
            end = timer()
            training_time += end - start
def compute_corr(config):
    # seeding randomness
    tf.set_random_seed(config.training.tf_random_seed)
    np.random.seed(config.training.np_random_seed)

    # Setting up the data and the model
    poison_eps = config.data.poison_eps
    clean_label = config.data.clean_label
    target_label = config.data.target_label
    dataset = dataset_input.CIFAR10Data(config,
                                        seed=config.training.np_random_seed)
    num_poisoned_left = dataset.num_poisoned_left
    print('Num poisoned left: ', num_poisoned_left)
    num_training_examples = len(dataset.train_data.xs)
    global_step = tf.contrib.framework.get_or_create_global_step()
    model = resnet.Model(config.model)


    # Setting up the Tensorboard and checkpoint outputs
    model_dir = config.model.output_dir

    saver = tf.train.Saver(max_to_keep=3)

    with tf.Session() as sess:

        # initialize data augmentation
        print('Dataset Size: ', len(dataset.train_data.xs))

        sess.run(tf.global_variables_initializer())
            
        latest_checkpoint = tf.train.latest_checkpoint(model_dir)
        if latest_checkpoint is not None:
            saver.restore(sess, latest_checkpoint)
            print('Restoring last saved checkpoint: ', latest_checkpoint)
        else:
            print('Check model directory')
            exit()

        lbl = target_label
        cur_indices = np.where(dataset.train_data.ys==lbl)[0]
        cur_examples = len(cur_indices)
        print('Label, num ex: ', lbl, cur_examples)
        cur_op = model.representation
        for iex in trange(cur_examples):
            cur_im = cur_indices[iex]
            x_batch = dataset.train_data.xs[cur_im:cur_im+1,:]
            y_batch = dataset.train_data.ys[cur_im:cur_im+1]

            dict_nat = {model.x_input: x_batch,
                        model.y_input: y_batch,
                        model.is_training: False}

            batch_grads = sess.run(cur_op, feed_dict=dict_nat)
            if iex==0:
                clean_cov = np.zeros(shape=(cur_examples-num_poisoned_left, len(batch_grads)))
                full_cov = np.zeros(shape=(cur_examples, len(batch_grads)))
            if iex < (cur_examples-num_poisoned_left):
                clean_cov[iex]=batch_grads
            full_cov[iex] = batch_grads

        #np.save(corr_dir+str(lbl)+'_full_cov.npy', full_cov)
        total_p = config.data.percentile            
        clean_mean = np.mean(clean_cov, axis=0, keepdims=True)
        full_mean = np.mean(full_cov, axis=0, keepdims=True)            

        print('Norm of Difference in Mean: ', np.linalg.norm(clean_mean-full_mean))
        clean_centered_cov = clean_cov - clean_mean
        s_clean = np.linalg.svd(clean_centered_cov, full_matrices=False, compute_uv=False)
        print('Top 7 Clean SVs: ', s_clean[0:7])
        
        centered_cov = full_cov - full_mean
        u,s,v = np.linalg.svd(centered_cov, full_matrices=False)
        print('Top 7 Singular Values: ', s[0:7])
        eigs = v[0:1]  
        p = total_p
        corrs = np.matmul(eigs, np.transpose(full_cov)) #shape num_top, num_active_indices
        scores = np.linalg.norm(corrs, axis=0) #shape num_active_indices
        np.save(os.path.join(model_dir, 'scores.npy'), scores)
        print('Length Scores: ', len(scores))
        p_score = np.percentile(scores, p)
        top_scores = np.where(scores>p_score)[0]
        print(top_scores)
        num_bad_removed = np.count_nonzero(top_scores>=(len(scores)-num_poisoned_left))
        print('Num Bad Removed: ', num_bad_removed)
        print('Num Good Rmoved: ', len(top_scores)-num_bad_removed)
        
        num_poisoned_after = num_poisoned_left - num_bad_removed
        removed_inds = np.copy(top_scores)
        
        removed_inds_file = os.path.join(model_dir, 'removed_inds.npy')
        np.save(removed_inds_file, cur_indices[removed_inds])        
        print('Num Poisoned Left: ', num_poisoned_after)    

        if os.path.exists('job_result.json'):
            with open('job_result.json') as result_file:
                result = json.load(result_file)
                result['num_poisoned_left'] = '{}'.format(num_poisoned_after)
        else:
            result = {'num_poisoned_left': '{}'.format(num_poisoned_after)}
        with open('job_result.json', 'w') as result_file:
            json.dump(result, result_file, sort_keys=True, indent=4)