def train(tf_seed, np_seed, train_steps, out_steps, summary_steps, checkpoint_steps, step_size_schedule, weight_decay, momentum, train_batch_size, epsilon, replay_m, model_dir, dataset, poison_alpha, poison_config, **kwargs): tf.compat.v1.set_random_seed(tf_seed) np.random.seed(np_seed) print('poison alpha = %f' % poison_alpha) model_dir = model_dir + '%s_m%d_eps%.1f_b%d' % ( dataset, replay_m, epsilon, train_batch_size ) # TODO Replace with not defaults # Setting up the data and the model poison_config_dict = utilities.config_to_namedtuple( utilities.get_config(poison_config)) print(poison_config_dict) data_path = get_path_dir(dataset=dataset, **kwargs) if dataset == 'cifar10': raw_data = cifar10_input.CIFAR10Data(data_path) elif dataset == 'cifar10_poisoned': raw_data = dataset_input.CIFAR10Data(poison_config_dict, seed=np_seed) else: raw_data = cifar100_input.CIFAR100Data(data_path) global_step = tf.compat.v1.train.get_or_create_global_step() with tpu_strategy.scope(): model = Model(mode='train', dataset=dataset, train_batch_size=train_batch_size) # Setting up the optimizer boundaries = [int(sss[0]) for sss in step_size_schedule][1:] values = [sss[1] for sss in step_size_schedule] learning_rate = tf.compat.v1.train.piecewise_constant( tf.cast(global_step, tf.int32), boundaries, values) optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate, momentum) # Optimizing computation total_loss = model.mean_xent + weight_decay * model.weight_decay_loss grads = optimizer.compute_gradients(total_loss) # Compute new image pert_grad = [g for g, v in grads if 'perturbation' in v.name] sign_pert_grad = tf.sign(pert_grad[0]) new_pert = model.pert + epsilon * sign_pert_grad clip_new_pert = tf.clip_by_value(new_pert, -epsilon, epsilon) assigned = tf.compat.v1.assign(model.pert, clip_new_pert) # Train no_pert_grad = [(tf.zeros_like(v), v) if 'perturbation' in v.name else (g, v) for g, v in grads] with tf.control_dependencies([assigned]): min_step = optimizer.apply_gradients(no_pert_grad, global_step=global_step) tf.compat.v1.initialize_variables([model.pert]) # TODO: Removed from TF # Setting up the Tensorboard and checkpoint outputs if not os.path.exists(model_dir): os.makedirs(model_dir) saver = tf.compat.v1.train.Saver(max_to_keep=1) tf.compat.v1.summary.scalar('accuracy', model.accuracy) tf.compat.v1.summary.scalar('xent', model.xent / train_batch_size) tf.compat.v1.summary.scalar('total loss', total_loss / train_batch_size) merged_summaries = tf.compat.v1.summary.merge_all() gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=1.0) with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto( gpu_options=gpu_options)) as sess: print( '\n\n********** free training for epsilon=%.1f using m_replay=%d **********\n\n' % (epsilon, replay_m)) print( 'important params >>> \n model dir: %s \n dataset: %s \n training batch size: %d \n' % (model_dir, dataset, train_batch_size)) if dataset == 'cifar100': print( 'the ride for CIFAR100 is bumpy -- fasten your seatbelts! \n \ you will probably see the training and validation accuracy fluctuating a lot early in trainnig \n \ this is natural especially for large replay_m values because we see that mini-batch so many times.' ) # initialize data augmentation if dataset == 'cifar10': data = cifar10_input.AugmentedCIFAR10Data(raw_data, sess, model) elif dataset == 'cifar10_poisoned': data = raw_data else: data = cifar100_input.AugmentedCIFAR100Data(raw_data, sess, model) # Initialize the summary writer, global variables, and our time counter. summary_writer = tf.compat.v1.summary.FileWriter( model_dir + '/train', sess.graph) eval_summary_writer = tf.compat.v1.summary.FileWriter(model_dir + '/eval') sess.run(tf.compat.v1.global_variables_initializer()) # Main training loop for ii in range(train_steps): if ii % replay_m == 0: x_batch, y_batch = data.train_data.get_next_batch( train_batch_size, multiple_passes=True) nat_dict = {model.x_input: x_batch, model.y_input: y_batch} x_eval_batch, y_eval_batch = data.eval_data.get_next_batch( train_batch_size, multiple_passes=True) eval_dict = { model.x_input: x_eval_batch, model.y_input: y_eval_batch } # Output to stdout if ii % summary_steps == 0: train_acc, summary = sess.run( [model.accuracy, merged_summaries], feed_dict=nat_dict) summary_writer.add_summary(summary, global_step.eval(sess)) val_acc, summary = sess.run([model.accuracy, merged_summaries], feed_dict=eval_dict) eval_summary_writer.add_summary(summary, global_step.eval(sess)) print('Step {}: ({})'.format(ii, datetime.now())) print( ' training nat accuracy {:.4}% -- validation nat accuracy {:.4}%' .format(train_acc * 100, val_acc * 100)) sys.stdout.flush() # Tensorboard summaries elif ii % out_steps == 0: nat_acc = sess.run(model.accuracy, feed_dict=nat_dict) print('Step {}: ({})'.format(ii, datetime.now())) print(' training nat accuracy {:.4}%'.format(nat_acc * 100)) # Write a checkpoint if (ii + 1) % checkpoint_steps == 0: saver.save(sess, os.path.join(model_dir, 'checkpoint'), global_step=global_step) # Actual training step sess.run(min_step, feed_dict=nat_dict)
def evaluate(model, sess, config, summary_writer=None): eval_batch_size = config.eval.batch_size model_dir = config.model.output_dir # Setting up the Tensorboard and checkpoint outputs if not os.path.exists(model_dir): os.makedirs(model_dir) poison_method = config.data.poison_method clean_label = config.data.clean_label target_label = config.data.target_label position = config.data.position color = config.data.color dataset = dataset_input.CIFAR10Data(config, seed=config.training.np_random_seed) print(poison_method, clean_label, target_label, position, color) global_step = tf.contrib.framework.get_or_create_global_step() # Iterate over the samples batch-by-batch num_eval_examples = len(dataset.eval_data.xs) num_clean_examples = 0 num_batches = int(math.ceil(num_eval_examples / eval_batch_size)) total_xent_nat = 0. total_corr_nat = 0 total_xent_pois = 0. total_corr_pois = 0 for ibatch in trange(num_batches): bstart = ibatch * eval_batch_size bend = min(bstart + eval_batch_size, num_eval_examples) x_batch = dataset.eval_data.xs[bstart:bend, :] y_batch = dataset.eval_data.ys[bstart:bend] pois_x_batch = dataset.poisoned_eval_data.xs[bstart:bend, :] pois_y_batch = dataset.poisoned_eval_data.ys[bstart:bend] dict_nat = { model.x_input: x_batch, model.y_input: y_batch, model.is_training: False } cur_corr_nat, cur_xent_nat = sess.run([model.num_correct, model.xent], feed_dict=dict_nat) total_xent_nat += cur_xent_nat total_corr_nat += cur_corr_nat if clean_label > -1: clean_indices = np.where(y_batch == clean_label)[0] if len(clean_indices) == 0: continue pois_x_batch = pois_x_batch[clean_indices] pois_y_batch = np.repeat(target_label, len(clean_indices)) else: pois_y_batch = np.repeat(target_label, bend - bstart) num_clean_examples += len(pois_x_batch) dict_pois = { model.x_input: pois_x_batch, model.y_input: pois_y_batch, model.is_training: False } cur_corr_pois, cur_xent_pois, prsof = sess.run( [model.num_correct, model.xent, model.pre_softmax], feed_dict=dict_pois) print(prsof[0:5]) total_xent_pois += cur_xent_pois total_corr_pois += cur_corr_pois avg_xent_nat = total_xent_nat / num_eval_examples acc_nat = total_corr_nat / num_eval_examples avg_xent_pois = total_xent_pois / num_clean_examples acc_pois = total_corr_pois / num_clean_examples if summary_writer: summary = tf.Summary(value=[ tf.Summary.Value(tag='xent_nat_eval', simple_value=avg_xent_nat), tf.Summary.Value(tag='xent_nat', simple_value=avg_xent_nat), tf.Summary.Value(tag='accuracy_nat_eval', simple_value=acc_nat), tf.Summary.Value(tag='accuracy_nat', simple_value=acc_nat) ]) summary_writer.add_summary(summary, global_step.eval(sess)) step = global_step.eval(sess) print('Eval at step: {}'.format(step)) print(' natural: {:.2f}%'.format(100 * acc_nat)) print(' avg nat xent: {:.4f}'.format(avg_xent_nat)) print(' poisoned: {:.2f}%'.format(100 * acc_pois)) print(' avg pois xent: {:.4f}'.format(avg_xent_pois)) result = { 'nat': '{:.2f}%'.format(100 * acc_nat), 'pois': '{:.2f}%'.format(100 * acc_pois) } with open('job_result.json', 'w') as result_file: json.dump(result, result_file, sort_keys=True, indent=4)
def wrapper(): return dataset_input.CIFAR10Data(config, seed=config.training.np_random_seed) #return dataset_input.RestrictedImagenet(config, seed=config.training.np_random_seed)
def train(config): # seeding randomness tf.set_random_seed(config.training.tf_random_seed) np.random.seed(config.training.np_random_seed) # Setting up training parameters max_num_training_steps = config.training.max_num_training_steps step_size_schedule = config.training.step_size_schedule weight_decay = config.training.weight_decay momentum = config.training.momentum batch_size = config.training.batch_size eval_during_training = config.training.eval_during_training num_clean_examples = config.training.num_examples if eval_during_training: num_eval_steps = config.training.num_eval_steps # Setting up output parameters num_output_steps = config.training.num_output_steps num_summary_steps = config.training.num_summary_steps num_checkpoint_steps = config.training.num_checkpoint_steps # Setting up the data and the model dataset = dataset_input.CIFAR10Data(config, seed=config.training.np_random_seed) print('Num Poisoned Left: {}'.format(dataset.num_poisoned_left)) print('Poison Position: {}'.format(config.data.position)) print('Poison Color: {}'.format(config.data.color)) num_training_examples = len(dataset.train_data.xs) global_step = tf.contrib.framework.get_or_create_global_step() model = resnet.Model(config.model) # uncomment to get a list of trainable variables model_vars = tf.trainable_variables() slim.model_analyzer.analyze_vars(model_vars, print_info=True) # Setting up the optimizer boundaries = [int(sss[0]) for sss in step_size_schedule] boundaries = boundaries[1:] values = [sss[1] for sss in step_size_schedule] learning_rate = tf.train.piecewise_constant(tf.cast(global_step, tf.int32), boundaries, values) total_loss = model.mean_xent + weight_decay * model.weight_decay_loss optimizer = tf.train.MomentumOptimizer(learning_rate, momentum) train_step = optimizer.minimize(total_loss, global_step=global_step) # Setting up the Tensorboard and checkpoint outputs model_dir = config.model.output_dir if eval_during_training: eval_dir = os.path.join(model_dir, 'eval') if not os.path.exists(eval_dir): os.makedirs(eval_dir) # We add accuracy and xent twice so we can easily make three types of # comparisons in Tensorboard: # - train vs eval (for a single run) # - train of different runs # - eval of different runs saver = tf.train.Saver(max_to_keep=3) tf.summary.scalar('accuracy_nat_train', model.accuracy, collections=['nat']) tf.summary.scalar('accuracy_nat', model.accuracy, collections=['nat']) tf.summary.scalar('xent_nat_train', model.xent / batch_size, collections=['nat']) tf.summary.scalar('xent_nat', model.xent / batch_size, collections=['nat']) tf.summary.image('images_nat_train', model.train_xs, collections=['nat']) tf.summary.scalar('learning_rate', learning_rate, collections=['nat']) nat_summaries = tf.summary.merge_all('nat') with tf.Session() as sess: print('Dataset Size: ', len(dataset.train_data.xs)) # Initialize the summary writer, global variables, and our time counter. summary_writer = tf.summary.FileWriter(model_dir, sess.graph) if eval_during_training: eval_summary_writer = tf.summary.FileWriter(eval_dir) sess.run(tf.global_variables_initializer()) training_time = 0.0 # Main training loop for ii in range(max_num_training_steps + 1): x_batch, y_batch = dataset.train_data.get_next_batch( batch_size, multiple_passes=True) nat_dict = { model.x_input: x_batch, model.y_input: y_batch, model.is_training: False } # Output to stdout if ii % num_output_steps == 0: nat_acc = sess.run(model.accuracy, feed_dict=nat_dict) print('Step {}: ({})'.format(ii, datetime.now())) print(' training nat accuracy {:.4}%'.format(nat_acc * 100)) if ii != 0: print(' {} examples per second'.format( num_output_steps * batch_size / training_time)) training_time = 0.0 # Tensorboard summaries if ii % num_summary_steps == 0: summary = sess.run(nat_summaries, feed_dict=nat_dict) summary_writer.add_summary(summary, global_step.eval(sess)) # Write a checkpoint if ii % num_checkpoint_steps == 0: saver.save(sess, os.path.join(model_dir, 'checkpoint'), global_step=global_step) if eval_during_training and ii % num_eval_steps == 0: evaluate(model, sess, config, eval_summary_writer) # Actual training step start = timer() nat_dict[model.is_training] = True sess.run(train_step, feed_dict=nat_dict) end = timer() training_time += end - start
def compute_corr(config): # seeding randomness tf.set_random_seed(config.training.tf_random_seed) np.random.seed(config.training.np_random_seed) # Setting up the data and the model poison_eps = config.data.poison_eps clean_label = config.data.clean_label target_label = config.data.target_label dataset = dataset_input.CIFAR10Data(config, seed=config.training.np_random_seed) num_poisoned_left = dataset.num_poisoned_left print('Num poisoned left: ', num_poisoned_left) num_training_examples = len(dataset.train_data.xs) global_step = tf.contrib.framework.get_or_create_global_step() model = resnet.Model(config.model) # Setting up the Tensorboard and checkpoint outputs model_dir = config.model.output_dir saver = tf.train.Saver(max_to_keep=3) with tf.Session() as sess: # initialize data augmentation print('Dataset Size: ', len(dataset.train_data.xs)) sess.run(tf.global_variables_initializer()) latest_checkpoint = tf.train.latest_checkpoint(model_dir) if latest_checkpoint is not None: saver.restore(sess, latest_checkpoint) print('Restoring last saved checkpoint: ', latest_checkpoint) else: print('Check model directory') exit() lbl = target_label cur_indices = np.where(dataset.train_data.ys==lbl)[0] cur_examples = len(cur_indices) print('Label, num ex: ', lbl, cur_examples) cur_op = model.representation for iex in trange(cur_examples): cur_im = cur_indices[iex] x_batch = dataset.train_data.xs[cur_im:cur_im+1,:] y_batch = dataset.train_data.ys[cur_im:cur_im+1] dict_nat = {model.x_input: x_batch, model.y_input: y_batch, model.is_training: False} batch_grads = sess.run(cur_op, feed_dict=dict_nat) if iex==0: clean_cov = np.zeros(shape=(cur_examples-num_poisoned_left, len(batch_grads))) full_cov = np.zeros(shape=(cur_examples, len(batch_grads))) if iex < (cur_examples-num_poisoned_left): clean_cov[iex]=batch_grads full_cov[iex] = batch_grads #np.save(corr_dir+str(lbl)+'_full_cov.npy', full_cov) total_p = config.data.percentile clean_mean = np.mean(clean_cov, axis=0, keepdims=True) full_mean = np.mean(full_cov, axis=0, keepdims=True) print('Norm of Difference in Mean: ', np.linalg.norm(clean_mean-full_mean)) clean_centered_cov = clean_cov - clean_mean s_clean = np.linalg.svd(clean_centered_cov, full_matrices=False, compute_uv=False) print('Top 7 Clean SVs: ', s_clean[0:7]) centered_cov = full_cov - full_mean u,s,v = np.linalg.svd(centered_cov, full_matrices=False) print('Top 7 Singular Values: ', s[0:7]) eigs = v[0:1] p = total_p corrs = np.matmul(eigs, np.transpose(full_cov)) #shape num_top, num_active_indices scores = np.linalg.norm(corrs, axis=0) #shape num_active_indices np.save(os.path.join(model_dir, 'scores.npy'), scores) print('Length Scores: ', len(scores)) p_score = np.percentile(scores, p) top_scores = np.where(scores>p_score)[0] print(top_scores) num_bad_removed = np.count_nonzero(top_scores>=(len(scores)-num_poisoned_left)) print('Num Bad Removed: ', num_bad_removed) print('Num Good Rmoved: ', len(top_scores)-num_bad_removed) num_poisoned_after = num_poisoned_left - num_bad_removed removed_inds = np.copy(top_scores) removed_inds_file = os.path.join(model_dir, 'removed_inds.npy') np.save(removed_inds_file, cur_indices[removed_inds]) print('Num Poisoned Left: ', num_poisoned_after) if os.path.exists('job_result.json'): with open('job_result.json') as result_file: result = json.load(result_file) result['num_poisoned_left'] = '{}'.format(num_poisoned_after) else: result = {'num_poisoned_left': '{}'.format(num_poisoned_after)} with open('job_result.json', 'w') as result_file: json.dump(result, result_file, sort_keys=True, indent=4)